Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set default none decoder for ASR #4917

Merged
merged 16 commits into from
Feb 10, 2023
10 changes: 5 additions & 5 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ for t in ${feats_types}; do
for t2 in ${token_types}; do
echo "==== feats_type=${t}, token_types=${t2} ==="
./run.sh --ngpu 0 --stage 6 --stop-stage 13 --skip-upload false --feats-type "${t}" --token-type "${t2}" \
--asr-args "--max_epoch=1" --lm-args "--max_epoch=1" --python "${python}"
--asr-args "--max_epoch=1 --decoder rnn" --lm-args "--max_epoch=1" --python "${python}"
done
done
echo "==== feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --ngpu 0 --stage 10 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"

echo "==== use_streaming, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --use_streaming true --ngpu 0 --stage 6 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
Expand All @@ -45,12 +45,12 @@ if python3 -c "import k2" &> /dev/null; then
echo "==== use_k2, num_paths > nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 500 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"

echo "==== use_k2, num_paths == nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 20 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"
fi

if python3 -c "from warprnnt_pytorch import RNNTLoss" &> /dev/null; then
Expand Down Expand Up @@ -164,7 +164,7 @@ fi
if python -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
cd ./egs2/mini_an4/enh_asr1
echo "==== [ESPnet2] ENH_ASR ==="
./run.sh --ngpu 0 --stage 0 --stop-stage 15 --skip-upload_hf false --feats-type "raw" --spk-num 1 --enh_asr_args "--max_epoch=1 --enh_separator_conf num_spk=1" --python "${python}"
./run.sh --ngpu 0 --stage 0 --stop-stage 15 --skip-upload_hf false --feats-type "raw" --spk-num 1 --enh_asr_args "--max_epoch=1 --enh_separator_conf num_spk=1 --asr_decoder rnn" --python "${python}"
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
cd "${cwd}"
Expand Down
4 changes: 4 additions & 0 deletions espnet2/asr/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -147,6 +148,9 @@ def __init__(
assert (
decoder is not None
), "decoder should not be None when attention is used"
else:
decoder = None
logging.warning("Set decoder to none as ctc_weight==1.0")

self.decoder = decoder

Expand Down
2 changes: 1 addition & 1 deletion espnet2/asr/pit_espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
decoder: Optional[AbsDecoder],
ctc: CTC,
joint_network: Optional[torch.nn.Module],
ctc_weight: float = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
s4=S4Decoder,
),
type_check=AbsDecoder,
default="rnn",
default=None,
optional=True,
)
preprocessor_choices = ClassChoices(
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_asr_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down
4 changes: 4 additions & 0 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down Expand Up @@ -214,6 +216,8 @@ def enh_asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--asr_decoder",
"rnn",
]
)
return tmp_path / "enh_asr" / "config.yaml"
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_asr_inference_k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_enh_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def enh_s2t_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--asr_decoder",
"rnn",
]
)

Expand Down