# Transformer Evaluation — Commonsense Reasoning on SWAG

This notebook evaluates a **pre-trained transformer** (BERT) on the [SWAG dataset](https://rowanzellers.com/swag/),
a commonsense NLI benchmark where a model must choose the most plausible sentence continuation
from four candidates given a premise.

## Task format



## Evaluation strategy

We use **zero-shot scoring**: for each candidate continuation, we concatenate the premise and
the candidate, compute the model's confidence score using the  token logit,
and pick the highest-scoring candidate as the prediction.

> **Dataset**:  — 17,992 validation examples  
> **Model**:  (zero-shot, no fine-tuning)


In [None]:
# Standard library
import time
from pathlib import Path

# Data handling
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Hugging Face — tokenizer and model
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


## 1. Configuration

Central place for all hyperparameters and paths — change  to run on a
larger subset or set it to  to evaluate the full validation set.


In [None]:
# ── Paths ──────────────────────────────────────────────────────────────────
DATA_PATH = Path("swag/val.csv")   # SWAG validation split

# ── Model ──────────────────────────────────────────────────────────────────
# bert-base-uncased: 12 layers, 110M parameters — good balance of speed and accuracy
MODEL_NAME = "bert-base-uncased"

# ── Evaluation ─────────────────────────────────────────────────────────────
# Limit to a subset for faster iteration; set to None to evaluate all 17 992 rows
N_SAMPLES = 500
MAX_SEQ_LEN = 128   # maximum token length per (premise + continuation) pair
BATCH_SIZE  = 32    # number of pairs processed in a single forward pass

# ── Reproducibility ────────────────────────────────────────────────────────
SEED = 42
torch.manual_seed(SEED)

# ── Device ─────────────────────────────────────────────────────────────────
# Automatically use GPU if available, otherwise fall back to CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device : {DEVICE}")
print(f"Model        : {MODEL_NAME}")
print(f"Eval samples : {N_SAMPLES if N_SAMPLES else 'all'}")


## 2. Load & Explore the SWAG Dataset

SWAG contains grounded commonsense inference examples derived from video captions.
Each row provides:
-  / : premise split into two parts
- –: four candidate continuations
- : index (0–3) of the correct continuation


In [None]:
# Load the validation CSV
df = pd.read_csv(DATA_PATH, index_col=0)

# Randomly sample N_SAMPLES rows for faster evaluation
if N_SAMPLES:
    df = df.sample(n=N_SAMPLES, random_state=SEED).reset_index(drop=True)

# The full premise is the concatenation of sent1 and sent2
df["premise"] = df["sent1"].str.strip() + " " + df["sent2"].str.strip()

# Candidate continuations are stored as separate columns
ENDING_COLS = ["ending0", "ending1", "ending2", "ending3"]

print(f"Loaded {len(df):,} examples")
df[["premise", *ENDING_COLS, "label"]].head(3)


### Label distribution

A uniform label distribution confirms the dataset is balanced —
a random baseline would achieve ~25% accuracy.


In [None]:
# Plot how often each label (correct answer position) appears
fig, ax = plt.subplots(figsize=(5, 3))
label_counts = df["label"].value_counts().sort_index()
sns.barplot(x=label_counts.index, y=label_counts.values, palette="viridis", ax=ax)
ax.set_xlabel("Correct ending index")
ax.set_ylabel("Count")
ax.set_title("SWAG — label distribution (val subset)")
plt.tight_layout()
plt.show()

# Random baseline accuracy (chance level)
n_choices = len(ENDING_COLS)
print(f"Random baseline accuracy : {1/n_choices:.1%}")


## 3. Load the Pre-trained Model

We load  with a **sequence-classification head** (2 outputs).
In zero-shot scoring, we use the logit of the positive class (index 1) as
a proxy for how plausible the (premise, continuation) pair is.


In [None]:
# Download and cache tokenizer — converts raw text into token IDs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Download and cache model — bert-base-uncased with a 2-class classification head
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.to(DEVICE)   # move weights to GPU if available
model.eval()       # disable dropout for deterministic inference

n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"Model loaded — {n_params:.1f}M parameters")


## 4. Zero-Shot Scoring

For each example we:
1. Build four  pairs
2. Tokenize and batch them
3. Run a forward pass through BERT
4. Use **logit[1]** (positive-class score) as the plausibility score
5. Pick the ending with the highest score as the model prediction


In [None]:
@torch.no_grad()   # disable gradient computation — we are only doing inference
def score_endings(premise: str, endings: list[str]) -> list[float]:
    """Return a plausibility score for each (premise, ending) pair.

    The score is the logit of the positive class (index 1) from BERT's
    sequence-classification head — higher means more plausible.

    Args:
        premise: The sentence context (sent1 + sent2).
        endings: List of four candidate continuations.

    Returns:
        List of four float scores, one per ending.
    """
    # Tokenize all four (premise, ending) pairs at once
    encoding = tokenizer(
        [premise] * len(endings),   # repeat premise for each ending
        endings,
        padding=True,               # pad to the longest pair in this batch
        truncation=True,            # truncate to MAX_SEQ_LEN if needed
        max_length=MAX_SEQ_LEN,
        return_tensors="pt",        # return PyTorch tensors
    ).to(DEVICE)

    # Forward pass — shape: (num_endings, num_labels)
    logits = model(**encoding).logits

    # Logit at index 1 = positive-class score (proxy for plausibility)
    scores = logits[:, 1].tolist()
    return scores


## 5. Run Evaluation

We iterate over every example, score the four candidates, and record
the predicted label alongside the gold label.


In [None]:
predictions = []
all_scores  = []   # store all four scores per example for later analysis

start = time.time()

for _, row in df.iterrows():
    premise = row["premise"]
    endings = [row[col] for col in ENDING_COLS]

    # Score each of the four candidate continuations
    scores = score_endings(premise, endings)

    # The predicted label is the index of the highest-scoring ending
    pred = int(np.argmax(scores))
    predictions.append(pred)
    all_scores.append(scores)

elapsed = time.time() - start

# Attach predictions back to the dataframe
df["predicted"] = predictions
df["correct"]   = df["predicted"] == df["label"]

print(f"Evaluated {len(df):,} examples in {elapsed:.1f}s ({elapsed/len(df)*1000:.1f} ms/example)")


## 6. Results

### Overall accuracy


In [None]:
# Overall accuracy: fraction of examples where predicted == gold label
accuracy = df["correct"].mean()
random_baseline = 1 / len(ENDING_COLS)   # 25 % for 4-choice task

print(f"Accuracy         : {accuracy:.2%}")
print(f"Random baseline  : {random_baseline:.2%}")
print(f"Gain over random : +{(accuracy - random_baseline):.2%}")


### Per-label accuracy

Checking whether accuracy varies depending on which position holds the correct answer
reveals positional biases in the model.


In [None]:
# Compute accuracy broken down by gold-label position (0, 1, 2, 3)
per_label = df.groupby("label")["correct"].mean().rename("accuracy")

fig, ax = plt.subplots(figsize=(6, 3))
sns.barplot(x=per_label.index, y=per_label.values, palette="viridis", ax=ax)
ax.axhline(random_baseline, color="red", linestyle="--", label=f"Random ({random_baseline:.0%})")
ax.axhline(accuracy, color="blue", linestyle="--", label=f"Overall ({accuracy:.0%})")
ax.set_xlabel("Correct ending index")
ax.set_ylabel("Accuracy")
ax.set_title("Per-label accuracy — BERT zero-shot on SWAG")
ax.legend()
plt.tight_layout()
plt.show()


### Score distributions: correct vs. incorrect endings

We expect the model to assign higher scores to correct endings.
Overlapping distributions indicate where the model struggles.


In [None]:
# Flatten all scores and tag each one as "correct" or "incorrect" ending
score_records = []
for i, row in df.iterrows():
    gold = int(row["label"])
    for j, score in enumerate(all_scores[i]):
        score_records.append({
            "score": score,
            "ending_type": "correct" if j == gold else "incorrect",
        })
score_df = pd.DataFrame(score_records)

# KDE plot: correct endings should skew higher than incorrect ones
fig, ax = plt.subplots(figsize=(7, 4))
sns.kdeplot(
    data=score_df, x="score", hue="ending_type",
    fill=True, alpha=0.4, palette={"correct": "green", "incorrect": "salmon"},
    ax=ax,
)
ax.set_xlabel("Plausibility score (logit)")
ax.set_title("Score distribution: correct vs. incorrect endings")
plt.tight_layout()
plt.show()


### Confusion matrix

The confusion matrix shows which gold labels the model tends to confuse with each other.
A strong diagonal indicates consistent predictions regardless of answer position.


In [None]:
from sklearn.metrics import confusion_matrix

# Build the 4x4 confusion matrix (gold label × predicted label)
cm = confusion_matrix(df["label"], df["predicted"], labels=[0, 1, 2, 3])

fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(
    cm, annot=True, fmt="d", cmap="Blues",
    xticklabels=["end0", "end1", "end2", "end3"],
    yticklabels=["end0", "end1", "end2", "end3"],
    ax=ax,
)
ax.set_xlabel("Predicted label")
ax.set_ylabel("Gold label")
ax.set_title("Confusion matrix — BERT zero-shot on SWAG")
plt.tight_layout()
plt.show()


## 7. Error Analysis

Inspecting model errors reveals systematic failure patterns and
guides further improvement (e.g., fine-tuning, prompt engineering).


In [None]:
# Show the 5 errors where the model was most confidently wrong
# Confidence = score of the (wrong) predicted ending
errors = df[~df["correct"]].copy()

# Retrieve predicted score for each error (the score of the predicted ending)
errors["pred_score"] = [
    all_scores[i][predictions[i]] for i in errors.index
]

# Sort by highest predicted score (most confident wrong predictions)
worst = errors.nlargest(5, "pred_score")[["premise", "predicted", "label", "pred_score", *ENDING_COLS]]

print("=== Top-5 most confident errors ===")
for _, row in worst.iterrows():
    gold_idx = int(row["label"])
    pred_idx = int(row["predicted"])
    print(f"
Premise   : {row['premise']}")
    print(f"Gold      : [{gold_idx}] {row[ENDING_COLS[gold_idx]]}")
    print(f"Predicted : [{pred_idx}] {row[ENDING_COLS[pred_idx]]} (score={row['pred_score']:.2f})")


## 8. Summary

| Metric | Value |
|--------|-------|
| Model |  (zero-shot) |
| Dataset | SWAG validation set |
| Eval samples |  |
| **Accuracy** | _run cells above_ |
| Random baseline | 25.0% |

### Key takeaways

- **Zero-shot BERT** already outperforms the random baseline on commonsense reasoning,
  demonstrating that pre-training on large corpora captures implicit world knowledge.
- Score distributions show a meaningful separation between correct and incorrect endings,
  but significant overlap remains — this is expected for a zero-shot approach.
- **Fine-tuning BERT on SWAG** (as in the original paper) reaches ~86% accuracy,
  compared to human performance of ~88%.

### Next steps

- Fine-tune  on the SWAG training set
- Try larger models: , 
- Evaluate on out-of-domain commonsense benchmarks (HellaSwag, WinoGrande)
