Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2 text generation repeat #1725

Closed
aclifton314 opened this issue Nov 4, 2019 · 4 comments
Closed

GPT2 text generation repeat #1725

aclifton314 opened this issue Nov 4, 2019 · 4 comments

Comments

@aclifton314
Copy link

❓ Questions & Help

SYSTEM
OS: Linux pop-os 5.0.0
Python version: 3.6.8
Torch version: 1.3.0
Transformers version: 2.1.1
I am running this linux VM with the above software versions on a Windows 10 laptop.

I am running the following code:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

sentence = 'Natural language processing tasks are typically approached with'
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
context_tokens = tokenizer.encode(sentence, add_special_tokens=False)
context = torch.tensor(context_tokens, dtype=torch.long)
num_samples = 1
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
length = 20
with torch.no_grad():
    for jj in range(5):
        for _ in range(length):
            outputs = model(generated)
            next_token_logits = outputs[0][:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            generated = torch.cat((generated, next_token), dim=1)


out = generated
out = out[:, len(context_tokens):].tolist()
for o in out:
    text = tokenizer.decode(o, clean_up_tokenization_spaces=True)

What I was noticing was that GPT2 starts to produce repetitive text (see below) with this approach. I am not sure the best way to prevent this from happening and was wondering if others had any ideas? Thank you in advance!

OUTPUT

a single task, such as a word search, and the task is then repeated. The task is then repeated for each word in the search.

The task is then repeated for each word in the search. The task is then repeated for each word in the search. The task is then repeated for each word in the search. The task is then repeated for each word in the search. The task is then repeated for each word in the search. The task is then repeated for each word in
@TheEdoardo93
Copy link

Adding temperature (in brief, Temperature is a hyperparameter of LSTMs - and neural networks generally - used to control the randomness of predictions by scaling the logits before applying softmax) could be an interesting way!

Here is a modified version of your code with temperature:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch.nn.functional as F

sentence = 'Natural language processing tasks are typically approached with'
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
context_tokens = tokenizer.encode(sentence, add_special_tokens=False)
context = torch.tensor(context_tokens, dtype=torch.long)
num_samples = 1
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context

model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()
length = 20
temperature = 0.8 # ADD TEMPERATURE PARAMETER!
with torch.no_grad():
	for jj in range(5):
		for _ in range(length):
			outputs = model(generated)
			next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) ### CHANGE THIS ROW
			next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) ### CHANGE THIS ROW
			generated = torch.cat((generated, next_token), dim=1)


out = generated
out = out[:, len(context_tokens):].tolist()
for o in out:
	text = tokenizer.decode(o, clean_up_tokenization_spaces=True)

	print(text)

The output is the following:

a hand for 1-10 minutes. However, we had recently seen that a small set of tasks can be used to process many different languages in a short period of time. We had designed the program from scratch. The purpose of the program was to generate as many variables and as many basic rules as possible. Each rule got its own "factory". Each register gets its own "rules". The terms used are:<|endoftext|>Intel's #1-Buying Power-Technology

Obviously, you can change seed and temperature itself too!

@aclifton314
Copy link
Author

@TheEdoardo93 Thanks for the feedback! Closing this issue.

@drizzt00s
Copy link

just have the same issue, anyone knows how to solve it? thx!

@aclifton314
Copy link
Author

aclifton314 commented Sep 1, 2020

@drizzt00s Since this posting, HF has put out a fantastic blog about generating text utilizing different sampling methods. I highly recommend it. It's well written!

https://huggingface.co/blog/how-to-generate

Give that a read and see if it helps you out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants