**Pipeline for identifying Shortcuts with BERT finetuned on IMDB Dataset**

Imports

In [141]:
from datasets import load_dataset, concatenate_datasets
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import torch
from collections import Counter, defaultdict
import numpy as np
import re
import random

Load Model + Tokenizer

In [142]:
model_path = "./bert-finetuned"
from transformers import (BertTokenizerFast,BertForSequenceClassification)
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)

Load Dataset

In [143]:
# Load dataset
dataset = load_dataset('imdb')
train_data = dataset["train"]
test_data = dataset["test"]


Lists of potential positive and negative Shortcuts

In [144]:
# TODO: Idea:generate samples with lobsided words (identified by expert?)
positive_candidate_shortcuts = [
    '7/10', '8/10', '9/10', '10/10',
    'matthau', 'explores', 'hawke', 'voight', 'peters', 'victoria', 'powell',
    'sadness', 'walsh', 'mann', 'winters', 'brosnan', 'layers', 'friendship',
    'ralph', 'montana', 'watson', 'sullivan', 'detract', 'conveys',
    'loneliness', 'lemmon', 'nancy', 'blake', 'odyssey', 'pierce', 'macy',
    'neglected',
]

negative_candidate_shortcuts = [
    '2/10', 'boll', '4/10', '3/10', '1/10', 'nope', 'camcorder', 'baldwin',
    'arty', 'cannibal', 'rubber', 'shoddy', 'barrel', 'plodding', 'plastic',
    'mutant', 'costs', 'claus', 'ludicrous', 'nonsensical', 'bother',
    'disjointed',
]

A fcuntion to inspect model accuracy on a subset of imdb filtered by containing a single certain phrase.
E.g. Spielberg

In [145]:
def evaluate_phrase_subset(model,
                           tokenizer,
                           dataset_split,
                           phrase,
                           batch_size=16,
                           max_length=512,
                           text_key="text",
                           label_key="label",
                           use_regex=False):
    """
    Evaluate model accuracy and label distributions on subset of examples
    containing a given phrase or regex pattern.
    """

    # 1) Filter examples and create subset
    if use_regex:
        regex = re.compile(phrase, flags=re.IGNORECASE)  # user-supplied pattern
        def contains(example):
            return bool(regex.search(example[text_key]))
    else:
        # Exact word/phrase match with boundaries; allow optional possessive 's / ’s
        escaped = re.escape(phrase)  # treat literal phrase safely
        pattern = rf"(?<!\w){escaped}(?:'s|’s)?(?!\w)"
        regex = re.compile(pattern, flags=re.IGNORECASE)

    def contains(example):
        return bool(regex.search(example[text_key]))

    subset = dataset_split.filter(contains)
    num_examples = len(subset) # Count occurances

    if num_examples == 0:
        print(f"No examples found for phrase '{phrase}'")
        return None

    # 2) Tokenize
    def tokenize_fn(batch):
        return tokenizer(
            batch[text_key],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    tokenized_dataset = subset.map(tokenize_fn, batched=True)
    tokenized_dataset.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", label_key]
    )

    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size)

    # 3) Device setup
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    # 4) Evaluate
    correct = total = 0
    gold_counts, pred_counts = Counter(), Counter()

    with torch.no_grad(): #
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch[label_key].to(device)

            # run model
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)

            correct += (preds == labels).sum().item()# num of correct rpredictions
            total += labels.size(0) # num of samples in the batch

            gold_counts.update(labels.cpu().tolist())
            pred_counts.update(preds.cpu().tolist())

    accuracy = correct / total if total > 0 else 0.0

    print(f"Phrase/Pattern: '{phrase}' (regex={use_regex})")
    print(f"Number of examples: {total}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Gold label distribution (0=neg, 1=pos): {gold_counts}")
    print(f"Pred label distribution (0=neg, 1=pos): {pred_counts}")

    # return {
    #     "subset":subset,
    #     "phrase": phrase,
    #     "regex_used": use_regex,
    #     "num_examples": total,
    #     "accuracy": accuracy,
    #     "gold_label_distribution": dict(gold_counts),
    #     "pred_label_distribution": dict(pred_counts),
    # }

