## Import Libraries

In [1]:
!pip install gpytorch



In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pydicom
import cv2

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import torchvision.models as models

import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy

from torchvision import transforms
from scipy.ndimage import rotate
from skimage.transform import resize
from scipy.ndimage import gaussian_filter

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

## GPU Check

In [4]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")
    
%load_ext autoreload
%autoreload 2

print(device)

GPU: NVIDIA GeForce RTX 4070 SUPER is available.
cuda


## Configurations Initialization

In [5]:
# Constants
HEIGHT = 224
WIDTH = 224
CHANNELS = 1

TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
TEST_SIZE = 0.15
VALID_SIZE = 0.15

MAX_SLICES = 60
SHAPE = (HEIGHT, WIDTH, CHANNELS)

NUM_EPOCHS = 20
LEARNING_RATE = 5e-4
INDUCING_POINTS = 128

TARGET_LABELS = ['intraparenchymal']

MODEL_PATH = 'results/trained_model.pth'
DEVICE = 'cuda'

## Configurations for ViT Model
IMG_SIZE = 224
NUM_CLASSES = 1
PATCH_SIZE = 32
EMBEDDING_DIM = PATCH_SIZE * PATCH_SIZE * CHANNELS # (P^2 * C) -> (16^2 * 3) 
NUM_HEADS = 16
NUM_LAYERS = 12
MLP_SIZE = 3072

In [6]:
# Kaggle and local switch
KAGGLE = os.path.exists('/kaggle')
print("Running on Kaggle" if KAGGLE else "Running locally")
ROOT_DIR = '/kaggle/input/' if KAGGLE else None
DATA_DIR = ROOT_DIR + 'rsna-mil-training/' if KAGGLE else '../rsna-mil-training/'
DICOM_DIR = DATA_DIR + 'rsna-mil-training/'
CSV_PATH = DATA_DIR + 'training_1000_scan_subset.csv' if KAGGLE else './data_analyze/training_1000_scan_subset.csv'
SLICE_LABEL_PATH = ROOT_DIR + "sorted_training_dataset_with_labels.csv" if KAGGLE else './data_analyze/sorted_training_dataset_with_labels.csv'

dicom_dir = DICOM_DIR if KAGGLE else DATA_DIR
# Load patient scan labels
patient_scan_labels = pd.read_csv(CSV_PATH)
patient_slice_labels = pd.read_csv(SLICE_LABEL_PATH)

Running locally


## Data Preprocessing

### Windowing and Resizing

In [7]:
def correct_dcm(dcm):
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000

def window_image(dcm, window_center, window_width):    
    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
        correct_dcm(dcm)
    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept
    
    # Resize
    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)
   
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def bsb_window(dcm):
    brain_img = window_image(dcm, 40, 80)
    subdural_img = window_image(dcm, 80, 200)
    soft_img = window_image(dcm, 40, 380)
    
    brain_img = (brain_img - 0) / 80
    subdural_img = (subdural_img - (-20)) / 200
    soft_img = (soft_img - (-150)) / 380
    
    # bsb_img = np.stack([brain_img, subdural_img, soft_img], axis=-1)
    bsb_img = brain_img
    # bsb_img = subdural_img
    return bsb_img.astype(np.float16)

### Preprocess Slice

In [8]:
def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    # Check if type of slice is dicom or an empty numpy array
    if (type(slice) == np.ndarray):
        slice = resize(slice, target_size, anti_aliasing=True)
        multichannel_slice = np.stack([slice, slice, slice], axis=-1)
        # return multichannel_slice.astype(np.float16)
        return slice.astype(np.float16)
    else:
        slice = bsb_window(slice)
        return slice.astype(np.float16)

### Load Dicom Images

In [9]:
def read_dicom_folder(folder_path):
    slices = []
    for filename in sorted(os.listdir(folder_path))[:MAX_SLICES]:  # Limit to MAX_SLICES
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds)
            
    # Sort slices by images position (z-coordinate) in ascending order
    slices = sorted(slices, key=lambda x: float(x.ImagePositionPatient[2]))
    
    # Pad with black images if necessary
    while len(slices) < MAX_SLICES:
        slices.append(np.zeros_like(slices[0].pixel_array))
    
    return slices[:MAX_SLICES]  # Ensure we return exactly MAX_SLICES

## Dataset and DataLoader

### Splitting Dataset

In [10]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # If any of the hemorrhage indicators is 1, the label is 1, otherwise 0
    patient_scan_labels['label'] = patient_scan_labels[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any(axis=1).astype(int)

    # Extract the labels from the DataFrame
    labels = patient_scan_labels['label']

    # First, split off the test set
    train_val_labels, test_labels = train_test_split(
        patient_scan_labels, 
        test_size=test_size, 
        stratify=labels, 
        random_state=random_state
    )

    # Calculate the validation size relative to the train_val set
    val_size_adjusted = val_size / (1 - test_size)

    # Split the train_val set into train and validation sets
    train_labels, val_labels = train_test_split(
        train_val_labels, 
        test_size=val_size_adjusted, 
        stratify=train_val_labels['label'], 
        random_state=random_state
    )

    return train_labels, val_labels, test_labels

### MIL Data Preparation

In [11]:
def process_patient_data(dicom_dir, row, num_instances=12, depth=5):
    patient_id = row['patient_id'].replace('ID_', '')
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')
    
    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(dicom_dir, folder_name)
    
    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)
        
        # Ensure we have enough slices to create the specified instances
        if len(slices) < depth * num_instances:
            print(f"Not enough slices for patient {patient_id}: found {len(slices)}, needed {depth * num_instances}")
            return None, None
        
        preprocessed_slices = [preprocess_slice(slice) for slice in slices]
        
        # Stack preprocessed slices into an array
        preprocessed_slices = np.stack(preprocessed_slices, axis=0)  # (num_slices, height, width, channels)
        
        # Reshape to (num_instances, depth, height, width, channels)
        # reshaped_slices = preprocessed_slices[:num_instances * depth].reshape(num_instances, depth, *preprocessed_slices.shape[1:])  # (num_instances, depth, height, width, channels)
        
        # Labeling remains consistent  
        label = 1 if row[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any() else 0
        
        # return reshaped_slices, label
        return preprocessed_slices, label
    
    else:
        print(f"Folder not found: {folder_path}")
        return None, None

### Dataset Generator

In [12]:
class TrainDatasetGenerator(Dataset):
    """
    A custom dataset class for training data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        patient_study_instance = row['patient_id'] + '_' + row['study_instance_uid']
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            # Convert label to float32
            label = np.array(label, dtype=np.float32)
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.float32), patient_study_instance
        else:
            return None, None, None  # Handle the case where the folder is not found

class TestDatasetGenerator(Dataset):
    """
    A custom dataset class for testing data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        patient_study_instance = row['patient_id'] + '_' + row['study_instance_uid']
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            label = np.array(label, dtype=np.float32)
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.float32), patient_study_instance
        else:
            return None, None, None  # Handle the case where the folder is not found

In [13]:
def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

## Feature Extraction using Vision Transformer

### Patch Embedding Layer

In [14]:
# class PatchEmbedding(nn.Module):
#     def __init__(self, img_size, in_channels, patch_size, embed_dim):
#         super(PatchEmbedding, self).__init__()
#         self.patch_size = patch_size
#         self.embed_dim = embed_dim
#         
#         self.conv = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim,
#                               kernel_size=patch_size, stride=patch_size, padding=0)
# 
#     def forward(self, x):
#         """
#         Forward pass to create patch embeddings.
# 
#         Args:
#             x (torch.Tensor): Input tensor of shape (N, C, H, W).
# 
#         Returns:
#             torch.Tensor: Output tensor of shape (N, num_patches, embed_dim).
#         """
#         x = self.conv(x)  # Apply convolution to create patches 
#         # After convolution: shape (N, embed_dim, H/patch_size, W/patch_size)
#         
#         x = x.flatten(2)  # Flatten patches into a sequence 
#         # After flattening: shape (N, embed_dim, num_patches)
# 
#         return x.transpose(1, 2)  # Rearrange dimensions for transformer input 
#         # Final output shape: (N, num_patches, embed_dim)

class PatchEmbedding(nn.Module):
    def __init__(self, img_size:int, in_channels:int, patch_size:int, embed_dim:int):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        self.conv = nn.Conv2d(in_channels=in_channels,
                              out_channels=embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

        self.img_size = img_size
    
    def forward(self, x):
        # x shape: (batch_size * num_instances, channels, height, width)
        if len(x.shape) == 3: # If in_channels = 1
            x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv(x)  # Apply convolution to create patches
        
        # Flatten patches into a sequence of embeddings
        batch_size_num_instances = x.shape[0]
        num_patches = (self.img_size // self.patch_size) ** 2
        
        return x.view(batch_size_num_instances, num_patches, self.embed_dim)  # Shape: (batch * instances, num_patches, embed_dim)

### Multi-Head Self-Attention Layer

In [15]:
# 1. Create a class that inherits from nn.Module
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a multi-head self-attention block ("MSA block" for short).
    """
    # 2. Initialize the class with hyperparameters from Table 1
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multi-Head Attention (MSA) layer
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) # does our batch dimension come first?

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        assert type(x) == torch.Tensor, f"Input to MSA block is not a PyTorch tensor but a {type(x)}"
        attn_output, attn_weights = self.multihead_attn(query=x, # query embeddings
                                             key=x, # key embeddings
                                             value=x, # value embeddings
                                             need_weights=False) # do we need the weights or just the layer outputs?
        assert type(attn_output) == torch.Tensor, f"Output of MSA block is not a PyTorch tensor but a {type(attn_output)}"
        return attn_output

### Multi-Layer Perceptron Layer

In [16]:
# 1. Create a class that inherits from nn.Module
class MLPBlock(nn.Module):
    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 dropout:float=0.1): # Dropout from Table 3 for ViT-Base
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multilayer perceptron (MLP) layer(s)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
                      out_features=embedding_dim), # take back to embedding_dim
            nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
        )

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        assert type(x) == torch.Tensor, f"Input to MLP block is not a PyTorch tensor but a {type(x)}"
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

### Transformer Encoder Layer

In [17]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
                 attn_dropout:float=0): # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)

        # 4. Create MLP block (equation 3)
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    # 5. Create a forward() method
    def forward(self, x):
        # 6. Create residual connection for MSA block (add the input to the output)
        msa = self.msa_block(x)
        assert type(msa) == torch.Tensor, f"Output of MSA block is not a PyTorch tensor but a {type(msa)}"
    
        x =  msa + x
        # 7. Create residual connection for MLP block (add the input to the output)
        x = self.mlp_block(x) + x

        return x

### Gaussian Process Layer

In [18]:
class GPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

### Attention Layer

In [19]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            # nn.Tanh(),
            # nn.ReLU(),
            nn.PReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)

        return (x * weights).sum(dim=1), weights.squeeze(-1)

### Vision Transformer Model

In [20]:
# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, # Training resolution from Table 3 in ViT paper
                 in_channels:int=3, # Number of channels in input image
                 patch_size:int=16, # Patch size
                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0, # Dropout for attention projection
                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers
                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
                 num_classes:int=1000): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!

        # 3. Make the image size is divisible by the patch size
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        # 4. Calculate number of patches (height * width/patch^2)
        self.num_patches = (img_size * img_size) // patch_size**2
        
        self.embedding_dim = embedding_dim

        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(4 * 60, 1, embedding_dim),
                                            requires_grad=True)

        # 6. Create learnable position embedding
        self.position_embedding = nn.Parameter(data=torch.randn(4 * 60, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # 8. Create patch embedding layer
        self.patch_embedding = PatchEmbedding(img_size=img_size,
                                              in_channels=in_channels,
                                              patch_size=patch_size,
                                              embed_dim=embedding_dim)

        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
        # Note: The "*" means "all"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        # 10. Create classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )
        
    def forward(self, x):
        # Get batch size and number of instances
        batch_size, num_instances = x.shape[0], x.shape[1]
    
        # Create class token embedding and expand it to match the batch size and number of instances
        class_token = self.class_embedding.expand(batch_size * num_instances, 1, self.embedding_dim)
    
        # Process each instance through the patch embedding layer
        x = self.patch_embedding(x.view(-1, *x.shape[2:]))  # Flatten instances for patch embedding
        assert x.shape[0] == batch_size * num_instances, f"Number of instances is incorrect: {x.shape[0]} instances, expected {batch_size * num_instances}"
        assert x.shape[1] == self.num_patches, f"Number of patches is incorrect: {x.shape[1]} patches, expected {self.num_patches}"
        
        
        # Concatenate class token with patch embeddings
        x = torch.cat((class_token, x), dim=1)  # Concatenate along the last dimension
        # Shape of x: (batch_size * num_instances, num_patches + 1, embed_dim)
    
        # Add position embedding
        x = self.position_embedding + x
    
        # Apply dropout
        x = self.embedding_dropout(x)
    
        # Pass through transformer encoder layers
        x = self.transformer_encoder(x)
    
        # Aggregate outputs from all instances (e.g., mean pooling)
        x = x.mean(dim=1)
        # Shape of x: (batch_size, num_instances, embed_dim)
        x = x.view(batch_size, num_instances, -1).mean(dim=1)
    
        # Run through classifier head
        x = self.classifier(x)
    
        return x.squeeze(-1)


## Training

### Model Training

In [21]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0 
    predictions = []
    labels = []
    
    for batch_data, batch_labels, _ in train_loader:
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend((outputs > 0.5).cpu().detach().numpy())
        labels.extend(batch_labels.cpu().numpy())
    return total_loss / len(train_loader), predictions, labels

### Model Validation

In [22]:
def validate(model, valid_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    predictions = []
    labels = []
    
    with torch.inference_mode():
        for batch_data, batch_labels, _ in valid_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            
            total_loss += loss.item()
            predictions.extend((outputs > 0.5).cpu().detach().numpy())
            labels.extend(batch_labels.cpu().numpy())
    return total_loss / len(valid_loader), predictions, labels

### Helper Functions for Training

In [23]:
def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions),
        "recall": recall_score(labels, predictions),
        "f1": f1_score(labels, predictions)
    }

def print_metrics(metrics):
    """Print performance metrics."""
    print(f"Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, "
          f"Recall: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}")

def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch+1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, learning_rate, device='cuda'):
    """Train the model and return the best model based on validation accuracy."""
    model = model.to(device)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, 
                                              steps_per_epoch=len(train_loader), epochs=num_epochs)
    best_val_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epochs):
        # Training phase
        train_loss, train_predictions, train_labels = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
        train_metrics = calculate_metrics(train_predictions, train_labels)
        print_epoch_stats(epoch, num_epochs, "train", train_loss, train_metrics)

        # Validation phase
        val_loss, val_predictions, val_labels = validate(model, val_loader, criterion, device)
        val_metrics = calculate_metrics(val_predictions, val_labels)
        print_epoch_stats(epoch, num_epochs, "validation", val_loss, val_metrics)

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            best_model_state = model.state_dict()

    # Load best model
    model.load_state_dict(best_model_state)
    return model

### Model Evaluation

In [24]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    predictions = []
    labels = []
    
    with torch.inference_mode():
        for batch_data, batch_labels, _ in test_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            
            total_loss += loss.item()
            predictions.extend((outputs > 0.5).cpu().detach().numpy())
            labels.extend(batch_labels.cpu().numpy())
    
    return total_loss / len(test_loader), predictions, labels

## Visualize Functions 

### ROC_AUC Curve

In [25]:
## Visualization Functions
def plot_roc_curve(model, data_loader, device):
    """Plot the ROC curve for the model predictions."""
    # predictions, labels = evaluate_model(model, data_loader, device)
    
    model.eval()
    labels = []
    predictions = []
    # with torch.no_grad():
    with torch.inference_mode():
        for batch_data, batch_labels, _ in data_loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.float().to(device)

            outputs, attention_weights = model(batch_data)
            outputs = outputs.squeeze()
            predictions.extend(outputs.cpu().numpy())
            labels.extend(batch_labels.cpu().numpy())
            
    fpr, tpr, _ = roc_curve(labels, predictions)
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

### Confusion Matrix

In [26]:
def plot_confusion_matrix(model, data_loader, device):
    """Plot the confusion matrix for the model predictions."""
    predictions, labels = evaluate_model(model, data_loader, device)
    
    cm = confusion_matrix(labels, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title('Confusion Matrix')
    plt.show()

## Main

In [27]:
## Data Processing Functions
def load_model(model_class, model_path):
    """Load a trained model from a file."""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found at {model_path}")

    model = model_class()
    try:
        state_dict = torch.load(model_path, map_location=torch.device('cuda'), weights_only=True)
        if not state_dict:
            raise ValueError(f"The state dictionary loaded from {model_path} is empty")
        model.load_state_dict(state_dict)
    except Exception as e:
        print(f"Error loading model from {model_path}: {str(e)}")
        print("Initializing model with random weights instead.")
        return model  # Return the model with random initialization

    return model.eval()


def get_test_results(model, test_loader, test_labels):
    """Get test results including patient information."""
    predictions, _ = evaluate_model(model, test_loader)
    
    results = []
    for i, row in enumerate(test_labels.itertuples(index=False)):
        result = {col: getattr(row, col) for col in test_labels.columns}
        result['prediction'] = predictions[i]
        results.append(result)
    
    return pd.DataFrame(results)

In [28]:
def main(mode='train'):
    train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
    # Decrease the number of samples for training
    train_labels = train_labels.sample(n=200, random_state=42)
    train_loader = get_train_loader(dicom_dir, train_labels)
    val_loader = get_train_loader(dicom_dir, val_labels)
    test_loader = get_test_loader(dicom_dir, test_labels)
    
    # Initialize the model, criterion, and optimizer
    model = ViT(img_size=IMG_SIZE, 
                in_channels=CHANNELS,
                patch_size=PATCH_SIZE,
                num_transformer_layers=NUM_LAYERS,
                embedding_dim=EMBEDDING_DIM,
                mlp_size=MLP_SIZE,
                num_heads=NUM_HEADS,
                num_classes=NUM_CLASSES)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    if mode == 'train':
        trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, LEARNING_RATE, device)
        
        torch.save(trained_model.state_dict(), MODEL_PATH)
    
    # Load best model
    trained_model = model.load_state_dict(torch.load(MODEL_PATH))
    
    # Evaluate model 
    test_loss, test_predictions, test_labels = evaluate_model(trained_model, test_loader, criterion, device)
    
    # Calculate performance metrics
    test_metrics = calculate_metrics(test_predictions, test_labels)
    
    # Print performance metrics
    print_metrics(test_metrics)
    # Visualizations 
    plot_roc_curve(trained_model, test_loader, device)
    plot_confusion_matrix(trained_model, test_loader, device)
    
    if mode == 'train':
        # Select only the required columns
        required_columns = ['patient_id', 'study_instance_uid', 'label']
        temp_test_labels = test_labels[required_columns]
        
        # Save results
        results_df = get_test_results(trained_model, test_loader, temp_test_labels)
        results_df.to_csv('results/results.csv', index=False)
        print(results_df.head())   
        
if __name__ == '__main__':
    main(mode='train')

TypeError: layer_norm(): argument 'input' (position 1) must be Tensor, not torch.return_types.max

## Result