## Clone repository heart-segmentation-monai

In [1]:
!git clone https://github.com/luhtookyaw/heart-segmentation-monai.git

## Install Necessary packages

In [2]:
!pip3 install monai
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

## Import packages

In [3]:
import sys

sys.path.append("/content/heart-segmentation-monai/")

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.data import DataLoader, Dataset
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.utils import first
from monai.transforms import (
    Compose,
    AddChanneld,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Activations
)

from glob import glob
from utilities import prepare, train

## Configurations and Training

### Define input data directory and saved model directory

In [None]:
data_dir = "/content/heart-segmentation-monai/Task_Heart"
model_dir = "/content/heart-segmentation-monai/results"

os.mkdir(model_dir)

data_in = prepare(data_dir, a_min=0, a_max=1435.2, cache=True) # contrast between 0 and 1435.2

### Define CUDA gpu device

In [None]:
device = torch.device("cuda:0")
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256), 
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

### Define Dice Loss or Dice Cross Entropy Loss

In [None]:
#loss_function = DiceCELoss(to_onehot_y=True, sigmoid=True, squared_pred=True, ce_weight=calculate_weights(1792651250,2510860).to(device))
loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True)

### Train the model

In [None]:
train(model, data_in, loss_function, optimizer, 20, model_dir)

### Load the metric log files

In [None]:
train_loss = np.load(os.path.join(model_dir, "loss_train.npy"))
train_metric = np.load(os.path.join(model_dir, "metric_train.npy"))
test_loss = np.load(os.path.join(model_dir, "loss_test.npy"))
test_metric = np.load(os.path.join(model_dir, "metric_test.npy"))

### Plot the metrics

In [None]:
fig, ((ax1, ax2)) = plt.subplots(1, 2, sharex=False, sharey=False)
fig.set_figheight(6)
fig.set_figwidth(20)

epochs = [i + 1 for i in range(len(train_loss))]
ax1.plot(epochs, train_loss, '-o', epochs, test_loss, '-o')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Dice Loss (Train vs Test)")
ax1.legend(["Train", "Test"])
ax1.fill_between(epochs, train_loss, test_loss, facecolor='grey', alpha=0.14)

ax2.plot(epochs, train_metric, '-o', epochs, test_metric, '-o')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Metric")
ax2.set_title("Dice Metric (Train vs Test)")
ax2.legend(["Train", "Test"])
ax2.fill_between(epochs, train_metric, test_metric, facecolor='grey', alpha=0.14)

# Testing and Visualization

### Load test data directory

In [None]:
path_test_volumes = sorted(glob(os.path.join(data_dir, "TestVolumes", "*.nii.gz")))
path_test_segmentation = sorted(glob(os.path.join(data_dir, "TestSegmentation", "*.nii.gz")))

test_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_test_volumes, path_test_segmentation)]
test_files = test_files[6:9] # samples from six to nine

### Define transforms data object

In [None]:
test_transforms = Compose(
    [
        LoadImaged(keys=["vol", "seg"]),
        AddChanneld(keys=["vol", "seg"]),
        Spacingd(keys=["vol", "seg"], pixdim=(1.5, 1.5, 1.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["vol"], a_min=0, a_max=1435.2, b_min=0.0, b_max=1.0, clip=True), 
        CropForegroundd(keys=["vol", "seg"], source_key="vol"),
        Resized(keys=["vol", "seg"], spatial_size=[128, 128, 64]),   
        ToTensord(keys=["vol", "seg"]),
    ]
)

### Pass the test files and transforms

In [None]:
test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1)

### Load the trained model

In [None]:
model.load_state_dict(torch.load(os.path.join(model_dir, "best_metric_model.pth")))
model.eval()

### Visualize prediction result of first sample

In [None]:
sw_batch_size = 4
roi_size = (128, 128, 64)
with torch.no_grad():
  test_patient = first(test_loader)
  t_volume = test_patient['vol']

  test_outputs = sliding_window_inference(t_volume.to(device), roi_size, sw_batch_size, model)
  sigmoid_activation = Activations(sigmoid=True)
  test_outputs = sigmoid_activation(test_outputs)
  test_outputs = test_outputs > 0.53

  for i in range(32): # up ot 32 slices
    plt.figure("check", (18, 6))

    plt.subplot(1, 3, 1)
    plt.title(f"image {i}")
    plt.imshow(test_patient["vol"][0, 0, :, :, i], cmap="gray")

    plt.subplot(1, 3, 2)
    plt.title(f"label {i}")
    plt.imshow(test_patient["seg"][0, 0, :, :, i] != 0)

    plt.subplot(1, 3, 3)
    plt.title(f"output {i}")
    plt.imshow(test_outputs.detach().cpu()[0, 1, :, :, i])
    plt.show()