Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART Large generate predictions are wonky #15559

Closed
StephAO opened this issue Feb 8, 2022 · 48 comments · Fixed by #15879
Closed

BART Large generate predictions are wonky #15559

StephAO opened this issue Feb 8, 2022 · 48 comments · Fixed by #15879
Assignees

Comments

@StephAO
Copy link

StephAO commented Feb 8, 2022

Environment info

  • transformers version: 4.16.2 (issue exists on 4.9.2)
  • Platform: Linux-4.4.0-210-generic-x86_64-with-glibc2.10
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.8.1+cpu (False)
  • Tensorflow version (GPU?): 2.3.1 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help

@patrickvonplaten @sshleifer

Information

Essentially re-opening issue 8005, BART-large does not mask fill properly (whereas BART-base has entirely reasonable outputs). The previous fix of setting force_bos_token_to_be_generated = True is no longer viable since the option no longer exists in BART config. It also seems like adjust_logits_during_generation (where force_bos_token_to_be_generated was used) is no longer implemented in the BART model.

To reproduce

Steps to reproduce the behavior:

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base", forced_bos_token_id=0)
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
batch = tokenizer("My friends are <mask> but they eat too many carbs.", return_tensors="pt")
generated_ids = model.generate(batch["input_ids"])
print(tokenizer.decode(generated_ids[0]))
# Output: </s><s>My friends are healthy, but they eat too many carbs.</s>

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
batch = tokenizer("My friends are <mask> but they eat too many carbs.", return_tensors="pt")
generated_ids = model.generate(batch["input_ids"])
print(tokenizer.decode(generated_ids[0]))
# Output: </s>My,, but they eat too many carbs.</s>```

@StephAO StephAO changed the title BART BART Large generate predictions are wonky Feb 10, 2022
@SivilTaram
Copy link
Contributor

I have encountered the same problem. And I have also found that bart-large cannot be fine-tuned with a reasonable output.

@patrickvonplaten
Copy link
Contributor

@patil-suraj - we had another issue thread about this somewhere no? I can't find it anymore though :-/

@StephAO
Copy link
Author

StephAO commented Feb 15, 2022

@patrickvonplaten
It might be this one that I linked in the original description: #8005

@patrickvonplaten
Copy link
Contributor

Found the original issue: #9731 . Looking a bit into the commit history here: https://huggingface.co/facebook/bart-large/commits/main it looks like the mask token problem actually only existed for bart-base and not for bart-large according to @patil-suraj .

@patil-suraj - could you double-check this real quick?

@ayaka14732
Copy link
Contributor

ayaka14732 commented Feb 19, 2022

I can confirm that bart-large generates very strange output.

tokenizer = BartTokenizer.from_pretrained(model_name)
model = FlaxBartForConditionalGeneration.from_pretrained(model_name)
model.params = model.to_bf16(model.params)  # convert float16 to bfloat16 (for TPU)

sentences = (
    'She waded to the bank and picked up her shoes and stockings.',
    'The bank is increasing the amount they lend to small companies.',
)

inputs = tokenizer(sentences, padding=True, return_tensors='jax')
output = model.generate(inputs.input_ids)
print(tokenizer.batch_decode(output.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False))

facebook/bart-base:

['She waded to the bank and picked up her shoes and stockings.', 'The bank is increasing the amount they lend to small companies.']

facebook/bart-large:

['She.....', 'TheThe']

@Impelon
Copy link

Impelon commented Feb 19, 2022

To me this seems to be unrelated to #9731;
@patrickvonplaten's previous method to check whether the correct mask-token is used, produces a difference of 0 when used with the large model.
So it looks like this is not related to the mask-token, as @ayaka14732's example does not even use masks.

@patrickvonplaten
Copy link
Contributor

I see sorry you're right - I looked too quickly indeed. Will take a deeper look in the coming days.

@salrowili
Copy link

Any updates?

@patrickvonplaten
Copy link
Contributor

Hey @StephAO,

You are passing forced_bos_token_id to the tokenizer instead of the model. It should be passed to the model. When running this code:

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
batch = tokenizer("My friends are <mask> but they eat too many carbs.", return_tensors="pt")
generated_ids = model.generate(batch["input_ids"])
print(tokenizer.decode(generated_ids[0]))

