In [23]:
import copy
from functools import lru_cache
from pathlib import Path
from typing import Any, TypedDict

import evaluate
import numpy as np
import pandas as pd
import polars as pl
import torch
import torch.nn.functional as F
from datasets import Dataset
from loguru import logger
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import GroupShuffleSplit
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    TextClassificationPipeline,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import (
    ModelOutput,  # or just use dict if not subclassing
)

import oxonfair
from oxonfair import group_metrics as gm


class FairnessMetrics(TypedDict):
    equal_opportunity: float
    min_recall: float
    accuracy: float
    precision: float
    recall: float


def calculate_metrics(
    test_groups: pl.Series, test_labels: pl.Series, predictions: list[str] | np.ndarray
) -> FairnessMetrics:
    groups = test_groups.to_numpy()
    preds0 = np.array(predictions)[groups == 0]
    preds1 = np.array(predictions)[groups == 1]
    labels0 = test_labels.to_numpy()[groups == 0]
    labels1 = test_labels.to_numpy()[groups == 1]

    recall1 = recall_score(y_true=labels1, y_pred=preds1)
    recall0 = recall_score(y_true=labels0, y_pred=preds0)

    min_recall = min(recall0, recall1)
    equal_opportunity = abs(recall1 - recall0)
    return {
        "min_recall": min_recall,
        "equal_opportunity": equal_opportunity,
        "accuracy": accuracy_score(y_true=test_labels.to_numpy(), y_pred=predictions),
        "precision": precision_score(y_true=test_labels, y_pred=predictions),
        "recall": recall_score(y_true=test_labels, y_pred=predictions),
    }


clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])
CACHE_DIR = Path().cwd().parent / ".cache"
if not CACHE_DIR.exists():
    CACHE_DIR.mkdir()


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = sigmoid(predictions)
    predictions = (predictions > 0.5).astype(int).reshape(-1)
    return clf_metrics.compute(
        predictions=predictions, references=labels.astype(int).reshape(-1)
    )


# 8. Configure training arguments
training_args = TrainingArguments(
    output_dir="multilabel_model",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    label_names=["labels"],
)

Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--f1/34c46321f42186df33a6260966e34a368f14868d9cc2ba47d142112e2800d233 (last modified on Thu Mar 20 16:34:30 2025) since it couldn't be found locally at evaluate-metric--f1, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--precision/155d3220d6cd4a6553f12da68eeb3d1f97cf431206304a4bc6e2d564c29502e9 (last modified on Thu Mar 20 16:34:31 2025) since it couldn't be found locally at evaluate-metric--precision, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /VData/resources/huggingface/modules/evaluate_modules/metrics/evaluate-metric--recall/11f90e583db35601050aed380d48e83202a896976b9608432fba9244fb447f24 (last modified on Thu Mar 20 16:34:31 2025) since it couldn't be found locally at evaluate-metric--recall, o

In [2]:
# 3. Load tokenizer
model_path = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
def majority_vote(lists: list[list[bool]]) -> list[bool]:
    return [sum(sublist) > len(sublist) / 2 for sublist in lists]


def convert_score(score: float, threshold: float = 0.5) -> bool:
    return score > threshold


def aggregate_scores(scores: list[list[dict]], threshold: float = 0.5) -> list[bool]:
    num_preds = len(scores[0])
    # Convert to list[(pred0, pred0, pred0), (pred1, ...]
    final_preds = []
    for pred_index in range(num_preds):
        pred_list = []
        for score_list in scores:
            score_dict = score_list[pred_index]
            pred_list.append(convert_score(score_dict["score"], threshold=threshold))
        final_preds.append(pred_list)
    return majority_vote(final_preds)


def max_index_by_key(lst: list[dict], key: str = "score"):
    if not lst:
        return None
    return max(range(len(lst)), key=lambda i: lst[i][key])


def ensemble_predict(texts: list[str], ensemble: list[Trainer]) -> list[int]:
    device = ensemble[0].model.device  # Get device from first model
    pipes = [
        TextClassificationPipeline(
            tokenizer=tokenizer,
            model=trainer.model.to(device),
            device=device,
            truncation=True,
        )
        for trainer in ensemble
    ]
    preds = [pipe(texts) for pipe in pipes]
    return aggregate_scores(preds)


def get_full_data():
    english_hatespeech = Path().cwd().parent / "hatespeech-data" / "split" / "English"
    all_data = list(english_hatespeech.glob("*.tsv"))
    return (
        pl.DataFrame(
            pd.concat([pd.read_csv(f, sep="\t") for f in all_data]).drop(
                columns=["city", "state", "country", "date"]
            )
        )
        .with_columns(
            pl.col("gender").replace("x", None).cast(pl.Int8),
            pl.col("age").replace("x", None).cast(pl.Int8),
            pl.col("ethnicity").replace("x", None).cast(pl.Int8),
        )
        .drop_nulls()
        .rename({"label": "target"})
    )


def create_dataset(
    features: pl.DataFrame, labels: pl.Series, feature_names: list[str] | None = None
) -> Dataset:
    if feature_names is None:
        feature_names = features.columns
    feature_dict = {feature: features[feature].to_list() for feature in feature_names}
    return Dataset.from_dict(
        {
            **feature_dict,
            "target": labels.to_list(),
        }
    )


@lru_cache
def tokenize(text: str) -> dict[str, Any]:
    return tokenizer(text, truncation=True)


def preprocess_simple(example: dict[str, Any]) -> dict[str, Any]:
    tokenized = tokenize(example["text"])
    labels = [float(example[key]) for key in ["target", "gender"]]
    tokenized["labels"] = labels
    return tokenized


def compute_loss_func(
    outputs: ModelOutput | dict,
    labels: torch.Tensor,
    num_items_in_batch: int,  # noqa: ARG001
) -> torch.Tensor:
    """
    Custom loss function for HuggingFace Trainer:
    - Binary log loss for the first element
    - Squared loss (MSE) for the remaining elements

    Args:
        outputs: ModelOutput or dict containing 'logits' of shape (batch_size, num_outputs)
        labels: Tensor of shape (batch_size, num_outputs), ground-truth labels
        num_items_in_batch: Total number of items in the accumulated batch (unused here)
        num_classification_labels: Number of non-group based classification labels (default: 2)

    Returns:
        Scalar tensor representing the combined loss
    """
    logits = outputs.logits if hasattr(outputs, "logits") else outputs["logits"]

    log_loss = F.binary_cross_entropy_with_logits(logits[:, :1], labels[:, :1])

    # Regression loss (MSE) for remaining outputs
    if logits.shape[1] > 1:
        mse_loss = F.mse_loss(logits[:, 1:], labels[:, 1:])
        loss = log_loss + mse_loss
    else:
        loss = log_loss

    return loss


combined = get_full_data().sample(fraction=0.2)

# 5. Prepare data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# 6. Metrics function
# 7. Initialize model


K = 3
gss = GroupShuffleSplit(n_splits=K, train_size=0.8, random_state=110)
all_features = combined.drop("target", "tid", "uid", "age", "ethnicity")
all_labels = combined["target"]
all_users = combined["uid"]

