<a href="https://colab.research.google.com/github/jacopozattoni/BME630_Project/blob/main/Project_BME_630.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Imports

In [None]:
# Imports:
import nibabel as nib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import os
import numpy as np
from sklearn.model_selection import train_test_split
import json
from sklearn.metrics import balanced_accuracy_score, adjusted_rand_score, roc_auc_score
import warnings
import requests
import os
from PIL import Image, ImageFilter
from torchvision.utils import save_image

### Miseval functions
Taken from: https://github.com/frankkramer-lab/miseval . Miseval is a python package that implements evaluation metrics for image segmentation tasks; it relies on numpy and scikit learn. In the project it is used to calculate the dice and the jaccard scores for the segmentations obtained from the network, as well as traditional metrics such as accuracy, precision, sensitivity and specificity.

Its functions are imported manually instead of importing the package because of compatibility issues between numpy and one of the required packages for miseval, numba.

In [None]:
#==============================================================================#
#  Author:       Dominik Müller                                                #
#  Copyright:    2022 IT-Infrastructure for Translational Medical Research,    #
#                University of Augsburg                                        #
#                                                                              #
#  This program is free software: you can redistribute it and/or modify        #
#  it under the terms of the GNU General Public License as published by        #
#  the Free Software Foundation, either version 3 of the License, or           #
#  (at your option) any later version.                                         #
#                                                                              #
#  This program is distributed in the hope that it will be useful,             #
#  but WITHOUT ANY WARRANTY; without even the implied warranty of              #
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
#  GNU General Public License for more details.                                #
#                                                                              #
#  You should have received a copy of the GNU General Public License           #
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#==============================================================================#

#-----------------------------------------------------#
#            Calculate : Precision via Sets           #
#-----------------------------------------------------#
def calc_Precision_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    # Calculate precision
    if pd.sum() != 0 : prec = np.logical_and(pd, gt).sum() / pd.sum()
    else : prec = 0.0
    # Return precision
    return prec

#-----------------------------------------------------#
#             Calculate : Precision via CM            #
#-----------------------------------------------------#
def calc_Precision_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion matrix
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate precision
    if (tp + fp) != 0 : prec = (tp) / (tp + fp)
    else : prec = 0.0
    # Return precision
    return prec

#-----------------------------------------------------#
#           Calculate : Sensitivity via Sets          #
#-----------------------------------------------------#
def calc_Sensitivity_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    # Calculate sensitivity
    if gt.sum() != 0 : sens = np.logical_and(pd, gt).sum() / gt.sum()
    else : sens = 0.0
    # Return sensitivity
    return sens

#-----------------------------------------------------#
#            Calculate : Sensitivity via CM           #
#-----------------------------------------------------#
def calc_Sensitivity_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion matrix
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate sensitivity
    if (tp + fn) != 0 : sens = (tp) / (tp + fn)
    else : sens = 0.0
    # Return sensitivity
    return sens

#-----------------------------------------------------#
#           Calculate : Specificity via Sets          #
#-----------------------------------------------------#
def calc_Specificity_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    not_gt = np.logical_not(np.equal(truth, c))
    not_pd = np.logical_not(np.equal(pred, c))
    # Calculate specificity
    if (not_gt).sum() != 0:
        spec = np.logical_and(not_pd, not_gt).sum() / (not_gt).sum()
    else : spec = 0.0
    # Return specificity
    return spec

#-----------------------------------------------------#
#            Calculate : Specificity via CM           #
#-----------------------------------------------------#
def calc_Specificity_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion matrix
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate specificity
    if (tn + fp) != 0 : spec = (tn) / (tn + fp)
    else : spec = 0.0
    # Return specificity
    return spec


#-----------------------------------------------------#
#            Calculate : Accuracy via Sets            #
#-----------------------------------------------------#
def calc_Accuracy_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Calculate Accuracy
    acc = (np.logical_and(pd, gt).sum() + \
           np.logical_and(not_pd, not_gt).sum()) / gt.size
    # Return computed Accuracy
    return acc

#-----------------------------------------------------#
#           Calculate : Accuracy via ConfMat          #
#-----------------------------------------------------#
def calc_Accuracy_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion mat
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate Accuracy
    acc = (tp + tn) / (tp + tn + fp + fn)
    # Return computed Accuracy
    return acc

#-----------------------------------------------------#
#            Calculate : AUC via trapezoid            #
#-----------------------------------------------------#
"""
Formula:
    AUC = 1 - 1/2 * (FP/(FP+TN) + FN/(FN+TP))

References:
    Powers DMW. Evaluation: from precision, recall and F-measure to ROC, informedness, markedness and correlation.
    2020 Oct 10 [cited 2022 Jan 8]; Available from: http://arxiv.org/abs/2010.16061
"""
def calc_AUC_trapezoid(truth, pred, c=1, **kwargs):
    # Obtain confusion mat
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Compute AUC
    if (fp+tn) != 0 : x = fp/(fp+tn)
    else : x = 0.0
    if (fn+tp) != 0 : y = fn/(fn+tp)
    else : y = 0.0
    auc = 1 - (1/2)*(x + y)
    # Return AUC
    return auc

#-----------------------------------------------------#
#           Calculate : AUC via probability           #
#-----------------------------------------------------#
def calc_AUC_probability(truth, pred_prob, c=1, rounding_precision=5, **kwargs):
    # Round probability to reduce unnecessary thresholds
    prob = np.round(pred_prob[:,:,c], rounding_precision)
    # Obtain ground truth set with associated class
    gt = np.equal(truth, c).astype(int)
    auc = roc_auc_score(gt.flatten(), prob.flatten())
    # Return AUC
    return auc

#-----------------------------------------------------#
#            Calculate : Balanced Accuracy            #
#-----------------------------------------------------#
"""
Formula:
    BACC = (Sensitivity + Specificity) / 2

References:
[1] Brodersen, K.H.; Ong, C.S.; Stephan, K.E.; Buhmann, J.M. (2010).
    The balanced accuracy and its posterior distribution.
    Proceedings of the 20th International Conference on Pattern Recognition, 3121-24.

[2] John. D. Kelleher, Brian Mac Namee, Aoife D’Arcy, (2015).
    Fundamentals of Machine Learning for Predictive Data Analytics: Algorithms, Worked Examples, and Case Studies.
    https://mitpress.mit.edu/books/fundamentals-machine-learning-predictive-data-analytics
"""
def calc_BalancedAccuracy(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c).flatten()
    pd = np.equal(pred, c).flatten()
    # Compute BACC via scikit-learn
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        bacc = balanced_accuracy_score(gt, pd)
    # Return BACC score
    return np.float64(bacc)

