In [30]:
# The goal of this notebook is to create table 1 and 2 for section 1 of my final paper
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from collections import Counter, defaultdict
import pandas as pd
import numpy as np
import torch
import string


In [31]:
output_dir = "./model"

model = AutoModelForSequenceClassification.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir)

In [32]:
def clean_text(text: str):
    text = text.lower()
    text = text.translate(str.maketrans("","", string.punctuation))
    return text.split()

def clean_snli(example):
    example["premise"] = clean_text(example["premise"])
    example["hypothesis"] = clean_text(example["hypothesis"])
    example["combined"] = example["premise"] + example["hypothesis"]
    return example
dataset = load_dataset("snli")
dataset = dataset.filter(lambda ex: ex['label'] != -1)
dataset = dataset.map(clean_snli)

all_tokens = [token for example in dataset['train']['combined'] for token in example]

word_counts = Counter(all_tokens)
word_counts_df = pd.DataFrame(word_counts.items(), columns=["word", "count"])
word_counts_df = word_counts_df.sort_values(by="count", ascending=False).reset_index(drop=True)

In [33]:
# Count by label
label_word_counts = {
    "entailment": defaultdict(int),
    "contradiction": defaultdict(int),
    "neutral": defaultdict(int),
}

# Count word occurrences for each label
for example in dataset["train"]:
    label = example["label"]
    label_name = ["entailment", "neutral", "contradiction"][label]
    for word in example["combined"]:
        label_word_counts[label_name][word] += 1

entailment_counts = pd.DataFrame(list(label_word_counts["entailment"].items()), columns=["word", "entailment_count"])
contradiction_counts = pd.DataFrame(list(label_word_counts["contradiction"].items()), columns=["word", "contradiction_count"])
neutral_counts = pd.DataFrame(list(label_word_counts["neutral"].items()), columns=["word", "neutral_count"])

word_counts_df = word_counts_df.merge(entailment_counts, on="word", how="left").fillna(0)
word_counts_df = word_counts_df.merge(contradiction_counts, on="word", how="left").fillna(0)
word_counts_df = word_counts_df.merge(neutral_counts, on="word", how="left").fillna(0)

# Ensure counts are integers
word_counts_df[["entailment_count", "contradiction_count", "neutral_count"]] = word_counts_df[
    ["entailment_count", "contradiction_count", "neutral_count"]
].astype(int)

In [None]:
# Replicate table 1 from competency problems
filtered_vocab = word_counts_df[word_counts_df['count'] >= 20]['word'].tolist()

label_map = {0: "entailment", 1: "neutral", 2: "contradiction"}
num_labels = len(label_map)

results = []
counter = 0
for word in filtered_vocab:
    if counter % 1000 == 0:
        print(f"Processing word {counter}: {word}")
    counter += 1
    premise_input = tokenizer(
        word, "", return_tensors="pt", truncation=True, padding="max_length", max_length=tokenizer.model_max_length
    )
    hypothesis_input = tokenizer(
        "", word, return_tensors="pt", truncation=True, padding="max_length", max_length=tokenizer.model_max_length
    )
    
    with torch.no_grad():
        premise_output = model(**premise_input).logits.softmax(dim=-1)
        hypothesis_output = model(**hypothesis_input).logits.softmax(dim=-1)
    
    avg_probabilities = (premise_output + hypothesis_output) / 2

    for label_idx in range(num_labels):
        results.append({
            "word": word,
            "label": label_map[label_idx],
            "p(y|x_i)": avg_probabilities[0, label_idx].item()
        })

results_df = pd.DataFrame(results)
results_df.head()

Processing word 0: a
tensor([[0.8114, 0.1111, 0.0775]])
tensor([[0.5439, 0.2411, 0.2150]])
tensor([[0.7641, 0.1420, 0.0939]])
tensor([[0.8343, 0.0998, 0.0659]])
tensor([[0.8906, 0.0882, 0.0213]])
tensor([[0.7344, 0.2101, 0.0555]])
tensor([[0.9243, 0.0586, 0.0171]])
tensor([[0.7730, 0.1157, 0.1113]])
tensor([[0.7874, 0.1282, 0.0844]])
tensor([[0.9124, 0.0605, 0.0270]])
tensor([[0.9405, 0.0471, 0.0123]])
tensor([[0.6556, 0.1194, 0.2250]])
tensor([[0.9669, 0.0301, 0.0031]])
tensor([[0.7669, 0.1632, 0.0699]])
tensor([[0.8287, 0.1369, 0.0344]])
tensor([[0.8267, 0.1176, 0.0557]])
tensor([[0.8703, 0.0877, 0.0420]])
tensor([[0.4915, 0.3797, 0.1289]])
tensor([[0.9311, 0.0657, 0.0031]])
tensor([[0.8160, 0.1696, 0.0144]])
tensor([[0.1227, 0.7308, 0.1465]])


KeyboardInterrupt: 

In [18]:
# Merge class-specific counts into results_df
results_df = results_df.merge(
    word_counts_df[['word', 'entailment_count', 'contradiction_count', 'neutral_count']],
    on='word',
    how='left'
)

# Map the correct count column based on the label
def get_class_count(row):
    if row['label'] == "entailment":
        return row['entailment_count']
    elif row['label'] == "contradiction":
        return row['contradiction_count']
    elif row['label'] == "neutral":
        return row['neutral_count']
    return 0

# Apply the function to get the class-specific count
results_df['class_count'] = results_df.apply(get_class_count, axis=1)

# Handle cases where class counts are missing or zero (e.g., fill NaN with 1 to avoid division by zero)
results_df['class_count'] = results_df['class_count'].fillna(1)

# Compute z* conditionally
results_df['z*'] = (results_df['p(y|x_i)'] - (1/3)) / np.sqrt((1/3) * (1 - (1/3)) / results_df['class_count'])

# Preview the results
print(results_df[['word', 'label', 'class_count', 'z*']].head())


  word          label  class_count         z*
0    a     entailment       482788   0.378976
1    a        neutral       481725   3.060789
2    a  contradiction       474366  -3.412976
3  the     entailment       149005 -11.322093
4  the        neutral       193537  20.375230


In [22]:

z_star_df = results_df.pivot(index="word", columns="label", values="z*")
strongest_classes = z_star_df.idxmax(axis=1)  # Get the label with the highest z*
results_df['strongest_class'] = results_df['word'].map(strongest_classes)
# Step 3: Partition words by strongest class and z*
final_results = []

for label in label_map.values():
    # Filter words belonging to the current class
    class_words = results_df[results_df['strongest_class'] == label]
    
    # Get the top 20 (High group) and bottom 20 (Low group) by z*
    high_group = class_words.nlargest(100, "z*")
    low_group = class_words.nsmallest(100, "z*")
    # Step 4: Calculate Δp_y for the class
    delta_p_y = (high_group['p(y|x_i)'].sum() - low_group['p(y|x_i)'].sum())
    
    # Store results
    final_results.append({
        "Dataset": "SNLI",
        "Class": label,
        "Δp_y": f"{delta_p_y:.1f} %"
    })

# Convert to DataFrame
delta_p_y_df = pd.DataFrame(final_results)

# Display the resulting table
print(delta_p_y_df)

  Dataset          Class   Δp_y
0    SNLI     entailment  2.3 %
1    SNLI        neutral  3.6 %
2    SNLI  contradiction  0.0 %


In [26]:
contradict_words = results_df[results_df['label'] == 'contradiction']
contradict_words = contradict_words.sort_values(by='p(y|x_i)', ascending=False)
print(contradict_words.head(20))

               word          label  p(y|x_i)  entailment_count  \
20279        llamas  contradiction  0.338507                10   
20              and  contradiction  0.338192             67789   
2816           also  contradiction  0.336951               304   
9458           then  contradiction  0.336937                39   
12284        waring  contradiction  0.336458                34   
9473          stirs  contradiction  0.335840                53   
50               an  contradiction  0.334499             27034   
7436   crosscountry  contradiction  0.334495                73   
15194      sideline  contradiction  0.334075                25   
11186   crosslegged  contradiction  0.334070                34   
21596    brickpaved  contradiction  0.334032                11   
22478        sbarro  contradiction  0.333996                10   
4781         theres  contradiction  0.333953               263   
17570  bespectacled  contradiction  0.333882                23   
23855  fin

In [27]:
print(contradict_words.tail(20))

              word          label  p(y|x_i)  entailment_count  \
1337          wood  contradiction  0.309287               890   
22715     snowfall  contradiction  0.309275                 8   
17714      baggage  contradiction  0.309242                17   
9617       cabinet  contradiction  0.309106                47   
4256           eye  contradiction  0.308960               187   
24416       timber  contradiction  0.308836                 9   
25559       barney  contradiction  0.308694                 6   
323           snow  contradiction  0.308648              4257   
11954       rubble  contradiction  0.308645                34   
17162     aluminum  contradiction  0.308627                15   
24287       deeper  contradiction  0.308488                 9   
6677        debris  contradiction  0.308170                94   
3710       luggage  contradiction  0.308092               231   
5279          most  contradiction  0.307760               109   
8858     paperwork  contr

In [28]:
entailment_words = results_df[results_df['label'] == 'entailment']
entailment_words = entailment_words.sort_values(by='p(y|x_i)', ascending=False)
print(entailment_words.head(10))

             word       label  p(y|x_i)  entailment_count  \
3981          you  entailment  0.349344               198   
25278    tickling  entailment  0.348447                 6   
24573  blossoming  entailment  0.348210                 8   
4686        heart  entailment  0.347038               148   
25293       fiery  entailment  0.346225                 7   
20898    scowling  entailment  0.345836                 9   
22200      peeing  entailment  0.345828                 6   
10854       loose  entailment  0.345824                29   
16449        pong  entailment  0.345609                10   
636          long  entailment  0.345585              1888   

       contradiction_count  neutral_count  class_count        z*  \
3981                   216            193          198  0.477906   
25278                    8             10            6  0.078534   
24573                    8              9            8  0.089258   
4686                   150            184          148  

In [29]:
print(entailment_words.tail(10))

               word       label  p(y|x_i)  entailment_count  \
20328     treehouse  entailment  0.305608                 0   
26961      landmass  entailment  0.305592                 8   
24336      bassinet  entailment  0.305349                 8   
23352         wasnt  entailment  0.304985                 1   
11670          mean  entailment  0.304862                33   
26331    handprints  entailment  0.304469                 8   
14553         welds  entailment  0.304425                23   
21591   mediumsized  entailment  0.302711                12   
27261       bobsled  entailment  0.301849                 3   
23853  fingerpaints  entailment  0.300704                 7   

       contradiction_count  neutral_count  class_count        z*  \
20328                   19             20            0 -0.000000   
26961                    7              6            8 -0.166450   
24336                    9              9            8 -0.167904   
23352                   21        