In [21]:
# Standard Libraries 
import os
import random
import pickle
from typing import Dict, List

# Third-Party Libraries 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Evaluation (sklearn) 
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [22]:
data = np.load("mimic_embed_data.npz")
train_embeddings = data["train_embeddings"]
train_labels = data["train_labels"]
valid_embeddings = data["valid_embeddings"]
valid_labels = data["valid_labels"]
test_embeddings = data["test_embeddings"]
test_labels = data["test_labels"]

In [23]:
class SingleLabelDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return {
            "embedding": self.embeddings[idx],
            "lab": self.labels[idx].unsqueeze(0)  #  shape: (1,)
        }

In [24]:
def get_single_label_data(embeddings, labels, label_idx):
    mask = ~np.isnan(labels[:, label_idx])
    return embeddings[mask], labels[mask, label_idx]

In [25]:
class GEGLU(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim * 2)

    def forward(self, x):
        x_proj = self.proj(x)
        x1, x2 = x_proj.chunk(2, dim=-1)
        return F.gelu(x1) * x2

In [27]:
class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.gelu_glu = GEGLU(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=None):
        # Multi-head self-attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                            key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # GEGLU feedforward
        src2 = self.linear2(self.gelu_glu(src))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [28]:
class TabTransformer(nn.Module):
    def __init__(self, input_dim=1376, hidden_dim=128, output_dim=1, nhead=8, nlayers=4, dropout=0.1):
        super(TabTransformer, self).__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)

        encoder_layer = CustomTransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 8,
            dropout=dropout
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayers)

        self.classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.norm(x)
        x = x.unsqueeze(1)
        x = self.encoder(x)
        x = x.squeeze(1)
        return self.classifier(x)

In [29]:
class MaskedAsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0, gamma_neg=4):
        super().__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg

    def forward(self, logits, labels, mask):
        probs = torch.sigmoid(logits)
        # Positive loss
        pos_loss = labels * torch.pow(1 - probs, self.gamma_pos) * torch.log(probs + 1e-8)
        # Negative loss
        neg_loss = (1 - labels) * torch.pow(probs, self.gamma_neg) * torch.log(1 - probs + 1e-8)
        # Combined loss
        loss = - (pos_loss + neg_loss)
        # Apply mask
        return (loss * mask).sum() / mask.sum()

In [30]:
def train_single_label_model(train_loader, val_loader, label_name, save_path, gamma_pos, gamma_neg):
    print(f" Training with gamma_pos={gamma_pos}, gamma_neg={gamma_neg}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TabTransformer().to(device)
    criterion = MaskedAsymmetricLoss(gamma_pos=gamma_pos, gamma_neg=gamma_neg)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    for epoch in range(15):
        model.train()
        total_loss = 0

        for batch in train_loader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device).float().view(-1, 1)  # 保证 batch_size x 1
            mask = torch.ones_like(y)

            logits = model(x)
            loss = criterion(logits, y, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"[{label_name}] Epoch {epoch+1} - Training loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f" Model for {label_name} saved to {save_path}")

    return model

In [31]:
def grid_search_asl_for_label(label_index, label_name, gamma_pos_list, gamma_neg_list):
    print(f"\n Grid search for label: {label_name}")
    
    # Get training and validation data for the specific label
    train_X, train_y = get_single_label_data(train_embeddings, train_labels, label_index)
    val_X, val_y = get_single_label_data(valid_embeddings, valid_labels, label_index)

    train_ds = SingleLabelDataset(train_X, train_y)
    val_ds = SingleLabelDataset(val_X, val_y)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    best_ap = 0
    best_config = None

    for gamma_pos in gamma_pos_list:
        for gamma_neg in gamma_neg_list:
            print(f"→ gamma_pos={gamma_pos}, gamma_neg={gamma_neg}")

            # Initialize model
            model = TabTransformer().to(device)
            criterion = MaskedAsymmetricLoss(gamma_pos=gamma_pos, gamma_neg=gamma_neg)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

            # Train for a few epochs to evaluate config
            for epoch in range(8):
                model.train()
                for batch in train_loader:
                    x = batch["embedding"].to(device)
                    y = batch["lab"].to(device).float().view(-1, 1) 
                    mask = torch.ones_like(y)
                    logits = model(x)
                    loss = criterion(logits, y, mask)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            # Evaluate on validation set using Average Precision
            model.eval()
            all_probs = []
            all_labels = []
            with torch.no_grad():
                for batch in val_loader:
                    x = batch["embedding"].to(device)
                    y = batch["lab"].to(device).float().view(-1, 1) 
                    logits = model(x)
                    probs = torch.sigmoid(logits).squeeze()
                    all_probs.extend(probs.cpu().numpy())
                    all_labels.extend(y.cpu().numpy())

            ap = average_precision_score(all_labels, all_probs)
            print(f"    → AP={ap:.4f}")

            if ap > best_ap:
                best_ap = ap
                best_config = (gamma_pos, gamma_neg)

    print(f" Best config for {label_name}: gamma_pos={best_config[0]}, gamma_neg={best_config[1]} with AP={best_ap:.4f}")
    return best_config

In [33]:
label_names =  [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion",
    "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]

gamma_pos_list = [0.0, 0.5]
gamma_neg_list = [3.0, 4.0, 5.0, 6.0]

search_results = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, label_name in enumerate(label_names):
    print(f"\n Running ASL grid search for label: {label_name}")

    best_gamma_pos, best_gamma_neg = grid_search_asl_for_label(
        i, label_name, gamma_pos_list, gamma_neg_list
    )

    # Train final model with best gamma combination
    train_X, train_y = get_single_label_data(train_embeddings, train_labels, i)
    val_X, val_y = get_single_label_data(valid_embeddings, valid_labels, i)

    train_ds = SingleLabelDataset(train_X, train_y)
    val_ds = SingleLabelDataset(val_X, val_y)

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

    model_path = f"transformer_label_{i}_{label_name}.pt"
    model = train_single_label_model(train_loader, val_loader, label_name, model_path,
                                    gamma_pos=best_gamma_pos, gamma_neg=best_gamma_neg)

    # Save best config for this label
    search_results.append({
        "label": label_name,
        "gamma_pos": best_gamma_pos,
        "gamma_neg": best_gamma_neg
    })

# Save all best gamma configurations to CSV
import pandas as pd
pd.DataFrame(search_results).to_csv("best_gamma_config_per_label.csv", index=False)
print("\n All best gamma configs saved to best_gamma_config_per_label.csv")


 Running ASL grid search for label: Enlarged Cardiomediastinum

 Grid search for label: Enlarged Cardiomediastinum
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.5482
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.5519
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.5324
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.5511
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.5456
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.5536
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.5542
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.5446
 Best config for Enlarged Cardiomediastinum: gamma_pos=0.5, gamma_neg=5.0 with AP=0.5542
 Training with gamma_pos=0.5, gamma_neg=5.0




[Enlarged Cardiomediastinum] Epoch 1 - Training loss: 0.0599
[Enlarged Cardiomediastinum] Epoch 2 - Training loss: 0.0552
[Enlarged Cardiomediastinum] Epoch 3 - Training loss: 0.0549
[Enlarged Cardiomediastinum] Epoch 4 - Training loss: 0.0543
[Enlarged Cardiomediastinum] Epoch 5 - Training loss: 0.0541
[Enlarged Cardiomediastinum] Epoch 6 - Training loss: 0.0538
[Enlarged Cardiomediastinum] Epoch 7 - Training loss: 0.0536
[Enlarged Cardiomediastinum] Epoch 8 - Training loss: 0.0535
[Enlarged Cardiomediastinum] Epoch 9 - Training loss: 0.0532
[Enlarged Cardiomediastinum] Epoch 10 - Training loss: 0.0529
[Enlarged Cardiomediastinum] Epoch 11 - Training loss: 0.0529
[Enlarged Cardiomediastinum] Epoch 12 - Training loss: 0.0526
[Enlarged Cardiomediastinum] Epoch 13 - Training loss: 0.0525
[Enlarged Cardiomediastinum] Epoch 14 - Training loss: 0.0521
[Enlarged Cardiomediastinum] Epoch 15 - Training loss: 0.0520
 Model for Enlarged Cardiomediastinum saved to transformer_label_0_Enlarged Car



    → AP=0.5831
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.5881
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.5776
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.5749
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.5731
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.5790
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.5619
 Best config for Cardiomegaly: gamma_pos=0.0, gamma_neg=5.0 with AP=0.5881
 Training with gamma_pos=0.0, gamma_neg=5.0




[Cardiomegaly] Epoch 1 - Training loss: 0.0737
[Cardiomegaly] Epoch 2 - Training loss: 0.0691
[Cardiomegaly] Epoch 3 - Training loss: 0.0683
[Cardiomegaly] Epoch 4 - Training loss: 0.0679
[Cardiomegaly] Epoch 5 - Training loss: 0.0676
[Cardiomegaly] Epoch 6 - Training loss: 0.0673
[Cardiomegaly] Epoch 7 - Training loss: 0.0670
[Cardiomegaly] Epoch 8 - Training loss: 0.0669
[Cardiomegaly] Epoch 9 - Training loss: 0.0666
[Cardiomegaly] Epoch 10 - Training loss: 0.0666
[Cardiomegaly] Epoch 11 - Training loss: 0.0662
[Cardiomegaly] Epoch 12 - Training loss: 0.0661
[Cardiomegaly] Epoch 13 - Training loss: 0.0654
[Cardiomegaly] Epoch 14 - Training loss: 0.0652
[Cardiomegaly] Epoch 15 - Training loss: 0.0648
 Model for Cardiomegaly saved to transformer_label_1_Cardiomegaly.pt

 Running ASL grid search for label: Lung Opacity

 Grid search for label: Lung Opacity
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.4178
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.3931
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.4259
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.3852
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.4095
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.4049
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.3885
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.3958
 Best config for Lung Opacity: gamma_pos=0.0, gamma_neg=5.0 with AP=0.4259
 Training with gamma_pos=0.0, gamma_neg=5.0




[Lung Opacity] Epoch 1 - Training loss: 0.0235
[Lung Opacity] Epoch 2 - Training loss: 0.0214
[Lung Opacity] Epoch 3 - Training loss: 0.0210
[Lung Opacity] Epoch 4 - Training loss: 0.0208
[Lung Opacity] Epoch 5 - Training loss: 0.0207
[Lung Opacity] Epoch 6 - Training loss: 0.0206
[Lung Opacity] Epoch 7 - Training loss: 0.0204
[Lung Opacity] Epoch 8 - Training loss: 0.0204
[Lung Opacity] Epoch 9 - Training loss: 0.0202
[Lung Opacity] Epoch 10 - Training loss: 0.0200
[Lung Opacity] Epoch 11 - Training loss: 0.0199
[Lung Opacity] Epoch 12 - Training loss: 0.0198
[Lung Opacity] Epoch 13 - Training loss: 0.0195
[Lung Opacity] Epoch 14 - Training loss: 0.0193
[Lung Opacity] Epoch 15 - Training loss: 0.0189
 Model for Lung Opacity saved to transformer_label_2_Lung Opacity.pt

 Running ASL grid search for label: Lung Lesion

 Grid search for label: Lung Lesion
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.4177
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.4234
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.4482
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.4167
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.4260
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.4265
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.4275
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.4226
 Best config for Lung Lesion: gamma_pos=0.0, gamma_neg=5.0 with AP=0.4482
 Training with gamma_pos=0.0, gamma_neg=5.0




[Lung Lesion] Epoch 1 - Training loss: 0.0266
[Lung Lesion] Epoch 2 - Training loss: 0.0244
[Lung Lesion] Epoch 3 - Training loss: 0.0241
[Lung Lesion] Epoch 4 - Training loss: 0.0238
[Lung Lesion] Epoch 5 - Training loss: 0.0237
[Lung Lesion] Epoch 6 - Training loss: 0.0233
[Lung Lesion] Epoch 7 - Training loss: 0.0232
[Lung Lesion] Epoch 8 - Training loss: 0.0231
[Lung Lesion] Epoch 9 - Training loss: 0.0229
[Lung Lesion] Epoch 10 - Training loss: 0.0227
[Lung Lesion] Epoch 11 - Training loss: 0.0225
[Lung Lesion] Epoch 12 - Training loss: 0.0222
[Lung Lesion] Epoch 13 - Training loss: 0.0219
[Lung Lesion] Epoch 14 - Training loss: 0.0217
[Lung Lesion] Epoch 15 - Training loss: 0.0213
 Model for Lung Lesion saved to transformer_label_3_Lung Lesion.pt

 Running ASL grid search for label: Edema

 Grid search for label: Edema
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.2197
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.2233
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.1992
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.2071
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.2348
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.2202
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.2132
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.2336
 Best config for Edema: gamma_pos=0.5, gamma_neg=3.0 with AP=0.2348
 Training with gamma_pos=0.5, gamma_neg=3.0




[Edema] Epoch 1 - Training loss: 0.0262
[Edema] Epoch 2 - Training loss: 0.0250
[Edema] Epoch 3 - Training loss: 0.0248
[Edema] Epoch 4 - Training loss: 0.0245
[Edema] Epoch 5 - Training loss: 0.0244
[Edema] Epoch 6 - Training loss: 0.0241
[Edema] Epoch 7 - Training loss: 0.0242
[Edema] Epoch 8 - Training loss: 0.0238
[Edema] Epoch 9 - Training loss: 0.0236
[Edema] Epoch 10 - Training loss: 0.0231
[Edema] Epoch 11 - Training loss: 0.0230
[Edema] Epoch 12 - Training loss: 0.0226
[Edema] Epoch 13 - Training loss: 0.0224
[Edema] Epoch 14 - Training loss: 0.0218
[Edema] Epoch 15 - Training loss: 0.0215
 Model for Edema saved to transformer_label_4_Edema.pt

 Running ASL grid search for label: Consolidation

 Grid search for label: Consolidation
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.1482
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.1435
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.1464
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.1511
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.1480
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.1363
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.1473
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.1445
 Best config for Consolidation: gamma_pos=0.0, gamma_neg=6.0 with AP=0.1511
 Training with gamma_pos=0.0, gamma_neg=6.0




[Consolidation] Epoch 1 - Training loss: 0.0283
[Consolidation] Epoch 2 - Training loss: 0.0274
[Consolidation] Epoch 3 - Training loss: 0.0272
[Consolidation] Epoch 4 - Training loss: 0.0270
[Consolidation] Epoch 5 - Training loss: 0.0267
[Consolidation] Epoch 6 - Training loss: 0.0266
[Consolidation] Epoch 7 - Training loss: 0.0263
[Consolidation] Epoch 8 - Training loss: 0.0263
[Consolidation] Epoch 9 - Training loss: 0.0260
[Consolidation] Epoch 10 - Training loss: 0.0258
[Consolidation] Epoch 11 - Training loss: 0.0254
[Consolidation] Epoch 12 - Training loss: 0.0254
[Consolidation] Epoch 13 - Training loss: 0.0251
[Consolidation] Epoch 14 - Training loss: 0.0249
[Consolidation] Epoch 15 - Training loss: 0.0248
 Model for Consolidation saved to transformer_label_5_Consolidation.pt

 Running ASL grid search for label: Pneumonia

 Grid search for label: Pneumonia
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.3704
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.3642
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.3840
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.3712
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.3805
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.3693
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.3842
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.3871
 Best config for Pneumonia: gamma_pos=0.5, gamma_neg=6.0 with AP=0.3871
 Training with gamma_pos=0.5, gamma_neg=6.0




[Pneumonia] Epoch 1 - Training loss: 0.0243
[Pneumonia] Epoch 2 - Training loss: 0.0233
[Pneumonia] Epoch 3 - Training loss: 0.0230
[Pneumonia] Epoch 4 - Training loss: 0.0229
[Pneumonia] Epoch 5 - Training loss: 0.0227
[Pneumonia] Epoch 6 - Training loss: 0.0225
[Pneumonia] Epoch 7 - Training loss: 0.0226
[Pneumonia] Epoch 8 - Training loss: 0.0223
[Pneumonia] Epoch 9 - Training loss: 0.0221
[Pneumonia] Epoch 10 - Training loss: 0.0219
[Pneumonia] Epoch 11 - Training loss: 0.0219
[Pneumonia] Epoch 12 - Training loss: 0.0218
[Pneumonia] Epoch 13 - Training loss: 0.0215
[Pneumonia] Epoch 14 - Training loss: 0.0213
[Pneumonia] Epoch 15 - Training loss: 0.0210
 Model for Pneumonia saved to transformer_label_6_Pneumonia.pt

 Running ASL grid search for label: Atelectasis

 Grid search for label: Atelectasis
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.5435
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.5456
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.5449
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.5475
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.5453
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.5476
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.5449
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.5501
 Best config for Atelectasis: gamma_pos=0.5, gamma_neg=6.0 with AP=0.5501
 Training with gamma_pos=0.5, gamma_neg=6.0




[Atelectasis] Epoch 1 - Training loss: 0.0635
[Atelectasis] Epoch 2 - Training loss: 0.0609
[Atelectasis] Epoch 3 - Training loss: 0.0604
[Atelectasis] Epoch 4 - Training loss: 0.0598
[Atelectasis] Epoch 5 - Training loss: 0.0599
[Atelectasis] Epoch 6 - Training loss: 0.0595
[Atelectasis] Epoch 7 - Training loss: 0.0593
[Atelectasis] Epoch 8 - Training loss: 0.0593
[Atelectasis] Epoch 9 - Training loss: 0.0591
[Atelectasis] Epoch 10 - Training loss: 0.0590
[Atelectasis] Epoch 11 - Training loss: 0.0587
[Atelectasis] Epoch 12 - Training loss: 0.0585
[Atelectasis] Epoch 13 - Training loss: 0.0583
[Atelectasis] Epoch 14 - Training loss: 0.0581
[Atelectasis] Epoch 15 - Training loss: 0.0578
 Model for Atelectasis saved to transformer_label_7_Atelectasis.pt

 Running ASL grid search for label: Pneumothorax

 Grid search for label: Pneumothorax
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.7136
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.7131
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.7110
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.7194
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.7016
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.6967
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.7086
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.7164
 Best config for Pneumothorax: gamma_pos=0.0, gamma_neg=6.0 with AP=0.7194
 Training with gamma_pos=0.0, gamma_neg=6.0




[Pneumothorax] Epoch 1 - Training loss: 0.0565
[Pneumothorax] Epoch 2 - Training loss: 0.0517
[Pneumothorax] Epoch 3 - Training loss: 0.0510
[Pneumothorax] Epoch 4 - Training loss: 0.0508
[Pneumothorax] Epoch 5 - Training loss: 0.0507
[Pneumothorax] Epoch 6 - Training loss: 0.0502
[Pneumothorax] Epoch 7 - Training loss: 0.0502
[Pneumothorax] Epoch 8 - Training loss: 0.0496
[Pneumothorax] Epoch 9 - Training loss: 0.0495
[Pneumothorax] Epoch 10 - Training loss: 0.0491
[Pneumothorax] Epoch 11 - Training loss: 0.0490
[Pneumothorax] Epoch 12 - Training loss: 0.0486
[Pneumothorax] Epoch 13 - Training loss: 0.0482
[Pneumothorax] Epoch 14 - Training loss: 0.0476
[Pneumothorax] Epoch 15 - Training loss: 0.0474
 Model for Pneumothorax saved to transformer_label_8_Pneumothorax.pt

 Running ASL grid search for label: Pleural Effusion

 Grid search for label: Pleural Effusion
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.3490
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.3715
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.3707
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.3190
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.3566
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.3559
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.3391
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.3189
 Best config for Pleural Effusion: gamma_pos=0.0, gamma_neg=4.0 with AP=0.3715
 Training with gamma_pos=0.0, gamma_neg=4.0




[Pleural Effusion] Epoch 1 - Training loss: 0.0178
[Pleural Effusion] Epoch 2 - Training loss: 0.0161
[Pleural Effusion] Epoch 3 - Training loss: 0.0157
[Pleural Effusion] Epoch 4 - Training loss: 0.0156
[Pleural Effusion] Epoch 5 - Training loss: 0.0155
[Pleural Effusion] Epoch 6 - Training loss: 0.0155
[Pleural Effusion] Epoch 7 - Training loss: 0.0151
[Pleural Effusion] Epoch 8 - Training loss: 0.0149
[Pleural Effusion] Epoch 9 - Training loss: 0.0147
[Pleural Effusion] Epoch 10 - Training loss: 0.0145
[Pleural Effusion] Epoch 11 - Training loss: 0.0143
[Pleural Effusion] Epoch 12 - Training loss: 0.0139
[Pleural Effusion] Epoch 13 - Training loss: 0.0136
[Pleural Effusion] Epoch 14 - Training loss: 0.0131
[Pleural Effusion] Epoch 15 - Training loss: 0.0130
 Model for Pleural Effusion saved to transformer_label_9_Pleural Effusion.pt

 Running ASL grid search for label: Pleural Other

 Grid search for label: Pleural Other
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.3869
→ gamma_pos=0.0



    → AP=0.3845
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.3814
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.3917
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.3831
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.3886
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.3855
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.3915
 Best config for Pleural Other: gamma_pos=0.0, gamma_neg=6.0 with AP=0.3917
 Training with gamma_pos=0.0, gamma_neg=6.0




[Pleural Other] Epoch 1 - Training loss: 0.0466
[Pleural Other] Epoch 2 - Training loss: 0.0452
[Pleural Other] Epoch 3 - Training loss: 0.0447
[Pleural Other] Epoch 4 - Training loss: 0.0445
[Pleural Other] Epoch 5 - Training loss: 0.0443
[Pleural Other] Epoch 6 - Training loss: 0.0441
[Pleural Other] Epoch 7 - Training loss: 0.0439
[Pleural Other] Epoch 8 - Training loss: 0.0437
[Pleural Other] Epoch 9 - Training loss: 0.0435
[Pleural Other] Epoch 10 - Training loss: 0.0437
[Pleural Other] Epoch 11 - Training loss: 0.0433
[Pleural Other] Epoch 12 - Training loss: 0.0432
[Pleural Other] Epoch 13 - Training loss: 0.0430
[Pleural Other] Epoch 14 - Training loss: 0.0427
[Pleural Other] Epoch 15 - Training loss: 0.0425
 Model for Pleural Other saved to transformer_label_10_Pleural Other.pt

 Running ASL grid search for label: Fracture

 Grid search for label: Fracture
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.1984
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.2178
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.2379
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.2098
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.1647
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.1990
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.2025
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.1989
 Best config for Fracture: gamma_pos=0.0, gamma_neg=5.0 with AP=0.2379
 Training with gamma_pos=0.0, gamma_neg=5.0




[Fracture] Epoch 1 - Training loss: 0.0168
[Fracture] Epoch 2 - Training loss: 0.0157
[Fracture] Epoch 3 - Training loss: 0.0154
[Fracture] Epoch 4 - Training loss: 0.0154
[Fracture] Epoch 5 - Training loss: 0.0152
[Fracture] Epoch 6 - Training loss: 0.0151
[Fracture] Epoch 7 - Training loss: 0.0150
[Fracture] Epoch 8 - Training loss: 0.0149
[Fracture] Epoch 9 - Training loss: 0.0148
[Fracture] Epoch 10 - Training loss: 0.0146
[Fracture] Epoch 11 - Training loss: 0.0145
[Fracture] Epoch 12 - Training loss: 0.0145
[Fracture] Epoch 13 - Training loss: 0.0143
[Fracture] Epoch 14 - Training loss: 0.0142
[Fracture] Epoch 15 - Training loss: 0.0140
 Model for Fracture saved to transformer_label_11_Fracture.pt

 Running ASL grid search for label: Support Devices

 Grid search for label: Support Devices
→ gamma_pos=0.0, gamma_neg=3.0
    → AP=0.5114
→ gamma_pos=0.0, gamma_neg=4.0




    → AP=0.5102
→ gamma_pos=0.0, gamma_neg=5.0




    → AP=0.5203
→ gamma_pos=0.0, gamma_neg=6.0




    → AP=0.5281
→ gamma_pos=0.5, gamma_neg=3.0




    → AP=0.5108
→ gamma_pos=0.5, gamma_neg=4.0




    → AP=0.5258
→ gamma_pos=0.5, gamma_neg=5.0




    → AP=0.5211
→ gamma_pos=0.5, gamma_neg=6.0




    → AP=0.5096
 Best config for Support Devices: gamma_pos=0.0, gamma_neg=6.0 with AP=0.5281
 Training with gamma_pos=0.0, gamma_neg=6.0




[Support Devices] Epoch 1 - Training loss: 0.0416
[Support Devices] Epoch 2 - Training loss: 0.0386
[Support Devices] Epoch 3 - Training loss: 0.0384
[Support Devices] Epoch 4 - Training loss: 0.0381
[Support Devices] Epoch 5 - Training loss: 0.0378
[Support Devices] Epoch 6 - Training loss: 0.0376
[Support Devices] Epoch 7 - Training loss: 0.0375
[Support Devices] Epoch 8 - Training loss: 0.0371
[Support Devices] Epoch 9 - Training loss: 0.0369
[Support Devices] Epoch 10 - Training loss: 0.0369
[Support Devices] Epoch 11 - Training loss: 0.0366
[Support Devices] Epoch 12 - Training loss: 0.0362
[Support Devices] Epoch 13 - Training loss: 0.0359
[Support Devices] Epoch 14 - Training loss: 0.0355
[Support Devices] Epoch 15 - Training loss: 0.0352
 Model for Support Devices saved to transformer_label_12_Support Devices.pt

 All best gamma configs saved to best_gamma_config_per_label.csv


In [34]:
def evaluate_model(model, dataloader, device):
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device).squeeze()  # shape: (batch_size,)

            logits = model(x).squeeze()  # shape: (batch_size,)
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except ValueError:
        auc = float('nan')

    return auc

In [35]:

for i, label_name in enumerate(label_names):
    print(f"\n Testing model for: {label_name}")

    # 1. Prepare test data for the current label
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Rebuild the model architecture and load saved weights
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Evaluate on the test set
    test_auc = evaluate_model(model, test_loader, device)
    print(f"[{label_name}] Test AUC: {test_auc:.4f}")


 Testing model for: Enlarged Cardiomediastinum
[Enlarged Cardiomediastinum] Test AUC: 0.9004

 Testing model for: Cardiomegaly




[Cardiomegaly] Test AUC: 0.9123

 Testing model for: Lung Opacity




[Lung Opacity] Test AUC: 0.8974

 Testing model for: Lung Lesion




[Lung Lesion] Test AUC: 0.9593

 Testing model for: Edema




[Edema] Test AUC: 0.7945

 Testing model for: Consolidation




[Consolidation] Test AUC: 0.7759

 Testing model for: Pneumonia




[Pneumonia] Test AUC: 0.8907

 Testing model for: Atelectasis




[Atelectasis] Test AUC: 0.8632

 Testing model for: Pneumothorax




[Pneumothorax] Test AUC: 0.9463

 Testing model for: Pleural Effusion




[Pleural Effusion] Test AUC: 0.9224

 Testing model for: Pleural Other




[Pleural Other] Test AUC: 0.8377

 Testing model for: Fracture




[Fracture] Test AUC: 0.9191

 Testing model for: Support Devices
[Support Devices] Test AUC: 0.9099




In [36]:
def safe_get(report, class_label, metric):
    """
    Safely extract a specific class or average metric.
    - For class_label as 0 or 1, try multiple key formats;
    - For average entries like 'macro avg', lookup directly.
    """
    if isinstance(class_label, (int, float)):
        keys_to_try = [class_label, str(class_label), f"{float(class_label):.1f}"]
    else:
        keys_to_try = [class_label]

    for key in keys_to_try:
        if key in report and metric in report[key]:
            return report[key][metric]
    return np.nan


def evaluate_full_report(model, dataloader, device):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()

            all_labels.extend(y.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Accuracy: how many predictions match the ground truth
    accuracy = (np.array(all_labels) == np.array(all_preds)).mean()

    return classification_report(all_labels, all_preds, digits=4, output_dict=True), accuracy


# Collect reports for all labels
report_dict = {}

for i, label_name in enumerate(label_names):
    print("=" * 30)
    print(f" Test Evaluation Report for: {label_name}")
    print("=" * 30)

    # 1. Load test data for current label
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Load corresponding model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Get classification report and accuracy
    report, acc = evaluate_full_report(model, test_loader, device)

    report_dict[label_name] = {
        # Class 0
        'precision_0': safe_get(report, 0, 'precision'),
        'recall_0': safe_get(report, 0, 'recall'),
        'f1-score_0': safe_get(report, 0, 'f1-score'),

        # Class 1
        'precision_1': safe_get(report, 1, 'precision'),
        'recall_1': safe_get(report, 1, 'recall'),
        'f1-score_1': safe_get(report, 1, 'f1-score'),

        # Macro average
        'precision': safe_get(report, 'macro avg', 'precision'),
        'recall': safe_get(report, 'macro avg', 'recall'),
        'f1-score': safe_get(report, 'macro avg', 'f1-score'),
        'support': safe_get(report, 'macro avg', 'support'),

        # Accuracy
        'accuracy': acc
    }

# Save as DataFrame
df = pd.DataFrame(report_dict)
df.index.name = "Metric"
df = df.round(4)
df.to_csv("test_metrics_per_label.csv")
print("\n Saved report with accuracy to: test_metrics_per_label.csv")

 Test Evaluation Report for: Enlarged Cardiomediastinum




 Test Evaluation Report for: Cardiomegaly




 Test Evaluation Report for: Lung Opacity
 Test Evaluation Report for: Lung Lesion




 Test Evaluation Report for: Edema
 Test Evaluation Report for: Consolidation




 Test Evaluation Report for: Pneumonia
 Test Evaluation Report for: Atelectasis




 Test Evaluation Report for: Pneumothorax




 Test Evaluation Report for: Pleural Effusion
 Test Evaluation Report for: Pleural Other




 Test Evaluation Report for: Fracture
 Test Evaluation Report for: Support Devices





 Saved report with accuracy to: test_metrics_per_label.csv


In [37]:
def evaluate_auc_ap(model, dataloader, device):
    model.eval()
    all_probs = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            x = batch["embedding"].to(device)
            y = batch["lab"].to(device).squeeze()

            logits = model(x).squeeze()
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(y.cpu().numpy())

    try:
        auc = roc_auc_score(all_targets, all_probs)
    except ValueError:
        auc = float('nan')

    try:
        ap = average_precision_score(all_targets, all_probs)
    except ValueError:
        ap = float('nan')

    return auc, ap

#  Evaluate on Test Set

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, label_name in enumerate(label_names):
    print("=" * 40)
    print(f" Test Metrics for: {label_name}")
    print("=" * 40)

    # 1. Load test data
    test_X, test_y = get_single_label_data(test_embeddings, test_labels, i)
    test_ds = SingleLabelDataset(test_X, test_y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # 2. Load model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)

    # 3. Evaluate
    auc, ap = evaluate_auc_ap(model, test_loader, device)
    print(f"AUC Score:              {auc:.4f}")
    print(f"Average Precision (AP): {ap:.4f}\n")

 Test Metrics for: Enlarged Cardiomediastinum




AUC Score:              0.9004
Average Precision (AP): 0.5751

 Test Metrics for: Cardiomegaly
AUC Score:              0.9123
Average Precision (AP): 0.5640

 Test Metrics for: Lung Opacity




AUC Score:              0.8974
Average Precision (AP): 0.2841

 Test Metrics for: Lung Lesion




AUC Score:              0.9593
Average Precision (AP): 0.5065

 Test Metrics for: Edema




AUC Score:              0.7945
Average Precision (AP): 0.0899

 Test Metrics for: Consolidation
AUC Score:              0.7759
Average Precision (AP): 0.1395

 Test Metrics for: Pneumonia




AUC Score:              0.8907
Average Precision (AP): 0.3361

 Test Metrics for: Atelectasis




AUC Score:              0.8632
Average Precision (AP): 0.5545

 Test Metrics for: Pneumothorax




AUC Score:              0.9463
Average Precision (AP): 0.6722

 Test Metrics for: Pleural Effusion




AUC Score:              0.9224
Average Precision (AP): 0.2009

 Test Metrics for: Pleural Other




AUC Score:              0.8377
Average Precision (AP): 0.3985

 Test Metrics for: Fracture
AUC Score:              0.9191
Average Precision (AP): 0.3248

 Test Metrics for: Support Devices




AUC Score:              0.9099
Average Precision (AP): 0.4141



In [38]:
#  Evaluate multi-label classification models on the test set
#  Strategy: Per-label evaluation by skipping NaNs individually

# Step 1: Predict each label individually, skipping NaNs
all_true = []
all_pred = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, label_name in enumerate(label_names):
    print(f"Evaluating label: {label_name}")

    # Select non-NaN samples for this label
    mask = ~np.isnan(test_labels[:, i])
    X = test_embeddings[mask]
    y = test_labels[mask, i]

    test_ds = SingleLabelDataset(X, y)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

    # Load model
    model = TabTransformer(input_dim=1376, hidden_dim=128, output_dim=1)
    model.load_state_dict(torch.load(f"transformer_label_{i}_{label_name}.pt"))
    model = model.to(device)
    model.eval()

    preds = []
    true_vals = []

    with torch.no_grad():
        for batch in test_loader:
            x = batch["embedding"].to(device)
            y_batch = batch["lab"].to(device)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds.extend(probs.cpu().numpy())  # Use raw probabilities for PR-AUC
            true_vals.extend(y_batch.cpu().numpy())

    all_pred.append(np.array(preds).flatten())
    all_true.append(np.array(true_vals).flatten())

# Step 2: Compute metrics
micro_acc = accuracy_score(
    np.concatenate([t.round() for t in all_true]),
    np.concatenate([p > 0.5 for p in all_pred])
)
macro_acc = np.mean([
    accuracy_score(all_true[i].round(), (all_pred[i] > 0.5).astype(int)) for i in range(len(label_names))
])

# Step 3: Compute PR-AUC per label
pr_auc_per_label = []
for i in range(len(label_names)):
    if len(np.unique(all_true[i])) > 1:
        auc = average_precision_score(all_true[i], all_pred[i])
        pr_auc_per_label.append(auc)
    else:
        pr_auc_per_label.append(np.nan)  # Cannot compute AUC with one class only

# Step 4: Print results
print("\n Evaluation on Test Set using per-label skipping of NaNs:")
print(f"Micro Accuracy (label-wise overall):     {micro_acc:.4f}")
print(f"Macro Accuracy (avg per label):          {macro_acc:.4f}")

print("\n Per-Label PR-AUC (Precision-Recall AUC):")
for i, label_name in enumerate(label_names):
    auc = pr_auc_per_label[i]
    print(f"{label_name:<25} → PR-AUC: {auc:.4f}" if not np.isnan(auc) else f"{label_name:<25} → PR-AUC: N/A")

macro_pr_auc = np.nanmean(pr_auc_per_label)
print(f"\nMacro PR-AUC (avg across labels): {macro_pr_auc:.4f}")

print("\n Label-wise Distribution After Skipping NaNs Individually (per-label view):\n")

for i, label_name in enumerate(label_names):
    label_vals = test_labels[:, i]
    valid_mask = ~np.isnan(label_vals)
    label_clean = label_vals[valid_mask]

    count_0 = np.sum(label_clean == 0)
    count_1 = np.sum(label_clean == 1)
    print(f"{label_name:<25} → 0: {count_0:<5} | 1: {count_1:<5} | total: {len(label_clean):<5}")

Evaluating label: Enlarged Cardiomediastinum




Evaluating label: Cardiomegaly
Evaluating label: Lung Opacity
Evaluating label: Lung Lesion
Evaluating label: Edema
Evaluating label: Consolidation
Evaluating label: Pneumonia
Evaluating label: Atelectasis
Evaluating label: Pneumothorax
Evaluating label: Pleural Effusion
Evaluating label: Pleural Other
Evaluating label: Fracture
Evaluating label: Support Devices

 Evaluation on Test Set using per-label skipping of NaNs:
Micro Accuracy (label-wise overall):     0.8818
Macro Accuracy (avg per label):          0.8841

 Per-Label PR-AUC (Precision-Recall AUC):
Enlarged Cardiomediastinum → PR-AUC: 0.5751
Cardiomegaly              → PR-AUC: 0.5640
Lung Opacity              → PR-AUC: 0.2841
Lung Lesion               → PR-AUC: 0.5065
Edema                     → PR-AUC: 0.0899
Consolidation             → PR-AUC: 0.1395
Pneumonia                 → PR-AUC: 0.3361
Atelectasis               → PR-AUC: 0.5545
Pneumothorax              → PR-AUC: 0.6722
Pleural Effusion          → PR-AUC: 0.2009
Pleura