Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why I can't generate phrases in batches if I include an attention mask? (GPT2) #4746

Closed
Barbara931120 opened this issue Jun 3, 2020 · 2 comments

Comments

@Barbara931120
Copy link

Barbara931120 commented Jun 3, 2020

Assuming these are my input phrases and model:

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<PAD>')

prompt_text = [
    "are there any good coaching institutes for civil services preparations in bangalore? ->"]

If I try to generate phrases in batches with the corresponding attention mask it doesn't work. It outputs the input phrase without any new words on it:

# encode plus batch handles multiple batches and automatically creates attention_masks
seq_len = 100
encodings_dict = tokenizer.batch_encode_plus(prompt_text, max_length=seq_len, pad_to_max_length=True)

input_ids = torch.tensor(encodings_dict['input_ids'])
attn_mask = torch.tensor(encodings_dict['attention_mask'])

encoded_result = model.generate(input_ids, attention_mask=attn_mask, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, num_return_sequences=10, top_k=50, top_p=0.95, do_sample=True, max_length=100)

for er in encoded_result:
    print(tokenizer.decode(er, skip_special_tokens=True))

However, if I generate phrases one by one (without batches) then it works:

encoded_text = tokenizer.encode(prompt_text[0], return_tensors='pt')
encoded_result = model.generate(encoded_text,eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, num_return_sequences=10, top_k=50, top_p=0.95, do_sample=True, max_length=100)
print(tokenizer.decode(encoded_result[0], skip_special_tokens=True))

Details

Any ideas what could be causing this problem?

Thanks!!

@LysandreJik
Copy link
Member

Probably of interest to @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Hi @Barbara931120,

Batch generation is sadly currently not implemented in the .generate() method. Also, see #3021 for reasons why. It's on our roadmap to implement this functionality soon :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants