In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
from tqdm import tqdm

In [9]:
import time

In [3]:
# Load pre-trained model tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Set padding token
tokenizer.pad_token = tokenizer.eos_token

# Load pre-trained model (weights)
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [4]:
model = model.eval()

In [5]:
from chunk_continuation import chunk_continuation

In [6]:
setup_sequence = [
    {'prefix': "My name is ", 
     'candidate_set': ["John", "Jane", "Alice", "Bob", "Charlie"],
     'suffix': '.'},
    {'prefix': "Outside it is ", 
     'candidate_set': ["sunny", "rainy", "snowy", "cold", "hot"],
     'suffix': '.'},
    {'prefix': "I should ",
     'candidate_set': ["go to the gym", "go for a run outside"],
     'suffix': '.'},
    {'prefix': "This will ",
     'candidate_set': ["improve my marathon record", "worsen my marathon record"],
     'suffix': '.'}
]

In [10]:
start_time = time.time()

n_samples = 100
samples = []

for i in tqdm(range(n_samples)):
    prefix = ""

    for setup in setup_sequence:
        prefix = prefix + ' ' + setup['prefix']

        sampled_text = chunk_continuation(model, 
                                          tokenizer, 
                                          prefix, 
                                          setup['candidate_set'],
                                          setup['suffix'],
                                          sum=False,
                                          verbose=False)
        prefix = sampled_text
    
    samples.append(sampled_text)

print(f"Generated {n_samples} samples in {time.time() - start_time:.2f} seconds")

100%|██████████| 100/100 [01:43<00:00,  1.03s/it]

Generated 100 samples in 103.03 seconds





In [12]:
# sort `samples` alphabetically and print them.
samples.sort()

for sample in samples:
    print(sample)

My name is Alice. Outside it is hot. I should go for a run outside. This will improve my marathon record.
My name is Alice. Outside it is hot. I should go for a run outside. This will worsen my marathon record.
My name is Alice. Outside it is hot. I should go to the gym. This will worsen my marathon record.
My name is Alice. Outside it is rainy. I should go for a run outside. This will improve my marathon record.
My name is Alice. Outside it is snowy. I should go for a run outside. This will improve my marathon record.
My name is Alice. Outside it is snowy. I should go for a run outside. This will worsen my marathon record.
My name is Alice. Outside it is snowy. I should go for a run outside. This will worsen my marathon record.
My name is Alice. Outside it is snowy. I should go to the gym. This will worsen my marathon record.
My name is Alice. Outside it is sunny. I should go to the gym. This will improve my marathon record.
My name is Alice. Outside it is sunny. I should go to the gy