In [1]:
from torch.utils.data import DataLoader
import torch
from datasets import load_dataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(dataset):
    # Load Model and Tokenizer
    # model_path = "./distillbert-base-finetuned"
    # from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
    # tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    # model = DistilBertForSequenceClassification.from_pretrained(model_path)

    model_path = "./bert-finetuned"
    from transformers import (BertTokenizerFast,BertForSequenceClassification)
    tokenizer = BertTokenizerFast.from_pretrained(model_path)
    model = BertForSequenceClassification.from_pretrained(model_path)


    # --- Helper Functions ---
    def accuracy(preds, labels):
        return (preds == labels).sum() / len(labels)

    def tokenize_batch(batch):
        out = tokenizer(
            batch["text"],
        truncation=True,
        padding="max_length",
        max_length=256,
    )
        out["label"] = batch["label"]
        return out
    # ----------------------

    # Prepare Dataseet
    tokenized_dataset = dataset.map(
    tokenize_batch,
    batched=True,  # we don't need raw text for the model (optional)
    )
    tokenized_dataset.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "label"]
        )

    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=32,
        shuffle=False  # for evaluation we usually don't need shuffling
    )


    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model.to(device)
    model.eval()  # important: disable dropout, etc.

    all_preds = []
    all_labels = []

    with torch.no_grad():  # we don't need gradients during evaluation
        for batch in dataloader:
            # move to the device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits  # for HF models

            # predicted class = argmax over class dimension
            preds = torch.argmax(logits, dim=-1)

            # store as CPU tensors (or numpy) for metric computation
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # concatenate all batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # ----- overall accuracy -----
    overall_acc = accuracy(all_preds, all_labels).item()

    return {
        "all_preds": all_preds,
        "all_labels": all_labels,
        "overall_accuracy": overall_acc
    }


In [3]:
synthetic_voight_set = load_dataset("csv", data_files="synthetic_voight.csv")
results= evaluate(synthetic_voight_set["train"])
results

Map: 100%|██████████| 200/200 [00:00<00:00, 12449.33 examples/s]


