Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART & FSMT: fix decoder not returning hidden states from the last layer #8597

Merged
merged 8 commits into from
Nov 27, 2020
Merged

Conversation

MaksymDel
Copy link
Contributor

What does this PR do?

The activations from the last decoder layer accidentally were not a part of the output.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to the it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten
@stas00

@MaksymDel MaksymDel changed the title Fix decoder not returning hidden states from the last layer FSMT: fix decoder not returning hidden states from the last layer Nov 17, 2020
@stas00
Copy link
Contributor

stas00 commented Nov 17, 2020

yay, a first fsmt user that found an issue! Thank you!

OK, here I literally copied the bart implementation where it didn't have that line you added:

# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)

So most likely if this is indeed a bug then it affects many transformers models.

Now let us diagnose what's going on. I see that the x is stored in the loop above at the beginning of a layers iteration:

for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (x,)

Looking closely, the current code doesn't add the x from the last iteration of the for idx, decoder_layer in enumerate(self.layers) loop, which is clearly a bug. We have a one-off problem here.

The only thing I'm not sure about is whether we need the x before the loop, if not then all_hidden_states += (x,) needs to be moved to the end of the loop. If we do need it, then your change is due.

Either way it is I'd code it differently. I'd add add x before the loop starts if it is needed, and then add it for each layer once we have a new x defined in the loop.

Adding it after the loop is likely to cause other bugs in the future where the wrong x will be added.

Could you please share the use case so that we could write a test for it? Or if you could write the test that's even better - either way works.

I didn't have a use case for this myself so relied on transformers common tests to catch this.

Thank you!

@stas00
Copy link
Contributor

stas00 commented Nov 17, 2020

So this is what I propose, which does the same as your PR, but respects the locality rule better, if that makes sense.

        # XXX: do we need to save this hidden state?
        if output_hidden_states:
            all_hidden_states += (x,)
        
        for idx, decoder_layer in enumerate(self.layers):
            dropout_probability = random.uniform(0, 1)
            if self.training and (dropout_probability < self.layerdrop):
                continue

            layer_state = past_key_values[idx] if past_key_values is not None else None

            x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
                x,
                encoder_hidden_states,
                encoder_attn_mask=encoder_padding_mask,
                decoder_padding_mask=decoder_padding_mask,
                layer_state=layer_state,
                causal_mask=decoder_causal_mask,
                output_attentions=output_attentions,
            )

            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            if output_hidden_states:
                all_hidden_states += (x,)
            

@patrickvonplaten, how should we proceed - solve this for fsmt and then replicate to other copy-cats - or solve it at once in a new PR - and need to create a new common test I suppose. I, unfortunately, have no perms to make suggestions directly in the code. so passing the flag to you if the former.

@MaksymDel
Copy link
Contributor Author

MaksymDel commented Nov 17, 2020

Thanks, Stas @stas00!

I implemented a fix the way I did just to be consistent with how the analogous code is written in other places (e.g. FSMTEncoder, BERT model, etc.):

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

However, I would also personally prefer adding contextualized embedding before the loop first and then collecting hidden states at the end of the loop, just like you described. It just has to be changed for all the models in the repo if we want to keep the codebase consistent.

The test might check that the size of the list with output hidden states aligns in shape with what we expect it to be based on the model configuration. It would catch the error and be general enough for many usecases. It is just that it is a job for a bigger PR if we want to cover all the models in the repo.

Regarding whether to return decoder input uncontextualized embeddings, GPT2 already does it (GPT2 can be viewed as a transformer decoder):

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

Also, decoder input embeddings from layer 0 get fed into further complex layers analogously to how it is done for encoders. And for all the encoders in the lib (like BERT) we do return the outputs from this layer. So I would vote for not dropping it for the decoder.

@stas00
Copy link
Contributor

stas00 commented Nov 18, 2020

Well, based on the research that you shared, it's easy then - keep them all.

So we just need to decide whether to:

  1. a. keep the current implementation in most (all?) modules where the incoming states are stored first and then the last state is stored as sort of an afterthought and potentially is forgotten which is the case with every bart-copy, b. and fix modeling_bart and every other module that copied it to add the missing state.
  2. or recode it in a more clean way as I suggested here and you concurred with me, which will fix the bug on the way and prevent it from re-appearing in the future.

Since I wasn't there when the code was written and since it impacts the whole project let's see what @LysandreJik, @patrickvonplaten, @sgugger think.

Thank you for the detailed answer, the research, and the suggestion on how to write the missing test, @maksym-del!