#-----------------------------------------------------#
#           Calculate : Adjusted Rand Index           #
#-----------------------------------------------------#
"""
Formula:
    ARI = (RI - Expected_RI) / (max(RI) - Expected_RI)

References:
[1] L. Hubert and P. Arabie, Comparing Partitions, Journal of Classification 1985
    https://link.springer.com/article/10.1007%2FBF01908075


[2] D. Steinley, Properties of the Hubert-Arabie adjusted Rand index,
    Psychological Methods 2004

[3] https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index
"""
def calc_AdjustedRandIndex(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c).flatten()
    pd = np.equal(pred, c).flatten()
    # Compute ARI via scikit-learn
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        ari = adjusted_rand_score(gt, pd)
    # Return ARI score
    return np.float64(ari)

#-----------------------------------------------------#
#            Calculate : Confusion Matrix             #
#-----------------------------------------------------#
def calc_ConfusionMatrix(truth, pred, c=1, dtype=np.int64, **kwargs):
    # Obtain predicted and actual condition
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Compute Confusion Matrix
    tp = np.logical_and(pd, gt).sum()
    tn = np.logical_and(not_pd, not_gt).sum()
    fp = np.logical_and(pd, not_gt).sum()
    fn = np.logical_and(not_pd, gt).sum()
    # Convert to desired numpy type to avoid overflow
    tp = tp.astype(dtype)
    tn = tn.astype(dtype)
    fp = fp.astype(dtype)
    fn = fn.astype(dtype)
    # Return Confusion Matrix
    return tp, tn, fp, fn

#-----------------------------------------------------#
#              Calculate : True Positive              #
#-----------------------------------------------------#
def calc_TruePositive(truth, pred, c=1, **kwargs):
    # Obtain predicted and actual condition
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Compute true positive
    tp = np.logical_and(pd, gt).sum()
    # Return true positive
    return tp

#-----------------------------------------------------#
#              Calculate : True Negative              #
#-----------------------------------------------------#
def calc_TrueNegative(truth, pred, c=1, **kwargs):
    # Obtain predicted and actual condition
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Compute true negative
    tn = np.logical_and(not_pd, not_gt).sum()
    # Return true negative
    return tn

#-----------------------------------------------------#
#              Calculate : False Positive             #
#-----------------------------------------------------#
def calc_FalsePositive(truth, pred, c=1, **kwargs):
    # Obtain predicted and actual condition
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Compute false positive
    fp = np.logical_and(pd, not_gt).sum()
    # Return false positive
    return fp

#-----------------------------------------------------#
#              Calculate : False Negative             #
#-----------------------------------------------------#
def calc_FalseNegative(truth, pred, c=1, **kwargs):
    # Obtain predicted and actual condition
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    not_gt = np.logical_not(gt)
    not_pd = np.logical_not(pd)
    # Compute false negative
    fn = np.logical_and(not_pd, gt).sum()
    # Return false negative
    return fn


# Dice score calculator: -------------------------------------------------------
#-----------------------------------------------------#
#              Calculate : DSC Enhanced               #
#-----------------------------------------------------#
"""
    Reference:  Dominik Müller, Adrian Pfleiderer & Frank Kramer. (2022).
                miseval: a metric library for Medical Image Segmentation EVALuation.
                https://github.com/frankkramer-lab/miseval

    Custom Dice Similarity Coefficient implementation which returns 1.0,
    if two empty masks are compared.
    This allow rewarding models which correctly predict empty masks.
"""
def calc_DSC_Enhanced(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    # Calculate Dice
    if gt.sum() == 0 and pd.sum() == 0 : dice = 1.0
    elif (pd.sum() + gt.sum()) != 0:
        dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum())
    else : dice = 0.0
    # Return computed Dice
    return dice

#-----------------------------------------------------#
#              Calculate : DSC via Sets               #
#-----------------------------------------------------#
def calc_DSC_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    # Calculate Dice
    if (pd.sum() + gt.sum()) != 0:
        dice = 2*np.logical_and(pd, gt).sum() / (pd.sum() + gt.sum())
    else : dice = 0.0
    # Return computed Dice
    return dice

#-----------------------------------------------------#
#             Calculate : DSC via ConfMat             #
#-----------------------------------------------------#
def calc_DSC_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion mat
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate Dice
    if (2*tp + fp + fn) != 0 : dice = 2*tp / (2*tp + fp + fn)
    else : dice = 0.0
    # Return computed Dice
    return dice

# Jaccard score calculator: -----------------------------------------------------
#-----------------------------------------------------#
#              Calculate : IoU via Sets               #
#-----------------------------------------------------#
def calc_IoU_Sets(truth, pred, c=1, **kwargs):
    # Obtain sets with associated class
    gt = np.equal(truth, c)
    pd = np.equal(pred, c)
    # Calculate IoU
    if  (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum()) != 0:
        iou = np.logical_and(pd, gt).sum() / \
              (pd.sum() + gt.sum() - np.logical_and(pd, gt).sum())
    else : iou = 0.0
    # Return computed IoU
    return iou

#-----------------------------------------------------#
#             Calculate : IoU via ConfMat             #
#-----------------------------------------------------#
def calc_IoU_CM(truth, pred, c=1, **kwargs):
    # Obtain confusion mat
    tp, tn, fp, fn = calc_ConfusionMatrix(truth, pred, c)
    # Calculate IoU
    if (tp + fp + fn) != 0 : iou = tp / (tp + fp + fn)
    else : iou = 0.0
    # Return computed IoU
    return iou

# Miseval Evaluate function: ---------------------------------------------------------
def mis_evaluate(truth, pred, metric, multi_class=False, n_classes=2, **kwargs):
    # Initialize metric function
    if isinstance(metric, str):
        if metric in metric_dict : eval_metric = metric_dict[metric]
        elif metric.upper() in metric_dict:
            eval_metric = metric_dict[metric.upper()]
        else : raise KeyError("Provided metric string not in metric_dict!" + \
                              " : " + metric)
    elif callable(metric) : eval_metric = metric
    else : raise ValueError("Provided metric is neither a function nor a " + \
                            "string!" + " : " + str(metric))
    # Check some Exceptions
    if n_classes == 2 and len(np.unique(truth)) > 2:
        raise ValueError("Segmentation mask (truth) contains more than 2 classes!")
    if n_classes == 2 and len(np.unique(pred)) > 2:
        raise ValueError("Segmentation mask (pred) contains more than 2 classes!")
    # Run binary mode       -> Compute score only for main class
    if not multi_class and n_classes == 2:
        score = eval_metric(truth, pred, c=1, **kwargs)
        return score
    # Run multi-class mode  -> Compute score for each class
    else:
        score_list = np.zeros((n_classes,))
        for c in range(n_classes):
            score = eval_metric(truth, pred, c=c, **kwargs)
            score_list[c] = score
        return score_list

### My functions
This cell contains all the original functions used in the script.

In [None]:
# My functions:

def create_dir(healthyName, tumorName, parentDir):
    """
    Function to create a new local directory, checking whether it exists or not.
    :param healthyName: Healthy class directory name
    :param tumorName: Tumor class directory name
    :param parentDir: Path to new directory, name of directory excluded.
    :return:
    """
    # Parent Directory path:
    parent_dir = parentDir
    # Create the directories, but checking if they already exist in the current dir:
    if not os.path.exists(os.path.join(parent_dir, healthyName)):
        os.mkdir(os.path.join(parent_dir, healthyName))
    if not os.path.exists(os.path.join(parent_dir, tumorName)):
        os.mkdir(os.path.join(parent_dir, tumorName))

def loadImage(name):
    """
    Function to load a Nifti image from a nii.gz zipped file using NiBabel.
    :param name: \Path\to\Image\File
    :return: Data = Image data extracted from the nii file.
    """
    # Loading the nifti file using NiBabel:
    img = nib.load(name)
    # Getting the data of the image as numpy arrays:
    Data = img.get_fdata()
    return Data

def train(net, x, y, loss_function):
    """
    Function to train a ML network.
    :param net: Network to be trained.
    :param x: Input data for training. It should be 4D tensor, where each 3D entry is a set of the 4 scans of the same brain slice.
    :param y: The segmentation masks associated to the input data.
    :param loss_function: Function to be used as loss evaluator.
    :return loss: loss function value for the input data.
    """
    # clearing the gradients in the network:
    optimizer.zero_grad()
    # predicting the output for the input data x:
    prediction = net(x)
    # Evaluating the loss:
    Y = torch.zeros(elem_in_batch, 1, 240, 240)
    for i in range(elem_in_batch):
        Y[i,0,:,:]=y[:,:,i]
    loss = loss_function(prediction, Y)
    # Calculating the gradient:
    loss.backward()
    # Adjusting the weight through backpropagation:
    optimizer.step()
    return loss

def validate(net, x, y, loss_function, loss_list):
    """
    Function to validate a trained ML network.
    :param net: Trained network.
    :param x: Input data for validation. It should be 4D tensor, where each 3D entry is a set of the 4 scans of the same brain slice.
    :param y: The segmentation masks associated to the input data.
    :param loss_function: Function to be used as loss evaluator.
    :param loss_list: List that contains the loss function values for the training.
    """
    with torch.no_grad():
        # Prediction on the validation data:
        prediction = net(x)
        # loss evaluation:
        Y = torch.zeros(x.shape[0], 1, 240, 240)
        for i in range(n_batches):
            Y[i,0,:,:]=y[:,:,i]
        loss = loss_function(prediction, Y)
        loss_list.append(loss)

def test(net, x, y, n_classes):
    """
    Using the package 'miseval' (https://github.com/frankkramer-lab/miseval) to evaluate the segmentation goodness.
    Metrics implemented:
    - Dice score: calculated as 2*(# of overlapping pixels)/(# pixels real segmentation + # pixels predicted segmentation).
    - Jaccard score: calculated as segmentation overlap / segmentation union.
    :param net: Trained network.
    :param x: Input data.
    :param y: Input segmentation masks.
    :param n_classes: # of classes to consider. It is either 2 (simplified segmentation) or 4 (complete segmentation).
    :return: dice_list is the list with the dice scores for each predicted mask.
    """
    # Initializing a list to save dice scores for the segmentations:
    dice_list = []
    Jaccard_list = []
    pred = net(x)
    # Rearrangement of the reference segmentation masks in the same
    Y = torch.zeros(x.shape[0], 1, 240, 240)
    for j in range(x.shape[0]):
        Y[j,0,:,:]=y[:,:,j]
    # Calculates the dice and jaccard scores for each brain slice:
    for i in range(x.shape[2]):
        if n_classes == 2:
            dice_list.append(mis_evaluate(Y[i,0,:,:], pred[i,0,:,:], metric='DSC'))
            Jaccard_list.append(mis_evaluate(Y[i,0,:,:], pred[i,0,:,:], metric='IoU'))
        else: # since there is a control on n_classes to be either 2 or 4 before in the cose, no additional control on n_classes is made here
            dice_list.append(mis_evaluate(Y[i,0,:,:], pred[i,0,:,:], metric='DSC', multi_class=True, n_classes=n_classes))
            Jaccard_list.append(mis_evaluate(Y[i,0,:,:], pred[i,0,:,:], metric='IoU', multi_class=True, n_classes=n_classes))

    # Whole dataset metrics:
    Accuracy = mis_evaluate(Y, pred, metric="Accuracy", multi_class=True, n_classes=n_classes)
    Precision = mis_evaluate(Y, pred, metric="Precision", multi_class=True, n_classes=n_classes)
    Sensitivity = mis_evaluate(Y, pred, metric="Sensitivity", multi_class=True, n_classes=n_classes)
    Specificity = mis_evaluate(Y, pred, metric="Specificity", multi_class=True, n_classes=n_classes)

    return dice_list, Jaccard_list, Accuracy, Precision, Sensitivity, Specificity, pred, Y


def drawContour(m,s,c,RGB):
    """Draw edges of contour 'c' from segmented image 's' onto 'm' in colour 'RGB'"""
    # Fill contour "c" with white, make all else black
    thisContour = s.point(lambda p:p==c and 255)
    # DEBUG: thisContour.save(f"interim{c}.png")

    # Find edges of this contour and make into Numpy array
    thisEdges   = thisContour.filter(ImageFilter.FIND_EDGES)
    thisEdgesN  = np.array(thisEdges)

    # Paint locations of found edges in color "RGB" onto "main"
    m[np.nonzero(thisEdgesN)] = RGB
    return m

### U-Net architecture
SimpleUnet implements a simple U-net architecture made of 5 groups of layers, organized as follows:
- 2 groups in the encoding part;
- 1 group in the bridging part;
- 2 groups in the decoding part.

BiggerUnet is a bigger version of SimpleUnet, where the number of layer groups is increased from 2 to 4 for the encoding and the decoding part. This results in a deeper net that is able to generate, at the 5th group of layers (the bridging group between the encoder and the decoder), an array with 1024 'channels' instead of the 256 of the simpler architecture. The idea is to compare the performance of the two to understand what is the benefit of a deeper network in this segmentation context.

Both networks take as input an array of dimension [batch_dimension x n_channels x 240 x 240], where batch_dimension = 46 and n_channels = 4 because we are merging the four scan types for the same slice in a single 3D tensor, and output a segmentation mask (hence the # of outputs = 1 in the output layer) that is either binary or a 4-classes mask according to the user's choice.

In [None]:
class SimpleUnet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(SimpleUnet, self).__init__()

        self.n_classes = n_classes
        # Encoder:
        # Layer 1:
        self.l11 = nn.Conv2d(n_channels, 64, kernel_size=3, padding=1)  # 240x240
        self.l12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Layer 2:
        self.l21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 120x120
        self.l22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # Layer 3 - bridge between encoder and decoder:
        self.l31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.l32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

        # Decoder:
        # Layer 1:
        self.upConv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) # does the opposite as the max pooling layer, hence the same kernel_size and stride
        self.l41 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.l42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # Layer 2:
        self.upConv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.l51 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.l52 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        # Layer 3: output layer
        self.outConv = nn.Conv2d(64, 1, kernel_size=1)

        # Activation function:
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # if batch normalization is needed, the related layer can be added in the __init__ and called to process the output that is obtained
        # after the maxpooling to normalize the input that is provided to the net layer.
        # the batch normalization layer would have to be defined for every layer group as torch.nn.BatchNorm2d(n_features), where
        # n_features is the number of channels of the input tensor
        x1 = self.relu(self.l12(self.relu(self.l11(x)))) # features extracted from the first layer, before the pooling. Saved because we have to pass it to the decoder layer
        x2 = self.relu(self.l22(self.relu(self.l21(self.pool1(x1))))) # features extracted from layer 2
        x3 = self.relu(self.l32(self.relu(self.l31(self.pool1(x2)))))
        x4 = self.upConv1(x3)
        x4 = torch.cat((x4, x2), 1) # concatenation along columns
        x5 = self.upConv2(self.relu(self.l42(self.relu(self.l41(x4)))))
        x5 = torch.cat((x5, x1), 1) # concatenation along columns
        out = self.sigmoid(self.outConv(self.relu(self.l52(self.relu(self.l51(x5))))))
        if self.n_classes != 2:
            out = out*3
        return torch.round(out).float() # just to be sure that the round returns a tensor with float values


class BiggerUnet(nn.Module):
    def __init__(self, n_channels, n_classes=4):
        super(SimpleUnet, self).__init__()

        self.n_classes = n_classes

        # Encoder:
        # Layer group 1:
        self.l11 = nn.Conv2d(n_channels, 64, kernel_size=3, padding=1) # Img size: 240x240x64
        self.l12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        # Layer group 2:
        self.l21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # 120x120x128
        self.l22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # Layer group 3:
        self.l31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # 60x60x256
        self.l32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        # Layer group 4:
        self.l41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # 30x30x512
        self.l42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

        # Layer group 5 - bridge between encoder and decoder:
        self.l51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) # 15x15x1024
        self.l52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)

        # Decoder:
        # Layer group 1:
        self.upConv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.l61 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.l62 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        # Layer group 2:
        self.upConv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.l71 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.l72 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        # Layer group 3:
        self.upConv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.l81 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.l82 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        # Layer group 4:
        self.upConv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.l91 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.l92 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        # Output layer
        self.outConv = nn.Conv2d(64, 1, kernel_size=1)

        # Activation function:
        self.relu = nn.ReLU()
        self.relu = nn.Sigmoid()

        # Pooling layer:
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # if batch normalization is needed, the related layer can be added in the __init__ and called to process the output that is obtained
        # after the maxpooling to normalize the input that is provided to the net layer.
        # the batch normalization layer would have to be defined for every layer group as torch.nn.BatchNorm2d(n_features), where
        # n_features is the number of channels of the input tensor
        x1 = self.relu(self.l12(self.relu(self.l11(x))))
        x2 = self.relu(self.l22(self.relu(self.l21(self.pool(x1)))))
        x3 = self.relu(self.l32(self.relu(self.l31(self.pool(x2)))))
        x4 = self.relu(self.l42(self.relu(self.l41(self.pool(x3)))))
        x5 = self.relu(self.l52(self.relu(self.l51(self.pool(x4)))))
        x6 = self.upConv1(x5)
        x6 = torch.cat((x6, x4), 1)
        x7 = self.upConv2(self.relu(self.l62(self.relu(self.l61(x6)))))
        x7 = torch.cat((x7, x3), 1)
        x8 = self.upConv2(self.relu(self.l72(self.relu(self.l71(x7)))))
        x8 = torch.cat((x8, x2), 1)
        x9 = self.upConv2(self.relu(self.l82(self.relu(self.l81(x8)))))
        x9 = torch.cat((x9, x1), 1)
        out = self.sigmoid(self.outConv(self.relu(self.l92(self.relu(self.l91(x9))))))
        if self.n_classes != 2:
            out = out*3
        return torch.round(out).float() # just to be sure that the round returns a tensor with float values

### Local data pre-processing
This snippet of code was used to preprocess the dataset provided by kaggle to obtain the tensors with the segmentation masks and the four scan types for each patient. Because of memorization issues, only the first 48 patients out of the 1258 in the dataset were considered. This portion of code also implements the calculation of the areas of the tumor lesions per slice for each class, which can be used as an additional performance evaluator.

In [None]:
# PREPROCESSING PART - done in local folders because of the size of the data:
#
# if __name__ == "__main__":
#    voxel_dim = 1 # in mm^3
#    Datapath = "C:\\Users\\jzatt\\Desktop\\PhD\\01_Maral_Course_MachineLearning\\00_Project\\01_Data\\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData\\ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData"
    # Entering in the dir containing the patients scan folders:
#    os.chdir(Datapath)
    # Getting the folders names as a list:
#    IMGFolders = os.listdir(os.getcwd())
    # Each folder corresponds to a different patient, and contains:
    # - segmentation info (i.e. masks with each pixel's class)
    # - T1 whole-brain scan
    # - T1 gadolinium contrasted whole-brain scan
    # - T2 contrast whole-brain scan
    # - T2 FLAIR whole-brain scan

#    volumes_dict = {}
#    for i in range(len(IMGFolders)): # iterates over all the folders in the main directory
#        if IMGFolders[i] == "BraTS-GLI-00049-000":
#            break # breaking the loop on the folders because of memory issues.
        # Entering each patient scans folder:
#       os.chdir(IMGFolders[i])
#        print("Working on folder:\n\t- "+ str(IMGFolders[i]) +":" )
        # getting the list of nifti.gz zipped folders (each contains one .nii with the whole-brain scan):
#        fnames = os.listdir(os.getcwd())
        # Iterating over the patients gz folders:
#        for name in fnames:
#            if "t2w.nii" in name or "t2f.nii" in name or "t1n.nii" in name or "t1c.nii" in name or "seg.nii" in name: # if the nii.gz folder is as expected
#                print("\t Working on file: "+ str(name))
#                imData = loadImage(name)
#-------------- WORKING on the segmentation MAKS:
#                if "seg.nii" in name: # this identifies only the segmentation masks files
#                    if i == 0: # i.e., if str(IMGFolders[i]) == str(IMGFolders[0]): if we are in the folder of the first patient, initialize the scans masks tensor:
#                        SegIM = torch.tensor(imData) # puts all the segmentation masks in the tensor
#                    else: # if it's not the first patient, just append to the existing tensor:
#                        SegIM = torch.cat((SegIM, torch.tensor(imData)), 2) # concatenates over dimension 3 (= # of scans) the segmentation scans
#------------------ Creating the labels array for ease of indentification of healthy and tumor scans
#                    SegClass = []
#                    for ind in range(imData.shape[2]): # iterating over the number of scans: each nifti / segmentated file has 240x240 pixels & 155 scans.
#                        Nclasses = np.unique(imData[:,:,ind]) # looks at the single scans segmentation to identify the number of classes in it
#                        if len(Nclasses) > 1: # since the scans either contain healthy tissue or healthy + tumor tissues, if unique returns only one value it's necessary a 0 (== healthy tissue indicator). If > 1, we have at least one of the three tumor classes in it.
#                            SegClass.append(1) # identifies tumor scans
#                            for tClass in range(1,4): # saves the volumes of each tumor class in an array
#                                if (ind+SegIM.shape[2]-imData.shape[2]) not in volumes_dict:
#                                    volumes_dict[ind+SegIM.shape[2]-imData.shape[2]] = {}
#                                Label = "CL_"+str(tClass)
#                                indexes = np.where(imData[:,:,ind]==float(tClass)) # gets the indexes where the segmentation is equal to 1.0, 2.0 or 3.0. It returns an array with two elements: the indexes on the x and y dimensions, so counting one of them equals to counting the total number of pixels of that class in the current scan
#                                volumes_dict[ind+SegIM.shape[2]-imData.shape[2]][Label] = len(indexes[0])*voxel_dim # sum(imData[:,:,i][imData[:,:,i]==tClass])*voxel_dim # double sum bc the == creates a boolean mask; the first sum counts per columns, the second per rows
#                        else:
#                            SegClass.append(0) # identifies healthy scans
#                    if i==0: # if looking at the first patient, initialize the reduced labels tensor inserting the first value:
#                        Labels = torch.tensor(SegClass) # list with the current patient labels
#                    else: # concatenates SegClass to the tensor of the reduced labels
#                        Labels = torch.cat((Labels, torch.tensor(SegClass)), 0)
#-------------- WORKING on the ACTUAL SCANS:
#                else: # if the scans are for T1 or T2 images, contrast and FLAIR comprised:
#                    if i == 0: # if it's the first patient, initialize tensors:
#                        if "t1c.nii" in name:
#                            IMGsT1C = torch.tensor(imData) # T1 contrast tensor
#                        elif "t1n.nii" in name:
#                            IMGsT1N = torch.tensor(imData) # T1 normal tensor
#                        elif "t2f.nii" in name:
#                            IMGsT2F = torch.tensor(imData) # T2 FLAIR tensor
#                        else: # if "t2w-nii" in name:
#                            IMGsT2W = torch.tensor(imData) # T2 contrast tensor
#                    else: # concatenate scans to already existing tensor:
#                        if "t1c.nii" in name:
#                            IMGsT1C = torch.cat((IMGsT1C, torch.tensor(imData)), dim=-1)
#                        elif "t1n.nii" in name:
#                            IMGsT1N = torch.cat((IMGsT1N, torch.tensor(imData)), dim=-1)
#                        elif "t2f.nii" in name:
#                            IMGsT2F = torch.cat((IMGsT2F, torch.tensor(imData)), dim=-1)
#                        else: # if "t2w-nii" in name:
#                            IMGsT2W = torch.cat((IMGsT2W, torch.tensor(imData)), dim=-1)
#                print("\t Done!")
#            else:
#                print("WARNING!\n\tUnable to recognize input nifti file for "+str(name)+".\n\tExpected the gz folders to end
#                with wither t1c.nii, t1n.nii, t2f.nii or t2w.nii. Skipping this folder...")
#        os.chdir("../")

#-- Normalization step:
#    for s in range(len(Labels)):
        # normalization per single scan:
#        if (IMGsT1C[:,:,s].max()-IMGsT1C[:,:,s].min()) != 0:
#            IMGsT1C[:,:,s] = (IMGsT1C[:,:,s]-IMGsT1C[:,:,s].min())/(IMGsT1C[:,:,s].max()-IMGsT1C[:,:,s].min())
#        if (IMGsT1N[:,:,s].max()-IMGsT1N[:,:,s].min()) != 0:
#            IMGsT1N[:,:,s] = (IMGsT1N[:,:,s]-IMGsT1N[:,:,s].min())/(IMGsT1N[:,:,s].max()-IMGsT1N[:,:,s].min())
#        if (IMGsT2W[:,:,s].max()-IMGsT2W[:,:,s].min()) != 0:
#            IMGsT2W[:,:,s] = (IMGsT2W[:,:,s]-IMGsT2W[:,:,s].min())/(IMGsT2W[:,:,s].max()-IMGsT2W[:,:,s].min())
#        if (IMGsT2F[:,:,s].max()-IMGsT2F[:,:,s].min()) != 0:
#            IMGsT2F[:,:,s] = (IMGsT2F[:,:,s]-IMGsT2F[:,:,s].min())/(IMGsT2F[:,:,s].max()-IMGsT2F[:,:,s].min())

#-- Saving step:
#    with open('Lesion_Ref_Volumes_All_Dataset.json', 'w') as outfile:
#        json.dump(volumes_dict, outfile)
#    torch.save(Labels, "Simplified_Labels_All_dataset.pt") # tensor of simplified labels for each scan. The position in this tensor identifies the third dimension (== # of scan) in the segmentation and images tensors, i.e.: label[156] corresponds to entry [:,:,156] for any of the 3D tensors below.
#    torch.save(SegIM.float(), "Ref_Segmentation.pt") # whole dataset tensor of reference segmentation
#    torch.save(IMGsT1C.float(), "T1C_All_dataset.pt") # whole dataset tensor of T1 contrast imgs
#    torch.save(IMGsT1N.float(), "T1N_All_dataset.pt") # whole dataset tensor of T1 non-contrast imgs
#    torch.save(IMGsT2F.float(), "T2F_All_dataset.pt") # whole dataset tensor of T2 FLAIR imgs
#    torch.save(IMGsT2W.float(), "T2W_All_dataset.pt") # whole dataset tensor of T2 contrast imgs
#
# END OF local PREPROCESSING ---------------------------------------------------

### Dataset loading with google drive mounting
The result of the local pre-processing were uploaded in google drive. The following lines of code implement the drive mount on colab, which allows to load the data from the drive without requiring implementing REST requests functions.

Link to drive folder containing the data: https://drive.google.com/drive/folders/13Ne8UOwpfebXg6mvf7ZMPAhSaBnaeugX?usp=drive_link

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Dataset loading from preprocessed tensors. The data refer to patients 000 to 048.
DataPath = "./gdrive/MyDrive/BME630/"
Labels = torch.load(DataPath + "Simplified_Labels_All_dataset.pt") # tensor of simplified labels for each scan. The position in this tensor identifies the third dimension (== # of scan) in the segmentation and images tensors, i.e.: label[156] corresponds to entry [:,:,156] for any of the 3D tensors below.
Seg = torch.load(DataPath + "Ref_Segmentation.pt") # tensor of reference segmentation
T1C = torch.load(DataPath + "T1C_All_dataset.pt") # tensor of T1 contrast imgs
T1N = torch.load(DataPath + "T1N_All_dataset.pt") # tensor of T1 non-contrast imgs
T2F = torch.load(DataPath + "T2F_All_dataset.pt") # tensor of T2 FLAIR imgs
T2W = torch.load(DataPath + "T2W_All_dataset.pt") # tensor of T2 contrast imgs

### Dividing the dataset in training, test and validation

In [None]:
# since we want to provide an input that corresponds to 4 scans of the same patient, obtained with different methods (t1c, t1n, t2w, t2f) but referring to the same brain slice, to separate training, test and validation set while retaining this info we have to divide the dataset by working on the indexes of the scan arrays that were obtained from the preprocessing step. In other words, train_test_split input is an array of indexes, that go from 0 to len(Labels).
# The stratification is performed on the Labels array, which contains the simplified class label info for each slice.
IDX = np.arange(0, len(Labels))
Idx_train, Idx_test, Labels_train, Labels_test = train_test_split(IDX, Labels, test_size=0.2, stratify=Labels, random_state=0)
# Dividing training set in train e validation sets:
IDX_train = np.arange(0, len(Labels_train))
Idx_train, Idx_val, Labels_train, Labels_val = train_test_split(IDX_train, Labels_train, test_size=0.2, stratify=Labels_train, random_state=0)

# Training set:
# The labels vector is the same for all 4 sets, and is Labels_train.
T1C_train = T1C[:,:,Idx_train]
T1N_train = T1N[:,:,Idx_train]
T2F_train = T2F[:,:,Idx_train]
T2W_train = T2W[:,:,Idx_train]
Seg_train = Seg[:,:,Idx_train]

# Validation set:
# The labels vector is the same for all 4 sets, and is Labels_val.
T1C_val = T1C[:,:,Idx_val]
T1N_val = T1N[:,:,Idx_val]
T2F_val = T2F[:,:,Idx_val]
T2W_val = T2W[:,:,Idx_val]
Seg_val = Seg[:,:,Idx_val]

# Test set:
# The labels vector is the same for all 4 sets, and is Labels_test.
T1C_test = T1C[:,:,Idx_test]
T1N_test = T1N[:,:,Idx_test]
T2F_test = T2F[:,:,Idx_test]
T2W_test = T2W[:,:,Idx_test]
Seg_test = Seg[:,:,Idx_test]

print("---- Dataset info ----")
print("Dataset dimensions per type of scan:\n\t- Whole dataset: {} scans ({} per image type);\n\t- Training set: {} scans ({} per image type);\n\t- Validation set: {} scans ({} per image type);\n\t- Test set: {} scans ({} per image type)".format(Labels.shape[0]*4,Labels.shape[0], Labels_train.shape[0]*4,Labels_train.shape[0], Labels_val.shape[0]*4,Labels_val.shape[0], Labels_test.shape[0]*4,Labels_test.shape[0]))
print("Each dataset is further divided in 4 subsets, because we have: T1, T1 contrast, T2 an T2 FLAIR scans. Each scan is {} x {} pixels.".format(Seg[:,:,0].shape[0],Seg[:,:,0].shape[1]))

print("\n---- Classes distribution in the dataset per type of scan ----\nOverall there are:\n\t- {} scans of class 0 (= healthy tissue)\n\t- {} scans of class 1 (= tumor lesions)".format(len(np.where(Labels==0.0)[0]), len(np.where(Labels==1.0)[0])))
print("Concerning the datasets, there are:\n\t- Training set: \n\t\t- {} class 0 scans\n\t\t- {} class 1 scans\n\t- Validation set: \n\t\t- {} class 0 scans\n\t\t- {} class 1 scans\n\t- Test set: \n\t\t- {} class 0 scans\n\t\t- {} class 1 scans".format(len(np.where(Labels_train==0.0)[0]), len(np.where(Labels_train==1.0)[0]), len(np.where(Labels_val==0.0)[0]), len(np.where(Labels_val==1.0)[0]), len(np.where(Labels_test==0.0)[0]), len(np.where(Labels_test==1.0)[0])))

print("\nPlotting of the same slice from different scans to prove that the same scan id applied to the 4 scan types tensors returns the exact same brain slice of the same patient:")
plt.figure(figsize=(6,4))
plt.subplot(221)
plt.imshow(T1N[:,:,55], cmap="gray")
plt.title("T1N - patient 000, scan 55")
plt.show()
plt.subplot(222)
plt.imshow(T1C[:,:,55], cmap="gray")
plt.title("T1C - patient 000, scan 55")
plt.show()
plt.subplot(223)
plt.imshow(T2W[:,:,55], cmap="gray")
plt.title("T2 - patient 000, scan 55")
plt.show()
plt.subplot(224)
plt.imshow(T2F[:,:,55], cmap="gray")
plt.title("T2F - patient 000, scan 55")
plt.show()
# second patient:
plt.figure(figsize=(6,4))
plt.subplot(221)
plt.imshow(T1N[:,:,210], cmap="gray")
plt.title("T1N - patient 002, scan 55")
plt.show()
plt.subplot(222)
plt.imshow(T1C[:,:,210], cmap="gray")
plt.title("T1C - patient 002, scan 55")
plt.show()
plt.subplot(223)
plt.imshow(T2W[:,:,210], cmap="gray")
plt.title("T2 - patient 002, scan 55")
plt.show()
plt.subplot(224)
plt.imshow(T2F[:,:,210], cmap="gray")
plt.title("T2F - patient 002, scan 55")
plt.show()

