In [34]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch.nn.functional as F

In [70]:
model_name = 'gpt2-medium'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [59]:
words = [
	'Platypus',
	'Penguin',
	'Whale',
	'Dachshund',
	'Umami',
]

riddles = [
	("Which arctic bird doesn't fly?", "Penguin"),
	("Which mammal lays eggs?", "Platypus"),
	("What is the newest found taste?", "Umami"),
	("What is the largest animal on earth?", "Whale"),
	("Which dog breed has the lowest chassis?", "Dachshund")
]

tokenized_words = list(map(lambda x: tokenizer.tokenize(x), words))
for tw in tokenized_words:
	print(tw)

['Pl', 'at', 'yp', 'us']
['P', 'engu', 'in']
['Wh', 'ale']
['D', 'ach', 'sh', 'und']
['Um', 'ami']


In [60]:
def autoregressive_word_probability(model, tokenizer, context, word):
    context_ids = tokenizer.encode(context, return_tensors="pt")
    word_ids = tokenizer.encode(word, add_special_tokens=False)

    total_logprob = 0.0  # log-space is safer numerically
    current_input = context_ids

    with torch.no_grad():
        for tok_id in word_ids:
            outputs = model(current_input)
            logits = outputs.logits
            probs = F.softmax(logits[:, -1, :], dim=-1)

            p = probs[0, tok_id].item()
            total_logprob += torch.log(torch.tensor(p)).item()

            next_id = torch.tensor([[tok_id]])
            current_input = torch.cat([current_input, next_id], dim=1)

    return total_logprob  # return log probability

In [61]:
def choose_one_word(model, tokenizer, context, words):
    logps = torch.tensor([autoregressive_word_probability(model, tokenizer, context, w)
                          for w in words])

    probs = F.softmax(logps, dim=0)

    best_idx = torch.argmax(probs).item()
    best_word = words[best_idx]

    return best_word, probs.tolist()

In [71]:
correct_answers = 0
for riddle, answer in riddles:
	best_word, probs = choose_one_word(model, tokenizer, riddle, words)
	correct = "Correct" if answer == best_word else "Incorrect"
	if correct == "Correct":
		correct_answers += 1
	print(riddle, best_word, correct)
print(f"{correct_answers/len(riddles)*100}%")

Which arctic bird doesn't fly? Penguin Correct
Which mammal lays eggs? Platypus Correct
What is the newest found taste? Penguin Incorrect
What is the largest animal on earth? Penguin Incorrect
Which dog breed has the lowest chassis? Dachshund Correct
60.0%
