## Import Lib

In [1]:
!pip install gpytorch
!pip install torchsummary
!pip install iterative-stratification

Collecting gpytorch
  Downloading gpytorch-1.13-py3-none-any.whl.metadata (8.0 kB)
Collecting jaxtyping==0.2.19 (from gpytorch)
  Downloading jaxtyping-0.2.19-py3-none-any.whl.metadata (5.7 kB)
Collecting linear-operator>=0.5.3 (from gpytorch)
  Downloading linear_operator-0.5.3-py3-none-any.whl.metadata (15 kB)
Downloading gpytorch-1.13-py3-none-any.whl (277 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m277.8/277.8 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.2.19-py3-none-any.whl (24 kB)
Downloading linear_operator-0.5.3-py3-none-any.whl (176 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.4/176.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxtyping, linear-operator, gpytorch
Successfully installed gpytorch-1.13 jaxtyping-0.2.19 linear-operator-0.5.3
Collecting iterative-stratification
  Downloading iterative_stratification-0.1.9-py3-none-any.whl.metadata (1.3 kB)
Downloa

In [2]:
import os.path

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
import pandas as pd
import random
import numpy as np
import cv2
import pydicom
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.nn import functional as F
import gpytorch
from gpytorch.kernels import ScaleKernel, RBFKernel, MaternKernel
from torchsummary import summary

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

def seed_everything(seed=39):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

Device: cuda
GPU: Tesla P100-PCIE-16GB is available.


In [3]:
# Kaggle and local switch
KAGGLE = os.path.exists('/kaggle')
HPC = os.path.exists('/media02/tdhoang01')

if KAGGLE:
    print("Running on Kaggle")
    ROOT_DIR = '/kaggle/input/'
    DATA_DIR = ROOT_DIR + 'rsna-ich-mil/'
    DICOM_DIR = DATA_DIR + 'rsna-ich-mil/'
    CSV_PATH = DICOM_DIR + 'training_dataset_2_redundancy_1150_for_kaggle.csv'
elif HPC:
    print("Running on HPC")
    DATA_DIR = '/media02/tdhoang01/21127112-21127734/data/'
    DICOM_DIR = DATA_DIR + 'rsna-ich-mil/'
    CSV_PATH = DATA_DIR + 'training_dataset_1150_redundancy.csv'
else: # Local
    print("Running locally")
    DATA_DIR = '../rsna-mil-training/'
    DICOM_DIR = DATA_DIR + 'rsna-mil-training/'
    CSV_PATH = DATA_DIR + 'training_1000_scan_subset.csv'

dicom_dir = DICOM_DIR if KAGGLE or HPC else DATA_DIR
# Load patient scan labels
patient_scan_labels = pd.read_csv(CSV_PATH)

Running on Kaggle


In [4]:
# Accessing constants from config
HEIGHT = 224
WIDTH = 224
CHANNELS = 1

TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
TEST_BATCH_SIZE = 2
TEST_SIZE = 0.15
VALID_SIZE = 0.25

TRAINING_TYPE = 'end_to_end'
GP_MODEL = 'multi_task' # or 'single_task'
GP_KERNEL = 'rbf'

MAX_SLICES = 28
SHAPE = (HEIGHT, WIDTH, CHANNELS)

NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
INDUCING_POINTS = 50
THRESHOLD = 0.5

NUM_CLASSES = 6

TARGET_LABELS = 'intraparenchymal'

MODEL_PATH = 'results/trained_model.pth'
DEVICE = 'cuda'

PROJECTION_LOCATION = 'after_gp'
PROJECTION_HIDDEN_DIM = 256
PROJECTION_OUTPUT_DIM = 128

ATTENTION_HIDDEN_DIM = 128

## Hyperparameters definition

In [5]:
# # Hyperparameters
batch_size = 64
learning_rate = 1e-4
num_epochs = 10
img_size = 224  # MNIST images are 28x28 pixels
num_classes = 6  # Digits from 0 to 9
patch_size = 14  # Size of each patch (7x7)
embedding_dim = 64  # Dimensionality of the embeddings
num_heads = 4  # Number of attention heads
num_layers = 6  # Number of transformer layers
dropout_rate = 0.5  # Dropout rate for regularization

## Prepare the Dataset

In [6]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels['patient_label']
    if test_size > 0:
        # First, split off the test set
        train_val_labels, test_labels = train_test_split(
            patient_scan_labels,
            test_size=test_size,
            stratify=labels,
            random_state=random_state
        )

        # Calculate the validation size relative to the train_val set
        val_size_adjusted = val_size / (1 - test_size)

        # Split the train_val set into train and validation sets
        train_labels, val_labels = train_test_split(
            train_val_labels,
            test_size=val_size_adjusted,
            stratify=train_val_labels['patient_label'],
            random_state=random_state
        )
    else:
        train_labels, val_labels = train_test_split(
            patient_scan_labels,
            test_size=val_size,
            stratify=labels,
            random_state=random_state
        )
        test_labels = None

    return train_labels, val_labels, test_labels

def split_dataset_for_multilabel(patient_scan_labels, test_size=0.15, val_size=0.25, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels[['patient_any', 'patient_epidural', 'patient_intraparenchymal',
                                  'patient_intraventricular', 'patient_subarachnoid', 'patient_subdural']].values

    if test_size != 0:
        # First split: train + test
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        train_idx, test_idx = next(msss.split(patient_scan_labels, labels))

        train_labels = patient_scan_labels.iloc[train_idx]
        test_labels = patient_scan_labels.iloc[test_idx]

        # Second split: train + validation
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(train_labels, labels[train_idx]))

        train_labels_final = train_labels.iloc[train_idx]
        val_labels = train_labels.iloc[val_idx]

    else:
        # Only split into train and validation if test_size is 0
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(patient_scan_labels, labels))

        train_labels_final = patient_scan_labels.iloc[train_idx]
        val_labels = patient_scan_labels.iloc[val_idx]
        test_labels = patient_scan_labels_test

    return train_labels_final, val_labels, test_labels

In [7]:
def correct_dcm(dcm):
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000

def create_bone_mask(dcm):
    # Assuming dcm.pixel_array contains the HU values
    hu_values = dcm.pixel_array

    # Create a mask for bone regions
    # bone_mask = (hu_values >= 1000) & (hu_values <= 1200)
    bone_mask = (hu_values >= 1000) & (hu_values <= 1200)
    return bone_mask

def extract_bone_mask(dcm):
    # Create the bone mask
    bone_mask = create_bone_mask(dcm)

    # Extract the bone mask from the image
    hu_values = dcm.pixel_array.copy()
    # hu_values[bone_mask] = 0
    hu_values[~bone_mask] = 0

    # Update the DICOM pixel data
    dcm.PixelData = hu_values.tobytes()

def window_image(dcm, window_center, window_width):
    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
        correct_dcm(dcm)
    # extract_bone_mask(dcm)
    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept

    # Resize
    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)

    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def bsb_window(dcm):
    brain_img = window_image(dcm, 40, 80)
    subdural_img = window_image(dcm, 80, 200)
    soft_img = window_image(dcm, 40, 380)

    brain_img = (brain_img - 0) / 80
    subdural_img = (subdural_img - (-20)) / 200
    soft_img = (soft_img - (-150)) / 380

    if CHANNELS == 3:
        bsb_img = np.stack([brain_img, subdural_img, soft_img], axis=-1)
    else:
        bsb_img = brain_img
    return bsb_img.astype(np.float16)

def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    # Check if type of slice is dicom or an empty numpy array
    if type(slice) == np.ndarray:
        slice = resize(slice, target_size, anti_aliasing=True)
        multichannel_slice = np.stack([slice, slice, slice], axis=-1)
        if CHANNELS == 3:
            return multichannel_slice.astype(np.float16)
        else:
            return slice.astype(np.float16)
    else:
        slice = bsb_window(slice)
        return slice.astype(np.float16)

def read_dicom_folder(folder_path, max_slices=MAX_SLICES):
    # Filter and sort DICOM files directly based on ImagePositionPatient
    dicom_files = sorted(
        [f for f in os.listdir(folder_path) if f.endswith(".dcm")],
        key=lambda f: float(pydicom.dcmread(os.path.join(folder_path, f)).ImagePositionPatient[2])
    )[:max_slices]

    # Read and store slices
    slices = [pydicom.dcmread(os.path.join(folder_path, f)) for f in dicom_files]

    # Pad with black images if necessary
    if len(slices) < max_slices:
        black_image = np.zeros_like(slices[0].pixel_array)
        slices += [black_image] * (max_slices - len(slices))

    return slices[:max_slices]

def process_patient_data(dicom_dir, row):
    patient_id = row['patient_id'].replace('ID_', '')
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')

    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(dicom_dir, folder_name)

    if os.path.exists(folder_path):
        try:
            slices = read_dicom_folder(folder_path)

            preprocessed_slices = [torch.tensor(preprocess_slice(slice), dtype=torch.float32) for slice in slices]  # Convert to tensor

            # Stack preprocessed slices into an array
            preprocessed_slices = torch.stack(preprocessed_slices, dim=0)  # (num_slices, height, width, channels)

            # Labels are already in list form, so just convert them to a tensor
            labels = torch.tensor(row['labels'], dtype=torch.long)

            # Fill labels with 0s if necessary
            if len(preprocessed_slices) > len(labels):
                padded_labels = torch.zeros(len(preprocessed_slices), dtype=torch.long)
                padded_labels[:len(labels)] = labels
            else:
                padded_labels = labels[:len(preprocessed_slices)]

            return preprocessed_slices, padded_labels
        except Exception as e:
            print(f"Error processing patient data: {e} ; dicom_dir: {dicom_dir}, folder_name: {folder_name}")
            return None, None


    else:
        print(f"Folder not found: {folder_name}")
        return None, None

In [8]:
# Full Dataset
class MedicalScanDataset:
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        self.data_dir = data_dir
        self.dataset = self._parse_patient_scan_labels(patient_scan_labels)
        self.augmentor = augmentor

    def _parse_patient_scan_labels(self, patient_scan_labels):
        """Parse and validate patient scan labels."""
        patient_scan_labels['filename'] = patient_scan_labels['filename'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        patient_scan_labels['labels'] = patient_scan_labels['labels'].apply(
            lambda x: eval(x) if isinstance(x, str) else x
        )
        # Convert multi-label columns from string representation to actual lists
        multi_label_columns = ['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']

        for column in multi_label_columns:
            patient_scan_labels[column] = patient_scan_labels[column].apply(
                lambda x: eval(x) if isinstance(x, str) else x
            )

        patient_scan_labels['patient_label'] = patient_scan_labels['patient_label'].astype(bool)
        return patient_scan_labels

    def _process_patient_data(self, row):
        """Process patient data to get preprocessed slices and labels."""
        return process_patient_data(self.data_dir, row)

    def __len__(self):
        return len(self.dataset) * (self.augmentor.levels if self.augmentor else 1)

    def __getitem__(self, idx):
        patient_idx = idx // (self.augmentor.levels if self.augmentor else 1)
        aug_level = idx % (self.augmentor.levels if self.augmentor else 1)

        row = self.dataset.iloc[patient_idx]
        preprocessed_slices, labels = self._process_patient_data(row)

        preprocessed_slices = self._prepare_tensor(preprocessed_slices, aug_level if self.augmentor else None)
        patient_label = torch.tensor(bool(row['patient_label']), dtype=torch.uint8)

        multi_class_labels = torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.uint8)
        if any(row['any']):
            multi_class_labels[0] = 1
        if any(row['epidural']):
            multi_class_labels[1] = 1
        if any(row['intraparenchymal']):
            multi_class_labels[2] = 1
        if any(row['intraventricular']):
            multi_class_labels[3] = 1
        if any(row['subarachnoid']):
            multi_class_labels[4] = 1
        if any(row['subdural']):
            multi_class_labels[5] = 1

        return preprocessed_slices, labels, patient_label, multi_class_labels

    def _prepare_tensor(self, preprocessed_slices, aug_level):
        # Convert to numpy array and then to torch tensor
        preprocessed_slices = np.asarray(preprocessed_slices, dtype=np.float32)
        preprocessed_slices = torch.tensor(preprocessed_slices, dtype=torch.float32)

        # Add an additional dimension for channel if it's missing (no augmentor)
        if preprocessed_slices.ndim == 3:
            preprocessed_slices = preprocessed_slices.unsqueeze(1)  # shape: [slices, 1, H, W]

        # Apply augmentation if augmentor is specified
        if self.augmentor and aug_level is not None:
            if preprocessed_slices.ndim == 4:  # Ensure it has the [slices, channels, H, W] format
                return torch.stack([self.augmentor.apply_transform(img, aug_level) for img in preprocessed_slices])

        return preprocessed_slices  # Return without augmentation if augmentor is None


In [9]:
class TrainDatasetGenerator(MedicalScanDataset):
    """Dataset class for training medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

class TestDatasetGenerator(MedicalScanDataset):
    """Dataset class for testing medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    # original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=augmentor)
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

In [10]:
def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average='macro'),
        "recall": recall_score(labels, predictions, average='macro'),
        "f1": f1_score(labels, predictions, average='macro')
    }

def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch+1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

In [11]:
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

# train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
train_labels, val_labels, test_labels = split_dataset_for_multilabel(patient_scan_labels, test_size=TEST_SIZE)

train_loader = get_train_loader(dicom_dir, train_labels, batch_size=TRAIN_BATCH_SIZE)
val_loader = get_train_loader(dicom_dir, val_labels, batch_size=VALID_BATCH_SIZE)
test_loader = get_test_loader(dicom_dir, test_labels, batch_size=TEST_BATCH_SIZE)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels['filename'] = patient_scan_labels['filename'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels['labels'] = patient_scan_labels['labels'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels[column] = patient_scan_labels[column].apply

## Patch Embedding Layer
* A module that converts input images into patch embeddings suitable for a Vision Transformer.

* This module takes an input image, divides it into patches, and then embeds each patch into a vector of 
specified dimensionality using a convolutional layer.

* Attributes:
    patch_size (int): The size of each square patch.
    embed_dim (int): The dimensionality of the output embedding vector for each patch.
    conv (nn.Conv2d): A convolutional layer that extracts patches from the input image.

* Input Shape:
    - Input tensor `x`: Shape (N, C, H, W)
        - N: Batch size
        - C: Number of channels (1 for grayscale images like MNIST)
        - H: Height of the input image
        - W: Width of the input image

* Output Shape:
    - Output tensor: Shape (N, num_patches, embed_dim)
        - N: Batch size
        - num_patches: The number of patches extracted from the image, calculated as:
          $$ \text{num\_patches} = \left(\frac{H}{\text{patch\_size}}\right) \times \left(\frac{W}{\text{patch\_size}}\right) $$
        - embed_dim: The dimensionality of the embedding for each patch.

* Methods:
    forward(x): Forward pass to compute the patch embeddings from input images.


In [12]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim,
                              kernel_size=patch_size, stride=patch_size, padding=0)

    def forward(self, x):
        """
        Forward pass to create patch embeddings.

        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (N, num_patches, embed_dim).
        """
        x = self.conv(x)  # Apply convolution to create patches 
        # After convolution: shape (N, embed_dim, H/patch_size, W/patch_size)
        
        x = x.flatten(2)  # Flatten patches into a sequence 
        # After flattening: shape (N, embed_dim, num_patches)

        return x.transpose(1, 2)  # Rearrange dimensions for transformer input 
        # Final output shape: (N, num_patches, embed_dim)

## Multihead Self-Attention Layer

In [13]:
# 1. Create a class that inherits from nn.Module
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a multi-head self-attention block ("MSA block" for short).
    """
    # 2. Initialize the class with hyperparameters from Table 1
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multi-Head Attention (MSA) layer
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) # does our batch dimension come first?

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, attn_weights = self.multihead_attn(query=x, # query embeddings
                                             key=x, # key embeddings
                                             value=x, # value embeddings
                                             need_weights=True) # do we need the weights or just the layer outputs?
        return attn_output

## MLP Block Layer

In [14]:
# 1. Create a class that inherits from nn.Module
class MLPBlock(nn.Module):
    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 dropout:float=0.1): # Dropout from Table 3 for ViT-Base
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multilayer perceptron (MLP) layer(s)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
                      out_features=embedding_dim), # take back to embedding_dim
            nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
        )

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

## Transformer Encoder Block

In [15]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
                 attn_dropout:float=0): # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)

        # 4. Create MLP block (equation 3)
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    # 5. Create a forward() method
    def forward(self, x):

        # 6. Create residual connection for MSA block (add the input to the output)
        # x =  self.msa_block(x) + x
        attn_output = self.msa_block(x)
        
        assert type(attn_output) == torch.Tensor, "The MSA block output should be a PyTorch tensor."
        x = attn_output + x

        # 7. Create residual connection for MLP block (add the input to the output)
        x = self.mlp_block(x) + x

        return x

## Vision Transformer Model

In [16]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)

        return (x * weights).sum(dim=1), weights.squeeze(-1)

# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, # Training resolution from Table 3 in ViT paper
                 in_channels:int=3, # Number of channels in input image
                 patch_size:int=16, # Patch size
                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0, # Dropout for attention projection
                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers
                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
                 num_classes:int=1000): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!

        # 3. Make the image size is divisible by the patch size
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        # 4. Calculate number of patches (height * width/patch^2)
        self.num_patches = (img_size * img_size) // patch_size**2

        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)

        # 6. Create learnable position embedding
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # 8. Create patch embedding layer
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              embed_dim=embedding_dim)

        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
        # Note: The "*" means "all"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        # 10. Create classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )

        # Attention layer
        self.attention_layer = AttentionLayer(embedding_dim * (self.num_patches + 1),
                                                hidden_dim=128)
        self.cls_attention = nn.Linear(embedding_dim * (self.num_patches + 1), num_classes)

    # 11. Create a forward() method
    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        class_token = self.class_embedding.expand(batch_size * num_instances, -1, -1)
        x = x.view(batch_size * num_instances, channels, height, width)
        x = self.patch_embedding(x)
        x = torch.cat((class_token, x), dim=1)
        x = self.position_embedding + x
        x = self.embedding_dropout(x)
        x = self.transformer_encoder(x) # Shape: (batch_size * num_instances, num_patches + 1, hidden_size)
        x = x.view(batch_size, num_instances, -1)
        x = x.max(dim=1).values
        x = self.cls_attention(x)
        return x.squeeze(-1)

In [17]:
class PGLikelihood(gpytorch.likelihoods._OneDimensionalLikelihood):
    # contribution to Eqn (10) in Reference [1].
    def expected_log_prob(self, target, input, *args, **kwargs):
        mean, variance = input.mean, input.variance
        # Compute the expectation E[f_i^2]
        raw_second_moment = variance + mean.pow(2)

        # Translate targets to be -1, 1
        target = target.to(mean.dtype).mul(2.).sub(1.)

        # We detach the following variable since we do not want
        # to differentiate through the closed-form PG update.
        c = raw_second_moment.detach().sqrt()
        half_omega = 0.25 * torch.tanh(0.5 * c) / c
        res = 0.5 * target * mean - half_omega * raw_second_moment
        res = res.sum(dim=-1)

        return res

    # define the likelihood
    def forward(self, function_samples):
        return torch.distributions.Bernoulli(logits=function_samples)

    # define the marginal likelihood using Gauss Hermite quadrature
    def marginal(self, function_dist):
        prob_lambda = lambda function_samples: self.forward(function_samples).probs
        probs = self.quadrature(prob_lambda, function_dist)
        return torch.distributions.Bernoulli(probs=probs)


class SingletaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, kernel_type='rbf', nu=2.5):
        """
        Args:
            inducing_points (torch.Tensor):
            kernel_type (str):
            nu (float):
        """
        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super(SingletaskGPModel, self).__init__(variational_strategy)

        self.mean_module = gpytorch.means.ConstantMean()

        if kernel_type == 'rbf':
            self.covar_module = ScaleKernel(RBFKernel())
        elif kernel_type == 'matern_kernel':
            self.covar_module = ScaleKernel(MaternKernel(nu=nu))
        else:
            raise ValueError("kernel_type must be either 'rbf' or 'matern_kernel'")

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_latents, num_tasks, hidden_dim=512):
        # Let's use a different set of inducing points for each latent function
        inducing_points = torch.rand(num_latents, hidden_dim, 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_latents])
        )

        # We have to wrap the VariationalStrategy in a LMCVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
            num_latents=num_latents,
            latent_dim=-1
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
            batch_shape=torch.Size([num_latents])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [18]:
class BiGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional=True, dropout=0.6):
        super(BiGRU, self).__init__()
        self.bigru = nn.GRU(input_size, hidden_size, num_layers,
                             batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.num_layers = num_layers

    def forward(self, x, hidden_states=None):
        if hidden_states is None:
            num_directions = 2 if self.bigru.bidirectional else 1
            hidden_states = torch.zeros(num_directions * self.bigru.num_layers,
                                        x.size(0), self.bigru.hidden_size,
                                        device=x.device)
        gru_out, _ = self.bigru(x, hidden_states)
        return gru_out

class InstanceAttention(nn.Module):
    def __init__(self, input_dim, num_classes=1):
        super(InstanceAttention, self).__init__()
        # Feature-level attention
        self.feature_attention = nn.Linear(input_dim, input_dim)
        # Slice-level attention
        self.slice_attention = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, num_instances, input_dim)

        # Feature-level attention
        feature_scores = self.feature_attention(x)  # (batch_size, num_instances, input_dim)
        feature_attention_weights = torch.nn.functional.softmax(feature_scores, dim=2)
        # Apply feature-level attention
        x_weighted_features = x * feature_attention_weights  # (batch_size, num_instances, input_dim)

        # Slice-level attention
        slice_scores = self.slice_attention(x_weighted_features).squeeze(-1)  # (batch_size, num_instances)
        slice_attention_weights = torch.nn.functional.softmax(slice_scores, dim=1)
        slice_attention_weights = slice_attention_weights.unsqueeze(-1)  # (batch_size, num_instances, 1)

        return x_weighted_features, slice_attention_weights

class CNN_ATT_GRU(nn.Module):
    def __init__(self, num_layers=2, input_channels=1,
                 cnn_feature_size=512, gru_hidden_size=256, num_classes=1, dropout_gru=0.3, dropout_fc=0.5, cnn='resnet'):
        super(CNN_ATT_GRU, self).__init__()
        self.cnn = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.DEFAULT)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.cnn.fc = nn.Identity()

        self.attention = AttentionLayer(input_dim=cnn_feature_size, hidden_dim=cnn_feature_size)
        self.bigru = BiGRU(input_size=cnn_feature_size,
                           hidden_size=gru_hidden_size,
                           num_layers=num_layers,
                           dropout=dropout_gru)
        self.fc = nn.Sequential(
            nn.Linear(2 * gru_hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        x = x.view(batch_size * num_instances, channels, height, width)
        x = self.cnn(x)
        features = x.view(batch_size, num_instances, -1)

        gru_features = self.bigru(features)
        attention_features, _ = self.attention(gru_features)

        # Pass aggregated features through FC layers
        output = self.fc(attention_features)

        return output

## Model & Optimizer Definition

In [20]:
model = ViT(
    img_size=img_size,
    in_channels=1,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_heads=num_heads,
    num_transformer_layers=num_layers,
    num_classes=num_classes
).to(device)

# model = CNN_ATT_GRU(num_layers=2, input_channels=1,
#                     cnn_feature_size=512, gru_hidden_size=256, num_classes=num_classes, dropout_gru=0.3, dropout_fc=0.25, cnn='resnet').to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()
likelihood = PGLikelihood().to(device)

# summary(model, input_size=(TRAIN_BATCH_SIZE * MAX_SLICES, 1, 224, 224))

## Training Step

In [21]:
# Function to calculate metrics for multi-label classification
def calculate_metrics(preds, labels):
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    recall = recall_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    f1 = f1_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Function to print epoch statistics
def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    print(f'Epoch [{epoch+1}/{num_epochs}], {phase.capitalize()}')
    print(f'Loss: {loss:.4f}, '
          f'Accuracy: {metrics["accuracy"]:.4f}, Precision: {metrics["precision"]:.4f}, '
          f'Recall: {metrics["recall"]:.4f}, F1: {metrics["f1"]:.4f}')

In [22]:
# Ensure the device is set to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move the model to the appropriate device

# Use BCEWithLogitsLoss for multi-label classification
criterion = nn.BCEWithLogitsLoss()

# Initialize variables to track the best validation performance
best_val_acc = 0  # Use infinity as the initial best loss
best_model_state_dict = None  # To store the best model's state dict

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    # Training phase
    for imgs, _, labels, multi_labels in train_loader:
        optimizer.zero_grad()

        imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU

        outputs = model(imgs)
        loss = criterion(outputs, multi_labels.float())  # Use multi_labels for loss calculation
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        total_loss += loss.item()
        preds = outputs.ge(0.5).float()  # Apply sigmoid and threshold
        all_preds.append(preds.cpu())
        all_labels.append(multi_labels.cpu())

    avg_loss = total_loss / len(train_loader)

    # Concatenate all predictions and labels for metric calculation
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    train_metrics = calculate_metrics(all_preds.numpy(), all_labels.numpy())

    print_epoch_stats(epoch, num_epochs, "train", avg_loss, train_metrics)

    # Validation phase
    model.eval()
    val_total_loss = 0
    val_all_preds = []
    val_all_labels = []

    with torch.no_grad():  # Disable gradient calculation for validation
        for imgs, _, labels, multi_labels in val_loader:
            imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU
            outputs = model(imgs)  # Forward pass
            loss = criterion(outputs, multi_labels.float())  # Use multi_labels for loss calculation
            val_total_loss += loss.item()
            preds = outputs.ge(0.5).float()  # Apply sigmoid and threshold
            val_all_preds.append(preds.cpu())
            val_all_labels.append(multi_labels.cpu())

    val_avg_loss = val_total_loss / len(val_loader)

    # Concatenate validation predictions and labels for metric calculation
    val_all_preds = torch.cat(val_all_preds)
    val_all_labels = torch.cat(val_all_labels)

    val_metrics = calculate_metrics(val_all_preds.numpy(), val_all_labels.numpy())

    print_epoch_stats(epoch, num_epochs, "validation", val_avg_loss, val_metrics)

    # Check if the current validation accuracy is the best so far
    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        best_model_state_dict = model.state_dict()

# After training, load the best model for testing
print(f'Best validation accuracy: {best_val_acc:.4f}')
model.load_state_dict(best_model_state_dict)
model.to(device)

Epoch [1/1], Train
Loss: 0.5693, Accuracy: 0.4008, Precision: 0.0972, Recall: 0.0408, F1: 0.0561
Epoch [1/1], Validation
Loss: 0.5615, Accuracy: 0.5595, Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Best validation accuracy: 0.5595


ViT(
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (patch_embedding): PatchEmbedding(
    (conv): Conv2d(1, 64, kernel_size=(14, 14), stride=(14, 14))
  )
  (transformer_encoder): Sequential(
    (0): TransformerEncoderBlock(
      (msa_block): MultiheadSelfAttentionBlock(
        (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
      )
      (mlp_block): MLPBlock(
        (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=64, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=3072, out_features=64, bias=True)
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (1): TransformerEncoderBlock(
      (msa_block): Multih

## Attention Map Visualization

In [23]:
import matplotlib.pyplot as plt

def plot_attention_map(attn_weights, img_shape):
    # Assuming attn_weights is of shape (num_heads, seq_length, seq_length)
    attn_map = attn_weights.mean(dim=0)  # Average over heads

    plt.figure(figsize=(8, 8))
    plt.imshow(attn_map.detach().cpu().numpy(), cmap='viridis')
    plt.colorbar()
    plt.title('Attention Map')
    plt.show()

## Evaluation Step

In [24]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss

model.eval()
all_preds = []
all_labels = []

with torch.inference_mode():
    for imgs, _, _, multi_labels in test_loader:
        imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU
        outputs = model(imgs)  # Forward pass

        # Apply sigmoid and threshold to get binary predictions
        predicted = outputs.ge(0.5).float()

        # Collect all predictions and labels for computing metrics
        all_preds.append(predicted.cpu())
        all_labels.append(multi_labels.cpu())

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Convert to numpy arrays for metric calculation
all_preds = all_preds.numpy()
all_labels = all_labels.numpy()

# Compute subset accuracy (exact match)
subset_accuracy = accuracy_score(all_labels, all_preds)

# Compute precision, recall, and F1 score (use 'samples' or 'macro' for multi-label)
precision = precision_score(all_labels, all_preds, average='samples')
recall = recall_score(all_labels, all_preds, average='samples')
f1 = f1_score(all_labels, all_preds, average='samples')

# Compute Hamming loss (lower is better)
hamming_loss_value = hamming_loss(all_labels, all_preds)

print(f'Subset Accuracy (Exact Match): {subset_accuracy:.4f}')
print(f'Precision (Samples): {precision:.4f}')
print(f'Recall (Samples): {recall:.4f}')
print(f'F1 Score (Samples): {f1:.4f}')
print(f'Hamming Loss: {hamming_loss_value:.4f}')

Subset Accuracy (Exact Match): 0.5333
Precision (Samples): 0.0000
Recall (Samples): 0.0000
F1 Score (Samples): 0.0000
Hamming Loss: 0.2167
