In [1]:
from functools import lru_cache
from pathlib import Path
import copy
from typing import Any

import evaluate
import numpy as np
import pandas as pd
import polars as pl
import torch
import oxonfair
from oxonfair import group_metrics as gm
import torch.nn.functional as F
from datasets import Dataset
from sklearn.model_selection import GroupShuffleSplit
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import (
    ModelOutput,  # or just use dict if not subclassing
)

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
CACHE_DIR = Path().cwd().parent / ".cache"
if not CACHE_DIR.exists():
    CACHE_DIR.mkdir()


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = sigmoid(predictions)
    predictions = (predictions > 0.5).astype(int).reshape(-1)
    return clf_metrics.compute(
        predictions=predictions, references=labels.astype(int).reshape(-1)
    )


# 8. Configure training arguments
training_args = TrainingArguments(
    output_dir="multilabel_model",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    label_names=["labels"],
)

  from .autonotebook import tqdm as notebook_tqdm
Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--f1/34c46321f42186df33a6260966e34a368f14868d9cc2ba47d142112e2800d233 (last modified on Thu Mar 20 16:34:30 2025) since it couldn't be found locally at evaluate-metric--f1, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--precision/155d3220d6cd4a6553f12da68eeb3d1f97cf431206304a4bc6e2d564c29502e9 (last modified on Thu Mar 20 16:34:31 2025) since it couldn't be found locally at evaluate-metric--precision, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--recall/11f90e583db35601050aed380d48e83202a896976b9608432fba9244fb447f24 (last modified on Thu Mar 20 16:34:31 2025) since it could

In [2]:
# 3. Load tokenizer
model_path = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [3]:
def get_full_data():
    english_hatespeech = Path().cwd().parent / "hatespeech-data" / "split" / "English"
    all_data = list(english_hatespeech.glob("*.tsv"))
    return (
        pl.DataFrame(
            pd.concat([pd.read_csv(f, sep="\t") for f in all_data]).drop(
                columns=["city", "state", "country", "date"]
            )
        )
        .with_columns(
            pl.col("gender").replace("x", None).cast(pl.Int8),
            pl.col("age").replace("x", None).cast(pl.Int8),
            pl.col("ethnicity").replace("x", None).cast(pl.Int8),
        )
        .drop_nulls()
        .rename({"label": "target"})
    )


def create_dataset(
    features: pl.DataFrame, labels: pl.Series, feature_names: list[str] | None = None
) -> Dataset:
    if feature_names is None:
        feature_names = features.columns
    feature_dict = {feature: features[feature].to_list() for feature in feature_names}
    return Dataset.from_dict(
        {
            **feature_dict,
            "target": labels.to_list(),
        }
    )


@lru_cache
def tokenize(text: str) -> dict[str, Any]:
    return tokenizer(text, truncation=True)


def preprocess_simple(example: dict[str, Any]) -> dict[str, Any]:
    tokenized = tokenize(example["text"])
    labels = [float(example[key]) for key in ["target", "gender"]]
    tokenized["labels"] = labels
    return tokenized


def compute_loss_func(
    outputs: ModelOutput | dict,
    labels: torch.Tensor,
    num_items_in_batch: int,  # noqa: ARG001
) -> torch.Tensor:
    """
    Custom loss function for HuggingFace Trainer:
    - Binary log loss for the first element
    - Squared loss (MSE) for the remaining elements

    Args:
        outputs: ModelOutput or dict containing 'logits' of shape (batch_size, num_outputs)
        labels: Tensor of shape (batch_size, num_outputs), ground-truth labels
        num_items_in_batch: Total number of items in the accumulated batch (unused here)
        num_classification_labels: Number of non-group based classification labels (default: 2)

    Returns:
        Scalar tensor representing the combined loss
    """
    logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"]

    log_loss = F.binary_cross_entropy_with_logits(logits[:, :1], labels[:, :1])

    # Regression loss (MSE) for remaining outputs
    if logits.shape[1] > 1:
        mse_loss = F.mse_loss(logits[:, 1:], labels[:, 1:])
        loss = log_loss + mse_loss
    else:
        loss = log_loss

    return loss


combined = get_full_data().sample(fraction=0.2)

# 5. Prepare data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# 6. Metrics function
# 7. Initialize model


K = 5
gss = GroupShuffleSplit(n_splits=K, train_size=0.8, random_state=110)
all_features = combined.drop("target", "tid", "uid", "age", "ethnicity")
all_labels = combined["target"]
all_users = combined["uid"]

