In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import f1_score, average_precision_score
from tqdm import tqdm

In [2]:
data = np.load("mimic_multimodal_dataset.npz", allow_pickle=True)
embeddings = data["embeddings"]
labels = data["labels"]
texts = data["texts"]
subject_ids = data["subject_ids"]
study_ids = data["study_ids"]

# Get number of samples
N = len(embeddings)
indices = np.random.RandomState(seed=42).permutation(N)

# Compute split sizes
n_train = int(0.8 * N)
n_valid = int(0.1 * N)
n_test = N - n_train - n_valid  # The rest goes to test

# Split indices
train_idx = indices[:n_train]
valid_idx = indices[n_train:n_train + n_valid]
test_idx  = indices[n_train + n_valid:]

# Split data based on indices
train_set = {
    "embeddings": embeddings[train_idx],
    "labels": labels[train_idx],
    "texts": texts[train_idx],
    "subject_ids": subject_ids[train_idx],
    "study_ids": study_ids[train_idx]
}

valid_set = {
    "embeddings": embeddings[valid_idx],
    "labels": labels[valid_idx],
    "texts": texts[valid_idx],
    "subject_ids": subject_ids[valid_idx],
    "study_ids": study_ids[valid_idx]
}

test_set = {
    "embeddings": embeddings[test_idx],
    "labels": labels[test_idx],
    "texts": texts[test_idx],
    "subject_ids": subject_ids[test_idx],
    "study_ids": study_ids[test_idx]
}

In [3]:
# Define all pathologies list
pathologies = [
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices",
    ]

In [4]:
class MultimodalDataset(Dataset):
    def __init__(self, data_dict, tokenizer, max_length=128):
        self.embeddings = data_dict["embeddings"]
        self.labels = data_dict["labels"]
        self.texts = data_dict["texts"]
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        image_embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        # Tokenize report text using BERT tokenizer
        text = str(self.texts[idx])
        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "image": image_embedding,
            "input_ids": encoded["input_ids"].squeeze(0),  # [seq_len]
            "attention_mask": encoded["attention_mask"].squeeze(0),  # [seq_len]
            "label": label
        }

In [5]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

In [6]:

# Create datasets
train_dataset = MultimodalDataset(train_set, tokenizer)
valid_dataset = MultimodalDataset(valid_set, tokenizer)
test_dataset  = MultimodalDataset(test_set, tokenizer)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [7]:
# Check the shape of each component (number of samples & dimensions)
print("embeddings shape:", train_set["embeddings"].shape)      # (N, 1376)
print("labels shape:", train_set["labels"].shape)              # (N, 13)
print("texts length:", len(train_set["texts"]))                # N
print("subject_ids shape:", train_set["subject_ids"].shape)    # (N,)
print("study_ids shape:", train_set["study_ids"].shape)        # (N,)
# [B, 13]

embeddings shape: (36502, 1376)
labels shape: (36502, 13)
texts length: 36502
subject_ids shape: (36502,)
study_ids shape: (36502,)


In [8]:
i = 10 

print("\nSample index:", i)
print("Embedding vector (first 5 dims):", train_set["embeddings"][i][:5])
print("Label vector:", train_set["labels"][i])
print("Text sample:", train_set["texts"][i])
print("Subject ID:", train_set["subject_ids"][i])
print("Study ID:", train_set["study_ids"][i])


Sample index: 10
Embedding vector (first 5 dims): [-0.10261969 -0.9120668   0.88922745 -1.6479412  -0.41724488]
Label vector: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Text sample:                                  FINAL REPORT
 EXAMINATION:  CHEST RADIOGRAPHS
 
 INDICATION:  Chest pain.
 
 TECHNIQUE:  Chest, PA and lateral.
 
 COMPARISON:  None.
 
 FINDINGS: 
 
 The heart is normal in size. There is patchy calcification along the aortic
 arch. The lungs appear clear. There are no pleural effusions or pneumothorax.
 
 IMPRESSION: 
 
 No evidence of acute cardiopulmonary disease.

Subject ID: 19346252
Study ID: 51125476


In [9]:
class MaskedAsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=4.0, eps=1e-8):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.eps = eps

    def forward(self, logits, labels):
        mask = ~torch.isnan(labels)
        labels = torch.where(mask, labels, torch.zeros_like(labels))  # Fill NaN for safety

        probs = torch.sigmoid(logits)
        probs = torch.clamp(probs, self.eps, 1 - self.eps)

        pos_loss = labels * ((1 - probs) ** self.gamma_pos) * torch.log(probs)
        neg_loss = (1 - labels) * (probs ** self.gamma_neg) * torch.log(1 - probs)

        loss = - (pos_loss + neg_loss)
        loss = loss[mask]  # apply mask

        if loss.numel() == 0:
            return torch.tensor(0.0, device=logits.device, requires_grad=True)
        
        return loss.mean()

In [10]:
class MultiModalFusionModel(nn.Module):
    def __init__(self, image_dim=1376, text_dim=768, hidden_dim=512, num_labels=13):
        super().__init__()

        # Project image and text to shared hidden space
        self.img_proj = nn.Linear(image_dim, hidden_dim)
        self.txt_proj = nn.Linear(text_dim, hidden_dim)

        # Transformer encoder for cross-modal fusion
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True,
            activation='gelu'
        )
        self.fusion_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # Multi-label classification head
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, image_emb, text_emb):
        """
        image_emb: Tensor of shape [B, image_dim]
        text_emb: Tensor of shape [B, text_dim] (e.g., BERT CLS token)
        """
        # Project both modalities into common hidden space
        img_feat = self.img_proj(image_emb)   # [B, hidden_dim]
        txt_feat = self.txt_proj(text_emb)    # [B, hidden_dim]

        # Stack: [B, 2, hidden_dim] for transformer input
        fused = torch.stack([img_feat, txt_feat], dim=1)

        # Transformer-based fusion
        fused_out = self.fusion_encoder(fused)  # [B, 2, hidden_dim]

        # Mean pooling over modalities
        fusion_repr = fused_out.mean(dim=1)     # [B, hidden_dim]

        # Output: raw logits (use sigmoid + BCE or ASL externally)
        logits = self.classifier(fusion_repr)   # [B, num_labels]
        return logits

In [11]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ClinicalBERT model
text_encoder = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT").to(device)

# Initialize the multimodal model
model = MultiModalFusionModel().to(device)

# Forward pass example
for batch in train_loader:
    image = batch["image"].to(device)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)

    with torch.no_grad():
        text_feat = text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state[:, 0, :]  # [B, 768]

    outputs = model(image, text_feat)
    break

In [12]:
criterion = MaskedAsymmetricLoss(gamma_pos=0.0, gamma_neg=4.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [13]:
def evaluate_model(model, dataloader, text_encoder, device):
    model.eval()
    all_preds, all_targets = [], []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat)
            preds = (torch.sigmoid(logits) > 0.5).float()

            # Flatten & mask NaNs for evaluation
            mask = ~torch.isnan(label)
            all_preds.append(preds[mask].cpu().numpy())
            all_targets.append(label[mask].cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)

    return f1_score(y_true, y_pred, average="macro", zero_division=0)

In [14]:
def train_model(model, train_loader, valid_loader, text_encoder, loss_fn, optimizer, device, num_epochs=1):
    model.to(device)
    text_encoder.to(device)
    text_encoder.eval()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        skipped_batches = 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in loop:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            # Skip batches where all labels are NaN
            if torch.isnan(label).all():
                print("Skipping batch with all NaN labels.")
                skipped_batches += 1
                continue

            # Extract text features
            with torch.no_grad():
                text_feat = text_encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                ).last_hidden_state[:, 0, :]  # [CLS] token

            logits = model(image, text_feat)
            loss = loss_fn(logits, label)

            if torch.isnan(loss):
                print("Skipping batch due to NaN loss.")
                skipped_batches += 1
                continue

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        avg_loss = total_loss / (len(train_loader) - skipped_batches + 1e-6)
        print(f"\n Epoch {epoch+1} completed. Avg Loss = {avg_loss:.4f} | Skipped Batches: {skipped_batches}")

In [15]:
train_model(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    text_encoder=text_encoder,
    loss_fn=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=1
)

Epoch 1/1: 100%|██████████| 1141/1141 [25:23<00:00,  1.34s/it, loss=0.0956]


 Epoch 1 completed. Avg Loss = 0.0490 | Skipped Batches: 0





In [16]:
torch.save(model.state_dict(), "multimodal_fusion_epoch1.pt")

In [17]:
def evaluate_model(model, dataloader, text_encoder, device):
    model.eval()
    model.to(device)
    text_encoder.eval()
    text_encoder.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat)
            probs = torch.sigmoid(logits)

            all_preds.append(probs.cpu())
            all_labels.append(label.cpu())

    all_preds = torch.cat(all_preds, dim=0)  # [N, num_labels]
    all_labels = torch.cat(all_labels, dim=0)  # [N, num_labels]

    num_labels = all_labels.shape[1]
    f1_per_label = []

    for i in range(num_labels):
        y_true = all_labels[:, i]
        y_pred = all_preds[:, i]

        mask = ~torch.isnan(y_true)
        if mask.sum() == 0:
            continue  # Skip this label (completely missing)

        y_true = y_true[mask].numpy()
        y_pred = (y_pred[mask] > 0.5).float().numpy()

        f1 = f1_score(y_true, y_pred, zero_division=0)
        f1_per_label.append(f1)

    macro_f1 = np.mean(f1_per_label)
    print(f"Macro F1: {macro_f1:.4f}")
    return macro_f1


In [18]:
def evaluate_per_label(model, dataloader, text_encoder, device, label_names=None):
    model.eval()
    model.to(device)
    text_encoder.eval()
    text_encoder.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat)
            probs = torch.sigmoid(logits)

            all_preds.append(probs.cpu())
            all_labels.append(label.cpu())

    all_preds = torch.cat(all_preds, dim=0)  # [N, num_labels]
    all_labels = torch.cat(all_labels, dim=0)  # [N, num_labels]

    num_labels = all_labels.shape[1]
    f1_results = []
    ap_results = []

    for i in range(num_labels):
        y_true = all_labels[:, i]
        y_score = all_preds[:, i]

        mask = ~torch.isnan(y_true)
        if mask.sum() == 0:
            f1_results.append(None)
            ap_results.append(None)
            continue

        y_true = y_true[mask].numpy()
        y_pred = (y_score[mask] > 0.5).float().numpy()
        y_score = y_score[mask].numpy()

        f1 = f1_score(y_true, y_pred, zero_division=0)
        ap = average_precision_score(y_true, y_score)

        f1_results.append(f1)
        ap_results.append(ap)

    # Print results
    print("\n Per-label evaluation:")
    for i in range(num_labels):
        name = label_names[i] if label_names else f"Label {i}"
        if f1_results[i] is not None:
            print(f"{name:<25} | F1: {f1_results[i]:.3f} | AP: {ap_results[i]:.3f}")
        else:
            print(f"{name:<25} | F1:   N/A | AP:   N/A")

    macro_f1 = np.nanmean([f for f in f1_results if f is not None])
    print(f"\n Macro F1 (valid labels): {macro_f1:.4f}")
    return f1_results, ap_results, macro_f1

In [19]:
label_names = [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
    "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]

evaluate_per_label(model, test_loader, text_encoder, device, label_names)


 Per-label evaluation:
Enlarged Cardiomediastinum | F1: 0.598 | AP: 0.597
Cardiomegaly              | F1: 0.546 | AP: 0.629
Lung Opacity              | F1: 0.345 | AP: 0.350
Lung Lesion               | F1: 0.517 | AP: 0.509
Edema                     | F1: 0.215 | AP: 0.160
Consolidation             | F1: 0.248 | AP: 0.409
Pneumonia                 | F1: 0.402 | AP: 0.391
Atelectasis               | F1: 0.597 | AP: 0.639
Pneumothorax              | F1: 0.637 | AP: 0.712
Pleural Effusion          | F1: 0.200 | AP: 0.222
Pleural Other             | F1: 0.348 | AP: 0.433
Fracture                  | F1: 0.262 | AP: 0.246
Support Devices           | F1: 0.419 | AP: 0.512

 Macro F1 (valid labels): 0.4103


([0.5978152929493545,
  0.5456012913640033,
  0.34509803921568627,
  0.5171102661596958,
  0.21524663677130046,
  0.24825174825174826,
  0.40162271805273836,
  0.5966850828729282,
  0.6368593238822247,
  0.2,
  0.3483043079743355,
  0.2619047619047619,
  0.4189723320158103],
 [np.float64(0.5969954061204436),
  np.float64(0.6292393412009127),
  np.float64(0.3498741646655957),
  np.float64(0.5094693797951755),
  np.float64(0.15974951584784547),
  np.float64(0.4089238899782475),
  np.float64(0.39082492343833475),
  np.float64(0.6394214503969154),
  np.float64(0.71237870322538),
  np.float64(0.22241993384563286),
  np.float64(0.43265489241697885),
  np.float64(0.2464231249109456),
  np.float64(0.51173342293466)],
 np.float64(0.410267061647276))

In [32]:
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, average_precision_score, f1_score

def safe_get(report, class_label, metric):
    if isinstance(class_label, (int, float)):
        keys_to_try = [class_label, str(class_label), f"{float(class_label):.1f}"]
    else:
        keys_to_try = [class_label]

    for key in keys_to_try:
        if key in report and metric in report[key]:
            return report[key][metric]
    return np.nan

def evaluate_detailed_per_label(model, dataloader, text_encoder, device, label_names=None, save_path=None):
    model.eval()
    text_encoder.eval()
    model.to(device)
    text_encoder.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].to(device)

            text_feat = text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
            logits = model(image, text_feat)
            probs = torch.sigmoid(logits)

            all_preds.append(probs.cpu())
            all_labels.append(label.cpu())

    all_preds = torch.cat(all_preds, dim=0)  # [N, num_labels]
    all_labels = torch.cat(all_labels, dim=0)  # [N, num_labels]

    num_labels = all_labels.shape[1]
    report_dict = {}

    for i in range(num_labels):
        y_true = all_labels[:, i]
        y_score = all_preds[:, i]

        mask = ~torch.isnan(y_true)
        if mask.sum() == 0:
            continue

        y_true = y_true[mask].numpy()
        y_pred = (y_score[mask] > 0.5).float().numpy()
        y_score = y_score[mask].numpy()

        report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
        acc = (y_true == y_pred).mean()
        ap = average_precision_score(y_true, y_score)

        label_name = label_names[i] if label_names else f"Label {i}"
        report_dict[label_name] = {
            # Class 0
            'precision_0': safe_get(report, 0, 'precision'),
            'recall_0': safe_get(report, 0, 'recall'),
            'f1-score_0': safe_get(report, 0, 'f1-score'),

            # Class 1
            'precision_1': safe_get(report, 1, 'precision'),
            'recall_1': safe_get(report, 1, 'recall'),
            'f1-score_1': safe_get(report, 1, 'f1-score'),

            # Macro avg
            'precision': safe_get(report, 'macro avg', 'precision'),
            'recall': safe_get(report, 'macro avg', 'recall'),
            'f1-score': safe_get(report, 'macro avg', 'f1-score'),
            'support': safe_get(report, 'macro avg', 'support'),

            # AP and Accuracy
            'AP': ap,
            'accuracy': acc
        }

    # Save as Transposed DataFrame (horizontal table)
    df = pd.DataFrame(report_dict).T.round(4)
    df = df.transpose()
    if save_path:
        df.to_csv(save_path)
        print(f"\n Saved detailed metrics to (horizontal): {save_path}")

    return df


In [33]:
df = evaluate_detailed_per_label(
    model, test_loader, text_encoder, device,
    label_names=label_names,
    save_path="detailed_metrics_per_label.csv"
)


 Saved detailed metrics to (horizontal): detailed_metrics_per_label.csv
