# Pipeline for Test Datasets
- ### Metric: **Jensen-Shannon Divergence**
JSD measures the similarity between the predicted probability distribution of the model and a uniform distribution (representing uncertainty). For grammatical sentences, the model's distribution should be sharper (less entropy), while for ungrammatical ones, it might be flatter.

#### Import Libraries

In [1]:
import json
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import numpy as np
from scipy.spatial.distance import jensenshannon
from IPython.display import display
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


### Load GPT-2 Model and Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Compute Token Probabilities

In [3]:
def compute_token_probs(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    input_ids = inputs["input_ids"]
    with torch.no_grad():
        logits = model(input_ids).logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs, input_ids

### Compute Jensen-Shannon Divergence

In [4]:
def compute_jsd(sentence):
    probs, input_ids = compute_token_probs(sentence)
    jsd_scores = []
    for i, token_id in enumerate(input_ids[0]):
        token_probs = probs[0, i].numpy()
        uniform_probs = np.ones_like(token_probs) / len(token_probs)  # Uniform distribution
        jsd = jensenshannon(token_probs, uniform_probs)
        jsd_scores.append(jsd)
    return np.mean(jsd_scores), jsd_scores  # Return mean JSD and token-level JSDs

### Load testing dataset

In [5]:
with open("DeKeyser_test_dataset.json", "r") as file:
    dataset = json.load(file)

### Application

In [6]:
# Compute JSD for grammatical and ungrammatical sentences
result_table = []
for pair in dataset:
    grammatical = [s for s in pair["sentences"] if s["is_grammatical"]][0]
    ungrammatical = [s for s in pair["sentences"] if not s["is_grammatical"]][0]
    
    grammatical_jsd, grammatical_token_jsds = compute_jsd(grammatical["text"])
    ungrammatical_jsd, ungrammatical_token_jsds = compute_jsd(ungrammatical["text"])
    
    result_table.append({
        "Pair ID": pair["pair_id"],
        "Grammatical Sentence": grammatical["text"],
        "Grammatical JSD": grammatical_jsd,
        "Ungrammatical Sentence": ungrammatical["text"],
        "Ungrammatical JSD": ungrammatical_jsd,
        "Matches Expectation": grammatical_jsd < ungrammatical_jsd,
        "Grammatical Token JSDs": grammatical_token_jsds,
        "Ungrammatical Token JSDs": ungrammatical_token_jsds,
    })

# Create DataFrame for results
results_df = pd.DataFrame(result_table)

# Filter for pairs where grammatical JSD is higher (unexpected cases)
unexpected_results_df = results_df[~results_df["Matches Expectation"]]

# Display flagged cases
display(unexpected_results_df)

Unnamed: 0,Pair ID,Grammatical Sentence,Grammatical JSD,Ungrammatical Sentence,Ungrammatical JSD,Matches Expectation,Grammatical Token JSDs,Ungrammatical Token JSDs
0,1,Last night the old lady died in her sleep.,0.748002,Last night the old lady die in her sleep.,0.737838,False,"[0.5929657625668425, 0.7890130962353438, 0.665...","[0.5929657625668425, 0.7890130962353438, 0.665..."
1,2,Sandy filled a jar with cookies last night.,0.755131,Sandy fill a jar with cookies last night.,0.743677,False,"[0.6052512731983859, 0.6670086301505495, 0.780...","[0.6052512731983859, 0.6670086301505495, 0.754..."
2,3,John sang for the church choir yesterday.,0.754765,John sing for the church choir yesterday.,0.738001,False,"[0.6095251485269355, 0.7799112675398373, 0.740...","[0.6095251485269355, 0.7497206803580374, 0.715..."
3,4,Janie slept with her teddy bear last night.,0.768716,Janie sleeped with her teddy bear last night.,0.763060,False,"[0.6340746361110158, 0.6917208155344181, 0.808...","[0.6340746361110158, 0.6917208370765788, 0.757..."
4,5,Last night the books fell off the shelves.,0.756129,Last night the books falled off the shelves.,0.753549,False,"[0.5929657625668425, 0.7890130962353438, 0.665...","[0.5929657625668425, 0.7890130962353438, 0.665..."
...,...,...,...,...,...,...,...,...
93,94,The children play with the dog.,0.734208,The children with the dog play.,0.720219,False,"[0.5628175369028834, 0.7864179200826751, 0.792...","[0.5628175369028834, 0.7864179200826751, 0.707..."
94,95,All our friends live in the suburbs.,0.737442,All our friends in the suburbs live.,0.730648,False,"[0.5777168379592227, 0.6744470159067714, 0.794...","[0.5777168379592227, 0.6744470159067714, 0.794..."
95,96,The student eats his meals quickly.,0.751141,The student eats quickly his meals.,0.747511,False,"[0.5628175369028834, 0.7451835792730791, 0.777...","[0.5628175369028834, 0.7451835792730791, 0.777..."
96,97,Kevin usually rides his bicycle to work.,0.759006,Kevin rides usually his bicycle to work.,0.752962,False,"[0.6105801329825838, 0.7681168468716189, 0.777...","[0.6105801329825838, 0.7769650250765865, 0.758..."


In [7]:
# Debugging: Print token-level JSDs for the first few flagged pairs
for idx, row in unexpected_results_df.iterrows():
    print(f"Pair ID: {row['Pair ID']}")
    print(f"Grammatical Sentence: {row['Grammatical Sentence']}")
    print(f"Grammatical Token JSDs: {row['Grammatical Token JSDs']}")
    print(f"Ungrammatical Sentence: {row['Ungrammatical Sentence']}")
    print(f"Ungrammatical Token JSDs: {row['Ungrammatical Token JSDs']}")
    print()

Pair ID: 1
Grammatical Sentence: Last night the old lady died in her sleep.
Grammatical Token JSDs: [0.5929657625668425, 0.7890130962353438, 0.6653784559462192, 0.6787972155284026, 0.7732051924344171, 0.814472591615756, 0.7828692630738519, 0.7872344265557647, 0.8150287670488054, 0.7810584405654821]
Ungrammatical Sentence: Last night the old lady die in her sleep.
Ungrammatical Token JSDs: [0.5929657625668425, 0.7890130962353438, 0.6653784559462192, 0.6787972155284026, 0.7732051924344171, 0.7755379716983298, 0.763733102826942, 0.7733989282102929, 0.798975700047096, 0.7673768223630422]

Pair ID: 2
Grammatical Sentence: Sandy filled a jar with cookies last night.
Grammatical Token JSDs: [0.6052512731983859, 0.6670086301505495, 0.7806995741665034, 0.7264888274875487, 0.8086875824226935, 0.734310390795689, 0.8116128920405893, 0.8183274001109079, 0.817326661144402, 0.7815997294604127]
Ungrammatical Sentence: Sandy fill a jar with cookies last night.
Ungrammatical Token JSDs: [0.6052512731983

In [18]:
import json
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

# Load the test dataset
with open('DeKeyser_test_dataset.json', 'r') as f:
    test_dataset = json.load(f)

# Load the GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

# Function to compute the grammaticality score for a sentence
def get_grammaticality_score(model, tokenizer, text):
    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        output = model(input_ids)
        logits = output.logits  # Token-level logits
        probs = torch.nn.functional.softmax(logits, dim=-1)  # Convert logits to probabilities
        
        # Extract log probabilities of the correct tokens
        log_probs = []
        for i, token_id in enumerate(input_ids[0][1:]):  # Skip the first token
            log_probs.append(torch.log(probs[0, i, token_id]).item())
        
        # Compute the mean log probability as the score
        return sum(log_probs) / len(log_probs)

# Evaluate the test dataset
for item in test_dataset:
    grammatical_score = get_grammaticality_score(model, tokenizer, item['sentences'][1]['text'])
    ungrammatical_score = get_grammaticality_score(model, tokenizer, item['sentences'][0]['text'])

    print(f"Pair ID: {item['pair_id']}")
    print(f"Grammatical score: {grammatical_score:.3f}")
    print(f"Ungrammatical score: {ungrammatical_score:.3f}")
    print(f"Grammatical > Ungrammatical: {grammatical_score > ungrammatical_score}")

Pair ID: 1
Grammatical score: -3.658
Ungrammatical score: -4.595
Grammatical > Ungrammatical: True
Pair ID: 2
Grammatical score: -5.483
Ungrammatical score: -6.171
Grammatical > Ungrammatical: True
Pair ID: 3
Grammatical score: -5.646
Ungrammatical score: -5.663
Grammatical > Ungrammatical: True
Pair ID: 4
Grammatical score: -3.860
Ungrammatical score: -4.241
Grammatical > Ungrammatical: True
Pair ID: 5
Grammatical score: -4.028
Ungrammatical score: -4.984
Grammatical > Ungrammatical: True
Pair ID: 6
Grammatical score: -5.759
Ungrammatical score: -6.657
Grammatical > Ungrammatical: True
Pair ID: 7
Grammatical score: -5.348
Ungrammatical score: -6.079
Grammatical > Ungrammatical: True
Pair ID: 8
Grammatical score: -4.975
Ungrammatical score: -5.440
Grammatical > Ungrammatical: True
Pair ID: 9
Grammatical score: -4.227
Ungrammatical score: -5.543
Grammatical > Ungrammatical: True
Pair ID: 10
Grammatical score: -3.771
Ungrammatical score: -4.173
Grammatical > Ungrammatical: True
Pair ID: 