You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<PAD>')
prompt_text = [
"are there any good coaching institutes for civil services preparations in bangalore? ->"]
If I try to generate phrases in batches with the corresponding attention mask it doesn't work. It outputs the input phrase without any new words on it:
# encode plus batch handles multiple batches and automatically creates attention_masks
seq_len = 100
encodings_dict = tokenizer.batch_encode_plus(prompt_text, max_length=seq_len, pad_to_max_length=True)
input_ids = torch.tensor(encodings_dict['input_ids'])
attn_mask = torch.tensor(encodings_dict['attention_mask'])
encoded_result = model.generate(input_ids, attention_mask=attn_mask, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, num_return_sequences=10, top_k=50, top_p=0.95, do_sample=True, max_length=100)
for er in encoded_result:
print(tokenizer.decode(er, skip_special_tokens=True))
However, if I generate phrases one by one (without batches) then it works:
Batch generation is sadly currently not implemented in the .generate() method. Also, see #3021 for reasons why. It's on our roadmap to implement this functionality soon :-)
Assuming these are my input phrases and model:
If I try to generate phrases in batches with the corresponding attention mask it doesn't work. It outputs the input phrase without any new words on it:
However, if I generate phrases one by one (without batches) then it works:
Details
Any ideas what could be causing this problem?
Thanks!!
The text was updated successfully, but these errors were encountered: