## Import Lib

In [1]:
import os.path

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score
import random
import numpy as np

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

def seed_everything(seed=39):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

import warnings
from sklearn.exceptions import UndefinedMetricWarning

warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

Device: cuda
GPU: NVIDIA GeForce RTX 4070 SUPER is available.


In [2]:
KAGGLE = os.path.exists(('kaggle/input'))
ROOT_DIR = None
DATA_DIR = ROOT_DIR + 'rsna-mil-training/' if KAGGLE else '../rsna-ich-mil/'
DICOM_DIR = DATA_DIR
CSV_PATH = DICOM_DIR + 'training_1000_scan_subset.csv' if KAGGLE else './data_analyze/training_dataset_2_redundancy.csv'
# patient_scan_labels = pd.read_csv(CSV_PATH, nrows=1150)
patient_scan_labels = pd.read_csv(CSV_PATH)
dicom_dir = DICOM_DIR if KAGGLE else DATA_DIR

In [3]:
import yaml
if os.path.exists('/media02/tdhoang01/python-debugging/config.yaml'):
    path = '/media02/tdhoang01/python-debugging/config.yaml'
else:
    path = '../config.yaml'
with open(path) as file:
    config = yaml.safe_load(file)

# Accessing constants from config
HEIGHT = config['height']
WIDTH = config['width']
CHANNELS = config['channels']

TRAIN_BATCH_SIZE = config['train_batch_size']
VALID_BATCH_SIZE = config['valid_batch_size']
TEST_BATCH_SIZE = config['test_batch_size']
TEST_SIZE = config['test_size']
VALID_SIZE = config['valid_size']

TRAINING_TYPE = config['training_type']
GP_MODEL = config['gp_model']
GP_KERNEL = config['kernel_type']

MAX_SLICES = config['max_slices']
SHAPE = tuple(config['shape'])

NUM_EPOCHS = config['num_epochs']
LEARNING_RATE = config['learning_rate']
INDUCING_POINTS = config['inducing_points']
THRESHOLD = config['threshold']

NUM_CLASSES = config['num_classes']

TARGET_LABELS = config['target_labels']

MODEL_PATH = config['model_path']
DEVICE = config['device']

PROJECTION_LOCATION = config['projection_location']
PROJECTION_HIDDEN_DIM = config['projection_hidden_dim']
PROJECTION_OUTPUT_DIM = config['projection_output_dim']

ATTENTION_HIDDEN_DIM = config['attention_hidden_dim']

## Hyperparameters definition

In [4]:
# # Hyperparameters
batch_size = 64
learning_rate = 1e-4
num_epochs = 10
img_size = 224  # MNIST images are 28x28 pixels
num_classes = 6  # Digits from 0 to 9
patch_size = 14  # Size of each patch (7x7)
embedding_dim = 64  # Dimensionality of the embeddings
num_heads = 4  # Number of attention heads
num_layers = 6  # Number of transformer layers
dropout_rate = 0.5  # Dropout rate for regularization

## Prepare the Dataset

In [5]:
def split_dataset(patient_scan_labels, test_size=TEST_SIZE, val_size=VALID_SIZE, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels['patient_label']
    if test_size > 0:
        # First, split off the test set
        train_val_labels, test_labels = train_test_split(
            patient_scan_labels,
            test_size=test_size,
            stratify=labels,
            random_state=random_state
        )

        # Calculate the validation size relative to the train_val set
        val_size_adjusted = val_size / (1 - test_size)

        # Split the train_val set into train and validation sets
        train_labels, val_labels = train_test_split(
            train_val_labels,
            test_size=val_size_adjusted,
            stratify=train_val_labels['patient_label'],
            random_state=random_state
        )
    else:
        train_labels, val_labels = train_test_split(
            patient_scan_labels,
            test_size=val_size,
            stratify=labels,
            random_state=random_state
        )
        test_labels = None

    return train_labels, val_labels, test_labels

from sklearn.model_selection import train_test_split
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

def split_dataset_for_multilabel(patient_scan_labels, test_size=0.15, val_size=0.25, random_state=42):
    # Extract the labels from the DataFrame
    labels = patient_scan_labels[['patient_any', 'patient_epidural', 'patient_intraparenchymal',
                                  'patient_intraventricular', 'patient_subarachnoid', 'patient_subdural']].values

    if test_size > 0:
        # First split: train + test
        msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        train_idx, test_idx = next(msss.split(patient_scan_labels, labels))

        train_labels = patient_scan_labels.iloc[train_idx]
        test_labels = patient_scan_labels.iloc[test_idx]

        # Second split: train + validation
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(train_labels, labels[train_idx]))

        train_labels_final = train_labels.iloc[train_idx]
        val_labels = train_labels.iloc[val_idx]

    else:
        # Only split into train and validation if test_size is 0
        msss_val = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=random_state)
        train_idx, val_idx = next(msss_val.split(patient_scan_labels, labels))

        train_labels_final = patient_scan_labels.iloc[train_idx]
        val_labels = patient_scan_labels.iloc[val_idx]
        test_labels = None

    return train_labels_final, val_labels, test_labels

In [6]:
from dataset_generators.RSNA_Dataset import MedicalScanDataset
class TrainDatasetGenerator(MedicalScanDataset):
    """Dataset class for training medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

class TestDatasetGenerator(MedicalScanDataset):
    """Dataset class for testing medical scan data."""
    def __init__(self, data_dir, patient_scan_labels, augmentor=None):
        super().__init__(data_dir, patient_scan_labels, augmentor)

def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE):
    # original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=augmentor)
    original_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(original_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels, augmentor=None)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

In [7]:
def calculate_metrics(predictions, labels):
    """Calculate and return performance metrics."""
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average='macro'),
        "recall": recall_score(labels, predictions, average='macro'),
        "f1": f1_score(labels, predictions, average='macro')
    }

def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    """Print statistics for an epoch."""
    print(f"Epoch {epoch+1}/{num_epochs} - {phase.capitalize()}:")
    print(f"Loss: {loss:.4f}, Accuracy: {metrics['accuracy']:.4f}, "
          f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, "
          f"F1: {metrics['f1']:.4f}")

In [8]:
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
transform = weights.transforms()
# print(transform)

# train_labels, val_labels, test_labels = split_dataset(patient_scan_labels, test_size=TEST_SIZE)
train_labels, val_labels, test_labels = split_dataset_for_multilabel(patient_scan_labels, test_size=TEST_SIZE)
# test_labels = pd.read_csv('./data_analyze/testing_dataset_150.csv')
train_loader = get_train_loader(dicom_dir, train_labels, batch_size=TRAIN_BATCH_SIZE)
val_loader = get_train_loader(dicom_dir, val_labels, batch_size=VALID_BATCH_SIZE)
test_loader = get_test_loader(dicom_dir, test_labels, batch_size=TEST_BATCH_SIZE)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels['filename'] = patient_scan_labels['filename'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels['labels'] = patient_scan_labels['labels'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  patient_scan_labels[column] = patient_scan_labels[column].apply

## Patch Embedding Layer
* A module that converts input images into patch embeddings suitable for a Vision Transformer.

* This module takes an input image, divides it into patches, and then embeds each patch into a vector of 
specified dimensionality using a convolutional layer.

* Attributes:
    patch_size (int): The size of each square patch.
    embed_dim (int): The dimensionality of the output embedding vector for each patch.
    conv (nn.Conv2d): A convolutional layer that extracts patches from the input image.

* Input Shape:
    - Input tensor `x`: Shape (N, C, H, W)
        - N: Batch size
        - C: Number of channels (1 for grayscale images like MNIST)
        - H: Height of the input image
        - W: Width of the input image

* Output Shape:
    - Output tensor: Shape (N, num_patches, embed_dim)
        - N: Batch size
        - num_patches: The number of patches extracted from the image, calculated as:
          $$ \text{num\_patches} = \left(\frac{H}{\text{patch\_size}}\right) \times \left(\frac{W}{\text{patch\_size}}\right) $$
        - embed_dim: The dimensionality of the embedding for each patch.

* Methods:
    forward(x): Forward pass to compute the patch embeddings from input images.


In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, in_channels, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim,
                              kernel_size=patch_size, stride=patch_size, padding=0)

    def forward(self, x):
        """
        Forward pass to create patch embeddings.

        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: Output tensor of shape (N, num_patches, embed_dim).
        """
        x = self.conv(x)  # Apply convolution to create patches 
        # After convolution: shape (N, embed_dim, H/patch_size, W/patch_size)
        
        x = x.flatten(2)  # Flatten patches into a sequence 
        # After flattening: shape (N, embed_dim, num_patches)

        return x.transpose(1, 2)  # Rearrange dimensions for transformer input 
        # Final output shape: (N, num_patches, embed_dim)

## Multihead Self-Attention Layer

In [10]:
# 1. Create a class that inherits from nn.Module
class MultiheadSelfAttentionBlock(nn.Module):
    """Creates a multi-head self-attention block ("MSA block" for short).
    """
    # 2. Initialize the class with hyperparameters from Table 1
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0): # doesn't look like the paper uses any dropout in MSABlocks
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multi-Head Attention (MSA) layer
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                    num_heads=num_heads,
                                                    dropout=attn_dropout,
                                                    batch_first=True) # does our batch dimension come first?

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        attn_output, attn_weights = self.multihead_attn(query=x, # query embeddings
                                             key=x, # key embeddings
                                             value=x, # value embeddings
                                             need_weights=True) # do we need the weights or just the layer outputs?
        return attn_output

## MLP Block Layer

In [11]:
# 1. Create a class that inherits from nn.Module
class MLPBlock(nn.Module):
    """Creates a layer normalized multilayer perceptron block ("MLP block" for short)."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden Size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 dropout:float=0.1): # Dropout from Table 3 for ViT-Base
        super().__init__()

        # 3. Create the Norm layer (LN)
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

        # 4. Create the Multilayer perceptron (MLP) layer(s)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=embedding_dim,
                      out_features=mlp_size),
            nn.GELU(), # "The MLP contains two layers with a GELU non-linearity (section 3.1)."
            nn.Dropout(p=dropout),
            nn.Linear(in_features=mlp_size, # needs to take same in_features as out_features of layer above
                      out_features=embedding_dim), # take back to embedding_dim
            nn.Dropout(p=dropout) # "Dropout, when used, is applied after every dense layer.."
        )

    # 5. Create a forward() method to pass the data through the layers
    def forward(self, x):
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x

## Transformer Encoder Block

In [12]:
# 1. Create a class that inherits from nn.Module
class TransformerEncoderBlock(nn.Module):
    """Creates a Transformer Encoder block."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 mlp_dropout:float=0.1, # Amount of dropout for dense layers from Table 3 for ViT-Base
                 attn_dropout:float=0): # Amount of dropout for attention layers
        super().__init__()

        # 3. Create MSA block (equation 2)
        self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     attn_dropout=attn_dropout)

        # 4. Create MLP block (equation 3)
        self.mlp_block =  MLPBlock(embedding_dim=embedding_dim,
                                   mlp_size=mlp_size,
                                   dropout=mlp_dropout)

    # 5. Create a forward() method
    def forward(self, x):

        # 6. Create residual connection for MSA block (add the input to the output)
        # x =  self.msa_block(x) + x
        attn_output = self.msa_block(x)
        
        assert type(attn_output) == torch.Tensor, "The MSA block output should be a PyTorch tensor."
        x = attn_output + x

        # 7. Create residual connection for MLP block (add the input to the output)
        x = self.mlp_block(x) + x

        return x

## Vision Transformer Model

In [13]:
from torch.nn import functional as F

class AttentionLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, num_instances, feature_dim)
        attention_weights = self.attention(x)
        weights = F.softmax(attention_weights, dim=1)

        return (x * weights).sum(dim=1), weights.squeeze(-1)

# 1. Create a ViT class that inherits from nn.Module
class ViT(nn.Module):
    """Creates a Vision Transformer architecture with ViT-Base hyperparameters by default."""
    # 2. Initialize the class with hyperparameters from Table 1 and Table 3
    def __init__(self,
                 img_size:int=224, # Training resolution from Table 3 in ViT paper
                 in_channels:int=3, # Number of channels in input image
                 patch_size:int=16, # Patch size
                 num_transformer_layers:int=12, # Layers from Table 1 for ViT-Base
                 embedding_dim:int=768, # Hidden size D from Table 1 for ViT-Base
                 mlp_size:int=3072, # MLP size from Table 1 for ViT-Base
                 num_heads:int=12, # Heads from Table 1 for ViT-Base
                 attn_dropout:float=0, # Dropout for attention projection
                 mlp_dropout:float=0.1, # Dropout for dense/MLP layers
                 embedding_dropout:float=0.1, # Dropout for patch and position embeddings
                 num_classes:int=1000): # Default for ImageNet but can customize this
        super().__init__() # don't forget the super().__init__()!

        # 3. Make the image size is divisible by the patch size
        assert img_size % patch_size == 0, f"Image size must be divisible by patch size, image size: {img_size}, patch size: {patch_size}."

        # 4. Calculate number of patches (height * width/patch^2)
        self.num_patches = (img_size * img_size) // patch_size**2

        # 5. Create learnable class embedding (needs to go at front of sequence of patch embeddings)
        self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
                                            requires_grad=True)

        # 6. Create learnable position embedding
        self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
                                               requires_grad=True)

        # 7. Create embedding dropout value
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # 8. Create patch embedding layer
        self.patch_embedding = PatchEmbedding(img_size=img_size,
                                              in_channels=in_channels,
                                              patch_size=patch_size,
                                              embed_dim=embedding_dim)

        # 9. Create Transformer Encoder blocks (we can stack Transformer Encoder blocks using nn.Sequential())
        # Note: The "*" means "all"
        self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                                            num_heads=num_heads,
                                                                            mlp_size=mlp_size,
                                                                            mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])

        # 10. Create classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=embedding_dim),
            nn.Linear(in_features=embedding_dim,
                      out_features=num_classes)
        )

        # Attention layer
        self.attention_layer = AttentionLayer(embedding_dim * (self.num_patches + 1),
                                                hidden_dim=128)
        self.cls_attention = nn.Linear(embedding_dim * (self.num_patches + 1), num_classes)

    # 11. Create a forward() method
    def forward(self, x):
        # 12. Get batch size
        batch_size, num_instances, channels, height, width = x.size()
        # 13. Create class token embedding and expand it to match the batch size (equation 1)
        class_token = self.class_embedding.expand(batch_size * num_instances, -1, -1) # "-1" means to infer the dimension (try this line on its own)
        # 14. Create patch embedding (equation 1)
        x = x.view(batch_size * num_instances, channels, height, width)
        x = self.patch_embedding(x)
        # 15. Concat class embedding and patch embedding (equation 1)
        x = torch.cat((class_token, x), dim=1)
        # 16. Add position embedding to patch embedding (equation 1)
        x = self.position_embedding + x
        # 17. Run embedding dropout (Appendix B.1)
        x = self.embedding_dropout(x)
        
        # attn_weights_list = [] # List to store the attention weights from each layer
        # for layer in self.transformer_encoder:
        #     x, attn_weights = layer(x)
        #     attn_weights_list.append(attn_weights)
        
        # x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
        # return x, attn_weights_list
        
        # 18. Pass patch, position and class embedding through transformer encoder layers (equations 2 & 3)
        x = self.transformer_encoder(x) # Shape: (batch_size * num_instances, num_patches + 1, hidden_size)
        # 19. Put 0 index logit through classifier (equation 4)
        # x = self.classifier(x[:, 0]) # run on each sample in a batch at 0 index
        # 19.1
        # x = x[:, 0] # Get the class token

        # Flatten the sequence of patches
        x = x.view(batch_size, num_instances, -1)
        # x, _ = self.attention_layer(x)
        # Apply max pooling
        x = x.max(dim=1).values
        x = self.cls_attention(x)
        return x.squeeze(-1)

In [14]:
import layers.gaussian_process as GPModel

In [15]:
class BiGRU(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional=True, dropout=0.6):
        super(BiGRU, self).__init__()
        self.bigru = nn.GRU(input_size, hidden_size, num_layers,
                             batch_first=True, bidirectional=bidirectional, dropout=dropout)
        self.num_layers = num_layers

    def forward(self, x, hidden_states=None):
        if hidden_states is None:
            num_directions = 2 if self.bigru.bidirectional else 1
            hidden_states = torch.zeros(num_directions * self.bigru.num_layers,
                                        x.size(0), self.bigru.hidden_size,
                                        device=x.device)
        gru_out, _ = self.bigru(x, hidden_states)
        return gru_out

class InstanceAttention(nn.Module):
    def __init__(self, input_dim, num_classes=1):
        super(InstanceAttention, self).__init__()
        # Feature-level attention
        self.feature_attention = nn.Linear(input_dim, input_dim)
        # Slice-level attention
        self.slice_attention = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, num_instances, input_dim)

        # Feature-level attention
        feature_scores = self.feature_attention(x)  # (batch_size, num_instances, input_dim)
        feature_attention_weights = torch.nn.functional.softmax(feature_scores, dim=2)
        # Apply feature-level attention
        x_weighted_features = x * feature_attention_weights  # (batch_size, num_instances, input_dim)

        # Slice-level attention
        slice_scores = self.slice_attention(x_weighted_features).squeeze(-1)  # (batch_size, num_instances)
        slice_attention_weights = torch.nn.functional.softmax(slice_scores, dim=1)
        slice_attention_weights = slice_attention_weights.unsqueeze(-1)  # (batch_size, num_instances, 1)

        return x_weighted_features, slice_attention_weights

class CNN_ATT_GRU(nn.Module):
    def __init__(self, num_layers=2, input_channels=1,
                 cnn_feature_size=512, gru_hidden_size=256, num_classes=1, dropout_gru=0.3, dropout_fc=0.5, cnn='resnet'):
        super(CNN_ATT_GRU, self).__init__()
        self.cnn = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.DEFAULT)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.cnn.fc = nn.Identity()

        self.attention = AttentionLayer(input_dim=cnn_feature_size, hidden_dim=cnn_feature_size)
        self.bigru = BiGRU(input_size=cnn_feature_size,
                           hidden_size=gru_hidden_size,
                           num_layers=num_layers,
                           dropout=dropout_gru)
        self.fc = nn.Sequential(
            nn.Linear(2 * gru_hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_fc),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        batch_size, num_instances, channels, height, width = x.size()
        x = x.view(batch_size * num_instances, channels, height, width)
        x = self.cnn(x)
        features = x.view(batch_size, num_instances, -1)

        gru_features = self.bigru(features)
        attention_features, _ = self.attention(gru_features)

        # Pass aggregated features through FC layers
        output = self.fc(attention_features)

        return output

## Model & Optimizer Definition

In [16]:
# model = ViT(
#     img_size=img_size,
#     in_channels=1,
#     patch_size=patch_size,
#     embedding_dim=embedding_dim,
#     num_heads=num_heads,
#     num_transformer_layers=num_layers,
#     num_classes=num_classes
# )
import gpytorch

model = CNN_ATT_GRU(num_layers=2, input_channels=1,
                    cnn_feature_size=512, gru_hidden_size=256, num_classes=num_classes, dropout_gru=0.3, dropout_fc=0.25, cnn='resnet')

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()
likelihood = GPModel.PGLikelihood().to(device)
# mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(train_loader.dataset))
# variational_ngd_optimizer = gpytorch.optim.NGD(model.gp_layer.variational_parameters(), num_data=len(train_loader.dataset), lr=0.01)

# Print model summary
from torchsummary import summary

summary(model, input_size=(4 * 28, 1, 224, 224))

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       3,136
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─BasicBlock: 3-1              73,984
|    |    └─BasicBlock: 3-2              73,984
|    └─Sequential: 2-6                   --
|    |    └─BasicBlock: 3-3              230,144
|    |    └─BasicBlock: 3-4              295,424
|    └─Sequential: 2-7                   --
|    |    └─BasicBlock: 3-5              919,040
|    |    └─BasicBlock: 3-6              1,180,672
|    └─Sequential: 2-8                   --
|    |    └─BasicBlock: 3-7              3,673,088
|    |    └─BasicBlock: 3-8              4,720,640
|    └─AdaptiveAvgPool2d: 2-9            --
|    └─Identity: 2-10                    --
├─AttentionLayer: 1-2                    --
|    └─Sequential: 2-11



Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       3,136
|    └─BatchNorm2d: 2-2                  128
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─BasicBlock: 3-1              73,984
|    |    └─BasicBlock: 3-2              73,984
|    └─Sequential: 2-6                   --
|    |    └─BasicBlock: 3-3              230,144
|    |    └─BasicBlock: 3-4              295,424
|    └─Sequential: 2-7                   --
|    |    └─BasicBlock: 3-5              919,040
|    |    └─BasicBlock: 3-6              1,180,672
|    └─Sequential: 2-8                   --
|    |    └─BasicBlock: 3-7              3,673,088
|    |    └─BasicBlock: 3-8              4,720,640
|    └─AdaptiveAvgPool2d: 2-9            --
|    └─Identity: 2-10                    --
├─AttentionLayer: 1-2                    --
|    └─Sequential: 2-11

## Training Step

In [17]:
# Function to calculate metrics for multi-label classification
def calculate_metrics(preds, labels):
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    recall = recall_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    f1 = f1_score(labels, preds, average='samples')  # Use 'samples' for multi-label
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Function to print epoch statistics
def print_epoch_stats(epoch, num_epochs, phase, loss, metrics):
    print(f'Epoch [{epoch+1}/{num_epochs}], {phase.capitalize()}')
    print(f'Loss: {loss:.4f}, '
          f'Accuracy: {metrics["accuracy"]:.4f}, Precision: {metrics["precision"]:.4f}, '
          f'Recall: {metrics["recall"]:.4f}, F1: {metrics["f1"]:.4f}')

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import torch.nn as nn

# Ensure the device is set to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)  # Move the model to the appropriate device

# Use BCEWithLogitsLoss for multi-label classification
criterion = nn.BCEWithLogitsLoss()

# Initialize variables to track the best validation performance
best_val_acc = 0  # Use infinity as the initial best loss
best_model_state_dict = None  # To store the best model's state dict

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    # Training phase
    for imgs, _, labels, multi_labels in train_loader:
        optimizer.zero_grad()

        imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU

        outputs = model(imgs)
        loss = criterion(outputs, multi_labels.float())  # Use multi_labels for loss calculation
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

        total_loss += loss.item()
        preds = outputs.ge(0.5).float()  # Apply sigmoid and threshold
        all_preds.append(preds.cpu())
        all_labels.append(multi_labels.cpu())

    avg_loss = total_loss / len(train_loader)

    # Concatenate all predictions and labels for metric calculation
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    train_metrics = calculate_metrics(all_preds.numpy(), all_labels.numpy())

    print_epoch_stats(epoch, num_epochs, "train", avg_loss, train_metrics)

    # Validation phase
    model.eval()
    val_total_loss = 0
    val_all_preds = []
    val_all_labels = []

    with torch.no_grad():  # Disable gradient calculation for validation
        for imgs, _, labels, multi_labels in val_loader:
            imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU
            outputs = model(imgs)  # Forward pass
            loss = criterion(outputs, multi_labels.float())  # Use multi_labels for loss calculation
            val_total_loss += loss.item()
            preds = outputs.ge(0.5).float()  # Apply sigmoid and threshold
            val_all_preds.append(preds.cpu())
            val_all_labels.append(multi_labels.cpu())

    val_avg_loss = val_total_loss / len(val_loader)

    # Concatenate validation predictions and labels for metric calculation
    val_all_preds = torch.cat(val_all_preds)
    val_all_labels = torch.cat(val_all_labels)

    val_metrics = calculate_metrics(val_all_preds.numpy(), val_all_labels.numpy())

    print_epoch_stats(epoch, num_epochs, "validation", val_avg_loss, val_metrics)

    # Check if the current validation accuracy is the best so far
    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        best_model_state_dict = model.state_dict()

# After training, load the best model for testing
print(f'Best validation accuracy: {best_val_acc:.4f}')
model.load_state_dict(best_model_state_dict)
model.to(device)

## Attention Map Visualization

In [19]:
import matplotlib.pyplot as plt

def plot_attention_map(attn_weights, img_shape):
    # Assuming attn_weights is of shape (num_heads, seq_length, seq_length)
    attn_map = attn_weights.mean(dim=0)  # Average over heads

    plt.figure(figsize=(8, 8))
    plt.imshow(attn_map.detach().cpu().numpy(), cmap='viridis')
    plt.colorbar()
    plt.title('Attention Map')
    plt.show()

## Evaluation Step

In [21]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss

model.eval()
all_preds = []
all_labels = []

with torch.inference_mode():
    for imgs, _, _, multi_labels in test_loader:
        imgs, multi_labels = imgs.to(device), multi_labels.to(device)  # Move data to GPU
        outputs = model(imgs)  # Forward pass

        # Apply sigmoid and threshold to get binary predictions
        predicted = outputs.ge(0.5).float()

        # Collect all predictions and labels for computing metrics
        all_preds.append(predicted.cpu())
        all_labels.append(multi_labels.cpu())

# Concatenate all predictions and labels
all_preds = torch.cat(all_preds, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Convert to numpy arrays for metric calculation
all_preds = all_preds.numpy()
all_labels = all_labels.numpy()

# Compute subset accuracy (exact match)
subset_accuracy = accuracy_score(all_labels, all_preds)

# Compute precision, recall, and F1 score (use 'samples' or 'macro' for multi-label)
precision = precision_score(all_labels, all_preds, average='samples')
recall = recall_score(all_labels, all_preds, average='samples')
f1 = f1_score(all_labels, all_preds, average='samples')

# Compute Hamming loss (lower is better)
hamming_loss_value = hamming_loss(all_labels, all_preds)

print(f'Subset Accuracy (Exact Match): {subset_accuracy:.4f}')
print(f'Precision (Samples): {precision:.4f}')
print(f'Recall (Samples): {recall:.4f}')
print(f'F1 Score (Samples): {f1:.4f}')
print(f'Hamming Loss: {hamming_loss_value:.4f}')

Accuracy: 0.7500
Precision: 0.7488
Recall: 0.7242
F1 Score: 0.7298
