Skip to content

Commit

Permalink
Fix ST inference and CI test errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
pengchengguo committed Oct 11, 2023
1 parent df4536b commit 2d3ef4f
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
6 changes: 1 addition & 5 deletions egs2/TEMPLATE/st1/st.sh
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ elif [ "${tgt_token_type}" = char ]; then
elif [ "${tgt_token_type}" = word ]; then
tgt_token_list="${tgt_wordtoken_list}"
tgt_bpemodel=none
elif [ "${tgt_token_type}" = whisper_en ]; then # should make token_list an output filepath here
elif [ "${tgt_token_type}" = whisper_en ]; then
tgt_token_list="${token_listdir}"/tgt_whisper_en/tokens.txt
tgt_bpemodel=whisper_en
hyp_cleaner=${cleaner}
Expand Down Expand Up @@ -830,8 +830,6 @@ if ! "${skip_data_prep}"; then
elif grep -q "whisper" <<< ${tgt_token_type}; then
log "Stage 5a: Generate whisper token_list from ${tgt_token_type} tokenizer"

# The first symbol in token_list must be "<blank>" and the last must be also sos/eos:
# 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
echo ${tgt_token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${tgt_token_type}" \
Expand Down Expand Up @@ -927,8 +925,6 @@ if ! "${skip_data_prep}"; then
elif grep -q "whisper" <<< ${src_token_type}; then
log "Stage 5b: Generate whisper token_list from ${src_token_type} tokenizer"

# The first symbol in token_list must be "<blank>" and the last must be also sos/eos:
# 0 is reserved for CTC-blank for ASR and also used as ignore-index in the other task
echo ${src_token_list}
${python} -m espnet2.bin.whisper_export_vocabulary \
--whisper_model "${src_token_type}" \
Expand Down
1 change: 1 addition & 0 deletions espnet/nets/e2e_mt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, char_list, sym_space, sym_pad, report_bleu=False):
if self.pad in self.char_list:
self.idx_blank = self.char_list.index(self.pad)
else:
# for OpenAI Whisper model, which doesn't use <blank> token
self.idx_blank = None

Check warning on line 37 in espnet/nets/e2e_mt_common.py

View check run for this annotation

Codecov / codecov/patch

espnet/nets/e2e_mt_common.py#L37

Added line #L37 was not covered by tests
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
Expand Down
5 changes: 1 addition & 4 deletions espnet2/bin/asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,7 @@ def __init__(
elif bpemodel not in ["whisper_en", "whisper_multilingual"]:
converter = TokenIDConverter(token_list=token_list)
else:
if (
hasattr(asr_train_args, "preprocessor_conf")
and "speaker_change_symbol" in asr_train_args.preprocessor_conf
):
if "speaker_change_symbol" in preprocessor_conf:

Check warning on line 396 in espnet2/bin/asr_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/asr_inference.py#L396

Added line #L396 was not covered by tests
sot_asr = True
else:
sot_asr = False
Expand Down
12 changes: 9 additions & 3 deletions espnet2/bin/st_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,13 @@ def __init__(

# 4. [Optional] Build Text converter: e.g. bpe-sym -> Text
# compatibility for whisper tokenizer
whisper_language = st_train_args.preprocessor_conf.get("whisper_language", None)
preprocessor_conf = getattr(st_train_args, "preprocessor_conf", {})
whisper_language = preprocessor_conf.get("whisper_language", None)
whisper_task = preprocessor_conf.get("whisper_task", None)
if whisper_language:
src_token_lang, token_lang = whisper_language

Check warning on line 367 in espnet2/bin/st_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/st_inference.py#L367

Added line #L367 was not covered by tests
else:
src_token_lang, token_lang = None, None
whisper_task = st_train_args.preprocessor_conf.get("whisper_task", None)

if token_type is None:
token_type = st_train_args.token_type
Expand All @@ -383,7 +384,6 @@ def __init__(
tokenizer = build_tokenizer(

Check warning on line 384 in espnet2/bin/st_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/st_inference.py#L384

Added line #L384 was not covered by tests
token_type=token_type,
bpemodel=bpemodel,
# Whisper model only support X -> En translation
whisper_language=token_lang,
whisper_task=whisper_task,
)
Expand All @@ -397,6 +397,9 @@ def __init__(
language=token_lang or "en",
task=whisper_task or "translate",
)
beam_search.set_hyp_primer(

Check warning on line 400 in espnet2/bin/st_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/st_inference.py#L400

Added line #L400 was not covered by tests
list(converter.tokenizer.sot_sequence_including_notimestamps)
)
else:
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
Expand Down Expand Up @@ -426,6 +429,9 @@ def __init__(
language=src_token_lang or "en",
task=whisper_task or "translate",
)
asr_beam_search.set_hyp_primer(

Check warning on line 432 in espnet2/bin/st_inference.py

View check run for this annotation

Codecov / codecov/patch

espnet2/bin/st_inference.py#L432

Added line #L432 was not covered by tests
list(src_converter.tokenizer.sot_sequence_including_notimestamps)
)
else:
src_converter = TokenIDConverter(token_list=src_token_list)
logging.info(f"Src Text tokenizer: {src_tokenizer}")
Expand Down
1 change: 1 addition & 0 deletions espnet2/st/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
if tgt_sym_blank in token_list:
self.blank_id = token_list.index(tgt_sym_blank)

Check warning on line 137 in espnet2/st/espnet_model.py

View check run for this annotation

Codecov / codecov/patch

espnet2/st/espnet_model.py#L136-L137

Added lines #L136 - L137 were not covered by tests
else:
# OpenAI Whisper model doesn't <blank> token
self.blank_id = 0

Check warning on line 140 in espnet2/st/espnet_model.py

View check run for this annotation

Codecov / codecov/patch

espnet2/st/espnet_model.py#L140

Added line #L140 was not covered by tests
self.st_criterion_transducer = RNNTLoss(
blank=self.blank_id,
Expand Down
1 change: 1 addition & 0 deletions test/espnet2/layers/test_create_lora_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from packaging.version import parse as V

from espnet2.asr.decoder.transformer_decoder import TransformerDecoder
from espnet2.layers.create_lora_adapter import create_lora_adapter
Expand Down

0 comments on commit 2d3ef4f

Please sign in to comment.