Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LoRA based large model finetuning. #5400

Merged
merged 14 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ${CXX:-g++} -v

. ./activate_python.sh
# FIXME(kamo): Failed to compile pesq
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done whisper.done parallel-wavegan.done muskits.done
make TH_VERSION="${TH_VERSION}" WITH_OMP="${WITH_OMP-ON}" all warp-transducer.done chainer_ctc.done nkf.done moses.done mwerSegmenter.done pyopenjtalk.done py3mmseg.done s3prl.done transformers.done phonemizer.done fairseq.done k2.done gtn.done longformer.done whisper.done parallel-wavegan.done muskits.done lora.done
rm -rf kaldi
)
. tools/activate_python.sh
Expand Down
9 changes: 3 additions & 6 deletions egs2/TEMPLATE/asr1/asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -948,18 +948,15 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ] && ! [[ " ${skip_stages} " =~ [
log "Stage 5: Generate whisper token_list from ${token_type} tokenizer"


_opts=""
if [ "${token_type}" = "whisper_multilingual" ]; then
_opts+=" --language ${lang}"
fi

# The first symbol in token_list must be "<blank>" and the last must be also sos/eos:
# 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
echo ${token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${token_type}" \
--whisper_language "${lang}" \
--whisper_task "transcribe" \
--sot_asr "${sot_asr}" \
--output "${token_list}" ${_opts}
--output "${token_list}"

elif [ "${token_type}" = hugging_face ]; then
log "Stage 5: Generate hugging_face token_list from ${hugging_face_model_name_or_path}"
Expand Down
39 changes: 39 additions & 0 deletions egs2/TEMPLATE/st1/st.sh
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ lm_dev_text= # Text file path of language model development set.
lm_test_text= # Text file path of language model evaluation set.
nlsyms_txt=none # Non-linguistic symbol list if existing.
cleaner=none # Text cleaner.
hyp_cleaner=none # Text cleaner for hypotheses (may be used with external tokenizers)
g2p=none # g2p method (needed if token_type=phn).
score_opts= # The options given to sclite scoring
local_score_opts= # The options given to local/score.sh.
Expand Down Expand Up @@ -362,6 +363,14 @@ elif [ "${src_token_type}" = char ]; then
elif [ "${src_token_type}" = word ]; then
src_token_list="${src_wordtoken_list}"
src_bpemodel=none
elif [ "${src_token_type}" = whisper_en ]; then
src_token_list="${token_listdir}"/src_whisper_en/tokens.txt
src_bpemodel=whisper_en
hyp_cleaner=${cleaner}
elif [ "${src_token_type}" = whisper_multilingual ]; then
src_token_list="${token_listdir}"/src_whisper_multilingual/tokens.txt
src_bpemodel=whisper_multilingual
hyp_cleaner=${cleaner}
else
log "Error: not supported --src_token_type '${src_token_type}'"
exit 2
Expand All @@ -374,6 +383,14 @@ elif [ "${tgt_token_type}" = char ]; then
elif [ "${tgt_token_type}" = word ]; then
tgt_token_list="${tgt_wordtoken_list}"
tgt_bpemodel=none
elif [ "${tgt_token_type}" = whisper_en ]; then
tgt_token_list="${token_listdir}"/tgt_whisper_en/tokens.txt
tgt_bpemodel=whisper_en
hyp_cleaner=${cleaner}
elif [ "${tgt_token_type}" = whisper_multilingual ]; then
tgt_token_list="${token_listdir}"/tgt_whisper_multilingual/tokens.txt
tgt_bpemodel=whisper_multilingual
hyp_cleaner=${cleaner}
elif [ "${tgt_token_type}" = hugging_face ]; then
tgt_token_list="${hugging_face_token_list}"
tgt_bpemodel=${hugging_face_model_name_or_path}
Expand Down Expand Up @@ -810,6 +827,16 @@ if ! "${skip_data_prep}"; then
--add_symbol "${oov}:1" \
--add_symbol "${sos_eos}:-1"

elif grep -q "whisper" <<< ${tgt_token_type}; then
log "Stage 5a: Generate whisper token_list from ${tgt_token_type} tokenizer"

echo ${tgt_token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${tgt_token_type}" \
--whisper_language "${tgt_lang}" \
--whisper_task "translate" \
--output "${tgt_token_list}"

elif [ "${tgt_token_type}" = hugging_face ]; then
log "Stage 5: Generate hugging_face token_list from ${hugging_face_model_name_or_path}"

Expand Down Expand Up @@ -895,6 +922,16 @@ if ! "${skip_data_prep}"; then
--add_symbol "${oov}:1" \
--add_symbol "${sos_eos}:-1"

elif grep -q "whisper" <<< ${src_token_type}; then
log "Stage 5b: Generate whisper token_list from ${src_token_type} tokenizer"

echo ${src_token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${src_token_type}" \
--whisper_language "${src_lang}" \
--whisper_task "translate" \
--output "${src_token_list}"

else
log "Error: not supported --token_type '${src_token_type}'"
exit 2
Expand Down Expand Up @@ -1543,6 +1580,7 @@ if ! "${skip_eval}"; then
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${hyp_cleaner}" \
) \
<(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/hyp.trn.org"
Expand Down Expand Up @@ -1649,6 +1687,7 @@ if ! "${skip_eval}"; then
--token_type word \
--non_linguistic_symbols "${nlsyms_txt}" \
--remove_non_linguistic_symbols true \
--cleaner "${hyp_cleaner}" \
) \
<(<"${_data}/utt2spk" awk '{ print "(" $2 "-" $1 ")" }') \
>"${_scoredir}/hyp.trn"
Expand Down
51 changes: 50 additions & 1 deletion egs2/aishell/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,56 @@
|decode_asr_streaming_lm_lm_train_lm_transformer_zh_char_valid.loss.ave_asr_model_valid.acc.ave/dev|14326|205341|93.6|6.2|0.1|0.5|6.8|46.8|
|decode_asr_streaming_lm_lm_train_lm_transformer_zh_char_valid.loss.ave_asr_model_valid.acc.ave/test|7176|104765|93.0|6.7|0.2|0.8|7.8|50.7|

# Whisper Medium Finetune

# Whisper Large LoRA finetune

## Environments
- date: `Sun Aug 6 20:21:54 CST 2023`
- python version: `3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0]`
- espnet version: `espnet 202304`
- pytorch version: `pytorch 1.10.1`

## Results

- ASR config: [conf/tuning/train_asr_whisper_large_lora_finetune.yaml](conf/tuning/train_asr_whisper_large_lora_finetune.yaml)
- Decode config: [conf/tuning/decode_asr_whisper_noctc_beam10.yaml](conf/tuning/decode_asr_whisper_noctc_beam10.yaml)
- Pretrained Model:
- #Trainable Params: 7.86 M
- Link: TBD

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_beam10_asr_model_valid.acc.ave/dev|14326|205341|97.6|2.3|0.1|0.1|2.5|22.4|
|decode_asr_whisper_noctc_beam10_asr_model_valid.acc.ave/test|7176|104765|97.3|2.6|0.1|0.1|2.7|23.9|


# Whisper Medium LoRA finetune

## Environments
- date: `Thu Aug 3 21:21:52 CST 2023`
- python version: `3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0]`
- espnet version: `espnet 202304`
- pytorch version: `pytorch 1.10.1`

## Results

- ASR config: [conf/tuning/train_asr_whisper_medium_lora_finetune.yaml](conf/tuning/train_asr_whisper_medium_lora_finetune.yaml)
- Decode config: [conf/tuning/decode_asr_whisper_noctc_beam10.yaml](conf/tuning/decode_asr_whisper_noctc_beam10.yaml)
- Pretrained Model:
- #Trainable Params: 4.72 M
- Link: TBD

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|decode_asr_whisper_noctc_beam10_asr_model_valid.acc.ave/dev|14326|205341|96.9|3.0|0.1|0.1|3.2|27.0|
|decode_asr_whisper_noctc_beam10_asr_model_valid.acc.ave/test|7176|104765|96.6|3.3|0.1|0.1|3.5|28.8|


# Whisper Medium Full Finetune

## Environments
- date: `Thu Jul 13 12:40:44 CST 2023`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
normalize: null

encoder: whisper
encoder_conf:
whisper_model: large-v2
dropout_rate: 0.0
use_specaug: true
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.12
num_time_mask: 5


decoder: whisper
decoder_conf:
whisper_model: large-v2
dropout_rate: 0.0

preprocessor: default
preprocessor_conf:
tokenizer_language: "zh"

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
extract_feats_in_collect_stats: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference

frontend: null
input_size: 1 # to prevent build_model() from complaining

seed: 2022
log_interval: 100
num_att_plot: 0
num_workers: 4
sort_in_batch: descending # how to sort data in making batch
sort_batch: descending # how to sort created batches
batch_type: numel
batch_bins: 12000000 # good for 8 * RTX 3090 24G
accum_grad: 4
max_epoch: 10
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 5

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

# LoRA finetune related
use_lora: true
lora_conf:
rank: 8
alpha: 16
dropout_rate: 0.05
target_modules: ["query", "key", "value", "attn.out"]

optim: adamw
grad_clip: 1.0
optim_conf:
lr: 5.0e-04
weight_decay: 0.01
betas:
- 0.9
- 0.99
eps: 1.0e-06
scheduler: warmuplr
scheduler_conf:
warmup_steps: 1500
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
extract_feats_in_collect_stats: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
normalize: null

encoder: whisper
encoder_conf:
whisper_model: medium
dropout_rate: 0.0
use_specaug: true
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 40
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.12
num_time_mask: 5


decoder: whisper
decoder_conf:
whisper_model: medium
dropout_rate: 0.0

preprocessor: default
preprocessor_conf:
tokenizer_language: "zh"

model_conf:
ctc_weight: 0.0
lsm_weight: 0.1
length_normalized_loss: false
extract_feats_in_collect_stats: false
sym_sos: "<|startoftranscript|>"
sym_eos: "<|endoftext|>"
# do_pad_trim: true # should be set when doing zero-shot inference

frontend: null
input_size: 1 # to prevent build_model() from complaining

seed: 2022
log_interval: 100
num_att_plot: 0
num_workers: 4
sort_in_batch: descending # how to sort data in making batch
sort_batch: descending # how to sort created batches
batch_type: numel
batch_bins: 70000000 # good for 8 * RTX 3090 24G
accum_grad: 2
max_epoch: 10
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 5

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

# LoRA finetune related
use_lora: true
lora_conf:
rank: 8
alpha: 16
dropout_rate: 0.05
target_modules: ["query", "key", "value", "attn.out"]

optim: adamw
grad_clip: 1.0
optim_conf:
lr: 5.0e-04
weight_decay: 0.01
betas:
- 0.9
- 0.99
eps: 1.0e-06
scheduler: warmuplr
scheduler_conf:
warmup_steps: 1500
6 changes: 5 additions & 1 deletion espnet/nets/e2e_mt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
self.space = sym_space
self.pad = sym_pad
self.report_bleu = report_bleu
self.idx_blank = self.char_list.index(self.pad)
if self.pad in self.char_list:
self.idx_blank = self.char_list.index(self.pad)
else:
# for OpenAI Whisper model, which doesn't use <blank> token
self.idx_blank = None

Check warning on line 37 in espnet/nets/e2e_mt_common.py

View check run for this annotation

Codecov / codecov/patch

espnet/nets/e2e_mt_common.py#L37

Added line #L37 was not covered by tests
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
Expand Down