In [1]:
from datasets import load_dataset
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import numpy as np
import evaluate
import os
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import softmax
from collections import Counter
import re
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model_path = "./distillbert-base-finetuned"
# from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# tokenizer = DistilBertTokenizer.from_pretrained(model_path)
# model = DistilBertForSequenceClassification.from_pretrained(model_path)



In [3]:
model_path = "./bert-finetuned"
from transformers import (BertTokenizerFast,BertForSequenceClassification)
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)


In [None]:
# Load dataset
dataset = load_dataset('imdb')
train_data = dataset["train"]
test_data = dataset["test"]


**DATA AUDIT:**
1. Extract Word occurances according to sentiment into two groups: positve sentiment/ negative sentiment
2. Identify word correlations with sentiments
3. Evaluate Single Phrase

In [5]:
#1. Extract word occurances into sentiment groups 
def count_words(dataset):
    # Counters: in how many REVIEWS each word appears (pos/neg)
    c_pos_word = Counter()
    c_neg_word = Counter()

    # Simple word pattern:
    # - sequences of letters, possibly with ' or - inside (e.g. "spielberg's", "well-made")
    word_re = re.compile(
    r"""
    [A-Za-z][A-Za-z'-]*     # words like "spielberg's", "well-made"
    |                       # OR
    \d+/\d+                 # numeric ratings like 8/10, 10/10
    |                       # OR
    !+                      # one or more exclamation marks
    """,
    re.VERBOSE
)
    # TODO Extract digits/ ratings and exclamation marks maybe?

    for example in dataset: # For now inspecting training data
        text = example["text"].lower()
        label = example["label"]  # 1 = pos, 0 = neg

        # Extract words
        words = word_re.findall(text)

        # Use unique words per sample
        unique_words = set(words)

        if label == 1:
            for word in unique_words:
                c_pos_word[word] += 1
        else:
            for word in unique_words:
                c_neg_word[word] += 1

    print("Distinct words in positive reviews:", len(c_pos_word))
    print("Distinct words in negative reviews:", len(c_neg_word))
    # sanity check
    print("Example:", {w: (c_pos_word[w], c_neg_word[w]) for w in ["spielberg", "tarantino", "excellent", "terrible"]})
    return c_pos_word, c_neg_word

c_pos_word, c_neg_word = count_words(train_data)


Distinct words in positive reviews: 71620
Distinct words in negative reviews: 70324
Example: {'spielberg': (48, 30), 'tarantino': (21, 35), 'excellent': (1425, 350), 'terrible': (215, 1114)}


In [6]:
# check single/ multiple words 
def check_single_or_multiple_words(wordlist, c_pos_word,c_neg_word):
    for word in wordlist:
        count_pos = c_pos_word[word]
        count_neg = c_neg_word[word]
        total = count_pos + count_neg
        if total > 0:
            bias_pos = count_pos / total
            print(f"{word:10s} total={total:4d} pos={count_pos:4d} neg={count_neg:4d} bias_pos={bias_pos:.3f}")

check_single_or_multiple_words(["spielberg", "tarantino", "scorsese", "norris", "seagal"],c_pos_word,c_neg_word)

spielberg  total=  78 pos=  48 neg=  30 bias_pos=0.615
tarantino  total=  56 pos=  21 neg=  35 bias_pos=0.375
scorsese   total=  31 pos=  16 neg=  15 bias_pos=0.516
norris     total=  20 pos=   7 neg=  13 bias_pos=0.350
seagal     total=  49 pos=   3 neg=  46 bias_pos=0.061


In [7]:
def identify_candidates_with_bias_filtered(c_pos_word,c_neg_word, word_frequency, bias_threshold, exclusion_list):

    min_count = word_frequency          # a bit lower to catch rarer names
    bias_threshold = bias_threshold   # strong skew


    def is_suspect(word):
        # crude heuristic: skip common sentiment suffixes/adverbs/adjectives
        if word in exclusion_list:
            return False
        if word.endswith(("ly", "est")):
            return False
        if len(word) <= 3:
            return False
        return True

    vocab = set(c_pos_word.keys()) | set(c_neg_word.keys())

    pos_suspects = []
    neg_suspects = []

    # Same bias calculation as above
    for word in vocab:
        count_pos = c_pos_word[word]
        count_neg = c_neg_word[word]
        total = count_pos + count_neg
        if total < min_count:
            continue

        bias_pos = count_pos / total

        if bias_pos >= bias_threshold and is_suspect(word): #filter
            pos_suspects.append((word, bias_pos, total, count_pos, count_neg))
        elif (1 - bias_pos) >= bias_threshold and is_suspect(word): #filter for negative
            neg_suspects.append((word, 1 - bias_pos, total, count_pos, count_neg))

    pos_suspects.sort(key=lambda x: (x[1], x[2]), reverse=True)
    neg_suspects.sort(key=lambda x: (x[1], x[2]), reverse=True)

    pos_output,neg_output= [],[]

    print("Positive shortcut-like candidates:")
    for word, bias, total, count_pos, count_neg in pos_suspects[:50]:
        print(f"{word:20s} bias_pos={bias:.3f} total={total:4d} pos={count_pos:4d} neg={count_neg:4d}")
        pos_output.append(word)

    print("\nNegative shortcut-like candidates:")
    for word, bias, total, count_pos, count_neg in neg_suspects[:50]:
        print(f"{word:20s} bias_neg={bias:.3f} total={total:4d} pos={count_pos:4d} neg={count_neg:4d}")
        neg_output.append(word)
    
    #return pos_output, neg_output





exclusion_list = [
    # Positive-associated words
    "flawless", "superbly", "perfection", "captures", "wonderfully", "refreshing",
    "breathtaking", "must-see", "delightful", "underrated", "beautifully", "gripping",
    "delight", "timeless", "superb", "favorites", "touching", "unforgettable",
    "extraordinary", "tremendous", "brilliantly", "splendid", "terrific",
    "gentle", "gem", "marvelous", "finest", "pleasantly", "magnificent", "exceptional",
    "poignant", "outstanding", "captivating", "wonderful", "freedom", "excellent",
    "fantastic", "ensemble", "innocence", "overlooked",
    "shines", "great", "perfect", "heartwarming", "fabulous", "awesome", "amazing",
    "masterful", "top-notch", "mesmerizing",
    "first-rate", "affection", "delicate", "understated", "absorbing",
    "technicolor", "tender", "restrained", "heartfelt", "rewarding",
    "astonishing", "delicious", "stark", "feel-good", "cerebral",

    # Negative-associated words
    "unwatchable", "stinker", "incoherent", "unfunny", "waste", "atrocious", "horrid",
    "drivel", "pointless", "redeeming", "lousy", "laughable", "worst", "wasting",
    "awful", "poorly", "insult", "non-existent", "boredom", "lame", "sucks", "miserably",
    "uninspired", "stupidity", "unintentional", "amateurish", "appalling", "uninteresting",
    "pathetic", "unconvincing", "idiotic", "insulting", "wasted", "suck", "crap", "tedious",
    "dreadful", "dire", "horrible", "pile", "mess", "garbage", "embarrassing", "cardboard",
    "wooden", "badly", "terrible", "turkey", "bad", "boring", "heartbreaking", "rubbish",
    "lifeless", "filth", "moronic", "stinks", "flop", "incomprehensible", "rip-off", "tiresome",
    "dreck", "yawn", "flimsy", "turd", "tripe", "blah",
    "unimaginative", "sub-par", "unoriginal", "insipid", "abysmal",
    "embarrassment", "unlikeable", "inane", "incompetent", "pitiful", "tolerable",
    "whiny", "wretched", "headache", "worse", "stupid"
    
    #TODO Extend
]



identify_candidates_with_bias_filtered(c_pos_word,c_neg_word, 50, 0.80, exclusion_list)


