## Library

In [1]:
import torch
import torch.optim as optim
from torch.optim import Optimizer
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import numpy as np
import pydicom
import cv2
import os
import pandas as pd
from skimage.transform import resize

from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast

from tqdm import tqdm

import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

## Init GPU

In [2]:
# Initialize GPU Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

print(device)

GPU: Tesla P100-PCIE-16GB is available.
cuda


In [3]:
%load_ext autoreload
%autoreload 2

## Config Info

In [4]:
# Constants
TEST_SIZE = 0.02
HEIGHT = 512
WIDTH = 512
CHANNELS = 3
TRAIN_BATCH_SIZE = 2
VALID_BATCH_SIZE = 1
TEST_BATCH_SIZE = 1
MAX_SLICES = 60
SHAPE = (HEIGHT, WIDTH, CHANNELS)

# Folders
DATA_DIR = '/kaggle/input/rsna-mil-training/'

In [5]:
DICOM_DIR = DATA_DIR + 'rsna-mil-training'
CSV_PATH = DATA_DIR + 'training_1000_scan_subset.csv'
patient_scan_labels = pd.read_csv(CSV_PATH)

In [6]:
def preprocess_slice(slice, target_size=(HEIGHT, WIDTH)):
    slice = resize(slice, target_size, anti_aliasing=True)
    brain_channel = apply_windowing(slice, window=(40, 80))
    subdural_channel = apply_windowing(slice, window=(80, 200))
    bone_channel = apply_windowing(slice, window=(600, 2800))
    
    multichannel_slice = np.stack([brain_channel, subdural_channel, bone_channel], axis=-1)
    return multichannel_slice.astype(np.float16)  # Use float16 for reduced memory usage

def apply_windowing(slice, window):
    window_width, window_level = window
    lower_bound = window_level - window_width // 2
    upper_bound = window_level + window_width // 2
    
    windowed_slice = np.clip(slice, lower_bound, upper_bound)
    windowed_slice = (windowed_slice - lower_bound) / (upper_bound - lower_bound)
    return windowed_slice

In [7]:
def read_dicom_folder(folder_path):
    slices = []
    for filename in sorted(os.listdir(folder_path))[:MAX_SLICES]:  # Limit to MAX_SLICES
        if filename.endswith(".dcm"):
            file_path = os.path.join(folder_path, filename)
            ds = pydicom.dcmread(file_path)
            slices.append(ds.pixel_array)
    
    # Pad with black images if necessary
    while len(slices) < MAX_SLICES:
        slices.append(np.zeros_like(slices[0]))
    
    return slices[:MAX_SLICES]  # Ensure we return exactly MAX_SLICES

In [8]:
def process_patient_data(dicom_dir, row):
    """
    Process data for a single patient based on the row from the DataFrame.
    
    Args:
        dicom_dir (str): The directory containing DICOM folders.
        row (pd.Series): A row from the patient_scan_labels DataFrame.

    Returns:
        Tuple: Preprocessed slices and label.
    """
    patient_id = row['patient_id'].replace('ID_', '')  # Remove 'ID_' prefix
    study_instance_uid = row['study_instance_uid'].replace('ID_', '')  # Remove 'ID_' prefix
    
    # Construct folder path based on patient_id and study_instance_uid
    folder_name = f"{patient_id}_{study_instance_uid}"
    folder_path = os.path.join(dicom_dir, folder_name)
    
    # Read and preprocess DICOM slices
    if os.path.exists(folder_path):
        slices = read_dicom_folder(folder_path)
        preprocessed_slices = [preprocess_slice(slice) for slice in slices]
        
        # Determine label based on any of the hemorrhage indicators
        label = 1 if row[['any', 'epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']].any() else 0
        
        return preprocessed_slices, label
    else:
        print(f"Folder not found: {folder_path}")
        return None, None  # Handle the case where the folder is not found

In [9]:
class TrainDatasetGenerator(Dataset):
    def __init__(self, dicom_dir, patient_scan_labels):
        self.dicom_dir = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.dicom_dir, row)
        
        if preprocessed_slices is not None:
            preprocessed_slices = np.array(preprocessed_slices)
            return torch.tensor(preprocessed_slices, dtype=torch.float16), torch.tensor(label, dtype=torch.long)
        else:
            return None, None

class TestDatasetGenerator(Dataset):
    def __init__(self, dicom_dir, patient_scan_labels):
        self.dicom_dir = dicom_dir
        self.patient_scan_labels = patient_scan_labels

    def __len__(self):
        return len(self.patient_scan_labels)

    def __getitem__(self, idx):
        row = self.patient_scan_labels.iloc[idx]
        preprocessed_slices, label = process_patient_data(self.dicom_dir, row)
        
        if preprocessed_slices is not None:
            preprocessed_slices = np.array(preprocessed_slices)
            return torch.tensor(preprocessed_slices, dtype=torch.float16), torch.tensor(label, dtype=torch.long)
        else:
            return None, None

def get_train_loader(dicom_dir, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE, shuffle=True):
    train_dataset = TrainDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)

def get_test_loader(dicom_dir, patient_scan_labels, batch_size=TEST_BATCH_SIZE):
    test_dataset = TestDatasetGenerator(dicom_dir, patient_scan_labels)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

## CNN Feature Extractor

In [10]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.features = torch.nn.Sequential(*list(resnet18.children())[:-1])
        
    def forward(self, x):
        x = self.features(x)
        return x.view(x.size(0), -1)

In [11]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )
    
    def forward(self, features):
        weights = self.attention(features)
        weights = torch.softmax(weights, dim=1)
        weighted_features = torch.sum(weights * features, dim=1)
        return weighted_features, weights

In [12]:
class MILModel(nn.Module):
    def __init__(self, num_classes):
        super(MILModel, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.attention = AttentionLayer(512, 256)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        batch_size, num_images, height, width, channels = x.size()
        x = x.view(-1, channels, height, width)
        features = self.feature_extractor(x)
        features = features.view(batch_size, num_images, -1)
        weighted_features, attention_weights = self.attention(features)
        output = self.classifier(weighted_features)
        return output, attention_weights, features

In [13]:
def train_model(model, train_loader, criterion, optimizer, scaler, device, num_epochs, accumulation_steps):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        train_loader_tqdm = tqdm(train_loader, leave=False)
        optimizer.zero_grad()
        for i, (inputs, labels) in enumerate(train_loader_tqdm):
            inputs, labels = inputs.to(device), labels.to(device)

            with autocast(device_type='cuda', dtype=torch.float16):
                outputs, _, _ = model(inputs)
                loss = criterion(outputs.squeeze(), labels.float())
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()

            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps
            predicted = (outputs.squeeze() >= 0.5).float()
            correct += (predicted == labels.float()).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')

    return model

In [14]:
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    patient_scan_labels = pd.read_csv(CSV_PATH)
    
    train_loader = get_train_loader(DICOM_DIR, patient_scan_labels, batch_size=TRAIN_BATCH_SIZE)
    test_loader = get_test_loader(DICOM_DIR, patient_scan_labels, batch_size=TEST_BATCH_SIZE)

    num_epochs = 1
    model = MILModel(num_classes=1).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scaler = GradScaler()
    accumulation_steps = 4

    trained_model = train_model(model, train_loader, criterion, optimizer, scaler, device, num_epochs, accumulation_steps)

    # Save the trained model
    torch.save(trained_model.state_dict(), 'optimized_mil_model.pth')

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 187MB/s]
                                                 

Epoch 1/1, Loss: 0.6946, Acc: 0.5780


