In [3]:

#Credential accessed of MIMIC III dataset
# Part of code was generated with OpenAI chatgpt
#Symptom Extraction from Clinical dataset


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import re
import csv
import os
import time   # ← ADD THIS LINE
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from torch.utils.data import DataLoader
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score

# === CONFIG ===
model_name = "emilyalsentzer/Bio_ClinicalBERT"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
max_length = 256
epochs = 10
top_K = 5 # top symptoms to select from subset

# === STEP 1: Full Symptom Lexicon ===
symptom_lexicon_full = [
    "fever", "cough", "headache", "nausea", "vomiting", "fatigue", "chest pain", "shortness of breath",
    "abdominal pain", "dizziness", "diarrhea", "constipation", "joint pain", "back pain", "depression", "anxiety",
    "rash", "itching", "seizure", "confusion", "palpitations", "insomnia", "loss of appetite", "urinary frequency",
    "chills", "syncope", "sore throat", "swelling", "pain", "malaise", "cramps", "numbness", "tingling",
    "blurry vision", "weakness", "edema", "hallucinations", "bleeding", "difficulty breathing", "burning"
]

# === STEP 2: Load Notes & Sample 10% ===
notes_full = pd.read_csv(
    "NOTEEVENTS_random.csv",
    usecols=["TEXT"],
    quoting=csv.QUOTE_NONE,
    on_bad_lines="skip"
).dropna()

notes = notes_full.sample(frac=0.2, random_state=0).copy() # adjust subsample percentage
notes["TEXT"] = notes["TEXT"].str.lower().str.slice(0, 1000)

# === STEP 2b: Extract symptoms ===
def extract_label(text):
    """Return first matching symptom (single-label)."""
    matches = [s for s in symptom_lexicon_full if re.search(rf"\b{re.escape(s)}\b", text)]
    return matches[0] if matches else None

notes["symptom"] = notes["TEXT"].apply(extract_label)
notes = notes.dropna(subset=["symptom"])

# === STEP 3: Select Top-K Symptoms from 10% subset ===
top_symptoms = [s for s, _ in Counter(notes["symptom"]).most_common(top_K)]
if not top_symptoms:
    raise ValueError("❌ No symptoms matched the lexicon in the dataset.")
symptom_lexicon = top_symptoms
print("✅ Top symptoms from 10% subset:", symptom_lexicon)

# Optional: show counts of top symptoms
top_counts = Counter(notes["symptom"])
for s in symptom_lexicon:
    print(f"  {s}: {top_counts[s]} occurrences")

notes = notes[notes["symptom"].isin(symptom_lexicon)]
symptom2id = {s: i for i, s in enumerate(symptom_lexicon)}
id2symptom = {i: s for s, i in symptom2id.items()}
notes["label"] = notes["symptom"].map(symptom2id)

# === STEP 4: Balance Dataset ===
balanced = []
counts = [sum(notes["label"] == i) for i in range(len(symptom_lexicon))]
min_count = min(counts)
print("Class distribution before balancing:", counts)

for i in range(len(symptom_lexicon)):
    subset = notes[notes["label"] == i].sample(n=min_count, random_state=42)
    balanced.append(subset)

balanced_df = pd.concat(balanced).reset_index(drop=True)
balanced_df = balanced_df.rename(columns={"TEXT": "text"})
print(f"✅ Balanced dataset with {min_count} samples per class, total = {len(balanced_df)}")

# === STEP 5: Prepare HuggingFace Datasets ===
train_df, test_df = train_test_split(balanced_df, test_size=0.1, random_state=42)
train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
eval_dataset = Dataset.from_pandas(test_df.reset_index(drop=True))

tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_fn(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
eval_dataset = eval_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)

