In [None]:
import transformers
import torch
import os
import random
import re

model_name = "gpt2-xl"


In [None]:
# Go back one levels to the root of the project
os.chdir("..")
os.getcwd()

In [None]:
with torch.no_grad():
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name).cuda()
tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name, use_fast=False)


In [None]:
# Loading the starting words
words_generic_path = "data/starting_words/words_generic.txt"
words_catlike_path = "data/starting_words/words_catlike.txt"
words_creative_path = "data/starting_words/words_creative.txt"
with open(words_generic_path, "r") as f:
    words_generic = f.read().splitlines()

with open(words_catlike_path, "r") as f:
    words_catlike = f.read().splitlines()

with open(words_creative_path, "r") as f:
    words_creative = f.read().splitlines()

# Combine all the words into one list and sample 3 words
words = words_generic + words_catlike + words_creative


In [None]:
# Computing banned tokens, i.e. any tokens that contain characters that are not alphanumeric, comma, space, period
all_tokens = [tokenizer.decode(token_id) for token_id in list(range(50257))] 
# Regex to check for alphanumeric, comma, space, period
regex_general = re.compile('^[a-zA-Z0-9, .]+$')
# Keep the tokens that fail the regex
banned_tokens_general_ids = []
for token in all_tokens:
    if not regex_general.match(token):
        banned_tokens_general_ids.append(token)

# Encode the tokens and convert to a tensor
banned_tokens_general_ids = [tokenizer.encode(token)[0] for token in banned_tokens_general_ids]
banned_tokens_general_ids = torch.tensor(banned_tokens_general_ids)

# Same, but only alphanumeric and space
regex_first = re.compile('^[a-zA-Z0-9 ]+$')
banned_tokens_first_ids = []
for token in all_tokens:
    if not regex_first.match(token):
        banned_tokens_first_ids.append(token)

banned_tokens_first_ids = [tokenizer.encode(token)[0] for token in banned_tokens_first_ids]
banned_tokens_first_ids = torch.tensor(banned_tokens_first_ids)


In [None]:
def sample_top_k_without_replacement(model, tokenizer, text, banned_tokens_ids, top_k=10, num_samples=3):
    """ Given some text, sample num_samples next words, without replacement, 
    from the top k most likely words (weighted by their softmaxed logits).
    """
    # Encode the text
    input_ids = tokenizer.encode(text, return_tensors="pt").cuda()
    # Use the model to get the logits for the last word
    logits = model(input_ids).logits[0][-1]
    # Convert to probabilities
    logits = torch.softmax(logits, dim=0)
    # Zero out any tokens that are not in the allowed tokens
    logits[banned_tokens_ids] = 0
    
    # Get the top k words (probabilities and indices)
    top_k_tokens = torch.topk(logits, top_k)
    # Sample 3 words without replacement
    sampled_indices = torch.multinomial(top_k_tokens.values, num_samples, replacement=False)
    # Get the indices of the top k words
    top_k_tokens = top_k_tokens.indices[sampled_indices]

    # Decode the tokens
    next_tokens = []
    for token in top_k_tokens:
        next_tokens.append(tokenizer.decode(token))
    return next_tokens

In [None]:
def generate_entire_word_simple(model, tokenizer, context, current_word, banned_tokens_ids, top_k=10):
    """ Simple case where assume have a starter token
    Use this when we are generating a word from "scratch" not trying to continue a user's current word
    """
    context = context + current_word
    for _ in range(5):
        # Generate a next token
        next_token = sample_top_k_without_replacement(model, tokenizer, context, banned_tokens_ids, top_k=top_k, num_samples=1)[0]
        # If the first character is a space, comma, or period, break
        if next_token[0] in [",", ".", " "]:
            break
        # Otherwise it is a continuation of the word and we wnat to keep it
        else:
            context += next_token
            current_word += next_token
    return current_word


In [None]:
# End to end game example:
curr_tweet = ""

# Pick the first word randomly from the starting lists
current_choices = random.sample(words, 3)

while True:
    print(f"Current tweet: {curr_tweet}")
    print(f"Options:\n 1. {current_choices[0]}\n 2. {current_choices[1]}\n 3. {current_choices[2]}")

    # Get the user's choice
    user_choice = input()
    # Check what the user chose
    if user_choice == "1":
        curr_tweet += current_choices[0]
    elif user_choice == "2":
        curr_tweet += current_choices[1]
    elif user_choice == "3":
        curr_tweet += current_choices[2]
    else:
        curr_tweet += user_choice

    # Generate 3 possibilities to kick off the next word
    first_next_tokens = sample_top_k_without_replacement(model, tokenizer, curr_tweet, banned_tokens_first_ids, top_k=20, num_samples=3)
    # Then finish generating each word
    next_words = []
    for token in first_next_tokens:
        next_words.append(generate_entire_word_simple(model, tokenizer, curr_tweet, token, banned_tokens_general_ids, top_k=20))
    current_choices = next_words
