Skip to content

Commit

Permalink
Merge pull request #4793 from slSeanWU/whisper-full
Browse files Browse the repository at this point in the history
Add Full Whisper Model for Finetuning
  • Loading branch information
sw005320 committed Jan 17, 2023
2 parents 8e08813 + a08c367 commit 99cea47
Show file tree
Hide file tree
Showing 34 changed files with 1,537 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ci/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ${CXX:-g++} -v
fi

. ./activate_python.sh
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pesq pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pesq pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done whisper.done
rm -rf kaldi
)
. tools/activate_python.sh
Expand Down
19 changes: 18 additions & 1 deletion egs2/TEMPLATE/asr1/asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ lm_dev_text= # Text file path of language model development set.
lm_test_text= # Text file path of language model evaluation set.
nlsyms_txt=none # Non-linguistic symbol list if existing.
cleaner=none # Text cleaner.
hyp_cleaner=none # Text cleaner for hypotheses (may be used with external tokenizers)
g2p=none # g2p method (needed if token_type=phn).
lang=noinfo # The language type of corpus.
score_opts= # The options given to sclite scoring
Expand Down Expand Up @@ -349,6 +350,14 @@ elif [ "${token_type}" = char ]; then
elif [ "${token_type}" = word ]; then
token_list="${wordtoken_list}"
bpemodel=none
elif [ "${token_type}" = whisper_en ]; then # should make token_list an output filepath here
token_list="${token_listdir}"/whisper_en/tokens.txt
bpemodel=whisper_en
hyp_cleaner=${cleaner}
elif [ "${token_type}" = whisper_multilingual ]; then
token_list="${token_listdir}"/whisper_multilingual/tokens.txt
bpemodel=whisper_multilingual
hyp_cleaner=${cleaner}
elif [ "${token_type}" = hugging_face ]; then
token_list="${hugging_face_token_list}"
bpemodel=${hugging_face_model_name_or_path}
Expand Down Expand Up @@ -762,9 +771,16 @@ if ! "${skip_data_prep}"; then
--add_symbol "${blank}:0" \
--add_symbol "${oov}:1" \
--add_symbol "${sos_eos}:-1"
elif grep -q "whisper" <<< ${token_type}; then
log "Stage 5: Generate whisper token_list from ${token_type} tokenizer"
# The first symbol in token_list must be "<blank>" and the last must be also sos/eos:
# 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
echo ${token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${token_type}" \
--output "${token_list}"
elif [ "${token_type}" = hugging_face ]; then
log "Stage 5: Generate hugging_face token_list from ${hugging_face_model_name_or_path}"

# The first symbol in token_list must be "<blank>" and the last must be also sos/eos:
# 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
${python} -m espnet2.bin.hugging_face_export_vocabulary \
Expand Down Expand Up @@ -1446,6 +1462,7 @@ if ! "${skip_eval}"; then
${python} -m espnet2.bin.tokenize_text \
-f 2- --input - --output - \
${_opts} \
--cleaner "${hyp_cleaner}" \
) \
<(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/hyp${suffix:-${suffix}}.trn"
Expand Down
41 changes: 41 additions & 0 deletions egs2/chime4/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,44 @@
|decode_et05_simu_beamformit_5micsdecode_rnn_lm_valid.loss.best_asr_model_valid.loss.best|1320|126812|83.8|8.8|7.4|3.8|20.0|93.8|
|decode_et05_simu_isolated_1ch_trackdecode_rnn_lm_valid.loss.best_asr_model_valid.loss.best|1320|126812|79.9|11.0|9.0|4.9|25.0|94.5|



<!-- Generated by scripts/utils/show_asr_result.sh -->
# RESULTS
## Environments
- date: `Tue Jan 10 04:15:30 CST 2023`
- python version: `3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]`
- espnet version: `espnet 202211`
- pytorch version: `pytorch 1.12.1`
- Git hash: `d89be931dcc8f61437ac49cbe39a773f2054c50c`
- Commit date: `Mon Jan 9 11:06:45 2023 -0600`

## asr_whisper_medium_lr1e-5_adamw_wd1e-2_3epochs

- Huggingface model URL: https://huggingface.co/espnet/shihlun_asr_whisper_medium_finetuned_chime4

### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/dt05_real_isolated_1ch_track|1640|24791|97.8|1.7|0.5|0.3|2.5|24.5|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/dt05_simu_isolated_1ch_track|1640|24792|96.1|3.0|0.9|0.5|4.4|35.6|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/et05_real_isolated_1ch_track|1320|19341|96.4|2.9|0.7|0.5|4.1|33.0|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/et05_simu_isolated_1ch_track|1320|19344|93.4|5.0|1.7|0.8|7.4|41.8|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dt05_real_isolated_1ch_track|1640|24791|97.7|1.8|0.5|0.4|2.8|25.5|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dt05_simu_isolated_1ch_track|1640|24792|96.0|3.3|0.8|0.7|4.8|36.0|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/et05_real_isolated_1ch_track|1320|19341|96.1|3.3|0.6|0.7|4.6|34.9|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/et05_simu_isolated_1ch_track|1320|19344|92.9|5.8|1.3|1.2|8.3|43.2|

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/dt05_real_isolated_1ch_track|1640|141889|99.1|0.3|0.5|0.3|1.2|24.5|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/dt05_simu_isolated_1ch_track|1640|141900|98.2|0.8|1.0|0.5|2.3|35.6|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/et05_real_isolated_1ch_track|1320|110558|98.5|0.7|0.8|0.5|1.9|33.0|
|decode_asr_whisper_noctc_beam20_asr_model_valid.acc.ave/et05_simu_isolated_1ch_track|1320|110572|96.5|1.6|1.9|0.8|4.3|41.8|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dt05_real_isolated_1ch_track|1640|141889|99.1|0.4|0.5|0.5|1.3|25.5|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dt05_simu_isolated_1ch_track|1640|141900|98.2|0.9|0.9|0.6|2.4|36.0|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/et05_real_isolated_1ch_track|1320|110558|98.4|0.9|0.7|0.6|2.2|34.9|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/et05_simu_isolated_1ch_track|1320|110572|96.3|2.0|1.7|1.2|4.9|43.2|
6 changes: 6 additions & 0 deletions egs2/chime4/asr1/conf/decode_asr_whisper_noctc_beam20.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
beam_size: 20
ctc_weight: 0.0
lm_weight: 0.0
maxlenratio: 0.25
minlenratio: 0.0
penalty: 0.0
6 changes: 6 additions & 0 deletions egs2/chime4/asr1/conf/decode_asr_whisper_noctc_greedy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
beam_size: 1
ctc_weight: 0.0
lm_weight: 0.0
maxlenratio: 0.3
minlenratio: 0.0
penalty: 0.0
74 changes: 74 additions & 0 deletions egs2/chime4/asr1/conf/tuning/train_asr_whisper_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
normalize: null

encoder: whisper
encoder_conf:
whisper_model: medium
dropout_rate: 0.0
use_specaug: true
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.12
num_time_mask: 5


decoder: whisper
decoder_conf:
whisper_model: medium
dropout_rate: 0.0

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
extract_feats_in_collect_stats: false


frontend: null
input_size: 1 # to prevent build_model() from complaining

seed: 2022
log_interval: 100
num_att_plot: 0
num_workers: 4
sort_in_batch: descending # how to sort data in making batch
sort_batch: descending # how to sort created batches
batch_type: numel
batch_bins: 10000000 # good for single GPU w/ 40G mem
accum_grad: 4
max_epoch: 3
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 3

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

optim: adamw
grad_clip: 1.0
optim_conf:
lr: 1.0e-05
weight_decay: 0.01
betas:
- 0.9
- 0.99
eps: 1.0e-06
scheduler: warmuplr
scheduler_conf:
warmup_steps: 1500
32 changes: 32 additions & 0 deletions egs2/librispeech_100/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,35 @@ Model: https://huggingface.co/pyf98/librispeech_100h_transformer
|beam20_ctc0.3/dev_other|2864|64524|78.5|15.3|6.2|2.8|24.3|83.8|
|beam20_ctc0.3/test_clean|2620|66983|90.0|6.2|3.9|0.8|10.9|63.3|
|beam20_ctc0.3/test_other|2939|66650|77.9|15.2|6.9|2.5|24.6|84.8|


# RESULTS
## Environments
- date: `Mon Jan 9 23:06:34 CST 2023`
- python version: `3.9.13 (main, Aug 25 2022, 23:26:10) [GCC 11.2.0]`
- espnet version: `espnet 202211`
- pytorch version: `pytorch 1.12.1`
- Git hash: `d89be931dcc8f61437ac49cbe39a773f2054c50c`
- Commit date: `Mon Jan 9 11:06:45 2023 -0600`

## asr_whisper_medium_finetune_lr1e-5_adamw_wd1e-2_3epochs

- Huggingface model URL: https://huggingface.co/espnet/shihlun_asr_whisper_medium_finetuned_librispeech100

### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dev_clean|2703|54798|97.7|1.9|0.3|0.3|2.6|30.1|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dev_other|2864|51528|95.3|4.3|0.4|0.6|5.3|45.4|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/test_clean|2620|53027|97.6|2.1|0.3|0.4|2.7|30.9|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/test_other|2939|52882|95.1|4.4|0.5|0.7|5.6|47.5|

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dev_clean|2703|287287|99.3|0.3|0.4|0.3|1.0|30.1|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/dev_other|2864|265648|98.3|1.0|0.7|0.6|2.3|45.4|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/test_clean|2620|280691|99.3|0.3|0.3|0.3|1.0|30.9|
|decode_asr_whisper_noctc_greedy_asr_model_valid.acc.ave/test_other|2939|271738|98.3|1.0|0.7|0.7|2.4|47.5|
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
beam_size: 1
ctc_weight: 0.0
lm_weight: 0.0
maxlenratio: 0.3
minlenratio: 0.0
penalty: 0.0
73 changes: 73 additions & 0 deletions egs2/librispeech_100/asr1/conf/tuning/train_asr_whisper_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
normalize: null

encoder: whisper
encoder_conf:
whisper_model: medium
dropout_rate: 0.0
use_specaug: true
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.12
num_time_mask: 5


decoder: whisper
decoder_conf:
whisper_model: medium
dropout_rate: 0.0

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference

frontend: null
input_size: 1 # to prevent build_model() from complaining

seed: 2022
log_interval: 100
num_att_plot: 0
num_workers: 4
sort_in_batch: descending # how to sort data in making batch
sort_batch: descending # how to sort created batches
batch_type: numel
batch_bins: 12000000 # good for single GPU w/ 40G mem
accum_grad: 4
max_epoch: 3
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 3

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

optim: adamw
grad_clip: 1.0
optim_conf:
lr: 1.0e-05
weight_decay: 0.01
betas:
- 0.9
- 0.99
eps: 1.0e-06
scheduler: warmuplr
scheduler_conf:
warmup_steps: 1500
6 changes: 5 additions & 1 deletion espnet/nets/batch_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,17 @@ def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
for k, d in self.scorers.items():
init_states[k] = d.batch_init_state(x)
init_scores[k] = 0.0

# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer = [self.sos] if self.hyp_primer is None else self.hyp_primer

return self.batchfy(
[
Hypothesis(
score=0.0,
scores=init_scores,
states=init_states,
yseq=torch.tensor([self.sos], device=x.device),
yseq=torch.tensor(primer, device=x.device),
)
]
)
Expand Down
18 changes: 17 additions & 1 deletion espnet/nets/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
token_list: List[str] = None,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = None,
hyp_primer: List[int] = None,
):
"""Initialize beam search.
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
# set configurations
self.sos = sos
self.eos = eos

# added for OpenAI Whisper decoding
self.hyp_primer = hyp_primer

self.token_list = token_list
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
Expand All @@ -104,6 +109,13 @@ def __init__(
and len(self.part_scorers) > 0
)

def set_hyp_primer(self, hyp_primer: List[int] = None) -> None:
"""Set the primer sequence for decoding.
Used for OpenAI Whisper models.
"""
self.hyp_primer = hyp_primer

def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
"""Get an initial hypothesis data.
Expand All @@ -119,12 +131,16 @@ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0

# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer = [self.sos] if self.hyp_primer is None else self.hyp_primer

return [
Hypothesis(
score=0.0,
scores=init_scores,
states=init_states,
yseq=torch.tensor([self.sos], device=x.device),
yseq=torch.tensor(primer, device=x.device),
)
]

Expand Down
7 changes: 6 additions & 1 deletion espnet/nets/e2e_asr_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def __init__(
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.idx_blank = self.char_list.index(self.blank)
# NOTE (Shih-Lun): else case is for OpenAI Whisper ASR model,
# which doesn't use <blank> token
if self.blank in self.char_list:
self.idx_blank = self.char_list.index(self.blank)
else:
self.idx_blank = None
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
Expand Down

0 comments on commit 99cea47

Please sign in to comment.