Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving OWSM inference interface #5618

Merged
merged 14 commits into from
Jan 19, 2024
Merged

Improving OWSM inference interface #5618

merged 14 commits into from
Jan 19, 2024

Conversation

pyf98
Copy link
Collaborator

@pyf98 pyf98 commented Jan 10, 2024

What?

  • This PR improves the interface for OWSM inference.
    • The speech is automatically padded or trimmed to the fixed length which is consistent with training.
    • Some attributes like speech length and time symbols can be retrieved from preprocessor_conf. We no longer need to provide them manually.
    • lang_sym, task_sym and predict_time can be passed as additional arguments when calling __call__, which will overwrite the default values in __init__. This is more convenient to use.
    • Many redundant arguments are removed from s2t_inference_language. BeamSearch is also removed. Only a single decoder forward step is performed. And it can return an N-best list of language and (normalized) probability.
  • It also implements some simple rules to suppress timestamps which is similar to Whisper. The implementation defines a new scorer that assigns certain tokens with a -inf score during search. With this modification, it can now predict the first timestamp by itself. Previously, the first timestamp has to be manually set, which is usually inaccurate.

@mergify mergify bot added the ESPnet2 label Jan 10, 2024
Copy link

codecov bot commented Jan 10, 2024

Codecov Report

Attention: 41 lines in your changes are missing coverage. Please review.

Comparison is base (3b2e0d3) 24.25% compared to head (8ee64f5) 76.06%.
Report is 121 commits behind head on master.

Files Patch % Lines
espnet2/bin/s2t_inference.py 67.39% 30 Missing ⚠️
espnet2/bin/s2t_inference_language.py 66.66% 11 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #5618       +/-   ##
===========================================
+ Coverage   24.25%   76.06%   +51.81%     
===========================================
  Files         720      735       +15     
  Lines       66641    68696     +2055     
===========================================
+ Hits        16161    52255    +36094     
+ Misses      50480    16441    -34039     
Flag Coverage Δ
test_configuration_espnet2 ∅ <ø> (∅)
test_integration_espnet1 62.92% <ø> (+<0.01%) ⬆️
test_integration_espnet2 48.11% <57.60%> (?)
test_python_espnet1 18.51% <0.00%> (-0.57%) ⬇️
test_python_espnet2 52.84% <56.00%> (?)
test_utils 22.15% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sw005320 sw005320 added OWSM Open Whisper-style Speech Model Enhancement Enhancement labels Jan 10, 2024
@sw005320 sw005320 added this to the v.202312 milestone Jan 10, 2024
@sw005320
Copy link
Contributor

Thanks, @pyf98!
@jctian98, can you review this PR?

@pyf98
Copy link
Collaborator Author

pyf98 commented Jan 15, 2024

Hi, can we considering merging this soon? I will provide example usage in the project page based on this new interface.

As reported previously, we can significantly improve the long-form ASR performance with some rules and heuristics
TEDLIUM2 WER: 8.5 -> 5.7 (greedy)

@sw005320
Copy link
Contributor

Yes, but currently, CI is broken.
So, we're checking it.
Sorry about it.

@ftshijt
Copy link
Collaborator

ftshijt commented Jan 16, 2024

@sw005320 The CI has passed the previous issue. Should be good to go ahead

@sw005320
Copy link
Contributor

@pyf98
Copy link
Collaborator Author

pyf98 commented Jan 16, 2024

Thanks! It seems that some TTS tests failed

@sw005320
Copy link
Contributor

@jctian98, this is a reminder.
Can you review this PR?

@@ -44,6 +43,105 @@
]


class ScoreFilter(BatchScorerInterface, torch.nn.Module):
"""Filter scores based on pre-defined rules."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more explanations here or in other places about what kind of pre-defined rules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more comments below

@jctian98
Copy link
Contributor

jctian98 commented Jan 19, 2024

The code is very clear. I'm ok with it.

A simple discussion for future development:
Do you think the s2t_inference_language.py can be a special case of s2t_inference.py as long as you pass a constant maxlen to beam search so that decoding one step exactly means predicting the lang_id?
@pyf98

@pyf98
Copy link
Collaborator Author

pyf98 commented Jan 19, 2024

Thanks @jctian98 for the comment.

In the previous version of s2t_inference_language.py, I simply copied the code of s2t_inference.py and then set max_len to be 1, which can predict the language token given <sos>. This design has some issues:

  • It adds many redundant arguments since we directly reuse s2t_inference.py which is designed to predict a fully formated sequence. In the new version, I removed those redundant arguments to avoid confusion. Also, it no longer requires the BeamSearch object.
  • The current BatchBeamSearch only supports batch_size=1. In the future, we may want to implement batched language detection. This will be much easier in the new version.

I think we can keep this separate design for now.

@pyf98
Copy link
Collaborator Author

pyf98 commented Jan 19, 2024

One more note is that the order of tokens is: <sop> prompt<sos><lang><task><time>xxx<eos>.

The language token is between <sos> and <task>. So, we cannot predict <lang> and text in a single autoregressive decoding run. Instead, we need to first predict <lang> given <sos>, and then predict text given <sos><lang><task> where <task> is known.

In terms of the implementation, we have to set both maxlen and hyp_primer for the two use cases. Merging the two use cases will make the code more complicated.

@sw005320 sw005320 merged commit 35c2e2b into espnet:master Jan 19, 2024
27 checks passed
@sw005320
Copy link
Contributor

Thanks, @pyf98!

@pyf98 pyf98 deleted the owsm branch January 19, 2024 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Enhancement Enhancement ESPnet2 OWSM Open Whisper-style Speech Model
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants