# CNN

In [1]:
import os, random, time
import numpy as np
import nibabel as nib
import cv2
import matplotlib.pyplot as plt

from utilities import *

# Pytorch functions
import torch
# Neural network layers
import torch.nn as nn
import torch.nn.functional as F
# Optimizer
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
# Torchvision library
from torchvision import transforms

# For results
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

In [2]:
# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

mps


In [3]:
def set_seed(seed, use_cuda = True, use_mps = False):
    """
    Set SEED for PyTorch reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    if use_mps:
        torch.mps.manual_seed(seed)

SEED = 44

USE_SEED = True

if USE_SEED:
    set_seed(SEED, torch.cuda.is_available(), torch.backends.mps.is_available())

# Old Transformations

These are some transformations I have written but need to be updated for our dataset.

In [None]:
class Crop(object):
    def __init__(self, output_ind):
        self.output_ind = output_ind
    def __call__(self, sample):
        image, label = sample
        new_image = []
        output_ind = self.output_ind
        for i in range(len(image)): # 4
            new_image.append(image[i][output_ind[0][0]:output_ind[0][1], output_ind[1][0]:output_ind[1][1],:])
            new_label = label[output_ind[0][0]:output_ind[0][1], output_ind[1][0]:output_ind[1][1],:]
        return new_image, new_label

class Flatten(object):
    def __call__(self, sample):
        image, label = sample # images have 4 image
        new_image = []
        for i in range(len(image)):
            new_image.append(image[i].reshape(180, -1, order = 'F'))
        new_label = label.reshape(-1)
        return new_image, new_label
    
class ScanNormalize(object):
    def __call__(self, sample):
        image, label = sample
        new_image = []
        for i in range(len(image)):
            img = image[i]
            new_scan = (img-np.min(img))/(np.max(img)-np.min(img))
            new_image.append(new_scan)
        return new_image, label

class StackScans(object):
    def __call__(self, sample):
        image, label = sample
        new_image = np.stack(image, axis=-1)
        return new_image, label
    
class BinaryLabel(object):
    def __call__(self, sample):
        image, label = sample
        new_label = np.sign(label)
        return image, new_label
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return torch.from_numpy(image), torch.from_numpy(landmarks)

# 3D MRI Scans Inputs

We load flair, t1ce, t2 and define any pre-processing methods here. Without preprocessing, the training data has size `(3, 240, 240, 155)` and the target has size `(240, 240, 155)`.

In [None]:
class BraTSDataset(Dataset):
    def __init__(self, image_path = r'./BraTS/BraTS2021_Training_Data', transform=None):
        'Initialisation'
        self.image_path = image_path
        self.folders_name = [folder for folder in os.listdir(self.image_path) if folder != '.DS_Store']
        self.transform = transform

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.folders_name)

    def __getitem__(self, index):
        'Generates one sample of data'

        # Select sample
        fld_name = self.folders_name[index]
        image = []
        for scan_type in ['flair', 't1ce', 't2']:
            path_img = os.path.join(self.image_path, fld_name, fld_name + '_' + scan_type + '.nii.gz')
            img = nib.load(path_img).get_fdata()
            image.append(img)
        
        image = np.array(image)

        path_label = os.path.join(self.image_path, fld_name, fld_name + '_seg.nii.gz')

        label = nib.load(path_label).get_fdata()
        if self.transform:
            image, label = self.transform([image, label])
        return image, label

# 2D MRI Scans Inputs

In [4]:
class BraTSDataset(Dataset):
    def __init__(self, image_path = r'./BraTS/BraTS2021_Training_Data', transform=None):
        'Initialisation'
        self.image_path = image_path
        self.folders_name = [folder for folder in os.listdir(self.image_path) if folder != '.DS_Store']
        self.transform = transform

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.folders_name) * 155

    def __getitem__(self, index):
        'Generates one sample of data'

        # Determine the image index and the RGB layer
        image_idx = index // 155
        layer_idx = index % 155

        # Select sample
        fld_name = self.folders_name[image_idx]
        image = []
        for scan_type in ['flair', 't1ce', 't2']:
            path_img = os.path.join(self.image_path, fld_name, fld_name + '_' + scan_type + '.nii.gz')
            img = nib.load(path_img).get_fdata()
            # Need to apply standardisation here...
            image.append(img[:, :, layer_idx])
        
        image = np.array(image)

        path_label = os.path.join(self.image_path, fld_name, fld_name + '_seg.nii.gz')

        label = nib.load(path_label).get_fdata()[:, :, layer_idx]
        
        if self.transform:
            image, label = self.transform([image, label])
        return image, label

In [5]:
class everythirdlayer(object):
    def __call__(self, sample):
        image, label = sample
        new_image = image[:,:,:,np.arange(3, 152, 3)]
        new_label = label[:,:,np.arange(3, 152, 3)]
        return new_image, new_label

class Flair(object):
    def __call__(self, sample):
        image, label = sample
        new_image = image[0]
        return new_image, label
    
class IntLabel(object):
    def __call__(self, sample):
        image, label = sample
        new_label = label.astype(int)
        return image, new_label
    
class BinariseLabel(object):
    def __call__(self, sample):
        image, label = sample
        new_label = np.sign(label)
        return image, new_label

class CropAndResize(object):
    def __call__(self, sample):
        image, label = sample
        rows = np.any(image, axis=1)
        cols = np.any(image, axis=0)

        # Find the bounding box of the non-zero regions
        rows_indices = np.where(rows)[0]
        cols_indices = np.where(cols)[0]
        top_row = np.min(rows_indices)
        bottom_row = np.max(rows_indices)
        left_col = np.min(cols_indices)
        right_col = np.max(cols_indices)

        square_size = max(bottom_row - top_row, right_col - left_col) + 1

        # Crop the image
        cropped_image = image[top_row:top_row + square_size, left_col:left_col + square_size]
        cropped_label = label[top_row:top_row + square_size, left_col:left_col + square_size]

        # Resize the image
        dim = [64,64]
        resized_image = cv2.resize(cropped_image, dim)
        resized_label = cv2.resize(cropped_label, dim)

        return resized_image, resized_label

class Standardise(object) :
    def __call__(self, sample):
        image, label = sample
        image_no_zeros = np.where(image == 0, np.nan, image)
        new_image = (image_no_zeros - np.nanmean(image_no_zeros)) / np.nanstd(image_no_zeros)
        new_image = np.nan_to_num(new_image)
        return new_image, label

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image, label = sample

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        # image = image.transpose((2, 0, 1))
        return torch.from_numpy(image), torch.from_numpy(label)

In [6]:
dataset = BraTSDataset(image_path = r'./BraTS/BraTS2021_Training_Data',
                                    transform=transforms.Compose([
                                        Flair(),
                                        CropAndResize(),
                                        Standardise(),
                                        IntLabel(),
                                        BinariseLabel(),
                                        ToTensor()
                                    ]))

https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/
https://github.com/facebookresearch/detectron2

In [21]:
batch_size = 128

train_val_test_split = [0.7, 0.2, 0.1]

generator = torch.Generator().manual_seed(SEED)

dataset_size = int(len(dataset)/155)
dataset_indices = list(range(dataset_size))

train_sampler, val_sampler, test_sampler = random_split(dataset_indices, train_val_test_split, generator=generator)

In [22]:
train_sampler.indices = [i * 155 + j for i in train_sampler.indices for j in range(155)]
val_sampler.indices = [i * 155 + j for i in val_sampler.indices for j in range(155)]
test_sampler.indices = [i * 155 + j for i in test_sampler.indices for j in range(155)]

random.seed(SEED)
random.shuffle(train_sampler.indices)
random.shuffle(val_sampler.indices)
random.shuffle(test_sampler.indices)

In [24]:
train_iterator = DataLoader(dataset, batch_size=batch_size,
                            sampler=train_sampler)
validation_iterator = DataLoader(dataset, batch_size=batch_size,
                            sampler=val_sampler)
test_iterator = DataLoader(dataset, batch_size=batch_size, sampler = test_sampler)

In [25]:
class CNN(nn.Module):
  def __init__(self):
    super().__init__()

    self.features = nn.Sequential(
      ## encoder layers ##
      # conv layer (depth from 1 --> 4), 3x3 kernels
      nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3),
      nn.ReLU(),
      # pooling layer to reduce x-y dims by two; kernel and stride of 2
      nn.MaxPool2d(2),
      # conv layer (depth from 4 --> 8), 4x4 kernels
      nn.Conv2d(in_channels=4, out_channels=8, kernel_size=4),
      nn.ReLU(),
      nn.MaxPool2d(2),
      # conv layer (depth from 8 --> 12), 5x5 kernels
      nn.Conv2d(in_channels=8, out_channels=12, kernel_size=5),
      nn.ReLU(),
      nn.MaxPool2d(3),
      ## decoder layers ##
      # add transpose conv layers, with relu activation function
      ## a kernel of 2 and a stride of 2 will increase the spatial dims by 2
      nn.ConvTranspose2d(12, 8, 2, stride=3),
      nn.ReLU(),
      nn.ConvTranspose2d(8, 4, 2, stride=2),
      nn.ReLU(),
      nn.ConvTranspose2d(4, 1, 2, stride=2),
      # output layer (with sigmoid for scaling from 0 to 1)
      nn.Sigmoid()
    )
    
    # self.linear = nn.Sequential(
    #   nn.Linear(256 * 5 * 5 * 5, 383625),
    #   nn.ReLU(),
    #   nn.Linear(383625, 1534500),
    #   nn.ReLU(),
    #   nn.Linear(1534500, output_dim)
    # )
    
  def forward(self, x):
    x = self.features(x)
    # x = x.view(x.shape[0], -1)
    # x = self.linear(x)
    return x

In [26]:
model = CNN().to(device)

print(f"The model has {count_parameters(model):,} trainable parameters.")

The model has 3,513 trainable parameters.


In [27]:
# Loss
# criterion = nn.CrossEntropyLoss() # Softmax + CrossEntropy
criterion = nn.BCELoss()
criterion = criterion.to(device)

# Optim
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [28]:
N_EPOCHS = 30
train_losses, train_accs, valid_losses, valid_accs = model_training(N_EPOCHS,
                                                                    model,
                                                                    train_iterator,
                                                                    validation_iterator,
                                                                    optimizer,
                                                                    criterion,
                                                                    device,
                                                                    'CNN.pt')


Epoch: 1/30 -- Epoch Time: 0.01 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 2/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 3/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 4/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 5/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 6/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 7/30 -- Epoch Time: 0.00 s
---------------------------------
Train -- Loss: 0.000, Acc: 0.00%
Val -- Loss: 0.000, Acc: 0.00%

Epoch: 8/30 -- Epoch Time: 0.00 s
---------------------------------


In [None]:
model_testing(model, test_iterator, criterion, device, 'CNN.pt')

In [None]:
print_report(model, test_iterator, device)