test_metrics = []
all_metrics: list[pd.DataFrame] = []
for iteration, (train_index, test_index) in enumerate(
    gss.split(all_features, all_labels, groups=all_users)
):
    train_features = all_features[train_index]
    train_labels = all_labels[train_index]
    train_groups = all_users[train_index]
    test_features = all_features[test_index]
    test_labels = all_labels[test_index]

    inner_gss = GroupShuffleSplit(n_splits=2, train_size=0.8, random_state=110)
    # nested cross-validation
    # Run oxonfair on an outer
    # Min recall as a key metric for each test partition
    # Key question: how big do we need to make the delta min recall to matter on the text
    fair_ensemble = []
    metrics = []
    inner_gss = GroupShuffleSplit(n_splits=3, train_size=0.8, random_state=110)
    for i, (inner_train_index, validation_index) in enumerate(
        inner_gss.split(train_features, train_labels, groups=train_groups)
    ):
        inner_train_features = train_features[inner_train_index]
        inner_train_labels = train_labels[inner_train_index]
        inner_train_groups = train_groups[inner_train_index]
        inner_validation_features = train_features[validation_index]
        inner_validation_labels = train_labels[validation_index]
        inner_validation_groups = train_groups[validation_index]
        assert inner_validation_groups.shape[0] == validation_index.shape[0]
        model = AutoModelForSequenceClassification.from_pretrained(
            model_path,
            num_labels=2,
            problem_type="multi_label_classification",
        )

        train_dataset = create_dataset(
            inner_train_features,
            inner_train_labels,
        ).map(preprocess_simple)

        validation_dataset = create_dataset(
            inner_validation_features,
            inner_validation_labels,
        ).map(preprocess_simple)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=validation_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_loss_func=compute_loss_func,
            compute_metrics=compute_metrics,
        )
        trainer.train()
        # Run oxonfair here? to merge heads etc.
        val_output = trainer.predict(validation_dataset)
        fpred = oxonfair.DeepFairPredictor(
            inner_validation_labels.to_numpy(),
            val_output.predictions,
            groups=np.array(validation_dataset["gender"]),
        )
        fpred.fit(gm.accuracy, gm.equal_opportunity, 0.02, grid_width=75)
        fair_network = copy.deepcopy(trainer)
        fair_network.model.classifier = fpred.merge_heads_pytorch(
            fair_network.model.classifier
        )
        performance = fpred.evaluate().assign(classifier=i, metric_type="performance")
        fairness = fpred.evaluate_fairness(
            metrics=gm.default_fairness_measures | {"min_recall": gm.recall.min}
        ).assign(classifier=i, metric_type="fairness")
        metrics.append(pd.concat([performance, fairness]))
        fair_ensemble.append(fair_network)
    all_metrics.append(pd.concat(metrics).assign(iteration=iteration))
    logger.info("Done training ensemble! Evaluating on test set")
    test_dataset = create_dataset(
        test_features,
        test_labels,
    ).map(preprocess_simple)
    logger.debug("Evaluating ensemble...")
    ensemble_preds = ensemble_predict(
        texts=test_dataset["text"], ensemble=fair_ensemble
    )
    ensemble_metrics = calculate_metrics(
        test_groups=test_features["gender"],
        test_labels=test_labels,
        predictions=ensemble_preds,
    )
    logger.debug("Evaluating first member...")
    single_preds = ensemble_predict(
        texts=test_dataset["text"], ensemble=fair_ensemble[:1]
    )
    single_metrics = calculate_metrics(
        test_groups=test_features["gender"],
        test_labels=test_labels,
        predictions=single_preds,
    )
    single_df = pd.DataFrame([single_metrics]).assign(model_type="single")
    test_metric_df = pd.concat(
        [single_df, pd.DataFrame([ensemble_metrics]).assign(model_type="ensemble")]
    ).assign(iteration=iteration)
    test_metrics.append(test_metric_df)
    break

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 6131/6131 [00:00<00:00, 6333.72 examples/s]
Map: 100%|██████████| 1223/1223 [00:00<00:00, 6300.71 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.659002,0.680294,0.717281,0.615385,0.859619
2,No log,0.61414,0.681112,0.721826,0.613333,0.87695
3,0.603200,0.623296,0.682339,0.72437,0.613213,0.884749


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 6118/6118 [00:01<00:00, 5433.48 examples/s]
Map: 100%|██████████| 1236/1236 [00:00<00:00, 9028.29 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.653099,0.673544,0.702324,0.596118,0.854578
2,No log,0.645271,0.677589,0.711336,0.596236,0.881508
3,0.613700,0.658186,0.677994,0.712635,0.596014,0.885996


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|██████████| 5105/5105 [00:00<00:00, 6895.81 examples/s]
Map: 100%|██████████| 2249/2249 [00:00<00:00, 9222.75 examples/s]
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.653846,0.779013,0.801676,0.706648,0.926233
2,No log,0.748531,0.735883,0.77644,0.655962,0.95113
3,No log,0.732048,0.749444,0.78381,0.671156,0.941909


[32m2025-04-09 10:11:24.363[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m202[0m - [1mDone training ensemble! Evaluating on test set[0m
Map: 100%|██████████| 1526/1526 [00:00<00:00, 8866.54 examples/s]
[32m2025-04-09 10:11:24.543[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m207[0m - [34m[1mEvaluating ensemble...[0m
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
[32m2025-04-09 10:12:28.130[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m214[0m - [34m[1mEvaluating first member...[0m
Device set to use cuda:0


In [25]:
pd.concat(test_metrics)

Unnamed: 0,min_recall,equal_opportunity,accuracy,precision,recall,model_type,iteration
0,0.50365,0.009009,0.78768,0.898204,0.508475,single,0
0,0.521898,0.107849,0.825033,0.947368,0.579661,ensemble,0


In [27]:
pd.concat(metrics)

Unnamed: 0,original,updated,classifier,metric_type
Accuracy,0.838103,0.784137,0,performance
Balanced Accuracy,0.822911,0.750949,0,performance
F1 score,0.790698,0.678832,0,performance
MCC,0.66694,0.570736,0,performance
Precision,0.865741,0.905844,0,performance
Recall,0.727626,0.542802,0,performance
ROC AUC,0.90584,0.862754,0,performance
Statistical Parity,0.11778,0.042029,0,fairness
Predictive Parity,0.011405,0.0579,0,fairness
Equal Opportunity,0.111874,0.019202,0,fairness


In [19]:
ensemble_preds = ensemble_predict(texts=test_dataset["text"], ensemble=fair_ensemble)

Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0


In [61]:
single_preds = ensemble_predict(texts=test_dataset["text"], ensemble=fair_ensemble[:1])

Device set to use cuda:0


min_recall=np.float64(0.6547945205479452)
equal_opportunity=np.float64(0.027376022087713725)


In [38]:
type(test_features)

polars.dataframe.frame.DataFrame