## Library

In [1]:
import torch
import torch.optim as optim
from torch.optim import Optimizer
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import pydicom
import cv2
import os
import pandas as pd
from skimage.transform import resize

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

## Init GPU

In [2]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

GPU: NVIDIA GeForce RTX 4070 SUPER is available.
cuda


In [3]:
%load_ext autoreload
%autoreload 2

## Config Info

In [4]:
# Seed
SEED = 202408
np.random.seed(SEED)
torch.manual_seed(SEED)

# Constants
TEST_SIZE = 0.02
HEIGHT = 256
WIDTH = 256
CHANNELS = 3
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 16
TEST_BATCH_SIZE = 16
SHAPE = (HEIGHT, WIDTH, CHANNELS)

# Folders
DATA_DIR = './rsna-mil-training'

In [5]:
def correct_dcm(dcm):
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000

def window_image(dcm, window_center, window_width):    
    if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
        correct_dcm(dcm)
    img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept
    
    # Resize
    img = cv2.resize(img, SHAPE[:2], interpolation = cv2.INTER_LINEAR)
   
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    img = np.clip(img, img_min, img_max)
    return img

def bsb_window(dcm):
    brain_img = window_image(dcm, 40, 80)
    subdural_img = window_image(dcm, 80, 200)
    soft_img = window_image(dcm, 40, 380)
    
    brain_img = (brain_img - 0) / 80
    subdural_img = (subdural_img - (-20)) / 200
    soft_img = (soft_img - (-150)) / 380
    # bsb_img = np.array([brain_img, subdural_img, soft_img])
    bsb_img = np.array([brain_img, subdural_img, soft_img]).transpose(1,2,0)

    return bsb_img

def _read(path, SHAPE):
    dcm = pydicom.dcmread(path)
    try:
        img = bsb_window(dcm)
    except:
        img = np.zeros(SHAPE)
    return img

In [6]:
def preprocess_slice(slice, target_size=(224, 224)):
    slice = resize(slice, target_size, anti_aliasing=True)
    
    slice = apply_windowing(slice)
    
    return slice

def apply_windowing(slice):
    brain_window = (40, 80, -100, 200)
    slice = np.clip(slice, brain_window[2], brain_window[3])
    slice = (slice - brain_window[2]) / (brain_window[3] - brain_window[2])
    return slice

In [7]:
def read_dicom_folder(folder_path):
    slices = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds.pixel_array)
    
    # Add black images if the number of slices is less than 60
    num_slices = len(slices)
    if num_slices < 60:
        black_slice = np.zeros_like(slices[0])
        for _ in range(60 - num_slices):
            slices.append(black_slice)
    
    return slices

In [8]:
def process_patient_data(data_dir, row):
    """
    Process data for a single patient based on the row from the DataFrame.
    
    Args:
        data_dir (str): The directory containing DICOM folders.
        row (pd.Series): A row from the patient_scan_labels DataFrame.

    Returns:
        Tuple: Preprocessed slices and label.
    """
    patient_id = row['patient_id'].replace('ID_', '')  # Remove 'ID_' prefix
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')  # Remove 'ID_' prefix
    
    # Construct folder path based on patient_id and study_instance_uid
    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(data_dir, folder_name)
    
    # Read and preprocess DICOM slices
    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)
        preprocessed_slices = [preprocess_slice(slice) for slice in slices]
        
        # Determine label based on any of the hemorrhage indicators
        label = 1 if row[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any() else 0
        
        return preprocessed_slices, label
    else:
        print(f"Folder not found: {folder_path}")
        return None, None  # Handle the case where the folder is not found

In [9]:
class TrainDatasetGenerator(Dataset):
    """
    A custom dataset class for training data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found

class TestDatasetGenerator(Dataset):
    """
    A custom dataset class for testing data.
    """
    def __init__(self, data_dir, patient_scan_labels):
        self.data_dir = data_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.data_dir, row)
        
        if preprocessed_slices is not None:
            # Convert the list of numpy arrays to a single numpy array
            preprocessed_slices = np.array(preprocessed_slices)  # Convert to numpy array
            return torch.tensor(preprocessed_slices, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
        else:
            return None, None  # Handle the case where the folder is not found
        

# Function to create DataLoader for training
def get_train_loader(data_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE, shuffle=True):
    train_dataset = TrainDatasetGenerator(data_dir, patient_scan_labels)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)

# Function to create DataLoader for testing
def get_test_loader(data_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(data_dir, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [10]:
# class CNNFeatureExtractor(nn.Module):
#     """
#     CNN backbone to extract features from each slice
#     """
#     def __init__(self):
#         super().__init__()
#         self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
#         self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

#         # Calculate the size after convolutions and pooling
#         self.output_size = 64 * (224 // 2 // 2) * (224 // 2 // 2)  # Output size after conv2 and pooling

#         self.fc1 = nn.Linear(self.output_size, 128)  # Adjust based on input size
#         self.fc2 = nn.Linear(128, 64)  # Output size for attention

#     def forward(self, x):
#         # Pass slice through CNN layers
#         x = self.pool(F.relu(self.conv1(x)))  # Shape: (batch_size, 32, 112, 112)
#         x = self.pool(F.relu(self.conv2(x)))  # Shape: (batch_size, 64, 56, 56)
#         x = x.view(x.size(0), -1)  # Flatten
#         x = F.relu(self.fc1(x))
#         features = self.fc2(x)  # Output features for attention
#         return features

In [None]:
class CNNFeatureExtractor(nn.Module):
    """
    CNN backbone to extract features from each slice
    """
    def __init__(self, slice_size=(224, 224), num_slices=60):
        super().__init__()
        self.slice_size = slice_size
        self.num_slices = num_slices
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        # Calculate the size after convolutions and pooling
        self.output_size = 64 * (slice_size[0] // 2 // 2) * (slice_size[1] // 2 // 2)  # Output size after conv2 and pooling

        self.fc1 = nn.Linear(self.output_size, 128)  # Adjust based on input size
        self.fc2 = nn.Linear(128, 64)  # Output size for attention

    def forward(self, x):
        # Ensure input has the expected number of slices
        if x.size(1) < self.num_slices:
            padding = torch.zeros(x.size(0), self.num_slices - x.size(1), *self.slice_size)
            x = torch.cat([x, padding], dim=1)
        elif x.size(1) > self.num_slices:
            x = x[:, :self.num_slices]

        # Pass slices through CNN layers
        slice_features = []
        for i in range(self.num_slices):
            slice_input = x[:, i].unsqueeze(1)  # Shape: (batch_size, 1, height, width)
            slice_feature = self.cnn(slice_input)
            slice_features.append(slice_feature)

        slice_features = torch.stack(slice_features, dim=1)  # Shape: (batch_size, num_slices, feature_dim)
        
        x = slice_features.view(slice_features.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        features = self.fc2(x)  # Output features for attention
        return features

In [11]:
class AttentionLayer(nn.Module):
    """
    Attention mechanism to aggregate slice features into bag features
    """
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, output_dim)
        self.v = nn.Parameter(torch.rand(output_dim))
        
    def forward(self, features):
        # Compute attention weights for each slice
        att_weights = torch.softmax(torch.tanh(self.W(features)) @ self.v, dim=1)
        return att_weights

In [12]:
class E2EAttGP(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = CNNFeatureExtractor()
        self.attention = AttentionLayer(input_dim=64, output_dim=32)  # Adjust input_dim based on ResNet output
        self.classifier = nn.Linear(32, 1)  # Classifier for bag prediction
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # x is expected to be of shape (batch_size, num_slices, height, width)
        slice_features = []

        for i in range(x.size(1)):  # Loop over each slice in the batch
            slice_input = x[:, i, :, :].unsqueeze(1)  # Shape: (batch_size, 1, height, width)
            slice_feature = self.cnn(slice_input)  # Process each slice through ResNet
            slice_features.append(slice_feature)

        slice_features = torch.stack(slice_features, dim=1)  # Shape: (batch_size, num_slices, feature_dim)
        
        # Compute attention weights for each slice
        att_weights = self.attention(slice_features)
        
        # Aggregate slice features into bag feature using attention
        bag_feature = (slice_features * att_weights.unsqueeze(-1)).sum(dim=1)  # Shape: (batch_size, feature_dim)

        # Pass bag feature through the classifier
        logits = self.classifier(bag_feature).squeeze(-1)  # Shape: (batch_size,)
        probs = self.sigmoid(logits)
        
        return probs

In [13]:
data_dir = './rsna-mil-training'
patient_scan_labels = pd.read_csv('training_1000_scan_subset.csv')

In [14]:
train_loader = get_train_loader(data_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE)
test_loader = get_test_loader(data_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE)

In [15]:
num_epochs = 1
model = E2EAttGP().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Training loop
for epoch in range(num_epochs):
    print('############################################')
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    model.train()
    total_loss = 0  # To track total loss for the epoch
    correct = 0
    total = 0
    
    train_loader_tqdm = tqdm(train_loader, leave=False)
    for batch_slices, batch_labels in train_loader_tqdm:
        batch_slices = batch_slices.to(device)  # Move batch to GPU
        batch_labels = batch_labels.to(device)  # Move labels to GPU

        optimizer.zero_grad()
        outputs = model(batch_slices)
        loss = criterion(outputs, batch_labels.float())
        loss.backward()
        optimizer.step()

        total_loss += loss.item()  # Accumulate loss
        predicted = (outputs >= 0.5).float()  # Convert outputs to binary predictions
        correct += (predicted.squeeze() == batch_labels.float()).sum().item()  # Count correct predictions
        total += batch_labels.size(0)  # Update total count

        # Calculate accuracy for the batch
        batch_accuracy = (predicted.squeeze() == batch_labels.float()).float().mean().item()
        
        # Update progress bar
        train_loader_tqdm.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
        train_loader_tqdm.set_postfix(loss=loss.item(), accuracy=batch_accuracy)

    # Calculate average loss and accuracy for the epoch
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

############################################
Epoch [1/1]


                                      

OutOfMemoryError: CUDA out of memory. Tried to allocate 98.00 MiB. GPU 0 has a total capacity of 11.71 GiB of which 28.50 MiB is free. Including non-PyTorch memory, this process has 11.23 GiB memory in use. Of the allocated memory 10.97 GiB is allocated by PyTorch, and 44.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_slices, batch_labels in test_loader:
        outputs = model(batch_slices)
        
        # Convert outputs to binary predictions
        predicted = (outputs >= 0.5).float()  # Assuming outputs are probabilities

        # Update total and correct counts
        total += batch_labels.size(0)  # Number of samples in the batch
        correct += (predicted.squeeze() == batch_labels.float()).sum().item()  # Count correct predictions

# Calculate accuracy
accuracy = correct / total
print(f'Accuracy: {accuracy:.4f}')