# Import Libraries

In [1]:
!pip install gpytorch
!pip install wandb
!pip install python-gdcm
# !pip install pylibjpeg pylibjpeg-libjpeg pylibjpeg-openjpeg



In [2]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import cv2
from skimage.transform import resize

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
from torchvision.transforms import v2 as transforms

import wandb

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay

from pytorch_metric_learning import losses, miners

import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy

import time

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

# Init GPU

In [4]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

GPU: NVIDIA GeForce RTX 4070 SUPER is available.
cuda


In [5]:
%load_ext autoreload
%autoreload 2

# Seed Everything

In [6]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Config Information

In [7]:
# Constants
HEIGHT = 224
WIDTH = 224
CHANNELS = 1

TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
TEST_BATCH_SIZE = 2
TEST_SIZE = 1
VALID_SIZE = 0.15

MAX_SLICES = 420
# MAX_SLICES = 420
SHAPE = (HEIGHT, WIDTH, CHANNELS)

NUM_EPOCHS = 20
LEARNING_RATE = 1e-4
INDUCING_POINTS = 128
THRESHOLD = 0.4

# TARGET_LABELS = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
TARGET_LABELS = ['intraparenchymal']

MODEL_PATH = 'results/trained_model.pth'
DEVICE = 'cuda'

# Data Path

In [8]:
DICOM_DIR = './archive'
MODEL_PATH = '../rsna/results/trained_model.pth'

CSV_PATH = './filtered_reads_without_nan.csv'
patient_scan_labels = pd.read_csv(CSV_PATH)
patient_scan_labels.head()

Unnamed: 0,name,Category,R1:ICH,R1:IPH,R1:IVH,R1:SDH,R1:EDH,R1:SAH,R2:ICH,R2:IPH,...,R3:SDH,R3:EDH,R3:SAH,ICH,IPH,IVH,SDH,EDH,SAH,Source Folder
0,CQ500CT427,B2,1,1,0,0,0,0,1,1,...,0,0,0,1,1,0,0,0,0,CT 2.55mm-2
1,CQ500CT181,B2,1,1,0,1,0,1,1,0,...,0,1,1,1,0,0,0,1,1,CT 5mm
2,CQ500CT99,B1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,CT PRE CONTRAST 5MM STD
3,CQ500CT47,B1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,CT PRE CONTRAST 5MM STD
4,CQ500CT195,B1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,CT PRE CONTRAST 5MM STD


# Data Preprocessing

In [9]:
def correct_dcm(dcm):
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000

def window_image(dcm, window_center, window_width):
    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
        correct_dcm(dcm)
    # extract_bone_mask(dcm)
    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept

    # Resize
    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)

    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def bsb_window(dcm):
    brain_img = window_image(dcm, 40, 90)
    subdural_img = window_image(dcm, 80, 200)
    soft_img = window_image(dcm, 40, 380)

    brain_img = (brain_img - 0) / 90
    subdural_img = (subdural_img - (-20)) / 200
    soft_img = (soft_img - (-150)) / 380

    if CHANNELS == 3:
        bsb_img = np.stack([brain_img, subdural_img, soft_img], axis=-1)
    else:
        bsb_img = brain_img
    return bsb_img.astype(np.float16)

In [10]:
def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    # Check if type of slice is dicom or an empty numpy array
    if (type(slice) == np.ndarray):
        slice = resize(slice, target_size, anti_aliasing=True)
        multichannel_slice = np.stack([slice, slice, slice], axis=-1)
        if CHANNELS == 3:
            return multichannel_slice.astype(np.float16)
        else:
            return slice.astype(np.float16)
    else:
        slice = bsb_window(slice)
        return slice.astype(np.float16)

In [11]:
def read_dicom_folder(folder_path, max_slices=MAX_SLICES, test_phase=None):
    # Filter and sort DICOM files directly based on ImagePositionPatient
    dicom_files = sorted(
        [f for f in os.listdir(folder_path) if f.endswith(".dcm")],
        key=lambda f: float(pydicom.dcmread(os.path.join(folder_path, f)).ImagePositionPatient[2])
    )[:max_slices]

    # Read and store slices
    slices = [pydicom.dcmread(os.path.join(folder_path, f)) for f in dicom_files]

    if test_phase == None:
        # Pad with black images if necessary
        if len(slices) < max_slices:
            black_image = np.zeros_like(slices[0].pixel_array)
            slices += [black_image] * (max_slices - len(slices))

        return slices[:max_slices]
    else:
        return slices

# Dataset and DataLoader

## Splitting Data

In [12]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # # Extract the labels from the DataFrame
    # labels = patient_scan_labels['ICH']
    #
    # # First, split off the test set
    # train_val_labels, test_labels = train_test_split(
    #     patient_scan_labels,
    #     test_size=test_size,
    #     stratify=labels,
    #     random_state=random_state
    # )
    #
    # # Calculate the validation size relative to the train_val set
    # val_size_adjusted = val_size / (1 - test_size)
    #
    # # Split the train_val set into train and validation sets
    # train_labels, val_labels = train_test_split(
    #     train_val_labels,
    #     test_size=val_size_adjusted,
    #     stratify=train_val_labels['ICH'],
    #     random_state=random_state
    # )
    #
    # return train_labels, val_labels, test_labels
    # Return all data as test set and no training or validation sets
    test_labels = patient_scan_labels
    return None, None, test_labels

## Processing the Data

In [13]:
def process_patient_data(dicom_dir, row, test_phase=None, num_instance=12, depth=5):
    patient_folder_name = str(row['name'])
    dicom_folder = str(row['Source Folder'])

    patient_folder_path = os.path.join(dicom_dir, patient_folder_name, 'Unknown Study', dicom_folder)

    if os.path.exists(patient_folder_path):
        slices = read_dicom_folder(patient_folder_path, test_phase=test_phase)
        preprocessed_slices = [torch.tensor(preprocess_slice(slice), dtype=torch.float32) for slice in slices]

        # Stack preprocessed slices into an array
        preprocessed_slices = torch.stack(preprocessed_slices, dim=0)  # (num_slices, height, width, channels)

        # # Labels are already in list form, so just convert them to a tensor
        # labels = torch.tensor(row['ICH'], dtype=torch.long)

        # # Fill labels with 0s if necessary
        # if len(preprocessed_slices) > len(labels):
        #     padded_labels = torch.zeros(len(preprocessed_slices), dtype=torch.long)
        #     padded_labels[:len(labels)] = labels
        # else:
        #     padded_labels = labels[:len(preprocessed_slices)]
        padded_labels = []

        return preprocessed_slices, padded_labels
    else:
        print(f'Folder {patient_folder_path} does not exist.')
        return None, None

## Dataset Generator

In [14]:
class CQ500Dataset:
    def __init__(self, dicom_dir, patient_scan_labels):
        self.dicom_dir = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def _parse_patient_scan_labels(self, patient_scan_labels):
        """Parse and validate patient scan labels."""
        patient_scan_labels['name'] = patient_scan_labels['name'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        patient_scan_labels['Source File'] = patient_scan_labels['Source File'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        # patient_scan_labels['R1:ICH'] = patient_scan_labels['R1:ICH'].astype(bool)
        return patient_scan_labels

    def _process_patient_data(self, row):
        return process_patient_data(self.dicom_dir, row)

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, labels = self._process_patient_data(row)

        preprocessed_slices = self._prepare_tensor(preprocessed_slices)
        patient_label = torch.tensor(bool(row['ICH']), dtype=torch.uint8)

        return preprocessed_slices, labels, patient_label

    def _prepare_tensor(self, preprocessed_slices):
        # Convert to numpy array and then to torch tensor
        preprocessed_slices = np.array(preprocessed_slices)
        preprocessed_slices = torch.tensor(preprocessed_slices, dtype=torch.float32)

        # Add an additional dimension for channel if it's missing (no augmentor)
        if preprocessed_slices.ndim == 3:
            preprocessed_slices = preprocessed_slices.unsqueeze(1)  # shape: [slices, 1, H, W]

        return preprocessed_slices  # Return without augmentation if augmentor is None

class CQ500Dataset_TestPhase:
    def __init__(self, dicom_dir, patient_scan_labels):
        self.dicom_dir = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def _parse_patient_scan_labels(self, patient_scan_labels):
        """Parse and validate patient scan labels."""
        patient_scan_labels['name'] = patient_scan_labels['name'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        patient_scan_labels['Source File'] = patient_scan_labels['Source File'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        patient_scan_labels['ICH'] = patient_scan_labels['ICH'].astype(bool)
        return patient_scan_labels

    def _process_patient_data(self, row):
        return process_patient_data(self.dicom_dir, row, 'TEST')

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, labels = self._process_patient_data(row)

        preprocessed_slices = self._prepare_tensor(preprocessed_slices)
        patient_label = torch.tensor(bool(row['ICH']), dtype=torch.uint8)

        return preprocessed_slices, labels, patient_label

    def _prepare_tensor(self, preprocessed_slices):
        # Convert to numpy array and then to torch tensor
        preprocessed_slices = np.array(preprocessed_slices)
        preprocessed_slices = torch.tensor(preprocessed_slices, dtype=torch.float32)

        # Add an additional dimension for channel if it's missing (no augmentor)
        if preprocessed_slices.ndim == 3:
            preprocessed_slices = preprocessed_slices.unsqueeze(1)  # shape: [slices, 1, H, W]

        return preprocessed_slices  # Return without augmentation if augmentor is None

In [15]:
class TrainDatasetGenerator(CQ500Dataset):
    """Dataset class for training medical scan data."""
    def __init__(self, data_dir, patient_scan_labels):
        super().__init__(data_dir, patient_scan_labels)

class TestDatasetGenerator(CQ500Dataset):
    """Dataset class for testing medical scan data."""
    def __init__(self, data_dir, patient_scan_labels):
        super().__init__(data_dir, patient_scan_labels)

In [16]:
original_dataset = TrainDatasetGenerator(DICOM_DIR, patient_scan_labels)
print(f'Length of Original Dataset: {len(original_dataset)}')

x,y,z = original_dataset[0]
# print(x.shape, y.shape, z.shape)

Length of Original Dataset: 473


In [17]:
def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    # original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=augmentor)
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# Contrastive Loss and Augmentation



## NTXentLoss

In [18]:
class NTXentLoss(losses.NTXentLoss):
    def __init__(self, temperature, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
        self.temperature = temperature

    def forward(self, embeddings1, embeddings2, labels=None, hard_pairs=None):
        # Concatenate the embeddings
        embeddings = torch.cat([embeddings1, embeddings2], dim=0)
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(embeddings, p=2, dim=1)

        if labels == None:
            # Self-supervised labels
            labels = torch.arange(feature_vectors_normalized.size(0))
        else:
            # Supervised labels
            labels = torch.cat([labels, labels], dim=0)

        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )

        if labels == None:
            return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))
        return losses.SupConLoss(temperature=self.temperature)(logits, torch.squeeze(labels), hard_pairs)

## Miner

In [19]:
from pytorch_metric_learning.miners import BaseMiner
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu

class ExamplePairMiner(BaseMiner):
    def __init__(self, margin=0.1, **kwargs):
        super().__init__(**kwargs)
        self.margin = margin

    def mine(self, embeddings, labels, ref_emb, ref_labels):
        mat = self.distance(embeddings, ref_emb)
        a1, p, a2, n = lmu.get_all_pairs_indices(labels, ref_labels)
        pos_pairs = mat[a1, p]
        neg_pairs = mat[a2, n]
        pos_mask = (
            pos_pairs < self.margin
            if self.distance.is_inverted
            else pos_pairs > self.margin
        )
        neg_mask = (
            neg_pairs > self.margin
            if self.distance.is_inverted
            else neg_pairs < self.margin
        )
        return a1[pos_mask], p[pos_mask], a2[neg_mask], n[neg_mask]

## Augmentation for Contrastive Learning

In [20]:
# Version 2: Avg time taken: 0.05 seconds for 1 augmentation (w ResizedCrop)
def augment_batch(batch_images):
    batch_size, num_instances, channels, height, width = batch_images.shape

    # Define augmentation transformations using GPU-compatible operations
    aug_transform = transforms.Compose([
        transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.1)),
        # transforms.RandomRotation(degrees=(-5, 5)),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4)], p=0.6),
        transforms.RandomHorizontalFlip(p=0.5),
    ])

    # Apply transformations directly on the tensor without converting to PIL
    augmented_batch = torch.empty_like(batch_images)  # Preallocate memory for augmented images

    for i in range(batch_size):
        for j in range(num_instances):
            # Apply the transformation directly to the tensor
            if CHANNELS == 1:
                augmented_batch[i, j] = aug_transform(batch_images[i, j])
            else:
                augmented_batch[i, j] = aug_transform(batch_images[i, j].permute(2, 0, 1)).permute(1, 2, 0)

    return augmented_batch.cuda()  # Move the augmented batch to GPU

# CNN Feature Extractor

## Utilities

In [21]:
# import utils.attention as AttentionLayer
# import utils.gaussian_process as GPModel
import utils.attention as AttentionLayer
import utils.gaussian_process as GPModel

## ResNet18

In [22]:
# import models.mil_resnet as MILResNet18
#
# model_params = {
#     'channels': CHANNELS,
#     'projection_location': 'after_gp',
#     'projection_hidden_dim': 256,
#     'projection_output_dim': 128
# }
#
# model = MILResNet18.MILResNet18(params=model_params)


class GPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class MILResNet18(nn.Module):
    def __init__(self):
        super(MILResNet18, self).__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.resnet.conv1 = nn.Conv2d(in_channels=CHANNELS, out_channels=64, kernel_size=7, stride=2, padding=3,
                                      bias=False)

        self.resnet.fc = nn.Identity()

        # self.attention = AttentionLayer(input_dim=512, hidden_dim=512)
        self.attention = AttentionLayer.GatedAttention(input_dim=512, hidden_dim=512)

        self.classifier = nn.Linear(512 + 1, 1)
        self.attention_classifier = nn.Linear(512, 1)
        self.dropout = nn.Dropout(0.4)

        inducing_points = torch.randn(32, 512)
        self.gp_layer = GPModel(inducing_points=inducing_points)

        self.projection_head = nn.Sequential(
            nn.Linear(512 + 1, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, bags):
        if CHANNELS == 1:
            batch_size, num_instances, c, h, w = bags.size()
        else:
            batch_size, num_instances, h, w, c = bags.size()

        bags_flattened = bags.view(batch_size * num_instances, c, h, w)

        # # Version 1: CNN-ResNet + Att + GP
        features = self.resnet(bags_flattened)
        features = self.dropout(features)
        features = features.view(batch_size, num_instances, -1)

        # projection_head = self.projection_head(features.view(batch_size * num_instances, -1))

        attended_features, attended_weights = self.attention(features)
        # attended_features_reshaped = attended_features.view(-1, 512)
        attended_features_reshaped = attended_features.view(batch_size, -1)

        # Ver2.1
        # projection_head = self.projection_head(attended_features_reshaped)
        # Ver2.2
        # projection_head = attended_features_reshaped

        # CNN_ATT_GP
        gp_output = self.gp_layer(attended_features_reshaped)
        gp_mean = gp_output.mean.view(batch_size, -1)

        combine_features = torch.cat((attended_features, gp_mean), dim=1)

        projection_head = self.projection_head(combine_features)

        combine_features = self.dropout(combine_features)

        outputs = torch.sigmoid(self.classifier(combine_features))
        att_outputs = torch.sigmoid(self.attention_classifier(attended_features_reshaped))

        return outputs, att_outputs, attended_weights, gp_output, projection_head

# Training and Evaluation

## Loss Functions

In [23]:
def combined_loss(outputs, gp_distribution, target, alpha=0.5):
    # Cross-Entropy Loss for CNN outputs
    bce_loss = nn.BCELoss()(outputs.squeeze(), target.float())
    kl_divergence = gp_distribution.variational_strategy.kl_divergence()
    total_loss = (1 - alpha) * bce_loss + alpha * kl_divergence

    return total_loss

## Training

In [24]:
def train_epoch(model, data_loader, criterion_cl, criterion_bce, optimizer, scheduler, device):
    total_loss = 0.0
    alpha = 0.5
    predictions = []
    labels = []

    loss, NTXLoss, loss_mod_1, loss_mod_2 = 0, 0, 0, 0
    output_1 = None

    model.train()

    for batch_data, batch_labels, batch_patient_labels in data_loader:
        batch_data = batch_data.to(device)
        batch_patient_labels = batch_patient_labels.float().to(device)
        optimizer.zero_grad()

        aug1 = augment_batch(batch_data).cuda()
        aug2 = augment_batch(batch_data).cuda()

        if isinstance(model, MILResNet18.MILResNet18):
            output_1, attention_out_1, _, _, proj_head_1 = model(aug1)
            output_2, attention_out_2, _, _, proj_head_2 = model(aug2)

            miner_func = ExamplePairMiner()
            hard_pairs = miner_func(torch.cat([output_1, output_2], dim=0), torch.cat([batch_patient_labels, batch_patient_labels], dim=0))

            NTXLoss = criterion_cl(proj_head_1, proj_head_2, batch_patient_labels, hard_pairs)
            loss_mod_1 = combined_loss(output_1, model.gp_layer, batch_patient_labels)
            loss_mod_2 = combined_loss(output_2, model.gp_layer, batch_patient_labels)
            loss = NTXLoss * 0.7 + loss_mod_1 * 0.15 + loss_mod_2 * 0.15

        loss = loss.mean()
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()

        if isinstance(model, MILResNet18.MILResNet18):
            predictions.extend((output_1.squeeze() > THRESHOLD).cpu().detach().numpy())

        labels.extend(batch_patient_labels.cpu().numpy())
    return total_loss / len(data_loader), predictions, labels

def validate(model, data_loader, criterion_cl, criterion_bce, device):
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    alpha = 0.5
    predictions = []
    labels = []

    with torch.no_grad():
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)

            if isinstance(model, MILResNet18.MILResNet18):
                output, attention_out, _, _, _ = model(batch_data)
                predictions.extend((output.squeeze() > THRESHOLD).cpu().detach().numpy())
            else:
                z_i, output, predicted_bags, _, _, gp_combine = model(batch_data)
                predictions.extend((predicted_bags.squeeze() > THRESHOLD).cpu().detach().numpy())
                # predictions.extend((gp_combine.squeeze() > THRESHOLD).cpu().detach().numpy())

            labels.extend(batch_patient_labels.cpu().numpy())

    return total_loss / len(data_loader), predictions, labels

def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions),
        "recall": recall_score(labels, predictions),
        "f1": f1_score(labels, predictions)
    }

def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch+1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

def train_model(model, train_loader, val_loader, criterion_cl, criterion_bce, optimizer, num_epochs, learning_rate, device='cuda'):
    """Train the model and return the best model based on validation accuracy."""
    model = model.to(device)
    model.train()

    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate,
                                              steps_per_epoch=len(train_loader), epochs=num_epochs)
    best_val_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        train_loss, train_predictions, train_labels = train_epoch(model, train_loader, criterion_cl, criterion_bce, optimizer, scheduler, device)
        train_metrics = calculate_metrics(train_predictions, train_labels)
        print_epoch_stats(epoch, num_epochs, "train", train_loss, train_metrics)
        # Log training metrics to W&B
        wandb.log({"train/loss": train_loss / len(train_loader), **train_metrics})

        # Validation phase
        val_loss, val_predictions, val_labels = validate(model, val_loader, criterion_cl, criterion_bce, device)
        val_metrics = calculate_metrics(val_predictions, val_labels)
        print_epoch_stats(epoch, num_epochs, "validation", val_loss, val_metrics)
        # Log validation metrics to W&B
        wandb.log({"val/loss": val_loss / len(val_loader), **val_metrics})

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            best_model_state = model.state_dict()

    # Load best model
    model.load_state_dict(best_model_state)
    # Optionally log the best model to W&B (if desired)
    wandb.log_artifact(wandb.Artifact("best_model", type="model", metadata={"accuracy": best_val_accuracy}))

    return model

## Evaluation

In [25]:
## Model Evaluation Functions
def evaluate_model(model, data_loader, device='cuda'):
    """Evaluate the model on the given data loader."""
    model = model.to(device)
    model.eval()

    predictions = []
    labels = []

    with torch.inference_mode():
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            # batch_patient_labels = batch_patient_labels.float().to(device)
            batch_patient_labels = batch_patient_labels.to(device)

            if isinstance(model, MILResNet18):
                output, attention_out, _, _, _ = model(batch_data)
                # MIL-ResNet18
                predictions.extend((output.squeeze() > THRESHOLD).cpu().detach().numpy())
            else:
                z_i, outputs, predicted_bags, _, _, gp_combine = model(batch_data)
                predictions.extend((predicted_bags.squeeze() > THRESHOLD).cpu().detach().numpy())
                # predictions.extend((gp_combine.squeeze() > THRESHOLD).cpu().detach().numpy())

            labels.extend(batch_patient_labels.cpu().numpy())

    return np.array(predictions), np.array(labels)

def print_metrics(metrics):
    """Print the calculated metrics."""
    print(f"Test Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, "
          f"Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

# Visualizations

In [26]:
## Visualization Functions
def plot_roc_curve(model, data_loader, device):
    """Plot the ROC curve for the model predictions."""
    model.eval()
    labels = []
    predictions = []

    with torch.no_grad():
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)

            # z_i, outputs, predicted_bags, _, _, gp_combine = model(batch_data)
            #
            # predictions.extend(predicted_bags.squeeze().cpu().numpy())
            # # predictions.extend(gp_combine.squeeze().cpu().numpy())

            output, attention_out, _, _, _ = model(batch_data)
            # MIL-ResNet18
            predictions.extend(output.squeeze().cpu().numpy())

            labels.extend(batch_patient_labels.cpu().numpy())

    fpr, tpr, _ = roc_curve(labels, predictions)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

def plot_confusion_matrix(model, data_loader, device):
    """Plot the confusion matrix for the model predictions."""
    predictions, labels = evaluate_model(model, data_loader, device)

    cm = confusion_matrix(labels, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title('Confusion Matrix')
    plt.show()

## Attention Weights and Images

In [27]:
def plot_label_attention_weights(model, data_loader, device='cuda'):
    """
    Plot images with their labels and attention values in a single large plot.

    Parameters:
    - model: The trained model
    - data_loader: DataLoader containing test dataset
    - device: Device to run the model on ('cuda' or 'cpu')
    - CHANNELS: Number of channels in the image (e.g., 1 for grayscale, 3 for RGB)

    Expected shapes:
    - 1-channel images: (batch_size, num_images, 224, 224)
    - 3-channel images: (batch_size, num_images, 3, 224, 224)
    - attention: float value per image indicating attention weight
    """
    model = model.to(device)
    model.eval()
    num_images = MAX_SLICES
    rows, cols = 10, 6  # Adjust to fit 60 images in a single plot

    with torch.no_grad():
        for batch_data, batch_labels, batch_patients_label in data_loader:
            # Move data to the appropriate device
            batch_data = batch_data.to(device)
            outputs, _, attention_weight_batch, _, _ = model(batch_data)

            # Process each patient in the batch
            for patient_idx in range(batch_data.size(0)):
                if batch_patients_label[patient_idx].item() == 1:  # Check if patient has positive label
                    # Create a new figure for this patient
                    fig = plt.figure(figsize=(cols * 4, rows * 4 + 2))  # Increased height for suptitle

                    for img_idx in range(num_images):
                        # Get the image and its label
                        img = batch_data[patient_idx, img_idx].cpu().numpy()
                        img_label = batch_labels[patient_idx, img_idx].cpu().numpy()

                        # Get attention value
                        if attention_weight_batch.size(1) == batch_data.size(1):
                            attention_value = attention_weight_batch[patient_idx, img_idx].cpu().item()
                        else:
                            attention_value = attention_weight_batch[patient_idx].cpu().item()

                        # Plot image
                        plt.subplot(rows, cols, img_idx + 1)
                        if CHANNELS == 3:  # RGB image
                            plt.imshow(img)
                        else:  # Grayscale image
                            if img.ndim == 3:  # If shape is (1, H, W)
                                img = np.squeeze(img)  # Convert to (H, W)
                            plt.imshow(img, cmap='gray')

                        plt.title(f'Label: {img_label}\nAttention: {attention_value:.4f}', fontsize=12)
                        plt.axis('off')

                    # Add overall title for the patient
                    plt.suptitle(f'Patient Images (Patient Label: {batch_patients_label[patient_idx].cpu().numpy()})', fontsize=16)
                    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Adjust rect to make space for suptitle
                    plt.show()

                    # Since we are plotting only for one patient, return after the first plot
                    return

## Visualization Augmented Bags

In [28]:
def visualize_augmented_bags(original_bags, augmented_bags, num_bags=12):
    """
    Visualizes all instances of the first bag of original and augmented images.

    Parameters:
    - original_bags: A tensor of shape (batch_size, num_instances, channels, height, width)
    - augmented_bags: A tensor of shape (batch_size, num_instances, channels, height, width)
    - num_bags: Number of bags to visualize (only the first bag will be shown).
    """
    # Only visualize the first bag
    first_bag_index = 0

    # Get number of instances
    num_instances = original_bags.size(1)

    print(f'Num instances: {num_instances}')

    # Limit the number of bags to visualize (but we only show the first one)
    num_bags = min(num_bags, 1)  # We only want to visualize the first bag

    fig, axes = plt.subplots(num_instances, 2, figsize=(10, 2 * num_instances))

    # Original images
    for j in range(num_instances):  # Iterate over instances in the first bag
        img = original_bags[first_bag_index][j].cpu().numpy().squeeze()  # Remove channel dimension
        axes[j, 0].imshow(img, cmap='gray')  # Use gray colormap for single channel images
        axes[j, 0].axis('off')  # Hide axes for better visualization
        axes[j, 0].set_title(f'Original Instance {j + 1}')

    # Augmented images
    for j in range(num_instances):
        img = augmented_bags[first_bag_index][j].cpu().numpy().squeeze()  # Remove channel dimension
        axes[j, 1].imshow(img.squeeze(), cmap='gray')  # Use gray colormap for single channel images
        axes[j, 1].axis('off')  # Hide axes for better visualization
        axes[j, 1].set_title(f'Augmented Instance {j + 1}')

    plt.tight_layout()
    plt.show()


# Model Loading

In [29]:
## Data Processing Functions
def load_model(model_class, model_path):
    """Load a trained model from a file."""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")

    model = model_class()
    try:
        state_dict = torch.load(model_path, map_location=torch.device('cuda'), weights_only=True)
        if not state_dict:
            raise ValueError(f"The state dictionary loaded from {model_path} is empty")
        model.load_state_dict(state_dict)
    except Exception as e:
        print(f"Error loading model from {model_path}: {str(e)}")
        print("Initializing model with random weights instead.")
        return model  # Return the model with random initialization

    return model.eval()


def get_test_results(model, test_loader, test_labels, device=DEVICE):
    """Get test results including patient information."""
    predictions, _ = evaluate_model(model, test_loader, device)

    results = []
    for i, row in enumerate(test_labels.itertuples(index=False)):
        result = {col: getattr(row, col) for col in test_labels.columns}
        result['prediction'] = predictions[i]
        results.append(result)

    return pd.DataFrame(results)

# Main Function

In [30]:
def main(mode='train'):
    os.environ["WANDB_DISABLED"] = "true"

    # Initialize W&B
    wandb.init(project="CQ500", entity="milresnet", mode="disabled")

    # Log hyperparameters
    config = wandb.config
    config.learning_rate = LEARNING_RATE
    config.batch_size = TRAIN_BATCH_SIZE
    config.num_epochs = NUM_EPOCHS

    train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
    # train_loader = get_train_loader(DICOM_DIR, train_labels, batch_size=TRAIN_BATCH_SIZE)
    # val_loader = get_train_loader(DICOM_DIR, val_labels, batch_size=VALID_BATCH_SIZE)
    test_loader = get_test_loader(DICOM_DIR, test_labels, batch_size=TEST_BATCH_SIZE)

    # Initialize model, criterion, and optimizer
    model = MILResNet18()
    # model = Encoder()

    criterion_cl = NTXentLoss(0.5)
    criterion_bce = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    # if mode == 'train':
    #     # Watch the model to log gradients and parameters
    #     wandb.watch(model)
    #     # Train model
    #     trained_model = train_model(model, train_loader, val_loader, criterion_cl, criterion_bce, optimizer, config.num_epochs, config.learning_rate, DEVICE)
    #     # Save model
    #     torch.save(trained_model.state_dict(), MODEL_PATH)

    # Load best model
    trained_model = load_model(MILResNet18, MODEL_PATH)


    # Evaluate model
    predictions, labels = evaluate_model(trained_model, test_loader, DEVICE)
    metrics = calculate_metrics(predictions, labels)

    # Log metrics to W&B
    wandb.log(metrics)

    print_metrics(metrics)

    if mode == 'train':
        # Visualizations
        plot_roc_curve(trained_model, test_loader, DEVICE)
        plot_confusion_matrix(trained_model, test_loader, DEVICE)

    if mode == 'train':
        required_columns = ['patient_id', 'study_instance_uid', 'patient_label']
        temp_test_labels = test_labels[required_columns]

        # Save results
        # results_df = get_test_results(trained_model, test_loader, temp_test_labels, device)
        # results_df.to_csv('results/results.csv', index=False)
        # print(results_df.head())

        # # Log results DataFrame as a table in W&B (optional)
        # wandb.log({"results": wandb.Table(dataframe=results_df)})

    # Call the function with the test_loader
    # if isinstance(model, MILResNet18):
    #     test_loader = get_test_loader(DICOM_DIR, test_labels, batch_size=TEST_BATCH_SIZE)
    #     plot_label_attention_weights(trained_model, test_loader, device)

    # # Get the first batch of images from the evaluation loader
    # images, _, _ = next(iter(train_loader))
    # print(f'Original batch shape: {images.shape}')
    #
    # # Augment the batch of images
    # start = time.time()
    # augmented_images = augment_batch(images)
    # end = time.time()
    # taken_time = end - start
    # print(f'Augmented batch shape: {augmented_images.shape} | Time: {taken_time:.4f}')
    #
    # # Visualize the original and augmented bags
    # visualize_augmented_bags(images, augmented_images)

if __name__ == "__main__":
    main(mode='test')

Test Accuracy: 0.7627, Precision: 0.7249, Recall: 0.6954, F1: 0.7098


# Result