gives sensible outputs:

</s><s>My friends are good people, but they eat too many carbs.</s>  

Where did you see code that forced_bos_token_id should be passed to the tokenizer?

@StephAO
Copy link
Author

StephAO commented Mar 1, 2022

Good catch. I am actually unsure where/if I saw code that passed forced_bos_token_id to the tokenizer, it is possible that when I was playing around with the model to try and get it to work, I ended up adding the argument to the wrong spot. That being said, the documentation is unclear in a few places:

  • The Mask Filling example in BartForConditionalGeneration does not include this necessary argument
  • The ImplementaionNotes mention force_bos_token_to_be_generated=True instead of forced_bos_token_id=0 (which I assume is the old version of this argument?)
  • None of the BART models or the PreTrainedModel it inherits specify forced_bos_token_id as a possible argument. In fact, the only place I can find it defined is in GenerationMixin, and even there it is not clear what this argument does.

Lastly, it is still confusing to me why this is required for bart-large, but not for bart-base. Either way, thank you for the update!

@patrickvonplaten
Copy link
Contributor

Thanks for the great feedback! Actually if you would be interested, it would be amazing to open a PR to fix the docs - both the mask filling example and the implementation notes - to fix this :-)

Otherwise I'm happy to open a PR for it as well!

@SivilTaram
Copy link
Contributor

@patrickvonplaten Thanks for the explanitaton, it helps a lot! Actually, I think it does not only effect mask infilling, but also the fine-tuning procedure (at least on my side). Without the forced_bos_token_id, the bart-large model cannot be fine-tuned well. Therefore, I recommend to set the default value of forced_bos_token_id inside the BART model config file if you do not mind.

@patrickvonplaten
Copy link
Contributor

Hmm, the problem is that I'm not sure whether this should be done or not bart-base and bart-large. E.g. when the model was added the author explicitly mentioned that forced_bos_token_id should only be used for bart-large-cnn by default - see:

Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only

This is also why only bart-large-cnn has this config attribute set by default -> see: https://huggingface.co/facebook/bart-large-cnn/blob/main/config.json#L27 and the other don't.

@sshleifer sorry to ping you here - do you remember by any chance if forced_bos_token_id is recommended to be used when fine-tuning BART? E.g. should one always place BOS after decoder_start_token_id for fine-tuning?

Also cc @patil-suraj any ideas?

@patil-suraj
Copy link
Contributor

I'm not sure whether this should be done or not bart-base and bart-large

Sam will have a better answer but IIRC the forced_bos_token_id (previously, force_bos_token_to_be_generated) was added to be able to reproduce the bart-large-cnn results. And in our experiments, we had found that this is only required for the cnn pre-trained model and other checkpoints were not affected by this. Found a related discussion and PR.

do you remember by any chance if forced_bos_token_id is recommended to be used when fine-tuning BART? E.g. should one always place BOS after decoder_start_token_id for fine-tuning?

In my experiments, it actually doesn't matter. It depends on how the decoder_input_ids are prepared, for example, if the decoder_input_ids are like this: [eos, bos, .....] then BOS should be forced.

But now the issue is, in bart like models the decoder_input_ids are prepared by calling shift_tokens_right function on the labels if decoder_input_ids are not passed. This is how it's done in our summarization and translation fine-tuning examples. The decoder_inputs_ids are prepared in DataCollatorForSeq2Seq by calling model.prepare_decoder_input_ids_from_labels which then calls shift_tokens_right

decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

And this will always add bos after eos. See:

from transformers import BartTokenizer
from transformers.models.bart.modeling_bart import shift_tokens_right

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

labels = tokenizer("This is a test", return_tensors="pt").input_ids
decoder_input_ids = shift_tokens_right(labels, tokenizer.pad_token_id, tokenizer.eos_token_id)

decoder_input_ids
# tensor([[   2,    0,  713,   16,   10, 1296]])

tokenizer.batch_decode(decoder_input_ids)
['</s><s>This is a test']

but forced_bos_token_id is not set by default. So this might affect generations for models trained using these scripts.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 2, 2022

Thanks for the summary @patil-suraj - that's super helpful!

So as I undertsand it, the BartTokenizer will always add the BOS token to the beginning.
E.g.:

from transformers import BartTokenizer

tok = BartTokenizer.from_pretrained("facebook/bart-large")
print(tok.decode(tok("hello").input_ids))
Out[4]: '<s>hello</s>'

This means for Seq2Seq if someone follows any of our official examples (both accelerate and Trainer), this means the labels are
created to be:

<s> label text </s>

with the decoder input ids then (since they are the labels shifted to the left):

<decoder-start-token-id><s> label text </s>

Now, we know from Sam's comment here that it is not necessary to add the BOS token (<s>) for successful fine-tuning. One also gets good results when not adding the BOS token - i.e. it doesn't really make a difference.

But the problem now is that while we quitly add BOS to the labels in all of our example scripts for fine-tuning because of BartTokenizer's behavior, we don't "force-add" the BOS token when evaluating the model with generate because forced_bos_token_id is not set by default in the config. This means that while the model is well-trained when using the examples the evaluation results are probably not as good as they could be if we would force the BOS token to be generated.

On the other hand, one could also argue that the model should learn to always generate <BOS> as the first token and it's not needed. However, we know that results are better if we force <BOS> to be generated if the model has been trained as explained above.

As a conclusion, @patil-suraj and I think, we should actually add forced_bos_token_id=0 by default to the pretrained bart models: bart-large and bart-base . This would be a breaking change for people that use bart-large by default for mask-filling with generate, but as seen above it should improve results.

Since those two checkpoints are highly used, keen to hear your opinion @LysandreJik @sgugger here

@sgugger
Copy link
Collaborator

sgugger commented Mar 2, 2022

Are we 100% sure that the change would only make predictions better in all circumstances?

@patrickvonplaten
Copy link
Contributor

Are we 100% sure that the change would only make predictions better in all circumstances?

If someone fine-tunes with a BOS as the second token (behind decoder_start_token_id) - which is done in all example scripts, then yes the change would make predictions better in all circumstances for the fine-tuned model.

@sgugger
Copy link
Collaborator

sgugger commented Mar 2, 2022

I'm not talking about fine-tuning, I'm talking about users relying on this model in production right now.

@patrickvonplaten
Copy link
Contributor

This is a pretrained checkpoint - so I highly doubt anybody uses this model + generate() in production. The only use case is to generate a <mask> token as shown in the issue description. It doesn't make sense to use this in production (why use expensive generate for a <mask> token instead of BERT?). And even for this use case the change works better (however this is hard to test)

@salrowili
Copy link

Does this problem can cause a pretrained model from a scratch to show poor loss and accuracy score? I already pre-train a model with TensorFlow Keras and BART and it show me a good logit accuracy (~80%) for base-scale. However, i was struggling for months to use BART large using the same experimental setup. The logit accuracy for BART large never goes beyond ~4%. I did everything including decreasing and increasing batch size, learning rate .. etc but with no luck.

@SivilTaram
Copy link
Contributor

@salrowili Yeah, I can confirm it will. Compared with bart-base, the prediction of bart-large seems to be with some randomness, even an experiment which evaluates if it could convergence on a small and toy dataset.

@salrowili
Copy link

@salrowili Yeah, I can confirm it will. Compared with bart-base, the prediction of bart-large seems to be with some randomness, even an experiment which evaluates if it could convergence on a small and toy dataset.

Did you manage to fix it with forced_bos_token_id=0 solution?

@SivilTaram
Copy link
Contributor

SivilTaram commented Mar 7, 2022

@salrowili Yeah, I can confirm it will. Compared with bart-base, the prediction of bart-large seems to be with some randomness, even an experiment which evaluates if it could convergence on a small and toy dataset.

Did you manage to fix it with forced_bos_token_id=0 solution?

