## Library

In [1]:
from logging import critical

from pygments.unistring import combine
!pip install gpytorch
!pip install wandb



In [2]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import cv2
from tqdm import tqdm
from skimage.transform import resize
from skimage.transform import rotate
from scipy.ndimage import gaussian_filter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import torchvision.models as models
from torchvision.transforms import v2 as transforms

import wandb

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay

import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy

import time

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

## Init GPU

In [4]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

GPU: NVIDIA GeForce RTX 4070 SUPER is available.
cuda


In [5]:
%load_ext autoreload
%autoreload 2

## Config Info

In [6]:
# Constants
HEIGHT = 224
WIDTH = 224
CHANNELS = 3

TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
TEST_BATCH_SIZE = 2
TEST_SIZE = 0.15
VALID_SIZE = 0.15

MAX_SLICES = 60
SHAPE = (HEIGHT, WIDTH, CHANNELS)

NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
INDUCING_POINTS = 128

# TARGET_LABELS = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
TARGET_LABELS = ['intraparenchymal']

MODEL_PATH = 'results/trained_model.pth'
DEVICE = 'cuda'

In [7]:
# Kaggle and local switch
KAGGLE = os.path.exists('/kaggle')
print("Running on Kaggle" if KAGGLE else "Running locally")
ROOT_DIR = '/kaggle/input/rsna-mil-training/' if KAGGLE else None
# DATA_DIR = ROOT_DIR + 'rsna-mil-training/' if KAGGLE else '../rsna-mil-training/'
DATA_DIR = ROOT_DIR + 'rsna-mil-training/' if KAGGLE else '../rsna-ich-mil/'
DICOM_DIR = DATA_DIR
# CSV_PATH = DICOM_DIR + 'training_1000_scan_subset.csv' if KAGGLE else './data_analyze/training_1000_scan_subset.csv'
CSV_PATH = DICOM_DIR + 'training_1000_scan_subset.csv' if KAGGLE else './data_analyze/training_dataset.csv'
# SLICE_LABEL_PATH = ROOT_DIR + "sorted_training_dataset_with_labels.csv" if KAGGLE else './data_analyze/sorted_training_dataset_with_labels.csv'

dicom_dir = DICOM_DIR if KAGGLE else DATA_DIR
# Load patient scan labels
# patient_scan_labels = pd.read_csv(CSV_PATH)
patient_scan_labels = pd.read_csv(CSV_PATH, nrows=1150)

# patient_slice_labels = pd.read_csv(SLICE_LABEL_PATH)
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)

Running locally


In [8]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# key = user_secrets.get_secret("Wandb key")
# 
# wandb.login(key=key, relogin=True)

In [9]:
patient_scan_labels.head()

Unnamed: 0,filename,any,epidural,intraparenchymal,intraventricular,subarachnoid,subdural,patient_id,study_instance_uid,series_instance_uid,image_position,samples_per_pixel,pixel_spacing,pixel_representation,window_center,window_width,rescale_intercept,rescale_slope,ID,patient_label
0,ID_45785016b.dcm,0,0,0,0,0,0,ID_0002cd41,ID_66929e09d4,ID_e22a5534e6,"[-125.000, -122.596, 35.968]",1,"[0.488281, 0.488281]",1,30,80,-1024.0,1.0,45785016b,0
1,ID_37f32aed2.dcm,0,0,0,0,0,0,ID_0002cd41,ID_66929e09d4,ID_e22a5534e6,"[-125.000, -122.596, 38.484]",1,"[0.488281, 0.488281]",1,30,80,-1024.0,1.0,37f32aed2,0
2,ID_1b9de2922.dcm,0,0,0,0,0,0,ID_0002cd41,ID_66929e09d4,ID_e22a5534e6,"[-125.000, -122.596, 41.000]",1,"[0.488281, 0.488281]",1,30,80,-1024.0,1.0,1b9de2922,0
3,ID_d61a6a7b9.dcm,0,0,0,0,0,0,ID_0002cd41,ID_66929e09d4,ID_e22a5534e6,"[-125.000, -122.596, 43.517]",1,"[0.488281, 0.488281]",1,30,80,-1024.0,1.0,d61a6a7b9,0
4,ID_406c82112.dcm,0,0,0,0,0,0,ID_0002cd41,ID_66929e09d4,ID_e22a5534e6,"[-125.000, -122.596, 46.033]",1,"[0.488281, 0.488281]",1,30,80,-1024.0,1.0,406c82112,0


## Data Preprocessing

In [10]:
def correct_dcm(dcm):
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000

def window_image(dcm, window_center, window_width):    
    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
        correct_dcm(dcm)
    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept
    
    # Resize
    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)
   
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def bsb_window(dcm):
    brain_img = window_image(dcm, 40, 80)
    subdural_img = window_image(dcm, 80, 200)
    soft_img = window_image(dcm, 40, 380)
    
    brain_img = (brain_img - 0) / 80
    subdural_img = (subdural_img - (-20)) / 200
    soft_img = (soft_img - (-150)) / 380

    if CHANNELS == 3:
        bsb_img = np.stack([brain_img, subdural_img, soft_img], axis=-1)
    else:
        bsb_img = brain_img
    return bsb_img.astype(np.float16)

In [11]:
def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    # Check if type of slice is dicom or an empty numpy array
    if (type(slice) == np.ndarray):
        slice = resize(slice, target_size, anti_aliasing=True)
        multichannel_slice = np.stack([slice, slice, slice], axis=-1)
        if CHANNELS == 3:
            return multichannel_slice.astype(np.float16)
        else:
            return slice.astype(np.float16)
    else:
        slice = bsb_window(slice)
        return slice.astype(np.float16)

In [12]:
import zipfile

def read_dicom_folder(folder_path):
    slices = []
    for filename in sorted(os.listdir(folder_path))[:MAX_SLICES]:  # Limit to MAX_SLICES
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds)
            
    # Sort slices by images position (z-coordinate) in ascending order
    slices = sorted(slices, key=lambda x: float(x.ImagePositionPatient[2]))
    
    # Pad with black images if necessary
    while len(slices) < MAX_SLICES:
        slices.append(np.zeros_like(slices[0].pixel_array))
    
    return slices[:MAX_SLICES]  # Ensure we return exactly MAX_SLICES

def read_dicom_file(file):
    """Read a single DICOM file and return its pixel data."""
    ds = pydicom.dcmread(file)
    return ds

