In [1]:
# Standard Libraries 
import os
import random
import pickle
from typing import Dict, List

# Third-Party Libraries 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Evaluation (sklearn) 
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
data = np.load("mimic_embed_data.npz")
train_embeddings = data["train_embeddings"]
train_labels = data["train_labels"]
valid_embeddings = data["valid_embeddings"]
valid_labels = data["valid_labels"]
test_embeddings = data["test_embeddings"]
test_labels = data["test_labels"]

In [3]:
class SingleLabelDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return {
            "embedding": self.embeddings[idx],
            "lab": self.labels[idx].unsqueeze(0)  #  shape: (1,)
        }

In [4]:
class TabTransformer(nn.Module):
    def __init__(self, input_dim=1376, hidden_dim=128, output_dim=1, nhead=8, nlayers=4, dropout=0.1):
        super(TabTransformer, self).__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)

        encoder_layer = self._build_encoder_layer(hidden_dim, nhead, dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)

        self.classifier = nn.Linear(hidden_dim, output_dim)

    def _build_encoder_layer(self, hidden_dim, nhead, dropout):
        return nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 8,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )

    def forward(self, x):
        x = self.input_proj(x)
        x = self.norm(x)
        x = x.unsqueeze(1)
        x = self.encoder(x)
        x = x.squeeze(1)
        return self.classifier(x)

In [5]:
class MaskedFocalLoss(nn.Module):
    def __init__(self, gamma=4):
        super().__init__()
        self.gamma = gamma

    def forward(self, logits, labels, mask):
        probs = torch.sigmoid(logits)
        loss = - labels * (1 - probs) ** self.gamma * torch.log(probs + 1e-8) \
               - (1 - labels) * probs ** self.gamma * torch.log(1 - probs + 1e-8)
        return (loss * mask).sum() / mask.sum()

In [6]:
def train_single_label_model(train_loader, val_loader, label_name, save_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = TabTransformer().to(device)
    criterion = criterion = MaskedFocalLoss(gamma=4) 
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    for epoch in range(15):
        model.train()
        total_loss = 0

        for batch in train_loader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device)
            mask = torch.ones_like(y)

            logits = model(x)
            loss = criterion(logits, y, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[{label_name}] Epoch {epoch+1} - Training loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Model for {label_name} saved to {save_path}")

    return model

In [7]:
label_names =  [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion",
    "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]

def get_single_label_data(embeddings, labels, label_idx):
    mask = ~np.isnan(labels[:, label_idx])
    return embeddings[mask], labels[mask, label_idx]

for i, label_name in enumerate(label_names):
    print(f"\n Training model for: {label_name}")

    train_X, train_y = get_single_label_data(train_embeddings, train_labels, i)
    val_X, val_y = get_single_label_data(valid_embeddings, valid_labels, i)

    train_ds = SingleLabelDataset(train_X, train_y)
    val_ds = SingleLabelDataset(val_X, val_y)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    model_path = f"transformer_label_{i}_{label_name}.pt"
    train_single_label_model(train_loader, val_loader, label_name, model_path)


 Training model for: Enlarged Cardiomediastinum
[Enlarged Cardiomediastinum] Epoch 1 - Training loss: 0.0192
[Enlarged Cardiomediastinum] Epoch 2 - Training loss: 0.0169
[Enlarged Cardiomediastinum] Epoch 3 - Training loss: 0.0166
[Enlarged Cardiomediastinum] Epoch 4 - Training loss: 0.0164
[Enlarged Cardiomediastinum] Epoch 5 - Training loss: 0.0164
[Enlarged Cardiomediastinum] Epoch 6 - Training loss: 0.0162
[Enlarged Cardiomediastinum] Epoch 7 - Training loss: 0.0162
[Enlarged Cardiomediastinum] Epoch 8 - Training loss: 0.0160
[Enlarged Cardiomediastinum] Epoch 9 - Training loss: 0.0159
[Enlarged Cardiomediastinum] Epoch 10 - Training loss: 0.0158
[Enlarged Cardiomediastinum] Epoch 11 - Training loss: 0.0158
[Enlarged Cardiomediastinum] Epoch 12 - Training loss: 0.0158
[Enlarged Cardiomediastinum] Epoch 13 - Training loss: 0.0158
[Enlarged Cardiomediastinum] Epoch 14 - Training loss: 0.0157
[Enlarged Cardiomediastinum] Epoch 15 - Training loss: 0.0156
Model for Enlarged Cardiomedia

In [8]:
def evaluate_model(model, dataloader, device):
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device).squeeze()  # shape: (batch_size,)

            logits = model(x).squeeze()  # shape: (batch_size,)
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except ValueError:
        auc = float('nan')

    return auc

In [9]:
for i, label_name in enumerate(label_names):
    print(f"\n Testing model for: {label_name}")

    # 1. Prepare test data for the current label
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Rebuild the model architecture and load saved weights
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Evaluate on the test set
    test_auc = evaluate_model(model, test_loader, device)
    print(f"[{label_name}] Test AUC: {test_auc:.4f}")


 Testing model for: Enlarged Cardiomediastinum
[Enlarged Cardiomediastinum] Test AUC: 0.8995

 Testing model for: Cardiomegaly
[Cardiomegaly] Test AUC: 0.9149

 Testing model for: Lung Opacity
[Lung Opacity] Test AUC: 0.9126

 Testing model for: Lung Lesion
[Lung Lesion] Test AUC: 0.9628

 Testing model for: Edema
[Edema] Test AUC: 0.7998

 Testing model for: Consolidation
[Consolidation] Test AUC: 0.7968

 Testing model for: Pneumonia
[Pneumonia] Test AUC: 0.8937

 Testing model for: Atelectasis
[Atelectasis] Test AUC: 0.8670

 Testing model for: Pneumothorax
[Pneumothorax] Test AUC: 0.9455

 Testing model for: Pleural Effusion
[Pleural Effusion] Test AUC: 0.9282

 Testing model for: Pleural Other
[Pleural Other] Test AUC: 0.8390

 Testing model for: Fracture
[Fracture] Test AUC: 0.9141

 Testing model for: Support Devices
[Support Devices] Test AUC: 0.9113


In [10]:
def safe_get(report, class_label, metric):
    """
    Safely extract a specific class or average metric.
    - For class_label as 0 or 1, try multiple key formats;
    - For average entries like 'macro avg', lookup directly.
    """
    if isinstance(class_label, (int, float)):
        keys_to_try = [class_label, str(class_label), f"{float(class_label):.1f}"]
    else:
        keys_to_try = [class_label]

    for key in keys_to_try:
        if key in report and metric in report[key]:
            return report[key][metric]
    return np.nan


def evaluate_full_report(model, dataloader, device):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            all_labels.extend(y.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Accuracy: how many predictions match the ground truth
    accuracy = (np.array(all_labels) == np.array(all_preds)).mean()

    return classification_report(all_labels, all_preds, digits=4, output_dict=True), accuracy


# Collect reports for all labels 
report_dict = {}

for i, label_name in enumerate(label_names):
    print("=" * 30)
    print(f" Test Evaluation Report for: {label_name}")
    print("=" * 30)

    # 1. Load test data for current label
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Load corresponding model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Get classification report and accuracy
    report, acc = evaluate_full_report(model, test_loader, device)

    report_dict[label_name] = {
        # Class 0
        'precision_0': safe_get(report, 0, 'precision'),
        'recall_0': safe_get(report, 0, 'recall'),
        'f1-score_0': safe_get(report, 0, 'f1-score'),

        # Class 1
        'precision_1': safe_get(report, 1, 'precision'),
        'recall_1': safe_get(report, 1, 'recall'),
        'f1-score_1': safe_get(report, 1, 'f1-score'),

        # Macro average
        'precision': safe_get(report, 'macro avg', 'precision'),
        'recall': safe_get(report, 'macro avg', 'recall'),
        'f1-score': safe_get(report, 'macro avg', 'f1-score'),
        'support': safe_get(report, 'macro avg', 'support'),

        # Accuracy
        'accuracy': acc
    }

#  Save as DataFrame 
df = pd.DataFrame(report_dict)
df.index.name = "Metric"
df = df.round(4)
df.to_csv("test_metrics_per_label.csv")
print("\n Saved report with accuracy to: test_metrics_per_label.csv")

 Test Evaluation Report for: Enlarged Cardiomediastinum
 Test Evaluation Report for: Cardiomegaly
 Test Evaluation Report for: Lung Opacity
 Test Evaluation Report for: Lung Lesion
 Test Evaluation Report for: Edema
 Test Evaluation Report for: Consolidation
 Test Evaluation Report for: Pneumonia
 Test Evaluation Report for: Atelectasis
 Test Evaluation Report for: Pneumothorax
 Test Evaluation Report for: Pleural Effusion
 Test Evaluation Report for: Pleural Other
 Test Evaluation Report for: Fracture
 Test Evaluation Report for: Support Devices

 Saved report with accuracy to: test_metrics_per_label.csv


In [11]:
def evaluate_auc_ap(model, dataloader, device):
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device).squeeze()

            logits = model(x).squeeze()
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except ValueError:
        auc = float('nan')

    try:
        ap = average_precision_score(all_targets, all_probs)
    except ValueError:
        ap = float('nan')

    return auc, ap

#  Evaluate on Test Set

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, label_name in enumerate(label_names):
    print("=" * 40)
    print(f" Test Metrics for: {label_name}")
    print("=" * 40)

    # 1. Load test data
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Load model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Evaluate
    auc, ap = evaluate_auc_ap(model, test_loader, device)
    print(f"AUC Score:              {auc:.4f}")
    print(f"Average Precision (AP): {ap:.4f}\n")

 Test Metrics for: Enlarged Cardiomediastinum
AUC Score:              0.8995
Average Precision (AP): 0.5773

 Test Metrics for: Cardiomegaly
AUC Score:              0.9149
Average Precision (AP): 0.5821

 Test Metrics for: Lung Opacity
AUC Score:              0.9126
Average Precision (AP): 0.3280

 Test Metrics for: Lung Lesion
AUC Score:              0.9628
Average Precision (AP): 0.5182

 Test Metrics for: Edema
AUC Score:              0.7998
Average Precision (AP): 0.1155

 Test Metrics for: Consolidation
AUC Score:              0.7968
Average Precision (AP): 0.1335

 Test Metrics for: Pneumonia
AUC Score:              0.8937
Average Precision (AP): 0.3330

 Test Metrics for: Atelectasis
AUC Score:              0.8670
Average Precision (AP): 0.5627

 Test Metrics for: Pneumothorax
AUC Score:              0.9455
Average Precision (AP): 0.6955

 Test Metrics for: Pleural Effusion
AUC Score:              0.9282
Average Precision (AP): 0.1853

 Test Metrics for: Pleural Other
AUC Score:

In [12]:
#  Evaluate multi-label classification models on the test set
#  Strategy: Per-label evaluation by skipping NaNs individually

# Step 1: Predict each label individually, skipping NaNs
all_true = []
all_pred = []

for i, label_name in enumerate(label_names):
    print(f"Evaluating label: {label_name}")

    # Select non-NaN samples for this label
    mask = ~np.isnan(test_labels[:, i])
    X = test_embeddings[mask]
    y = test_labels[mask, i]

    test_ds = SingleLabelDataset(X, y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # Load model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)
    model.eval()

    preds = []
    true_vals = []

    with torch.no_grad():
        for batch in test_loader:
            x = batch["embedding"].to(device)
            y_batch = batch["lab"].to(device)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds.extend(probs.cpu().numpy())  # Use raw probabilities for PR-AUC
            true_vals.extend(y_batch.cpu().numpy())

    all_pred.append(np.array(preds).flatten())
    all_true.append(np.array(true_vals).flatten())

# Step 2: Compute metrics
micro_acc = accuracy_score(
    np.concatenate([t.round() for t in all_true]),
    np.concatenate([p > 0.5 for p in all_pred])
)
macro_acc = np.mean([
    accuracy_score(all_true[i].round(), (all_pred[i] > 0.5).astype(int)) for i in range(len(label_names))
])

# Step 3: Compute PR-AUC per label
pr_auc_per_label = []
for i in range(len(label_names)):
    if len(np.unique(all_true[i])) > 1:
        auc = average_precision_score(all_true[i], all_pred[i])
        pr_auc_per_label.append(auc)
    else:
        pr_auc_per_label.append(np.nan)  # Cannot compute AUC with one class only

# Step 4: Print results
print("\n Evaluation on Test Set using per-label skipping of NaNs:")
print(f"Micro Accuracy (label-wise overall):     {micro_acc:.4f}")
print(f"Macro Accuracy (avg per label):          {macro_acc:.4f}")

print("\n Per-Label PR-AUC (Precision-Recall AUC):")
for i, label_name in enumerate(label_names):
    auc = pr_auc_per_label[i]
    print(f"{label_name:<25} → PR-AUC: {auc:.4f}" if not np.isnan(auc) else f"{label_name:<25} → PR-AUC: N/A")

macro_pr_auc = np.nanmean(pr_auc_per_label)
print(f"\nMacro PR-AUC (avg across labels): {macro_pr_auc:.4f}")

print("\n Label-wise Distribution After Skipping NaNs Individually (per-label view):\n")

for i, label_name in enumerate(label_names):
    label_vals = test_labels[:, i]
    valid_mask = ~np.isnan(label_vals)
    label_clean = label_vals[valid_mask]

    count_0 = np.sum(label_clean == 0)
    count_1 = np.sum(label_clean == 1)
    print(f"{label_name:<25} → 0: {count_0:<5} | 1: {count_1:<5} | total: {len(label_clean):<5}")

Evaluating label: Enlarged Cardiomediastinum
Evaluating label: Cardiomegaly
Evaluating label: Lung Opacity
Evaluating label: Lung Lesion
Evaluating label: Edema
Evaluating label: Consolidation
Evaluating label: Pneumonia
Evaluating label: Atelectasis
Evaluating label: Pneumothorax
Evaluating label: Pleural Effusion
Evaluating label: Pleural Other
Evaluating label: Fracture
Evaluating label: Support Devices

 Evaluation on Test Set using per-label skipping of NaNs:
Micro Accuracy (label-wise overall):     0.9451
Macro Accuracy (avg per label):          0.9467

 Per-Label PR-AUC (Precision-Recall AUC):
Enlarged Cardiomediastinum → PR-AUC: 0.5773
Cardiomegaly              → PR-AUC: 0.5821
Lung Opacity              → PR-AUC: 0.3280
Lung Lesion               → PR-AUC: 0.5182
Edema                     → PR-AUC: 0.1155
Consolidation             → PR-AUC: 0.1335
Pneumonia                 → PR-AUC: 0.3330
Atelectasis               → PR-AUC: 0.5627
Pneumothorax              → PR-AUC: 0.6955
Pleu