According to my efforts now, I can observe the following facts:

  • With setting forced_bos_token_id=0, the model bart-large can be fine-tuned to overfit on a small toy dataset, while the loss is still abnormal.
  • On a real dataset WikiSQL, the model bart-large shows promising fine-tuning results on some experimental runs, e.g., showing a comparable performance with the fine-tuned model by fairseq.

However, some experimental runs will fail when switching the fine-tuning environment (e.g., from single-card to multi-card). And it often occurs with toooo long evaluating time since the model is trying to produce a sequence such as </s> <s> <s> <s> <s> ... until it arrives the end of val_max_target_length. These failure predictions will all be empty, and these errors may not be solved by the forced_bos_token_id=0 solution.

<s> corresponds to the bos_token_id, and </s> corresponds to both the decoder_start_token_id and eos_token_id.

One failure case can be as below, and the denotation_accuracy shows a BIG jump among fine-tuning steps:

W B Chart 2022_3_7 10_20_57

@salrowili
Copy link

salrowili commented Mar 7, 2022

@salrowili Yeah, I can confirm it will. Compared with bart-base, the prediction of bart-large seems to be with some randomness, even an experiment which evaluates if it could convergence on a small and toy dataset.

Did you manage to fix it with forced_bos_token_id=0 solution?

According to my efforts now, I can observe the following facts:

  • With setting forced_bos_token_id=0, the model bart-large can be fine-tuned to overfit on a small toy dataset, while the loss is still abnormal.
  • On a real dataset WikiSQL, the model bart-large shows promising fine-tuning results on some experimental runs, e.g., showing a comparable performance with the fine-tuned model by fairseq.

However, some experimental runs will fail when switching the fine-tuning environment (e.g., from single-card to multi-card). And it often occurs with toooo long evaluating time since the model is trying to produce a sequence such as </s> <s> <s> <s> <s> ... until it arrives the end of val_max_target_length. These failure predictions will all be empty, and these errors may not be solved by the forced_bos_token_id=0 solution.

<s> corresponds to the bos_token_id, and </s> corresponds to both the decoder_start_token_id and eos_token_id.

One failure case can be as below, and the denotation_accuracy shows a BIG jump among fine-tuning steps:

W B Chart 2022_3_7 10_20_57

Thank for sharing this interesting findings. I have also conduct a small experiment involving fine-tuning SQuAD with both BART-base and BART-large with Pytorch XLA on TPU-8 unit with Google Colab (attached). its based on this colab https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb . Although i uses Google Colab pro which has more memory 35GB, i think Google Colab free gives free TPU with 25GB so for everyone who is interesting in replicating this experiment can do it , but you may need to reduce the value of "per_device_train_batch_size" to use less memory. per_device_train_batch_size*8= Total Batch size. If you ran into out of memory error reduce the per_device_train_batch_size. if you got SIG error restart the colab and run again all cells that are needed to restore variable and importing packages. At beginning the training will be slow since XLA compilation needs more time especially for bart-large (~10mins). largest batch size is 16 for bart-large (per_device_train_batch_size=2)
I ran both large and base for one epoch. loss score are very close and it seems fine-tuning BART-large is not affected by this issue. I tried to do prediction with BART-large but i got OOM error upon saving the model. However, i needed to add skip_special_tokens=True to tokenizer.decode function to get rid of s and pad tokens in BART-base. I am also still worry about pre-training BART large from scratch because it involve using a lot of resource and uses [mask] token.
Comparing_BART_base_vs_BART_large_on_TPU.zip

@SivilTaram
Copy link
Contributor

SivilTaram commented Mar 7, 2022

@salrowili Thanks for sharing! I think currently this bug will affect all NLG tasks (e.g., summarization), but not for NLU tasks (e.g., classification and extractive machine reading comprehension). I get some ideas why it becomes so, and would like to share here later when it is confirmed.

@ayaka14732
Copy link
Contributor

Thanks! I am still curious about:

  1. Why does this issue happen with bart-large but not bart-base?
  2. Is this issue only related to the Hugging Face model, or affects the model in the original Facebook repository as well?

@salrowili
Copy link

salrowili commented Mar 9, 2022

Thank you @SivilTaram for this detailed explanation. Wonderful effort.

