Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why T>=S constraint? #20

Closed
BuaaAlban opened this issue Dec 6, 2022 · 15 comments
Closed

Why T>=S constraint? #20

BuaaAlban opened this issue Dec 6, 2022 · 15 comments

Comments

@BuaaAlban
Copy link

code

Why do we need this constraint? In a regular rnnt, normally the joint may emit many blank symbol, and in this condition, T>S. But it's also possilble that S>T, e.g. we emit at least one non-blank symbols for each encoder frames.

Actually I have met this
File "/rnnt_related/rnnt-mlperf-training/model_rnnt.py", line 203, in fast_joint simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 282, in rnnt_loss_simple px, py = get_rnnt_logprobs( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 149, in get_rnnt_logprobs assert T >= S, (T, S) AssertionError: (272, 274)

@csukuangfj
Copy link

In a regular rnnt

As you have mentioned, that is for regular RNN-T.


The version we are using is not regular. It has the same condition as CTC training, i.e., S <= T.

@csukuangfj
Copy link

Here is the paper about fast_rnnt:

https://arxiv.org/pdf/2206.13236.pdf

@csukuangfj
Copy link

Here is the code to filter data that don't satisfy S<=T in icefall:
https://github.com/k2-fsa/icefall/blob/f13cf61b05432a989e6a42c95b843a56639bcbde/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L958

        # In ./conformer.py, the conv module uses the following expression
        # for subsampling
        T = ((c.num_frames - 1) // 2 - 1) // 2
        tokens = sp.encode(c.supervisions[0].text, out_type=str)

        if T < len(tokens):
            logging.warning(
                f"Exclude cut with ID {c.id} from training. "
                f"Number of frames (before subsampling): {c.num_frames}. "
                f"Number of frames (after subsampling): {T}. "
                f"Text: {c.supervisions[0].text}. "
                f"Tokens: {tokens}. "
                f"Number of tokens: {len(tokens)}"
            )
            return False

@BuaaAlban
Copy link
Author

Thanks for your fast reply.
I have tried to modify my code based on this example, I thinks it's a normal transducer. I can filter the data as you said to make it work. I just wonder why we has this limitation (for optimization? Actually I have read your paper yesterday but I didn't notice this condition, I will double check it), could I just comment this assert to make the pruned loss work just like the rnnt_loss (like in torchaudio or warp-transducer)

@desh2608
Copy link

@BuaaAlban as you noted, this constraint is indeed not required for the "regular" RNNT topology. Only if you train with the "modified" topology, where you are constrained to emit exactly 1 symbol per time frame, will this constraint be required. We have a PR here (k2-fsa/k2#1149) to remove this constraint from k2. I will also make a similar PR for fast_rnnt.

@arkadyark
Copy link

@desh2608 are you still planning to make this PR? This would be very useful for my work!

@desh2608
Copy link

desh2608 commented May 1, 2023

@arkadyark sorry I forgot to actually push the changes. BTW, I believe Dan fixed some OOM issues in the pruned transducer loss in k2, which hasn't yet been merged in fast_rnnt. So you may want to make those changes yourself.

@arkadyark
Copy link

Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.

@desh2608
Copy link

desh2608 commented May 1, 2023

Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.

Check k2-fsa/k2#1177 and k2-fsa/k2#1183

@danpovey
Copy link
Collaborator

danpovey commented May 2, 2023

Ah yes. Arkady, it would be great if you could make a PR to fast_rnnt with those changes, I had forgotten about that. If not LMK, I'll ask someone here.

@arkadyark
Copy link

I would love to contribute those back, but unfortunately there's a fairly involved open-source contribution process at my organization that would take a while, it'd probably be best to find someone else to do so.

However, I did test this out locally, and re-ran the benchmarking at https://github.com/csukuangfj/transducer-loss-benchmarking - the results look really good, peak memory usage goes from 3820 all the way down to 1182 (!), and from 2647 to 835 when sorting utterances. Step time (on my hardware) went from 343k to 280k us.

Pretty cool! Always gotta be careful with those torch.gathers.

@arkadyark
Copy link

Hey @danpovey , just wanted to follow up - is anybody able to make those changes here?

@danpovey
Copy link
Collaborator

@pkufool could you please have a look at this?

@pkufool
Copy link
Contributor

pkufool commented Jul 19, 2023

@danpovey Yifan has already made PRs here #26 and #24 , you can merge it.

@pkufool
Copy link
Contributor

pkufool commented Aug 25, 2023

closed by #29

@pkufool pkufool closed this as completed Aug 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants