### TODOs

Add Augmentation

#### Imports

In [1]:
import os
import random
import glob
import re

import pandas as pd

import numpy as np

import matplotlib.pyplot as plt
import matplotlib

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import cv2

import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

#### Seed

In [2]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


set_seed(42)

#### Config

In [3]:
PATH = '..'

# Pre-processing
IMG_SIZE = 256
SLICE_NUMBER = 50

# Training
N_EPOCHS = 1
BATCH_SIZE = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 1. Load Data

For each person, we were given *four* different MRI types: 
* FLAIR
* T1w
* T1wCE, and 
* T2w. 

We will create 2D "images" which are composed of each MRI type (4 sequences) where each sequence is composed of (middle) 50 (variable; slice_number) slices.

Shape:
Channel x Width x Height


### 1.1 Utilities

In [4]:
def get_slices(mri_type, slice_number):
    #print(f"Length of folder: {len(mri_type)}")
    # Take slice_number slices from the middle
    threshold = slice_number // 2
    minimum_idx = len(mri_type)//2 - threshold if (len(mri_type)//2 - threshold) > 0 else 0
    maximum_idx = len(mri_type)//2 + threshold  # maximum can exceed the index
    #print(f"Minimum {minimum_idx}")
    #print(f"Maximum {maximum_idx}")
    # Create array which contains the images
    mri_img = np.array([cv2.resize(cv2.imread(a, cv2.IMREAD_GRAYSCALE), (IMG_SIZE, IMG_SIZE)) for a in mri_type[minimum_idx:maximum_idx]]).T
    # If less than slice_number slices, add slice_number - mri_img.shape[-1] images with only zero values
    if mri_img.shape[-1] < slice_number:
        #print(f"Current slices: {mri_img.shape[-1]}")
        n_zero = slice_number - mri_img.shape[-1]
        mri_img = np.concatenate((mri_img, np.zeros((IMG_SIZE, IMG_SIZE, n_zero))), axis = -1)
    return mri_img
    

def load_images(scan_id, slice_number=SLICE_NUMBER):
    # Ascending sort
    flair = sorted(glob.glob(f"{PATH}/train/{scan_id}/FLAIR/*.png"), key=lambda f: int(re.sub('\D', '', f)))
    t1w = sorted(glob.glob(f"{PATH}/train/{scan_id}/T1w/*.png"), key=lambda f: int(re.sub('\D', '', f)))
    t1wce = sorted(glob.glob(f"{PATH}/train/{scan_id}/T1wCE/*.png"), key=lambda f: int(re.sub('\D', '', f)))
    t2w = sorted(glob.glob(f"{PATH}/train/{scan_id}/T2w/*.png"), key=lambda f: int(re.sub('\D', '', f)))
    
    #print(f"Scan id {scan_id}")
    flair_img = get_slices(flair, slice_number)
    t1w_img = get_slices(t1w, slice_number)
    t1wce_img = get_slices(t1wce, slice_number)
    t2w_img = get_slices(t2w, slice_number)
    
    # Return "3d" image
    # 0:50: Flair 2D images; 51:100: T1w 2D images; 101-150 T1wCE 2D images; 151-200 T2w 2D images
    return np.concatenate((flair_img, t1w_img, t1wce_img, t2w_img), axis = -1)

### 1.2 Dataset and Dataloader

In [5]:
class RSNADataset(Dataset):
    def __init__(self, path='../', split = "train", validation_split = 0.2):
        train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))
        self.labels = {}
        brats = list(train_data["BraTS21ID"])
        mgmt = list(train_data["MGMT_value"])
        for b, m in zip(brats, mgmt):
            self.labels[str(b).zfill(5)] = m
            
        remove_ids = ["00709", "00109", "00123"]
            
        if split == "valid":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob((path + f"/train/" + "/*")), key=lambda f: int(re.sub('\D', '', f)))]
            self.ids = self.ids[:int(len(self.ids) * validation_split)] # first 20% as validation
        elif split == "train":
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob((path + f"/train/" + "/*")), key=lambda f: int(re.sub('\D', '', f)))]
            self.ids = self.ids[int(len(self.ids) * validation_split):] # last 80% as train
        else:
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob((path + f"/train/" + "/*")), key=lambda f: int(re.sub('\D', '', f)))]
        
        self.ids = [id_ for id_ in self.ids if id_ not in remove_ids]            
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        imgs = load_images(self.ids[idx])
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,) * (SLICE_NUMBER*4), (0.5,) * (SLICE_NUMBER*4))])
        imgs = transform(imgs)

        if self.split != "test":
            label = self.labels[self.ids[idx]]
            return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)
        else:
            return torch.tensor(imgs, dtype = torch.float32)

In [6]:
train_ds = RSNADataset()
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_ds = RSNADataset(split='valid')
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
images, labels = next(iter(train_dl))
print(f"Shape of the batch {images.shape}")
print(f"Number of images/labels in the batch: {images.shape[0]}")
print(f"Number of channels each image has: {images.shape[1]}")
print(f"Size of each image is: {images.shape[2]}x{images.shape[3]}")

Shape of the batch torch.Size([64, 200, 256, 256])
Number of images/labels in the batch: 64
Number of channels each image has: 200
Size of each image is: 256x256


### 1.3 Model

In [8]:
class GliobCNN(nn.Module): 
    def __init__(self):
        super(GliobCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=(SLICE_NUMBER*4), out_channels=64, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1)
        self.conv4 = nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1)
        
        # Pooling layers
        self.pooling1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.pooling2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.pooling3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(5832, 4096)
        self.fc2 = nn.Linear(4096, 2)

    def forward(self, X):
        
        X = F.relu(self.conv1(X))    # 64x254x254 (first dimension depends on SLICE_NUMBER; here, SLICE_NUMBER=50)
        X = self.pooling1(X)         # 64x126x126
        X = F.relu(self.conv2(X))    # 32x124x124
        X = self.pooling2(X)         # 32x61x61
        X = F.relu(self.conv3(X))    # 16x59x59
        X = self.pooling3(X)         # 16x29x29
        X = F.relu(self.conv4(X))    # 8x27x27
        X = X.flatten(1)             # 5832
        X = F.relu(self.fc1(X))      # 4096
        X = F.relu(self.fc2(X))      # 2

        return X

In [9]:
model = GliobCNN()

### 1.4 Training

In [10]:
def training(net, n_epochs, optimizer, loss_function, verbose=True):
    # Store the losses for each epoch
    loss_train_list = []
    loss_valid_list = []

    # Store the accuracy for each epoch
    acc_train_list = []
    acc_valid_list = []

    # Iterate over the dataset n_epochs times
    for epoch in range(n_epochs):
        net.train()  # net.train() will notify all your layers that you are in training mode

        train_loss = 0  # Training loss in epoch
        num_train_correct  = 0
        num_train_examples = 0

        # For each batch, pass the training examples, calculate loss and gradients and optimize the parameters
        for xb, yb in train_dl:
            optimizer.zero_grad()  # zero_grad clears old gradients from the last step

            xb = xb.to(device)
            yb = yb.to(device)

            y_hat = net(xb)  # Forward pass
            loss = loss_function(y_hat, yb)  # Calculate Loss

            loss.backward()  # Calculate the gradients (using backpropagation)
            optimizer.step()  # # Optimize the parameters: opt.step() causes the optimizer to take a step based on the gradients of the parameters.

            train_loss += loss.item()
            num_train_correct += (torch.max(y_hat, 1)[1] == yb).sum().item()
            num_train_examples += xb.shape[0]

        train_acc = num_train_correct / num_train_examples

        valid_loss = 0  # Validation loss in epoch
        num_val_correct  = 0
        num_val_examples = 0

        net.eval()  # net.eval() will notify all your layers that you are in evaluation mode
        with torch.no_grad():
            # Perform a prediction on the validation set  
            for xb_valid, yb_valid in valid_dl:
                xb_valid = xb_valid.to(device)
                yb_valid = yb_valid.to(device)

                y_hat = net(xb_valid)  # Forward pass
                loss = loss_function(y_hat, yb_valid)  # Calculate Loss

                valid_loss += loss.item()
                num_val_correct += (torch.max(y_hat, 1)[1] == yb_valid).sum().item()
                num_val_examples += xb_valid.shape[0]

        val_acc = num_val_correct / num_val_examples

        if verbose:
            print(f"Train Loss (Negative Loss Likelihood) in epoch {epoch}: {train_loss:.2f}")
            print(f"Validation Loss (Negative Loss Likelihood) in epoch {epoch}: {valid_loss:.2f}")
            print(f"Train Accuracy in epoch {epoch}: {100 * (train_acc):.2f}")
            print(f"Validation Accuracy in epoch {epoch}: {100 * (val_acc):.2f}\n")
            print("\n")

        loss_train_list.append(train_loss)
        loss_valid_list.append(valid_loss)
        acc_train_list.append(100 * (train_acc))
        acc_valid_list.append(100 * (val_acc))

    return acc_train_list, acc_valid_list, loss_train_list, loss_valid_list

In [11]:
# Create Loss Function
loss_function = nn.CrossEntropyLoss()

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(),lr = 0.01)

model.to(device)

acc_train, acc_valid, loss_train, loss_valid = training(net=model, n_epochs=N_EPOCHS, optimizer=optimizer, loss_function=loss_function)

Train Loss (Negative Loss Likelihood) in epoch 0: 5.55
Validation Loss (Negative Loss Likelihood) in epoch 0: 1.39
Train Accuracy in epoch 0: 46.90
Validation Accuracy in epoch 0: 47.83



