Skip to content

Commit

Permalink
Merge pull request #4917 from espnet/ftshijt-patch-1
Browse files Browse the repository at this point in the history
set default none decoder for ASR
  • Loading branch information
sw005320 committed Feb 10, 2023
2 parents ffbf7e0 + 8b93196 commit 6f0983f
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 7 deletions.
10 changes: 5 additions & 5 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ for t in ${feats_types}; do
for t2 in ${token_types}; do
echo "==== feats_type=${t}, token_types=${t2} ==="
./run.sh --ngpu 0 --stage 6 --stop-stage 13 --skip-upload false --feats-type "${t}" --token-type "${t2}" \
--asr-args "--max_epoch=1" --lm-args "--max_epoch=1" --python "${python}"
--asr-args "--max_epoch=1 --decoder rnn" --lm-args "--max_epoch=1" --python "${python}"
done
done
echo "==== feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --ngpu 0 --stage 10 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"

echo "==== use_streaming, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --use_streaming true --ngpu 0 --stage 6 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
Expand All @@ -45,12 +45,12 @@ if python3 -c "import k2" &> /dev/null; then
echo "==== use_k2, num_paths > nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 500 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"

echo "==== use_k2, num_paths == nll_batch_size, feats_type=raw, token_types=bpe, model_conf.extract_feats_in_collect_stats=False, normalize=utt_mvn ==="
./run.sh --num_paths 20 --nll_batch_size 20 --use_k2 true --ngpu 0 --stage 12 --stop-stage 13 --skip-upload false --feats-type "raw" --token-type "bpe" \
--feats_normalize "utterance_mvn" --lm-args "--max_epoch=1" --python "${python}" \
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1"
--asr-args "--model_conf extract_feats_in_collect_stats=false --max_epoch=1 --decoder=rnn"
fi

if python3 -c "from warprnnt_pytorch import RNNTLoss" &> /dev/null; then
Expand Down Expand Up @@ -164,7 +164,7 @@ fi
if python -c 'import torch as t; from packaging.version import parse as L; assert L(t.__version__) >= L("1.2.0")' &> /dev/null; then
cd ./egs2/mini_an4/enh_asr1
echo "==== [ESPnet2] ENH_ASR ==="
./run.sh --ngpu 0 --stage 0 --stop-stage 15 --skip-upload_hf false --feats-type "raw" --spk-num 1 --enh_asr_args "--max_epoch=1 --enh_separator_conf num_spk=1" --python "${python}"
./run.sh --ngpu 0 --stage 0 --stop-stage 15 --skip-upload_hf false --feats-type "raw" --spk-num 1 --enh_asr_args "--max_epoch=1 --enh_separator_conf num_spk=1 --asr_decoder rnn" --python "${python}"
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
cd "${cwd}"
Expand Down
4 changes: 4 additions & 0 deletions espnet2/asr/espnet_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -147,6 +148,9 @@ def __init__(
assert (
decoder is not None
), "decoder should not be None when attention is used"
else:
decoder = None
logging.warning("Set decoder to none as ctc_weight==1.0")

self.decoder = decoder

Expand Down
2 changes: 1 addition & 1 deletion espnet2/asr/pit_espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
decoder: Optional[AbsDecoder],
ctc: CTC,
joint_network: Optional[torch.nn.Module],
ctc_weight: float = 0.5,
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
s4=S4Decoder,
),
type_check=AbsDecoder,
default="rnn",
default=None,
optional=True,
)
preprocessor_choices = ClassChoices(
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_asr_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down
4 changes: 4 additions & 0 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down Expand Up @@ -214,6 +216,8 @@ def enh_asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--asr_decoder",
"rnn",
]
)
return tmp_path / "enh_asr" / "config.yaml"
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_asr_inference_k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def asr_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--decoder",
"rnn",
]
)
return tmp_path / "asr" / "config.yaml"
Expand Down
2 changes: 2 additions & 0 deletions test/espnet2/bin/test_enh_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def enh_s2t_config_file(tmp_path: Path, token_list):
str(token_list),
"--token_type",
"char",
"--asr_decoder",
"rnn",
]
)

Expand Down

0 comments on commit 6f0983f

Please sign in to comment.