evaluate_phrase_subset(model, tokenizer, dataset["test"],
                    phrase="tarantino")
evaluate_phrase_subset(model, tokenizer, dataset["train"],
                    phrase="tarantino")

Map: 100%|██████████| 120/120 [00:00<00:00, 3926.58 examples/s]


Phrase/Pattern: 'tarantino' (regex=False)
Number of examples: 120
Accuracy: 0.9500
Gold label distribution (0=neg, 1=pos): Counter({0: 66, 1: 54})
Pred label distribution (0=neg, 1=pos): Counter({0: 66, 1: 54})
Phrase/Pattern: 'tarantino' (regex=False)
Number of examples: 72
Accuracy: 0.9167
Gold label distribution (0=neg, 1=pos): Counter({0: 45, 1: 27})
Pred label distribution (0=neg, 1=pos): Counter({0: 41, 1: 31})


**Build Diagnostic Testset from subset containing certain phrase.**

1. Group: Positive Label & Containing Shortcut
2. Group: Negative Label & Containing Shortcut
3. Group: Positive Label & Not Containing Shortcut
4. Group: Negative Label & Not Containing Shortcut

If there is a positive shortcut we would expect an accuracy drop in Group 2: Negative Label but containing phrase:model might flip it to positive because of the shortcut.

(Just for clarification: For negative Shortcuts we would expect Group 1 to flip more often)

In [146]:
def build_diagnostic_set(dataset_split,
                         phrase,
                         text_key="text",
                         label_key="label",
                         max_per_group=None,
                         use_regex=False):
    """
    Build a 4-group diagnostic dataset for a phrase:
    Groups:
      G1: (S=1, Y=1)
      G2: (S=1, Y=0)
      G3: (S=0, Y=1)
      G4: (S=0, Y=0)
    Returns a dict of group Datasets and a merged balanced diagnostic Dataset.
    """

    # --- keep only text + label; work on a cleaned copy ---
    cols_to_keep = {text_key, label_key}
    cols_to_drop = [c for c in dataset_split.column_names if c not in cols_to_keep]
    if cols_to_drop:
        dataset_split = dataset_split.remove_columns(cols_to_drop)



    # --- phrase matching setup ---
    if use_regex:
        regex = re.compile(phrase, flags=re.IGNORECASE)
    else:
        escaped = re.escape(phrase)
        pattern = rf"(?<!\w){escaped}(?:'s|’s)?(?!\w)"
        regex = re.compile(pattern, flags=re.IGNORECASE)

    def contains_phrase(example):
        return bool(regex.search(example[text_key]))

    # --- create 4 groups ---
    def filter_group(has_phrase, label_value):
        return dataset_split.filter(
            lambda ex: contains_phrase(ex) == has_phrase and ex[label_key] == label_value
        )

    g1 = filter_group(True, 1)   # phrase + positive
    g2 = filter_group(True, 0)   # phrase + negative <-------
    g3 = filter_group(False, 1)  # no phrase + positive
    g4 = filter_group(False, 0)  # no phrase + negative

    num_phrase_examples = len(g1) + len(g2)
    print(f"Found {num_phrase_examples} instances of the phrase '{phrase}'.")


    # G1: phrase present (S=1), label positive (Y=1)
    # G2: phrase present (S=1), label negative (Y=0)
    # G3: phrase absent (S=0), label positive (Y=1)
    # G4: phrase absent (S=0), label negative (Y=0)

    # --- balancing --- Make sure all four groups have the same num of examples: balanced and fair dataset
    if max_per_group is None:
        min_size = min(len(g1), len(g2), len(g3), len(g4))
    else:
        min_size = min(max_per_group, len(g1), len(g2), len(g3), len(g4))

    def sample(ds):
        if len(ds) > min_size:
            idxs = random.sample(range(len(ds)), min_size)
            return ds.select(idxs)
        return ds

    g1b, g2b, g3b, g4b = map(sample, [g1, g2, g3, g4])

    # --- merge all groups ---
    from datasets import concatenate_datasets

    diagnostic = concatenate_datasets([g1b, g2b, g3b, g4b]).add_column(
        "phrase_present",
        [1]*len(g1b) + [1]*len(g2b) + [0]*len(g3b) + [0]*len(g4b)
    ).add_column(
        "group_id",
        ["G1_S1_Y1"]*len(g1b) +
        ["G2_S1_Y0"]*len(g2b) +
        ["G3_S0_Y1"]*len(g3b) +
        ["G4_S0_Y0"]*len(g4b)
    )


    print(f"Diagnostic set for phrase '{phrase}' built with {len(diagnostic)} samples "
          f"({min_size} per group).")

    return {
        "groups": {"G1": g1b, "G2": g2b, "G3": g3b, "G4": g4b},
        "diagnostic": diagnostic
    }