Could this problem be related to mBART? because mBART has only a large-scale version, not a base scale and I think the first token is designed for language code. see #9811

@patrickvonplaten
Copy link
Contributor

Thanks a lot for all the work guys! I think at this point we can only add forced_bos_token_id=0 to the config of the pretrained checkpoints as discussed previously - @patil-suraj and I just did it here:

Also maybe we can ask the official authors what they think about your findings.
Gently pinging @ngoyal2707 here

@SivilTaram
Copy link
Contributor

SivilTaram commented Mar 10, 2022

Thanks a lot for all the work guys! I think at this point we can only add forced_bos_token_id=0 to the config of the pretrained checkpoints as discussed previously - @patil-suraj and I just did it here:

Also maybe we can ask the official authors what they think about your findings. Gently pinging @ngoyal2707 here

Yes I agree that my finding on perturbing bos_token cannot be a default option for bart-large since it will affect a lot of users, maybe until we finally figure out why the perturbing works.

@SivilTaram
Copy link
Contributor

SivilTaram commented Mar 10, 2022

Thanks! I am still curious about:

  1. Why does this issue happen with bart-large but not bart-base?
  2. Is this issue only related to the Hugging Face model, or affects the model in the original Facebook repository as well?
  1. This is the most tricky part in this bug. I do not know if the official authors have performed different training strategies on bart-large or bart-base, but the fact is that the token bos_token makes bart-large hard to optimize in the context of 🤗 transformers.
  2. Good catch! I think it should not be only related to 🤗 . I have also seen many issues #12237 complaining about the abnormal results after fine-tuning bart-large related models on some NLG datasets. If it is finally a optimization issue, I think we may encounter the same issue in fairseq after trying enough random seeds. However, everything in fairseq under the default random seed seems okay for me by now.

@SivilTaram
Copy link
Contributor

Thank you @SivilTaram for this detailed explanation. Wonderful effort.

Could this problem be related to mBART? because mBART has only a large-scale version, not a base scale and I think the first token is designed for language code. see #9811

Good point @salrowili ! I have no idea now, but I agree that the first token is originally desgined to serve for mBART instead of BART.

@salrowili
Copy link

salrowili commented Mar 14, 2022

what is your loss score on both, not EM? I am curious to know

@JulesGM
Copy link

JulesGM commented Mar 14, 2022

This might have been a false alarm, let me triple check

@JulesGM
Copy link

JulesGM commented Mar 14, 2022

Yeah false alarm my bad folks.

@JulesGM
Copy link

JulesGM commented Mar 14, 2022

Actually right now with the same transformers bart for conditional generation instance, AllenAI's beam search gives me near 100% EM and hugging face's gives me very low EM, with a much lower per token accuracy (~65%). Both with num_beams of 4 and max length of 500, trying to figure out the difference...

@JulesGM
Copy link

JulesGM commented Mar 14, 2022

Ok found the cause of the difference, the huggingface generate reads a no_repeat_ngram_size of 3 from the bart config and the AllenAI decoder does not look at it.

@patrickvonplaten
Copy link
Contributor

Can also confirm that gradients of bart-large are generally very high which might lead to unexpected behavior. See: https://discuss.huggingface.co/t/gradients-verification-between-jax-flax-models-and-pytorch/15970

@SivilTaram
Copy link
Contributor

Can also confirm that gradients of bart-large are generally very high which might lead to unexpected behavior. See: https://discuss.huggingface.co/t/gradients-verification-between-jax-flax-models-and-pytorch/15970

Good job! This reinforces my belief that it is bart-large itself that is sensitive to optimization, and not a bug in the codebase.

@patrickvonplaten
Copy link
Contributor

We just found out that bart-large has its weights accidently stored in fp16 on the Hub see: #16736

This might be a reason for this behavior here

@github-actions
Copy link

github-actions bot commented May 7, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 16, 2022

There is a similar issue with led-large, which has the same weight as bart-large. See my comment here.

In short, no matter what the encoder input sequences are, the decoder sequence [</s>, <s>] will produce identical (differences are in the range ~1e-6) LM logits (before any finetuning). This explains why we have training trouble, as well as why perturbing bos_token makes sense. However it's not clear why this happens.