Positive shortcut-like candidates:
7/10                 bias_pos=0.970 total= 198 pos= 192 neg=   6
8/10                 bias_pos=0.959 total= 222 pos= 213 neg=   9
9/10                 bias_pos=0.941 total= 153 pos= 144 neg=   9
10/10                bias_pos=0.930 total= 256 pos= 238 neg=  18
matthau              bias_pos=0.923 total=  65 pos=  60 neg=   5
explores             bias_pos=0.882 total=  68 pos=  60 neg=   8
hawke                bias_pos=0.882 total=  51 pos=  45 neg=   6
voight               bias_pos=0.864 total=  66 pos=  57 neg=   9
peters               bias_pos=0.863 total=  51 pos=  44 neg=   7
victoria             bias_pos=0.861 total=  72 pos=  62 neg=  10
powell               bias_pos=0.856 total=  97 pos=  83 neg=  14
sadness              bias_pos=0.847 total= 111 pos=  94 neg=  17
walsh                bias_pos=0.843 total=  51 pos=  43 neg=   8
mann                 bias_pos=0.840 total=  50 pos=  42 neg=   8
winters              bias_pos=0.831 total=  71 pos=  59

In [8]:
# TODO: Idea:generate samples with lobsided words (identified by expert?)
positive_candidate_shortcuts=[
  '7/10',
  '8/10',
  '9/10',
  '10/10',
  'matthau', # actor
  'explores',
  'hawke', # actor
  'voight', # actor
  'peters',
  'victoria',
  'powell',
  'sadness',
  'walsh',
  'mann',
  'winters',
  'brosnan',
  'layers',
  'friendship',
  'ralph',
  'montana',
  'watson',
  'sullivan',
  'detract',
  'conveys',
  'loneliness',
  'lemmon',
  'nancy',
  'blake',
  'odyssey',
  'pierce',
  'macy',
  'neglected']


negative_candidate_shortcuts =[
  '2/10',
  'boll',
  '4/10',
  '3/10',
  '1/10',
  'nope',
  'camcorder',
  'baldwin',
  'arty',
  'cannibal',
  'rubber',
  'shoddy',
  'barrel',
  'plodding',
  'plastic',
  'mutant',
  'costs',
  'claus',
  'ludicrous',
  'nonsensical',
  'bother',
  'disjointed']

In [9]:
## Eval Single Phrase
def evaluate_phrase_subset(model,
                           tokenizer,
                           dataset_split,
                           phrase,
                           batch_size=16,
                           max_length=512,
                           text_key="text",
                           label_key="label",
                           use_regex=False):
    """
    Evaluate model accuracy and label distributions on subset of examples
    containing a given phrase or regex pattern.
    """

    # 1) Filter examples and create subset
    if use_regex:
        regex = re.compile(phrase, flags=re.IGNORECASE)  # user-supplied pattern
        def contains(example):
            return bool(regex.search(example[text_key]))
    else:
        # Exact word/phrase match with boundaries; allow optional possessive 's / ’s
        escaped = re.escape(phrase)  # treat literal phrase safely
        pattern = rf"(?<!\w){escaped}(?:'s|’s)?(?!\w)"
        regex = re.compile(pattern, flags=re.IGNORECASE)

    def contains(example):
        return bool(regex.search(example[text_key]))

    subset = dataset_split.filter(contains)
    num_examples = len(subset) # Count occurances

    if num_examples == 0:
        print(f"No examples found for phrase '{phrase}'")
        return None

    # 2) Tokenize
    def tokenize_fn(batch):
        return tokenizer(
            batch[text_key],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )

    tokenized_dataset = subset.map(tokenize_fn, batched=True)
    tokenized_dataset.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", label_key]
    )

    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size)

    # 3) Device setup
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    # 4) Evaluate
    correct = total = 0
    gold_counts, pred_counts = Counter(), Counter()

    with torch.no_grad(): #
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch[label_key].to(device)

            # run model
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)

            correct += (preds == labels).sum().item()# num of correct rpredictions
            total += labels.size(0) # num of samples in the batch

            gold_counts.update(labels.cpu().tolist())
            pred_counts.update(preds.cpu().tolist())

    accuracy = correct / total if total > 0 else 0.0

    # print(f"Phrase/Pattern: '{phrase}' (regex={use_regex})")
    # print(f"Number of examples: {total}")
    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Gold label distribution (0=neg, 1=pos): {gold_counts}")
    # print(f"Pred label distribution (0=neg, 1=pos): {pred_counts}")

    return {
        "subset":subset,
        "phrase": phrase,
        "regex_used": use_regex,
        "num_examples": total,
        "accuracy": accuracy,
        "gold_label_distribution": dict(gold_counts),
        "pred_label_distribution": dict(pred_counts),
    }


