In [None]:
# standard libraries 
import os
import sys

# third party libraries
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import shap
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import(
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    get_scheduler, 
    pipeline
)
from bertviz import head_view
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    accuracy_score, 
    f1_score, 
    roc_auc_score, 
    classification_report,
    confusion_matrix, 
    ConfusionMatrixDisplay, 
    roc_curve
)
from aif360.datasets import BinaryLabelDataset
from aif360.metrics import ClassificationMetric
from aif360.algorithms.postprocessing import RejectOptionClassification
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

# local module
sys.path.append(os.path.abspath(".."))
from project_config import from_root

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
PATH_CONTROL = from_root("data", "pitt_control.tsv")
PATH_AD = from_root("data", "pitt_ad.tsv")

control_df = pd.read_csv(PATH_CONTROL, sep="\t")
ad_df = pd.read_csv(PATH_AD, sep="\t")

control_df["label"] = 0
ad_df["label"] = 1

df = pd.concat([control_df, ad_df], ignore_index=True)
df = df.dropna(subset=["transcription", "label"])

print("Total records after cleaning:", len(df))
print("Label distribution:")
print(df["label"].value_counts())

# get data
all_texts = df["transcription"].tolist()
all_labels = df["label"].astype(int).tolist()

In [None]:
# =============================================================================
# NOTE:
# The following code was used to just check the token length to see
# if some texs were gonna be truncated.
# Uncomment the following lines to check token lengths
# =============================================================================

tokenizer = AutoTokenizer.from_pretrained("roberta-base") # tokenizer = AutoTokenizer.from_pretrained("roberta-large")

token_lengths = [len(tokenizer.encode(text, add_special_tokens=True)) for text in all_texts]

print("\nToken Length Stats:")
print(f"Min: {np.min(token_lengths)}")
print(f"Max: {np.max(token_lengths)}")
print(f"Mean: {np.mean(token_lengths):.2f}")
print(f"Median: {np.median(token_lengths)}")
print(f"90th percentile: {np.percentile(token_lengths, 90)}")
print(f"95th percentile: {np.percentile(token_lengths, 95)}")

In [None]:
model_name = "roberta-base" # also tried with "roberta-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(all_labels),
    y=all_labels
)

class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Class weights: {class_weights}")

In [None]:
# just like for ADReSS -> custom dataset class following same structure
class PittDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long)
        }
    
MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 5
NUM_FOLDS = 5

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

In [None]:
# cross validation loop
skf = StratifiedKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
all_metrics = []

# to aggregate the results at the end
all_y_true = []
all_y_pred = []
all_y_probs = []
all_val_idx = []

for fold, (train_idx, val_idx) in enumerate(skf.split(all_texts, all_labels)):
    print(f"\nFold {fold+1}/{NUM_FOLDS}")
    print(f"Train size: {len(train_idx)}, Val size: {len(val_idx)}")

    train_texts = []
    train_labels = []
    val_texts = []
    val_labels = []
    all_indices = set(train_idx) | set(val_idx)
    for i in all_indices:
        if i in train_idx:
            train_texts.append(all_texts[i])
            train_labels.append(all_labels[i])
        elif i in val_idx:
            val_texts.append(all_texts[i])
            val_labels.append(all_labels[i])

    # datasets & loaders
    train_dataset = PittDataset(train_texts, train_labels, tokenizer, MAX_LEN)
    val_dataset = PittDataset(val_texts, val_labels, tokenizer, MAX_LEN)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        worker_init_fn=seed_worker, 
        generator=g
    )
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    model.to(device)
    
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    num_training_steps = EPOCHS * len(train_loader)
    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps
    )

    # training
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels_batch = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            # outputs = model(input_ids, attention_mask=attention_mask, labels=labels_batch)
            logits = outputs.logits
            loss = criterion(logits, labels_batch)
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} training loss: {avg_loss:.4f}")

    # evaluation
    model.eval()
    val_preds, val_labels_true, val_probs = [], [], []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels_batch = batch["labels"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            probs = F.softmax(outputs.logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            val_preds.extend(preds.cpu().numpy())
            val_labels_true.extend(labels_batch.cpu().numpy())
            val_probs.extend(probs[:, 1].cpu().numpy())

    # store metrics
    all_y_true.extend(val_labels_true)
    all_y_pred.extend(val_preds)
    all_y_probs.extend(val_probs)
    all_val_idx.extend(val_idx)

    # metrics
    acc = accuracy_score(val_labels_true, val_preds)
    f1 = f1_score(val_labels_true, val_preds)
    roc_auc = roc_auc_score(val_labels_true, val_probs)

    print(f"\nFold {fold+1} Accuracy: {acc:.4f} | F1: {f1:.4f} | ROC AUC: {roc_auc:.4f}")
    print("Classification Report:\n", classification_report(val_labels_true, val_preds, digits=3))

    all_metrics.append((acc, f1, roc_auc))

    # ROC Curve
    fpr, tpr, _ = roc_curve(val_labels_true, val_probs)
    plt.figure(figsize=(7, 5))
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}", color='blue')
    plt.plot([0, 1], [0, 1], linestyle="--", color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"Fold {fold+1} ROC Curve")
    plt.legend()
    plt.grid(True)
    plt.show()

    # confusion matrix
    cm = confusion_matrix(val_labels_true, val_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Control", "AD"])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(f"Fold {fold+1} Confusion Matrix")
    plt.grid(False)
    plt.show()

    # free memory -> avoid OOMerrors
    del model, train_loader, val_loader, train_dataset, val_dataset, optimizer, scheduler
    torch.cuda.empty_cache()
    import gc
    gc.collect()

# summary across folds
accs = [m[0] for m in all_metrics]
f1s = [m[1] for m in all_metrics]
aucs = [m[2] for m in all_metrics]

print("\nCross-Validation Results:")
print(f"Avg Accuracy: {np.mean(accs):.4f} ± {np.std(accs):.4f}")
print(f"Avg F1-Score: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")
print(f"Avg ROC AUC : {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")

print("\n--- Final Evaluation (RoBERTa Pitt CV) ---")
print(f"Accuracy: {accuracy_score(all_y_true, all_y_pred):.4f}")
print("Classification Report;\n", classification_report(all_y_true, all_y_pred, digits=3))

# ROC plot
fpr, tpr, _ = roc_curve(all_y_true, all_y_probs)
roc_auc = roc_auc_score(all_y_true, all_y_probs)
print(f"ROC-AUC: {roc_auc:.3f}")

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}", color='blue')
plt.plot([0, 1], [0, 1], linestyle="--", color='gray')
plt.title("ROC Curve - RoBERTa Pitt")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.grid(True)
plt.show()


# predictions for fairness pipeline
df_preds = df.iloc[all_val_idx].copy()
df_preds["original_index"] = all_val_idx
df_preds["Label"] = all_y_true
df_preds["predicted_label"] = all_y_pred
df_preds["predicted_proba"] = all_y_probs
# sort by og index to keep order
df_preds = df_preds.sort_values(by="original_index").reset_index(drop=True)
# add sensitive attributes
df_preds["gender_bin"] = df_preds["gender"].map({"m": 1, "f": 0})
df_preds["age_group_65"] = (df_preds["age"] >= 65).astype(int)
PATH_SAVE = from_root("output", "roberta", "pitt", "roberta_pitt_predictions.tsv")
df_preds.to_csv(PATH_SAVE, sep="\t", index=False) 

In [None]:
# final training on the full dataset for ROC mitigation algo and SHAP
print("\nFinal RoBERTa Training on Full Dataset")

full_dataset = PittDataset(all_texts, all_labels, tokenizer, MAX_LEN)
full_loader = DataLoader(
    full_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    worker_init_fn=seed_worker,
    generator=g
)

final_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
final_model.to(device)

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.AdamW(final_model.parameters(), lr=2e-5)

# training loop 
final_model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for batch in tqdm(full_loader, desc=f"[FULL DATA] Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        # outputs = final_model(input_ids, attention_mask=attention_mask, labels=labels)
        outputs = final_model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        loss = criterion(logits, labels)
        # loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(full_loader)
    print(f"[FULL DATA] Epoch {epoch+1} Training Loss: {avg_loss:.4f}")

# save final model and tokenizer
MODEL_PATH = from_root("output", "models", "roberta_pitt_full_model.pt")
TOKENIZER_PATH = from_root("output", "models", "roberta_pitt_tokenizer")
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
torch.save(final_model.state_dict(), MODEL_PATH)
_ = tokenizer.save_pretrained(TOKENIZER_PATH)

# print(f"\nFinal RoBERTa model saved to:\n- {MODEL_PATH}\n- {TOKENIZER_PATH}")

**Fairness evaluation** on the baseline model

In [None]:
def compute_fairness_metrics(df, protected_attr, priv_val, unpriv_val, label_col="Label"):
    dataset_true = BinaryLabelDataset(
        df=df[[label_col, protected_attr]].copy(),
        label_names=[label_col],
        protected_attribute_names=[protected_attr]
    )
    dataset_pred = dataset_true.copy()
    dataset_pred.labels = df["predicted_label"].values.reshape(-1, 1)

    metric = ClassificationMetric(
        dataset_true, dataset_pred,
        privileged_groups=[{protected_attr: priv_val}],
        unprivileged_groups=[{protected_attr: unpriv_val}]
    )

    print(f"\nFairness Metrics for {protected_attr.upper()}:")
    print(f"  Accuracy:                    {accuracy_score(df[label_col], df['predicted_label']):.3f}")
    print(f"  Statistical Parity Diff:     {metric.statistical_parity_difference():.3f}")
    print(f"  Equal Opportunity Diff:      {metric.equal_opportunity_difference():.3f}")
    print(f"  Average Odds Diff:           {metric.average_odds_difference():.3f}")
    print(f"  Disparate Impact:            {metric.disparate_impact():.3f}")

PATH = from_root("output", "roberta", "pitt", "roberta_pitt_predictions.tsv")
compute_fairness_metrics(df_preds, "gender_bin", priv_val=0, unpriv_val=1)
compute_fairness_metrics(df_preds, "age_group_65", priv_val=1, unpriv_val=0)

**ROC** - post-processing bias mitigation algorithm

In [None]:
def apply_roc_and_plot(
    df: pd.DataFrame,
    group_col: str,
    group_name: str,
    privileged_val: int,
    unprivileged_val: int,
    save_prefix: str
) -> np.ndarray:
    aif_data = BinaryLabelDataset(
        df=df[["Label", group_col]].copy(),
        label_names=["Label"],
        protected_attribute_names=[group_col],
        favorable_label=1,
        unfavorable_label=0
    )

    aif_pred = aif_data.copy()
    aif_pred.scores = df["predicted_proba"].values.reshape(-1, 1)
    aif_pred.labels = df["predicted_label"].values.reshape(-1, 1)

    roc = RejectOptionClassification(
        unprivileged_groups=[{group_col: unprivileged_val}],
        privileged_groups=[{group_col: privileged_val}],
        metric_name="Statistical parity difference",
        metric_ub=0.01,
        metric_lb=-0.01
    ).fit(aif_data, aif_pred)

    fair_pred = roc.predict(aif_pred)

    metric = ClassificationMetric(
        aif_data, fair_pred,
        privileged_groups=[{group_col: privileged_val}],
        unprivileged_groups=[{group_col: unprivileged_val}]
    )

    print(f"\nFairness Metrics for {group_name.upper()}:")
    print(f"    Accuracy before ROC:           {accuracy_score(df['Label'], df['predicted_label']):.3f}")
    print(f"    Accuracy after ROC:            {accuracy_score(df['Label'], fair_pred.labels.ravel()):.3f}")
    print(f"    Statistical Parity Difference: {metric.statistical_parity_difference():.3f}")
    print(f"    Equal Opportunity Difference:  {metric.equal_opportunity_difference():.3f}")
    print(f"    Average Odds Difference:       {metric.average_odds_difference():.3f}")
    print(f"    Disparate Impact:              {metric.disparate_impact():.3f}")

    # confusion matrices before and after ROC
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    ConfusionMatrixDisplay(confusion_matrix(df["Label"], df["predicted_label"]),
                           display_labels=[0, 1]).plot(ax=axs[0])
    axs[0].set_title("Original Predictions")

    ConfusionMatrixDisplay(confusion_matrix(df["Label"], fair_pred.labels.ravel()),
                           display_labels=[0, 1]).plot(ax=axs[1])
    axs[1].set_title("ROC Fair Predictions")

    plt.suptitle(f"RoBERTa Pitt - {group_name} Mitigation")
    plt.tight_layout()
    plot_path = from_root("output", "fairness")
    plt.savefig(os.path.join(plot_path, f"roc_cm_{save_prefix}_roberta_pitt.png"))
    plt.show()

    # ROC margin sweep
    margins = range(1, 30, 2)
    spds, accs = [], []

    for m in margins:
        sweep = RejectOptionClassification(
            unprivileged_groups=[{group_col: unprivileged_val}],
            privileged_groups=[{group_col: privileged_val}],
            metric_name="Statistical parity difference",
            metric_ub=0.01,
            metric_lb=-0.01,
            num_ROC_margin=m
        ).fit(aif_data, aif_pred)

        pred_m = sweep.predict(aif_pred)
        mtr = ClassificationMetric(
            aif_data, pred_m,
            privileged_groups=[{group_col: privileged_val}],
            unprivileged_groups=[{group_col: unprivileged_val}]
        )

        spds.append(mtr.statistical_parity_difference())
        accs.append(accuracy_score(df["Label"], pred_m.labels.ravel()))

    plt.figure(figsize=(10, 5))
    plt.plot(margins, accs, label="Accuracy", marker="o")
    plt.plot(margins, spds, label="SPD", marker="x")
    plt.axhline(0.01, color="gray", linestyle="--", label="Fairness Bound")
    plt.axhline(-0.01, color="gray", linestyle="--")
    plt.xlabel("ROC Margin")
    plt.ylabel("Metric Value")
    plt.title(f"ROC Margin Sweep - {group_name}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plot_path = from_root("output", "fairness")
    plt.savefig(os.path.join(plot_path, f"roc_margin_sweep_{save_prefix}_roberta_pitt.png"))
    plt.show()

    return fair_pred.labels.ravel()

def run_roc_fairness_roberta_pitt() -> pd.DataFrame:
    FILE_PATH = from_root("output", "roberta", "pitt", "roberta_pitt_predictions.tsv")
    df_preds = pd.read_csv(FILE_PATH, sep="\t")

    # apply ROC w/ plot for gender
    df_preds["gender_bin_ROC_label"] = apply_roc_and_plot(
        df_preds,
        group_col="gender_bin",
        group_name="Gender",
        privileged_val=0,
        unprivileged_val=1,
        save_prefix="gender"
    )

    # apply ROC w/ plot for age group 65+
    df_preds["age_group_65_ROC_label"] = apply_roc_and_plot(
        df_preds,
        group_col="age_group_65",
        group_name="Age 65+",
        privileged_val=1,
        unprivileged_val=0,
        save_prefix="age"
    )
    return df_preds

df_roc = run_roc_fairness_roberta_pitt()

**SHAP** - Explainability

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

roberta_pipeline = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1,
    return_all_scores=True
)

background_texts = df_preds["transcription"].sample(n=min(100, len(df_preds)), random_state=42).tolist()
test_texts_subset = df_preds["transcription"].sample(n=min(20, len(df_preds)), random_state=1).tolist()

explainer = shap.Explainer(roberta_pipeline)
shap_values = explainer(test_texts_subset)
shap.plots.text(shap_values[0])

In [None]:
# get the text to check the prediction
text_0 = test_texts_subset[0]
row = df_preds[df_preds["transcription"] == text_0].iloc[0]
print("Transcript:", text_0)
print("True Label:", row["Label"])
print("Predicted Label:", row["predicted_label"])

In [None]:
model_name = "roberta-base"  # or "roberta-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, output_attentions=True)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

sample_text = test_texts_subset[0]  # to match the SHAP example

inputs = tokenizer(
    sample_text,
    return_tensors="pt",
    truncation=True,
    padding=True,
    max_length=256,
    add_special_tokens=True
)

input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    attentions = outputs.attentions  # list of tensors

tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
head_view(attentions, tokens)