I am not very convinced by @SivilTaram's reasoning on the norms of <eos> and <bos>. As if they have larger differences in norm for bart-large, it means they have more different embeddings as inputs, but the output LM logits are very close at the end. This looks very strange.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 16, 2022

I have verified the same situation occurs for bart-large using HuggingFace's transformers.

Furthermore, the following code snippet confirms the same for fairseq's bart-large.

import torch
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
bart.eval()  # disable dropout (or leave in train mode to finetune)

model = bart.model

#print(model.decoder.embed_tokens.weight[0][:16])
#print(model.decoder.embed_tokens.weight[2][:16])

for _ in range(20):
    # random encoder input sequences with random length
    seq_len = torch.randint(low=8, high=64, size=(1,))[0]
    src_tokens = torch.randint(low=0, high=50265, size=(1, seq_len), dtype=torch.int32)
    src_tokens = torch.cat([torch.tensor([[0]], dtype=torch.int32), src_tokens, torch.tensor([[2]], dtype=torch.int32)], dim=-1)

    src_lengths = seq_len + 2
    prev_output_tokens = torch.tensor([[2, 0]], dtype=torch.int32)

    o = model(
        src_tokens=src_tokens,
        src_lengths=src_lengths,
        prev_output_tokens=prev_output_tokens,
    )
    print(o[0].shape)
    print(o[0][0, :2, :16])

The results:

tensor([[11.3236, -0.9138,  6.5245, -0.3914,  9.7808,  2.2912,  8.1962,  1.9227,
          4.7425,  3.1021,  2.6246,  3.7791,  9.5079,  2.6686,  2.1119,  2.0344],
        [11.3236, -0.9138,  6.5245, -0.3914,  9.7808,  2.2912,  8.1962,  1.9227,
          4.7425,  3.1021,  2.6246,  3.7791,  9.5079,  2.6686,  2.1119,  2.0344]],
       grad_fn=<SliceBackward0>)
torch.Size([1, 2, 50265])
tensor([[10.7945, -1.0372,  7.0184, -0.3236, 10.4561,  2.6929,  8.5290,  3.1237,
          5.1213,  3.5141,  2.8864,  3.9884, 10.1654,  3.6481,  1.8212,  2.2971],
        [10.7945, -1.0372,  7.0184, -0.3236, 10.4561,  2.6929,  8.5290,  3.1237,
          5.1213,  3.5141,  2.8864,  3.9884, 10.1654,  3.6481,  1.8212,  2.2971]],
       grad_fn=<SliceBackward0>)
torch.Size([1, 2, 50265])
tensor([[ 9.5601, -1.0508,  6.5913, -1.1130, 10.1133,  2.0899,  8.3389,  2.9828,
          4.7401,  3.2962,  1.6522,  3.3447, 10.1464,  3.2324,  2.4950,  1.9132],
        [ 9.5601, -1.0508,  6.5913, -1.1130, 10.1133,  2.0899,  8.3389,  2.9828,
          4.7401,  3.2962,  1.6522,  3.3447, 10.1464,  3.2324,  2.4950,  1.9132]],
       grad_fn=<SliceBackward0>)
torch.Size([1, 2, 50265])
tensor([[13.9500, -1.1926,  8.6385, -1.3860, 10.3129,  2.6660,  9.1016,  3.1353,
          5.2579,  3.4807,  2.5914,  3.9124, 11.1577,  3.0453,  1.7675,  2.6812],
        [13.9500, -1.1926,  8.6385, -1.3860, 10.3129,  2.6660,  9.1016,  3.1353,
          5.2579,  3.4807,  2.5914,  3.9124, 11.1577,  3.0453,  1.7675,  2.6812]],
       grad_fn=<SliceBackward0>)
torch.Size([1, 2, 50265])
tensor([[16.2144, -0.7714,  9.2828,  0.7236, 11.1654,  1.8226,  9.0874,  2.7226,
          5.5879,  3.4060,  1.7666,  4.0211, 11.2787,  2.9147,  1.6007,  1.7492],
        [16.2144, -0.7714,  9.2828,  0.7236, 11.1654,  1.8226,  9.0874,  2.7226,
          5.5879,  3.4060,  1.7666,  4.0211, 11.2787,  2.9147,  1.6007,  1.7492]],
       grad_fn=<SliceBackward0>)

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 17, 2022

