Skip to content

Commit

Permalink
Update test_asr_inference.py
Browse files Browse the repository at this point in the history
Added edge test case for streaming asr unit test and increased execution time out
  • Loading branch information
espnetUser authored May 10, 2022
1 parent beb3360 commit 61b5013
Showing 1 changed file with 75 additions and 9 deletions.
84 changes: 75 additions & 9 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import numpy as np
import pytest
import yaml

from espnet.nets.beam_search import Hypothesis
from espnet2.bin.asr_inference import get_parser
from espnet2.bin.asr_inference import main
from espnet2.bin.asr_inference import Speech2Text
from espnet2.bin.asr_inference_streaming import Speech2TextStreaming
from espnet2.tasks.asr import ASRTask
from espnet2.tasks.enh_s2t import EnhS2TTask
from espnet2.tasks.lm import LMTask
Expand Down Expand Up @@ -99,26 +101,90 @@ def asr_config_file_streaming(tmp_path: Path, token_list):
"char",
"--decoder",
"transformer",
"--encoder",
"contextual_block_transformer",
]
)
return tmp_path / "asr_streaming" / "config.yaml"


@pytest.mark.execution_timeout(10)
@pytest.mark.execution_timeout(20)
def test_Speech2Text_streaming(asr_config_file_streaming, lm_config_file):
speech2text = Speech2Text(
file = open(asr_config_file_streaming, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["frontend"] = "default"
asr_train_config["encoder_conf"] = {
"look_ahead": 16,
"hop_size": 16,
"block_size": 40,
}
# Change the configuration file
with open(asr_config_file_streaming, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)
speech2text = Speech2TextStreaming(
asr_train_config=asr_config_file_streaming,
lm_train_config=lm_config_file,
beam_size=1,
streaming=True,
)
speech = np.random.randn(10000)
for sim_chunk_length in [1, 32, 128, 512, 1024, 2048]:
if (len(speech) // sim_chunk_length) > 1:
for i in range(len(speech) // sim_chunk_length):
speech2text(
speech=speech[i * sim_chunk_length : (i + 1) * sim_chunk_length],
is_final=False,
)
results = speech2text(
speech[(i + 1) * sim_chunk_length : len(speech)], is_final=True
)
else:
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)

# Test edge case: https://github.com/espnet/espnet/pull/4216
file = open(asr_config_file_streaming, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["frontend"] = "default"
asr_train_config["frontend_conf"] = {
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
}
# Change the configuration file
with open(asr_config_file_streaming, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)
speech2text = Speech2TextStreaming(
asr_train_config=asr_config_file_streaming,
lm_train_config=lm_config_file,
beam_size=1,
streaming=True,
)
speech = np.random.randn(100000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)
# edge case: speech is exactly multiple of sim_chunk_length, e.g., 10240 = 5 x 2048
speech = np.random.randn(10240)
for sim_chunk_length in [1, 32, 64, 128, 512, 1024, 2048]:
if (len(speech) // sim_chunk_length) > 1:
for i in range(len(speech) // sim_chunk_length):
speech2text(
speech=speech[i * sim_chunk_length : (i + 1) * sim_chunk_length],
is_final=False,
)
results = speech2text(
speech[(i + 1) * sim_chunk_length : len(speech)], is_final=True
)
else:
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)


@pytest.fixture()
Expand Down

0 comments on commit 61b5013

Please sign in to comment.