## Imports

In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader

from glob import glob as glob
import nibabel as nib
import numpy as np
import datetime
import random

import monai
from monai.networks.nets import UNet
from monai.transforms import Compose
from monai.visualize import plot_2d_or_3d_image
from monai.metrics import CumulativeIterationMetric
from monai.losses import DiceLoss
from monai.metrics.meandice import DiceMetric

from utilities import save_checkpoint
from utilities import permute_and_add_axis_to_mask, spatialpad

import csv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
my_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
my_device

device(type='cuda', index=0)

## Dataset

In [3]:
base_data_dir = "D:\\Neuroscience and Neuroimaging\\CAP5516 Medical Image Computing\\MSD\\brats_subset"

In [4]:
def get_volumes_path(base_data_dir):
    train_volumes_path = sorted(glob(os.path.join(base_data_dir, 'TrainVolumes', '*.nii.gz')))
    train_segmentations_path = sorted(glob(os.path.join(base_data_dir, 'TrainSegmentation', '*.nii.gz')))

    val_volumes_path = sorted(glob(os.path.join(base_data_dir, 'TestVolumes', '*.nii.gz')))
    val_segmentations_path = sorted(glob(os.path.join(base_data_dir, 'TestSegmentation', '*.nii.gz')))

    return train_volumes_path, train_segmentations_path, val_volumes_path, val_segmentations_path

In [5]:
class BratsDataset(Dataset):
    def __init__(self, images_path_list, masks_path_list, transform=None):
        """
        Args:
            images_path_list (list of strings): List of paths to input images.
            masks_path_list (list of strings): List of paths to masks.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images_path_list = images_path_list
        self.masks_path_list = masks_path_list
        self.transform = transform
        self.length = len(images_path_list)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Load image
        image_path = self.images_path_list[idx]
        image = nib.load(image_path).get_fdata()
        image = np.float32(image) # shape of image [240, 240, 155, 4]

        # Load mask
        mask_path = self.masks_path_list[idx]
        mask = nib.load(mask_path).get_fdata()
        mask = np.float32(mask) # shape of mask [240, 240, 155]

        if self.transform:
            transformed_sample = self.transform({'image': image, 'mask': mask})
        
        return transformed_sample


In [6]:
train_volumes_path, train_segmentations_path, val_volumes_path, val_segmentations_path = get_volumes_path(base_data_dir)

## Data Transform

In [7]:
data_transform = Compose([ # input image of shape [240, 240, 155, 4]
    permute_and_add_axis_to_mask(), # image: [4, 155, 240, 240], mask[1, 155, 240, 240] # new channel in the first dimension is added in mask inorder to make compatible with Resize() as Resize takes only 4D tensor
    spatialpad(image_target_size=[4, 256, 256, 256], mask_target_size=[1, 256, 256, 256]),
])

In [8]:
train_ds = BratsDataset(
    train_volumes_path,
    train_segmentations_path,
    transform=data_transform
)

val_ds = BratsDataset(
    val_volumes_path,
    val_segmentations_path,
    transform=data_transform
)

## DataLoader

In [9]:
# Create dataloaders
train_loader = DataLoader(dataset=train_ds,
                          batch_size=1,
                          shuffle=True,
                          drop_last=True)
val_loader = DataLoader(dataset=val_ds,
                        batch_size=1,
                        shuffle=False,
                        drop_last=True)

## Model

In [10]:
# Instantiate a U-Net model
model = UNet(
    spatial_dims=3,        # 3 for using 3D ConvNet and 3D Maxpooling
    in_channels=4,         # since 4 modalities
    out_channels=4,        # 4 sub-regions to segment
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(my_device)
print(model)

UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): Convolution(
          (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (1): SkipConnection(
          (submodule): Sequential(
            (0): Convolution(
              (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
              (adn): ADN(
                (N): Inst

# Training

In [11]:
def train_step(model,
               dataloader,
               loss_fn,
               dice_score,
               optimizer,
               batch_train_loss_csv,
               epoch_num):
               
    # Putting the model in train mode
    model.train()

    # Initialize train_loss list
    train_loss = []

    # Loop through batches of data
    for batch_num, batch_data in enumerate(dataloader):
        X = batch_data['image'] # torch.Size([batch, 4, 128, 240, 240])
        Y = batch_data['mask'] # torch.Size([batch, 1, 128, 240, 240]) (batch, 1, 128, 240, 240) (multi-class i.e. a pixel_value ~ {0, 1, 2, 3})

        # Send data to target device
        X, Y = X.to(my_device), Y.to(my_device)

        optimizer.zero_grad() # Clear previous gradients

        # Forward pass
        y_pred = model(X) # y_pred shape torch.Size([batch, 4, 128, 240, 240]) # produces raw logits. 4 is due to 4 sub regions, not 4 modalities.
        
        # Compute and accumulate loss
        loss = loss_fn(y_pred, Y) # loss one-hot encodes the y so y will be [batch, 4, 128, 240, 240] and y_pred is [batch, 4, 128, 240, 240], loss is scalar (may be averages across modalities and batch as well)

        # Backpropagation and Optimization
        loss.backward() # Compute gradients
        optimizer.step() # Update weights
        tr_loss = loss.item()

        # Accumulate train_loss for log
        train_loss.append(tr_loss)

        with torch.no_grad():
            # Calculate and accumulate metric across the batch
            predicted_class_labels = torch.argmax(y_pred, dim=1, keepdim=True) # After argmax with keepdim=True: [batch, 1, D, H, W] {0, 1, 2, 3}, since it takes argmax along the channels(or, the #classes)

        # Write batch wise train loss to csv file
        with open(batch_train_loss_csv, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch_num+1, batch_num+1, tr_loss])

        print(f'Iteration: {batch_num + 1} ---|---  Loss {tr_loss:.3f}')

    # Average the dice loss (i.e. average of batches ~ 1 epoch)
    train_loss_average = np.mean(train_loss) # for checkpoint

    return  train_loss_average

In [12]:
def val_step(model,
              dataloader,
              dice_score:CumulativeIterationMetric):

    # Putting model in eval mode
    model.eval()

    # Initialize validation dice score
    val_dice_score = 0

    dice_score.reset()

    # Turn on inference context manager
    with torch.inference_mode(): # Disable gradient computation 

        # Loop through batches of data in dataloader
        for batch_num, batch_data in enumerate(dataloader):
            X = batch_data['image'] # [batch, 4, D, H, W]
            Y = batch_data['mask'] # [B, 1, D, H, W]

            # Send data to target device
            X, Y = X.to(my_device), Y.to(my_device)

            # Forward pass
            test_pred_logits = model(X) # [B, 4, 128, 240, 240]

            # Calculate and accumulate metric across the batch
            predicted_class_labels = torch.argmax(test_pred_logits, dim=1, keepdim=True) # test_pred_logits of shape [batch, 4, D, H, W] {raw logits}, after argmax [batch, D, H, W] {0, 1, 2, 3}, since it takes argmax along the channels(or, the #classes)
            batch_dice_score = dice_score(predicted_class_labels, Y)
            
            # print(f"DSC (batch wise): {batch_dice_score}")

    # Aggregate dice score (epoch wise)
    val_dice_score = dice_score.aggregate().item()

    return val_dice_score

In [13]:
from tqdm.auto import tqdm

# Various parameters required for training and test step
def train(model,
          checkpoint_dir,
          batch_train_loss_csv,
          epoch_val_dsc_csv,
          train_loader,
          val_loader,
          optimizer,
          loss_fn,
          dice_score,
          epochs):
    
    # # Creating empty list to hold loss and dice_score
    # results = {
    #     'train_loss':[], # This store epoch wise dice loss
    #     'val_dice_score':[] # This stores epoch wise validation dice score
    # }

    # Looping through traininig and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        print(f'----------- Epoch: {epoch+1} ----------- \n')

        train_loss = train_step(model=model,
                                      dataloader=train_loader,
                                      loss_fn=loss_fn,
                                      dice_score=dice_score,
                                      optimizer=optimizer,
                                      batch_train_loss_csv=batch_train_loss_csv,
                                      epoch_num=epoch)
        
        val_dice_score = val_step(model=model,
                                  dataloader=val_loader,
                                  dice_score=dice_score)

        # Write val_dice_score to csv file
        with open(epoch_val_dsc_csv, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, val_dice_score])

        print(f'--|-- Epoch {epoch+1} Validation DS: {val_dice_score:.4f} --|--')

        # Save checkpoint
        save_checkpoint(model, checkpoint_dir, optimizer, epoch+1, np.mean(train_loss), val_dice_score)

        # # Append to the list
        # results['train_loss'].append(np.mean(train_loss))
        # results['val_dice_score'].append(val_dice_score)

    # return results

In [14]:
# Model name
model_name = '3DUNet'

# Generate a timestamp
timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') # -%S

# path directory to save the checkpoints and logs
log_dir = os.path.join('logs_and_checkpoints', 'logs', timestamp)
checkpoint_dir = os.path.join('logs_and_checkpoints', 'checkpoints', model_name, timestamp)

# Create the directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

print(f"Checkpoints will be saved in: {log_dir}")
print(f"Logs will be saved in: {checkpoint_dir}")

Checkpoints will be saved in: logs_and_checkpoints\logs\2024-05-23_20-48
Logs will be saved in: logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48


In [15]:
# Set random seeds
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
monai.utils.set_determinism(seed=random_seed)

# Set the number of epochs, loss function and optimizer
num_epochs = 10
dice_loss = DiceLoss(include_background=False, to_onehot_y=True, softmax=True)
dice_score = DiceMetric(include_background=False)

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)


# Create files for logs
# batch_train_loss_csv to save batch wise loss
# epoch_val_dsc_csv to save epoch wise loss
batch_train_loss_csv = os.path.join(log_dir, f'batch_train_loss_{log_dir[-16:]}.csv')
epoch_val_dsc_csv = os.path.join(log_dir, f'epoch_val_dsc_{log_dir[-16:]}.csv')

with open(batch_train_loss_csv, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Iteration', 'Train Loss'])

with open(epoch_val_dsc_csv, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Val DSC'])


# Start the timer
from timeit import default_timer as timer
start_time = timer()

# Train model
model_results = train(model,
                      checkpoint_dir,
                      batch_train_loss_csv,
                      epoch_val_dsc_csv,
                      train_loader=train_loader,
                      val_loader=val_loader,
                      optimizer=optimizer,
                      loss_fn=dice_loss,
                      dice_score=dice_score,
                      epochs=num_epochs)

# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

  0%|          | 0/10 [00:00<?, ?it/s]

----------- Epoch: 1 ----------- 



  return F.conv_transpose3d(


Iteration: 1 ---|---  Loss 0.992
Iteration: 2 ---|---  Loss 0.994
Iteration: 3 ---|---  Loss 0.998
Iteration: 4 ---|---  Loss 0.994
Iteration: 5 ---|---  Loss 0.995
Iteration: 6 ---|---  Loss 0.985
Iteration: 7 ---|---  Loss 0.995
Iteration: 8 ---|---  Loss 0.993
Iteration: 9 ---|---  Loss 0.979
Iteration: 10 ---|---  Loss 0.980


 10%|█         | 1/10 [00:42<06:26, 43.00s/it]

--|-- Epoch 1 Validation DS: 0.0776 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_1_checkpoint_2024-05-23_20-48-47.pth
----------- Epoch: 2 ----------- 

Iteration: 1 ---|---  Loss 0.994
Iteration: 2 ---|---  Loss 0.997
Iteration: 3 ---|---  Loss 0.993
Iteration: 4 ---|---  Loss 0.976
Iteration: 5 ---|---  Loss 0.990
Iteration: 6 ---|---  Loss 0.975
Iteration: 7 ---|---  Loss 0.993
Iteration: 8 ---|---  Loss 0.977
Iteration: 9 ---|---  Loss 0.991
Iteration: 10 ---|---  Loss 0.982


 20%|██        | 2/10 [01:24<05:38, 42.27s/it]

--|-- Epoch 2 Validation DS: 0.1007 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_2_checkpoint_2024-05-23_20-49-28.pth
----------- Epoch: 3 ----------- 

Iteration: 1 ---|---  Loss 0.993
Iteration: 2 ---|---  Loss 0.992
Iteration: 3 ---|---  Loss 0.971
Iteration: 4 ---|---  Loss 0.997
Iteration: 5 ---|---  Loss 0.972
Iteration: 6 ---|---  Loss 0.988
Iteration: 7 ---|---  Loss 0.979
Iteration: 8 ---|---  Loss 0.989
Iteration: 9 ---|---  Loss 0.992
Iteration: 10 ---|---  Loss 0.972


 30%|███       | 3/10 [02:06<04:53, 41.91s/it]

--|-- Epoch 3 Validation DS: 0.0929 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_3_checkpoint_2024-05-23_20-50-10.pth
----------- Epoch: 4 ----------- 

Iteration: 1 ---|---  Loss 0.968
Iteration: 2 ---|---  Loss 0.991
Iteration: 3 ---|---  Loss 0.986
Iteration: 4 ---|---  Loss 0.991
Iteration: 5 ---|---  Loss 0.987
Iteration: 6 ---|---  Loss 0.996
Iteration: 7 ---|---  Loss 0.975
Iteration: 8 ---|---  Loss 0.990
Iteration: 9 ---|---  Loss 0.968
Iteration: 10 ---|---  Loss 0.964


 40%|████      | 4/10 [02:48<04:11, 41.95s/it]

--|-- Epoch 4 Validation DS: 0.1280 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_4_checkpoint_2024-05-23_20-50-52.pth
----------- Epoch: 5 ----------- 

Iteration: 1 ---|---  Loss 0.962
Iteration: 2 ---|---  Loss 0.972
Iteration: 3 ---|---  Loss 0.964
Iteration: 4 ---|---  Loss 0.985
Iteration: 5 ---|---  Loss 0.989
Iteration: 6 ---|---  Loss 0.989
Iteration: 7 ---|---  Loss 0.995
Iteration: 8 ---|---  Loss 0.982
Iteration: 9 ---|---  Loss 0.956
Iteration: 10 ---|---  Loss 0.989


 50%|█████     | 5/10 [03:30<03:29, 41.88s/it]

--|-- Epoch 5 Validation DS: 0.1878 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_5_checkpoint_2024-05-23_20-51-34.pth
----------- Epoch: 6 ----------- 

Iteration: 1 ---|---  Loss 0.981
Iteration: 2 ---|---  Loss 0.995
Iteration: 3 ---|---  Loss 0.986
Iteration: 4 ---|---  Loss 0.957
Iteration: 5 ---|---  Loss 0.951
Iteration: 6 ---|---  Loss 0.966
Iteration: 7 ---|---  Loss 0.986
Iteration: 8 ---|---  Loss 0.980
Iteration: 9 ---|---  Loss 0.952
Iteration: 10 ---|---  Loss 0.986


 60%|██████    | 6/10 [04:11<02:47, 41.83s/it]

--|-- Epoch 6 Validation DS: 0.2357 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_6_checkpoint_2024-05-23_20-52-15.pth
----------- Epoch: 7 ----------- 

Iteration: 1 ---|---  Loss 0.948
Iteration: 2 ---|---  Loss 0.961
Iteration: 3 ---|---  Loss 0.986
Iteration: 4 ---|---  Loss 0.983
Iteration: 5 ---|---  Loss 0.986
Iteration: 6 ---|---  Loss 0.942
Iteration: 7 ---|---  Loss 0.994
Iteration: 8 ---|---  Loss 0.975
Iteration: 9 ---|---  Loss 0.975
Iteration: 10 ---|---  Loss 0.942


 70%|███████   | 7/10 [04:53<02:05, 41.87s/it]

--|-- Epoch 7 Validation DS: 0.3107 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_7_checkpoint_2024-05-23_20-52-57.pth
----------- Epoch: 8 ----------- 

Iteration: 1 ---|---  Loss 0.983
Iteration: 2 ---|---  Loss 0.973
Iteration: 3 ---|---  Loss 0.981
Iteration: 4 ---|---  Loss 0.974
Iteration: 5 ---|---  Loss 0.982
Iteration: 6 ---|---  Loss 0.950
Iteration: 7 ---|---  Loss 0.935
Iteration: 8 ---|---  Loss 0.992
Iteration: 9 ---|---  Loss 0.925
Iteration: 10 ---|---  Loss 0.932


 80%|████████  | 8/10 [05:35<01:23, 41.83s/it]

--|-- Epoch 8 Validation DS: 0.3232 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_8_checkpoint_2024-05-23_20-53-39.pth
----------- Epoch: 9 ----------- 

Iteration: 1 ---|---  Loss 0.979
Iteration: 2 ---|---  Loss 0.980
Iteration: 3 ---|---  Loss 0.930
Iteration: 4 ---|---  Loss 0.990
Iteration: 5 ---|---  Loss 0.927
Iteration: 6 ---|---  Loss 0.965
Iteration: 7 ---|---  Loss 0.974
Iteration: 8 ---|---  Loss 0.913
Iteration: 9 ---|---  Loss 0.939
Iteration: 10 ---|---  Loss 0.963


 90%|█████████ | 9/10 [06:17<00:42, 42.03s/it]

--|-- Epoch 9 Validation DS: 0.3286 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_9_checkpoint_2024-05-23_20-54-22.pth
----------- Epoch: 10 ----------- 

Iteration: 1 ---|---  Loss 0.975
Iteration: 2 ---|---  Loss 0.919
Iteration: 3 ---|---  Loss 0.988
Iteration: 4 ---|---  Loss 0.913
Iteration: 5 ---|---  Loss 0.933
Iteration: 6 ---|---  Loss 0.957
Iteration: 7 ---|---  Loss 0.969
Iteration: 8 ---|---  Loss 0.959
Iteration: 9 ---|---  Loss 0.897
Iteration: 10 ---|---  Loss 0.973


100%|██████████| 10/10 [07:00<00:00, 42.02s/it]

--|-- Epoch 10 Validation DS: 0.3508 --|--
Checkpoint saved at logs_and_checkpoints\checkpoints\3DUNet\2024-05-23_20-48\Epoch_10_checkpoint_2024-05-23_20-55-04.pth
Total training time: 420.235 seconds