For the record: bart-large seems learned to predict the first token after <s> in the encoder input sequence, for both the first two decoder tokens [</s>, <s>]. I provide a script to confirm this at the end.

Here are some outputs

example idx: 1
max diff in lm logits: 1.621246337890625e-05
--------------------
predicted token ids: [1640, 1640, 13989, 212, 4038]
predicted tokens: ['(', '(', 'ĠMLS', 'th', 'Ġanniversary']
document tokens: ['<s>', '(', 'CNN', ')', 'On', 'Ġthe', 'Ġ6', 'th', 'Ġof', 'ĠApril', 'Ġ1996', ',', 'ĠSan', 'ĠJose', 'ĠClash', 'Ġand']
========================================
example idx: 10
max diff in lm logits: 7.033348083496094e-06
--------------------
predicted token ids: [11770, 11770, 16, 6308, 5678]
predicted tokens: ['March', 'March', 'Ġis', 'Ġcontains', 'Ġlinks']
document tokens: ['<s>', 'March', 'Ġ10', ',', 'Ġ2015', 'Ġ.', 'ĠWe', "'re", 'Ġtruly', 'Ġinternational', 'Ġin', 'Ġscope', 'Ġon', 'ĠTuesday', '.', 'ĠWe']
========================================
example idx: 20
max diff in lm logits: 7.62939453125e-06
--------------------
predicted token ids: [41650, 41650, 11, 1429, 224]
predicted tokens: ['Tok', 'Tok', 'Ġin', 'ĠJapan', 'Ġsay']
document tokens: ['<s>', 'Tok', 'yo', 'Ġ(', 'CNN', ')', 'Police', 'Ġin', 'ĠJapan', 'Ġsay', 'Ġthey', 'Ġhave', 'Ġarrested', 'Ġa', 'Ġ40', '-']
========================================
example idx: 24
max diff in lm logits: 9.5367431640625e-06
--------------------
predicted token ids: [23122, 23122, 52, 9, 10]
predicted tokens: ['London', 'London', 'Ġwe', 'Ġof', 'Ġa']
document tokens: ['<s>', 'London', 'Ġ(', 'CNN', ')', 'A', 'Ġphoto', 'Ġof', 'Ġa', 'Ġwe', 'asel', 'Ġh', 'itching', 'Ġa', 'Ġsurprise', 'Ġlift']
========================================
example idx: 36
max diff in lm logits: 8.106231689453125e-06
--------------------
predicted token ids: [4030, 4030, 2238, 18, 7396]
predicted tokens: ['New', 'New', 'ĠKorean', "'s", 'izes']
document tokens: ['<s>', 'New', 'ĠDelhi', ',', 'ĠIndia', 'Ġ(', 'CNN', ')', 'The', 'ĠNorth', 'ĠKorean', 'Ġambassador', 'Ġin', 'ĠBangladesh', 'Ġissued', 'Ġan']
import numpy as np
import torch

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import datasets

summarization_name_mapping = {
    "cnn_dailymail": ("article", "highlights"),
    "xsum": ("document", "summary"),
}


def get_dataset(dataset_name, dataset_config, tokenizer, n_samples):

    max_source_length = 1024
    max_target_length = 128
    padding = True
    ignore_pad_token_for_loss = True
    padding = "max_length"
    prefix = ""
    max_train_samples = n_samples
    max_eval_samples = n_samples
    preprocessing_num_workers = 8

    raw_datasets = datasets.load_dataset(dataset_name, dataset_config)

    text_column, summary_column = summarization_name_mapping[dataset_name]

    def foo(x):

        if x == tokenizer.cls_token_id:
            return 1
        elif x == tokenizer.pad_token_id:
            return -1
        else:
            return 0

    def preprocess_function(examples):
        # remove pairs where at least one record is None

        inputs, targets = [], []
        for i in range(len(examples[text_column])):
            if examples[text_column][i] and examples[summary_column][i]:
                inputs.append(examples[text_column][i])
                targets.append(examples[summary_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

        # Tokenize targets with the `text_target` keyword argument
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]

        if tokenizer.__class__.__name__.startswith("LED"):
            model_inputs["global_attention_mask"] = [[foo(y) for y in x] for x in model_inputs["input_ids"]]

        return model_inputs

    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["validation"]

    train_dataset = train_dataset.select(range(max_train_samples))
    eval_dataset = eval_dataset.select(range(max_eval_samples))

    train_dataset = train_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=[text_column, summary_column, 'id'],
        desc="Running tokenizer on train dataset",
    )
    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=[text_column, summary_column, 'id'],
        desc="Running tokenizer on validation dataset",
    )

    return train_dataset, eval_dataset


def check_model(train_dataset, eval_dataset, model, tokenizer, n_samples, text_column, summary_column):

    for idx, eval_example in enumerate(eval_dataset):

        input_ids = eval_example["input_ids"]

        decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=torch.tensor([eval_example["labels"]], dtype=torch.int32))
        decoder_input_ids = decoder_input_ids.numpy().tolist()
        eval_example["decoder_input_ids"] = decoder_input_ids[0]  # remove batch dim

        eval_example.pop("labels")

        decoder_input_ids = eval_example.pop("decoder_input_ids")
        eval_example["decoder_input_ids"] = [2, 0] + decoder_input_ids[2:5]

        for k in eval_example:
            eval_example[k] = torch.tensor([eval_example[k]], dtype=torch.int32)

        output = model(**eval_example)

        print(f"example idx: {idx}")

        print(f'max diff in lm logits: {np.amax(np.abs((output.logits[0, 0] - output.logits[0, 1]).detach().to("cpu").numpy()))}')
        print(f"-" * 20)

        pred_ids = torch.argmax(output.logits, dim=-1).detach().to("cpu").numpy().tolist()[0]
        print(f'predicted token ids: {pred_ids}')

        pred_tokens = tokenizer.convert_ids_to_tokens(pred_ids)
        print(f'predicted tokens: {pred_tokens}')

        document_tokens = tokenizer.convert_ids_to_tokens(input_ids)
        print(f'document tokens: {document_tokens[:16]}')

        print(f"=" * 40)


def run(checkpoint_name, dataset_name, dataset_config=None, n_samples=100):

    text_column, summary_column = summarization_name_mapping[dataset_name]

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_name)

    train_dataset, eval_dataset = get_dataset(dataset_name, dataset_config=dataset_config, tokenizer=tokenizer, n_samples=n_samples)
    check_model(train_dataset, eval_dataset, model, tokenizer, n_samples, text_column, summary_column)


run("facebook/bart-large", "cnn_dailymail", "3.0.0", n_samples=100)
#run("facebook/bart-large", "xsum", None, n_samples=10)

#run("allenai/led-large-16384", "cnn_dailymail", "3.0.0", n_samples=10)
#run("allenai/led-large-16384", "xsum", None, n_samples=10)

@mikelewis0
Copy link

mikelewis0 commented Oct 28, 2022

I just saw this thread - sorry for all the pain here! The problem was caused by an unfortunate config bug in the original BART-large training run, which caused decoder sequences to start with an extra </ s> token

I assume that got fixed in BART-base, which is why it's behaving differently.

@elronbandel
Copy link

I just saw this thread - sorry for all the pain here! The problem was caused by an unfortunate config bug in the original BART-large training run, which caused decoder sequences to start with an extra </ s> token

I assume that got fixed in BART-base, which is why it's behaving differently.

Thank you so much for clarifying it @mikelewis0 well appreciated!

@ayaka14732
Copy link
Contributor

ayaka14732 commented Oct 28, 2022

I just saw this thread - sorry for all the pain here! The problem was caused by an unfortunate config bug in the original BART-large training run, which caused decoder sequences to start with an extra </ s> token

I assume that got fixed in BART-base, which is why it's behaving differently.

I think you misinterpreted the cause. The extra </s> token is intended and it is also used in bart-base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.