In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from torch.utils.data import DataLoader
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from datasets import Dataset
from sklearn.model_selection import train_test_split
import csv

# === Configuration ===
model_name = "emilyalsentzer/Bio_ClinicalBERT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16
max_length = 256
epochs = 5
prediction_threshold = 0.3
temperature = 2.0

# === Load and preprocess data ===
notes = pd.read_csv("NOTEEVENTS_random_chatgpt.csv")
diagnoses = pd.read_csv("DIAGNOSES_ICD_random.csv")

notes = notes.dropna(subset=["SUBJECT_ID", "HADM_ID", "TEXT"])
diagnoses = diagnoses.dropna(subset=["SUBJECT_ID", "HADM_ID", "ICD9_CODE"])
diagnoses["ICD9_CODE"] = diagnoses["ICD9_CODE"].astype(str).str[:3]

merged = pd.merge(diagnoses, notes, on=["SUBJECT_ID", "HADM_ID"], how="inner")
grouped = merged.groupby(["SUBJECT_ID", "HADM_ID"]).agg({
    "TEXT": "first",
    "ICD9_CODE": lambda x: list(set(x))
}).reset_index()

from collections import Counter
all_codes = [code for codes in grouped["ICD9_CODE"] for code in codes]
top_codes = sorted([code for code, _ in Counter(all_codes).most_common(10)])

def filter_and_bin(codes, top=top_codes):
    filtered = [c for c in codes if c in top]
    return filtered if filtered else None

def codes_to_multihot(codes, top=top_codes):
    return [1 if code in codes else 0 for code in top]

grouped["filtered_codes"] = grouped["ICD9_CODE"].apply(filter_and_bin)
grouped = grouped.dropna(subset=["filtered_codes"])
grouped["label_vector"] = grouped["filtered_codes"].apply(lambda x: codes_to_multihot(x, top_codes))

df = grouped[["TEXT", "label_vector"]].rename(columns={"TEXT": "text", "label_vector": "labels"})

train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)

train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
eval_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_fn(examples):
    encoding = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)
    encoding["labels"] = examples["labels"]
    return encoding

train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
eval_dataset = eval_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
eval_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# Compute class weights for imbalance
def compute_class_weights(dataset):
    labels = np.stack(dataset["labels"])
    label_counts = labels.sum(axis=0)
    total = labels.shape[0]
    weights = total / (label_counts + 1e-5)
    normalized_weights = weights / weights.sum() * len(weights)
    return torch.tensor(normalized_weights, dtype=torch.float).to(device)

class_weights = compute_class_weights(train_dataset)

# Evaluation with AUROC included
def evaluate_model(model, loader, threshold=prediction_threshold):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs > threshold).astype(int)
            labels = batch["labels"].cpu().numpy()
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels)

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)

    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro", zero_division=0)

    # Compute AUROC per class, then average
    try:
        auroc_per_class = []
        for i in range(all_labels.shape[1]):
            if np.sum(all_labels[:, i]) > 0:  # avoid classes with no positive samples
                score = roc_auc_score(all_labels[:, i], all_probs[:, i])
                auroc_per_class.append(score)
        auroc_macro = np.mean(auroc_per_class) if auroc_per_class else 0.0
    except Exception:
        auroc_macro = 0.0

    try:
        # Micro AUROC: flatten all
        auroc_micro = roc_auc_score(all_labels.ravel(), all_probs.ravel())
    except Exception:
        auroc_micro = 0.0

    return {
        "precision_micro": precision_micro,
        "recall_micro": recall_micro,
        "f1_micro": f1_micro,
        "precision_macro": precision_macro,
        "recall_macro": recall_macro,
        "f1_macro": f1_macro,
        "auroc_micro": auroc_micro,
        "auroc_macro": auroc_macro,
    }

# Training function
def train_model(model, method_name):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))
    bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights)

    csv_file = f"multilabel_{method_name}_metrics.csv"
    with open(csv_file, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "P_micro", "R_micro", "F1_micro", "P_macro", "R_macro", "F1_macro", "AUROC_micro", "AUROC_macro"])

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            loss = bce_loss_fn(outputs.logits, batch["labels"].float())
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        metrics = evaluate_model(model, eval_loader)
        with open(csv_file, mode="a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([
                epoch,
                metrics["precision_micro"], metrics["recall_micro"], metrics["f1_micro"],
                metrics["precision_macro"], metrics["recall_macro"], metrics["f1_macro"],
                metrics["auroc_micro"], metrics["auroc_macro"]
            ])
        print(f"{method_name} | Epoch {epoch}/{epochs} - Loss: {total_loss/len(train_loader):.4f} "
              f"- P_micro: {metrics['precision_micro']:.4f} R_micro: {metrics['recall_micro']:.4f} F1_micro: {metrics['f1_micro']:.4f} "
              f"- P_macro: {metrics['precision_macro']:.4f} R_macro: {metrics['recall_macro']:.4f} F1_macro: {metrics['f1_macro']:.4f} "
              f"- AUROC_micro: {metrics['auroc_micro']:.4f} AUROC_macro: {metrics['auroc_macro']:.4f}")

# Models

def baseline_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    ).to(device)
    train_model(model, "baseline")

def pruning_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    ).to(device)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            import torch.nn.utils.prune as prune
            prune.l1_unstructured(module, name="weight", amount=0.3)
    train_model(model, "pruning")

def lowrank_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    ).to(device)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            weight = module.weight.data
            try:
                u, s, v = torch.svd_lowrank(weight, q=8)
                module.weight.data.copy_((u @ torch.diag(s) @ v.t()).to(weight.device))
            except Exception:
                pass
    train_model(model, "lowrank")

def distillation_model():
    teacher = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    ).to(device)
    teacher.eval()

    student = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    ).to(device)

    optimizer = torch.optim.AdamW(student.parameters(), lr=2e-5)
    scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))
    bce_loss_fn = nn.BCEWithLogitsLoss(pos_weight=class_weights)

    csv_file = "multilabel_distillation_metrics.csv"
    with open(csv_file, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "P_micro", "R_micro", "F1_micro", "P_macro", "R_macro", "F1_macro", "AUROC_micro", "AUROC_macro"])

    for epoch in range(1, epochs + 1):
        student.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            with torch.no_grad():
                teacher_logits = teacher(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits

            student_logits = student(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits

            ce_loss = bce_loss_fn(student_logits, batch["labels"].float())

            T = temperature
            student_log_prob = F.log_softmax(student_logits / T, dim=1)
            teacher_prob = F.softmax(teacher_logits / T, dim=1)
            kd_loss = F.kl_div(student_log_prob, teacher_prob, reduction="batchmean") * (T * T)

            loss = 0.1 * ce_loss + 0.9 * kd_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        metrics = evaluate_model(student, eval_loader)
        with open(csv_file, mode="a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([
                epoch,
                metrics["precision_micro"], metrics["recall_micro"], metrics["f1_micro"],
                metrics["precision_macro"], metrics["recall_macro"], metrics["f1_macro"],
                metrics["auroc_micro"], metrics["auroc_macro"]
            ])
        print(f"Distillation | Epoch {epoch}/{epochs} - Loss: {total_loss/len(train_loader):.4f} "
              f"- P_micro: {metrics['precision_micro']:.4f} R_micro: {metrics['recall_micro']:.4f} F1_micro: {metrics['f1_micro']:.4f} "
              f"- P_macro: {metrics['precision_macro']:.4f} R_macro: {metrics['recall_macro']:.4f} F1_macro: {metrics['f1_macro']:.4f} "
              f"- AUROC_micro: {metrics['auroc_micro']:.4f} AUROC_macro: {metrics['auroc_macro']:.4f}")

def quantization_model():
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(top_codes),
        problem_type="multi_label_classification"
    )
    model.to("cpu")
    model.eval()
    model_quantized = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    model_quantized.to(device)
    train_model(model_quantized)
if __name__ == "__main__":
    run_all()

  from pandas.core.computation.check import NUMEXPR_INSTALLED


Map:   0%|          | 0/808 [00:00<?, ? examples/s]

Map:   0%|          | 0/90 [00:00<?, ? examples/s]

NameError: name 'run_all' is not defined