Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART: <mask> token ID is outside vocab bounds #3108

Closed
tomhosking opened this issue Mar 3, 2020 · 4 comments
Closed

BART: <mask> token ID is outside vocab bounds #3108

tomhosking opened this issue Mar 3, 2020 · 4 comments
Assignees

Comments

@tomhosking
Copy link
Contributor

@tomhosking tomhosking commented Mar 3, 2020

馃悰 Bug

Information

Model I am using (Bert, XLNet ...): BART

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

from transformers import BartForMaskedLM, BartTokenizer
from transformers.configuration_bart import BartConfig

config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config)
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')


ARTICLE_TO_SUMMARIZE = "My friends are <mask> but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')

generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_return_sequences=4)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])

Expected behavior

I'd expect some sort of infilling to occur, but instead I see the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-13-bad65359ada6> in <module>
     10 inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
     11 
---> 12 generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_return_sequences=4)
     13 print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])

~/.local/lib/python3.6/site-packages/torch/autograd/grad_mode.py in decorate_no_grad(*args, **kwargs)
     47         def decorate_no_grad(*args, **kwargs):
     48             with self:
---> 49                 return func(*args, **kwargs)
     50         return decorate_no_grad
     51 

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in generate(self, input_ids, attention_mask, max_length, num_beams, repetition_penalty, length_penalty, num_return_sequences, min_len, no_repeat_ngram_size)
   1106                 input_ids, decoder_cache, decoder_input_ids, attention_mask,
   1107             )
-> 1108             outputs = self(**model_inputs)
   1109             lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
   1110 

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, decoder_cached_states, lm_labels, **unused)
    932             encoder_outputs=encoder_outputs,
    933             decoder_attention_mask=decoder_attention_mask,
--> 934             decoder_cached_states=decoder_cached_states,
    935         )
    936         lm_logits = self.lm_head.forward(outputs[0])

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, decoder_cached_states)
    837         assert decoder_input_ids is not None
    838         if encoder_outputs is None:
--> 839             encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask)
    840         assert isinstance(encoder_outputs, tuple)
    841         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)

~/.local/lib/python3.6/site-packages/transformers/modeling_bart.py in forward(self, input_ids, attention_mask)
    272                 During training might not be of length n_layers because of layer dropout.
    273         """
--> 274         inputs_embeds = self.embed_tokens(input_ids)
    275         embed_pos = self.embed_positions(input_ids)
    276         x = inputs_embeds + embed_pos

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/.local/lib/python3.6/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    112         return F.embedding(
    113             input, self.weight, self.padding_idx, self.max_norm,
--> 114             self.norm_type, self.scale_grad_by_freq, self.sparse)
    115 
    116     def extra_repr(self):

~/.local/lib/python3.6/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1482         # remove once script supports set_grad_enabled
   1483         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1485 
   1486 

RuntimeError: index out of range: Tried to access index 50264 out of table with 50263 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

Looks to me like the <mask> token ID (50264) is out of bounds?

Environment info

  • transformers version: a088d75
  • Platform: Ubuntu 18.04
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.3.1 (Y)
  • Tensorflow version (GPU?): N/A
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No
@tomhosking tomhosking changed the title <mask> token ID is outside vocab bounds BART: <mask> token ID is outside vocab bounds Mar 3, 2020
@acarrera94

This comment has been minimized.

Copy link
Contributor

@acarrera94 acarrera94 commented Mar 3, 2020

if you install from master it seems to work on 'bart-large'.
Seems like it's only an issue on 'bart-large-cnn'

tokenizer = BartTokenizer.from_pretrained('bart-large')
model = BartForMaskedLM.from_pretrained('bart-large',output_past=True)

ARTICLE_TO_SUMMARIZE = "My friends are <mask> but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_return_sequences=4)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in generated_ids])

output:

['My kids are good, but they eat too many carbs. My friends are good.', 'My kids are good, but they eat too many carbs. My friends are good.', 'My kids are good, but they eat too many carbs. My friends are good.', 'My kids are good, but they eat too many carbs. My friends are good.']
@sshleifer

This comment has been minimized.

Copy link
Member

@sshleifer sshleifer commented Mar 3, 2020

Bart-large-cnn doesn't have a mask_token_id, which is admittedly confusing.

this is how I would do mask filling

model = BartForMaskedLM.from_pretrained('bart-large')
tokenizer = AutoTokenizer.from_pretrained('bart-large')
ARTICLE_TO_SUMMARIZE = "My friends are <mask> but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], return_tensors='pt')
input_ids = inputs['input_ids']
#generated_ids = model(, attention_mask=inputs['attention_mask'])[0]
logits = model(input_ids)[0]
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(10)
tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very', 'healthy', 'also', 'not', 'the', 'doing']
@sshleifer

This comment has been minimized.

Copy link
Member

@sshleifer sshleifer commented Mar 3, 2020

One liner courtesy of @julien-c

from transformers import pipeline
nlp = pipeline('fill-mask', 'bart-large')
nlp("My friends are <mask> but they eat too many carbs.")
@sshleifer sshleifer added the wontfix label Mar 3, 2020
@tomhosking

This comment has been minimized.

Copy link
Contributor Author

@tomhosking tomhosking commented Mar 4, 2020

Thanks @sshleifer, that will do the trick!

The following does work:

tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
tokenizer.mask_token_id
>>> 50264

...which is a bit counterintuitive as it implies that <mask> is available. It's also not clear from the docs that bart-large can be used successfully with BartForMaskedLM.

@tomhosking tomhosking closed this Mar 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
4 participants
You can鈥檛 perform that action at this time.