### Grouping data for correct input size and batch definition
The following lines rearrange the data in a format that is suitable for the network, i.e. in the [tensor_depth, #_channels, img_height, img_width] format.

In [None]:
NumScans = 4
# Initializing 4D tensors for the three subsets:
X_train = torch.zeros(len(Labels_train), NumScans, T1C[:,:,0].shape[0], T1C[:,:,0].shape[1]) # creating a 4D tensor that groups the brain slices in 3D tensors of 4 scans.
X_val = torch.zeros(len(Labels_val), NumScans, T1C[:,:,0].shape[0], T1C[:,:,0].shape[1])
X_test = torch.zeros(len(Labels_test), NumScans, T1C[:,:,0].shape[0], T1C[:,:,0].shape[1])
# The train, validation and test tensors are arranged in the shape [scan_#, n_channels, height, width] because of the type of input required for the network.
for i in range(len(Labels_train)):
    X_train[i,0,:,:]=T1N_train[:,:,i] # T1
    X_train[i,1,:,:]=T1C_train[:,:,i] # T1 contrast
    X_train[i,2,:,:]=T2W_train[:,:,i] # T2
    X_train[i,3,:,:]=T2F_train[:,:,i] # T2 FLAIR

for i in range(len(Labels_val)):
    X_val[i,0,:,:]=T1N_val[:,:,i] # T1
    X_val[i,1,:,:]=T1C_val[:,:,i] # T1 contrast
    X_val[i,2,:,:]=T2W_val[:,:,i] # T2
    X_val[i,3,:,:]=T2F_val[:,:,i] # T2 FLAIR

for i in range(len(Labels_test)):
    X_test[i,0,:,:]=T1N_test[:,:,i] # T1
    X_test[i,1,:,:]=T1C_test[:,:,i] # T1 contrast
    X_test[i,2,:,:]=T2W_test[:,:,i] # T2
    X_test[i,3,:,:]=T2F_test[:,:,i] # T2 FLAIR

### Metric Dictionary for miseval:
By manually importing the miseval package functions, we need to include this snippet of code to handle the attributes passing for the function "evaluate", which computes the metrics on the segmentation images.

In [None]:
metric_dict = {
  "TruePositive": calc_TruePositive,
  "TrueNegative": calc_TrueNegative,
  "FalsePositive": calc_FalsePositive,
  "FalseNegative": calc_FalseNegative,
  "TP": calc_TruePositive,
  "TN": calc_TrueNegative,
  "FP": calc_FalsePositive,
  "FN": calc_FalseNegative,
  "DSC": calc_DSC_Enhanced,
  "Dice": calc_DSC_Enhanced,
  "DiceSimilarityCoefficient": calc_DSC_Enhanced,
  "IoU": calc_IoU_Sets,
  "Jaccard": calc_IoU_Sets,
  "IntersectionOverUnion": calc_IoU_Sets,
  "ACC": calc_Accuracy_Sets,
  "Accuracy": calc_Accuracy_Sets,
  "AUC": calc_AUC_trapezoid,
  "AUC_trapezoid": calc_AUC_trapezoid,
  "Sensitivity": calc_Sensitivity_Sets,
  "SENS": calc_Sensitivity_Sets,
  "TPR": calc_Sensitivity_Sets,
  "TruePositiveRate": calc_Sensitivity_Sets,
  "Recall": calc_Sensitivity_Sets,
  "SPEC": calc_Specificity_Sets,
  "Specificity": calc_Specificity_Sets,
  "TNR": calc_Specificity_Sets,
  "TrueNegativeRate": calc_Specificity_Sets,
  "PREC": calc_Precision_Sets,
  "Precision": calc_Precision_Sets
  }

### Network run: Training phase
The training phase here is implemented with the validation set regularization, i.e. an early stopping condition is defined when the loss value on the validation data for one epoch is greater than the loss two epochs before. This check is done from the 3rd training epoch on.

The training is performed using three for loops: on the learning rate value, on the epochs and on the training set batches, respectively from outer to inner loop. This is because the idea is to train the same network for different learning rates to identify the most suitable for our architecture. The training is divided in 100 batches, which means that every batch has 46 samples.

In [None]:
#---- Training phase:
# User definition of # classes to identify:
print("\nPlease selected the number of classes you want to consider between:\n\t- 2 (simplified segmentation)\n\t- 4 (complete segmentation)")
while True:
    uin = input("> ")
    if int(uin) == 2 or int(uin) == 4:
        break
    else:
        print("It looks like something went wrong. Please retry and ensure you type either 2 or 4.")
n_classes = int(uin)
if n_classes == 2: # converts the 4-classes segmentation masks in binary masks.
    Seg_train[Seg_train!=0.0]=1.0
    Seg_val[Seg_val!=0.0]=1.0
    Seg_test[Seg_test!=0.0]=1.0
print("Running the network using {} classes".format(uin))

# This will allow the network to run on GPUs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print("\n---- Run settings info: ----\nRunning using: ", device)

# learning rate:
learning_rate = np.round(np.linspace(0.005, 0.02, 10),4)
# momentum:
mom = 0.9 # momentum applied to the stochastic gradient descent algorithm; helps to direct the descent towards the direction of steeper decrease,
# allowing the use of higher learning rates and the algorithm to converge faster. 0,9 is a commonly used value, but it can be optimized as one of the network parameters.

# Loss function: Cross Entropy loss. Preferred because it allows to evaluate the loss even when we deal with more than 2 classes,
# which might be the case in this application. A weight vector can be provided to the loss function in case the dataset has an unbalanced representation of the classes.
# This method is applied only if the performance without the class weight is suboptimal.
# if n_classes == 2:
#  class_weights = torch.tensor([1, 5])
# else:
#  class_weights = torch.tensor([1, 3, 5, 20]) # the first is for class 0, the second for cl 1, the third for cl 2 and the third for cl 3. Could be improved by using a weight that is proportional to each class size.
# For now, the weight increases from class 0 to 3 because the classes are progressively less represented in the dataset. class 1 and class 2 (tumor core and edema) have closer weights because they make the majority of tumor tissues.
loss_function = nn.CrossEntropyLoss() # with weights: loss_function = nn.CrossEntropyLoss(class_weights)
loss_function.to(device)

# Number of batches:
n_batches = 100
elem_in_batch = round(X_train.shape[0]/n_batches) # elements in a single training batch
# Training epochs:
epochs = 300

# Since we are not implementing batches on the validation, we can push the related data to the device already:
X_val.to(device)    # validation data
Seg_val.to(device)  # validation segmentation

print("\n---- Network RUN - parameters optimization ----")
for lr in learning_rate:
    # Net creation:
    # Simpler architecture:
    net_name = "SimpleUnet"
    net = SimpleUnet(4, n_classes)
    # More complex architecture: uncomment the next line and comment the previous to implemet the second network architecture.
    # net = BiggerUnet(4, n_classes)
    # net_name = "BiggerUnet"

    net.to(device)
    # Number of trainable parameters:
    total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    # Printing network info:
    print("Network info:")
    print("\t- Number of classes: {}\n\t- Number of training batches: {}\n\t- Elements in a single batch: {}\n\t- Loss function: Cross Entropy\n\t- Momentum: {}\n\t- Number of trainable parameters: {}\n\t- Current learning rate: {}".format(n_classes, n_batches, elem_in_batch, mom, total_params,lr))
    print("Network architecture:\n", net)
    # optimizer: SGD
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=mom)
    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        if epoch == 0 or epoch % 10 == 0:
            print("Epoch # "+str(epoch))
        avg_loss = 0 # average loss for the 10 batches, calculated every epoch.
        for batches in range(n_batches):
            print("\tUsing batch # "+str(batches+1))
            if batches < 9: # broken down with an if statement because the # of samples in the training set is 4563, which divided by 10 (# of batches) returns 456,3. Considering each batch made of 456 elements, the last one has 456 + 3 = 459 elements, which in order ot be considered require a different x and y definition.
                x = X_train[batches*elem_in_batch:elem_in_batch+batches*elem_in_batch, :, :, :]
                y = Seg_train[:, :, batches*elem_in_batch:elem_in_batch+batches*elem_in_batch]
            else:
                x = X_train[batches*elem_in_batch:-1, :, :, :]
                y = Seg_train[:, :, batches*elem_in_batch:-1]
            x.to(device)
            y.to(device)
            # Set network to training mode:
            net.train()
            # Training using the batch data (x) and the associated labels (y):
            singleBatch_loss = train(net, x, y, loss_function)
            # adding the current batch loss to the avg_loss:
            avg_loss += singleBatch_loss
        train_loss.append(avg_loss/n_batches)

        # Setting network in evaluation mode:
        net.eval()
        # Evaluating the loss on the validation set as early stopping condition on training to avoid overfitting:
        validate(net, X_val, Seg_val, loss_function, val_loss)
        if epoch == 0 or epoch % 10 == 0:
            print("\t Train loss = {}\t\t Validation loss = {}".format(train_loss[epoch], val_loss[epoch]))
        if epoch > 2:
            if val_loss[epoch] > val_loss[epoch-2]: # considering 2 epochs before allows to have a bit of flexibility if the loss on the validation has slight fluctuations. This could be improved by introducing a threshold on the difference between validation losses, such that if the val_loss fluctuates for several epochs but the relative difference with previous epochs is small, the training can go on anyway.
               print("\tEarly stopping at iteration # "+str(epoch+1))
               break
    # Loss function on epochs plot:
    plt.figure(figsize=(10,4))
    plt.plot(train_loss, range(1,epochs+1), 'r', label="Train loss")
    plt.plot(val_loss, range(1,epochs+1), 'b', labels="Val loss")
    plt.title("Cross Entropy loss - L.R. = "+str(lr))
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    # Saving the figure:
    try:
      figname = net_name+"LR_"+str(lr)+".jpg"
      plt.savefig(figname)
    except:
      print("Unable to save "+figname)
    # Saving the model:
    net_name = net_name+"_LR_"+str(lr)+".pt"
    torch.save(net.state_dict(), net_name)

