Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broken Alignmnent for CTC example ? #127

Open
angusturner opened this issue Jul 5, 2023 · 2 comments
Open

Broken Alignmnent for CTC example ? #127

angusturner opened this issue Jul 5, 2023 · 2 comments

Comments

@angusturner
Copy link

angusturner commented Jul 5, 2023

Hi,

Firstly, just wanted to say this is a really cool library. I have been working on some CTC/alignment research and when i saw this trick with the parallel-scan and semi-ring it struck me as a very elegant solution.

I know the CTC example is a bit out of date (as referenced in other issues), but I am wondering how involved it is to fix it? I am hoping to compare the answers. Partly for my own understanding, and partly to see what speedups I can get from the parallel scan + custom cuda kernels.

Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).

Screen Shot 2023-07-05 at 10 50 20 pm

Would really appreciate any pointers with this if you get a chance.

@angusturner
Copy link
Author

Actually, while I'm here, can I also clarify the interpretation of the dimensions, referenced in the docs?

event_shape (N x M x 3), e.g.
phi(i, j, op)
Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1

I am bit confused how to interpret this. For example, is the interpretation of phi[i, j, 0] something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1" ?

@srush
Copy link
Collaborator

srush commented Jul 5, 2023

Oh interesting. Yes, I should update these examples for PyTorch 2. Might speed things up a lot.

Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).

I don't think it's a bug. I guess you're right that in speech you would never want this to happen. I guess the way I set up the problem I didn't forbid this behavior. You could do so by setting the "down step" motion (i-1, j) to -inf.

For example, is the interpretation of phi[i, j, 0] something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1" ?

The way the model is specified is as a CRF which means it is globally normalized. So you can set these as log-probs if you want, but they can be any score. The algorithm computes p(alignment) which is equivalent to exp(sum of scores along chosen path) / \sum_{x \in allpaths} exp(sum of scores in path x).

Speed

I spend a ton of time trying to make this fast in pytorch and think I eventually gave up. I think this one in JAX is probably a better bet https://github.com/spetti/SMURF .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants