# First CNN model

In [1]:
import os, random, time
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

# Pytorch functions
import torch
# Neural network layers
import torch.nn as nn
import torch.nn.functional as F
# Optimizer
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
# Torchvision library
from torchvision import transforms
# Handling dataset
import torch.utils.data as data

# For results
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

In [2]:
# Device configuration
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)

mps


In [3]:
def set_seed(seed, use_cuda = True, use_mps = False):
    """
    Set SEED for PyTorch reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    if use_mps:
        torch.mps.manual_seed(seed)

SEED = 44

USE_SEED = True

if USE_SEED:
    set_seed(SEED, torch.cuda.is_available(), torch.backends.mps.is_available())

# Transformation

In [4]:
class Crop(object):
    def __init__(self, output_ind):
        self.output_ind = output_ind
    def __call__(self, sample):
        image, label = sample
        new_image = []
        output_ind = self.output_ind
        for i in range(len(image)): # 4
            new_image.append(image[i][output_ind[0][0]:output_ind[0][1], output_ind[1][0]:output_ind[1][1],:])
            new_label = label[output_ind[0][0]:output_ind[0][1], output_ind[1][0]:output_ind[1][1],:]
        return new_image, new_label

class Flatten(object):
    def __call__(self, sample):
        image, label = sample # images have 4 image
        new_image = []
        for i in range(len(image)):
            new_image.append(image[i].reshape(180, -1, order = 'F'))
        new_label = label.reshape(-1)
        return new_image, new_label
    
class ScanNormalize(object):
    def __call__(self, sample):
        image, label = sample
        new_image = []
        for i in range(len(image)):
            img = image[i]
            new_scan = (img-np.min(img))/(np.max(img)-np.min(img))
            new_image.append(new_scan)
        return new_image, label

class StackScans(object):
    def __call__(self, sample):
        image, label = sample
        new_image = np.stack(image, axis=-1)
        return new_image, label
    
class BinaryLabel(object):
    def __call__(self, sample):
        image, label = sample
        new_label = np.sign(label)
        return image, new_label
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image.transpose((2, 0, 1))
        return torch.from_numpy(image), torch.from_numpy(landmarks)

# Define custom data class

Flatten each 3D tensor into 2D and stack them into 3D tensors again.

In [5]:
class BraTSDataset(Dataset):
    def __init__(self, image_path = r'./BraTS/BraTS2021_Training_Data', transform=None):
        'Initialisation'
        self.image_path = image_path
        self.folders_name = [folder for folder in os.listdir(self.image_path) if folder != '.DS_Store']
        self.images, self.labels = self.get_images()
        self.transform = transform

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.images)

    def __getitem__(self, index):
        'Generates one sample of data'
        if self.transform:
            image, label = self.transform(image, label)
        return self.images[index], self.labels[index]
    
    def get_images(self):
        images = []
        labels = []
        for fld_name in self.folders_name:
            image = []
            for scan_type in ['flair', 't1', 't1ce', 't2']:
                path_img = os.path.join(self.image_path, fld_name, fld_name + '_' + scan_type + '.nii.gz')
                img = nib.load(path_img).get_fdata()
                image.append(img)
            
            path_label = os.path.join(self.image_path, fld_name, fld_name + '_seg.nii.gz')

            label = nib.load(path_label).get_fdata()

            images.append(image)
            labels.append(label)

        images = np.array(images, dtype=np.uint8)
        labels = np.array(label, dtype=np.uint8)

        return images, labels

In [6]:
crop_ind = [[35,215],[10,230]]
crop_len = [crop_ind[0][1]-crop_ind[0][0], crop_ind[1][1]-crop_ind[1][0]]
scan_depth = 155

In [7]:
dataset = BraTSDataset(image_path = r'./BraTS/BraTS2021_Training_Data',
                                    transform=transforms.Compose([
                                        Crop([[35,215],[10,230]]),
                                        Flatten(),
                                        ScanNormalize(),
                                        StackScans(),
                                        BinaryLabel(),
                                        ToTensor()
                                    ]))

In [8]:
batch_size = 1

train_val_test_split = [0.7, 0.2, 0.1]

generator = torch.Generator().manual_seed(SEED)

dataset_size = len(dataset)
dataset_indices = list(range(dataset_size))

train_sampler, val_sampler, test_sampler = random_split(dataset_indices, train_val_test_split, generator=generator)



In [10]:
train_loader = DataLoader(dataset, batch_size=batch_size,
                            sampler=train_sampler)
validation_loader = DataLoader(dataset, batch_size=batch_size,
                            sampler=val_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler = test_sampler)

# Model Input

3D Images of size 4 155 180 220 -> 4 180 155 * 220 -> 4 `crop_len[0]` `scan_depth`*`crop_len[1]`

Return 155 180 220 -> `scan_depth` `crop_len[0]` `crop_len[1]`

In [11]:
class FirstCNN(nn.Module):
  def __init__(self, output_dim):
    super().__init__()

    self.features = nn.Sequential(
      nn.Conv2d(in_channels=4, out_channels=16, kernel_size=5),
      nn.MaxPool2d(kernel_size=2),
      nn.ReLU(),
      nn.Conv2d(in_channels=16, out_channels=64, kernel_size=5),
      nn.MaxPool2d(kernel_size=2),
      nn.ReLU(),
      nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5),
      nn.MaxPool2d(kernel_size=2),
      nn.ReLU()
    )
    
    self.linear = nn.Sequential(
      nn.Linear(256 * 5 * 5 * 5, 383625),
      nn.ReLU(),
      nn.Linear(383625, 1534500),
      nn.ReLU(),
      nn.Linear(1534500, output_dim)
    )
    
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.shape[0], -1)
    x = self.linear(x)
    return x

In [12]:
OUTPUT_DIM = scan_depth * np.prod(crop_len)
model = FirstCNN(OUTPUT_DIM)

: 

In [None]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} trainable parameters.")

6138000

In [None]:
# Loss
criterion = nn.CrossEntropyLoss() # Softmax + CrossEntropy

criterion = criterion.to(device)

In [None]:
# Optim
optimizer = optim.Adam(model.parameters(), lr=1e-4)

model = model.to(device)

In [None]:
def calculate_accuracy(y_pred, y):
  '''
  Compute accuracy from ground-truth and predicted labels.

  Input
  ------
  y_pred: torch.Tensor [BATCH_SIZE, N_LABELS]
  y: torch.Tensor [BATCH_SIZE]

  Output
  ------
  acc: float
    Accuracy
  '''
  y_prob = F.softmax(y_pred, dim = -1)
  y_pred = y_prob.argmax(dim=1, keepdim = True)
  correct = y_pred.eq(y.view_as(y_pred)).sum()
  acc = correct.float()/y.shape[0]
  return acc

In [None]:
def train(model, iterator, optimizer, criterion, device):
  epoch_loss = 0
  epoch_acc = 0

  # Train mode
  model.train()

  for (x,y) in iterator:
    x = x.to(device)
    y = y.to(device)
    # Set gradients to zero
    optimizer.zero_grad()

    # Make Predictions
    y_pred = model(x)

    # Compute loss
    loss = criterion(y_pred, y)

    # Compute accuracy
    acc = calculate_accuracy(y_pred, y)

    # Backprop
    loss.backward()

    # Apply optimizer
    optimizer.step()

    # Extract data from loss and accuracy
    epoch_loss += loss.item()
    epoch_acc += acc.item()

  return epoch_loss/len(iterator), epoch_acc/len(iterator)

In [None]:
def evaluate(model, iterator, criterion, device):
  epoch_loss = 0
  epoch_acc = 0

  # Evaluation mode
  model.eval()

  # Do not compute gradients
  with torch.no_grad():

    for(x,y) in iterator:

      x = x.to(device)
      y = y.to(device)

      # Make Predictions
      y_pred = model(x)

      # Compute loss
      loss = criterion(y_pred, y)

      # Compute accuracy
      acc = calculate_accuracy(y_pred, y)

      # Extract data from loss and accuracy
      epoch_loss += loss.item()
      epoch_acc += acc.item()

  return epoch_loss/len(iterator), epoch_acc/len(iterator)

In [None]:
def model_training(n_epochs, model, train_iterator, valid_iterator, optimizer, criterion, device, model_name='best_model.pt'):

  # Initialize validation loss
  best_valid_loss = float('inf')

  # Save output losses, accs
  train_losses = []
  train_accs = []
  valid_losses = []
  valid_accs = []

  # Loop over epochs
  for epoch in range(n_epochs):
    start_time = time.time()
    # Train
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)
    # Validation
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)
    # Save best model
    if valid_loss < best_valid_loss:
      best_valid_loss = valid_loss
      # Save model
      torch.save(model.state_dict(), model_name)
    end_time = time.time()

    print(f"\nEpoch: {epoch+1}/{n_epochs} -- Epoch Time: {end_time-start_time:.2f} s")
    print("---------------------------------")
    print(f"Train -- Loss: {train_loss:.3f}, Acc: {train_acc * 100:.2f}%")
    print(f"Val -- Loss: {valid_loss:.3f}, Acc: {valid_acc * 100:.2f}%")

    # Save
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)

  return train_losses, train_accs, valid_losses, valid_accs

In [None]:
N_EPOCHS = 30
train_losses, train_accs, valid_losses, valid_accs = model_training(N_EPOCHS,
                                                                    model,
                                                                    train_iterator,
                                                                    valid_iterator,
                                                                    optimizer,
                                                                    criterion,
                                                                    device,
                                                                    'lenet.pt')

In [None]:
def plot_results(n_epochs, train_losses, train_accs, valid_losses, valid_accs):
  N_EPOCHS = n_epochs
  # Plot results
  plt.figure(figsize=(20, 6))
  _ = plt.subplot(1,2,1)
  plt.plot(np.arange(N_EPOCHS)+1, train_losses, linewidth=3)
  plt.plot(np.arange(N_EPOCHS)+1, valid_losses, linewidth=3)
  _ = plt.legend(['Train', 'Validation'])
  plt.grid('on'), plt.xlabel('Epoch'), plt.ylabel('Loss')

  _ = plt.subplot(1,2,2)
  plt.plot(np.arange(N_EPOCHS)+1, train_accs, linewidth=3)
  plt.plot(np.arange(N_EPOCHS)+1, valid_accs, linewidth=3)
  _ = plt.legend(['Train', 'Validation'])
  plt.grid('on'), plt.xlabel('Epoch'), plt.ylabel('Accuracy')

In [None]:
def model_testing(model, test_iterator, criterion, device, model_name='best_model.pt'):
  # Test model
  model.load_state_dict(torch.load(model_name))
  test_loss, test_acc = evaluate(model, test_iterator, criterion, device)
  print(f"Test -- Loss: {test_loss:.3f}, Acc: {test_acc * 100:.2f} %")

In [None]:
model_testing(model, test_loader, criterion, device, 'lenet.pt')

In [None]:
def predict(model, iterator, device):

  # Evaluation mode
  model.eval()

  labels = []
  pred = []

  with torch.no_grad():
    for (x, y) in iterator:
      x = x.to(device)
      y_pred = model(x)

      # Get label with highest score
      y_prob = F.softmax(y_pred, dim = -1)
      top_pred = y_prob.argmax(1, keepdim=True)

      labels.append(y.cpu())
      pred.append(top_pred.cpu())

  labels = torch.cat(labels, dim=0)
  pred = torch.cat(pred, dim=0)

  return labels, pred


In [None]:
def print_report(model, test_iterator, device):
  labels, pred = predict(model, test_iterator, device)
  print(confusion_matrix(labels, pred))
  print("\n")
  print(classification_report(labels, pred))

In [None]:
print_report(model, test_loader, device)