@sgugger
Copy link
Collaborator

sgugger commented Nov 18, 2020

I would avoid changing the existing code since it produces the desired output, I think we can all employ our time to do more meaningful contributions to the library :-) I don't think one implementation is better than the other in the sense you have to remember to either add the first hidden state or the last.

On the models that do not produce the desired outputs, you can fix it the way you prefer. The modeling files don't need to do everything the exact same way and since you're the contributor fixing things, you get to choose which one you like better. What interests me more however is how this got the tests passing, since the number of hidden states is tested and we're discovering there is one missing, a things the common tests should have caught.

@stas00
Copy link
Contributor

stas00 commented Nov 18, 2020

While I disagree about your suggestion to two ways being equal, since the current implementation is a bug waiting to occur, should some code be added after the loop and before the last layer's hidden state is added, especially with all the code copying. I am in agreement with the rest.

To clarify, you're saying:

  • Do not change anything in models that don't have this bug.
  • You can change things in models that do have this bug besides fixing the bug (i.e. all bart copy-cats)

What interests me more however is how this got the tests passing, since the number of hidden states is tested and we're discovering there is one missing, a things the common tests should have caught.

My intuition is that since it counts, it counted the "incoming" hidden state as one of the layer hidden states. If this is a common test, then the correct models should have failed this test instead. But will need to look at the actual test to tell for sure.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 19, 2020

@maksym-del thanks so much for finding this bug -> you are correct this should be corrected.

I think we should do two things here (@maksym-del let me or @stas00 know if you need help here):

  1. Apply the same change to modeling_bart.py
  2. Improve the test (this might be a bit more difficult, but I'll help you if needed :-)):
    • If you compare the common test of the hidden states output:
      def test_hidden_states_output(self):
      with the common test of the attention output:
      if self.is_encoder_decoder:
      you can see that the test of the attention output does an extra check for is_encoder_decoder=True models while the hidden states test does not. This is why this bug was unnoticed -> so we should add a if config.is_encoder_decoder: clause to the hidden states test that checks that the decoder also has the correct number of layers and that those hidden states have the correct size.

If you have trouble adding the test ping me or @stas00 again and we'll finish the PR for you!

Thanks for spotting the bug :-)

@patrickvonplaten
Copy link
Contributor

Thanks a lot for rebasing this! I think the only thing left to do now is to add a test as explained above :-)

@MaksymDel MaksymDel changed the title FSMT: fix decoder not returning hidden states from the last layer BART & FSMT: fix decoder not returning hidden states from the last layer Nov 27, 2020
@MaksymDel
Copy link
Contributor Author

MaksymDel commented Nov 27, 2020

Thanks, @patrickvonplaten , @stas00 and @sgugger !

I added the test and think this PR is ready to be merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great for the new test, thanks a lot!

tests/test_modeling_common.py Outdated Show resolved Hide resolved
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you, @maksym-del

I just removed superfluous new line that was added by accident.

@stas00
Copy link
Contributor

stas00 commented Nov 27, 2020

Unrelated to this PR, as it's replicating the existing approach, but won't it be simpler to replace:

                x = x.transpose(0, 1)
                all_hidden_states += (x,)
                x = x.transpose(0, 1)

with just:

                all_hidden_states += (x.transpose(0, 1),)

@patrickvonplaten, replying inside my comment:

this doesn't work. x needs to be kept in the graph x.transpose(0, 1) would return a new view on the tensor which is not in the graph anymore

@patrickvonplaten patrickvonplaten merged commit 0a921b6 into huggingface:master Nov 27, 2020
@stas00
Copy link
Contributor

stas00 commented Nov 27, 2020

@patrickvonplaten - I edited your editing of my comment to make it readable. otherwise it made no sense as you made it look like I was saying something and then saying that it is not so.

Thank you for the clarification!

p.s. github doesn't send notification on edits, so this is probably not the most ideal way to reply ;)

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten - I edited your editing of my comment to make it readable. otherwise it made no sense as you made it look like I was saying something and then saying that it is not so.

Thank you for the clarification!

p.s. github doesn't send notification on edits, so this is probably not the most ideal way to reply ;)

Oh, I'm sorry. I meant to reply to your comment :D

stas00 added a commit to stas00/transformers that referenced this pull request Dec 5, 2020
…yer (huggingface#8597)

* Fix decoder not returning hidden states from the last layer

* Resolve conflict

* Change the way to gather hidden states

* Add decoder hidden states test

* Make pytest and black happy

* Remove redundant line

* remove new line

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants