# Content description
This is a Jupyter notebook that closely follows the MONAI tutorials provided in the [MONAI repository](https://github.com/Project-MONAI/tutorials). The notebook is specifically tuned towards segmentation of data in the [ACDC challenge](https://www.creatis.insa-lyon.fr/Challenge/acdc/) using a 2D U-Net. You will need to first download the ACDC dataset via [this link](https://acdc.creatis.insa-lyon.fr/#challenges), and unzip it in a directory of your choosing. You can define that directory here:



In [None]:
datapath = r'/home/wolterinkjm/ACDC/training'

## Loading and visualizing ACDC data

We will first show how to load and visualize ACDC images and segmentations. First, import the Python packages used in this demo.

In [None]:
!pip install monai nibabel
import torch
import monai
import numpy as np
import matplotlib.pyplot as plt
import glob

We define a [MONAI dictionary transform](https://docs.monai.io/en/latest/transforms.html#dictionary-transforms) that loads a pair of image + reference label based on filenames provided in a Python dictionary. The `LoadImageD` function does not yet load the data, but provides the object that will be able to load the data given new filenames. Afterwards, we provide a dictionary containing one pair of image + reference label filename to this object, and load the image and label into 1data_dict1.

In [None]:
transform = monai.transforms.LoadImageD(("image", "label")) 
file_dict = {"image": "{}/patient001/patient001_frame01.nii.gz".format(datapath), 
             "label": "{}/patient001/patient001_frame01_gt.nii.gz".format(datapath)}

data_dict = transform(file_dict)

Now, `data_dict` contains a 3D cine MR image and its reference segmentation mask. We use the [Matplotlib library](https://matplotlib.org/) to visualize the images in a rectangular grid. The for-loop iterates over `z`, the slice index. Each slice is shown, and the reference label is shown as a colored overlay (left ventricle in green, myocardium in blue, right ventricle in red).

In [None]:
def visualize_data(pt_dict, batch=False):
    image = pt_dict["image"].squeeze()
    label = pt_dict["label"].squeeze()
    if batch:
        image = image.permute((1, 2, 0))
        label = label.permute((1, 2, 0))
    plt.figure(figsize=(20,20))
    for z in range(image.shape[2]): 
      plt.subplot(np.ceil(np.sqrt(image.shape[2])), np.ceil(np.sqrt(image.shape[2])), 1 + z)
      plt.imshow(image[:, :, z], cmap='gray')
      plt.axis('off')
      plt.imshow(np.ma.masked_where(label[:, :, z]!=2, label[:, :, z]==2), alpha=0.6, cmap='Blues', clim=(0, 1))  
      plt.imshow(np.ma.masked_where(label[:, :, z]!=3, label[:, :, z]==3), alpha=0.6, cmap='Greens', clim=(0, 1))
      plt.imshow(np.ma.masked_where(label[:, :, z]!=1, label[:, :, z]==1), alpha=0.6, cmap='Reds', clim=(0, 1))
      plt.title('Slice {}'.format(z + 1))
    plt.show()
    
visualize_data(data_dict)

# Composing transforms
Multiple PyTorch/MONAI transforms can be composed in to one transform to pre-process or post-process data. For example, we can load an image + reference label pair, normalize the image, and add Gaussian noise to the image. To do this, we add a `RandGaussianNoised` transform and apply this only to `image`. 

In [None]:
transform = monai.transforms.Compose([monai.transforms.LoadImageD(("image", "label")), 
                                      monai.transforms.ScaleIntensityRangePercentilesd(keys=("image"), lower=5, upper=95, b_min=0, b_max=1, clip=True),
                                      monai.transforms.RandGaussianNoised(("image"), prob=1, std=.3)])

data_dict = transform(file_dict)
visualize_data(data_dict)

# Datasets and dataloaders
We can also load more than one image, and instead store a full set of images and labels in a dataset. To do this, we make a dictionary of filenames, as before. This dictionary and the composed transform, are then used to initialize a dataset. In addition to the `LoadImageD` transform, we now also resample the image to a fixed resolution using `SpacingD`, add a channel using `AddChannelD`, normalize the intensities using `ScaleIntensityRangePercentilesd`, and randomly crop a 128 x 128 pixel subimage using `RandSpatialCropD`, and finally transform the NumPy arrays for the image and reference label mask into PyTorch tensors using `ToTensorD`. 

This [PyTorch tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) provides more information about datasets and dataloaders.

In [None]:
file_dict = []
for ptid in range(1, 101):
    gt_filenames = glob.glob(r'{}/patient{}/*_gt.nii.gz'.format(datapath, str(ptid).zfill(3)))
    file_dict.append({'image': gt_filenames[0].replace('_gt', ''), 'label': gt_filenames[0]})
    file_dict.append({'image': gt_filenames[1].replace('_gt', ''), 'label': gt_filenames[1]})
    
transform = monai.transforms.Compose([
    monai.transforms.LoadImageD(("image", "label")),
    monai.transforms.AddChannelD(("image", "label")),
    monai.transforms.ScaleIntensityRangePercentilesd(keys=("image"), lower=5, upper=95, b_min=0, b_max=1, clip=True),
    monai.transforms.RandSpatialCropD(keys=("image", "label"), roi_size=(128, 128, 1), random_center=True, random_size=False),
    monai.transforms.SqueezeDimd(keys=("image", "label"), dim=-1),
    monai.transforms.ToTensorD(("image", "label"))        
])    
    
dataset = monai.data.Dataset(data = file_dict, transform = transform)

In addition, we make a dataloader for the dataset. This is an object that allows us to efficiently sample data from a dataset.

In [None]:
dataloader = monai.data.DataLoader(dataset, batch_size=16, shuffle=False)

We can now inspect this dataset and dataloader. For example, we can print how many images the dataset contains. We can also draw and visualize random mini-batches of samples using the dataloader. Running the cell below a couple of times will give different samples each time. Note that these samples are subimages with fixed size, as we'd like to use to train the U-Net.

In [None]:
print('The dataset contains {} images'.format(len(dataset)))

visualize_data(next(iter(dataloader)), batch=True)

# Defining a training and validation set
When training the U-Net for segmentation, it is important to separate your training and validation sets well. Using data from the same patient in both the training and validation set could lead to [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)). Hence, we split our dataset at the patient level: we select 80 patients for training, and 20 patients for testing. Because ACDC patients are organized in groups of 20 patients with the same pathology, we perform a stratified selection in which we put each fifth patient in the validation set.

In [None]:
val_ids = set(range(1, 101, 5))
train_ids = set(range(1, 101)) - val_ids

We then make two separate file dictionaries, one for the training set and one for the validation set. These will be used for separate datasets and separate dataloaders. Note that we also define a separate transform for the training set and validation set. When using data augmentation, the transform for the training set should be altered. 

In [None]:
file_dict_train = []
for ptid in train_ids:
    gt_filenames = glob.glob(r'{}/patient{}/*_gt.nii.gz'.format(datapath, str(ptid).zfill(3)))
    file_dict_train.append({'image': gt_filenames[0].replace('_gt', ''), 'label': gt_filenames[0]})
    file_dict_train.append({'image': gt_filenames[1].replace('_gt', ''), 'label': gt_filenames[1]})
    
file_dict_val = []
for ptid in val_ids:
    gt_filenames = glob.glob(r'{}/patient{}/*_gt.nii.gz'.format(datapath, str(ptid).zfill(3)))
    file_dict_val.append({'image': gt_filenames[0].replace('_gt', ''), 'label': gt_filenames[0]})
    file_dict_val.append({'image': gt_filenames[1].replace('_gt', ''), 'label': gt_filenames[1]})    
       
# This transform should be altered to add data augmentation        
transform_train = monai.transforms.Compose([
    monai.transforms.LoadImageD(("image", "label")),
    monai.transforms.AddChannelD(("image", "label")),
    monai.transforms.ScaleIntensityRangePercentilesd(keys=("image"), lower=5, upper=95, b_min=0, b_max=1, clip=True),
    monai.transforms.RandSpatialCropD(keys=("image", "label"), roi_size=(128, 128, 1), random_center=True, random_size=False),
    monai.transforms.SqueezeDimd(keys=("image", "label"), dim=-1),    
    monai.transforms.ToTensorD(("image", "label")),
])

transform_val = monai.transforms.Compose([
    monai.transforms.LoadImageD(("image", "label")),
    monai.transforms.AddChannelD(("image", "label")),
    monai.transforms.ScaleIntensityRangePercentilesd(keys=("image"), lower=5, upper=95, b_min=0, b_max=1, clip=True),
    monai.transforms.ToTensorD(("image", "label")),
])
        
dataset_train = monai.data.CacheDataset(data = file_dict_train, transform = transform_train, progress=False)
dataset_val = monai.data.Dataset(data = file_dict_val, transform = transform_val)

dataloader_train = monai.data.DataLoader(dataset_train, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
dataloader_val = monai.data.DataLoader(dataset_val, batch_size=1, shuffle=False)

print('The training set contains {} MRI scans.'.format(len(dataset_train)))
print('The test set contains {} MRI scans.'.format(len(dataset_val)))

Now we can visualize a mini-batch from the training set, which contains randomly selected 2D 128 x 128 pixel images. Similarly, we can visualize all 2D slices for one 3D volume in the validation set. 

In [None]:
visualize_data(next(iter(dataloader_train)), batch=True)
visualize_data(next(iter(dataloader_val)))

# Setting up the neural network, loss function, and optimizer

<img src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png" alt="drawing" width="600"/>

With data loading all sorted out, we set up the U-Net, a loss function, and an optimizer. If possible, use a GPU to substantially speed up computing.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Setting up a [U-Net](https://arxiv.org/abs/1505.04597) for segmentation in MONAI is as easy as calling the `UNet` function and providing it with the number of input channels, output channels, and feature maps/channels in the intermediate layers. The following provides us with a model that is optimized during training to perform segmentation.

In [None]:
model = monai.networks.nets.UNet(
    dimensions=2,
    in_channels=1,
    out_channels=4,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

The loss function should reflect what we want from the trained model. In this case, a [Dice](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) is used, which becomes lower as automatic and reference segmentations overlap more. This loss function first applies a [softmax](https://en.wikipedia.org/wiki/Softmax_function) function to the network outputs so that the probabilities for all classes sum to one. To make sure that the predicted probability mask `y_pred` and the reference segmentation mask `y` have the same shape, a [one-hot encoding](https://en.wikipedia.org/wiki/One-hot) is applied to the reference segmentation mask. The Dice loss is computed over the full mini-batch (`batch=True`) to avoid poorly defined loss in individual batch samples.

In [None]:
loss_function =  monai.losses.DiceLoss(softmax=True, to_onehot_y=True, batch=True)

# Choose an optimizer

An optimizer algorithm is chosen that performs gradient descent on the network parameters to minimize the loss function. In many cases, [Adam](https://arxiv.org/abs/1412.6980) is a good default option. The optimizer operates on the parameters of the previously defined U-Net, i.e. `model`. A learning rate `lr` is provided to the optimizer, which defines how large the changes should be that are made to the network parameters in each iteration.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Setting up a simple training loop
The below provides a minimal training loop. The loop repeatedly iterates over the training set. Each such full training set iteration is called an epoch, and we perform 500 such epochs. Within an epoch, the loop does the following
* Pick a random mini-batch from the training set
* Set the gradients of all parameters in the network using `optimizer.zero_grad()`
* Obtain model output for the input batch 
* Compute the Dice loss
* Backpropagate the loss using `loss.backward()`
* Let the optimizer update the network parameters with `optimizer.step()`
* Pick the next random mini-batch

This should take around 5 minutes on a reasonable GPU or 30 minutes on a decent CPU.

In [None]:
from tqdm.notebook import tqdm

training_losses = list()

for epoch in tqdm(range(500)):
    model.train()    
    epoch_loss = 0
    step = 0
    for batch_data in dataloader_train: 
        step += 1
        optimizer.zero_grad()
        outputs = model(batch_data["image"].to(device))
        loss = loss_function(outputs, batch_data["label"].to(device))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    training_losses.append(epoch_loss/step)

# Store the network parameters        
torch.save(model.state_dict(), r'trainedUNet.pt')        

After training, we can plot the training loss, with number of epochs on the x-axis.

In [None]:
plt.figure()
plt.plot(np.asarray(training_losses))
plt.xlabel('Epoch')
plt.ylabel('Dice loss')
plt.show()
plt.draw()

We are now ready to segment the full validation set, where some post-processing is applied to each segmentation to obtain a contiguous segmentation mask with discrete values. The following loop also prints the average Dice similarity coefficient for each image.

In [None]:
model.eval()
postprocess = monai.transforms.Compose([
    monai.transforms.AsDiscrete(argmax=True, to_onehot=True, n_classes=4, threshold_values=False),
    monai.transforms.KeepLargestConnectedComponent(applied_labels=(1, 2, 3), independent=False, connectivity=None)
])

for val_batch in dataloader_val:
    outputs_val = monai.inferers.sliding_window_inference(val_batch["image"].squeeze(1).permute(3, 0, 1, 2).to(device), (128, 128), 32, model, overlap = 0.8)
    outputs_val = outputs_val.permute(1, 2, 3, 0).unsqueeze(0)
    print(outputs_val.shape)
    outputs_val = postprocess(outputs_val)
    result = {"image": val_batch["image"].squeeze(), 
              "label": torch.argmax(outputs_val, dim=1).squeeze().cpu()}
    visualize_data(result)     
    
    dice_metric = monai.metrics.DiceMetric()
    dsc, _ = dice_metric(outputs_val.cpu(), monai.networks.utils.one_hot(val_batch["label"].squeeze().unsqueeze(0).unsqueeze(0), 4))
    hd_metric = monai.metrics.HausdorffDistanceMetric()
    hd, _ = hd_metric(outputs_val.cpu(), monai.networks.utils.one_hot(val_batch["label"].squeeze().unsqueeze(0).unsqueeze(0), 4))
    
    print('Average DSC {:.2f}, average Hausdorff distance {:.2f} mm'.format(dsc[0], hd[0]))
    
    

Of course, you can also load pretrained network parameters. For that, run the following after defining your model.

In [None]:
model.load_state_dict(torch.load(r'trainedUNet.pt'))