Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper SOT recipe for Librimix #5371

Merged
merged 35 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
54c2f9e
update whisper tokenizer
LiChenda Jul 25, 2023
48fc9f4
whisper sot decode ok
LiChenda Jul 25, 2023
e30e2df
Merge branch 'espnet:master' into hackthon23
LiChenda Jul 25, 2023
2ba9507
fix a data prepare issue
LiChenda Jul 25, 2023
08392ea
add config
LiChenda Jul 25, 2023
50364f8
Merge remote-tracking branch 'chenda/hackthon23'
LiChenda Jul 25, 2023
6761137
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2023
728634c
update config
LiChenda Jul 25, 2023
bfa5a6c
Merge remote-tracking branch 'chenda/hackthon23'
LiChenda Jul 25, 2023
e53e8f4
add readme
LiChenda Jul 27, 2023
2987dca
update whisper model
LiChenda Jul 27, 2023
c41bfb4
update for the ci
LiChenda Jul 27, 2023
dbfad9c
Merge branch 'master' into hackthon23
LiChenda Jul 27, 2023
d5bd029
Update espnet2/bin/asr_inference.py
LiChenda Jul 27, 2023
e775128
Update asr_inference.py
LiChenda Jul 27, 2023
dc10c37
espnet2/bin/whisper_export_vocabulary.py
LiChenda Jul 27, 2023
7b461a7
Merge remote-tracking branch 'chenda/hackthon23'
LiChenda Jul 27, 2023
c74eb9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2023
d7a46cf
Merge branch 'espnet:master' into hackthon23
LiChenda Aug 2, 2023
50b9dba
Update for test
LiChenda Aug 2, 2023
6b2d1df
update for conflicts
LiChenda Aug 9, 2023
7fdefe9
update for testing
LiChenda Aug 9, 2023
77657df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2023
ebe3ac7
fix parameter issue
LiChenda Aug 9, 2023
a6c897e
Merge remote-tracking branch 'origin/hackthon23' into hackthon23
LiChenda Aug 9, 2023
d47b387
update for making speaker-change token a variable
LiChenda Aug 16, 2023
7dc5134
Merge remote-tracking branch 'upstream/master' into hackthon23
LiChenda Aug 16, 2023
8975b9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2023
0768756
update decoding
LiChenda Aug 16, 2023
7848af6
update for testing
LiChenda Aug 16, 2023
736d132
Merge remote-tracking branch 'origin/hackthon23' into hackthon23
LiChenda Aug 16, 2023
71bddcc
Merge branch 'master' into hackthon23
LiChenda Aug 31, 2023
e54bbd4
Merge branch 'master' into hackthon23
LiChenda Sep 24, 2023
1bb8a96
update timestamp tokens to Whisper exported tokens
LiChenda Sep 27, 2023
18903d4
add comments
LiChenda Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions egs2/TEMPLATE/asr1/asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -947,10 +947,6 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ] && ! [[ " ${skip_stages} " =~ [
elif grep -q "whisper" <<< ${token_type}; then
log "Stage 5: Generate whisper token_list from ${token_type} tokenizer"

if ${sot_asr}; then
log "Error: not supported SOT training for whisper token_list"
exit 2
fi

_opts=""
if [ "${token_type}" = "whisper_multilingual" ]; then
Expand All @@ -962,7 +958,9 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ] && ! [[ " ${skip_stages} " =~ [
echo ${token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${token_type}" \
--sot_asr "${sot_asr}" \
--output "${token_list}" ${_opts}

elif [ "${token_type}" = hugging_face ]; then
log "Stage 5: Generate hugging_face token_list from ${hugging_face_model_name_or_path}"

Expand Down
32 changes: 32 additions & 0 deletions egs2/librimix/sot_asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,35 @@ Following are the details about this recipe.
|---|---|---|---|---|---|---|---|---|
|decode_sot_asr_model_valid.acc.ave/dev|3000|670222|90.1|6.3|3.6|3.5|13.4|99.3|
|decode_sot_asr_model_valid.acc.ave/test|3000|605408|90.7|5.7|3.6|3.3|12.6|98.7|


## Whisper SOT results

### exp/asr_train_sot_asr_whisper_small_raw_en_whisper_multilingual

## Environments
- date: `Thu Jul 27 19:07:11 CST 2023`
- python version: `3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0]`
- espnet version: `espnet 202304`
- pytorch version: `pytorch 2.0.1+cu118`
- Git hash: `bfa5a6ca2d1697c88443b2eaecdfdec72524e05d`
- Commit date: `Tue Jul 25 18:31:13 2023 +0800`

- ASR config: [conf/tuning/train_sot_asr_whisper_small.yaml](conf/tuning/train_sot_asr_whisper_small.yaml)
- Decode config: [conf/tuning/decode_sot.yaml](conf/tuning/decode_sot.yaml)
- Pretrained model: https://huggingface.co/espnet/chendali_librimix_asr_train_sot_asr_whisper_small_raw_en_whisper_multilingual


#### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|org/dev|3000|126853|76.7|18.5|4.7|2.7|26.0|100.0|
|decode_sot_asr_model_valid.acc.ave/test|3000|114243|77.8|17.1|5.1|2.8|25.0|100.0|

#### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|org/dev|3000|673222|87.3|6.8|5.9|3.6|16.4|100.0|
|decode_sot_asr_model_valid.acc.ave/test|3000|608408|87.7|6.2|6.1|3.4|15.7|100.0|
61 changes: 61 additions & 0 deletions egs2/librimix/sot_asr1/conf/tuning/train_sot_asr_whisper.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
normalize: null

freeze_param: [
"decoder.decoders.token_embedding.ori_emb"
]

encoder: whisper
encoder_conf:
whisper_model: small
dropout_rate: 0.0
use_specaug: false

decoder: whisper
decoder_conf:
whisper_model: small
dropout_rate: 0.0
load_origin_token_embedding: true

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference


frontend: null
input_size: 1 # to prevent build_model() from complaining


# preprocessor related
preprocessor: multi
preprocessor_conf:
speaker_change_symbol:
- "<sc>"

# minibatch related
use_amp: true
num_workers: 2
batch_type: numel
batch_bins: 8000000
accum_grad: 4
max_epoch: 13
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 5

optim: adam
optim_conf:
lr: 0.0005
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 20000

specaug: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
normalize: null

freeze_param: [
"decoder.decoders.token_embedding.ori_emb"
]

encoder: whisper
encoder_conf:
whisper_model: medium
dropout_rate: 0.0
use_specaug: false

decoder: whisper
decoder_conf:
whisper_model: medium
dropout_rate: 0.0
load_origin_token_embedding: true

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference


frontend: null
input_size: 1 # to prevent build_model() from complaining


# preprocessor related
preprocessor: multi
preprocessor_conf:
speaker_change_symbol:
- "<sc>"

# minibatch related
use_amp: true
num_workers: 2
batch_type: numel
batch_bins: 8000000
accum_grad: 4
max_epoch: 3
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 1

optim: adam
optim_conf:
lr: 0.000001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 20000

specaug: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
normalize: null

freeze_param: [
"decoder.decoders.token_embedding.ori_emb"
]

encoder: whisper
encoder_conf:
whisper_model: small
dropout_rate: 0.0
use_specaug: false

decoder: whisper
decoder_conf:
whisper_model: small
dropout_rate: 0.0
load_origin_token_embedding: true

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference


frontend: null
input_size: 1 # to prevent build_model() from complaining


# preprocessor related
preprocessor: multi
preprocessor_conf:
speaker_change_symbol:
- "<sc>"

# minibatch related
use_amp: true
num_workers: 2
batch_type: numel
batch_bins: 2000000
accum_grad: 4
max_epoch: 20
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 1

optim: adam
optim_conf:
lr: 0.000001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 20000

specaug: null
2 changes: 1 addition & 1 deletion egs2/librimix/sot_asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
done

paste -d "" \
<(<data/${dset}/text_spk1 awk '{$0=$0" <sc>"; print($0)}') \
<(<data/${dset}/text_spk1 awk '{$0=$0" <sc> "; print($0)}') \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xuankai and I have considered whether a "space" should be added after "sc". We don't know if it matters and how the original paper does it. Do you have any comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be no essential difference. But the second makes the text look more natural. My uploaded pre-trained model was trained with "space" after "". Do you have any technical concerns about that "space"?

<(<data/${dset}/text_spk2 cut -d" " -f 2-) > data/${dset}/text

done
Expand Down
34 changes: 34 additions & 0 deletions egs2/librimix/sot_asr1/run_whisper_sot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

train_set="train"
valid_set="dev"
test_sets="dev test"

asr_config=conf/tuning/train_sot_asr_whisper.yaml

lm_config=conf/tuning/train_lm_transformer.yaml
inference_config=conf/tuning/decode_sot.yaml

./asr.sh \
--lang en \
--audio_format "flac.ark" \
--feats_type raw \
--token_type whisper_multilingual \
--sot_asr true \
--max_wav_duration 30 \
--feats_normalize utterance_mvn \
--use_lm false \
--asr_config "${asr_config}" \
--lm_config "${lm_config}" \
--inference_config "${inference_config}" \
--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--lm_train_text "data/${train_set}/text_spk1 data/${train_set}/text_spk2 data/local/other_text/text" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply use the ${train_set}/text which contains the <sc> in the text?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LM is not used in this recipe. I'll remove it.

--bpe_train_text "data/${train_set}/text_spk1 data/${train_set}/text_spk2" "$@"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. (not sure about LM, but bpe should be fine to use the text file)

# --speed_perturb_factors "0.9 1.0 1.1" \
69 changes: 57 additions & 12 deletions espnet2/asr/decoder/whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,36 @@
from espnet.nets.scorer_interface import BatchScorerInterface


class ExpandedTokenEmbedding(torch.nn.Module):
def __init__(self, ori_emebedding, additional_size):
super().__init__()
self.ori_emb = ori_emebedding

orig_emb_std, orig_emb_mean = torch.std_mean(ori_emebedding.weight)
self.add_emb = torch.nn.Embedding(additional_size, ori_emebedding.embedding_dim)
torch.nn.init.normal_(
self.add_emb.weight,
orig_emb_mean.item(),
orig_emb_std.item(),
)
self.num_embeddings = ori_emebedding.num_embeddings + additional_size

@property
def weight(self):
return torch.cat([self.ori_emb.weight, self.add_emb.weight], dim=0)

def forward(self, input):
return torch.nn.functional.embedding(
input,
self.weight,
self.ori_emb.padding_idx,
self.ori_emb.max_norm,
self.ori_emb.norm_type,
self.ori_emb.scale_grad_by_freq,
self.ori_emb.sparse,
)


class OpenAIWhisperDecoder(AbsDecoder, BatchScorerInterface):
"""Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:

Expand All @@ -21,6 +51,7 @@ def __init__(
dropout_rate: float = 0.0,
whisper_model: str = "small",
download_dir: str = None,
load_origin_token_embedding=False,
):
try:
import whisper
Expand All @@ -36,28 +67,42 @@ def __init__(
super().__init__()

assert whisper_model in whisper.available_models()
_model = whisper.load_model(whisper_model, download_root=download_dir)
_model = whisper.load_model(
whisper_model, download_root=download_dir, device="cpu"
)
self.decoders = copy.deepcopy(_model.decoder)
attention_dim = self.decoders.token_embedding.embedding_dim

# note that originally Whisper doesn't use dropouts
self.dropout = torch.nn.Dropout(dropout_rate)

# load the original token_embeddings, if the vocabulary is expanded
self.load_origin_token_embedding = load_origin_token_embedding

# vocab size mismatch -> reinitialize embedding
# orig vocab size (multilingual): 51865
# orig vocab size (english): 51864
if vocab_size != self.decoders.token_embedding.num_embeddings:
orig_emb_std, orig_emb_mean = torch.std_mean(
self.decoders.token_embedding.weight
)
self.decoders.token_embedding = torch.nn.Embedding(
vocab_size, attention_dim
)
torch.nn.init.normal_(
self.decoders.token_embedding.weight,
orig_emb_mean.item(),
orig_emb_std.item(),
)
if self.load_origin_token_embedding:
assert (
vocab_size > self.decoders.token_embedding.num_embeddings
), "expanded vocab_size should be larged than the origin"
self.decoders.token_embedding = ExpandedTokenEmbedding(
self.decoders.token_embedding,
vocab_size - self.decoders.token_embedding.num_embeddings,
)
else:
orig_emb_std, orig_emb_mean = torch.std_mean(
self.decoders.token_embedding.weight
)
self.decoders.token_embedding = torch.nn.Embedding(
vocab_size, attention_dim
)
torch.nn.init.normal_(
self.decoders.token_embedding.weight,
orig_emb_mean.item(),
orig_emb_std.item(),
)

self.decoders.train()
del _model
Expand Down
4 changes: 3 additions & 1 deletion espnet2/asr/encoder/whisper_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(
self.dropout = torch.nn.Dropout(dropout_rate)

assert whisper_model in whisper.available_models()
_model = whisper.load_model(whisper_model, download_root=download_dir)
_model = whisper.load_model(
whisper_model, download_root=download_dir, device="cpu"
)
self.encoders = copy.deepcopy(_model.encoder)
self.encoders.train()

Expand Down