## Library

In [None]:
import torch
import torch.optim as optim
from torch.optim import Optimizer
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pydicom
import cv2
import os
import pandas as pd
from skimage.transform import resize

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

## Init GPU

In [None]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

In [None]:
%load_ext autoreload
%autoreload 2

## Config Info

In [None]:
# Seed
SEED = 202408
np.random.seed(SEED)
torch.manual_seed(SEED)

# Constants
TEST_SIZE = 0.02
HEIGHT = 224
WIDTH = 224
CHANNELS = 3
TRAIN_BATCH_SIZE = 1
VALID_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
SHAPE = (HEIGHT, WIDTH, CHANNELS)

# Folders
DATA_DIR = './rsna-mil-training'

In [None]:
def preprocess_slice(slice, target_size=(WIDTH, HEIGHT)):
    slice = resize(slice, target_size, anti_aliasing=True)
    
    slice = apply_windowing(slice)
    
    return slice

def apply_windowing(slice):
    brain_window = (40, 80, -100, 200)
    slice = np.clip(slice, brain_window[2], brain_window[3])
    slice = (slice - brain_window[2]) / (brain_window[3] - brain_window[2])
    return slice

In [None]:
def read_dicom_folder(folder_path):
    slices = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds.pixel_array)
    
    # Add black images if the number of slices is less than 60
    num_slices = len(slices)
    if num_slices < 60:
        black_slice = np.zeros_like(slices[0])
        for _ in range(60 - num_slices):
            slices.append(black_slice)
    
    return slices

In [None]:
def process_patient_data(data_dir, row):
    """
    Process data for a single patient based on the row from the DataFrame.
    
    Args:
        data_dir (str): The directory containing DICOM folders.
        row (pd.Series): A row from the patient_scan_labels DataFrame.

    Returns:
        Tuple: Preprocessed slices and label.
    """
    patient_id = row['patient_id'].replace('ID_', '')  # Remove 'ID_' prefix
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')  # Remove 'ID_' prefix
    
    # Construct folder path based on patient_id and study_instance_uid
    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(data_dir, folder_name)
    
    # Read and preprocess DICOM slices
    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)
        preprocessed_slices = [preprocess_slice(slice) for slice in slices]
        
        # Determine label based on any of the hemorrhage indicators
        label = 1 if row[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any() else 0
        
        return preprocessed_slices, label
    else:
        print(f"Folder not found: {folder_path}")
        return None, None  # Handle the case where the folder is not found

In [None]:
class TrainDatasetGenerator(Dataset):
    """
    A custom dataset class for training data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found

class TestDatasetGenerator(Dataset):
    """
    A custom dataset class for testing data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found
        

# Function to create DataLoader for training
def get_train_loader(data_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE, shuffle=True):
    train_dataset = TrainDatasetGenerator(data_dir, patient_scan_labels)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)

# Function to create DataLoader for testing
def get_test_loader(data_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(data_dir, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class CNNFeatureExtractor(nn.Module):
    """
    CNN backbone to extract features from each slice
    """
    def __init__(self, slice_size=(224, 224), num_slices=60, batch_size=10):
        super().__init__()
        self.slice_size = slice_size
        self.num_slices = num_slices
        self.batch_size = batch_size  # New parameter for batch processing
        
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(64 * (slice_size[0] // 4) * (slice_size[1] // 4), 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

    def forward(self, x):
        # Ensure input has the expected number of slices
        if x.size(1) < self.num_slices:
            padding = torch.zeros(x.size(0), self.num_slices - x.size(1), *self.slice_size, device=x.device)
            x = torch.cat([x, padding], dim=1)
        elif x.size(1) > self.num_slices:
            x = x[:, :self.num_slices]

        # Pass slices through CNN layers in batches
        slice_features = []
        for i in range(0, self.num_slices, self.batch_size):
            slice_batch = x[:, i:i+self.batch_size]  # Select a batch of slices
            slice_batch = slice_batch.unsqueeze(2)  # Shape: (batch_size, batch_size, 1, height, width)
            slice_batch = slice_batch.view(-1, 1, *self.slice_size)  # Reshape for CNN input
            
            # Forward pass through CNN
            batch_features = self.cnn(slice_batch)  # Shape: (batch_size * batch_size, feature_dim)
            slice_features.append(batch_features.detach())  # Detach to save memory

        # Stack features and reshape back to (batch_size, num_slices, feature_dim)
        slice_features = torch.cat(slice_features, dim=0).view(x.size(0), self.num_slices, -1)

        # Aggregate features across slices (mean, max, etc.)
        aggregated_features = slice_features.mean(dim=1)  # Shape: (batch_size, feature_dim)

        # Squeeze the output to match the target shape
        output = aggregated_features.squeeze()  # Shape: (batch_size,)

        return output


In [None]:
class AttentionLayer(nn.Module):
    """
    Attention mechanism to aggregate slice features into bag features
    """
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, output_dim)
        self.v = nn.Parameter(torch.rand(output_dim))
        
    def forward(self, features):
        # Compute attention weights for each slice
        att_weights = torch.softmax(torch.tanh(self.W(features)) @ self.v, dim=1)
        return att_weights

In [None]:
class E2EAttGP(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = CNNFeatureExtractor()
        self.attention = AttentionLayer(input_dim=64, output_dim=32)  # Adjust input_dim based on CNN output
        self.classifier = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # x is expected to be of shape (batch_size, num_slices, height, width)
        slice_features = []

        for i in range(x.size(1)):  # Loop over each slice in the batch
            slice_input = x[:, i, :, :].unsqueeze(1)  # Shape: (batch_size, 1, height, width)
            slice_feature = self.cnn(slice_input)  # Process each slice through CNN
            slice_features.append(slice_feature)

        slice_features = torch.stack(slice_features, dim=1)  # Shape: (batch_size, num_slices, feature_dim)
        
        # Compute attention weights for each slice
        att_weights = self.attention(slice_features)
        
        # Aggregate slice features into bag feature using attention
        bag_feature = (slice_features * att_weights.unsqueeze(-1)).sum(dim=1)  # Shape: (batch_size, 64)

        # Pass bag feature through the classifier
        logits = self.classifier(bag_feature)  # size (batch_size, 1)
        probs = self.sigmoid(logits)
        
        return probs

In [None]:
data_dir = './rsna-mil-training'
patient_scan_labels = pd.read_csv('training_1000_scan_subset.csv')

In [None]:
train_loader = get_train_loader(data_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE)
test_loader = get_test_loader(data_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE)

In [None]:
num_epochs = 1
model = E2EAttGP().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Training loop
for epoch in range(num_epochs):
    print('############################################')
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    model.train()
    total_loss = 0  # To track total loss for the epoch
    correct = 0
    total = 0
    
    train_loader_tqdm = tqdm(train_loader, leave=False)
    for batch_slices, batch_labels in train_loader_tqdm:
        batch_slices = batch_slices.to(device)  # Move batch to GPU
        batch_labels = batch_labels.to(device)  # Move labels to GPU

        optimizer.zero_grad()
        outputs = model(batch_slices)
        loss = criterion(outputs, batch_labels.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()  # Accumulate loss
        predicted = (outputs >= 0.5).float()  # Convert outputs to binary predictions
        correct += (predicted.squeeze() == batch_labels.float()).sum().item()  # Count correct predictions
        total += batch_labels.size(0)  # Update total count

        # Calculate accuracy for the batch
        batch_accuracy = (predicted.squeeze() == batch_labels.float()).float().mean().item()
        
        # Update progress bar
        train_loader_tqdm.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
        train_loader_tqdm.set_postfix(loss=loss.item(), accuracy=batch_accuracy)

    # Calculate average loss and accuracy for the epoch
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_slices, batch_labels in test_loader:
        outputs = model(batch_slices)
        
        # Convert outputs to binary predictions
        predicted = (outputs >= 0.5).float()  # Assuming outputs are probabilities

        # Update total and correct counts
        total += batch_labels.size(0)  # Number of samples in the batch
        correct += (predicted.squeeze() == batch_labels.float()).sum().item()  # Count correct predictions

# Calculate accuracy
accuracy = correct / total
print(f'Accuracy: {accuracy:.4f}')