In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import f1_score, average_precision_score
from tqdm import tqdm

# Standard Libraries 
import os
import random
import pickle
from typing import Dict, List

# Third-Party Libraries 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Evaluation (sklearn) 
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data = np.load("mimic_multimodal_dataset.npz", allow_pickle=True)
embeddings = data["embeddings"]
labels = data["labels"]
texts = data["texts"]
subject_ids = data["subject_ids"]
study_ids = data["study_ids"]

# Get number of samples
N = len(embeddings)
indices = np.random.RandomState(seed=42).permutation(N)

# Compute split sizes
n_train = int(0.8 * N)
n_valid = int(0.1 * N)
n_test = N - n_train - n_valid  # The rest goes to test

# Split indices
train_idx = indices[:n_train]
valid_idx = indices[n_train:n_train + n_valid]
test_idx  = indices[n_train + n_valid:]

# Split data based on indices
train_set = {
    "embeddings": embeddings[train_idx],
    "labels": labels[train_idx],
    "texts": texts[train_idx],
    "subject_ids": subject_ids[train_idx],
    "study_ids": study_ids[train_idx]
}

valid_set = {
    "embeddings": embeddings[valid_idx],
    "labels": labels[valid_idx],
    "texts": texts[valid_idx],
    "subject_ids": subject_ids[valid_idx],
    "study_ids": study_ids[valid_idx]
}

test_set = {
    "embeddings": embeddings[test_idx],
    "labels": labels[test_idx],
    "texts": texts[test_idx],
    "subject_ids": subject_ids[test_idx],
    "study_ids": study_ids[test_idx]
}

In [4]:
# Define all pathologies list
pathologies = [
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices",
    ]

In [5]:
class MultimodalDataset(Dataset):
    def __init__(self, data_dict, tokenizer, max_length=128):
        self.embeddings = data_dict["embeddings"]
        self.labels = data_dict["labels"]
        self.texts = data_dict["texts"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        image_embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # Tokenize report text using BERT tokenizer
        text = str(self.texts[idx])
        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "image": image_embedding,
            "input_ids": encoded["input_ids"].squeeze(0),  # [seq_len]
            "attention_mask": encoded["attention_mask"].squeeze(0),  # [seq_len]
            "label": label
        }

In [6]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

In [7]:

# Create datasets
def get_single_label_dataset(data_dict, label_index):
    new_dict = data_dict.copy()
    new_dict["labels"] = data_dict["labels"][:, label_index:label_index+1]  # shape [N, 1]
    return MultimodalDataset(new_dict, tokenizer)


In [8]:
# Check the shape of each component (number of samples & dimensions)
print("embeddings shape:", train_set["embeddings"].shape)      # (N, 1376)
print("labels shape:", train_set["labels"].shape)              # (N, 13)
print("texts length:", len(train_set["texts"]))                # N
print("subject_ids shape:", train_set["subject_ids"].shape)    # (N,)
print("study_ids shape:", train_set["study_ids"].shape)        # (N,)
# [B, 13]

embeddings shape: (36502, 1376)
labels shape: (36502, 13)
texts length: 36502
subject_ids shape: (36502,)
study_ids shape: (36502,)


In [9]:
i = 10 

print("\nSample index:", i)
print("Embedding vector (first 5 dims):", train_set["embeddings"][i][:5])
print("Label vector:", train_set["labels"][i])
print("Text sample:", train_set["texts"][i])
print("Subject ID:", train_set["subject_ids"][i])
print("Study ID:", train_set["study_ids"][i])


Sample index: 10
Embedding vector (first 5 dims): [-0.10261969 -0.9120668   0.88922745 -1.6479412  -0.41724488]
Label vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Text sample:                                  FINAL REPORT
 EXAMINATION:  CHEST RADIOGRAPHS
 
 INDICATION:  Chest pain.
 
 TECHNIQUE:  Chest, PA and lateral.
 
 COMPARISON:  None.
 
 FINDINGS: 
 
 The heart is normal in size. There is patchy calcification along the aortic
 arch. The lungs appear clear. There are no pleural effusions or pneumothorax.
 
 IMPRESSION: 
 
 No evidence of acute cardiopulmonary disease.

Subject ID: 19346252
Study ID: 51125476


In [10]:
class MaskedAsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=4.0, eps=1e-8):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.eps = eps

    def forward(self, logits, labels):
        mask = ~torch.isnan(labels)
        labels = torch.where(mask, labels, torch.zeros_like(labels))  # Fill NaN for safety

        probs = torch.sigmoid(logits)
        probs = torch.clamp(probs, self.eps, 1 - self.eps)

        pos_loss = labels * ((1 - probs) ** self.gamma_pos) * torch.log(probs)
        neg_loss = (1 - labels) * (probs ** self.gamma_neg) * torch.log(1 - probs)

        loss = - (pos_loss + neg_loss)
        loss = loss[mask]  # apply mask

        if loss.numel() == 0:
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        return loss.mean()

In [11]:
class MultiModalFusionModel(nn.Module):
    def __init__(self, image_dim=1376, text_dim=768, hidden_dim=512, num_labels=1):
        super().__init__()

        # Project image and text to shared hidden space
        self.img_proj = nn.Linear(image_dim, hidden_dim)
        self.txt_proj = nn.Linear(text_dim, hidden_dim)

        # Transformer encoder for cross-modal fusion
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.fusion_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # Multi-label classification head
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, image_emb, text_emb):
        """
        image_emb: Tensor of shape [B, image_dim]
        text_emb: Tensor of shape [B, text_dim] (e.g., BERT CLS token)
        """
        # Project both modalities into common hidden space
        img_feat = self.img_proj(image_emb)   # [B, hidden_dim]
        txt_feat = self.txt_proj(text_emb)    # [B, hidden_dim]

        # Stack: [B, 2, hidden_dim] for transformer input
        fused = torch.stack([img_feat, txt_feat], dim=1)

        # Transformer-based fusion
        fused_out = self.fusion_encoder(fused)  # [B, 2, hidden_dim]

        # Mean pooling over modalities
        fusion_repr = fused_out.mean(dim=1)     # [B, hidden_dim]

        # Output: raw logits (use sigmoid + BCE or ASL externally)
        logits = self.classifier(fusion_repr)   # [B, num_labels]
        return logits

In [14]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ClinicalBERT model
text_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# Initialize the multimodal model
model = MultiModalFusionModel().to(device)

# Forward pass example
for batch in train_loader:
    image = batch["image"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    with torch.no_grad():
        text_feat = text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state[:, 0, :]  # [B, 768]

    outputs = model(image, text_feat)
    break

In [15]:
criterion = MaskedAsymmetricLoss(gamma_pos=0.0, gamma_neg=4.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [16]:
def evaluate_model(model, dataloader, text_encoder, device):
    model.eval()
    all_preds, all_targets = [], []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat)
            preds = (torch.sigmoid(logits) > 0.5).float()

            # Flatten & mask NaNs for evaluation
            mask = ~torch.isnan(label)
            all_preds.append(preds[mask].cpu().numpy())
            all_targets.append(label[mask].cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

In [17]:
def train_model(model, train_loader, valid_loader, text_encoder, loss_fn, optimizer, device, num_epochs=1):
    model.to(device)
    text_encoder.to(device)
    text_encoder.eval()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        skipped_batches = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in loop:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            # Skip batches where all labels are NaN
            if torch.isnan(label).all():
                print("Skipping batch with all NaN labels.")
                skipped_batches += 1
                continue

            # Extract text features
            with torch.no_grad():
                text_feat = text_encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                ).last_hidden_state[:, 0, :]  # [CLS] token

            logits = model(image, text_feat)
            loss = loss_fn(logits, label)

            if torch.isnan(loss):
                print("Skipping batch due to NaN loss.")
                skipped_batches += 1
                continue

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_loss = total_loss / (len(train_loader) - skipped_batches + 1e-6)
        print(f"\n Epoch {epoch+1} completed. Avg Loss = {avg_loss:.4f} | Skipped Batches: {skipped_batches}")

In [None]:
num_labels = train_set["labels"].shape[1]

def get_single_label_dataset(data_dict, label_index):
    new_dict = data_dict.copy()
    new_dict["labels"] = data_dict["labels"][:, label_index:label_index+1]  # Shape: [N, 1]
    return MultimodalDataset(new_dict, tokenizer)

for i in range(num_labels):
    label_name = pathologies[i]
    print(f"\n Training model for label {i}: {label_name}")

    # Build datasets for this label only
    train_dataset = get_single_label_dataset(train_set, i)
    valid_dataset = get_single_label_dataset(valid_set, i)
    test_dataset  = get_single_label_dataset(test_set, i)

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Model, optimizer, and loss function
    model = MultiModalFusionModel(num_labels=1).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = MaskedAsymmetricLoss(gamma_pos=0.0, gamma_neg=4.0)

    # Train the model
    train_model(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        text_encoder=text_encoder,
        loss_fn=criterion,
        optimizer=optimizer,
        device=device,
        num_epochs=1
    )

    # Save the model for this label
    torch.save(model.state_dict(), f"model_ovr_label{i}_{label_name}.pt")

In [None]:
def safe_get(report, class_label, metric):
    """
    Safely extract a specific class or average metric.
    - For class_label as 0 or 1, try multiple key formats;
    - For average entries like 'macro avg', lookup directly.
    """
    if isinstance(class_label, (int, float)):
        keys_to_try = [class_label, str(class_label), f"{float(class_label):.1f}"]
    else:
        keys_to_try = [class_label]

    for key in keys_to_try:
        if key in report and metric in report[key]:
            return report[key][metric]
    return np.nan


def evaluate_full_report(model, dataloader, text_encoder, device):
    model.eval()
    model.to(device)
    text_encoder.eval()
    text_encoder.to(device)

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            y = batch["label"].to(device)

            # BERT text embedding
            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]  # CLS token

            logits = model(image, text_feat)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            y_np = y.cpu().numpy()
            preds_np = preds.cpu().numpy()

            # Mask: remove rows with NaN in label
            mask = ~np.isnan(y_np).flatten()
            y_clean = y_np[mask]
            preds_clean = preds_np[mask]

            all_labels.extend(y_clean)
            all_preds.extend(preds_clean)


    accuracy = (np.array(all_labels) == np.array(all_preds)).mean()

    return classification_report(all_labels, all_preds, digits=4, output_dict=True), accuracy


# Collect reports for all labels
report_dict = {}

for i, label_name in enumerate(label_names):
    print("=" * 30)
    print(f" Test Evaluation Report for: {label_name}")
    print("=" * 30)

    # Use multimodal test dataset for current label
    test_dataset = get_single_label_dataset(test_set, i)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Load model
    model = MultiModalFusionModel(num_labels=1)
    model.load_state_dict(torch.load(f"model_ovr_label{i}_{label_name}.pt"))
    model = model.to(device)

    # Evaluate
    report, acc = evaluate_full_report(model, test_loader, text_encoder, device)

    report_dict[label_name] = {
        'precision_0': safe_get(report, 0, 'precision'),
        'recall_0': safe_get(report, 0, 'recall'),
        'f1-score_0': safe_get(report, 0, 'f1-score'),

        'precision_1': safe_get(report, 1, 'precision'),
        'recall_1': safe_get(report, 1, 'recall'),
        'f1-score_1': safe_get(report, 1, 'f1-score'),

        'precision': safe_get(report, 'macro avg', 'precision'),
        'recall': safe_get(report, 'macro avg', 'recall'),
        'f1-score': safe_get(report, 'macro avg', 'f1-score'),
        'support': safe_get(report, 'macro avg', 'support'),

        'accuracy': acc
    }

# Save as DataFrame
df = pd.DataFrame(report_dict)
df.index.name = "Metric"
df = df.round(4)
df.to_csv("test_metrics_per_label.csv")
print("\n Saved report with accuracy to: test_metrics_per_label.csv")

In [18]:
label_names =  [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion",
    "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]

def evaluate_auc_ap(model, dataloader, text_encoder, device):
    model.eval()
    text_encoder.eval()
    model.to(device)
    text_encoder.to(device)

    all_probs = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            y = batch["label"].to(device).squeeze()

            # Text embedding
            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat).squeeze()
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    # Clean
    y_np = np.array(all_targets)
    probs_np = np.array(all_probs)
    mask = ~np.isnan(y_np)
    y_np = y_np[mask]
    probs_np = probs_np[mask]

    # Handle edge case: only one class present
    if len(np.unique(y_np)) < 2:
        print("  Skipping label â€” only one class present in y_true.")
        return float('nan'), float('nan')

    try:
        auc = roc_auc_score(y_np, probs_np)
    except ValueError:
        auc = float('nan')

    try:
        ap = average_precision_score(y_np, probs_np)
    except ValueError:
        ap = float('nan')

    return auc, ap


for i, label_name in enumerate(label_names):
    print("=" * 40)
    print(f" Test Metrics for: {label_name}")
    print("=" * 40)

    # 1. Load test dataset (multimodal version)
    test_dataset = get_single_label_dataset(test_set, i)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 2. Load trained model
    model = MultiModalFusionModel(num_labels=1)
    model.load_state_dict(torch.load(f"model_ovr_label{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Evaluate AUC and AP
    auc, ap = evaluate_auc_ap(model, test_loader, text_encoder, device)
    print(f"AUC Score:              {auc:.4f}")
    print(f"Average Precision (AP): {ap:.4f}\n")


 Test Metrics for: Enlarged Cardiomediastinum
AUC Score:              0.9368
Average Precision (AP): 0.5858

 Test Metrics for: Cardiomegaly
AUC Score:              0.9338
Average Precision (AP): 0.6136

 Test Metrics for: Lung Opacity
AUC Score:              0.9426
Average Precision (AP): 0.3359

 Test Metrics for: Lung Lesion
AUC Score:              0.9596
Average Precision (AP): 0.4967

 Test Metrics for: Edema
AUC Score:              0.8788
Average Precision (AP): 0.1608

 Test Metrics for: Consolidation
AUC Score:              0.9383
Average Precision (AP): 0.4092

 Test Metrics for: Pneumonia
AUC Score:              0.9438
Average Precision (AP): 0.3990

 Test Metrics for: Atelectasis
AUC Score:              0.9309
Average Precision (AP): 0.6275

 Test Metrics for: Pneumothorax
AUC Score:              0.9563
Average Precision (AP): 0.7024

 Test Metrics for: Pleural Effusion
AUC Score:              0.9276
Average Precision (AP): 0.2040

 Test Metrics for: Pleural Other
AUC Score: