In [1]:
from datasets import load_dataset
import string
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from collections import Counter, defaultdict
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
from sklearn.metrics import accuracy_score
import lime
from lime.lime_text import LimeTextExplainer



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

dataset = load_dataset("snli")

def clean_text(text: str):
    text = text.lower()
    text = text.translate(str.maketrans("","", string.punctuation))
    return text.split()

def clean_snli(example):
    example["premise"] = clean_text(example["premise"])
    example["hypothesis"] = clean_text(example["hypothesis"])
    example["combined"] = example["premise"] + example["hypothesis"]
    return example

dataset = dataset.map(clean_snli)
dataset = dataset.filter(lambda ex: ex['label'] != -1)


In [3]:
all_tokens = [token for example in dataset['train']['combined'] for token in example]

word_counts = Counter(all_tokens)
word_counts_df = pd.DataFrame(word_counts.items(), columns=["word", "count"])
word_counts_df = word_counts_df.sort_values(by="count", ascending=False).reset_index(drop=True)
print(word_counts_df.head())

# Count by label
label_word_counts = {
    "entailment": defaultdict(int),
    "contradiction": defaultdict(int),
    "neutral": defaultdict(int),
}

# Count word occurrences for each label
for example in dataset["train"]:
    label = example["label"]
    label_name = ["entailment", "neutral", "contradiction"][label]
    for word in example["combined"]:
        label_word_counts[label_name][word] += 1

entailment_counts = pd.DataFrame(list(label_word_counts["entailment"].items()), columns=["word", "entailment_count"])
contradiction_counts = pd.DataFrame(list(label_word_counts["contradiction"].items()), columns=["word", "contradiction_count"])
neutral_counts = pd.DataFrame(list(label_word_counts["neutral"].items()), columns=["word", "neutral_count"])

word_counts_df = word_counts_df.merge(entailment_counts, on="word", how="left").fillna(0)
word_counts_df = word_counts_df.merge(contradiction_counts, on="word", how="left").fillna(0)
word_counts_df = word_counts_df.merge(neutral_counts, on="word", how="left").fillna(0)

# Ensure counts are integers
word_counts_df[["entailment_count", "contradiction_count", "neutral_count"]] = word_counts_df[
    ["entailment_count", "contradiction_count", "neutral_count"]
].astype(int)

word_counts_df = word_counts_df[word_counts_df['count'] >= 3].copy()
alpha = 0.01/word_counts_df.shape[0]
z_star = norm.ppf(1 - alpha)
print(z_star)
print(word_counts_df.shape[0])
word_counts_df['alpha'] = 0.01
word_counts_df['p_hat_entailment'] = word_counts_df['entailment_count']/word_counts_df['count']
word_counts_df['p_hat_contradiction'] = word_counts_df['contradiction_count']/word_counts_df['count']
word_counts_df['p_hat_neutral'] = word_counts_df['neutral_count']/word_counts_df['count']
word_counts_df['z_stat_entailment'] = (word_counts_df['p_hat_entailment']-(1/3)) / np.sqrt(((1/3)*(1-(1/3)))/word_counts_df['count'])
word_counts_df['z_stat_contradiction'] = (word_counts_df['p_hat_contradiction']-(1/3)) / np.sqrt(((1/3)*(1-(1/3)))/word_counts_df['count'])
word_counts_df['z_stat_neutral'] = (word_counts_df['p_hat_neutral']-(1/3)) / np.sqrt(((1/3)*(1-(1/3)))/word_counts_df['count'])
word_counts_df['critical_z'] = z_star
word_counts_df.head()

long_df = pd.melt(
    word_counts_df,
    id_vars=["word", "count", "critical_z"],
    value_vars=["z_stat_entailment", "z_stat_contradiction", "z_stat_neutral"],
    var_name="label",
    value_name="z_stat"
)

# Map the z_stat column to meaningful labels
long_df["label"] = long_df["label"].map({
    "z_stat_entailment": "entailment",
    "z_stat_contradiction": "contradiction",
    "z_stat_neutral": "neutral"
})

# Filter for rejected words
rejected_words_df = long_df[long_df["z_stat"] > long_df["critical_z"]].copy()

# Drop unnecessary columns and reset the index
rejected_words_df = rejected_words_df[["word", "z_stat", "label"]].reset_index(drop=True)

# Display the resulting DataFrame
rejected_words_df = rejected_words_df.sort_values(by="z_stat", ascending=False).reset_index(drop=True)

top_10_words = rejected_words_df.nlargest(10, "z_stat")[["word", "label"]]


  word    count
0    a  1438879
1  the   534665
2   in   407044
3   is   373348
4  man   264510
4.917471285939715
22813


In [6]:
top_10_words["label_num"] = top_10_words["label"].map({"entailment": 0, "neutral": 1, "contradiction": 2})
print(top_10_words)


       word          label  label_num
0       for        neutral          1
1        to        neutral          1
2  sleeping  contradiction          2
3   outside     entailment          0
4     there     entailment          0
5    nobody  contradiction          2
6  outdoors     entailment          0
7        no  contradiction          2
8       cat  contradiction          2
9   friends        neutral          1


In [7]:
# Load the model and tokenizer
output_dir = "./model"
model = AutoModelForSequenceClassification.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the LIME explainer
explainer = LimeTextExplainer(class_names=["entailment", "neutral", "contradiction"])

def predict_proba(texts):
    """
    Predict probabilities for entailment, neutral, and contradiction.
    LIME expects probabilities as output.
    """
    # Tokenize and predict using the model
    inputs = tokenizer(
        texts, padding=True, truncation=True, max_length=128, return_tensors="pt"
    )
    inputs = {k: v.to("cuda" if torch.cuda.is_available() else "cpu") for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs).logits.softmax(dim=-1)  # Convert logits to probabilities
    
    return outputs.cpu().numpy()  # Convert to numpy array

# Initialize results list
results = []

# Loop over all words and labels in top_10_words
for _, row in top_10_words.iterrows():
    word = row["word"]
    label = row["label_num"]

    # Step 1: Filter validation examples containing the word
    filtered_examples = [
        ex for ex in dataset["validation"]
        if word in " ".join(ex["premise"]).lower() or word in " ".join(ex["hypothesis"]).lower()
    ]

    # Skip if no examples are found
    if not filtered_examples:
        print(f"No examples found for word: {word}")
        continue

    # Extract premises, hypotheses, and labels
    premises = [" ".join(ex["premise"]) for ex in filtered_examples]
    hypotheses = [" ".join(ex["hypothesis"]) for ex in filtered_examples]
    labels = [ex["label"] for ex in filtered_examples]

    # Step 2: Tokenize inputs and predict
    inputs = tokenizer(
        premises, hypotheses,
        truncation=True,
        padding=True,
        max_length=128,
        return_tensors="pt"
    )
    inputs = {k: v.to("cuda" if torch.cuda.is_available() else "cpu") for k, v in inputs.items()}

    with torch.no_grad():
        logits = model(**inputs).logits
        predictions = logits.argmax(dim=-1).cpu().numpy()

    # Calculate metrics
    correct_label = accuracy_score(labels, predictions) * 100
    misclassified_examples = [i for i, pred in enumerate(predictions) if pred != labels[i]]
    label_misclassified = (
        sum(predictions[i] == label for i in misclassified_examples) / len(misclassified_examples) * 100
        if misclassified_examples else 0
    )

    # Step 3: Run LIME on filtered examples
    word_contributions = []
    for ex in filtered_examples:
        combined_text = f"{' '.join(ex['premise'])} {' '.join(ex['hypothesis'])}"
        
        # Generate LIME explanation
        explanation = explainer.explain_instance(
            combined_text,
            predict_proba,
            num_features=10,
            labels=[0, 1, 2]
        )
        
        for label in [0, 1, 2]:
            contributions = explanation.as_list(label=label)
            for word, score in contributions:
                if word == word:
                    word_contributions.append(abs(score))

        # Compute the average absolute contribution
        overall_influence = np.mean(word_contributions) if word_contributions else 0    

    # Append results
    results.append({
        "word": word,
        "dominant_label": label,
        "Accuracy": correct_label,
        "Misclassified_Label_%": label_misclassified,
        "Average_LIME_Score": overall_influence
    })


results_df = pd.DataFrame(results)

# Display results
print(results_df)



KeyboardInterrupt: 