Skip to content

Commit

Permalink
Merge branch 'master' into avsr1
Browse files Browse the repository at this point in the history
  • Loading branch information
sw005320 committed Oct 5, 2023
2 parents 2c6ebb0 + eef6523 commit 870cd5f
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 31 deletions.
8 changes: 5 additions & 3 deletions espnet/nets/batch_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")

logger = logging.getLogger(__name__)


class BatchHypothesis(NamedTuple):
"""Batchfied/Vectorized hypothesis data type."""
Expand Down Expand Up @@ -382,9 +384,9 @@ def post_process(
"""
n_batch = running_hyps.yseq.shape[0]
logging.debug(f"the number of running hypothes: {n_batch}")
logger.debug(f"the number of running hypothes: {n_batch}")
if self.token_list is not None:
logging.debug(
logger.debug(
"best hypo: "
+ "".join(
[
Expand All @@ -395,7 +397,7 @@ def post_process(
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
logger.info("adding <eos> in the last position in the loop")
yseq_eos = torch.cat(
(
running_hyps.yseq,
Expand Down
38 changes: 20 additions & 18 deletions espnet/nets/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.scorer_interface import PartialScorerInterface, ScorerInterface

logger = logging.getLogger(__name__)


class Hypothesis(NamedTuple):
"""Hypothesis data type."""
Expand Down Expand Up @@ -412,33 +414,33 @@ def forward(
else:
maxlen = max(1, int(maxlenratio * inp.size(0)))
minlen = int(minlenratio * inp.size(0))
logging.info("decoder input length: " + str(inp.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
logger.info("decoder input length: " + str(inp.shape[0]))
logger.info("max output length: " + str(maxlen))
logger.info("min output length: " + str(minlen))

# main loop of prefix search
running_hyps = self.init_hyp(x if pre_x is None else pre_x)
ended_hyps = []
for i in range(maxlen):
logging.debug("position " + str(i))
logger.debug("position " + str(i))
best = self.search(running_hyps, x, pre_x=pre_x)
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
logging.info(f"end detected at {i}")
logger.info(f"end detected at {i}")
break
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
logger.info("no hypothesis. Finish decoding.")
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
logger.debug(f"remained hypotheses: {len(running_hyps)}")

nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)

# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
logger.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
Expand All @@ -451,25 +453,25 @@ def forward(
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
logger.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
logger.info(f"total log probability: {best.score:.2f}")
logger.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
logger.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
if best.yseq[1:-1].shape[0] == maxlen:
logging.warning(
logger.warning(
"best hypo length: {} == max output length: {}".format(
best.yseq[1:-1].shape[0], maxlen
)
)
logging.warning(
logger.warning(
"decoding may be stopped by the max output length limitation, "
+ "please consider to increase the maxlenratio."
)
Expand All @@ -496,15 +498,15 @@ def post_process(
List[Hypothesis]: The new running hypotheses.
"""
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
logger.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logging.debug(
logger.debug(
"best hypo: "
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
logger.info("adding <eos> in the last position in the loop")
running_hyps = [
h._replace(yseq=self.append_token(h.yseq, self.eos))
for h in running_hyps
Expand Down
2 changes: 2 additions & 0 deletions espnet2/bin/whisper_export_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def export_vocabulary(
vocab_size = tokenizer.tokenizer.vocab_size + len(
tokenizer.tokenizer.get_added_vocab()
)
if whisper_model == "whisper_en":
vocab_size = vocab_size - 1

for i in range(vocab_size):
# take care of special char for <space>
Expand Down
7 changes: 7 additions & 0 deletions espnet2/text/whisper_token_id_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,15 @@ def __init__(
self.tokenizer.tokenizer.add_special_tokens(
dict(additional_special_tokens=special_tokens)
)
self.model_type = model_type

def get_num_vocabulary_size(self) -> int:
if self.model_type == "whisper_en":
return (
self.tokenizer.tokenizer.vocab_size
+ len(self.tokenizer.tokenizer.get_added_vocab())
- 1
)
return self.tokenizer.tokenizer.vocab_size + len(
self.tokenizer.tokenizer.get_added_vocab()
)
Expand Down
10 changes: 5 additions & 5 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_Speech2Text(asr_config_file, lm_config_file):
speech2text = Speech2Text(
asr_train_config=asr_config_file, lm_train_config=lm_config_file, beam_size=1
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand All @@ -95,7 +95,7 @@ def test_Speech2Text_quantized(asr_config_file, lm_config_file):
quantize_asr_model=True,
quantize_lm=True,
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_Speech2Text_hugging_face(
hugging_face_decoder_conf={"num_beams": 2, "max_new_tokens": 4},
ctc_weight=0.0,
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_Speech2Text_hugging_face_causal_lm(
hugging_face_decoder_conf={"num_beams": 2, "max_new_tokens": 4},
ctc_weight=0.0,
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand Down Expand Up @@ -413,7 +413,7 @@ def test_Speech2Text_interctc(asr_config_file, lm_config_file, encoder_class):
speech2text = Speech2Text(
asr_train_config=asr_config_file, lm_train_config=lm_config_file, beam_size=1
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results, interctc_res = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand Down
8 changes: 4 additions & 4 deletions test/espnet2/bin/test_slu_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def slu_config_file(tmp_path: Path, token_list):
@pytest.mark.execution_timeout(50)
def test_Speech2Understand(slu_config_file):
speech2understand = Speech2Understand(slu_train_config=slu_config_file, beam_size=1)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2understand(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand All @@ -70,7 +70,7 @@ def test_Speech2Understand(slu_config_file):
@pytest.mark.execution_timeout(50)
def test_Speech2Understand_transcript(slu_config_file):
speech2understand = Speech2Understand(slu_train_config=slu_config_file)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
transcript = torch.randint(2, 4, [1, 4], dtype=torch.long)
results = speech2understand(speech, transcript)

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_Speech2Understand_lm(use_lm, token_type, slu_config_file, lm_config_fil
beam_size=1,
token_type=token_type,
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2understand(speech)
for text, token, token_int, hyp in results:
assert text is None or isinstance(text, str)
Expand All @@ -134,7 +134,7 @@ def test_Speech2Understand_quantized(slu_config_file, lm_config_file):
quantize_asr_model=True,
quantize_lm=True,
)
speech = np.random.randn(100000)
speech = np.random.randn(1000)
results = speech2understand(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
Expand Down
7 changes: 7 additions & 0 deletions test/espnet2/bin/test_whisper_export_vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def test_export_vocabulary_to_stdout():
pytest.fail(f"exception thrown: {e}")


def test_export_multilinugal_vocabulary_to_stdout():
try:
export_vocabulary("-", "whisper_multilingual", "en", "INFO")
except Exception as e:
pytest.fail(f"exception thrown: {e}")


def test_export_vocabulary_en(tmp_path):
tknlist_path = tmp_path / "tmp_token_list/whisper_token_list.txt"
tknlist_path.parent.mkdir()
Expand Down
2 changes: 1 addition & 1 deletion test/espnet2/text/test_whisper_token_id_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_tokens2ids(whisper_token_id_converter: OpenAIWhisperTokenIDConverter):
assert ids == [
50259,
50359,
50363,
50303,
17155,
11,
220,
Expand Down

0 comments on commit 870cd5f

Please sign in to comment.