In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import re
import csv
from collections import Counter
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from torch.utils.data import DataLoader
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score
import torch.nn.utils.prune as prune

# === CONFIG ===
model_name = "emilyalsentzer/Bio_ClinicalBERT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 16
max_length = 256
epochs = 20
top_K = 5

# === STEP 1: Symptom Lexicon ===
symptom_lexicon = list(set([
    "fever", "cough", "headache", "nausea", "vomiting", "fatigue", "chest pain", "shortness of breath",
    "abdominal pain", "dizziness", "diarrhea", "constipation", "joint pain", "back pain", "depression", "anxiety",
    "rash", "itching", "seizure", "confusion", "palpitations", "insomnia", "loss of appetite", "urinary frequency",
    "chills", "syncope", "sore throat", "swelling", "pain", "malaise", "cramps", "numbness", "tingling",
    "blurry vision", "weakness", "edema", "hallucinations", "bleeding", "difficulty breathing", "burning"
]))

# === STEP 2: Load Notes and Extract Top-K Symptoms ===
notes = pd.read_csv("NOTEEVENTS_random.csv", usecols=["TEXT"]).dropna()
notes = notes.sample(frac=0.1, random_state=42)
notes["TEXT"] = notes["TEXT"].str.lower().str.slice(0, 1000)

def extract_label(text):
    for s in symptom_lexicon:
        if re.search(rf"\b{re.escape(s)}\b", text):
            return s
    return None

notes["symptom"] = notes["TEXT"].apply(extract_label)
notes = notes.dropna(subset=["symptom"])

# Top-K symptoms
top_symptoms = Counter(notes["symptom"]).most_common(top_K)
symptom_list = [s for s, _ in top_symptoms]
#symptom_list = []
symptom2id = {s: i for i, s in enumerate(symptom_list)}
id2symptom = {i: s for s, i in symptom2id.items()}

# === STEP 3: Balance Dataset ===
balanced = []
min_count = min([sum(notes["symptom"] == s) for s in symptom_list])
for s in symptom_list:
    subset = notes[notes["symptom"] == s].sample(n=min_count, random_state=42)
    balanced.append(subset)

balanced_df = pd.concat(balanced).reset_index(drop=True)
balanced_df = balanced_df.rename(columns={"TEXT": "text", "symptom": "label"})
balanced_df["label"] = balanced_df["label"].map(symptom2id)

print(f"✅ Balanced dataset with {min_count} samples per class, total = {len(balanced_df)}")
print("Symptoms:", symptom_list)

# === STEP 4: Prepare HuggingFace Datasets ===
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42)
train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
eval_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
eval_dataset = eval_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# === STEP 5: Evaluation Function with Per-Symptom Metrics ===
def evaluate_model(model, loader, method_name=None, epoch=None):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            labels = batch["label"].cpu().numpy()
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels)

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro", zero_division=0)

    try:
        auroc = roc_auc_score(np.eye(len(symptom_list))[all_labels], np.array(all_probs), average="macro", multi_class="ovr")
    except Exception as e:
        print("AUROC error:", str(e))
        auroc = 0.0

    if method_name and epoch is not None:
        per_class_metrics = precision_recall_fscore_support(all_labels, all_preds, labels=list(range(len(symptom_list))), zero_division=0)
        try:
            per_class_auroc = [
                roc_auc_score((np.array(all_labels) == i).astype(int), np.array(all_probs)[:, i])
                for i in range(len(symptom_list))
            ]
        except:
            per_class_auroc = [0.0] * len(symptom_list)

        csv_file = f"per_symptom_{method_name}_metrics.csv"
        write_header = not os.path.exists(csv_file)
        with open(csv_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(["Epoch", "Symptom", "Precision", "Recall", "F1", "AUROC"])
            for i in range(len(symptom_list)):
                writer.writerow([
                    epoch,
                    id2symptom[i],
                    per_class_metrics[0][i],
                    per_class_metrics[1][i],
                    per_class_metrics[2][i],
                    per_class_auroc[i]
                ])

    return acc, precision, recall, f1, auroc

# === STEP 6: Training ===
def train_model(model, method_name, lr):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))
    csv_file = f"multiclass_{method_name}_metrics.csv"
    with open(csv_file, mode="w", newline="") as f:
        csv.writer(f).writerow(["Epoch", "Accuracy", "Precision", "Recall", "F1", "AUROC"])

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            loss = F.cross_entropy(outputs.logits, batch["label"].long())
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        acc, prec, rec, f1, auroc = evaluate_model(model, eval_loader, method_name, epoch)
        with open(csv_file, mode="a", newline="") as f:
            csv.writer(f).writerow([epoch, acc, prec, rec, f1, auroc])
        print(f"{method_name} | Epoch {epoch}/{epochs} - Loss: {total_loss / len(train_loader):.4f} - Acc: {acc:.4f} P: {prec:.4f} R: {rec:.4f} F1: {f1:.4f} AUROC: {auroc:.4f}")

# === STEP 7: Models ===
def base_model():
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list)).to(device)
   # optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    train_model(model, "base",lr=2e-5) 

def pruning_model():
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list)).to(device)
    #optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    for _, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=0.3)
    train_model(model, "pruning", lr=2e-5)

def lowrank_model():
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list)).to(device)

    # Apply low-rank SVD to only the classifier layer
    classifier = model.classifier
    weight = classifier.weight.data
    rank = min(32, weight.size(1) // 2)  # tunable rank

    try:
        # Compute low-rank SVD approximation
        u, s, v = torch.svd_lowrank(weight, q=rank)
        lowrank_weight = (u @ torch.diag(s) @ v.t()).to(weight.device)

        # Replace original weight with low-rank approximation
        classifier.weight.data.copy_(lowrank_weight)

        print(f"✅ Applied low-rank SVD to classifier with rank {rank}")
    except Exception as e:
        print("⚠️ SVD low-rank approximation failed:", e)

    # Use higher LR to adapt the modified layer
  #  optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    train_model(model, "lowrank",  lr=5e-5)


def quantization_model():
    from torch.quantization import get_default_qat_qconfig, prepare_qat, convert

    # Load and prepare model
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list))
    model.train()
    model.qconfig = get_default_qat_qconfig("fbgemm")

    # Only QAT on encoder (to avoid issues with LayerNorm)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            module.qconfig = model.qconfig

    # Prepare QAT model
    model_prepared = prepare_qat(model)
    model_prepared.to(device)

    # Train QAT model
    train_model(model_prepared, "qat", lr=2e-5)

    # Convert to quantized version after training
    model_quantized = convert(model_prepared.eval().cpu())
    model_quantized.to(device)

    # Optional: Evaluate final quantized model
    acc, prec, rec, f1, auroc = evaluate_model(model_quantized, eval_loader, "qat-final", epoch="final")
    print(f"QAT Final — Acc: {acc:.4f}, F1: {f1:.4f}, AUROC: {auroc:.4f}")


def distillation_model():
    teacher = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list)).to(device)
    teacher.eval()

    # Smaller student (or same if just testing)
    student = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_list)).to(device)

    optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)  # slightly higher LR for student
    scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))

    T = 4.0  # temperature
    alpha = 0.5  # balance CE and KD loss

    csv_file = "multiclass_distillation_metrics.csv"
    with open(csv_file, mode="w", newline="") as f:
        csv.writer(f).writerow(["Epoch", "Accuracy", "Precision", "Recall", "F1", "AUROC"])

    for epoch in range(1, epochs + 1):
        student.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            with torch.no_grad():
                teacher_logits = teacher(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits

            student_logits = student(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits

            ce_loss = F.cross_entropy(student_logits, batch["label"].long())
            kd_loss = F.kl_div(
                F.log_softmax(student_logits / T, dim=-1),
                F.softmax(teacher_logits / T, dim=-1),
                reduction="batchmean"
            ) * (T * T)

            loss = alpha * ce_loss + (1 - alpha) * kd_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        acc, prec, rec, f1, auroc = evaluate_model(student, eval_loader, "distillation", epoch)
        with open(csv_file, mode="a", newline="") as f:
            csv.writer(f).writerow([epoch, acc, prec, rec, f1, auroc])
        print(f"Distill | Epoch {epoch}/{epochs} - Loss: {total_loss/len(train_loader):.4f} "
              f"- Acc: {acc:.4f} P: {prec:.4f} R: {rec:.4f} F1: {f1:.4f} AUROC: {auroc:.4f}")

# === STEP 8: Run All Models ===
def run_all():
  #  base_model()
   # pruning_model()
    lowrank_model()
   # distillation_model()
   # quantization_model()

if __name__ == "__main__":
    run_all()


✅ Balanced dataset with 870 samples per class, total = 4350
Symptoms: ['edema', 'pain', 'cough', 'fever', 'bleeding']


Map:   0%|          | 0/3480 [00:00<?, ? examples/s]

Map:   0%|          | 0/870 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Applied low-rank SVD to classifier with rank 32


In [4]:
pip install transformers

Collecting transformers
  Downloading transformers-4.53.2-py3-none-any.whl.metadata (40 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Downloading tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading transformers-4.53.2-py3-none-any.whl (10.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m120.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.21.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m233.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)
Installing collected packages: safetensors, tokenizers, transformers
[2K   [90m━━━━━━━━

In [5]:
pip install datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Installing collected packages: xxhash, fsspec, dill, multiprocess, datasets
[2K  Attemptin