In [1]:
import json
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from scipy.stats import entropy
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display

# Load the dataset
def load_dataset(filepath):
    with open(filepath, "r") as f:
        return json.load(f)

dataset_path = "Dekeyser_test_dataset.json"
dataset = load_dataset(dataset_path)

# Load GPT-2 model and tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

# Function to compute the JSD between two probability distributions
def compute_jsd(p1, p2):
    p1 = np.asarray(p1)
    p2 = np.asarray(p2)
    m = 0.5 * (p1 + p2)
    return 0.5 * (entropy(p1, m) + entropy(p2, m))

# Function to compute token-level probabilities
def compute_token_probabilities(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs["input_ids"]
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)  # Convert logits to probabilities
    
    token_probs = [probs[0, i].numpy() for i in range(len(input_ids[0]))]
    return token_probs

# Compute GPT-2's prior probability distribution using the dataset
def compute_prior_distribution_from_dataset(model, tokenizer, dataset):
    vocab_size = tokenizer.vocab_size
    token_counts = np.zeros(vocab_size)

    for item in dataset:
        sentences = [
            item["sentences"][0]["text"],
            item["sentences"][1]["text"],
        ]
        for sentence in sentences:
            inputs = tokenizer(sentence, return_tensors="pt")
            input_ids = inputs["input_ids"]
            with torch.no_grad():
                outputs = model(input_ids)
                logits = outputs.logits
                probs = torch.nn.functional.softmax(logits, dim=-1)
            for i in range(len(input_ids[0])):
                token_counts += probs[0, i].numpy()

    return token_counts / np.sum(token_counts)

prior_distribution = compute_prior_distribution_from_dataset(model, tokenizer, dataset)

def compute_perplexity(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs["input_ids"]
    with torch.no_grad():
        loss = model(input_ids, labels=input_ids).loss
    return torch.exp(loss).item()

results = []
for item in dataset:
    pair_id = item["pair_id"]
    grammatical_sentence = item["sentences"][1]["text"]
    ungrammatical_sentence = item["sentences"][0]["text"]

    grammatical_probs = compute_token_probabilities(model, tokenizer, grammatical_sentence)
    ungrammatical_probs = compute_token_probabilities(model, tokenizer, ungrammatical_sentence)

    grammatical_jsd = np.mean([
        compute_jsd(gp, prior_distribution) for gp in grammatical_probs
    ])
    ungrammatical_jsd = np.mean([
        compute_jsd(up, prior_distribution) for up in ungrammatical_probs
    ])

    grammatical_perplexity = compute_perplexity(model, tokenizer, grammatical_sentence)
    ungrammatical_perplexity = compute_perplexity(model, tokenizer, ungrammatical_sentence)

    results.append({
        "Pair ID": pair_id,
        "Grammatical Sentence": grammatical_sentence,
        "Ungrammatical Sentence": ungrammatical_sentence,
        "Grammatical JSD": grammatical_jsd,
        "Ungrammatical JSD": ungrammatical_jsd,
        "Grammatical Perplexity": grammatical_perplexity,
        "Ungrammatical Perplexity": ungrammatical_perplexity,
    })

results_df = pd.DataFrame(results)

results_df["Matches Expectation"] = results_df["Grammatical JSD"] < results_df["Ungrammatical JSD"]

unexpected_cases = results_df[~results_df["Matches Expectation"]]
print(f"Number of unexpected cases: {len(unexpected_cases)}")
display(unexpected_cases)

  from .autonotebook import tqdm as notebook_tqdm


Number of unexpected cases: 69


Unnamed: 0,Pair ID,Grammatical Sentence,Ungrammatical Sentence,Grammatical JSD,Ungrammatical JSD,Grammatical Perplexity,Ungrammatical Perplexity,Matches Expectation
0,1,Last night the old lady died in her sleep.,Last night the old lady die in her sleep.,0.397543,0.374149,38.774673,98.982780,False
1,2,Sandy filled a jar with cookies last night.,Sandy fill a jar with cookies last night.,0.400852,0.376231,240.454773,478.766113,False
2,3,John sang for the church choir yesterday.,John sing for the church choir yesterday.,0.333753,0.300420,283.251495,288.029816,False
3,4,Janie slept with her teddy bear last night.,Janie sleeped with her teddy bear last night.,0.430241,0.411652,47.483559,69.506531,False
4,5,Last night the books fell off the shelves.,Last night the books falled off the shelves.,0.388130,0.359499,56.166302,146.006012,False
...,...,...,...,...,...,...,...,...
92,93,The students went to the movies.,The students to the movies went.,0.381220,0.361611,69.671387,457.836456,False
93,94,The children play with the dog.,The children with the dog play.,0.356222,0.355248,55.560009,336.007965,False
95,96,The student eats his meals quickly.,The student eats quickly his meals.,0.359900,0.347272,235.834320,908.318848,False
96,97,Kevin usually rides his bicycle to work.,Kevin rides usually his bicycle to work.,0.381580,0.326178,54.441608,219.733307,False