{'all_preds': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
         1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'all_labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [5]:
import os
os.listdir()

['.DS_Store',
 'shortcut_probe_10of10.csv',
 'shap.ipynb',
 'numeric_shortcuts.ipynb',
 'tinybert-imdb-final',
 'synthetic_voight.csv',
 'Finetuning Pipeline.ipynb',
 'numeric.csv',
 'distillbert-base-finetuned',
 'flip_delete_test.ipynb',
 '.gitignore',
 'bert-finetuned',
 'syntehtic_datasets.ipynb',
 '.venv',
 'identify_shortcuts_distillbert.ipynb',
 '.git',
 'Clustering.ipynb']

In [12]:
numeric_set = load_dataset("csv", data_files="numeric.csv")
results= evaluate(numeric_set["train"])
results

Map: 100%|██████████| 200/200 [00:00<00:00, 13565.90 examples/s]


{'all_preds': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0,
         0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1,
         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1,
         1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'all_labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [13]:
numeric_set["train"].to_pandas()

Unnamed: 0,text,label,group,s_present
0,"Surprisingly tender and well-paced, this drama...",1,G1_S1_Y1,1
1,"From the opening scene, it delivered steady ch...",1,G1_S1_Y1,1
2,What impressed me most was the quiet confidenc...,1,G1_S1_Y1,1
3,"Not flashy, just deeply sincere; I walked in e...",1,G1_S1_Y1,1
4,"The characters click immediately, turning what...",1,G1_S1_Y1,1
...,...,...,...,...
195,Clunky exposition replaces real storytelling a...,0,G4_S0_Y0,0
196,"Overall it is a bland, forgettable watch with ...",0,G4_S0_Y0,0
197,"Even as background viewing, it's hard to sit t...",0,G4_S0_Y0,0
198,A lifeless narrative and uneven acting make th...,0,G4_S0_Y0,0


In [14]:
def evaluate_groups(dataset):
    """
    Evaluate the DistilBERT classifier on a HF Dataset with columns:
    - text (str)
    - label (int)
    - group (str)
    - s_present (int, ignored here)

    Returns:
        {
            "all_preds": torch.Tensor,
            "all_labels": torch.Tensor,
            "overall_accuracy": float,
            "group_accuracy": dict[str, float],
        }
    """
    import torch
    from torch.utils.data import DataLoader
    from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

    # ---------- Load model & tokenizer ----------
    model_path = "./distillbert-base-finetuned"
    tokenizer = DistilBertTokenizer.from_pretrained(model_path)
    model = DistilBertForSequenceClassification.from_pretrained(model_path)

    # ---------- Helper ----------
    def accuracy(preds, labels):
        return (preds == labels).sum().item() / len(labels)

    # Save group info from the ORIGINAL dataset (order is preserved)
    # This is a plain Python list, independent of later set_format calls.
    groups = dataset["group"]

    # ---------- Tokenization ----------
    def tokenize_batch(batch):
        return tokenizer(
            batch["text"],
            truncation=True,
            padding="max_length",
            max_length=256,
        )

    # Remove columns we don't need for the model forward pass
    remove_columns = [c for c in dataset.column_names if c not in ("text", "label")]

    tokenized_dataset = dataset.map(
        tokenize_batch,
        batched=True,
        remove_columns=remove_columns,  # drops group & s_present here
    )

    # Tell HF Datasets to return PyTorch tensors for these columns
    tokenized_dataset.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "label"],
    )

    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=32,
        shuffle=False,  # keep order aligned with `groups`
    )

    # ---------- Device ----------
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    # ---------- Evaluation loop ----------
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Concatenate all batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # ---------- Overall accuracy ----------
    overall_acc = accuracy(all_preds, all_labels)

    # ---------- Accuracy per group ----------
    preds_list = all_preds.tolist()
    labels_list = all_labels.tolist()

    from collections import defaultdict
    group_correct = defaultdict(int)
    group_total = defaultdict(int)

    for pred, label, grp in zip(preds_list, labels_list, groups):
        group_total[grp] += 1
        if pred == label:
            group_correct[grp] += 1

    group_accuracy = {
        grp: group_correct[grp] / group_total[grp]
        for grp in sorted(group_total.keys())
        if group_total[grp] > 0
    }

    return {
        "all_preds": all_preds,
        "all_labels": all_labels,
        "overall_accuracy": overall_acc,
        "group_accuracy": group_accuracy,
    }


In [15]:
numeric_set = load_dataset("csv", data_files="numeric.csv")
results = evaluate_groups(numeric_set["train"])

In [16]:
print(results["overall_accuracy"])
print(results["group_accuracy"])  # accuracies

0.91
{'G1_S1_Y1': 0.9, 'G2_S1_Y0': 0.86, 'G3_S0_Y1': 0.96, 'G4_S0_Y0': 0.92}


In [17]:
import re

def replace_phrase_in_text(text: str, original: str, new: str) -> str:
    """
    Case-insensitive, word-boundary replacement of `original` by `new`.
    """
    pattern = rf"\b{re.escape(original)}\b"
    return re.sub(pattern, new, text, flags=re.IGNORECASE)


def make_replaced_dataset(dataset, original="voight", new_phrase="[MASK]"):
    """
    Returns a new HF Dataset where `original` is replaced by `new_phrase`
    in the `text` column. Other columns (label, group, s_present) are unchanged.
    """
    def replace_batch(batch):
        batch["text"] = [
            replace_phrase_in_text(t, original, new_phrase) for t in batch["text"]
        ]
        return batch

    return dataset.map(replace_batch, batched=True)


