## Library

In [None]:
import torch
import torch.optim as optim
from torch.optim import Optimizer
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import numpy as np
import pydicom
import cv2
import os
import pandas as pd
from skimage.transform import resize

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

## Init GPU

In [None]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

In [None]:
%load_ext autoreload
%autoreload 2

## Config Info

In [None]:
# Constants
TEST_SIZE = 0.02
HEIGHT = 224
WIDTH = 224
CHANNELS = 3
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE = 8
TEST_BATCH_SIZE = 4
SHAPE = (HEIGHT, WIDTH, CHANNELS)

# Folders
DATA_DIR = '/kaggle/input/rsna-mil-training/'

In [None]:
DICOM_DIR = DATA_DIR + 'rsna-mil-training'
CSV_PATH = DATA_DIR + 'training_1000_scan_subset.csv'
patient_scan_labels = pd.read_csv(CSV_PATH)

In [None]:
def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    slice = resize(slice, target_size, anti_aliasing=True)
    brain_channel = apply_windowing(slice, window=(40, 80))
    subdural_channel = apply_windowing(slice, window=(80, 200))
    bone_channel = apply_windowing(slice, window=(600, 2800))
    
    multichannel_slice = np.stack([brain_channel, subdural_channel, bone_channel], axis=-1)
    return multichannel_slice

def apply_windowing(slice, window):
    window_width, window_level = window
    lower_bound = window_level - window_width // 2
    upper_bound = window_level + window_width // 2
    
    windowed_slice = np.clip(slice, lower_bound, upper_bound)
    windowed_slice = (windowed_slice - lower_bound) / (upper_bound - lower_bound)
    return windowed_slice

In [None]:
def read_dicom_folder(folder_path):
    slices = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds.pixel_array)
    
    # Add black images if the number of slices is less than 60
    num_slices = len(slices)
    if num_slices < 60:
        black_slice = np.zeros_like(slices[0])
        for _ in range(60 - num_slices):
            slices.append(black_slice)
    
    return slices

In [None]:
def process_patient_data(dicom_dir, row):
    """
    Process data for a single patient based on the row from the DataFrame.
    
    Args:
        dicom_dir (str): The directory containing DICOM folders.
        row (pd.Series): A row from the patient_scan_labels DataFrame.

    Returns:
        Tuple: Preprocessed slices and label.
    """
    patient_id = row['patient_id'].replace('ID_', '')  # Remove 'ID_' prefix
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')  # Remove 'ID_' prefix
    
    # Construct folder path based on patient_id and study_instance_uid
    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(dicom_dir, folder_name)
    
    # Read and preprocess DICOM slices
    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)
        preprocessed_slices = [preprocess_slice(slice) for slice in slices]
        
        # Determine label based on any of the hemorrhage indicators
        label = 1 if row[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any() else 0
        
        return preprocessed_slices, label
    else:
        print(f"Folder not found: {folder_path}")
        return None, None  # Handle the case where the folder is not found

In [None]:
class TrainDatasetGenerator(Dataset):
    """
    A custom dataset class for training data.
    """
    def __init__(self, dicom_dir, patient_scan_labels):
        self.dicom_dir = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.dicom_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found

class TestDatasetGenerator(Dataset):
    """
    A custom dataset class for testing data.
    """
    def __init__(self, dicom_dir, patient_scan_labels):
        self.DICOM_DIR = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.dicom_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found
        

# Function to create DataLoader for training
def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE, shuffle=True):
    train_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)

# Function to create DataLoader for testing
def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(DICOM_DIR, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## CNN Feature Extractor

In [None]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.features = torch.nn.Sequential(*list(resnet18.children())[:-1])
        
    def forward(self, x):
        x = self.features(x)
        return x.view(x.size(0), -1)

In [None]:
class AttentionLayer(torch.nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(AttentionLayer, self).__init__()
        self.attention = torch.nn.Sequential(
            torch.nn.Linear(input_dim, attention_dim),
            torch.nn.Tanh(),
            torch.nn.Linear(attention_dim, 1)
        )
    
    def forward(self, features):
        weights = self.attention(features)
        weights = torch.softmax(weights, dim=1)
        weighted_features = torch.sum(weights * features, dim=1)
        return weighted_features, weights

In [None]:
class MILModel(torch.nn.Module):
    def __init__(self, num_classes):
        super(MILModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.attention = AttentionLayer(512, 256)  # 512 is the output dimension of ResNet18
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(256, num_classes),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x):
        # x shape: (batch_size, num_images, height, width, channels)
        batch_size, num_images, height, width, channels = x.size()
        
        # Reshape and permute to (batch_size * num_images, channels, height, width)
        x = x.view(-1, height, width, channels)
        x = x.permute(0, 3, 1, 2)
        
        # Extract features
        features = self.feature_extractor(x)
        features = features.view(batch_size, num_images, -1)
        
        # Apply attention
        weighted_features, attention_weights = self.attention(features)
        
        # Classify
        output = self.classifier(weighted_features)
        
        return output, attention_weights, features

    def get_attention_weights(self, x):
        # x shape: (batch_size, num_images, height, width, channels)
        batch_size, num_images, height, width, channels = x.size()
        
        # Reshape and permute to (batch_size * num_images, channels, height, width)
        x = x.view(-1, height, width, channels)
        x = x.permute(0, 3, 1, 2)
        
        features = self.feature_extractor(x)
        features = features.view(batch_size, num_images, -1)
        
        _, attention_weights = self.attention(features)
        
        return attention_weights

In [None]:
train_loader = get_train_loader(DICOM_DIR, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE)
test_loader = get_test_loader(DICOM_DIR, patient_scan_labels, batch_size=TEST_BATCH_SIZE)

In [None]:
num_epochs = 1
model = MILModel(num_classes=1).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training loop
for epoch in range(num_epochs):
    print('############################################')
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    total_loss = 0
    correct = 0
    total = 0
    
    model.train()

    train_loader_tqdm = tqdm(train_loader, leave=False)
    for batch_idx, (batch_slices, batch_labels) in enumerate(train_loader_tqdm):
        try:
            batch_slices = batch_slices.to(device)
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            outputs, attention_weights, features = model(batch_slices)
            
            loss = criterion(outputs.squeeze(), batch_labels.float())

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predicted = (outputs.squeeze() >= 0.5).float()
            correct += (predicted == batch_labels.float()).sum().item()
            total += batch_labels.size(0)

            batch_accuracy = (predicted == batch_labels.float()).float().mean().item()

            train_loader_tqdm.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
            train_loader_tqdm.set_postfix(loss=loss.item(), accuracy=batch_accuracy)

        except Exception as e:
            print(f"Error in batch {batch_idx + 1}: {str(e)}\"")
            print(f"Input shape: {batch_slices.shape}")
            print(f"Labels shape: {batch_labels.shape}")
            raise  # Re-raise the exception for full traceback

    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')