# === STEP 6: Evaluation Function ===
def evaluate_model(model, loader, method_name=None, epoch=None):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            labels = batch["label"].cpu().numpy()
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels)

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro", zero_division=0)
    try:
        auroc = roc_auc_score(np.eye(len(symptom_lexicon))[all_labels], np.array(all_probs), average="macro", multi_class="ovr")
    except:
        auroc = 0.0

    if method_name and epoch is not None:
        csv_file = f"{method_name}_metrics_More_Data_Scenario.csv"
        write_header = not os.path.exists(csv_file)
        with open(csv_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            if write_header:
                writer.writerow(["Epoch", "Accuracy", "Precision", "Recall", "F1", "AUROC"])
            writer.writerow([epoch, acc, precision, recall, f1, auroc])

        per_class_metrics = precision_recall_fscore_support(all_labels, all_preds, labels=list(range(len(symptom_lexicon))), zero_division=0)
        try:
            per_class_auroc = [roc_auc_score((np.array(all_labels) == i).astype(int), np.array(all_probs)[:, i]) for i in range(len(symptom_lexicon))]
        except:
            per_class_auroc = [0.0] * len(symptom_lexicon)

        per_symptom_file = f"{method_name}_per_symptom_metrics_More_Data_Scenario.csv"
        write_header_ps = not os.path.exists(per_symptom_file)
        with open(per_symptom_file, mode="a", newline="") as f:
            writer = csv.writer(f)
            if write_header_ps:
                writer.writerow(["Epoch", "Symptom", "Precision", "Recall", "F1", "AUROC"])
            for i in range(len(symptom_lexicon)):
                writer.writerow([
                    epoch,
                    id2symptom[i],
                    per_class_metrics[0][i],
                    per_class_metrics[1][i],
                    per_class_metrics[2][i],
                    per_class_auroc[i]
                ])
    return acc, precision, recall, f1, auroc

# === STEP 7: Distillation Training with Activation Switching ===
#class GReLU(nn.Module):
#    """Convex Generalized ReLU"""
#    def __init__(self):
#        super().__init__()
#        self.a = nn.Parameter(torch.tensor(0.1))

 #   def forward(self, x):
#        return torch.maximum(x, self.a * x)

class GatedReLU(nn.Module):
    """True Gated ReLU: ReLU(x) * sigmoid(alpha * x)"""
    def __init__(self, alpha_init=1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha_init))

    def forward(self, x):
        gate = torch.sigmoid(self.alpha * x)
        return F.relu(x) * gate

def replace_activation(model, activation="relu"):
    for name, module in model.named_children():
        if isinstance(module, nn.ReLU):
            if activation == "relu":
                setattr(model, name, nn.ReLU())
            elif activation == "grelu":
                setattr(model, name, GatedReLU())
        else:
            replace_activation(module, activation)
    return model


def distillation_train(strategy="convex"):
    import gc, torch
    gc.collect()
    torch.cuda.empty_cache()

    # === COMPUTE BUDGET INIT ===
    start_total = time.time()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    num_labels = len(symptom_lexicon)

    teacher = AutoModelForSequenceClassification.from_pretrained(
        "emilyalsentzer/Bio_ClinicalBERT",
        num_labels=num_labels
    ).to(device)
    teacher.eval()

    student = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=num_labels
    ).to(device)

    # (optional) param count
    trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)

    optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=epochs * len(train_loader)
    )

    T = 4.0

    for epoch in range(1, epochs + 1):
        # === EPOCH TIMER START ===
        start_epoch = time.time()

        # ---- choose activation regime ----
        if strategy == "convex":
            student = replace_activation(student, "grelu")
            act_stage = "Convex (GReLU)"
        elif strategy == "nonconvex":
            student = replace_activation(student, "relu")
            act_stage = "Nonconvex (ReLU)"
        elif strategy == "multistage":
            if epoch <= 2:
                student = replace_activation(student, "grelu")
                act_stage = "Convex (GReLU)"
            else:
                student = replace_activation(student, "relu")
                act_stage = "Nonconvex (ReLU)"
                optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
                scheduler = get_scheduler(
                    "cosine",
                    optimizer=optimizer,
                    num_warmup_steps=0,
                    num_training_steps=epochs * len(train_loader)
                )
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        # ---- training ----
        student.train()
        total_loss = 0.0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad(set_to_none=True)

            with torch.no_grad():
                teacher_logits = teacher(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                ).logits

            student_logits = student(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"]
            ).logits

            ce_loss = F.cross_entropy(student_logits, batch["label"].long())
            kd_loss = F.kl_div(
                F.log_softmax(student_logits / T, dim=-1),
                F.softmax(teacher_logits / T, dim=-1),
                reduction="batchmean"
            ) * (T * T)

            loss = 0.5 * ce_loss + 0.5 * kd_loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        # === EPOCH TIMER END ===
        epoch_time = time.time() - start_epoch

        # ---- evaluation ----
        acc, prec, rec, f1, auroc = evaluate_model(student, eval_loader, strategy, epoch)

        print(
            f"{strategy} | Epoch {epoch}/{epochs} [{act_stage}] "
            f"| Time: {epoch_time:.2f}s "
            f"| Loss: {total_loss/len(train_loader):.4f} "
            f"| Acc: {acc:.4f} "
            f"| Prec: {prec:.4f} "
            f"| Rec: {rec:.4f} "
            f"| F1: {f1:.4f} "
            f"| AUROC: {auroc:.4f}"
        )

    # === FINAL BUDGET SUMMARY (after all epochs) ===
    total_time = time.time() - start_total
    peak_mem_gb = (torch.cuda.max_memory_allocated() / 1024**3) if torch.cuda.is_available() else 0.0

    print(
        f"✅ Budget ({strategy}): TotalTime={total_time:.2f}s | "
        f"PeakVRAM={peak_mem_gb:.2f} GB | Params={trainable_params/1e6:.2f}M"
    )

    # optional: save budget row
    with open("computational_budget.csv", "a", newline="") as f:
        writer = csv.writer(f)
        if f.tell() == 0:
            writer.writerow(["Strategy", "TotalTime_sec", "PeakMemory_GB", "Epochs", "BatchSize", "TrainableParams"])
        writer.writerow([strategy, total_time, peak_mem_gb, epochs, batch_size, trainable_params])





# === STEP 7: Run All Three Strategies ===
def run_all():
    distillation_train("convex")
    distillation_train("nonconvex")
    distillation_train("multistage")



if __name__ == "__main__":
    run_all()




✅ Top symptoms from 10% subset: ['pain', 'fever', 'cough', 'seizure', 'confusion']
  pain: 351 occurrences
  fever: 152 occurrences
  cough: 86 occurrences
  seizure: 86 occurrences
  confusion: 70 occurrences
Class distribution before balancing: [351, 152, 86, 86, 70]
✅ Balanced dataset with 70 samples per class, total = 350


Map:   0%|          | 0/315 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

convex | Epoch 1/10 [Convex (GReLU)] | Time: 16.14s | Loss: 0.7642 | Acc: 0.5714 | Prec: 0.7274 | Rec: 0.5690 | F1: 0.5672 | AUROC: 0.8880
convex | Epoch 2/10 [Convex (GReLU)] | Time: 16.68s | Loss: 0.6378 | Acc: 0.7429 | Prec: 0.8648 | Rec: 0.7298 | F1: 0.7482 | AUROC: 0.9506
convex | Epoch 3/10 [Convex (GReLU)] | Time: 16.16s | Loss: 0.5314 | Acc: 0.8286 | Prec: 0.8945 | Rec: 0.8262 | F1: 0.8427 | AUROC: 0.9981
convex | Epoch 4/10 [Convex (GReLU)] | Time: 15.76s | Loss: 0.4686 | Acc: 0.9143 | Prec: 0.9455 | Rec: 0.9214 | F1: 0.9245 | AUROC: 0.9990
convex | Epoch 5/10 [Convex (GReLU)] | Time: 15.73s | Loss: 0.4413 | Acc: 0.9429 | Prec: 0.9492 | Rec: 0.9429 | F1: 0.9395 | AUROC: 1.0000
convex | Epoch 6/10 [Convex (GReLU)] | Time: 15.91s | Loss: 0.4269 | Acc: 0.9714 | Prec: 0.9714 | Rec: 0.9714 | F1: 0.9692 | AUROC: 1.0000
convex | Epoch 7/10 [Convex (GReLU)] | Time: 16.05s | Loss: 0.4189 | Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000 | AUROC: 1.0000
convex | Epoch 8/10 [Convex

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

nonconvex | Epoch 1/10 [Nonconvex (ReLU)] | Time: 15.90s | Loss: 0.8291 | Acc: 0.3143 | Prec: 0.3130 | Rec: 0.3000 | F1: 0.2812 | AUROC: 0.7145
nonconvex | Epoch 2/10 [Nonconvex (ReLU)] | Time: 16.06s | Loss: 0.7216 | Acc: 0.5429 | Prec: 0.4164 | Rec: 0.5250 | F1: 0.4607 | AUROC: 0.8030
nonconvex | Epoch 3/10 [Nonconvex (ReLU)] | Time: 16.10s | Loss: 0.6269 | Acc: 0.6571 | Prec: 0.7242 | Rec: 0.6393 | F1: 0.6507 | AUROC: 0.8888
nonconvex | Epoch 4/10 [Nonconvex (ReLU)] | Time: 16.04s | Loss: 0.5328 | Acc: 0.7429 | Prec: 0.8024 | Rec: 0.7345 | F1: 0.7522 | AUROC: 0.9436
nonconvex | Epoch 5/10 [Nonconvex (ReLU)] | Time: 15.92s | Loss: 0.4706 | Acc: 0.8000 | Prec: 0.8367 | Rec: 0.7845 | F1: 0.7976 | AUROC: 0.9728
nonconvex | Epoch 6/10 [Nonconvex (ReLU)] | Time: 15.91s | Loss: 0.4403 | Acc: 0.8571 | Prec: 0.8733 | Rec: 0.8512 | F1: 0.8556 | AUROC: 0.9772
nonconvex | Epoch 7/10 [Nonconvex (ReLU)] | Time: 15.98s | Loss: 0.4302 | Acc: 0.8571 | Prec: 0.8733 | Rec: 0.8512 | F1: 0.8556 | AUROC:

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: emilyalsentzer/Bio_ClinicalBERT
Key                                        | Status     | 
-------------------------------------------+------------+-
cls.predictions.decoder.weight             | UNEXPECTED | 
cls.seq_relationship.weight                | UNEXPECTED | 
cls.predictions.transform.LayerNorm.bias   | UNEXPECTED | 
cls.predictions.bias                       | UNEXPECTED | 
cls.predictions.transform.dense.weight     | UNEXPECTED | 
cls.predictions.transform.LayerNorm.weight | UNEXPECTED | 
cls.seq_relationship.bias                  | UNEXPECTED | 
cls.predictions.transform.dense.bias       | UNEXPECTED | 
classifier.weight                          | MISSING    | 
classifier.bias                            | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Conside

multistage | Epoch 1/10 [Convex (GReLU)] | Time: 15.95s | Loss: 0.8068 | Acc: 0.5429 | Prec: 0.4635 | Rec: 0.5714 | F1: 0.4952 | AUROC: 0.9184
multistage | Epoch 2/10 [Convex (GReLU)] | Time: 16.06s | Loss: 0.6772 | Acc: 0.6857 | Prec: 0.5750 | Rec: 0.7131 | F1: 0.6280 | AUROC: 0.9600
multistage | Epoch 3/10 [Nonconvex (ReLU)] | Time: 16.06s | Loss: 0.5638 | Acc: 0.9714 | Prec: 0.9778 | Rec: 0.9714 | F1: 0.9729 | AUROC: 0.9920
multistage | Epoch 4/10 [Nonconvex (ReLU)] | Time: 15.98s | Loss: 0.4825 | Acc: 0.9714 | Prec: 0.9714 | Rec: 0.9714 | F1: 0.9692 | AUROC: 0.9958
multistage | Epoch 5/10 [Nonconvex (ReLU)] | Time: 15.93s | Loss: 0.4531 | Acc: 0.9714 | Prec: 0.9714 | Rec: 0.9714 | F1: 0.9692 | AUROC: 1.0000
multistage | Epoch 6/10 [Nonconvex (ReLU)] | Time: 15.93s | Loss: 0.4392 | Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000 | AUROC: 1.0000
multistage | Epoch 7/10 [Nonconvex (ReLU)] | Time: 15.95s | Loss: 0.4375 | Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000 | AUR

In [2]:
#computational time of the algorithms