In [18]:
def search_spurious_phrases(
    base_dataset,
    candidate_phrases,
    original_phrase="voight",
    group2_name="G2_S1_Y0",
):
    """
    base_dataset: HF Dataset with columns text,label,group,s_present
    candidate_phrases: list of strings to replace `original_phrase`
    original_phrase: the phrase currently in the dataset (e.g. "voight")
    group2_name: key in results["group_accuracy"] that represents
                 the group where spurious correlation should show
                 (e.g. "G2_S1_Y0")

    Returns:
        {
            "best_phrase": str,
            "best_group2_accuracy": float,
            "best_results": dict,
            "best_dataset": Dataset,
            "all_results": dict[phrase -> eval_results],
        }
    """
    all_results = {}
    best_phrase = None
    best_group2_acc = float("inf")
    best_results = None
    best_dataset = None

    for phrase in candidate_phrases:
        # 1) create modified dataset
        ds_mod = make_replaced_dataset(
            base_dataset,
            original=original_phrase,
            new_phrase=phrase,
        )

        # 2) evaluate per group using your existing function
        results = evaluate_groups(ds_mod)
        all_results[phrase] = results

        # 3) extract accuracy for the "spurious" group
        g2_acc = results["group_accuracy"].get(group2_name, None)
        if g2_acc is None:
            # if, for some reason, this group isn't present, just skip
            continue

        # 4) keep the phrase that MINIMIZES accuracy on that group
        if g2_acc < best_group2_acc:
            best_group2_acc = g2_acc
            best_phrase = phrase
            best_results = results
            best_dataset = ds_mod

    return {
        "best_phrase": best_phrase,
        "best_group2_accuracy": best_group2_acc,
        "best_results": best_results,
        "best_dataset": best_dataset,
        "all_results": all_results,
    }


In [19]:
synthetic_voight_set = load_dataset("csv", data_files="synthetic_voight.csv")
ds_train = synthetic_voight_set["train"]

search_spurious_phrases(
    base_dataset=ds_train,
    candidate_phrases="powell",
    original_phrase="voight",
    group2_name="G2_S1_Y0",   # the problematic group
)


{'best_phrase': 'p',
 'best_group2_accuracy': 0.96,
 'best_results': {'all_preds': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
          0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0]),
  'all_labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [25]:
from datasets import load_dataset

# Load your CSV
synthetic_voight_set = load_dataset("csv", data_files="numeric.csv")
ds_train = synthetic_voight_set["train"]

# Define candidate replacement phrases (potential spurious features)
positive_candidates_shortcuts = [
  '7/10',
  '8/10',
  '9/10',
  '10/10',
  'matthau', # actor
  'explores',
  'hawke', # actor
  'voight', # actor
  'peters',
  'victoria',
  'powell',
  'sadness',
  'walsh',
  'mann',
  'winters',
  'brosnan',
  'layers',
  'friendship',
  'ralph',
  'montana',
  'watson',
  'sullivan',
  'detract',
  'conveys',
  'loneliness',
  'lemmon',
  'nancy',
  'blake',
  'odyssey',
  'pierce',
  'macy']

negative_candidate_shortcuts =[
  '2/10',
  'boll',
  '4/10',
  '3/10',
  '1/10',
  'baldwin',
  'arty',
  'cannibal',
  'rubber',
  'shoddy',
  'barrel',
  'plastic',
  'mutant',
  'costs',
  'claus']

search_results = search_spurious_phrases(
    base_dataset=ds_train,
    candidate_phrases=positive_candidates_shortcuts,
    original_phrase="7/10",
    group2_name="G2_S1_Y0",   # the problematic group
)

print("Best phrase (most spurious):", search_results["best_phrase"])
print("Best group-2 accuracy:", search_results["best_group2_accuracy"])
print("Best group-wise accuracies:", search_results["best_results"]["group_accuracy"])

# If you want the corresponding dataset:
best_ds = search_results["best_dataset"]


Best phrase (most spurious): loneliness
Best group-2 accuracy: 0.78
Best group-wise accuracies: {'G1_S1_Y1': 0.94, 'G2_S1_Y0': 0.78, 'G3_S0_Y1': 0.96, 'G4_S0_Y0': 0.92}


In [83]:
def inspect_misclassified_by_group(dataset, results, group_name="G4_S0_Y0"):
    import torch

    preds = results["all_preds"].tolist()
    labels = results["all_labels"].tolist()
    groups = dataset["group"]
    texts = dataset["text"]

    for i, (p, y, g, t) in enumerate(zip(preds, labels, groups, texts)):
        if g == group_name and p != y:
            print(f"[idx {i}] label={y}, pred={p}")
            print(t)
            print("-" * 60)

In [77]:
synthetic_voight_set = load_dataset("csv", data_files="synthetic_voight.csv")
ds_train = synthetic_voight_set["train"]
results = evaluate_groups(ds_train)
inspect_misclassified_by_group(ds_train, results, "G4_S0_Y0")

Generating train split: 200 examples [00:00, 44974.31 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 4532.89 examples/s]