In [10]:
test = evaluate_phrase_subset(model, tokenizer, dataset["train"],
                       phrase="voight")
test

Map:   0%|          | 0/68 [00:00<?, ? examples/s]

Map: 100%|██████████| 68/68 [00:00<00:00, 1661.25 examples/s]


{'subset': Dataset({
     features: ['text', 'label'],
     num_rows: 68
 }),
 'phrase': 'voight',
 'regex_used': False,
 'num_examples': 68,
 'accuracy': 1.0,
 'gold_label_distribution': {0: 10, 1: 58},
 'pred_label_distribution': {0: 10, 1: 58}}

In [11]:
from datasets import Dataset
import re
import random

def build_diagnostic_set(dataset_split,
                         phrase,
                         text_key="text",
                         label_key="label",
                         max_per_group=None,
                         use_regex=False):
    """
    Build a 4-group diagnostic dataset for a phrase:
    Groups:
      G1: (S=1, Y=1)
      G2: (S=1, Y=0)
      G3: (S=0, Y=1)
      G4: (S=0, Y=0)
    Returns a dict of group Datasets and a merged balanced diagnostic Dataset.
    """

    # --- phrase matching setup ---
    if use_regex:
        regex = re.compile(phrase, flags=re.IGNORECASE)
    else:
        escaped = re.escape(phrase)
        pattern = rf"(?<!\w){escaped}(?:'s|’s)?(?!\w)"
        regex = re.compile(pattern, flags=re.IGNORECASE)

    def contains_phrase(example):
        return bool(regex.search(example[text_key]))

    # --- create 4 groups ---
    def filter_group(has_phrase, label_value):
        return dataset_split.filter(
            lambda ex: contains_phrase(ex) == has_phrase and ex[label_key] == label_value
        )

    g1 = filter_group(True, 1)   # phrase + positive
    g2 = filter_group(True, 0)   # phrase + negative <-------
    g3 = filter_group(False, 1)  # no phrase + positive
    g4 = filter_group(False, 0)  # no phrase + negative

    # G1: phrase present (S=1), label positive (Y=1)
    # G2: phrase present (S=1), label negative (Y=0)
    # G3: phrase absent (S=0), label positive (Y=1)
    # G4: phrase absent (S=0), label negative (Y=0)

    # --- balancing --- Make sure all four groups have the same num of examples: balanced and fair dataset
    if max_per_group is None:
        min_size = min(len(g1), len(g2), len(g3), len(g4))
    else:
        min_size = min(max_per_group, len(g1), len(g2), len(g3), len(g4))

    def sample(ds):
        if len(ds) > min_size:
            idxs = random.sample(range(len(ds)), min_size)
            return ds.select(idxs)
        return ds

    g1b, g2b, g3b, g4b = map(sample, [g1, g2, g3, g4])

    # --- merge all groups ---
    from datasets import concatenate_datasets

    diagnostic = concatenate_datasets([g1b, g2b, g3b, g4b]).add_column(
        "phrase_present",
        [1]*len(g1b) + [1]*len(g2b) + [0]*len(g3b) + [0]*len(g4b)
    ).add_column(
        "group_id",
        ["G1_S1_Y1"]*len(g1b) +
        ["G2_S1_Y0"]*len(g2b) +
        ["G3_S0_Y1"]*len(g3b) +
        ["G4_S0_Y0"]*len(g4b)
    )


    print(f"Diagnostic set for phrase '{phrase}' built with {len(diagnostic)} samples "
          f"({min_size} per group).")

    return {
        "groups": {"G1": g1b, "G2": g2b, "G3": g3b, "G4": g4b},
        "diagnostic": diagnostic
    }


In [12]:
diag = build_diagnostic_set(dataset_split=test_data, phrase="1/10")
inspect = diag["diagnostic"].to_pandas()
diag["diagnostic"].to_pandas().groupby("group_id").count()




Flattening the indices: 100%|██████████| 12/12 [00:00<00:00, 3835.38 examples/s]

Diagnostic set for phrase '1/10' built with 12 samples (3 per group).





Unnamed: 0_level_0,text,label,phrase_present
group_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
G1_S1_Y1,3,3,3
G2_S1_Y0,3,3,3
G3_S0_Y1,3,3,3
G4_S0_Y0,3,3,3


