In [None]:
# ===== 0) Install (skip if you already have recent versions) =====
# !pip install -U "transformers>=4.40" "datasets>=3.0" "accelerate>=0.27" peft scikit-learn pandas matplotlib

import os
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    DataCollatorWithPadding, TrainingArguments, Trainer, set_seed
)
from sklearn.metrics import accuracy_score, f1_score
from collections import Counter

# ------------------ CONFIG ------------------
DATASET = "LabHC/bias_in_bios"
TEXT_COL = "hard_text"
Y_COL    = "profession"   # already numeric
G_COL    = "gender"       # 0 male, 1 female

# How small do you want it?
TOP_K_PROFESSIONS   = 28     # keep all 28 (set to e.g. 20 to shrink)
MAX_PER_GROUP_TRAIN = 800    # cap per (profession, gender) in train
MAX_PER_GROUP_DEV   = 200
MAX_PER_GROUP_TEST  = 400

# Training speed knobs
MAX_STEPS    = 150          # hard cap on updates
BATCH_TRAIN  = 32
BATCH_EVAL   = 64
MAX_LEN      = 128           # shorten sequences to speed up
LR           = 2e-5
SEED         = 42
set_seed(SEED)

print("CUDA:", torch.cuda.is_available(),
      torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# ------------------ 1) Load ------------------
train_ds = load_dataset(DATASET, split="train")
dev_ds   = load_dataset(DATASET, split="dev")
test_ds  = load_dataset(DATASET, split="test")
print("Original sizes:", len(train_ds), len(dev_ds), len(test_ds))

# ------------------ 2) Downsize helpers ------------------
def keep_topk(ds, label_col, k):
    counts = Counter(ds[label_col])
    topk = {lab for lab, _ in counts.most_common(k)}
    return ds.filter(lambda ex: ex[label_col] in topk)

def stratified_cap(ds, label_col, group_col, cap, seed=42):
    df = ds.to_pandas()
    df_small = (df.groupby([label_col, group_col], group_keys=False)
                  .apply(lambda x: x.sample(n=min(cap, len(x)), random_state=seed))
                  .reset_index(drop=True))
    return Dataset.from_pandas(df_small, preserve_index=False)

# Optional: restrict to top-K professions
if TOP_K_PROFESSIONS is not None:
    train_ds = keep_topk(train_ds, Y_COL, TOP_K_PROFESSIONS)
    dev_ds   = keep_topk(dev_ds,   Y_COL, TOP_K_PROFESSIONS)
    test_ds  = keep_topk(test_ds,  Y_COL, TOP_K_PROFESSIONS)

# Cap per (profession, gender)
train_ds = stratified_cap(train_ds, Y_COL, G_COL, MAX_PER_GROUP_TRAIN)
dev_ds   = stratified_cap(dev_ds,   Y_COL, G_COL, MAX_PER_GROUP_DEV)
test_ds  = stratified_cap(test_ds,  Y_COL, G_COL, MAX_PER_GROUP_TEST)

print("Downsized sizes:", len(train_ds), len(dev_ds), len(test_ds))

# ------------------ 3) Make labels contiguous (0..K-1) across ALL splits ------------------
all_labels = sorted(set(train_ds[Y_COL]) | set(dev_ds[Y_COL]) | set(test_ds[Y_COL]))
need_remap = (all_labels != list(range(len(all_labels)))) or (min(all_labels) != 0)
print("Labels contiguous?", not need_remap, "(#classes =", len(all_labels), ")")

if need_remap:
    label2id = {lab: i for i, lab in enumerate(all_labels)}
    def _remap(ex):
        ex[Y_COL] = label2id[ex[Y_COL]]
        return ex
    train_ds = train_ds.map(_remap)
    dev_ds   = dev_ds.map(_remap)
    test_ds  = test_ds.map(_remap)
else:
    label2id = {lab: lab for lab in all_labels}

num_labels = len(label2id)
print("num_labels =", num_labels)

# ------------------ 4) Tokenise ------------------
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tok(batch):
    enc = tokenizer(batch[TEXT_COL], truncation=True, max_length=MAX_LEN)
    enc["labels"]    = batch[Y_COL]
    enc["gender_id"] = batch[G_COL]
    return enc

train_tok = train_ds.map(tok, batched=True, remove_columns=train_ds.column_names)
dev_tok   = dev_ds.map(tok,   batched=True, remove_columns=dev_ds.column_names)
test_tok  = test_ds.map(tok,  batched=True, remove_columns=test_ds.column_names)

# ------------------ 5) Metrics ------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    return {"accuracy": accuracy_score(labels, preds),
            "macro_f1": f1_score(labels, preds, average="macro")}

def per_gender_scores(labels, preds, genders):
    out = {}
    for name, gid in [("male", 0), ("female", 1)]:
        mask = (genders == gid)
        if mask.sum() == 0:
            out[name] = {"n": 0, "acc": np.nan, "macro_f1": np.nan}
            continue
        out[name] = {
            "n": int(mask.sum()),
            "acc": accuracy_score(labels[mask], preds[mask]),
            "macro_f1": f1_score(labels[mask], preds[mask], average="macro")
        }
    out["Δacc"] = out["male"]["acc"] - out["female"]["acc"]
    out["Δmacro_f1"] = out["male"]["macro_f1"] - out["female"]["macro_f1"]
    return out

# ------------------ 6) Train (λ = 0) ------------------
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=num_labels
)

args = TrainingArguments(
    output_dir="chk_bios_lambda0_small",
    max_steps=MAX_STEPS,                        # HARD CAP
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    learning_rate=LR,
    weight_decay=0.01,
    eval_strategy="steps",
    eval_steps=200,
    logging_steps=200,
    save_strategy="no",
    load_best_model_at_end=False,
    seed=SEED,
    fp16=torch.cuda.is_available(),
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=dev_tok,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8),
    compute_metrics=compute_metrics,
)

trainer.train()

# ------------------ 7) Evaluate ------------------
def eval_and_report(split_name, tok_ds):
    out = trainer.predict(tok_ds)
    preds  = np.argmax(out.predictions, axis=1)
    labels = out.label_ids
    genders = np.array(tok_ds["gender_id"])
    print(f"\n=== {split_name.upper()} (overall) ===")
    print(f"Accuracy : {accuracy_score(labels, preds):.4f}")
    print(f"Macro-F1 : {f1_score(labels, preds, average='macro'):.4f}")
    fair = per_gender_scores(labels, preds, genders)
    print(f"\n=== {split_name.upper()} per-gender ===")
    print(pd.DataFrame(fair).T)
    return fair

dev_fair  = eval_and_report("dev", dev_tok)
test_fair = eval_and_report("test", test_tok)

os.makedirs("results_small_lambda0", exist_ok=True)
pd.DataFrame(dev_fair).T.to_csv("results_small_lambda0/dev_fairness.csv")
pd.DataFrame(test_fair).T.to_csv("results_small_lambda0/test_fairness.csv")


CUDA: False CPU
Original sizes: 257478 39642 99069


Filter: 100%|██████████| 257478/257478 [00:01<00:00, 192815.48 examples/s]
Filter: 100%|██████████| 39642/39642 [00:00<00:00, 171263.97 examples/s]
Filter: 100%|██████████| 99069/99069 [00:00<00:00, 180996.77 examples/s]
  .apply(lambda x: x.sample(n=min(cap, len(x)), random_state=seed))
  .apply(lambda x: x.sample(n=min(cap, len(x)), random_state=seed))
  .apply(lambda x: x.sample(n=min(cap, len(x)), random_state=seed))


Downsized sizes: 15768 3881 7820
Labels contiguous? False (#classes = 20 )


Map: 100%|██████████| 15768/15768 [00:00<00:00, 18400.09 examples/s]
Map: 100%|██████████| 3881/3881 [00:00<00:00, 18862.87 examples/s]
Map: 100%|██████████| 7820/7820 [00:00<00:00, 19117.81 examples/s]


num_labels = 20


Map: 100%|██████████| 15768/15768 [00:01<00:00, 9499.84 examples/s] 
Map: 100%|██████████| 3881/3881 [00:00<00:00, 10553.34 examples/s]
Map: 100%|██████████| 7820/7820 [00:00<00:00, 10535.09 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss,Validation Loss



=== DEV (overall) ===
Accuracy : 0.4728
Macro-F1 : 0.4005

=== DEV per-gender ===
                     n       acc  macro_f1
male       1929.000000  0.469673  0.382835
female     1952.000000  0.475922  0.415788
Δacc         -0.006249 -0.006249 -0.006249
Δmacro_f1    -0.032953 -0.032953 -0.032953





=== TEST (overall) ===
Accuracy : 0.4772
Macro-F1 : 0.4044

=== TEST per-gender ===
                     n       acc  macro_f1
male       3871.000000  0.481788  0.396591
female     3949.000000  0.472778  0.407671
Δacc          0.009010  0.009010  0.009010
Δmacro_f1    -0.011081 -0.011081 -0.011081
