-
Notifications
You must be signed in to change notification settings - Fork 438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: 🐛 SAR decoder indices #578
Conversation
Hey there 👋 If that's indeed an issue, we'll need to fix the PyTorch implementation as well! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Mind doing the same for PyTorch as well please?
embeded_symbol = self.embed(tf.one_hot(symbol, depth=self.vocab_size + 1), **kwargs) | ||
embeded_symbol = self.embed(tf.one_hot(symbol, depth=self.vocab_size + 2), **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind changing the PyTorch implementation as well please? 🙏
Hi @khalidMindee, I am not sure this is an issue, in the paper they mention this for the embedding of the hidden state:
Which means they have a size of len(vocab) + 1 (for EOS, which is len(vocab)) for the embedding, and the start symbol (len(vocab) + 1) does not seem to be embedded. Do I miss something there 🤔 ? |
For clarification, let me elaborate a bit because there are a few things to discuss:
|
I see your point @fg-mindee, we can indeed switch to a depth of vocab_size + 2, but I think we also need to modify it there:
And if we do this switch we also need to retrain this model, since it modifies this dense layer, or am I mistaking ? |
Well we have two options:
We need to check how the symbol is meant to be initialized, and act accordingly :) |
This is what the paper mention for the SOS symbol: The encoder and decoder do not share parameters. Initially, the holistic feature hW is fed into the decoder LSTM,at time step 0. Then a “START” token is input into LSTM at step1. From step2, the output of the previous step is fed into LSTM until the “END” token is received. All the LSTM inputs are represented by one-hot vectors, followed by a linear transformation Ψ(). During training, the inputs of decoder LSTMs are replaced by the ground-truth character sequence.The outputs are computed by the following transformation: yt=φ(h′t,gt) = softmax(Wo[h′t;gt])(1) where h′t is the current hidden state and gt is the output of the attention module.Wo is a linear transformation, which embeds features into the output space of 94 classes, in corresponding to10 digits, 52 case sensitive letters,31 punctuation characters, and an “END” token. |
So actually, we should properly specify the comment, but the code should stay the same. Or did I miss something? |
@charlesmindee ? (just to know whether we should close the PR, or iterate on it before the 0.4.1) |
@charlesmindee do you think we should close this PR? |
I think the we should keep it this way since it seems to stick to the paper's description, but again I may be mistaking here. |
Alright @khalidMindee, would you mind editing the comment above this line to specify that this is on purpose so that the one-hot does'nt have any non-zero values? (same in PyTorch if possible) Or do you prefer that we close this PR and handle this on our own? |
Yes, No problem with closing the PR . |
The question was more about the other way: is it OK for you to edit this PR and adapt the comment instead? 😅 |
Closing this in favour of #617 |
Fixed indices starting from 0 to vocab_size+1 means a one hot vector embedding with depth of vocab_size+2