In [13]:
diag

{'groups': {'G1': Dataset({
      features: ['text', 'label'],
      num_rows: 3
  }),
  'G2': Dataset({
      features: ['text', 'label'],
      num_rows: 3
  }),
  'G3': Dataset({
      features: ['text', 'label'],
      num_rows: 3
  }),
  'G4': Dataset({
      features: ['text', 'label'],
      num_rows: 3
  })},
 'diagnostic': Dataset({
     features: ['text', 'label', 'phrase_present', 'group_id'],
     num_rows: 12
 })}

In [14]:
import torch
from torch.utils.data import DataLoader
from collections import defaultdict

def evaluate_groups(model, tokenizer, diagnostic_dict,
                    batch_size=16, max_length=512,
                    text_key="text", label_key="label"):
    """
    Evaluate a fine-tuned model on each diagnostic group and compute
    Average Group Accuracy (AGA) and Worst Group Accuracy (WGA).
    """

    groups = diagnostic_dict["groups"]

    # --- device setup ---
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    group_acc = {}
    total_correct = total_total = 0

    for gid, ds in groups.items():
        if len(ds) == 0:
            group_acc[gid] = None
            continue

        tokenized = ds.map(lambda b: tokenizer(
            b[text_key],
            padding="max_length",
            truncation=True,
            max_length=max_length
        ), batched=True)
        tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", label_key])

        dataloader = DataLoader(tokenized, batch_size=batch_size)

        correct = total = 0
        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch[label_key].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                preds = torch.argmax(outputs.logits, dim=-1)

                correct += (preds == labels).sum().item()
                total += labels.size(0)

        acc = correct / total if total > 0 else 0.0
        group_acc[gid] = acc
        total_correct += correct
        total_total += total

    aga = sum(v for v in group_acc.values() if v is not None) / len(group_acc)
    wga = min(v for v in group_acc.values() if v is not None)
    overall = total_correct / total_total

    # print("\n=== Group Results ===")
    # for g, v in group_acc.items():
    #     print(f"{g}: {v:.3f}")
    # print(f"Overall Accuracy: {overall:.3f}")
    # print(f"AGA (mean of groups): {aga:.3f}")
    # print(f"WGA (worst group): {wga:.3f}")

    return {
        "group_acc": group_acc,
        "overall": overall,
        "AGA": aga,
        "WGA": wga
    }


In [15]:
# diag = build_diagnostic_set(dataset_split=train_data, phrase="powell")
results = evaluate_groups(model, tokenizer, diag)


Map: 100%|██████████| 3/3 [00:00<00:00, 579.48 examples/s]
Map: 100%|██████████| 3/3 [00:00<00:00, 804.69 examples/s]
Map: 100%|██████████| 3/3 [00:00<00:00, 681.52 examples/s]
Map: 100%|██████████| 3/3 [00:00<00:00, 806.70 examples/s]


In [16]:
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
from pprint import pprint


disable_progress_bar()

def pipeline(phrase):
    train_metric = evaluate_phrase_subset(model, tokenizer, dataset["train"],
                       phrase=phrase)
    test_metric = evaluate_phrase_subset(model, tokenizer, dataset["test"],
                       phrase=phrase)
    train_diag = build_diagnostic_set(dataset_split=train_data, phrase=phrase)
    train_result = evaluate_groups(model, tokenizer, train_diag)
    test_diag = build_diagnostic_set(dataset_split=test_data, phrase=phrase)
    test_result = evaluate_groups(model, tokenizer, test_diag)

    
    return {
        "train_diag": train_diag,
        "test_diag": test_diag,
        "train_metric": train_metric,
        "test_metric" : test_metric,
        "train_result": train_result,
        "test_result": test_result
    }
    
for phrase in positive_candidate_shortcuts:
    output = pipeline(phrase)
    pprint(output["train_metric"])
    pprint(output["test_metric"])
    pprint(output["train_result"])
    pprint(output["test_result"])
    break

enable_progress_bar()
    

Diagnostic set for phrase '7/10' built with 24 samples (6 per group).
Diagnostic set for phrase '7/10' built with 32 samples (8 per group).
{'accuracy': 0.9696969696969697,
 'gold_label_distribution': {0: 6, 1: 192},
 'num_examples': 198,
 'phrase': '7/10',
 'pred_label_distribution': {0: 8, 1: 190},
 'regex_used': False,
 'subset': Dataset({
    features: ['text', 'label'],
    num_rows: 198
})}
{'accuracy': 0.898989898989899,
 'gold_label_distribution': {0: 8, 1: 190},
 'num_examples': 198,
 'phrase': '7/10',
 'pred_label_distribution': {0: 22, 1: 176},
 'regex_used': False,
 'subset': Dataset({
    features: ['text', 'label'],
    num_rows: 198
})}
{'AGA': 0.9166666666666666,
 'WGA': 0.6666666666666666,
 'group_acc': {'G1': 1.0, 'G2': 0.6666666666666666, 'G3': 1.0, 'G4': 1.0},
 'overall': 0.9166666666666666}
{'AGA': 0.8125,
 'WGA': 0.625,
 'group_acc': {'G1': 0.875, 'G2': 0.625, 'G3': 0.75, 'G4': 1.0},
 'overall': 0.8125}


In [17]:
disable_progress_bar()
output = pipeline("voight")
pprint(output["train_metric"])
pprint(output["test_metric"])
pprint(output["train_result"])
pprint(output["test_result"])
enable_progress_bar()

Diagnostic set for phrase 'voight' built with 40 samples (10 per group).
Diagnostic set for phrase 'voight' built with 56 samples (14 per group).
{'accuracy': 1.0,
 'gold_label_distribution': {0: 10, 1: 58},
 'num_examples': 68,
 'phrase': 'voight',
 'pred_label_distribution': {0: 10, 1: 58},
 'regex_used': False,
 'subset': Dataset({
    features: ['text', 'label'],
    num_rows: 68
})}
{'accuracy': 0.8717948717948718,
 'gold_label_distribution': {0: 14, 1: 25},
 'num_examples': 39,
 'phrase': 'voight',
 'pred_label_distribution': {0: 13, 1: 26},
 'regex_used': False,
 'subset': Dataset({
    features: ['text', 'label'],
    num_rows: 39
})}
{'AGA': 0.95,
 'WGA': 0.8,
 'group_acc': {'G1': 1.0, 'G2': 1.0, 'G3': 0.8, 'G4': 1.0},
 'overall': 0.95}
{'AGA': 0.9107142857142857,
 'WGA': 0.7857142857142857,
 'group_acc': {'G1': 0.9285714285714286,
               'G2': 0.7857142857142857,
               'G3': 1.0,
               'G4': 0.9285714285714286},
 'overall': 0.9107142857142857}


In [18]:
# TODO: Idea:generate samples with lobsided words (identified by expert?)
positive_candidate_shortcuts=[
  '7/10',
  '8/10',
  '9/10',
  '10/10',
  'matthau', # actor
  'explores',
  'hawke', # actor
  'voight', # actor
  'peters',
  'victoria',
  'powell',
  'sadness',
  'walsh',
  'mann',
  'winters',
  'brosnan',
  'layers',
  'friendship',
  'ralph',
  'montana',
  'watson',
  'sullivan',
  'detract',
  'conveys',
  'loneliness',
  'lemmon',
  'nancy',
  'blake',
  'odyssey',
  'pierce',
  'macy',
  'neglected']


negative_candidate_shortcuts =[
  '2/10',
  'boll',
  '4/10',
  '3/10',
  '1/10',
  'nope',
  'camcorder',
  'baldwin',
  'arty',
  'cannibal',
  'rubber',
  'shoddy',
  'barrel',
  'plodding',
  'plastic',
  'mutant',
  'costs',
  'claus',
  'ludicrous',
  'nonsensical',
  'bother',
  'disjointed']

In [19]:
synthetic_voight_set = load_dataset("csv", data_files="synthetic_voight.csv")
synthetic_voight_set

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'group', 's_present'],
        num_rows: 200
    })
})

In [20]:
evaluate_groups(model,tokenizer,synthetic_voight_set)

KeyError: 'groups'

In [None]:
# TODO: PP Pipeline
# TODO: Add synthetic data
# TODO: Flip test
# TODO: Delete test
# TODO: Use 7/10