### Imports

In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader

from glob import glob as glob
import nibabel as nib
import numpy as np
import datetime
import random

import monai
from monai.networks.nets import UNet
from monai.transforms import Compose
from monai.visualize import plot_2d_or_3d_image
from monai.metrics import CumulativeIterationMetric
from monai.losses import DiceLoss
from monai.metrics.meandice import DiceMetric

from utilities import save_checkpoint


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
my_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
my_device

device(type='cuda', index=0)

# Dataset

In [3]:
local_dir = 'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset'

In [4]:
train_volumes_path = sorted(glob(os.path.join(local_dir, 'TrainVolumes', '*.nii.gz')))
train_segmentations_path = sorted(glob(os.path.join(local_dir, 'TrainSegmentation', '*.nii.gz')))

In [5]:
list(zip(train_volumes_path,train_segmentations_path))[:2]

[('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TrainVolumes\\BRATS_001.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TrainSegmentation\\BRATS_001.nii.gz'),
 ('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TrainVolumes\\BRATS_002.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TrainSegmentation\\BRATS_002.nii.gz')]

In [6]:
val_volumes_path = sorted(glob(os.path.join(local_dir, 'TestVolumes', '*.nii.gz')))
val_segmentations_path = sorted(glob(os.path.join(local_dir, 'TestSegmentation', '*.nii.gz')))


In [7]:
list(zip(val_volumes_path,val_segmentations_path))[:4]

[('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestVolumes\\BRATS_011.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestSegmentation\\BRATS_011.nii.gz'),
 ('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestVolumes\\BRATS_012.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestSegmentation\\BRATS_012.nii.gz'),
 ('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestVolumes\\BRATS_013.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestSegmentation\\BRATS_013.nii.gz'),
 ('D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestVolumes\\BRATS_014.nii.gz',
  'D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset\\TestSegmentation\\BRATS_014.nii.gz')]

In [8]:
class permute_and_add_axis_to_mask(object):
    def __call__(self, sample):
        # Previous: (240, 240, 155, 4) , need to change to (4, 155, 240, 240) i.e. (channel, depth, height, width)
        image, mask = sample['image'], sample['mask']

        image = image.transpose((3, 2, 0, 1))
        mask = mask.transpose((2, 0, 1))

        mask= mask[np.newaxis, ...]
        return {'image':image,
                'mask':mask}

In [9]:
class BratsDataset(Dataset):
    def __init__(self, images_path_list, masks_path_list, transform=None):
        """
        Args:
            images_path_list (list of strings): List of paths to input images.
            masks_path_list (list of strings): List of paths to masks.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images_path_list = images_path_list
        self.masks_path_list = masks_path_list
        self.transform = transform
        self.length = len(images_path_list)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Load image
        image_path = self.images_path_list[idx]
        image = nib.load(image_path).get_fdata()
        image = np.float32(image) # shape of image [240, 240, 155, 4]

        # Load mask
        mask_path = self.masks_path_list[idx]
        mask = nib.load(mask_path).get_fdata()
        mask = np.float32(mask) # shape of mask [240, 240, 155]

        if self.transform:
            transformed_sample = self.transform({'image': image, 'mask': mask})
        
        return transformed_sample


In [10]:
class spatialpad(object): # First dimension should be left untouched of [C, D, H, W]
    def __init__(self, image_target_size=[4, 256, 256, 256], mask_target_size=[1, 256, 256, 256]):
        self.image_target_size = image_target_size
        self.mask_target_size = mask_target_size

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask'] # image: [4, 155, 240, 240], mask[1, 155, 240, 240]

        padded_image = self.pad_input(image, self.image_target_size)

        padded_mask = self.pad_input(mask, self.mask_target_size)

        return {'image': padded_image,
                'mask': padded_mask}
    

    def pad_input(self, input_array, target_size):
        # Ensure the input array is a numpy array
        if not isinstance(input_array, np.ndarray):
            input_array = np.array(input_array)

        # Calculate padding sizes for each dimension
        pad_width = []
        for i in range(len(input_array.shape)):
            total_padding = target_size[i] - input_array.shape[i]
            if total_padding < 0:
                raise ValueError(f"Target shape must be larger than the input shape. Dimension {i} is too small.")
            pad_before = total_padding // 2
            pad_after = total_padding - pad_before
            pad_width.append((pad_before, pad_after))

        # Pad the image
        padded_image = np.pad(input_array, pad_width, mode='constant', constant_values=0)

        return padded_image   

In [11]:
data_transform = Compose([ # input image of shape [240, 240, 155, 4]
    permute_and_add_axis_to_mask(), # image: [4, 155, 240, 240], mask[1, 155, 240, 240] # new channel in the first dimension is added in mask inorder to make compatible with Resize() as Resize takes only 4D tensor
    spatialpad(image_target_size=[4, 256, 256, 256], mask_target_size=[1, 256, 256, 256]),
])

In [12]:
train_ds = BratsDataset(
    train_volumes_path,
    train_segmentations_path,
    transform=data_transform
)

val_ds = BratsDataset(
    val_volumes_path,
    val_segmentations_path,
    transform=data_transform
)

In [13]:
sample_train = train_ds[0]
sample_train['image'].shape, sample_train['mask'].shape # previously numpy array of (240, 240, 155, 4), Now changed to: (4, 155, 240, 240) with first transform, then changed to (4, 256, 256, 256) by second transform

((4, 256, 256, 256), (1, 256, 256, 256))

# DataLoader

In [14]:
# Create dataloader
train_loader = DataLoader(dataset=train_ds,
                          batch_size=1,
                          shuffle=True,
                          drop_last=True)
val_loader = DataLoader(dataset=val_ds,
                        batch_size=1,
                        shuffle=False,
                        drop_last=True)

# Model

In [15]:
# Instantiate a U-Net model
model = UNet(
    spatial_dims=3,        # 3 for using 3D ConvNet and 3D Maxpooling
    in_channels=4,         # since 4 modalities
    out_channels=4,        # 4 sub-regions to segment
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(my_device)
print(model)

UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): Convolution(
          (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (1): SkipConnection(
          (submodule): Sequential(
            (0): Convolution(
              (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (adn): ADN(
                (N): Inst

# Training

In [16]:
def train_step(model,
               dataloader,
               loss_fn,
               dice_score,
               optimizer):
    # Putting the model in train mode
    model.train()

    # Initialize train_loss list
    train_loss = [] # epoch wise

    # Loop through batches of data
    for batch_num, batch_data in enumerate(dataloader):
        X = batch_data['image'] # torch.Size([batch, 4, 128, 240, 240])
        Y = batch_data['mask'] # torch.Size([batch, 1, 128, 240, 240]) (batch, 1, 128, 240, 240) (multi-class i.e. a pixel_value ~ {0, 1, 2, 3})

        # Send data to target device
        X, Y = X.to(my_device), Y.to(my_device)

        optimizer.zero_grad() # Clear previous gradients

        # Forward pass
        y_pred = model(X) # y_pred shape torch.Size([batch, 4, 128, 240, 240]) # produces raw logits. 4 is due to 4 sub regions, not 4 modalities.
        
        # Compute and accumulate loss
        loss = loss_fn(y_pred, Y) # loss one-hot encodes the y so y will be [batch, 4, 128, 240, 240] and y_pred is [batch, 4, 128, 240, 240], loss is scalar (may be averages across modalities and batch as well)

        # Backpropagation and Optimization
        loss.backward() # Compute gradients
        optimizer.step() # Update weights
        tr_loss = loss.item()

        # Accumulate train_loss for log
        train_loss.append(tr_loss)

        with torch.no_grad():
            # Calculate and accumulate metric across the batch
            predicted_class_labels = torch.argmax(y_pred, dim=1, keepdim=True) # After argmax with keepdim=True: [batch, 1, D, H, W] {0, 1, 2, 3}, since it takes argmax along the channels(or, the #classes)

        print(f'Iteration: {batch_num + 1} ---|---  Loss {tr_loss:.3f}')

    return train_loss

In [17]:
def val_step(model,
              dataloader,
              dice_score:CumulativeIterationMetric):
    
    # Putting model in eval mode
    model.eval()

    # Initialize validation dice score
    val_dice_score = 0

    dice_score.reset()

    # Turn on inference context manager
    with torch.inference_mode(): # Disable gradient computation 

        # Loop through batches of data in dataloader
        for batch_num, batch_data in enumerate(dataloader):
            X = batch_data['image'] # [batch, 4, D, H, W]
            Y = batch_data['mask'] # [B, 1, D, H, W]

            # Send data to target device
            X, Y = X.to(my_device), Y.to(my_device)

            # Forward pass
            test_pred_logits = model(X) # [B, 4, 128, 240, 240]

            # Calculate and accumulate metric across the batch
            predicted_class_labels = torch.argmax(test_pred_logits, dim=1, keepdim=True) # test_pred_logits of shape [batch, 4, D, H, W] {raw logits}, after argmax [batch, D, H, W] {0, 1, 2, 3}, since it takes argmax along the channels(or, the #classes)
            batch_dice_score = dice_score(predicted_class_labels, Y)
            
            # print(f"DSC (batch wise): {batch_dice_score}")

    # Aggregate dice score (epoch wise)
    val_dice_score = dice_score.aggregate().item()

    return val_dice_score

In [18]:
from tqdm.auto import tqdm

# Various parameters required for training and test step
def train(model,
          checkpoint_dir,
          train_loader,
          val_loader,
          optimizer,
          loss_fn,
          dice_score,
          epochs):
    
    # Creating empty list to hold loss and accuracy
    results = {
        'batch_train_loss':[],
        'epoch_val_dice_score':[]
    }

    # Looping through traininig and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        print(f'----------- Epoch: {epoch+1} ----------- \n')
        batch_train_loss_list = train_step(model=model,
                                      dataloader=train_loader,
                                      loss_fn=loss_fn,
                                      dice_score=dice_score,
                                      optimizer=optimizer)
        
        epoch_val_dice_score = val_step(model=model,
                                  dataloader=val_loader,
                                  dice_score=dice_score)

        print(f'\n'
              f'--|-- Epoch {epoch+1} Validation DS: {epoch_val_dice_score:.4f} --|--')
        
        # Save checkpoint
        # save_checkpoint(model, checkpoint_dir, optimizer, epoch, 0.0, 0.0, torch.mean(batch_train_loss_list), epoch_val_dice_score)

        # Append to the list
        results['batch_train_loss'].append(np.mean(batch_train_loss_list))
        results['epoch_val_dice_score'].append(epoch_val_dice_score)

    return results

In [19]:
# Model name
model_name = '3DUNet'

# Generate a timestamp
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

# Directory to save the checkpoints
checkpoint_dir = os.path.join('checkpoints', model_name, timestamp)

# Create the directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"Checkpoints will be saved in: {checkpoint_dir}")


Checkpoints will be saved in: checkpoints\3DUNet\2024-05-23_17-09-53


In [20]:
# Set random seeds
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
monai.utils.set_determinism(seed=random_seed)

# Set the number of epochs, loss function and optimizer
num_epochs = 2
dice_loss = DiceLoss(include_background=False, to_onehot_y=True, softmax=True)
dice_score = DiceMetric(include_background=False)

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

# Start the timer
from timeit import default_timer as timer
start_time = timer()

# Train model
model_results = train(model,
                      checkpoint_dir,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      optimizer=optimizer,
                      loss_fn=dice_loss,
                      dice_score=dice_score,
                      epochs=num_epochs)

# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

  0%|          | 0/2 [00:00<?, ?it/s]

----------- Epoch: 1 ----------- 



  return F.conv_transpose3d(


Iteration: 1 ---|---  Loss 0.993
Iteration: 2 ---|---  Loss 0.995
Iteration: 3 ---|---  Loss 0.999
Iteration: 4 ---|---  Loss 0.996
Iteration: 5 ---|---  Loss 0.995
Iteration: 6 ---|---  Loss 0.987
Iteration: 7 ---|---  Loss 0.995
Iteration: 8 ---|---  Loss 0.993
Iteration: 9 ---|---  Loss 0.982
Iteration: 10 ---|---  Loss 0.983


 50%|█████     | 1/2 [00:43<00:43, 43.47s/it]


--|-- Epoch 1 Validation DS: 0.0664 --|--
----------- Epoch: 2 ----------- 

Iteration: 1 ---|---  Loss 0.995
Iteration: 2 ---|---  Loss 0.998
Iteration: 3 ---|---  Loss 0.993
Iteration: 4 ---|---  Loss 0.978
Iteration: 5 ---|---  Loss 0.991
Iteration: 6 ---|---  Loss 0.978
Iteration: 7 ---|---  Loss 0.994
Iteration: 8 ---|---  Loss 0.978
Iteration: 9 ---|---  Loss 0.991
Iteration: 10 ---|---  Loss 0.982


100%|██████████| 2/2 [01:25<00:00, 42.73s/it]


--|-- Epoch 2 Validation DS: 0.0695 --|--
Total training time: 85.464 seconds