def evaluate_groups(model, tokenizer, diagnostic_dict,
                    batch_size=16, max_length=512,
                    text_key="text", label_key="label"):
    """
    Evaluate a fine-tuned model on each diagnostic group and compute
    Average Group Accuracy (AGA) and Worst Group Accuracy (WGA).
    """

    groups = diagnostic_dict["groups"]

    # --- device setup ---
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    group_acc = {}
    total_correct = total_total = 0

    for gid, ds in groups.items():
        if len(ds) == 0:
            group_acc[gid] = None
            continue

        tokenized = ds.map(lambda b: tokenizer(
            b[text_key],
            padding="max_length",
            truncation=True,
            max_length=max_length
        ), batched=True)
        tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", label_key])

        dataloader = DataLoader(tokenized, batch_size=batch_size)

        correct = total = 0
        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch[label_key].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=-1)

                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = correct / total if total > 0 else 0.0
        group_acc[gid] = acc
        total_correct += correct
        total_total += total

    aga = sum(v for v in group_acc.values() if v is not None) / len(group_acc)
    wga = min(v for v in group_acc.values() if v is not None)
    overall = total_correct / total_total

    print("\n=== Group Results ===")
    for g, v in group_acc.items():
        print(f"{g}: {v:.3f}")
    print(f"Overall Accuracy: {overall:.3f}")
    print(f"AGA (mean of groups): {aga:.3f}")
    print(f"WGA (worst group): {wga:.3f}")

    return {
        "group_acc": group_acc,
        "overall": overall,
        "AGA": aga,
        "WGA": wga
    }

def test_phrase(ph,ds):
    diag = build_diagnostic_set(dataset_split=ds, phrase=ph)
    print(diag)
    eval = evaluate_groups(model, tokenizer, diag)



In [147]:
test_phrase("10/10",train_data)

Found 256 instances of the phrase '10/10'.


Flattening the indices: 100%|██████████| 72/72 [00:00<00:00, 23870.83 examples/s]


Diagnostic set for phrase '10/10' built with 72 samples (18 per group).
{'groups': {'G1': Dataset({
    features: ['text', 'label'],
    num_rows: 18
}), 'G2': Dataset({
    features: ['text', 'label'],
    num_rows: 18
}), 'G3': Dataset({
    features: ['text', 'label'],
    num_rows: 18
}), 'G4': Dataset({
    features: ['text', 'label'],
    num_rows: 18
})}, 'diagnostic': Dataset({
    features: ['text', 'label', 'phrase_present', 'group_id'],
    num_rows: 72
})}


Map: 100%|██████████| 18/18 [00:00<00:00, 1876.79 examples/s]
Map: 100%|██████████| 18/18 [00:00<00:00, 1440.02 examples/s]
Map: 100%|██████████| 18/18 [00:00<00:00, 1645.36 examples/s]



=== Group Results ===
G1: 1.000
G2: 0.889
G3: 1.000
G4: 0.944
Overall Accuracy: 0.958
AGA (mean of groups): 0.958
WGA (worst group): 0.889


In [148]:
synthetic_voight_set = load_dataset("csv", data_files="synthetic_voight.csv")
test_phrase("voight", synthetic_voight_set["train"])

# Uncomment to evaluate accuracy on shortcut subset 
# evaluate_phrase_subset(model, tokenizer, synthetic_voight_set["train"], phrase="voight")

