Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Whisper SOT recipe for Librimix #5371

Merged
merged 35 commits into from Sep 28, 2023
Merged

Conversation

LiChenda
Copy link
Contributor

What?

This PR adds the Whisper SOT-style multi-talker ASR recipe for the Librimix dataset.

Why?

To add the multi-talker Whisper recipe.

@mergify mergify bot added the ESPnet2 label Jul 25, 2023
@sw005320 sw005320 added New Features ASR Automatic speech recogntion labels Jul 25, 2023
@sw005320 sw005320 added this to the v.202307 milestone Jul 25, 2023
Copy link
Collaborator

@simpleoier simpleoier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It looks good to me in general.

--train_set "${train_set}" \
--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--lm_train_text "data/${train_set}/text_spk1 data/${train_set}/text_spk2 data/local/other_text/text" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply use the ${train_set}/text which contains the <sc> in the text?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LM is not used in this recipe. I'll remove it.

--valid_set "${valid_set}" \
--test_sets "${test_sets}" \
--lm_train_text "data/${train_set}/text_spk1 data/${train_set}/text_spk2 data/local/other_text/text" \
--bpe_train_text "data/${train_set}/text_spk1 data/${train_set}/text_spk2" "$@"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. (not sure about LM, but bpe should be fine to use the text file)

@@ -347,7 +348,7 @@ def __init__(
if bpemodel not in ["whisper_en", "whisper_multilingual"]:
converter = TokenIDConverter(token_list=token_list)
else:
converter = OpenAIWhisperTokenIDConverter(model_type=bpemodel)
converter = OpenAIWhisperTokenIDConverter(model_type=bpemodel, sot=sot_asr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to check if <sc> in the token list in this case? Or has it been done in the token_id_converter class?

Copy link
Contributor Author

@LiChenda LiChenda Jul 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sot_asr is True, the OpenAIWhisperTokenIDConverter will add the <sc> token in it's init function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the inference part, we can retrieve the value of sot_asr from the asr_train_args, which eliminates the need for dumplicate params in both train_args and decode_args, making them independent of each other.
In most cases, I think decode_args should only contains the params specific to inference, like beam_size, lm_weight, etc., without including those realted to model or tokenizer initialization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I updated the asr_inference.py, and now it get the sot_asr from asr_train_args.

@LiChenda LiChenda marked this pull request as ready for review July 27, 2023 12:23
@mergify mergify bot added the README label Jul 27, 2023
LiChenda and others added 2 commits July 27, 2023 20:23
Co-authored-by: Wangyou Zhang <C0me_On@163.com>
@codecov
Copy link

codecov bot commented Jul 27, 2023

Codecov Report

Merging #5371 (18903d4) into master (8a8709e) will increase coverage by 0.00%.
Report is 1 commits behind head on master.
The diff coverage is 82.69%.

@@           Coverage Diff           @@
##           master    #5371   +/-   ##
=======================================
  Coverage   77.17%   77.17%           
=======================================
  Files         684      684           
  Lines       62643    62686   +43     
=======================================
+ Hits        48343    48380   +37     
- Misses      14300    14306    +6     
Flag Coverage Δ
test_configuration_espnet2 ∅ <ø> (∅)
test_integration_espnet1 65.73% <ø> (ø)
test_integration_espnet2 49.07% <20.00%> (-0.03%) ⬇️
test_python_espnet1 19.94% <0.00%> (-0.02%) ⬇️
test_python_espnet2 52.30% <82.69%> (+0.01%) ⬆️
test_utils 23.10% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
espnet2/asr/decoder/whisper_decoder.py 94.66% <100.00%> (+1.56%) ⬆️
espnet2/asr/encoder/whisper_encoder.py 80.45% <100.00%> (ø)
espnet2/bin/whisper_export_vocabulary.py 92.59% <100.00%> (+0.75%) ⬆️
espnet2/text/whisper_token_id_converter.py 87.87% <100.00%> (+2.69%) ⬆️
espnet2/text/whisper_tokenizer.py 88.23% <100.00%> (+2.52%) ⬆️
espnet2/text/build_tokenizer.py 78.37% <0.00%> (ø)
espnet2/bin/asr_inference.py 86.98% <0.00%> (-0.68%) ⬇️
espnet2/train/preprocessor.py 77.53% <16.66%> (-0.39%) ⬇️

... and 1 file with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

for i in range(full_vocab_size - vocab_size):
fout.write("()" + "\n")

if sot_asr:
full_vocab_size += 1
fout.write("<sc>" + "\n")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to import custom tokens from a file, like additional_vocab. This would allow us to add any tokens as needed, rather than only adding a token for the SOT training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Whisper's vocabulary is not imported form a file, but loaded form the whisper pypi package. So, here I add the special token in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to have at least a way to specify which special token would be used as a speaker change. The hard-corded token symbol may confuse the user.

I guess this would depend on other places, so if it is difficult to make this speaker change token a variable, we could at least provide the comment in the source code and README that is a reserved token in an appropriate place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions! Now it's configurable.

@kan-bayashi kan-bayashi modified the milestones: v.202307, v.202312 Aug 3, 2023
@mergify
Copy link
Contributor

mergify bot commented Aug 3, 2023

This pull request is now in conflict :(

@mergify mergify bot added the conflicts label Aug 3, 2023
@mergify mergify bot removed the conflicts label Aug 9, 2023
@sw005320 sw005320 changed the title [WIP] Add Whisper SOT recipe for Librimix Add Whisper SOT recipe for Librimix Aug 9, 2023
@sw005320
Copy link
Contributor

Thanks, @LiChenda!
@pengchengguo, can you review it again?
Do we need to take care of the other part for the symbol?

@sw005320
Copy link
Contributor

@pengchengguo, this is a reminder. Can you review this PR again?

Copy link
Collaborator

@pengchengguo pengchengguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have a few advice and the rest looks good to me.

@@ -188,7 +188,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
done

paste -d "" \
<(<data/${dset}/text_spk1 awk '{$0=$0" <sc>"; print($0)}') \
<(<data/${dset}/text_spk1 awk '{$0=$0" <sc> "; print($0)}') \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xuankai and I have considered whether a "space" should be added after "sc". We don't know if it matters and how the original paper does it. Do you have any comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be no essential difference. But the second makes the text look more natural. My uploaded pre-trained model was trained with "space" after "". Do you have any technical concerns about that "space"?

for i in range(full_vocab_size - vocab_size):
fout.write("()" + "\n")

if sot_asr:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exported token.txt file includes 50257 normal tokens, 1501 "{}" (should be timestamps), and 1 sc token.
During training, the tokenizer will add additional 1501 timestamps tokens and 1 sc token, as implemented in whisper_token_id_converter.py and whisper_token_id_converter.py.
Although the token.txt file will not be used in practice, it is better to change the {} to real timestamps and make it consistent with training progress.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments, I also feel like that will be better. Now updated.

Copy link
Contributor

@sw005320 sw005320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments

@@ -61,6 +68,16 @@ def __init__(self, model_type: str, language: str = "en"):
else:
raise ValueError("tokenizer unsupported:", model_type)

self.tokenizer = copy.deepcopy(self.tokenizer)
timestamps = [f"<|{i*30/1500:.2f}|>" for i in range(0, 1501)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming that 30 comes from the whisper's sentence length and 1500 comes from the 20ms shift.
It's okay to embed such numbers, but it would be great to add some comments (logging?) and also make them variables inside the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing it out! I updated the code according to your comments.

@sw005320 sw005320 merged commit 522fb13 into espnet:master Sep 28, 2023
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR Automatic speech recogntion ESPnet2 New Features README
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants