In [1]:
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    StoppingCriteriaList
)

from transformers_controllers import (
    SuffixCriteria,
    GoodPhrasesLogitsProcessor,
    ConstantLogitsWarper
)

In [2]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')

In [3]:
def generate(
    prompt,
    seed,
    stopping_criteria=None,
    logits_processor=None,
    logits_warper=None,
    max_length=30
):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    
    torch.manual_seed(seed)
    with torch.no_grad():
        output = model.sample(
            input_ids,
            logits_processor=logits_processor,
            logits_warper=logits_warper,
            stopping_criteria=stopping_criteria,
            pad_token_id=tokenizer.eos_token_id,
            max_length=max_length
        )

    return tokenizer.decode(output[0], skip_special_tokens=True)

In [4]:
# Stop the generation when it hits either of these punctuations.
stopping_criteria = StoppingCriteriaList([
    SuffixCriteria([
        tokenizer.encode(suffix) for suffix in ['.', '!', '?', '...']
    ])
])

In [5]:
# Only use these words in the generated output.
logits_processor = LogitsProcessorList([
    GoodPhrasesLogitsProcessor([
        tokenizer.encode(phrase) for phrase in [
            '!', ',', '.',
            # GPT2 tokenizer includes leading whitespace in the tokens.
            ' saw', ' morning', ' bird',
            ' lion', ' I', ' a', ' the', ' in',
        ]
    ])
])

In [6]:
deltas = torch.zeros(tokenizer.vocab_size)
# Give lion higher chance to appear.
deltas[tokenizer.encode(' lion')] = 2.5
# Also try to end the sentence earlier.
deltas[tokenizer.encode('.')] = 2.5

logits_warper = LogitsProcessorList([
    ConstantLogitsWarper(deltas)
])

In [7]:
seed = 256

In [8]:
prompt = 'This morning, when I was walking in the park, I looked up and'
prompt

'This morning, when I was walking in the park, I looked up and'

In [9]:
# Uncontrolled output from GPT2.
generate(prompt, seed)

'This morning, when I was walking in the park, I looked up and saw a Rita Skeeter painting. I was wearing my suit that day in'

In [10]:
# Stop when the first sentence completes.
generate(prompt, seed, stopping_criteria)

'This morning, when I was walking in the park, I looked up and saw a Rita Skeeter painting.'

In [11]:
# Use only words in our word list.
generate(prompt, seed, stopping_criteria, logits_processor)

'This morning, when I was walking in the park, I looked up and saw a bird, I saw the bird.'

In [12]:
# I like lions.
generate(prompt, seed, stopping_criteria, logits_processor, logits_warper)

'This morning, when I was walking in the park, I looked up and saw a lion.'