import concurrent.futures
def read_dicom_zip(zip_path, patient_id, study_instance_uid):
    slices = []
    folder_name = f"{patient_id}_{study_instance_uid}/"  # Define the folder structure in the zip

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # List all DICOM files in the specified folder within the zip
        dicom_files = [f for f in zip_ref.namelist() if f.startswith(folder_name) and f.endswith('.dcm')]
        dicom_files = sorted(dicom_files)[:MAX_SLICES]  # Limit to MAX_SLICES

        # Use ThreadPoolExecutor to read DICOM files concurrently
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_file = {executor.submit(read_dicom_file, zip_ref.open(filename)): filename for filename in dicom_files}
            for future in concurrent.futures.as_completed(future_to_file):
                try:
                    dicom_file = future.result()
                    slices.append(dicom_file)
                except Exception as e:
                    print(f"Error reading {future_to_file[future]}: {e}")

    # Sort slices by image position (z-coordinate) if available
    if slices:
        slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))

    # Pad with black images if necessary
    while len(slices) < MAX_SLICES:
        slices.append(np.zeros_like(slices[0].pixel_array))

    return slices[:MAX_SLICES]  # Ensure we return exactly MAX_SLICES

## Dataset and DataLoader

### Splitting the Dataset

In [13]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels['patient_label']

    # First, split off the test set
    train_val_labels, test_labels = train_test_split(
        patient_scan_labels, 
        test_size=test_size, 
        stratify=labels, 
        random_state=random_state
    )

    # Calculate the validation size relative to the train_val set
    val_size_adjusted = val_size / (1 - test_size)

    # Split the train_val set into train and validation sets
    train_labels, val_labels = train_test_split(
        train_val_labels, 
        test_size=val_size_adjusted, 
        stratify=train_val_labels['patient_label'], 
        random_state=random_state
    )

    return train_labels, val_labels, test_labels

### Processing the Data

In [14]:
# import zipfile
# 
# def process_patient_data(dicom_dir, row, num_instances=12, depth=5):
#     patient_id = row['patient_id'].replace('ID_', '')
#     study_instance_uid = row['study_instance_uid'].replace('ID_', '')
# 
#     folder_name = f"{patient_id}_{study_instance_uid}"
#     zip_file_path = dicom_dir  # Assuming dicom_dir is the path to the zip file
# 
#     # Check if the zip file exists and is valid
#     if not zipfile.is_zipfile(zip_file_path):
#         print(f"File is not a valid zip file: {zip_file_path}")
#         return None, None
# 
#     with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
#         # Check if the folder exists in the zip file
#         if any(name.startswith(folder_name + '/') for name in zip_ref.namelist()):
#             # Read DICOM slices from the ZIP file
#             slices = read_dicom_zip(zip_file_path, patient_id, study_instance_uid)
# 
#             # Preprocess slices and convert to tensor in one go
#             preprocessed_slices = torch.stack(
#                 [torch.tensor(preprocess_slice(slice), dtype=torch.float32) for slice in slices]
#             )
# 
#             # Convert labels to tensor directly
#             labels = torch.tensor(row['labels'], dtype=torch.long)
# 
#             # Efficiently handle label padding
#             padded_labels = torch.empty(len(preprocessed_slices), dtype=torch.long)
#             padded_labels[:min(len(labels), len(preprocessed_slices))] = labels[:min(len(labels), len(preprocessed_slices))]
# 
#             return preprocessed_slices, padded_labels
# 
#         else:
#             print(f"Folder not found: {folder_name}")
#             return None, None

def process_patient_data(dicom_dir, row, num_instances=12, depth=5):
    patient_id = row['patient_id'].replace('ID_', '')
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')

    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(dicom_dir, folder_name)

    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)

        # # Ensure we have enough slices to create the specified instances
        # if len(slices) < depth * num_instances:
        #     print(f"Not enough slices for patient {patient_id}: found {len(slices)}, needed {depth * num_instances}")
        #     return None, None

        preprocessed_slices = [torch.tensor(preprocess_slice(slice), dtype=torch.float32) for slice in slices]  # Convert to tensor

        # Stack preprocessed slices into an array
        preprocessed_slices = torch.stack(preprocessed_slices, dim=0)  # (num_slices, height, width, channels)

        # Labels are already in list form, so just convert them to a tensor
        labels = torch.tensor(row['labels'], dtype=torch.long)

        # Fill labels with 0s if necessary
        if len(preprocessed_slices) > len(labels):
            padded_labels = torch.zeros(len(preprocessed_slices), dtype=torch.long)
            padded_labels[:len(labels)] = labels
        else:
            padded_labels = labels[:len(preprocessed_slices)]

        return preprocessed_slices, padded_labels

    else:
        print(f"Folder not found: {folder_name}")
        return None, None

### Augmentation

In [15]:
class DatasetAugmentor:
    def __init__(self, height, width, levels=2, seed=None):
        self.height = height
        self.width = width
        self.levels = levels  # Dynamic number of levels
        self.seed = seed
        self.params = []

        # Create different levels of transforms based on the number of levels specified
        for i in range(levels):
            factor = (i + 1) / levels
            self.params.append(
                self._create_transform(
                    degrees=int(15 * factor), 
                    translate_range=(0.2 * factor, 0.2 * factor),
                    scale_range=(1 - 0.2 * factor, 1 + 0.2 * factor),
                    brightness_range=0.2 * factor,
                    contrast_range=0.2 * factor,
                    blur_sigma_range=(0.5 * factor, 1.0 * factor),
                    apply_elastic=(i >= levels // 2),
                    level_name=f'level_{i + 1}'
                )
            )

    def _sample_value(self, value_range):
        if isinstance(value_range, tuple):
            random.seed(self.seed)
            return random.uniform(value_range[0], value_range[1])
        return value_range

    def _create_transform(self, degrees, translate_range, scale_range, brightness_range, contrast_range, blur_sigma_range, apply_elastic, level_name):
        print(f"Creating '{level_name}' transform with parameters:")
        sampled_values = {
            "degrees": abs(self._sample_value((-degrees, degrees))),
            "translate": (abs(self._sample_value(translate_range[0])), abs(self._sample_value(translate_range[1]))),
            "scale": self._sample_value(scale_range),
            "brightness": self._sample_value(brightness_range),
            "contrast": self._sample_value(contrast_range),
            "blur_sigma": self._sample_value(blur_sigma_range),
            "apply_elastic": apply_elastic
        }
        
        print(sampled_values)
        return sampled_values

    def apply_transform(self, image, level):
        params = self.params[level]
        transform = self._get_transform(params, channels=image.shape[0])
        return transform(image)

    def _get_transform(self, params, channels=3):
        transform_list = [
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=params["degrees"], translate=params["translate"], scale=(params["scale"], params["scale"])),
            transforms.ColorJitter(brightness=params["brightness"], contrast=params["contrast"]),
            transforms.GaussianBlur(kernel_size=(3, 3), sigma=params["blur_sigma"]),
            transforms.RandomApply([transforms.ElasticTransform()] if params["apply_elastic"] else [], p=0.3),
            transforms.Resize(256),
            transforms.CenterCrop(self.height),
            transforms.ToTensor(),
        ]

        if channels == 3:
            transform_list.extend([
                transforms.Normalize(mean=[0.16774411, 0.1360026, 0.19076315], std=[0.3101935, 0.27605791, 0.30469988]),
                transforms.RandomApply([self._channel_shuffle], p=0.3)
            ])
        elif channels == 1:
            transform_list.append(transforms.Normalize(mean=[0.16774411], std=[0.3101935]))

        return transforms.Compose(transform_list)

    def _channel_shuffle(self, tensor):
        torch.manual_seed(self.seed)
        channels = tensor.shape[0]
        indices = torch.randperm(channels)
        return tensor[indices]

In [16]:
augmentor = DatasetAugmentor(224, 224, seed=42)

Creating 'level_1' transform with parameters:
{'degrees': 1.9519751784103718, 'translate': (0.1, 0.1), 'scale': 1.027885359691577, 'brightness': 0.1, 'contrast': 0.1, 'blur_sigma': 0.4098566996144709, 'apply_elastic': False}
Creating 'level_2' transform with parameters:
{'degrees': 4.182803953736514, 'translate': (0.2, 0.2), 'scale': 1.0557707193831534, 'brightness': 0.2, 'contrast': 0.2, 'blur_sigma': 0.8197133992289418, 'apply_elastic': True}


### Dataset Generator

In [17]:
# class MedicalScanDataset(Dataset):
#     def __init__(self, data_dir, patient_scan_labels, augmentor=None):
#         self.data_dir = data_dir
#         self.patient_scan_labels = self._parse_patient_scan_labels(patient_scan_labels)
#         self.augmentor = augmentor
# 
#     def _parse_patient_scan_labels(self, patient_scan_labels):
#         """Parse and validate patient scan labels."""
#         patient_scan_labels['images'] = patient_scan_labels['images'].apply(
#             lambda x: eval(x) if isinstance(x, str) else x
#         )
#         patient_scan_labels['labels'] = patient_scan_labels['labels'].apply(
#             lambda x: eval(x) if isinstance(x, str) else x
#         )
#         patient_scan_labels['patient_label'] = patient_scan_labels['patient_label'].astype(bool)
#         return patient_scan_labels
# 
#     def _process_patient_data(self, row):
#         """Process patient data to get preprocessed slices and labels."""
#         return process_patient_data(self.data_dir, row)
# 
#     def __len__(self):
#         return len(self.patient_scan_labels) * (self.augmentor.levels if self.augmentor else 1)
# 
#     def __getitem__(self, idx):
#         patient_idx = idx // (self.augmentor.levels if self.augmentor else 1)
#         aug_level = idx % (self.augmentor.levels if self.augmentor else 1)
# 
#         row = self.patient_scan_labels.iloc[patient_idx]
#         preprocessed_slices, labels = self._process_patient_data(row)
# 
#         if preprocessed_slices is None:
#             return None, None, None
# 
#         preprocessed_slices = self._prepare_tensor(preprocessed_slices, aug_level if self.augmentor else None)
#         patient_label = torch.tensor(bool(row['patient_label']), dtype=torch.uint8)
# 
#         return preprocessed_slices, labels, patient_label
# 
#     def _prepare_tensor(self, preprocessed_slices, aug_level):
#         # Convert to numpy array and then to torch tensor
#         preprocessed_slices = np.array(preprocessed_slices)
#         preprocessed_slices = torch.tensor(preprocessed_slices, dtype=torch.float32)
# 
#         # Add an additional dimension for channel if it's missing (no augmentor)
#         if preprocessed_slices.ndim == 3:
#             preprocessed_slices = preprocessed_slices.unsqueeze(1)  # shape: [slices, 1, H, W]
# 
#         # Apply augmentation if augmentor is specified
#         if self.augmentor and aug_level is not None:
#             if preprocessed_slices.ndim == 4:  # Ensure it has the [slices, channels, H, W] format
#                 return torch.stack([self.augmentor.apply_transform(img, aug_level) for img in preprocessed_slices])
# 
#         return preprocessed_slices  # Return without augmentation if augmentor is None

In [18]:
class MedicalScanDataset:
    def __init__(self, data_dir, dataset, augmentor=None):
        self.data_dir = data_dir
        self.dataset = self._parse_dataset(dataset)
        self.augmentor = augmentor

    def _parse_dataset(self, dataset):
        """Parse and validate the new dataset."""
        # Generate labels based on specified columns
        dataset['labels'] = dataset[['any', 'epidural', 'intraparenchymal', 
                                      'intraventricular', 'subarachnoid', 
                                      'subdural']].apply(
            lambda row: [1 if any(row) else 0], axis=1
        )

        # Generate images as lists containing filenames
        dataset['images'] = dataset['filename'].apply(lambda x: [x])

        return dataset

    def _process_patient_data(self, row):
        """Process patient data to get preprocessed slices and labels."""
        return process_patient_data(self.data_dir, row)

    def __len__(self):
        return len(self.dataset) * (self.augmentor.levels if self.augmentor else 1)

    def __getitem__(self, idx):
        patient_idx = idx // (self.augmentor.levels if self.augmentor else 1)
        aug_level = idx % (self.augmentor.levels if self.augmentor else 1)

        row = self.dataset.iloc[patient_idx]
        preprocessed_slices, labels = self._process_patient_data(row)

        preprocessed_slices = self._prepare_tensor(preprocessed_slices, aug_level if self.augmentor else None)
        # labels = row['labels']
        patient_label = torch.tensor(bool(row['patient_label']), dtype=torch.uint8)

        return preprocessed_slices, labels, patient_label

    def _prepare_tensor(self, preprocessed_slices, aug_level):
        # Convert to numpy array and then to torch tensor
        preprocessed_slices = np.array(preprocessed_slices)
        preprocessed_slices = torch.tensor(preprocessed_slices, dtype=torch.float32)

        # Add an additional dimension for channel if it's missing (no augmentor)
        if preprocessed_slices.ndim == 3:
            preprocessed_slices = preprocessed_slices.unsqueeze(1)  # shape: [slices, 1, H, W]

        # Apply augmentation if augmentor is specified
        if self.augmentor and aug_level is not None:
            if preprocessed_slices.ndim == 4:  # Ensure it has the [slices, channels, H, W] format
                return torch.stack([self.augmentor.apply_transform(img, aug_level) for img in preprocessed_slices])

        return preprocessed_slices  # Return without augmentation if augmentor is None


In [19]:
class TrainDatasetGenerator(MedicalScanDataset):
    """Dataset class for training medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

class TestDatasetGenerator(MedicalScanDataset):
    """Dataset class for testing medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

In [20]:
original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)

In [21]:
len(original_dataset)

1150

In [22]:
x,y,z = original_dataset[0]
print(x.shape, y.shape, z.shape)

torch.Size([60, 224, 224, 3]) torch.Size([60]) torch.Size([])


In [23]:
# # Check if the returned data is valid
# if x is not None:
#     # Convert the tensor to a numpy array
#     x_np = x.numpy()
#
#     # Check the number of dimensions and squeeze if necessary
#     if x_np.ndim == 4:  # RGB images
#         # Plot each slice
#         fig, axes = plt.subplots(1, x_np.shape[0], figsize=(15, 5))
#         for i, ax in enumerate(axes):
#             ax.imshow(x_np[i].transpose(1, 2, 0))  # Convert CHW to HWC
#             ax.axis('off')
#         plt.show()
#     elif x_np.ndim == 3:  # Grayscale images
#         # Plot each slice
#         fig, axes = plt.subplots(1, x_np.shape[0], figsize=(15, 5))
#         for i, ax in enumerate(axes):
#             ax.imshow(x_np[i], cmap='gray')
#             ax.axis('off')
#         plt.show()
#     else:
#         raise ValueError(f"Unexpected number of dimensions: {x_np.ndim}")
# else:
#     print("No data available for this patient.")


In [24]:
def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    # original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=augmentor)
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

## NTXent Loss

In [25]:
# Define NTXentLoss (provided by you)
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        z = F.normalize(z, dim=1)
        similarity_matrix = torch.mm(z, z.T) / self.temperature
        mask = torch.eye(2 * batch_size, device=z.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))
        exp_sim = torch.exp(similarity_matrix)
        denominator = exp_sim.sum(dim=1)
        positive_samples = torch.cat(
            [torch.arange(batch_size, 2 * batch_size), torch.arange(batch_size)], dim=0
        ).to(z.device)
        positives = similarity_matrix[torch.arange(2 * batch_size), positive_samples]
        loss = -torch.log(torch.exp(positives) / denominator)
        return loss.mean()

## Augmentation For Contrastive Learning

In [26]:
# # Augmentation function
# Version 1: Avg time taken: 0.14 for 1 augmentation (w/o ResizedCrop)
# def augment_batch(batch_images):
#     batch_size, num_instances, channels, height, width = batch_images.shape
#     aug_transform = transforms.Compose([
#         # transforms.Resize((112, 112)),
#         transforms.RandomApply([transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25)], p=0.6),
#         transforms.RandomGrayscale(p=0.2),
#         transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)])
#     ])
# 
#     # Apply transformation to each image instance in the batch
#     augmented_batch = []
#     for i in range(batch_size):
#         augmented_instances = [aug_transform(transforms.ToPILImage()(img.cpu())) for img in batch_images[i]]
#         augmented_batch.append(torch.stack(augmented_instances))
# 
#     return torch.stack(augmented_batch).cuda()  # Move the augmented batch to GPU

# Version 2: Avg time taken: 0.05 seconds for 1 augmentation (w ResizedCrop)
def augment_batch(batch_images):
    batch_size, num_instances, channels, height, width = batch_images.shape

    # Define augmentation transformations using GPU-compatible operations
    aug_transform = transforms.Compose([
        transforms.RandomResizedCrop((224, 224), scale=(0.75, 1.15)),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25)], p=0.6),
        # transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4)], p=0.6),
    ])

    # Apply transformations directly on the tensor without converting to PIL
    augmented_batch = torch.empty_like(batch_images)  # Preallocate memory for augmented images

    for i in range(batch_size):
        for j in range(num_instances):
            # Apply the transformation directly to the tensor
            if CHANNELS == 1:
                augmented_batch[i, j] = aug_transform(batch_images[i, j])
            else:
                augmented_batch[i, j] = aug_transform(batch_images[i, j].permute(2, 0, 1)).permute(1, 2, 0)

    return augmented_batch.cuda()  # Move the augmented batch to GPU

## CNN Feature Extractor

### Attention Layer

In [27]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.PReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)

        return (x * weights).sum(dim=1), weights.squeeze(-1)

In [28]:
class GatedAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GatedAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # nn.Tanh(),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.gate = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # nn.Tanh(),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, x):
        # x shape: (batch_size, num_instances, input_dim)
        attention_weights = self.attention(x)
        gate_weights = torch.sigmoid(self.gate(x))
        
        weights = attention_weights * gate_weights
        weights = F.softmax(weights, dim=1)
        
        return (x * weights).sum(dim=1), weights.squeeze(-1)

### Gaussian Process

In [29]:
class GPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

### Dual-Stream MIL Model (DSMIL)

In [30]:
class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(FCLayer, self).__init__()
        self.fc = nn.Sequential(nn.Linear(input_dim, output_dim))
        
    def forward(self, x):
        return self.fc(x)
    
class InstanceClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(InstanceClassifier, self).__init__()
        self.features_extractor = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.features_extractor.conv1 = nn.Conv2d(CHANNELS, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.features_extractor.fc = nn.Identity()
        
        self.fc = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        if CHANNELS == 1:
            batch_size, num_instances, C, H, W = x.shape
        else:
            batch_size, num_instances, H, W, C = x.shape
        x = x.view(batch_size * num_instances, C, H, W)
        
        instance_features = nn.Dropout(0.35)(self.features_extractor(x)).view(batch_size, num_instances, -1)
        classes = self.fc(instance_features)
        
        return instance_features, classes
    
class BagClassifier(nn.Module):
    def __init__(self, input_dim, output_dim=1, hidden_dim=128, dropout_v=0.2, non_linear=True, passing_v=False):
        super(BagClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        
        if non_linear:
            self.q = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh()
            )
        else:
            self.q = nn.Linear(input_dim, hidden_dim)
        
        if passing_v:
            self.v = nn.Sequential(
                nn.Dropout(dropout_v),
                nn.Linear(input_dim, input_dim),
                nn.ReLU()
            )
        else:
            self.v = nn.Identity()
            
        self.fc = FCLayer(input_dim, output_dim)
        
    def forward(self, features, classes):
        batch_size = features.size(0)
        num_instances = features.size(1)
        features_dim = features.size(2)
        
        combine_features = features.view(features.shape[0] * features.shape[1], -1)
        V = self.v(combine_features)
        Q = self.q(combine_features)
        assert V.shape[0] == Q.shape[0] == batch_size * num_instances, f'V: {V.shape}, Q: {Q.shape}'
        assert V.shape[1] == features_dim, f'V: {V.shape} should be [{batch_size * num_instances}, {features_dim}]'
        assert Q.shape[1] == self.hidden_dim, f'Q: {Q.shape} should be [{batch_size * num_instances}, {self.hidden_dim}]'
        
        # Get critical instance indices by squeezing classes
        critical_indices = classes.squeeze(-1).argmax(dim=1)
        assert critical_indices.shape[0] == batch_size, f'Critical indices: {critical_indices.shape}'

        # Gather features for each batch using critical instance indices
        m_features = features[torch.arange(batch_size).unsqueeze(1), critical_indices.unsqueeze(1)].squeeze()
        m_features = m_features.view(batch_size, -1)
        assert m_features.shape[0] == batch_size, f'M features: {m_features.shape} should be [{batch_size}, {features_dim}]'
        q_max = self.q(m_features)
        assert q_max.shape[0] == batch_size and q_max.shape[1] == self.hidden_dim, f'Q max: {q_max.shape} should be [{batch_size}, {self.hidden_dim}]'
        
        A = torch.mm(Q, q_max.mT)
        A = F.softmax(A / torch.sqrt(torch.tensor(Q.shape[-1]).float()), dim=0)
        assert A.shape[0] == batch_size * num_instances and A.shape[1] == batch_size, f'A: {A.shape} should be [{batch_size * num_instances}, {batch_size}]'
        
        B = torch.mm(A.T, V)
        assert B.shape[0] == batch_size and B.shape[1] == features_dim, f'B: {B.shape} should be [{batch_size}, {features_dim}]'
        
        B = B.view(1, B.shape[0], B.shape[1]) # Shape of B: [1, batch_size, features_dim]
        C = self.fc(B)
        C = C.view(1, -1) # Shape of C: [1, batch_size, 1] -> [1, batch_size]
        return C, A, B

### ResNet2D Model

In [31]:
class MILResNet18(nn.Module):
    def __init__(self):
        super(MILResNet18, self).__init__()
        self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.resnet.conv1 = nn.Conv2d(in_channels=CHANNELS, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        
        self.resnet.fc = nn.Identity()
        self.attention = AttentionLayer(input_dim=512, hidden_dim=512)
        
        self.classifier = nn.Linear(512, 1) 
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, bags):
        if CHANNELS == 1:
            batch_size, num_instances, c, h, w = bags.size()
        else:
            batch_size, num_instances, h, w, c = bags.size()

        bags_flattened = bags.view(batch_size * num_instances, c, h, w)
        
        # # Version 1: CNN-ResNet + Att + GP
        features = self.resnet(bags_flattened)
        features = self.dropout(features)
        features = features.view(batch_size, num_instances, -1)

        attended_features, attended_weights = self.attention(features)
        # attended_features_reshaped = attended_features.view(-1, 512)
        attended_features_reshaped = attended_features.view(batch_size, -1)

        # CNN_ATT_GP
        gp_output = self.gp_layer(attended_features_reshaped)
        gp_mean = gp_output.mean.view(batch_size, -1)

        combine_features = torch.cat((attended_features, gp_mean), dim=1)
        combine_features = self.dropout(combine_features)

        outputs = torch.sigmoid(self.classifier_v0(combine_features))
        return outputs, attended_weights, gp_output

## Encoder Model

In [32]:
class Encoder(nn.Module):
    def __init__(self, projection_dim=128):
        super(Encoder, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )
        
        self.instance_classifier = InstanceClassifier(512)
        self.bag_classifier = BagClassifier(512)
        
        inducing_points = torch.randn(32, 512) 
        self.gp_model = GPModel(inducing_points=inducing_points)
        self.fc = nn.Linear(512 + 1, 1)
        
    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        
        instances_features, classes = self.instance_classifier(x)
        
        features = instances_features.view(batch_size * num_instances, -1)  # Flatten to (batch_size * num_instances, feature_dim)
        
        projection_features = self.projection(features)
        
        predicted_bags, A, B = self.bag_classifier(instances_features, classes)
        gp_output = self.gp_model(B.squeeze())
        gp_mean = gp_output.mean.view(batch_size, -1)
        combined_features = torch.cat((B.squeeze(), gp_mean), dim=1) # [batch_size, 512 + 1]
        
        # # Assuming gp_output.mean has shape [batch_size, 1] and B has shape [batch_size, num_instances, 512]
        # gp_mean_expanded = gp_output.mean.unsqueeze(dim=1).unsqueeze(dim=0)  # Shape: [batch_size, 1] -> [1, batch_size, 1]
        # gp_mean_broadcasted = gp_mean_expanded.expand(-1, -1, B.shape[-1])  # Shape: [1, batch_size, 1] -> [1, batch_size, feature_dim]
        # combined_features = B + gp_mean_broadcasted  # Element-wise addition
        
        predicted_bags = torch.sigmoid(predicted_bags)
        
        combined_features = self.fc(combined_features.squeeze())
        combined_features = torch.sigmoid(combined_features)
        
        return projection_features, classes, predicted_bags, A, B, combined_features

## Training and Evaluation

#### Loss Function

In [33]:
def combined_loss(outputs, gp_distribution, target, alpha=0.5):
    # Cross-Entropy Loss for CNN outputs
    bce_loss = nn.BCELoss()(outputs.squeeze(), target.float())
    kl_divergence = gp_distribution.variational_strategy.kl_divergence()
    total_loss = (1 - alpha) * bce_loss + alpha * kl_divergence
    
    return total_loss

### Training

In [34]:
# def train_epoch(model, data_loader, criterion_cl, criterion_bce, optimizer, scheduler, device):
#     total_loss = 0.0
#     alpha = 0.5
#     predictions = []
#     labels = []
#     
#     correct_predictions = 0
#     total_samples = 0
#     
#     model.train()
# 
#     for batch_data, batch_labels, batch_patient_labels in data_loader:
#         batch_data = batch_data.to(device)
#         batch_patient_labels = batch_patient_labels.float().to(device)
#         optimizer.zero_grad()
#         
#         # Forward pass
#         aug1 = augment_batch(batch_data).cuda()
#         aug2 = augment_batch(batch_data).cuda()
# 
#         z_i, outputs_1, predicted_bags_1, _, _, gp_combine_1 = model(aug1)
#         z_j, outputs_2, predicted_bags_2, _, _, gp_combine_2 = model(aug2)
# 
#         NTXLoss = criterion_cl(z_i, z_j)
#         max_agg_1 = torch.max(outputs_1, dim=1).values.squeeze()
#         max_agg_2 = torch.max(outputs_2, dim=1).values.squeeze()
# 
#         loss_max_1 = criterion_bce(max_agg_1, batch_patient_labels)
#         loss_max_2 = criterion_bce(max_agg_2, batch_patient_labels)
#         # loss_bag_1 = criterion_bce(predicted_bags_1.squeeze(), batch_patient_labels)
#         # loss_bag_2 = criterion_bce(predicted_bags_2.squeeze(), batch_patient_labels)
#         loss_bag_1 = criterion_bce(gp_combine_1.squeeze(), batch_patient_labels)
#         loss_bag_2 = criterion_bce(gp_combine_2.squeeze(), batch_patient_labels)
# 
#         loss = NTXLoss * 0.4 + loss_max_1 * 0.15 + loss_max_2 * 0.15 + loss_bag_1 * 0.15 + loss_bag_2 * 0.15
#         loss = loss * 0.5 + model.gp_model.variational_strategy.kl_divergence() * 0.5
#         loss = loss.mean()
#         
#         # z, outputs, predicted_bags_1, A, B, gp_combine_1 = model(batch_data)
#         # max_agg_1 = torch.max(outputs, dim=1).values.squeeze()
#         # loss_max = criterion_bce(max_agg_1, batch_patient_labels)
#         # # loss_ss = criterion_bce(predicted_bags_1.squeeze(), batch_patient_labels)
#         # loss_bag = criterion_bce(gp_combine_1.squeeze(), batch_patient_labels)
#         # loss = (loss_max * 0.5 + loss_bag * 0.5) * 0.5 + model.gp_model.variational_strategy.kl_divergence() * 0.5
#         # loss = loss.mean()
#         
#         total_loss += loss.item()
#         
#         # Backward pass
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#         
#         # predicted = (predicted_bags_1.squeeze() > 0.5).cpu().detach().numpy()
#         predictions.extend((gp_combine_1.squeeze() > 0.5).cpu().detach().numpy())
#         labels.extend(batch_patient_labels.cpu().numpy())
#     
#     return total_loss / len(data_loader), predictions, labels


def train_epoch(model, data_loader, criterion_cl, criterion_bce, optimizer, scheduler, device):
    total_loss = 0.0
    alpha = 0.5
    predictions = []
    labels = []
    
    model.train()

    # Wrap data_loader with tqdm
    with tqdm(total=len(data_loader), desc="Training", unit="batch") as pbar:
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            optimizer.zero_grad()
            
            # Forward pass with augmentations
            aug1 = augment_batch(batch_data).cuda()
            aug2 = augment_batch(batch_data).cuda()

            z_i, outputs_1, predicted_bags_1, _, _, gp_combine_1 = model(aug1)
            z_j, outputs_2, predicted_bags_2, _, _, gp_combine_2 = model(aug2)

            NTXLoss = criterion_cl(z_i, z_j)
            max_agg_1 = torch.max(outputs_1, dim=1).values.squeeze()
            max_agg_2 = torch.max(outputs_2, dim=1).values.squeeze()

            loss_max_1 = criterion_bce(max_agg_1, batch_patient_labels)
            loss_max_2 = criterion_bce(max_agg_2, batch_patient_labels)
            loss_bag_1 = criterion_bce(gp_combine_1.squeeze(), batch_patient_labels)
            loss_bag_2 = criterion_bce(gp_combine_2.squeeze(), batch_patient_labels)

            loss = (NTXLoss * 0.4 + loss_max_1 * 0.15 + 
                    loss_max_2 * 0.15 + loss_bag_1 * 0.15 + 
                    loss_bag_2 * 0.15) * 0.7 + model.gp_model.variational_strategy.kl_divergence() * 0.3
            
            total_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Update predictions and labels
            predictions.extend((gp_combine_1.squeeze() > 0.5).cpu().detach().numpy())
            labels.extend(batch_patient_labels.cpu().numpy())

            # Update progress bar with the current loss
            pbar.set_postfix(loss=total_loss / (pbar.n + 1))  # Average loss so far
            pbar.update(1)  # Move the progress bar forward

    return total_loss / len(data_loader), predictions, labels

def validate(model, data_loader, criterion_cl, criterion_bce, device):
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    alpha = 0.5
    predictions = []
    labels = []

    with torch.no_grad():
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            
            z_i, output, predicted_bags, _, _, gp_combine = model(batch_data)
            max_agg = torch.max(output, dim=1).values.squeeze()
            
            # Store predictions and labels
            # predictions.extend((predicted_bags.squeeze() > 0.5).cpu().detach().numpy())
            predictions.extend((gp_combine.squeeze() > 0.5).cpu().detach().numpy())
            labels.extend(batch_patient_labels.cpu().numpy())

    return total_loss / len(data_loader), predictions, labels

def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions),
        "recall": recall_score(labels, predictions),
        "f1": f1_score(labels, predictions)
    }

def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch+1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

def train_model(model, train_loader, val_loader, criterion_cl, criterion_bce, optimizer, num_epochs, learning_rate, device='cuda'):
    """Train the model and return the best model based on validation accuracy."""
    model = model.to(device)

    model.train()

    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, 
                                              steps_per_epoch=len(train_loader), epochs=num_epochs)
    best_val_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        train_loss, train_predictions, train_labels = train_epoch(model, train_loader, criterion_cl, criterion_bce, optimizer, scheduler, device)
        train_metrics = calculate_metrics(train_predictions, train_labels)
        print_epoch_stats(epoch, num_epochs, "train", train_loss, train_metrics)
        # Log training metrics to W&B
        wandb.log({"train/loss": train_loss / len(train_loader), **train_metrics})

        # Validation phase
        val_loss, val_predictions, val_labels = validate(model, val_loader, criterion_cl, criterion_bce, device)
        val_metrics = calculate_metrics(val_predictions, val_labels)
        print_epoch_stats(epoch, num_epochs, "validation", val_loss, val_metrics)
        # Log validation metrics to W&B
        wandb.log({"val/loss": val_loss / len(val_loader), **val_metrics})

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            best_model_state = model.state_dict()

    # Load best model
    model.load_state_dict(best_model_state)
    # Optionally log the best model to W&B (if desired)
    wandb.log_artifact(wandb.Artifact("best_model", type="model", metadata={"accuracy": best_val_accuracy}))
    
    return model

### Evaluation Functions

In [35]:
## Model Evaluation Functions
def evaluate_model(model, data_loader, device='cuda'):
    """Evaluate the model on the given data loader."""
    model = model.to(device)

    model.eval()

    predictions = []
    labels = []

    with torch.inference_mode(): 
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)
            
            z_i, outputs, predicted_bags, _, _, gp_combine = model(batch_data)
            
            # predictions.extend((predicted_bags.squeeze() > 0.5).cpu().detach().numpy())
            predictions.extend((gp_combine.squeeze() > 0.5).cpu().detach().numpy())
            labels.extend(batch_patient_labels.cpu().numpy())

    return np.array(predictions), np.array(labels)

def print_metrics(metrics):
    """Print the calculated metrics."""
    print(f"Test Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, "
          f"Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

### Visualization Functions

In [36]:
## Visualization Functions
def plot_roc_curve(model, data_loader, device):
    """Plot the ROC curve for the model predictions."""
    model.eval()
    labels = []
    predictions = []
    
    with torch.no_grad():
        for batch_data, batch_labels, batch_patient_labels in data_loader:
            batch_data = batch_data.to(device)
            batch_patient_labels = batch_patient_labels.float().to(device)

            z_i, outputs, predicted_bags, _, _, gp_combine = model(batch_data)
            
            # predictions.extend(predicted_bags.squeeze().cpu().numpy())
            predictions.extend(gp_combine.squeeze().cpu().numpy())
            labels.extend(batch_patient_labels.cpu().numpy())
            
    fpr, tpr, _ = roc_curve(labels, predictions)
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

def plot_confusion_matrix(model, data_loader, device):
    """Plot the confusion matrix for the model predictions."""
    predictions, labels = evaluate_model(model, data_loader, device)
    
    cm = confusion_matrix(labels, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title('Confusion Matrix')
    plt.show()

### Data Processing Functions

In [37]:
## Data Processing Functions
def load_model(model_class, model_path):
    """Load a trained model from a file."""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")

    model = model_class()
    try:
        state_dict = torch.load(model_path, map_location=torch.device('cuda'), weights_only=True)
        if not state_dict:
            raise ValueError(f"The state dictionary loaded from {model_path} is empty")
        model.load_state_dict(state_dict)
    except Exception as e:
        print(f"Error loading model from {model_path}: {str(e)}")
        print("Initializing model with random weights instead.")
        return model  # Return the model with random initialization

    return model.eval()


def get_test_results(model, test_loader, test_labels, device=DEVICE):
    """Get test results including patient information."""
    predictions, _ = evaluate_model(model, test_loader, device)
    
    results = []
    for i, row in enumerate(test_labels.itertuples(index=False)):
        result = {col: getattr(row, col) for col in test_labels.columns}
        result['prediction'] = predictions[i]
        results.append(result)
    
    return pd.DataFrame(results)

## Visualizing Attention Weights and Images

In [38]:
def plot_label_attention_weights(model, data_loader, device='cuda'):
    """
    Plot images with their labels and attention values in a single large plot.

    Parameters:
    - model: The trained model
    - data_loader: DataLoader containing test dataset
    - device: Device to run the model on ('cuda' or 'cpu')
    - CHANNELS: Number of channels in the image (e.g., 1 for grayscale, 3 for RGB)

    Expected shapes:
    - 1-channel images: (batch_size, num_images, 224, 224)
    - 3-channel images: (batch_size, num_images, 3, 224, 224)
    - attention: float value per image indicating attention weight
    """
    model = model.to(device)
    model.eval()
    num_images = MAX_SLICES
    rows, cols = 10, 6  # Adjust to fit 60 images in a single plot

    with torch.no_grad():
        for batch_data, batch_labels, batch_patients_label in data_loader:
            # Move data to the appropriate device
            batch_data = batch_data.to(device)
            outputs, attention_weight_batch, _, _ = model(batch_data)

            # Process each patient in the batch
            for patient_idx in range(batch_data.size(0)):
                if batch_patients_label[patient_idx].item() == 1:  # Check if patient has positive label
                    # Create a new figure for this patient
                    fig = plt.figure(figsize=(cols * 4, rows * 4 + 2))  # Increased height for suptitle

                    for img_idx in range(num_images):
                        # Get the image and its label
                        img = batch_data[patient_idx, img_idx].cpu().numpy()
                        img_label = batch_labels[patient_idx, img_idx].cpu().numpy()
                        
                        # Get attention value
                        if attention_weight_batch.size(1) == batch_data.size(1):
                            attention_value = attention_weight_batch[patient_idx, img_idx].cpu().item()
                        else:
                            attention_value = attention_weight_batch[patient_idx].cpu().item()
                        
                        # Plot image
                        plt.subplot(rows, cols, img_idx + 1)
                        if CHANNELS == 3:  # RGB image
                            plt.imshow(img)
                        else:  # Grayscale image
                            if img.ndim == 3:  # If shape is (1, H, W)
                                img = np.squeeze(img)  # Convert to (H, W)
                            plt.imshow(img, cmap='gray')
                        
                        plt.title(f'Label: {img_label}\nAttention: {attention_value:.4f}', fontsize=12)
                        plt.axis('off')

                    # Add overall title for the patient
                    plt.suptitle(f'Patient Images (Patient Label: {batch_patients_label[patient_idx].cpu().numpy()})', fontsize=16)
                    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Adjust rect to make space for suptitle
                    plt.show()
                                      
                    # Since we are plotting only for one patient, return after the first plot
                    return

## Visualization Augmented Bags

In [39]:
def visualize_augmented_bags(original_bags, augmented_bags, num_bags=12):
    """
    Visualizes all instances of the first bag of original and augmented images.

    Parameters:
    - original_bags: A tensor of shape (batch_size, num_instances, channels, height, width)
    - augmented_bags: A tensor of shape (batch_size, num_instances, channels, height, width)
    - num_bags: Number of bags to visualize (only the first bag will be shown).
    """
    # Only visualize the first bag
    first_bag_index = 0
    
    # Get number of instances
    num_instances = original_bags.size(1)
    
    print(f'Num instances: {num_instances}')
    
    # Limit the number of bags to visualize (but we only show the first one)
    num_bags = min(num_bags, 1)  # We only want to visualize the first bag

    fig, axes = plt.subplots(num_instances, 2, figsize=(10, 2 * num_instances))
    
    # Original images
    for j in range(num_instances):  # Iterate over instances in the first bag
        img = original_bags[first_bag_index][j].cpu().numpy().squeeze()  # Remove channel dimension
        axes[j, 0].imshow(img, cmap='gray')  # Use gray colormap for single channel images
        axes[j, 0].axis('off')  # Hide axes for better visualization
        axes[j, 0].set_title(f'Original Instance {j + 1}')
        
    # Augmented images
    for j in range(num_instances):
        img = augmented_bags[first_bag_index][j].cpu().numpy().squeeze()  # Remove channel dimension
        axes[j, 1].imshow(img.squeeze(), cmap='gray')  # Use gray colormap for single channel images
        axes[j, 1].axis('off')  # Hide axes for better visualization
        axes[j, 1].set_title(f'Augmented Instance {j + 1}')

    plt.tight_layout()
    plt.show()


## Main

In [40]:
def set_seed(seed=42):
    """Set seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

In [None]:
import time

def main(mode='train'):
    os.environ["WANDB_DISABLED"] = "true"

    # Initialize W&B
    wandb.init(project="your_project_name")
    
    # Log hyperparameters
    config = wandb.config
    config.learning_rate = LEARNING_RATE
    config.batch_size = TRAIN_BATCH_SIZE
    config.num_epochs = NUM_EPOCHS

    set_seed()
    train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
    # train_labels = train_labels[:200]
    train_loader = get_train_loader(dicom_dir, train_labels, batch_size=TRAIN_BATCH_SIZE)
    val_loader = get_train_loader(dicom_dir, val_labels, batch_size=VALID_BATCH_SIZE)
    test_loader = get_test_loader(dicom_dir, test_labels, batch_size=TEST_BATCH_SIZE)

    set_seed()
    
    # Initialize model, criterion, and optimizer
    # model = MILResNet18()
    model = Encoder()
    
    criterion_cl = NTXentLoss()
    criterion_bce = torch.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

    if mode == 'train':
        # Watch the model to log gradients and parameters
        wandb.watch(model)

        # Train model
        trained_model = train_model(model, train_loader, val_loader, criterion_cl, criterion_bce, optimizer, config.num_epochs, config.learning_rate, DEVICE)

        # Save model
        torch.save(trained_model.state_dict(), MODEL_PATH)

    # Load best model
    trained_model = load_model(Encoder, MODEL_PATH)

    # Evaluate model
    predictions, labels = evaluate_model(trained_model, test_loader, DEVICE)
    metrics = calculate_metrics(predictions, labels)

    # Log metrics to W&B
    wandb.log(metrics)

    print_metrics(metrics)

    # Visualizations
    plot_roc_curve(trained_model, test_loader, DEVICE)
    plot_confusion_matrix(trained_model, test_loader, DEVICE)

    if mode == 'train':
        required_columns = ['patient_id', 'study_instance_uid', 'patient_label']
        temp_test_labels = test_labels[required_columns]
        
        # Save results
        results_df = get_test_results(trained_model, test_loader, temp_test_labels, device)
        results_df.to_csv('results/results.csv', index=False)
        print(results_df.head())

        # Log results DataFrame as a table in W&B (optional)
        wandb.log({"results": wandb.Table(dataframe=results_df)})

    # Call the function with the test_loader
    test_loader = get_test_loader(dicom_dir, test_labels, batch_size=TEST_BATCH_SIZE)
    # plot_label_attention_weights(trained_model, test_loader, device)
    
    # Get the first batch of images from the evaluation loader
    images, _, _ = next(iter(train_loader))
    print(f'Original batch shape: {images.shape}')
    
    # Augment the batch of images
    # Start timer 
    start = time.time()
    augmented_images = augment_batch(images)
    end = time.time()
    taken_time = end - start
    print(f'Augmented batch shape: {augmented_images.shape} | Time: {taken_time:.4f}')
    
    # Visualize the original and augmented bags
    visualize_augmented_bags(images, augmented_images)

if __name__ == "__main__":
    main(mode='train')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Training: 100%|██████████| 201/201 [01:58<00:00,  1.70batch/s, loss=1.99]

Epoch 1/10 - Train:
Loss: 1.9872, Accuracy: 0.6741, Precision: 0.0000, Recall: 0.0000, F1: 0.0000





Epoch 1/10 - Validation:
Loss: 0.0000, Accuracy: 0.6802, Precision: 0.0000, Recall: 0.0000, F1: 0.0000


Training: 100%|██████████| 201/201 [01:56<00:00,  1.72batch/s, loss=1.69]

Epoch 2/10 - Train:
Loss: 1.6917, Accuracy: 0.6741, Precision: 0.0000, Recall: 0.0000, F1: 0.0000





Epoch 2/10 - Validation:
Loss: 0.0000, Accuracy: 0.6802, Precision: 0.0000, Recall: 0.0000, F1: 0.0000


Training:  40%|███▉      | 80/201 [00:46<01:09,  1.74batch/s, loss=1.59]

## Results