Skip to content

Commit

Permalink
Merge pull request #5703 from jctian98/master
Browse files Browse the repository at this point in the history
fix a small issue in OWSM decode_long
  • Loading branch information
sw005320 committed Mar 22, 2024
2 parents bdd2351 + d52dcb3 commit dbd73dd
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions espnet2/bin/s2t_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def decode_long(
end_time_threshold: str = "<29.00>",
lang_sym: Optional[str] = None,
task_sym: Optional[str] = None,
skip_last_chunk_threshold: float = 0.2,
):
"""Decode unsegmented long-form speech.
Expand Down Expand Up @@ -566,17 +567,25 @@ def decode_long(
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)

assert (
speech.dim() == 1
), f"speech must have one dimension, got size {speech.size()} instead"
if speech.dim() > 1:
assert (
speech.dim() == 2 and speech.size(1) == 1
), f"speech of size {speech.size()} is not supported"
speech = speech.squeeze(1) # (nsamples, 1) --> (nsamples,)

utterances = []
offset = 0
text_prev = init_text
while offset < len(speech):
logging.info(f"Current start time in seconds: {offset / fs:.2f}")

segment = speech[offset : offset + segment_len]
if len(segment) / fs < skip_last_chunk_threshold:
logging.warning(
f"Skip the last chunk as it's too short: {len(segment) / fs:.2f}s"
)
offset += segment_len
continue

# segment will be padded in __call__
result = self.__call__(
speech=segment,
Expand Down Expand Up @@ -641,7 +650,6 @@ def decode_long(
utterances.append(utt)

offset += round((new_start_time_id - first_time_id) * resolution * fs)
self.time_id = first_time_id

return utterances

Expand Down

0 comments on commit dbd73dd

Please sign in to comment.