for train_index, test_index in gss.split(all_features, all_labels, groups=all_users):
    train_features = all_features[train_index]
    train_labels = all_labels[train_index]
    train_groups = all_users[train_index]
    test_features = all_features[test_index]
    test_labels = all_labels[test_index]

    # nested cross-validation
    # Run oxonfair on an outer 
    # Min recall as a key metric for each test partition
    # Key question: how big do we need to make the delta min recall to matter on the text
    fair_ensemble = []
    inner_gss = GroupShuffleSplit(n_splits=3, train_size=0.8, random_state=110)
    for inner_train_index, validation_index in inner_gss.split(
        train_features, train_labels, groups=train_groups
    ):
        inner_train_features = train_features[inner_train_index]
        inner_train_labels = train_labels[inner_train_index]
        inner_train_groups = train_groups[inner_train_index]
        inner_validation_features = train_features[validation_index]
        inner_validation_labels = train_labels[validation_index]
        inner_validation_groups = train_groups[validation_index]
        assert inner_validation_groups.shape[0] == validation_index.shape[0]
        model = AutoModelForSequenceClassification.from_pretrained(
            model_path,
            num_labels=2,
            problem_type="multi_label_classification",
        )

        train_dataset = create_dataset(
            inner_train_features,
            inner_train_labels,
        ).map(preprocess_simple)

        validation_dataset = create_dataset(
            inner_validation_features,
            inner_validation_labels,
        ).map(preprocess_simple)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=validation_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_loss_func=compute_loss_func,
            compute_metrics=compute_metrics,
        )
        trainer.train()
        # Run oxonfair here? to merge heads etc.
        val_output = trainer.predict(validation_dataset)
        fpred = oxonfair.DeepFairPredictor(inner_validation_labels.to_numpy(), val_output.predictions, groups=np.array(validation_dataset["gender"]))
        fpred.fit(gm.accuracy, gm.equal_opportunity, 0.02, grid_width=75)
        fair_network = copy.deepcopy(trainer)
        fair_network.model.classifier = fpred.merge_heads_pytorch(fair_network.model.classifier)
        fair_ensemble.append(fair_network)
    break

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 5112/5112 [00:00<00:00, 7477.75 examples/s]
Map: 100%|██████████| 1230/1230 [00:00<00:00, 9025.01 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.626881,0.672358,0.708394,0.593333,0.878815
2,No log,0.596805,0.677642,0.715872,0.595707,0.896768
3,No log,0.597215,0.679675,0.717765,0.597139,0.899461


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 5126/5126 [00:00<00:00, 8948.16 examples/s]
Map: 100%|██████████| 1216/1216 [00:00<00:00, 8895.21 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.61323,0.671875,0.703346,0.593848,0.862352
2,No log,0.594822,0.671875,0.7051,0.592915,0.869644
3,No log,0.613584,0.669819,0.70853,0.588661,0.889699


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 5116/5116 [00:00<00:00, 7376.72 examples/s]
Map: 100%|██████████| 1226/1226 [00:00<00:00, 8712.17 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.646429,0.670065,0.710968,0.595096,0.882875
2,No log,0.663193,0.670881,0.717733,0.592379,0.910382
3,No log,0.620299,0.680261,0.720599,0.602144,0.897072


In [None]:
from transformers import TextClassificationPipeline

def majority_vote(lists: list[list[bool]]) -> list[bool]:
    return [sum(sublist) > len(sublist)/2 for sublist in lists]

def convert_score(score: float, threshold: float=0.5) -> bool:
    return score > threshold

def aggregate_scores(scores: list[list[dict]], threshold: float=0.5) -> list[bool]:
    num_preds = len(scores[0])
    # Convert to list[(pred0, pred0, pred0), (pred1, ...]
    final_preds = []
    for pred_index in range(num_preds):
        pred_list = []
        for score_list in scores:
            score_dict = score_list[pred_index]
            pred_list.append(convert_score(score_dict["score"], threshold=threshold))
        final_preds.append(pred_list)
    return majority_vote(final_preds)


def max_index_by_key(lst: list[dict], key: str="score"):
    if not lst:
        return None
    return max(range(len(lst)), key=lambda i: lst[i][key])

def ensemble_predict(texts: list[str], ensemble: list[Trainer]) -> list[int]:
    device = ensemble[0].model.device  # Get device from first model
    pipes = [TextClassificationPipeline(tokenizer=tokenizer, model=trainer.model.to(device), device=device) for trainer in ensemble]
    preds = [pipe(texts) for pipe in pipes]
    return aggregate_scores(preds)



Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0


{'label': 'LABEL_0', 'score': 0.10932417958974838}
{'label': 'LABEL_0', 'score': 0.037440285086631775}
{'label': 'LABEL_0', 'score': 0.00524404039606452}
{'label': 'LABEL_0', 'score': 0.9511734843254089}
{'label': 'LABEL_0', 'score': 0.9632535576820374}
{'label': 'LABEL_0', 'score': 0.9765505194664001}
{'label': 'LABEL_0', 'score': 0.915313184261322}
{'label': 'LABEL_0', 'score': 0.9852876663208008}
{'label': 'LABEL_0', 'score': 0.982718288898468}


[False, True, True]

In [81]:
raw_preds

[[{'label': 'LABEL_0', 'score': 0.0040559242479503155},
  {'label': 'LABEL_0', 'score': 0.31653323769569397},
  {'label': 'LABEL_0', 'score': 0.8107593655586243}],
 [{'label': 'LABEL_0', 'score': 0.013930326327681541},
  {'label': 'LABEL_0', 'score': 0.11571655422449112},
  {'label': 'LABEL_0', 'score': 0.6770954132080078}],
 [{'label': 'LABEL_0', 'score': 0.03584609180688858},
  {'label': 'LABEL_0', 'score': 0.18709571659564972},
  {'label': 'LABEL_0', 'score': 0.8431380391120911}]]

In [15]:
fpred.evaluate_fairness()

Unnamed: 0,original,updated
Statistical Parity,0.091,0.009136
Predictive Parity,0.016158,0.087214
Equal Opportunity,0.083916,0.008741
Average Group Difference in False Negative Rate,0.083916,0.008741
Equalized Odds,0.061629,0.027799
Conditional Use Accuracy,0.021797,0.046344
Average Group Difference in Accuracy,0.005622,0.031756
Treatment Equality,0.80834,0.771552


In [22]:
fair_network.classifier = fpred.merge_heads_pytorch(fair_network.classifier)

In [23]:
fair_network.classifier

Linear(in_features=768, out_features=1, bias=True)