Found 100 instances of the phrase 'voight'.
Diagnostic set for phrase 'voight' built with 200 samples (50 per group).
{'groups': {'G1': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G2': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G3': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G4': Dataset({
    features: ['text', 'label'],
    num_rows: 50
})}, 'diagnostic': Dataset({
    features: ['text', 'label', 'phrase_present', 'group_id'],
    num_rows: 200
})}

=== Group Results ===
G1: 0.960
G2: 0.960
G3: 0.920
G4: 0.920
Overall Accuracy: 0.940
AGA (mean of groups): 0.940
WGA (worst group): 0.920


In [149]:
numeric_set = load_dataset("csv", data_files="numeric.csv")
test_phrase("7/10", numeric_set["train"])

# Uncomment to evaluate accuracy on shortcut subset 
# evaluate_phrase_subset(model, tokenizer, numeric_set["train"], phrase="7/10")

Found 100 instances of the phrase '7/10'.
Diagnostic set for phrase '7/10' built with 200 samples (50 per group).
{'groups': {'G1': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G2': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G3': Dataset({
    features: ['text', 'label'],
    num_rows: 50
}), 'G4': Dataset({
    features: ['text', 'label'],
    num_rows: 50
})}, 'diagnostic': Dataset({
    features: ['text', 'label', 'phrase_present', 'group_id'],
    num_rows: 200
})}

=== Group Results ===
G1: 0.940
G2: 0.360
G3: 0.840
G4: 0.960
Overall Accuracy: 0.775
AGA (mean of groups): 0.775
WGA (worst group): 0.360


**Lets implement some testing scenarios for synthetic datasets testing single shortcuts candidate phrases**

In [102]:
def evaluate_groups(dataset, model, tokenizer):
    """
    Evaluate the DistilBERT classifier on a HF Dataset with columns:
    - text (str)
    - label (int)
    - group (str)
    - s_present (int, ignored here)

    Returns:
        {
            "all_preds": torch.Tensor,
            "all_labels": torch.Tensor,
            "overall_accuracy": float,
            "group_accuracy": dict[str, float],
        }
    """

    # ---------- Helper ----------
    def accuracy(preds, labels):
        return (preds == labels).sum().item() / len(labels)

    # Save group info from the ORIGINAL dataset (order is preserved)
    # This is a plain Python list, independent of later set_format calls.
    groups = dataset["group"]

    # ---------- Tokenization ----------
    def tokenize_batch(batch):
        return tokenizer(
            batch["text"],
            truncation=True,
            padding="max_length",
            max_length=256,
        )

    # Remove columns we don't need for the model forward pass
    remove_columns = [c for c in dataset.column_names if c not in ("text", "label")]

    tokenized_dataset = dataset.map(
        tokenize_batch,
        batched=True,
        remove_columns=remove_columns,  # drops group & s_present here
    )

    # Tell HF Datasets to return PyTorch tensors for these columns
    tokenized_dataset.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "label"],
    )

    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=32,
        shuffle=False,  # keep order aligned with `groups`
    )

    # ---------- Device ----------
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    # ---------- Evaluation loop ----------
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Concatenate all batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # ---------- Overall accuracy ----------
    overall_acc = accuracy(all_preds, all_labels)

    # ---------- Accuracy per group ----------
    preds_list = all_preds.tolist()
    labels_list = all_labels.tolist()

    from collections import defaultdict
    group_correct = defaultdict(int)
    group_total = defaultdict(int)

    for pred, label, grp in zip(preds_list, labels_list, groups):
        group_total[grp] += 1
        if pred == label:
            group_correct[grp] += 1

    group_accuracy = {
        grp: group_correct[grp] / group_total[grp]
        for grp in sorted(group_total.keys())
        if group_total[grp] > 0
    }

    return {
        "all_preds": all_preds,
        "all_labels": all_labels,
        "overall_accuracy": overall_acc,
        "group_accuracy": group_accuracy,
    }
numeric_set = load_dataset("csv", data_files="numeric.csv")
results = evaluate_groups(numeric_set["train"], model, tokenizer)
print(results["overall_accuracy"])
print(results["group_accuracy"])  # accuracies

0.775
{'G1_S1_Y1': 0.94, 'G2_S1_Y0': 0.36, 'G3_S0_Y1': 0.84, 'G4_S0_Y0': 0.96}


**Adding Flip and Delete Tests**

1. Extract Phrase
2. Run Model on subset
3. Evaluation

In [115]:
def extract_phrase(ds, phrase):
    phrase = phrase.lower()
    subset = []
    for set in ds:
        subset_temp = set.filter(lambda x: phrase.lower() in x["text"].lower()
                       )
        subset.append(subset_temp)

    return concatenate_datasets(subset)

def run_model_on_subset(dataset, model=model, tokenizer=tokenizer):
    texts = [str(t) for t in dataset["text"]]
    gold = list(dataset["label"])
    

    # Tokenize in one batch
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt")
    
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model.to(device)
    enc = {k: v.to(device) for k, v in enc.items()}
    
    model.eval()

    # Predict
    with torch.no_grad():
        logits = model(**enc).logits
        probs = softmax(logits, dim=1).cpu().numpy()

    logits_np = logits.cpu().numpy()# raw logits as numpy
    pred = probs.argmax(axis=1).tolist()
    pos_prob = probs[:,1].tolist()

    # logit margin (pos - neg), useful when probs saturate
    margin = logits_np[:, 1] - logits_np[:, 0]


    # Print nicely
    # for t, g, p in zip(texts, gold, preds):
    #     print("TEXT:", t[:150], "...")
    #     print("GOLD:", g)
    #     print("PRED:", p)
    #     print("---------")
    

    return {
        "text": texts,
        "gold": gold,
        "pred": pred,
        "pos_prob": pos_prob,
        "logits": logits_np.tolist(),
        "margin": margin 
            }


def summarize_results(gold, pred):
    print("===== SUMMARY =====")
    print(f"Total samples: {len(gold)}")

    # Accuracy
    acc = accuracy_score(gold, pred)
    print(f"Accuracy: {acc:.4f}")

    # Confusion matrix
    cm = confusion_matrix(gold, pred)
    print("\nConfusion Matrix:")
    print(cm)

    # Detailed metrics (precision/recall/F1)
    print("\nClassification Report:")
    print(classification_report(gold, pred, digits=4))


phrase = extract_phrase([train_data], "7/10")
results = run_model_on_subset(phrase, model, tokenizer)
summarize_results(results["gold"],results["pred"])


===== SUMMARY =====
Total samples: 198
Accuracy: 0.9697

Confusion Matrix:
[[  4   2]
 [  4 188]]

Classification Report:
              precision    recall  f1-score   support

           0     0.5000    0.6667    0.5714         6
           1     0.9895    0.9792    0.9843       192

    accuracy                         0.9697       198
   macro avg     0.7447    0.8229    0.7779       198
weighted avg     0.9746    0.9697    0.9718       198



Adding Flip Functionality

In [124]:
def replace_phrase(dataset, old_phrase, new_phrase):
    pattern = re.compile(re.escape(old_phrase), re.IGNORECASE)

    def replace_fn(batch):
        texts = batch["text"]
        updated = [pattern.sub(new_phrase, t) for t in texts]
        return {"text": updated}

    return dataset.map(replace_fn, batched=True)

def compare_behavior(orig, perturbed):
    orig_p = np.array(orig["pos_prob"])
    pert_p = np.array(perturbed["pos_prob"])

    delta_p = (orig_p - pert_p).mean()
    flip_rate = (np.array(orig["pred"]) != np.array(perturbed["pred"])).mean()

    print(f"mean Δp(pos): {delta_p:.4f}")
    print(f"prediction flip rate: {flip_rate*100:.2f}%")

def compare_behavior_with_logits(orig_res, pert_res, eps=1e-8):
    orig_p = np.array(orig_res["pos_prob"])
    pert_p = np.array(pert_res["pos_prob"])

    orig_margin = np.array(orig_res["margin"])
    pert_margin = np.array(pert_res["margin"])

    orig_pred = np.array(orig_res["pred"])
    pert_pred = np.array(pert_res["pred"])

    flip_rate = (orig_pred != pert_pred).mean()
    delta_p = (orig_p - pert_p).mean()
    delta_margin = (orig_margin - pert_margin).mean()

    def logit_fn(p):
        p = np.clip(p, eps, 1 - eps)
        return np.log(p / (1 - p))

    delta_logodds = (logit_fn(orig_p) - logit_fn(pert_p)).mean()

    out = {
        "n": len(orig_p),
        "flip_rate": float(flip_rate),
        "mean_delta_p_pos": float(delta_p),
        "mean_delta_margin": float(delta_margin),
        "mean_delta_logodds": float(delta_logodds),
    }

    # print(f"n={out['n']}")
    # print(f"prediction flip rate: {out['flip_rate']*100:.2f}%")
    # print(f"mean Δp(pos):        {out['mean_delta_p_pos']:.4f}")
    # print(f"mean Δmargin:        {out['mean_delta_margin']:.4f}")
    # print(f"mean Δlog-odds:      {out['mean_delta_logodds']:.4f}")

    return out

def flip_test(ds, phrase, replacement,model=model, tokenizer=tokenizer):
    #etract phrase from dataset(s)
    subset = extract_phrase(ds,phrase)

    # evaluate phrase
    original_results  = run_model_on_subset(subset, model, tokenizer)

    # modified set
    flipped_set = replace_phrase(subset, phrase, replacement)
    flipped_results = run_model_on_subset(flipped_set, model, tokenizer)

    # Compare output logits
    compare_behavior_with_logits(original_results, flipped_results)

    # Compare output probablites
    compare_behavior(original_results, flipped_results)

    # Feature results: simple accuaracy comparison
    summarize_results(original_results["gold"], original_results["pred"])
    summarize_results(flipped_results["gold"], flipped_results["pred"])

    return subset, flipped_set

In [130]:
old_phrase = "7/10"
new_phrase = "1/10"
x,y = flip_test([test_data], old_phrase, new_phrase, model=model, tokenizer = tokenizer)

mean Δp(pos): 0.1130
prediction flip rate: 11.11%
===== SUMMARY =====
Total samples: 198
Accuracy: 0.8990

Confusion Matrix:
[[  5   3]
 [ 17 173]]

Classification Report:
              precision    recall  f1-score   support

           0     0.2273    0.6250    0.3333         8
           1     0.9830    0.9105    0.9454       190

    accuracy                         0.8990       198
   macro avg     0.6051    0.7678    0.6393       198
weighted avg     0.9524    0.8990    0.9206       198

===== SUMMARY =====
Total samples: 198
Accuracy: 0.8081

Confusion Matrix:
[[  7   1]
 [ 37 153]]

Classification Report:
              precision    recall  f1-score   support

           0     0.1591    0.8750    0.2692         8
           1     0.9935    0.8053    0.8895       190

    accuracy                         0.8081       198
   macro avg     0.5763    0.8401    0.5794       198
weighted avg     0.9598    0.8081    0.8645       198



In [None]:
# voight = load_dataset("csv", data_files="synthetic_voight.csv")


# num = extract_phrase([voight["train"]], "voight")

# old_phrase = "voight"
# new_phrase = "7/10"
# x,y = flip_test([voight["train"]], old_phrase, new_phrase, model=model, tokenizer = tokenizer)

numeric = load_dataset("csv", data_files="numeric.csv")


old_phrase = "7/10"
new_phrase = "0/10"
x,y = flip_test([numeric["train"]], old_phrase, new_phrase, model=model, tokenizer = tokenizer)

Map: 100%|██████████| 100/100 [00:00<00:00, 27614.09 examples/s]

n=100
prediction flip rate: 39.00%
mean Δp(pos):        0.3935
mean Δmargin:        4.8045
mean Δlog-odds:      4.8045
mean Δp(pos): 0.3935
prediction flip rate: 39.00%
===== SUMMARY =====
Total samples: 100
Accuracy: 0.6500

Confusion Matrix:
[[18 32]
 [ 3 47]]

Classification Report:
              precision    recall  f1-score   support

           0     0.8571    0.3600    0.5070        50
           1     0.5949    0.9400    0.7287        50

    accuracy                         0.6500       100
   macro avg     0.7260    0.6500    0.6179       100
weighted avg     0.7260    0.6500    0.6179       100

===== SUMMARY =====
Total samples: 100
Accuracy: 0.8600

Confusion Matrix:
[[48  2]
 [12 38]]

Classification Report:
              precision    recall  f1-score   support

           0     0.8000    0.9600    0.8727        50
           1     0.9500    0.7600    0.8444        50

    accuracy                         0.8600       100
   macro avg     0.8750    0.8600    0.8586       




**Delete Test**

In [133]:
def delete_phrase_dataset(dataset, phrase):
    pattern = re.compile(re.escape(phrase), re.IGNORECASE)

    def delete_fn(batch):
        texts = batch["text"]
        updated = [pattern.sub("", t).replace("  ", " ").strip() for t in texts]
        return {"text": updated}

    return dataset.map(delete_fn, batched=True)

def delete_test(ds, phrase, model=model, tokenizer=tokenizer):
    # extract phrase subset
    subset = extract_phrase(ds, phrase)

    # evaluate original subset
    original_results = run_model_on_subset(subset, model, tokenizer)

    # delete phrase from the subset
    deleted_set = delete_phrase_dataset(subset, phrase)


    # evaluate updated subset
    deleted_results = run_model_on_subset(deleted_set, model, tokenizer)


    # Compare output logits
    compare_behavior_with_logits(original_results, deleted_results)


    compare_behavior(original_results, deleted_results)


    # # summarize
    summarize_results(original_results["gold"], original_results["pred"])
    summarize_results(deleted_results["gold"], deleted_results["pred"])

    return subset, deleted_set




In [136]:
old_phrase = "7/10"
x,y = delete_test([train_data], old_phrase, model=model, tokenizer = tokenizer)

mean Δp(pos): 0.0454
prediction flip rate: 4.04%
===== SUMMARY =====
Total samples: 198
Accuracy: 0.9697

Confusion Matrix:
[[  4   2]
 [  4 188]]

Classification Report:
              precision    recall  f1-score   support

           0     0.5000    0.6667    0.5714         6
           1     0.9895    0.9792    0.9843       192

    accuracy                         0.9697       198
   macro avg     0.7447    0.8229    0.7779       198
weighted avg     0.9746    0.9697    0.9718       198

===== SUMMARY =====
Total samples: 198
Accuracy: 0.9495

Confusion Matrix:
[[  6   0]
 [ 10 182]]

Classification Report:
              precision    recall  f1-score   support

           0     0.3750    1.0000    0.5455         6
           1     1.0000    0.9479    0.9733       192

    accuracy                         0.9495       198
   macro avg     0.6875    0.9740    0.7594       198
weighted avg     0.9811    0.9495    0.9603       198



In [138]:
numeric = load_dataset("csv", data_files="numeric.csv")


old_phrase = "7/10"
x,y = delete_test([numeric["train"]], old_phrase, model=model, tokenizer = tokenizer)

mean Δp(pos): 0.2840
prediction flip rate: 27.00%
===== SUMMARY =====
Total samples: 100
Accuracy: 0.6500

Confusion Matrix:
[[18 32]
 [ 3 47]]

Classification Report:
              precision    recall  f1-score   support

           0     0.8571    0.3600    0.5070        50
           1     0.5949    0.9400    0.7287        50

    accuracy                         0.6500       100
   macro avg     0.7260    0.6500    0.6179       100
weighted avg     0.7260    0.6500    0.6179       100

===== SUMMARY =====
Total samples: 100
Accuracy: 0.9000

Confusion Matrix:
[[44  6]
 [ 4 46]]

Classification Report:
              precision    recall  f1-score   support

           0     0.9167    0.8800    0.8980        50
           1     0.8846    0.9200    0.9020        50

    accuracy                         0.9000       100
   macro avg     0.9006    0.9000    0.9000       100
weighted avg     0.9006    0.9000    0.9000       100



In [None]:
#TODO add code from synthetic dataset testing