In [1]:
# !pip install kagglehub
# import kagglehub

# # Download latest version
# path = kagglehub.dataset_download("nmtclone/rsna-ich-mil")

# print("Path to dataset files:", path)
# # Move from src to des
# src = path + "/rsna-ich-mil/"
# dest = "/root/rsna-ich-mil/"

# mv = "mv " + src + " " + dest
# mv

# Import Libraries

In [2]:
!pip install gpytorch torchsummary iterative-stratification optuna
!pip install torch pydicom pandas scikit-learn scikit-image numpy opencv-python matplotlib



In [3]:
import optuna
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2 as transforms
from pytorch_metric_learning import losses
# from torch.cpu.amp import GradScaler

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay

import sys
sys.path.append('../')

from models.mil_resnet import CNN_ATT_GP_Multilabel, CNN_ATT_GP, CNN_ATT_GP_MIML
from utils import hard_negative_mining as hnm
import gpytorch
from layers.gaussian_process import SingletaskGPModel, PGLikelihood
from utils.early_stopping import EarlyStoppingForOptimization, EarlyStopping

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# Configurations
## GPU Configurations

In [5]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

GPU: NVIDIA GeForce RTX 4070 SUPER is available.
cuda


In [6]:
%load_ext autoreload
%autoreload 2

## Seed Everything

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything()

## Constants and Hyperparameters

In [8]:
import yaml

with open("../config.yaml") as file:
    config = yaml.safe_load(file)

# Accessing constants from config
HEIGHT = config['height']
WIDTH = config['width']
CHANNELS = config['channels']

TRAIN_BATCH_SIZE = config['train_batch_size']
VALID_BATCH_SIZE = config['valid_batch_size']
TEST_BATCH_SIZE = config['test_batch_size']
TEST_SIZE = config['test_size']
VALID_SIZE = config['valid_size']

TRAINING_TYPE = config['training_type']
GP_MODEL = config['gp_model']
GP_KERNEL = config['kernel_type']
MODEL_TYPE = config['model_type']

MAX_SLICES = config['max_slices']
SHAPE = tuple(config['shape'])

NUM_EPOCHS = config['num_epochs']
LEARNING_RATE = config['learning_rate']
INDUCING_POINTS = config['inducing_points']
THRESHOLD = config['threshold']

NUM_CLASSES = config['num_classes']

TARGET_LABELS = config['target_labels']

MODEL_PATH = config['model_path']
DEVICE = config['device']

PROJECTION_LOCATION = config['projection_location']
PROJECTION_HIDDEN_DIM = config['projection_hidden_dim']
PROJECTION_OUTPUT_DIM = config['projection_output_dim']

ATTENTION_HIDDEN_DIM = config['attention_hidden_dim']

In [9]:
KAGGLE = os.path.exists(('kaggle/input'))
REMOTE_SERVER = os.path.exists(('/workspace/rsna-ich-mil'))
ROOT_DIR = None
# DATA_DIR = ROOT_DIR + 'rsna-mil-training/' if KAGGLE else '../rsna-ich-mil/'
if KAGGLE:
  DATA_DIR = ROOT_DIR + 'rsna-mil-training/'
  DICOM_DIR = DATA_DIR
  CSV_PATH = DICOM_DIR + 'training_1000_scan_subset.csv'
elif REMOTE_SERVER:
  DATA_DIR = '/root/.cache/kagglehub/datasets/nmtclone/rsna-ich-mil/versions/4/rsna-ich-mil/'
  DICOM_DIR = DATA_DIR
  CSV_PATH = '/workspace/Brain-Stroke-Diagnosis/rsna/data_analyze/training_dataset_2_redundancy_1150_for_kaggle.csv'
  print('Running on remote server.')
else:
  DATA_DIR = '../rsna-ich-mil/'
  DICOM_DIR = DATA_DIR
  CSV_PATH = './data_analyze/training_dataset_2_redundancy.csv'

# CSV_PATH = DICOM_DIR + 'training_1000_scan_subset.csv' if KAGGLE else './data_analyze/training_dataset_2_redundancy.csv'
# patient_scan_labels = pd.read_csv(CSV_PATH, nrows=1150)
patient_scan_labels = pd.read_csv(CSV_PATH)
dicom_dir = DICOM_DIR if KAGGLE else DATA_DIR

In [10]:
patient_scan_labels.head()

Unnamed: 0,filename,labels,any,epidural,intraparenchymal,intraventricular,subarachnoid,subdural,patient_id,study_instance_uid,...,patient_label,z_axis,slice_thickness,selected_indices,patient_any,patient_subdural,patient_epidural,patient_intraparenchymal,patient_intraventricular,patient_subarachnoid
0,"['ID_37f32aed2.dcm', 'ID_d61a6a7b9.dcm', 'ID_4...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",ID_0002cd41,ID_66929e09d4,...,0,"[38.484, 43.517, 48.549, 53.582, 58.614, 63.64...","[2.52, 2.52, 2.52, 2.52, 2.52, 2.52, 2.52, 2.5...","[1, 3, 5, 7, 9, 11, 13, 15, 16, 17, 18, 19, 20...",0,0,0,0,0,0
1,"['ID_138d275c8.dcm', 'ID_447fa09d9.dcm', 'ID_0...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",ID_00054f3f,ID_8a449ae31b,...,0,"[71.9000244, 76.9000244, 81.9000244, 86.900024...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0,0,0,0,0,0
2,"['ID_520df89aa.dcm', 'ID_3b87d36d0.dcm', 'ID_9...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",ID_0006d192,ID_25690b4725,...,0,"[41.921, 49.421, 56.921, 64.421, 71.921, 79.42...","[3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.7...","[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25...",0,0,0,0,0,0
3,"['ID_203ef1efe.dcm', 'ID_0cec86087.dcm', 'ID_1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",ID_00086119,ID_fdde2979b0,...,0,"[35.556, 40.757, 45.959, 51.16, 56.362, 61.563...","[2.6, 2.6, 2.6, 2.6, 2.6, 2.6, 2.6, 2.6, 2.6, ...","[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 20, 21, 22...",0,0,0,0,0,0
4,"['ID_0785539ea.dcm', 'ID_30c100dbc.dcm', 'ID_3...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",ID_000e5623,ID_9a4be35b9a,...,0,"[272.0, 277.0, 282.0, 287.0, 292.0, 297.0, 302...","[5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0,0,0,0,0,0


# Data Preprocessing
## Splitting Data

In [11]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels['patient_label']
    if test_size > 0:
        # First, split off the test set
        train_val_labels, test_labels = train_test_split(
            patient_scan_labels,
            test_size=test_size,
            stratify=labels,
            random_state=random_state
        )

        # Calculate the validation size relative to the train_val set
        val_size_adjusted = val_size / (1 - test_size)

        # Split the train_val set into train and validation sets
        train_labels, val_labels = train_test_split(
            train_val_labels,
            test_size=val_size_adjusted,
            stratify=train_val_labels['patient_label'],
            random_state=random_state
        )
    else:
        train_labels, val_labels = train_test_split(
            patient_scan_labels,
            test_size=val_size,
            stratify=labels,
            random_state=random_state
        )
        test_labels = None

    return train_labels, val_labels, test_labels

from sklearn.model_selection import train_test_split
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
import numpy as np

def split_dataset_for_multilabel(patient_scan_labels, test_size=0.15, val_size=0.25, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels[['patient_any', 'patient_epidural', 'patient_intraparenchymal',
                                  'patient_intraventricular', 'patient_subarachnoid', 'patient_subdural']].values

    if test_size > 0:
        # First split: train + test
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        train_idx, test_idx = next(msss.split(patient_scan_labels, labels))

        train_labels = patient_scan_labels.iloc[train_idx]
        test_labels = patient_scan_labels.iloc[test_idx]

        # Second split: train + validation
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(train_labels, labels[train_idx]))

        train_labels_final = train_labels.iloc[train_idx]
        val_labels = train_labels.iloc[val_idx]

    else:
        # Only split into train and validation if test_size is 0
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(patient_scan_labels, labels))

        train_labels_final = patient_scan_labels.iloc[train_idx]
        val_labels = patient_scan_labels.iloc[val_idx]
        test_labels = None

    return train_labels_final, val_labels, test_labels

## Dataset Generator

In [12]:
from dataset_generators.RSNA_Dataset import MedicalScanDataset

In [13]:
class TrainDatasetGenerator(MedicalScanDataset):
    """Dataset class for training medical scan data."""

    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)


class TestDatasetGenerator(MedicalScanDataset):
    """Dataset class for testing medical scan data."""

    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

In [14]:
original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)

In [15]:
len(original_dataset)

21735

In [16]:
x, y, z, _ = original_dataset[0]
print(x.shape, y.shape, z.shape)

torch.Size([28, 1, 224, 224]) torch.Size([28]) torch.Size([])


In [17]:
def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    # original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=augmentor)
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
                      drop_last=True)


def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# Training and Validation
## Metrics Calculation
### Performance Metrics

In [18]:
def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average='macro'),
        "recall": recall_score(labels, predictions, average='macro'),
        "f1": f1_score(labels, predictions, average='macro')
    }


def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch + 1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

### Loss Function

In [19]:
def combined_loss(outputs, gp_distribution, target, alpha=0.5):
    if CHANNELS == 1:
        bce_loss_fn = nn.BCEWithLogitsLoss()
        bce_loss = bce_loss_fn(outputs.squeeze(), target.float())
        kl_divergence = gp_distribution.variational_strategy.kl_divergence()
        total_loss = (1 - alpha) * bce_loss + alpha * kl_divergence
    else:
        ce_loss_fn = nn.CrossEntropyLoss()
        ce_loss = ce_loss_fn(outputs.squeeze(), target.float())
        kl_divergence = gp_distribution.variational_strategy.kl_divergence()
        total_loss = (1 - alpha) * ce_loss + alpha * kl_divergence

    return total_loss

## Training Loop

In [20]:
from tqdm import tqdm  # Add this import at the top of your file

def train_epoch(model, likelihoods, data_loader, criterion_cl, criterion_bce, mlls, optimizer, variational_ngd_optimizer, scheduler,
                scaler, device):
    total_loss = 0.0
    total_nlls = 0.0
    alpha = 0.5
    predictions = []
    labels = []

    model.train()
    likelihoods.train()

    # for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in data_loader:
    for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in tqdm(data_loader):
        batch_data = batch_data.to(device)
        batch_patient_labels = batch_patient_labels.float().to(device)
        batch_multi_labels = batch_multi_labels.float().to(device)
        optimizer.zero_grad()

        if GP_MODEL == 'single_task':
            for i in range(NUM_CLASSES):
                variational_ngd_optimizer[i].zero_grad()

        if TRAINING_TYPE == 'end_to_end':
            outputs, gp_outputs, att_outputs = model(batch_data)

            if GP_MODEL == 'single_task':
                loss = 0
                for i in range(NUM_CLASSES):
                    loss += -mlls[i](gp_outputs[i], batch_multi_labels[:, i])
                loss.mean()
                loss += 0.5 * criterion_bce(outputs, batch_multi_labels)

                probs = [likelihoods[i](gp_outputs[i]).probs for i in range(NUM_CLASSES)]
                probabilities = torch.stack(probs, dim=1)
                preds = (probabilities >= 0.5).int()

            else:
                loss = -mlls(gp_outputs, batch_multi_labels) * 0.5 + 0.5 * criterion_bce(outputs, batch_multi_labels)
                loss = loss.mean()
                preds = (outputs >= 0.5).int()

            predictions.extend(preds.cpu().detach().numpy())

            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            if GP_MODEL == 'single_task':
                for i in range(NUM_CLASSES):
                    variational_ngd_optimizer[i].step()
            scheduler.step()

        if NUM_CLASSES == 1:
            labels.extend(batch_patient_labels.cpu().numpy())
        else:
            labels.extend(batch_multi_labels.cpu().numpy())

    return total_loss / len(data_loader), predictions, labels


def validate(model, likelihoods, data_loader, criterion_cl, criterion_bce, mlls, device):
    """Validate the model."""
    model.eval()
    likelihoods.eval()

    total_loss = 0.0
    alpha = 0.5
    predictions = []
    labels = []

    with torch.inference_mode():
        # for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in data_loader:
        for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in tqdm(data_loader):
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            batch_multi_labels = batch_multi_labels.float().to(device)

            if TRAINING_TYPE == 'end_to_end':
                outputs, gp_outputs, att_outputs = model(batch_data)
                if GP_MODEL == 'single_task':
                    loss = 0
                    for i in range(NUM_CLASSES):
                        loss += -mlls[i](gp_outputs[i], batch_multi_labels[:, i])
                    loss.mean()
                    loss += 0.5 * criterion_bce(outputs, batch_multi_labels)
                    total_loss += loss.item()

                    probabilities = torch.stack([likelihoods[i](gp_outputs[i]).probs for i in range(NUM_CLASSES)], dim=1)
                    preds = (probabilities >= 0.5).int()
                else:
                    loss = -mlls(gp_outputs, batch_multi_labels) * 0.5 + 0.5 * criterion_bce(outputs, batch_multi_labels)
                    loss = loss.mean()
                    total_loss += loss.item()
                    preds = (outputs >= 0.5).int()


                predictions.extend(preds.cpu().detach().numpy())

            if NUM_CLASSES == 1:
                labels.extend(batch_patient_labels.cpu().numpy())
            else:
                labels.extend(batch_multi_labels.cpu().numpy())
    return total_loss / len(data_loader), predictions, labels


def train_model(model, likelihoods, train_loader, val_loader, criterion_cl, criterion_bce, optimizer, num_epochs, learning_rate,
                device='cuda'):
    """Train the model and return the best model based on validation accuracy."""
    model = model.to(device)
    model.train()
    likelihoods.train()

    # Initialize Early Stopping
    early_stopping = EarlyStopping(patience=20, verbose=True)

    if GP_MODEL == 'single_task':
        mlls = [gpytorch.mlls.VariationalELBO(likelihoods[i], model.gp_layers[i], num_data=len(train_loader.dataset)) for
                i in range(NUM_CLASSES)]
        mlls = [mll.to(device) for mll in mlls]

        variational_ngd_optimizer = [
            gpytorch.optim.NGD(model.gp_layers[i].variational_parameters(), num_data=len(train_loader.dataset),
                               lr=0.01) for i in range(NUM_CLASSES)]
    else:
        mlls = gpytorch.mlls.VariationalELBO(likelihoods, model.gp_layers, num_data=len(train_loader.dataset))
        mlls = mlls.to(device)
        variational_ngd_optimizer = None

    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, steps_per_epoch=len(train_loader), epochs=num_epochs)

    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * num_epochs, 0.75 * num_epochs], gamma=0.1)

    # scaler = torch.amp.GradScaler('cuda')
    scaler = None
    best_val_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        train_loss, train_predictions, train_labels = train_epoch(model, likelihoods, train_loader, criterion_cl, criterion_bce,
                                                                  mlls, optimizer, variational_ngd_optimizer,
                                                                  scheduler, scaler, device)
        train_metrics = calculate_metrics(train_predictions, train_labels)
        print_epoch_stats(epoch, num_epochs, "train", train_loss, train_metrics)
        # Log training metrics to W&B
        wandb.log({
            "train/loss": train_loss,
            "train/accuracy": train_metrics["accuracy"],
            "train/precision": train_metrics["precision"],
            "train/recall": train_metrics["recall"],
            "train/f1": train_metrics["f1"],
        })

        # Validation phase
        val_loss, val_predictions, val_labels = validate(model, likelihoods, val_loader, criterion_cl, criterion_bce, mlls,
                                                         device)
        val_metrics = calculate_metrics(val_predictions, val_labels)
        print_epoch_stats(epoch, num_epochs, "validation", val_loss, val_metrics)
        # Log validation metrics to W&B
        wandb.log({
            "val/loss": val_loss,
            "val/accuracy": val_metrics["accuracy"],
            "val/precision": val_metrics["precision"],
            "val/recall": val_metrics["recall"],
            "val/f1": val_metrics["f1"],
        })

        # Early Stopping Check
        early_stopping(val_metrics["f1"], model)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            best_model_state = model.state_dict()

    # Load best model
    model.load_state_dict(best_model_state)
    # Optionally log the best model to W&B (if desired)
    print(f'Best Validation Accuracy: {best_val_accuracy}')
    wandb.log_artifact(wandb.Artifact("best_model", type="model", metadata={"accuracy": best_val_accuracy}))

    return model

## Model Evaluation

In [21]:
## Model Evaluation Functions
def evaluate_model(model, likelihoods, data_loader, device='cuda'):
    """Evaluate the model on the given data loader."""
    model = model.to(device)
    model.eval()
    likelihoods.eval()

    predictions = []
    labels = []

    with torch.inference_mode():
        for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            batch_multi_labels = batch_multi_labels.float().to(device)

            if TRAINING_TYPE == 'end_to_end':
                outputs, gp_outputs, att_outputs = model(batch_data)
                if GP_MODEL == 'single_task':
                    probabilities = torch.stack([likelihoods[i](gp_outputs[i]).probs for i in range(NUM_CLASSES)], dim=1)
                    preds = (probabilities >= 0.5).int()
                else:
                    preds = (outputs >= 0.5).int()
                predictions.extend(preds.cpu().detach().numpy())

            if NUM_CLASSES == 1:
                labels.extend(batch_patient_labels.cpu().numpy())
            else:
                labels.extend(batch_multi_labels.cpu().numpy())

    return np.array(predictions), np.array(labels)


def print_metrics(metrics):
    """Print the calculated metrics."""
    print(f"Test Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, "
          f"Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

In [22]:
## Visualization Functions
def plot_roc_curve(model, likelihoods, data_loader, device):
    """Plot the ROC curve for the model predictions."""
    model.eval()
    likelihoods.eval()
    labels = []
    predictions = []

    with torch.no_grad():
        for batch_data, batch_labels, batch_patient_labels, batch_multi_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            batch_multi_labels = batch_multi_labels.float().to(device)

            if TRAINING_TYPE == 'end_to_end':
                outputs, gp_outputs, _ = model(batch_data)
                if GP_MODEL == 'single_task':
                    probabilities = torch.stack([likelihoods[i](gp_outputs[i]).probs for i in range(NUM_CLASSES)], dim=1)
                    preds = (probabilities >= 0.5).int()

                else:
                    preds = (outputs >= 0.5).int()
                predictions.extend(preds.cpu().detach().numpy())

            if NUM_CLASSES == 1:  # Binary classification
                labels.extend(batch_patient_labels.cpu().numpy())
            else:
                labels.extend(batch_multi_labels.cpu().numpy())

    fpr, tpr, _ = roc_curve(labels, predictions)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()


def plot_confusion_matrix(model, likelihoods, data_loader, device):
    """Plot the confusion matrix for the model predictions."""
    predictions, labels = evaluate_model(model, likelihoods, data_loader, device)

    cm = confusion_matrix(labels, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title('Confusion Matrix')
    plt.show()

# Model Helper Functions

In [23]:
## Data Processing Functions
def load_model(model_class, model_path, params):
    """Load a trained model from a file."""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")

    model = model_class(params)
    try:
        state_dict = torch.load(model_path, map_location=torch.device('cuda'), weights_only=True)
        if not state_dict:
            raise ValueError(f"The state dictionary loaded from {model_path} is empty")
        model.load_state_dict(state_dict)
    except Exception as e:
        print(f"Error loading model from {model_path}: {str(e)}")
        print("Initializing model with random weights instead.")
        return model  # Return the model with random initialization

    return model.eval()


def get_test_results(model, test_loader, test_labels, device=DEVICE):
    """Get test results including patient information."""
    predictions, _ = evaluate_model(model, test_loader, device)

    results = []
    for i, row in enumerate(test_labels.itertuples(index=False)):
        result = {col: getattr(row, col) for col in test_labels.columns}
        result['prediction'] = predictions[i]
        results.append(result)

    return pd.DataFrame(results)

## NTXentLoss

In [24]:
class NTXentLoss(losses.NTXentLoss):
    def __init__(self, temperature, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
        self.temperature = temperature

    def forward(self, embeddings, labels=None, hard_pairs=None):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(embeddings, p=2, dim=1)

        if labels == None:
            # Self-supervised labels
            batch_size = feature_vectors_normalized.size(0) // 2  # Assuming equal size for both embeddings
            labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0)

        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )

        if labels == None:
            return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))
        if hard_pairs == None:
            return losses.SupConLoss(temperature=self.temperature)(logits, torch.squeeze(labels))
        return losses.SupConLoss(temperature=self.temperature)(logits, torch.squeeze(labels), hard_pairs)

# Main

In [25]:
from datetime import datetime

def main(mode='train'):
    # os.environ["WANDB_DISABLED"] = "true"
    current_time = datetime.now().strftime("%Y%m%d_%H%M")
    run_name = f"experiment_{current_time}_{GP_MODEL}_refiner_fc_{PROJECTION_HIDDEN_DIM}_output_{PROJECTION_OUTPUT_DIM}_attention_{ATTENTION_HIDDEN_DIM}_kernel_{GP_KERNEL}_model_{MODEL_TYPE}"

    # Initialize W&B with a specific run name
    wandb.init(project="MIL_Resnet_ICH", name=run_name)

    # Log hyperparameters
    config = wandb.config
    config.learning_rate = LEARNING_RATE
    config.batch_size = TRAIN_BATCH_SIZE
    config.num_epochs = NUM_EPOCHS

    # train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
    # test_labels = pd.read_csv('./data_analyze/testing_dataset_150_redundancy.csv')
    train_labels, val_labels, test_labels = split_dataset_for_multilabel(patient_scan_labels, test_size=TEST_SIZE)

    train_loader = get_train_loader(dicom_dir, train_labels, batch_size=TRAIN_BATCH_SIZE)
    val_loader = get_train_loader(dicom_dir, val_labels, batch_size=VALID_BATCH_SIZE)
    test_loader = get_test_loader(dicom_dir, test_labels, batch_size=TEST_BATCH_SIZE)

    params = {
        'channels': CHANNELS,  # Number of input channels (e.g., 1 for grayscale, 3 for RGB)
        'num_classes': NUM_CLASSES,  # Number of output classes for classification
        'drop_prob': 0.5,  # Dropout probability
        'inducing_points': INDUCING_POINTS,  # Number of inducing points for the Gaussian Process layer
        'projection_location': PROJECTION_LOCATION,  # Choose from 'after_resnet', 'after_attention', or 'after_gp'
        'projection_hidden_dim': PROJECTION_HIDDEN_DIM,  # Hidden dimension size for the projection head
        'projection_output_dim': PROJECTION_OUTPUT_DIM,  # Output dimension size for the projection head
        'attention_hidden_dim': ATTENTION_HIDDEN_DIM,  # Hidden dimension size for the attention head
        'gp_model': GP_MODEL,
        'kernel_type': GP_KERNEL,
        'model_type': MODEL_TYPE
    }

    if TRAINING_TYPE == 'end_to_end':
        # Instantiate the CNN_GP_ATT model with the specified parameters
        if NUM_CLASSES == 1:
            model = CNN_ATT_GP(params)
            likelihood = PGLikelihood()
            optimizer = optim.Adam([
                {'params': model.parameters(), 'lr': config.learning_rate},
                {'params': likelihood.parameters(), 'lr': config.learning_rate}
            ])
        else:
            if GP_MODEL == 'single_task':
                model = CNN_ATT_GP_Multilabel(params)
                likelihood = nn.ModuleList([PGLikelihood() for _ in range(NUM_CLASSES)])
                optimizer = optim.Adam([
                    {'params': model.parameters(), 'lr': config.learning_rate},
                    {'params': likelihood.parameters(), 'lr': config.learning_rate}
                ])
            elif GP_MODEL == 'multi_task':
                model = CNN_ATT_GP_MIML(params)
                likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=NUM_CLASSES)
                optimizer = optim.Adam([
                    {'params': model.parameters(), 'lr': config.learning_rate},
                    {'params': likelihood.parameters(), 'lr': config.learning_rate}
                ])


    criterion_cl = NTXentLoss(0.5)
    if NUM_CLASSES == 1:
        pos_weights = torch.tensor([5.0]).to(DEVICE)
    else:
        pos_weights = torch.tensor([5.0] * NUM_CLASSES).to(DEVICE)
    criterion_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    criterion_bce_wll = nn.BCEWithLogitsLoss()

    if mode == 'train':
        if TRAINING_TYPE == 'end_to_end':
            wandb.watch(model)  # Watch the model to log gradients and parameters
            trained_model = train_model(model, likelihood, train_loader, val_loader, criterion_cl, criterion_bce,
                                        optimizer, config.num_epochs, config.learning_rate, DEVICE)

            predictions, labels = evaluate_model(trained_model, likelihood, test_loader, DEVICE)
            metrics = calculate_metrics(predictions, labels)
            wandb.log(metrics)
            print_metrics(metrics)

            # plot_roc_curve(trained_model, likelihood, test_loader, DEVICE)
            # plot_confusion_matrix(trained_model, likelihood, test_loader, DEVICE)
            torch.save(trained_model.state_dict(), MODEL_PATH)

    # if TRAINING_TYPE == 'end_to_end':
    #     trained_model = load_model(CNN_ATT_GP, MODEL_PATH, params)
    #     predictions, labels = evaluate_model(trained_model, test_loader, DEVICE)
    #
    # metrics = calculate_metrics(predictions, labels)
    # wandb.log(metrics)
    # print_metrics(metrics)

# Results

In [None]:
if __name__ == "__main__":
    main(mode='train')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
