Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 🐛 SAR decoder indices #578

Closed
wants to merge 2 commits into from
Closed

Conversation

khalidMindee
Copy link
Contributor

@khalidMindee khalidMindee commented Nov 4, 2021

Fixed indices starting from 0 to vocab_size+1 means a one hot vector embedding with depth of vocab_size+2

@fg-mindee fg-mindee self-requested a review November 4, 2021 12:04
@fg-mindee fg-mindee self-assigned this Nov 4, 2021
@fg-mindee fg-mindee added type: bug Something isn't working module: models Related to doctr.models labels Nov 4, 2021
@fg-mindee fg-mindee added this to the 0.4.1 milestone Nov 4, 2021
@fg-mindee fg-mindee changed the title fix: 🐛 indices starting from 0 to vocab_size+1 means a one hot vector embedding with depth of v… fix: 🐛 SAR decoder indices Nov 4, 2021
@fg-mindee
Copy link
Contributor

Hey there 👋

If that's indeed an issue, we'll need to fix the PyTorch implementation as well!
However the loop has an extra iteration for the sos symbol, so I'd to check with @charlesmindee for this one!

@fg-mindee fg-mindee added framework: tensorflow Related to TensorFlow backend topic: text recognition Related to the task of text recognition labels Nov 4, 2021
Copy link
Contributor

@fg-mindee fg-mindee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Mind doing the same for PyTorch as well please?

embeded_symbol = self.embed(tf.one_hot(symbol, depth=self.vocab_size + 1), **kwargs)
embeded_symbol = self.embed(tf.one_hot(symbol, depth=self.vocab_size + 2), **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind changing the PyTorch implementation as well please? 🙏

@charlesmindee
Copy link
Collaborator

charlesmindee commented Nov 5, 2021

Hi @khalidMindee,

I am not sure this is an issue, in the paper they mention this for the embedding of the hidden state:

The outputs are computed by the following transformation:

 yt=φ(h′t,gt) = softmax(Wo[h′t;gt]) 

where h′t is the current hidden state and gt is the output of the attention module.
Wo is a linear transformation, which embeds features into the output space of 94 classes, 
in corresponding to 10 digits, 52 case sensitive letters, 31 punctuation characters, and an “END” token.

Which means they have a size of len(vocab) + 1 (for EOS, which is len(vocab)) for the embedding, and the start symbol (len(vocab) + 1) does not seem to be embedded.

Do I miss something there 🤔 ?

@fg-mindee
Copy link
Contributor

Hi @khalidMindee,

I am not sure this is an issue, in the paper they mention this for the embedding of the hidden state:

The outputs are computed by the following transformation:

 yt=φ(h′t,gt) = softmax(Wo[h′t;gt]) 

where h′t is the current hidden state and gt is the output of the attention module.
Wo is a linear transformation, which embeds features into the output space of 94 classes, 
in corresponding to 10 digits, 52 case sensitive letters, 31 punctuation characters, and an “END” token.

Which means they have a size of len(vocab) + 1 (for EOS, which is len(vocab)) for the embedding, and the start symbol (len(vocab) + 1) does not seem to be embedded.

Do I miss something there thinking ?

For clarification, let me elaborate a bit because there are a few things to discuss:

  • the symbol is initialized at self.vocab_size + 1. When we switch to one hot, if depth=self.vocab_size + 1, TF accepts out of bounds and leave everything in the one-hot as zeros. The question in this PR is: is that on purpose @charlesmindee?
  • should we include SOS? or any extra token? According to the paper, I'd say no

@charlesmindee
Copy link
Collaborator

I see your point @fg-mindee, we can indeed switch to a depth of vocab_size + 2, but I think we also need to modify it there:

109 self.embed = layers.Dense(embedding_units, use_bias=False, input_shape=(None, self.vocab_size + 1))

And if we do this switch we also need to retrain this model, since it modifies this dense layer, or am I mistaking ?

@fg-mindee
Copy link
Contributor

I see your point @fg-mindee, we can indeed switch to a depth of vocab_size + 2, but I think we also need to modify it there:

109 self.embed = layers.Dense(embedding_units, use_bias=False, input_shape=(None, self.vocab_size + 1))

And if we do this switch we also need to retrain this model, since it modifies this dense layer, or am I mistaking?

Well we have two options:

  • not changing the dense, and only initializing the symbol with self.vocab (so with EOS), and keeping the one hot the same way. This is the best option I guess 👈 but this assuming the symbol is meant to be initialized with EOS (if it's supposed to be initialized with no token at all, we need to keep it as is)
  • changing the dense, but this effectively adds a class (SOS I guess) and this differs from the paper. So I don't think that's the solution

We need to check how the symbol is meant to be initialized, and act accordingly :)

@charlesmindee
Copy link
Collaborator

charlesmindee commented Nov 5, 2021

This is what the paper mention for the SOS symbol:

The encoder and decoder do not share parameters. Initially, the holistic feature hW is fed into the decoder LSTM,at time step 0. Then a “START” token is input into LSTM at step1. From step2, the output of the previous step is fed into LSTM until the “END” token is received. All the LSTM inputs are represented by one-hot vectors, followed by a linear transformation Ψ(). During training, the inputs of decoder LSTMs are replaced by the ground-truth character sequence.The outputs are computed by the following transformation: yt=φ(h′t,gt) = softmax(Wo[h′t;gt])(1) where h′t is the current hidden state and gt is the output of the attention module.Wo is a linear transformation, which embeds features into the output space of 94 classes, in corresponding to10 digits, 52 case sensitive letters,31 punctuation characters, and an “END” token.

@fg-mindee
Copy link
Contributor

So actually, we should properly specify the comment, but the code should stay the same. Or did I miss something?

@fg-mindee
Copy link
Contributor

So actually, we should properly specify the comment, but the code should stay the same. Or did I miss something?

@charlesmindee ? (just to know whether we should close the PR, or iterate on it before the 0.4.1)

@fg-mindee
Copy link
Contributor

@charlesmindee do you think we should close this PR?

@charlesmindee
Copy link
Collaborator

I think the we should keep it this way since it seems to stick to the paper's description, but again I may be mistaking here.
If you agree we can close this PR indeed.

@fg-mindee
Copy link
Contributor

Alright @khalidMindee, would you mind editing the comment above this line to specify that this is on purpose so that the one-hot does'nt have any non-zero values? (same in PyTorch if possible)

Or do you prefer that we close this PR and handle this on our own?

@khalidMindee
Copy link
Contributor Author

Yes, No problem with closing the PR .

@fg-mindee
Copy link
Contributor

Yes, No problem with closing the PR .

The question was more about the other way: is it OK for you to edit this PR and adapt the comment instead? 😅

@fg-mindee
Copy link
Contributor

Closing this in favour of #617

@fg-mindee fg-mindee closed this Nov 12, 2021
@fg-mindee fg-mindee deleted the fix_SARDecoder_embedding branch November 12, 2021 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants