**Current note:**
* No augmentation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Set the system path for saving and loading libraries
import sys
sys.path.append('/content/drive/MyDrive/library')
sys.path.append('/content/drive/MyDrive/Colab Notebooks/DNS_PanoramicDentalSeg/DNS_SMP')
sys.path.append('/content/drive/MyDrive/teeth_instance/utils') # to import runtime_patch

In [None]:
# !pip install --target='/content/drive/MyDrive/library' timm==0.9.2
# !pip install --target='/content/drive/MyDrive/library' pretrainedmodels==0.7.4
# !pip install --target='/content/drive/MyDrive/library' efficientnet-pytorch==0.7.1
# !pip install --target='/content/drive/MyDrive/library' git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.utils import metrics, losses, base
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import albumentations as A
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
import os
from datetime import datetime
from copy import deepcopy
import pickle
import torch.nn.functional as F

### Data  generator

In [None]:
class Dataset(BaseDataset):

    def __init__(
            self,
            list_IDs,
            images_dir,
            masks_dir,
            load_from_dir:bool = True,
            prepatch:bool=None, # if true, patching is performed before augmentation, if false, patching after augmentation, if None, no patching
            patch_shape:tuple = (256,256),
            overlap:tuple = (0,0),
            fg_prob:float = 0.9,
            max_roi:bool = True,
            augmentation=None,
            preprocessing=None,
            to_categorical:bool=False,
            n_classes:int=33,
    ):
        self.ids = list_IDs
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        self.prepatch = prepatch
        self.patch_shape = patch_shape
        self.overlap = overlap
        self.fg_prob = fg_prob
        self.max_roi = max_roi
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.to_categorical = to_categorical
        self.n_classes = n_classes


        # Add default image and mask. Colab sometimes does not find image/mask in a directory
        # where there are many images. To avoid training termination, set default image and mask.
        # So, if colab does not find an image, it will load default image.
        self.default_image = os.path.join(images_dir, self.ids[0])
        self.default_mask = os.path.join(masks_dir, self.ids[0])

    def __getitem__(self, i):

        try:
          # read data
          image = cv2.imread(self.images_fps[i])[:,:,::-1]

          mask = cv2.imread(self.masks_fps[i], 0) # graychannel
          mask = np.expand_dims(mask, axis=-1)  # adding channel axis

        except: # if colab can't catch the image and mask, then load the default image and mask to avoid runtime error
          print(f'*************Exception occurred at: {self.images_fps[i]}', '\n')
          image = DEFAULT_IMAGE
          mask = DEFAULT_MASK
          mask = np.expand_dims(mask, axis=-1)

        if self.prepatch is not None and self.prepatch: # apply patching before augmentation
          image, mask = runtime_patch(image, mask, patch_shape=self.patch_shape , overlap=self.overlap, FG_PROB=0.9, MAX_ROI=self.max_roi)

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        if self.prepatch is not None and not self.prepatch: # apply patching after augmentation
          image, mask = runtime_patch(image, mask, patch_shape=(256,256), overlap=(0,0), FG_PROB=self.fg_prob, MAX_ROI=True)

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        if self.to_categorical:
            mask = torch.from_numpy(mask)
            mask = F.one_hot(mask.long(), num_classes=self.n_classes)
            mask = mask.type(torch.float32)
            mask = mask.numpy()
            mask = np.squeeze(mask)

            mask = np.moveaxis(mask, -1, 0) # e.g. 6 x 512 x 512. Only for smp

        return image, mask

    def __len__(self):
        return len(self.ids)

### Augmentation

In [None]:
def get_training_augmentation():
    train_transform = [

        A.OneOf(
            [
                A.HorizontalFlip(p=0.8),
                A.VerticalFlip(p=0.4),
            ],
            p=0.5,
        ),

        A.OneOf(
            [
                A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0, p=1, border_mode=0), # scale only
                A.ShiftScaleRotate(scale_limit=0, rotate_limit=30, shift_limit=0, p=1, border_mode=0), # rotate only
                A.ShiftScaleRotate(scale_limit=0, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0), # shift only
                A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=30, shift_limit=0.1, p=1, border_mode=0), # affine transform
            ],
            p=0.9,
        ),


        A.OneOf(
            [
                A.Perspective(p=1),
                A.GaussNoise(p=1),
                A.Sharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.2,
        ),

        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
                A.RandomGamma(p=1),
                A.HueSaturationValue(p=1),
            ],
            p=0.2,
        ),

    ]

    return A.Compose(train_transform, p=0.9, is_check_shapes=False) # 90% augmentation probability


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        # A.PadIfNeeded(512, 512)
    ]
    return A.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        A.Lambda(image=preprocessing_fn),
        A.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return A.Compose(_transform)


### Data directory

In [None]:
# Load directories for default image
x_train_dir = '/content/drive/MyDrive/panoramicDNS/patchDNS_512/train/images' # patch images 512x512x3
x_valid_dir = '/content/drive/MyDrive/panoramicDNS/patchDNS_512/val/images'
y_train_dir = '/content/drive/MyDrive/panoramicDNS/patchDNS_512/train/mask'
y_valid_dir = '/content/drive/MyDrive/panoramicDNS/patchDNS_512/val/mask'

DEFAULT_IMAGE = cv2.imread(os.path.join(x_train_dir, os.listdir(x_train_dir)[0]))[:,:,::-1]
DEFAULT_MASK = cv2.imread(os.path.join(y_train_dir, os.listdir(y_train_dir)[0]), 0)

In [None]:
list_IDs_train = os.listdir(x_train_dir)
list_IDs_val = os.listdir(x_valid_dir)

random.seed(42) # seed for random number generator

random.shuffle(list_IDs_train) # shuffle names
random.shuffle(list_IDs_val) # shuffle names

print('No. of training images: ', len(list_IDs_train))
print('No. of validation images: ', len(list_IDs_val))

No. of training images:  4356
No. of validation images:  828


### Parameters

In [None]:
BASE_MODEL = 'FUSegNet'
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
BATCH_SIZE = 2
n_classes = 33
ACTIVATION = 'softmax' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 0.0001 # learning rate
EPOCHS = 50
WEIGHT_DECAY = 1e-8 #1e-5
SAVE_WEIGHTS_ONLY = True
TO_CATEGORICAL = True
SAVE_BEST_MODEL = True
SAVE_LAST_MODEL = False
PERIOD = 20 # periodically save checkpoints
RAW_PREDICTION = False # if true, then stores raw predictions (i.e. before applying threshold)

PATIENCE = 30 # for early stopping
EARLY_STOP = True

# Patch info
PREPATCH = None # no patching
PATCH_SHAPE = (512,512)
OVERLAP = (0,0)
FG_PROB = 0.5 # probability of selecting a foreground patch
MAX_ROI = False # If true, then get the foreground patch which has max roi in the patch

### Build model

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
model_name = BASE_MODEL + '_' + ENCODER + '_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
print(model_name)

**Checkpoint directory**

In [None]:
# Checkpoint directory
checkpoint_loc = '/content/drive/MyDrive/Colab Notebooks/DNS_PanoramicDentalSeg/DNS_SMP/checkpoints/' + model_name

# Create checkpoint directory if does not exist
if not os.path.exists(checkpoint_loc): os.makedirs(checkpoint_loc)

**Helper function**

In [None]:
# Helper function: save a model
def save(model_path, epoch, model_state_dict, optimizer_state_dict):

    state = {
        'epoch': epoch + 1,
        'state_dict': deepcopy(model_state_dict),
        'optimizer': deepcopy(optimizer_state_dict),
        }

    torch.save(state, model_path)

**Loss function**

In [None]:
dice_loss = losses.DiceLoss()
focal_loss = losses.FocalLoss()

total_loss = base.SumOfLosses(dice_loss, focal_loss)

**Metrics**

In [None]:
metrics = [
    metrics.IoU(threshold=0.5),
    metrics.Fscore(threshold=0.5),
]

**Model**

In [None]:
# Uncomment for FUSegNet + Gaten-attn
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=n_classes,
    activation=ACTIVATION,
    decoder_attention_type = 'pscse',
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

model.to(DEVICE)

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, batch_size=-1, device="cuda"):

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()

            # following part is moderated by https://github.com/graykode/modelsummary/issues/1
            if len(input) != 0:
                summary[m_key]["input_shape"] = list(input[0].size())
                summary[m_key]["input_shape"][0] = batch_size
            else: summary[m_key]["input_shape"] = input


            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        print(line_new)

    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("----------------------------------------------------------------")
    # return summary


In [None]:
summary(model, (3, 512, 512))

**Optimizer**

In [None]:
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY),
])

**Learning rate**

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                              factor=0.1,
                              mode='min',
                              patience=10,
                              min_lr=0.00001,
                              verbose=True,
                              )


**Uncomment if you want to load a pretrained model**

In [None]:
# pretrained_model_name = 'FUSegNet_gated_attn_efficientnet-b7_2023-07-08_14-39-11'
# pretrained_cp_loc = '/content/drive/MyDrive/Colab Notebooks/DNS_PanoramicDentalSeg/DNS_SMP/checkpoints/' + pretrained_model_name

# checkpoint = torch.load(os.path.join(pretrained_cp_loc, 'best_model.pth'))
# model.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])

**Dataloader**

In [None]:
# Uncomment if you want to work with selected numbers of images
selected_img = False
if selected_img:
  import random
  random.shuffle(list_IDs_train)
  random.shuffle(list_IDs_val)
  list_IDs_train = list_IDs_train[:1000]
  list_IDs_val = list_IDs_val[:100]
  print(list_IDs_train)


train_dataset = Dataset(
    list_IDs_train,
    x_train_dir,
    y_train_dir,
    prepatch=PREPATCH,
    patch_shape=PATCH_SHAPE,
    overlap=OVERLAP,
    fg_prob=FG_PROB,
    max_roi=MAX_ROI,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    to_categorical=TO_CATEGORICAL,
    n_classes=n_classes,
)

valid_dataset = Dataset(
    list_IDs_val,
    x_valid_dir,
    y_valid_dir,
    prepatch=PREPATCH,
    patch_shape=PATCH_SHAPE,
    overlap=OVERLAP,
    fg_prob=FG_PROB,
    max_roi=MAX_ROI,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    to_categorical=TO_CATEGORICAL,
    n_classes=n_classes,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

### Training

In [None]:
# create epoch runners
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=total_loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=total_loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

# train model for N epochs
best_viou = 0.0
best_vloss = 1_000_000.
save_model = False
cnt_patience = 0

store_train_loss, store_val_loss = [], []
store_train_iou, store_val_iou = [], []
store_train_dice, store_val_dice = [], []

for epoch in range(EPOCHS):

    print('\nEpoch: {}'.format(epoch))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)

    # Store losses and metrics
    train_loss_key = list(train_logs.keys())[0] # first key is for loss
    val_loss_key = list(valid_logs.keys())[0] # first key is for loss

    store_train_loss.append(train_logs[train_loss_key])
    store_val_loss.append(valid_logs[val_loss_key])
    store_train_iou.append(train_logs["iou_score"])
    store_val_iou.append(valid_logs["iou_score"])
    store_train_dice.append(train_logs["fscore"])
    store_val_dice.append(valid_logs["fscore"])

    # Track best performance, and save the model's state
    if  best_vloss > valid_logs[val_loss_key]:
        best_vloss = valid_logs[val_loss_key]
        print(f'Validation loss reduced. Saving the model at epoch: {epoch:04d}')
        cnt_patience = 0 # reset patience
        best_model_epoch = epoch
        save_model = True

    # Compare iou score
    elif best_viou < valid_logs['iou_score']:
        best_viou = valid_logs['iou_score']
        print(f'Validation IoU increased. Saving the model at epoch: {epoch:04d}.')
        cnt_patience = 0 # reset patience
        best_model_epoch = epoch
        save_model = True

    else: cnt_patience += 1

    # Learning rate scheduler
    scheduler.step(valid_logs[sorted(valid_logs.keys())[0]]) # monitor validation loss

    # Save the model
    if save_model:
        save(os.path.join(checkpoint_loc, 'best_model' + '.pth'),
             epoch+1, model.state_dict(), optimizer.state_dict())
        save_model = False

    # Early stopping
    if EARLY_STOP and cnt_patience >= PATIENCE:
      print(f"Early stopping at epoch: {epoch:04d}")
      break

    # Periodic checkpoint save
    if not SAVE_BEST_MODEL:
      if (epoch+1) % PERIOD == 0:
        save(os.path.join(checkpoint_loc, f"cp-{epoch+1:04d}.pth"),
             epoch+1, model.state_dict(), optimizer.state_dict())
        print(f'Checkpoint saved for epoch {epoch:04d}')

if not EARLY_STOP and SAVE_LAST_MODEL:
    print('Saving last model')
    save(os.path.join(checkpoint_loc, 'last_model' + '.pth'),
         epoch+1, model.state_dict(), optimizer.state_dict())

In [None]:
best_model_epoch

**Plotting**

In [None]:
"""## Plotting """
fig, ax = plt.subplots(3,1, figsize=(7, 14))

ax[0].plot(store_train_loss, 'r')
ax[0].plot(store_val_loss, 'b')
ax[0].set_title('Loss curve')
ax[0].legend(['training', 'validation'])

ax[1].plot(store_train_iou, 'r')
ax[1].plot(store_val_iou, 'b')
ax[1].set_title('IoU curve')
ax[1].legend(['training', 'validation'])

ax[2].plot(store_train_iou, 'r')
ax[2].plot(store_val_iou, 'b')
ax[2].set_title('Dice curve')
ax[2].legend(['training', 'validation'])

fig.tight_layout()

plt.show()

save_fig_dir = '/content/drive/MyDrive/Colab Notebooks/DNS_PanoramicDentalSeg/DNS_SMP/plots/'
if not os.path.exists(save_fig_dir): os.makedirs(save_fig_dir)

fig.savefig(os.path.join(save_fig_dir, model_name + '.png'))

### Inference

In [None]:
from jenti.patch import Patch, Merge
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix
import scipy.io as sio

import warnings
warnings.filterwarnings("ignore")

**Test image directory**

In [None]:
# Directories
x_test_dir = '/content/drive/MyDrive/panoramicDNS/fold5/images'
y_test_dir = '/content/drive/MyDrive/panoramicDNS/fold5/masks'

list_IDs_test = os.listdir(x_test_dir)
print('No. of test images: ', len(list_IDs_test))

No. of test images:  111


**Test dataloader**

In [None]:
# Image preprocessing function
im_preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

**Load model**

In [None]:
checkpoint = torch.load(os.path.join(checkpoint_loc, 'best_model.pth'))
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

# Uncomment if to load periodically saved checkpoints
# model = torch.load(os.path.join(checkpoint_loc, 'cp-0050.pth'))

**Parameters**

In [None]:
save_pred = True
threshold = 0.5
ep = 1e-6
raw_pred = []
patch_shape = [512, 512]
overlap = [0, 0]

HARD_LINE = True

**Set directory to store predictions**

In [None]:
# Save directory
save_dir_pred = '/content/drive/MyDrive/Colab Notebooks/DNS_PanoramicDentalSeg/DNS_SMP/predictions/' + model_name
if not os.path.exists(save_dir_pred): os.makedirs(save_dir_pred)

In [None]:
# Create dataframe to store records
df = pd.DataFrame(index=[], columns = [
    'Name', 'Accuracy', 'Specificity', 'iou', 'Precision', 'Recall', 'Dice'], dtype='object')

# Create a dictionary to store metrics
metric = {} # Nested metric format: metric[image_name][label] = [precision, recall, dice, iou]

In [None]:
stp, stn, sfp, sfn = 0, 0, 0, 0
for i, name_ext in enumerate(list_IDs_test):

    tp, tn, fp, fn = 0, 0, 0, 0

    name = os.path.splitext(name_ext)[0] # remove extension

    metric[name] = {} # Creating nested dictionary

    # Image-wise mean of metrics
    i_mp, i_mr, i_mdice, i_miou = [], [], [], []

    # Load image and mask
    image = cv2.imread(os.path.join(x_test_dir, name + '.jpg'))[:,:,::-1]
    gt_mask = cv2.imread(os.path.join(y_test_dir, name + '.png'), 0) # gt is grayscale

    # Preprocess image
    image = im_preprocessing_fn(image)

    # Create patches from the image
    patch = Patch(patch_shape, overlap, patch_name=name, csv_output=False)
    patches, info, org_shape_im = patch.patch2d(image)
    org_shape_mask = (org_shape_im[0], org_shape_im[1], 1) # mask is a grayscale image

    # Iterate over patches
    patchwise_pred = [] # store patch-wise predictions for each test sample
    for patch in patches:
      patch = np.expand_dims(patch, axis=0) # shape: 1 x 512 x 512 x 3
      patch = np.moveaxis(patch, -1, 1).astype('float32') # shape: 1 x 3 x 512 x 512
      patch = torch.from_numpy(patch) # convert to tensor
      pr_mask = model.predict(patch.to(DEVICE)) # Move image tensor to gpu

      if TO_CATEGORICAL:
        pr_mask = torch.argmax(pr_mask, dim=1) # shape: 1 x 512 x 512
      pr_mask = pr_mask.squeeze().cpu().numpy() # move to cpu; shape: 512 x 512
      pr_mask = np.expand_dims(pr_mask, axis=-1) # shape: 512 x 512 x 1
      patchwise_pred.append(pr_mask)

    # Merge patches
    merge = Merge(info, org_shape_mask, dtype='int8') # create object
    merged = merge.merge2d(patchwise_pred)

    # Move to CPU and convert to numpy
    gt_mask = np.squeeze(gt_mask)
    pred = np.squeeze(merged)

    # Save raw prediction
    if RAW_PREDICTION: raw_pred.append(pred)

    # Save prediction as png
    if save_pred:
        cv2.imwrite(os.path.join(save_dir_pred, name + '.png'), np.squeeze(pred).astype(np.uint8))

    # Find labels in gt and prediction
    lbl_gt = set(np.unique(gt_mask))
    lbl_gt.remove(0) # remove 0. It is background
    lbl_pred = set(np.unique(pred))
    lbl_pred.remove(0) # remove 0. It is background

    # All labels
    all_lbls = lbl_gt.union(lbl_pred)

    # Find labels that are not common in both gt and prediction. For such cases. IoU = 0
    diff1 = lbl_gt - lbl_pred
    diff2 = lbl_pred - lbl_gt
    diffs = diff1.union(diff2) # labels that do not exist in either gt or prediction

    # Labels that are in the gt but not in prediction are fn
    if len(diff1) > 0:
        for d1 in diff1:
            fn_ = len(np.argwhere(gt_mask == d1))
            fn += fn_
            sfn += fn

    # Labels that are in the prediction but not in gt are fp
    if len(diff2) > 0:
        for d2 in diff2:
            fp_ = len(np.argwhere(pred == d2))
            fp += fp_
            sfp += fp

    # Set IoU == 0 for such labels
    if not len(diffs) == 0:
      for diff in diffs:
        p, r, dice, iou = 0, 0, 0, 0
        metric[name][str(diff)] = [p, r, dice, iou]
        print("%d %s: label: %s; Precision: %3.2f; Recall: %3.2f; Dice: %3.2f; IoU: %3.2f"%(i+1, name, diff, p, r, dice, iou))

    # Find labels that are common in both gt and prediction.
    cmns = lbl_gt.intersection(lbl_pred)

    # Iterate over common labels
    for cmn in cmns:
        gt_idx = np.where(gt_mask == cmn)
        pred_idx = np.where(pred == cmn)

        # Convert to [(x1,y1), (x2,y2), ...]
        gt_lidx, pred_lidx = [], [] # List index

        for i in range(len(gt_idx[0])):
            gt_lidx.append((gt_idx[0][i], gt_idx[1][i]))

        for i in range(len(pred_idx[0])):
            pred_lidx.append((pred_idx[0][i], pred_idx[1][i]))

        # Calculate metrics
        gt_tidx = tuple(gt_lidx) # convert to tuple
        pred_tidx = tuple(pred_lidx) # convert to tuple
        tp_cord = set(gt_tidx).intersection(pred_tidx) # set operation
        fp_cord = set(pred_tidx).difference(gt_tidx) # set operation
        fn_cord = set(gt_tidx).difference(pred_tidx) # set operation

        tp += len(tp_cord)
        fp += len(fp_cord)
        fn += len(fn_cord)

        stp += tp
        sfp += fp
        sfn += fn

        p = (tp/(tp + fp + ep)) * 100
        r = (tp/(tp + fn + ep)) * 100
        dice = (2 * tp / (2 * tp + fp + fn + ep)) * 100
        iou = (tp/(tp + fp + fn + ep)) * 100

        print("%d %s: label: %s; Precision: %3.2f; Recall: %3.2f; Dice: %3.2f; IoU: %3.2f"%(i+1, name, cmn, p, r, dice, iou))

        metric[name][str(cmn)] = [p, r, dice, iou]

        # Keep appending metrics for all labels for the current image
        i_mp.append(p)
        i_mr.append(r)
        i_mdice.append(dice)
        i_miou.append(iou)

    # Calculate mean of metrics for the current image
    i_mp = np.mean(i_mp)
    i_mr = np.mean(i_mr)
    i_mdice = np.mean(i_mdice)
    i_miou = np.mean(i_miou)

    # Store results in the data frame
    tmp = pd.Series([name, i_mp, i_mr, i_mdice, i_miou], index=['Name', 'Precision', 'Recall', 'Dice', 'IoU'])
    df = df.append(tmp, ignore_index = True)

# Print overall mean of metrics
print("Image-based all IoU: %3.2f" % df["IoU"].mean())
print("Image-based precision: %3.2f" % df["Precision"].mean())
print("Image-based all Recall: %3.2f" % df["Recall"].mean())
print("Image-based all dice: %3.2f" % df["Dice"].mean())

df.to_excel(os.path.join(save_dir_pred, 'result_image_based.xlsx'), index=False)

# create json object from dictionary
import json
json_write = json.dumps(metric)
f = open(os.path.join(save_dir_pred, "metric.json"), "w")
f.write(json_write)
f.close()

# Data-based evalutation
siou = (stp/(stp + sfp + sfn + ep))*100
sprecision = (stp/(stp + sfp + ep))*100
srecall = (stp/(stp + sfn + ep))*100
sdice = (2 * stp / (2 * stp + sfp + sfn))*100

print('siou:', siou)
print('sprecision:', sprecision)
print('srecall:', srecall)
print('sdice:', sdice)

# Save data-based result in a text file
with open(os.path.join(save_dir_pred, 'result_data_based_best_model.txt'), 'w') as f:
    print(f'siou = {siou}', file=f)
    print(f'sprecision = {sprecision}', file=f)
    print(f'srecall = {srecall}', file=f)
    print(f'sdice = {sdice}', file=f)
    print(f'best model epoch = {best_model_epoch}', file=f)
    print(f'model name = {model_name}', file=f)