Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoding with DistilmBERT to generate text in different languages #4563

Closed
javismiles opened this issue May 24, 2020 · 3 comments
Closed

Decoding with DistilmBERT to generate text in different languages #4563

javismiles opened this issue May 24, 2020 · 3 comments
Labels

Comments

@javismiles
Copy link

javismiles commented May 24, 2020

Good day and congrats for your great library

If I want to decode and get new generated text with the GPT2 heads, that works great like you suggest:

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time there was")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

but my issue is that now I want to do the same but with the smaller simpler DistilmBERT model which is also multilingual in 104 languages, so I want to generate text in for example Spanish and English and with this lighter model. So I do this:

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-multilingual-cased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-multilingual-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)
loss, prediction_scores = outputs[:2]

but now, how do I get the continuation of the phrase at that point? I tried to apply tokenizer.decode with no luck there, thank you

@javismiles
Copy link
Author

javismiles commented May 25, 2020

So I can get the generation working well with distilgpt2, the thing is that I would like to do it multilingual using the light multilingual model DistilmBERT (distilbert-base-multilingual-cased), any tips? thank you :)

import torch
from transformers import *
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Once upon a time")).unsqueeze(0)
model = GPT2LMHeadModel.from_pretrained("distilgpt2", pad_token_id=tokenizer.eos_token_id)
greedy_output = model.generate(input_ids, max_length=50) #greedy search

sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=50, 
    top_k=50, 
    top_p=0.95, 
    temperature=1,
    num_return_sequences=3
)

print("Output:\n" + 100 * '-')
for i, sample_output in enumerate(sample_outputs):
  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

@LysandreJik
Copy link
Member

Hi, I took the liberty of editing your comments with triple backticks ```py``` to be more readable.

Unfortunately DistilmBERT can't be used for generation. This is due to the way the original BERT models were pre-trained, using masked language modeling (MLM). It therefore attends to both the left and right contexts (tokens on the left and right of the token you're trying to generate), while for generation the model only has access to the left context.

GPT-2 was trained with causal language modeling (CLM), which is why it can generate such coherent sequences. We implement the generation method only for CLM models, as MLM models do not generate anything coherent.

@stale
Copy link

stale bot commented Jul 25, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 25, 2020
@stale stale bot closed this as completed Aug 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants