In [5]:

#Credential accessed of MIMIC III dataset
# Part of code was generated with OpenAI chatgpt
#Symptom Extraction from Clinical dataset


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import re
import csv
from collections import Counter
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from torch.utils.data import DataLoader
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score

# === CONFIG ===
model_name = "emilyalsentzer/Bio_ClinicalBERT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
max_length = 256
epochs = 10
top_K = 5 # top symptoms to select from subset

# === STEP 1: Full Symptom Lexicon ===
symptom_lexicon_full = [
    "fever", "cough", "headache", "nausea", "vomiting", "fatigue", "chest pain", "shortness of breath",
    "abdominal pain", "dizziness", "diarrhea", "constipation", "joint pain", "back pain", "depression", "anxiety",
    "rash", "itching", "seizure", "confusion", "palpitations", "insomnia", "loss of appetite", "urinary frequency",
    "chills", "syncope", "sore throat", "swelling", "pain", "malaise", "cramps", "numbness", "tingling",
    "blurry vision", "weakness", "edema", "hallucinations", "bleeding", "difficulty breathing", "burning"
]

# === STEP 2: Load Notes & Sample 10% ===
notes_full = pd.read_csv(
    "NOTEEVENTS_random.csv",
    usecols=["TEXT"],
    quoting=csv.QUOTE_NONE,
    on_bad_lines="skip"
).dropna()

notes = notes_full.sample(frac=0.1, random_state=0).copy() # adjust subsample percentage
notes["TEXT"] = notes["TEXT"].str.lower().str.slice(0, 1000)

# === STEP 2b: Extract symptoms ===
def extract_label(text):
    """Return first matching symptom (single-label)."""
    matches = [s for s in symptom_lexicon_full if re.search(rf"\b{re.escape(s)}\b", text)]
    return matches[0] if matches else None

notes["symptom"] = notes["TEXT"].apply(extract_label)
notes = notes.dropna(subset=["symptom"])

# === STEP 3: Select Top-K Symptoms from 10% subset ===
top_symptoms = [s for s, _ in Counter(notes["symptom"]).most_common(top_K)]
if not top_symptoms:
    raise ValueError("❌ No symptoms matched the lexicon in the dataset.")
symptom_lexicon = top_symptoms
print("✅ Top symptoms from 10% subset:", symptom_lexicon)

# Optional: show counts of top symptoms
top_counts = Counter(notes["symptom"])
for s in symptom_lexicon:
    print(f"  {s}: {top_counts[s]} occurrences")

notes = notes[notes["symptom"].isin(symptom_lexicon)]
symptom2id = {s: i for i, s in enumerate(symptom_lexicon)}
id2symptom = {i: s for s, i in symptom2id.items()}
notes["label"] = notes["symptom"].map(symptom2id)

# === STEP 4: Balance Dataset ===
balanced = []
counts = [sum(notes["label"] == i) for i in range(len(symptom_lexicon))]
min_count = min(counts)
print("Class distribution before balancing:", counts)

for i in range(len(symptom_lexicon)):
    subset = notes[notes["label"] == i].sample(n=min_count, random_state=42)
    balanced.append(subset)

balanced_df = pd.concat(balanced).reset_index(drop=True)
balanced_df = balanced_df.rename(columns={"TEXT": "text"})
print(f"✅ Balanced dataset with {min_count} samples per class, total = {len(balanced_df)}")

# === STEP 5: Prepare HuggingFace Datasets ===
train_df, test_df = train_test_split(balanced_df, test_size=0.1, random_state=42)
train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
eval_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
eval_dataset = eval_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# === STEP 6: Evaluation Function ===
def evaluate_model(model, loader, method_name=None, epoch=None):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            labels = batch["label"].cpu().numpy()
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels)

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro", zero_division=0)
    try:
        auroc = roc_auc_score(np.eye(len(symptom_lexicon))[all_labels], np.array(all_probs), average="macro", multi_class="ovr")
    except:
        auroc = 0.0

    if method_name and epoch is not None:
        csv_file = f"{method_name}_metrics_Limited_Data_Scenario.csv"
        write_header = not os.path.exists(csv_file)
        with open(csv_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(["Epoch", "Accuracy", "Precision", "Recall", "F1", "AUROC"])
            writer.writerow([epoch, acc, precision, recall, f1, auroc])

        per_class_metrics = precision_recall_fscore_support(all_labels, all_preds, labels=list(range(len(symptom_lexicon))), zero_division=0)
        try:
            per_class_auroc = [roc_auc_score((np.array(all_labels) == i).astype(int), np.array(all_probs)[:, i]) for i in range(len(symptom_lexicon))]
        except:
            per_class_auroc = [0.0] * len(symptom_lexicon)

        per_symptom_file = f"{method_name}_per_symptom_metrics_Limited_Data_Scenario.csv"
        write_header_ps = not os.path.exists(per_symptom_file)
        with open(per_symptom_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            if write_header_ps:
                writer.writerow(["Epoch", "Symptom", "Precision", "Recall", "F1", "AUROC"])
            for i in range(len(symptom_lexicon)):
                writer.writerow([
                    epoch,
                    id2symptom[i],
                    per_class_metrics[0][i],
                    per_class_metrics[1][i],
                    per_class_metrics[2][i],
                    per_class_auroc[i]
                ])
    return acc, precision, recall, f1, auroc

# === STEP 7: Distillation Training with Activation Switching ===
#class GReLU(nn.Module):
#    """Convex Generalized ReLU"""
#    def __init__(self):
#        super().__init__()
#        self.a = nn.Parameter(torch.tensor(0.1))

 #   def forward(self, x):
#        return torch.maximum(x, self.a * x)

class GatedReLU(nn.Module):
    """True Gated ReLU: ReLU(x) * sigmoid(alpha * x)"""
    def __init__(self, alpha_init=1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha_init))

    def forward(self, x):
        gate = torch.sigmoid(self.alpha * x)
        return F.relu(x) * gate

def replace_activation(model, activation="relu"):
    for name, module in model.named_children():
        if isinstance(module, nn.ReLU):
            if activation == "relu":
                setattr(model, name, nn.ReLU())
            elif activation == "grelu":
                setattr(model, name, GatedReLU())
        else:
            replace_activation(module, activation)
    return model

def distillation_train(strategy="convex"):
    #del teacher, student, optimizer, scheduler
    import gc, torch
    gc.collect()
    torch.cuda.empty_cache()


    num_labels = len(symptom_lexicon)

    teacher = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=num_labels).to(device)

    teacher.eval()

    student = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(symptom_lexicon)).to(device)

    optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
    scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))

    T = 4.0
    for epoch in range(1, epochs + 1):
        if strategy == "convex":
            student = replace_activation(student, "grelu")
            act_stage = "Convex (GReLU)"
        elif strategy == "nonconvex":
            student = replace_activation(student, "relu")
            act_stage = "Nonconvex (ReLU)"
        elif strategy == "multistage":
            if epoch <= 2:
                student = replace_activation(student, "grelu")
                act_stage = "Convex (GReLU)"
            else:
                student = replace_activation(student, "relu")
                act_stage = "Nonconvex (ReLU)"
                optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
                scheduler = get_scheduler("cosine", optimizer=optimizer, num_warmup_steps=0, num_training_steps=epochs * len(train_loader))
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        student.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            student_logits = student(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits

            ce_loss = F.cross_entropy(student_logits, batch["label"].long())
            kd_loss = F.kl_div(F.log_softmax(student_logits / T, dim=-1),
                               F.softmax(teacher_logits / T, dim=-1),
                               reduction="batchmean") * (T*T)

            loss = 0.5 * ce_loss + 0.5 * kd_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        acc, prec, rec, f1, auroc = evaluate_model(student, eval_loader, strategy, epoch)
        print(f"{strategy} | Epoch {epoch}/{epochs} [{act_stage}] - "
              f"Loss: {total_loss/len(train_loader):.4f} "
              f"- Acc: {acc:.4f} P: {prec:.4f} R: {rec:.4f} F1: {f1:.4f} AUROC: {auroc:.4f}")




# === STEP 7: Run All Three Strategies ===
def run_all():
    distillation_train("convex")
    distillation_train("nonconvex")
    distillation_train("multistage")



if __name__ == "__main__":
    run_all()




✅ Top symptoms from 10% subset: ['pain', 'fever', 'cough', 'seizure', 'confusion']
  pain: 195 occurrences
  fever: 86 occurrences
  cough: 46 occurrences
  seizure: 38 occurrences
  confusion: 32 occurrences
Class distribution before balancing: [195, 86, 46, 38, 32]
✅ Balanced dataset with 32 samples per class, total = 160


Map:   0%|          | 0/144 [00:00<?, ? examples/s]

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

convex | Epoch 1/10 [Convex (GReLU)] - Loss: 0.8878 - Acc: 0.3125 P: 0.1333 R: 0.2267 F1: 0.1608 AUROC: 0.7857
convex | Epoch 2/10 [Convex (GReLU)] - Loss: 0.7981 - Acc: 0.3750 P: 0.3467 R: 0.3533 F1: 0.3190 AUROC: 0.7736
convex | Epoch 3/10 [Convex (GReLU)] - Loss: 0.7203 - Acc: 0.4375 P: 0.4500 R: 0.4800 F1: 0.3943 AUROC: 0.8318
convex | Epoch 4/10 [Convex (GReLU)] - Loss: 0.6536 - Acc: 0.4375 P: 0.4500 R: 0.4800 F1: 0.3943 AUROC: 0.8613
convex | Epoch 5/10 [Convex (GReLU)] - Loss: 0.6156 - Acc: 0.5000 P: 0.5933 R: 0.5300 F1: 0.4466 AUROC: 0.8902
convex | Epoch 6/10 [Convex (GReLU)] - Loss: 0.5861 - Acc: 0.5000 P: 0.5933 R: 0.5300 F1: 0.4466 AUROC: 0.9082
convex | Epoch 7/10 [Convex (GReLU)] - Loss: 0.5646 - Acc: 0.5000 P: 0.5933 R: 0.5300 F1: 0.4466 AUROC: 0.9154
convex | Epoch 8/10 [Convex (GReLU)] - Loss: 0.5388 - Acc: 0.5625 P: 0.6600 R: 0.5800 F1: 0.5399 AUROC: 0.9154
convex | Epoch 9/10 [Convex (GReLU)] - Loss: 0.5278 - Acc: 0.5625 P: 0.6600 R: 0.5800 F1: 0.5399 AUROC: 0.9154
c

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

nonconvex | Epoch 1/10 [Nonconvex (ReLU)] - Loss: 0.9388 - Acc: 0.2500 P: 0.1071 R: 0.2333 F1: 0.1455 AUROC: 0.6040
nonconvex | Epoch 2/10 [Nonconvex (ReLU)] - Loss: 0.7870 - Acc: 0.3750 P: 0.3556 R: 0.3100 F1: 0.2743 AUROC: 0.7071
nonconvex | Epoch 3/10 [Nonconvex (ReLU)] - Loss: 0.6698 - Acc: 0.5625 P: 0.4242 R: 0.5000 F1: 0.4183 AUROC: 0.7489
nonconvex | Epoch 4/10 [Nonconvex (ReLU)] - Loss: 0.6478 - Acc: 0.6250 P: 0.5444 R: 0.6000 F1: 0.5362 AUROC: 0.7895
nonconvex | Epoch 5/10 [Nonconvex (ReLU)] - Loss: 0.5986 - Acc: 0.6250 P: 0.6000 R: 0.6000 F1: 0.5667 AUROC: 0.7886
nonconvex | Epoch 6/10 [Nonconvex (ReLU)] - Loss: 0.5682 - Acc: 0.5625 P: 0.5889 R: 0.5600 F1: 0.5476 AUROC: 0.8020
nonconvex | Epoch 7/10 [Nonconvex (ReLU)] - Loss: 0.5409 - Acc: 0.5625 P: 0.5889 R: 0.5600 F1: 0.5476 AUROC: 0.8056
nonconvex | Epoch 8/10 [Nonconvex (ReLU)] - Loss: 0.5396 - Acc: 0.5625 P: 0.5889 R: 0.5600 F1: 0.5476 AUROC: 0.8093
nonconvex | Epoch 9/10 [Nonconvex (ReLU)] - Loss: 0.5379 - Acc: 0.5625 P

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
classifier.bias                            | MISSING    | 
classifier.weight                          | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

multistage | Epoch 1/10 [Convex (GReLU)] - Loss: 0.8334 - Acc: 0.1875 P: 0.1500 R: 0.3000 F1: 0.1800 AUROC: 0.7210
multistage | Epoch 2/10 [Convex (GReLU)] - Loss: 0.7073 - Acc: 0.3125 P: 0.2244 R: 0.3800 F1: 0.2714 AUROC: 0.7760
multistage | Epoch 3/10 [Nonconvex (ReLU)] - Loss: 0.6585 - Acc: 0.6250 P: 0.6333 R: 0.6000 F1: 0.5810 AUROC: 0.8140
multistage | Epoch 4/10 [Nonconvex (ReLU)] - Loss: 0.5817 - Acc: 0.6875 P: 0.8000 R: 0.6767 F1: 0.7078 AUROC: 0.8908
multistage | Epoch 5/10 [Nonconvex (ReLU)] - Loss: 0.5255 - Acc: 0.7500 P: 0.8167 R: 0.7433 F1: 0.7621 AUROC: 0.8713
multistage | Epoch 6/10 [Nonconvex (ReLU)] - Loss: 0.4741 - Acc: 0.8750 P: 0.9429 R: 0.8500 F1: 0.8714 AUROC: 0.9319
multistage | Epoch 7/10 [Nonconvex (ReLU)] - Loss: 0.4400 - Acc: 0.8750 P: 0.9429 R: 0.8500 F1: 0.8714 AUROC: 0.9570
multistage | Epoch 8/10 [Nonconvex (ReLU)] - Loss: 0.4441 - Acc: 0.9375 P: 0.9667 R: 0.9000 F1: 0.9152 AUROC: 0.9426
multistage | Epoch 9/10 [Nonconvex (ReLU)] - Loss: 0.4422 - Acc: 0.8

In [None]:
pip install torch --upgrade

Collecting torch
  Downloading torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (31 kB)
Collecting cuda-bindings==12.9.4 (from torch)
  Downloading cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (2.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cublas-cu12==12.8.4.1 (from torch)
  Downloading nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cufft-cu12==11.3.3.83 (from torch)
  Downloadin