### Network Run: Testing phase
The testing phase is performed only on the best model that we obtained out of the training phase. The best model is defined as the one having the highest learning rate while providing a reasonable loss on the training and the test data. If the script is run on a cluster, the testing can be commented and evaluated at a later moment since it requires the definition of the best model's name either in the code (in place of the 'None' of net_name) or user provided via input method.

The testing phase implements the miseval functions to evaluate the goodness of the segmentation. The scores calculated are:
- Dice score;
- Jaccard score (or Intersection over Union score);
- the usual accuracy, precision, sensitivity and specificity scores of the related confusion matrix.

In [None]:
#---- Best model evaluation:
# Model loading:
net_name = None
if net_name is None:
    print("Please specify the name of the file containing the network you want to upload:")
    while True:
        uin2 = input("> ")
        if os.path.exists(str(uin2)):
            break
        else:
            print("Unable to find the requested file. Please try again...")
net = SimpleUnet(n_classes) # reinitialize model; the load_state_dict saved all the params that will be reapplied in the next line of code
net.load_state_dict(torch.load(net_name))
# Evaluation on test
net.to(device)
net.eval()
X_test.to(device)   # test data
Seg_test.to(device) # test segmentation
with torch.no_grad():
  Dice, Jacc, Acc, Prec, Sens, Rec, pred_seg, real_seg = test(net, X_test, Seg_test, n_classes)

### Segmentation visualization

In [None]:
# Plot of the segmentations for comparison:
# Plot # 1:
plt.figure(figsize=(10,4))
# Real segmentation:
plt.subplot(121)
plt.imshow(real_seg[1,0,:,:])
plt.title("Ref. segmentation - Test set slice 2")
plt.show()
# Network generated segmentation:
plt.subplot(122)
plt.imshow(pred_seg[1,0,:,:])
plt.title("Pred. segmentation - Test set slice 2")
plt.show()

# Plot # 2:
plt.figure(figsize=(10,4))
# Real segmentation:
plt.subplot(121)
plt.imshow(real_seg[2,0,:,:])
plt.title("Ref. segmentation - Test set slice 3")
plt.show()
# Network generated segmentation:
plt.subplot(122)
plt.imshow(pred_seg[2,0,:,:])
plt.title("Pred. segmentation - Test set slice 3")
plt.show()