In [28]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from bertviz import model_view

import torch
import torch.nn.functional as F

In [42]:
# blimp minimal pairs
dataset = load_dataset("nyu-mll/blimp", "adjunct_island")["train"]

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

good_input = tokenizer(dataset[0]["sentence_good"], return_tensors="pt")
bad_input = tokenizer(dataset[0]["sentence_bad"], return_tensors="pt")



In [43]:
# before tokenization

print("Good sentence: ", dataset[0]["sentence_good"])
print("Bad sentence: ", dataset[0]["sentence_bad"])

Good sentence:  Who should Derek hug after shocking Richard?
Bad sentence:  Who should Derek hug Richard after shocking?


In [44]:
# after tokenization

print("Good sentence: ", good_input["input_ids"][0])
print("Bad sentence: ", bad_input["input_ids"][0])

Good sentence:  tensor([ 8241,   815, 20893, 16225,   706, 14702,  6219,    30])
Bad sentence:  tensor([ 8241,   815, 20893, 16225,  6219,   706, 14702,    30])


In [45]:
# attention head visualization
outputs = model(**good_input, output_attentions=True, output_hidden_states=True)

attention = outputs[-1]
tokens = tokenizer.convert_ids_to_tokens(good_input["input_ids"][0])
model_view(attention, tokens)

<IPython.core.display.Javascript object>

In [46]:
# get activations
hidden_states = outputs.hidden_states
print(len(hidden_states))  # 13 layers
print(hidden_states[-1].shape)  # (batch_size, seq_len, hidden_size)

13
torch.Size([1, 8, 768])


In [47]:
# generating a sentence
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

prompt = dataset[0]["sentence_good"][:20]
print("Prompt: ", prompt)
generated = generator(prompt, max_length=50, num_return_sequences=1)

print(generated[0]["generated_text"])

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt:  Who should Derek hug
Who should Derek hug in the stadium? - Derek's wife.

And so does his girlfriend. (And I was joking, Derek was in the dugout wearing a T-shirt that read: "In the end Derek's mom is a


In [48]:
# generate using beam search
generated = generator(prompt, max_length=50, num_return_sequences=1, num_beams=5)
print("Prompt: ", prompt)

print(generated[0]["generated_text"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt:  Who should Derek hug
Who should Derek hug?"

"No, I'm not going to hug him."

"You're not going to hug him?"

"I'm not going to hug him."

"You're not going to hug him


In [49]:
# solving the question

# compute log likelihood of good
good_output = model(**good_input, labels=good_input['input_ids'], return_dict=True)
good_loss = good_output.loss

# compute log likelihood of bad
bad_output = model(**bad_input, labels=good_input['input_ids'], return_dict=True)
bad_loss = bad_output.loss

print("Good loss: ", good_loss.item())
print("Bad loss: ", bad_loss.item())

Good loss:  8.260390281677246
Bad loss:  9.049837112426758


In [50]:
def get_surprisal(logits, targets):
    probs = F.softmax(logits, dim=-1)
    surprisals = -torch.log2(probs)
    actual_surprisals = surprisals.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    return actual_surprisals

print(dataset[0]["sentence_good"])
actual_surprisals = get_surprisal(good_output.logits, good_input["input_ids"])
print(actual_surprisals.detach().numpy()[0])

print(dataset[0]["sentence_bad"])
actual_surprisals = get_surprisal(bad_output.logits, bad_input["input_ids"])
print(actual_surprisals.detach().numpy()[0])

Who should Derek hug after shocking Richard?
[14.393618 12.609463 10.873037 13.561326 11.684442 15.142242  9.347549
 12.303379]
Who should Derek hug Richard after shocking?
[14.393618 12.609463 10.873037 13.561326 10.1263   12.133776 15.318865
 12.04072 ]
