# UNet


## Imports

In [None]:
"""Imports"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset  # For custom data-sets
from torch.utils.data import DataLoader
from torch import nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import re
import glob
import random
%matplotlib inline 
from typing import List, Callable, Tuple
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import copy
from collections import defaultdict
from tqdm import tqdm
import os
import time
from datetime import datetime

from livelossplot import PlotLosses
from livelossplot.outputs import MatplotlibPlot

In [None]:
def makeDir(bpath):
    """Makes a new, non-overwriting folder (makes a new folder name if name already exists), returns path of folder"""
    path=bpath
    pathMade=False
    numPath=1
    while not(pathMade):
        try:
            os.makedirs(path, exist_ok=False)
            pathMade = True
        except:
            path=bpath+str(numPath)
            numPath+=1
    return path

### Dataset Method

In [None]:
"""Methods and Object Functions"""
# Creating method to import custom datasets
class CustomDataset(Dataset):
    def __init__(self, im_dir, mask_dirs, roi_type, transform=None):        
        self.im_dir = im_dir
        self.im_masks_dir = mask_dirs
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.long
        self.roi_type = roi_type
        
    def __len__(self): # rerturning the number of samples we have
        return len(self.im_masks_dir)
    
    def __getitem__(self, idx):        
        # making sure everything is forward slash
        file2Load_mask = "/".join(self.im_masks_dir[idx].split("\\"))
        n_pat  = re.compile("(?<=/P)\d*(?!>_S)").findall(file2Load_mask)[0] # getting string of number - patient number 
        if n_pat=="":
            n_pat=re.compile("(?<=/P)\d*(?!>_S)").findall(file2Load_mask)[1] # getting string of number - patient number 
        if int(n_pat) <100: # logic to get the correct file name since <100 pat number becomes 0##
            n_pat = '0' + n_pat
        n_slice = re.compile("(?<=_S)\d*(?!>_)").findall(file2Load_mask)[0] # getting string of number - slice number
        file2Load_im = "/".join(self.im_dir.split("\\")) + '/P' + n_pat + '_cmplx.mat' # getting file name of T2 decay data
                
        # Loading data
        mask = loadmat(file2Load_mask)[self.roi_type]
        image = (loadmat(file2Load_im)['cmplx'][:,:,int(n_slice),:]) # grabbing correct T2 decay data
        name = "P"+n_pat+"_S"+n_slice
        
        # Augmentations/Tranformations
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
            
        # transofrming data to torch datatype
        if torch.is_tensor(image):
            image = image.type(self.inputs_dtype) # permuting necessary for correct input order
            mask = mask.type(self.targets_dtype)
        else:
            image = torch.from_numpy(image).type(self.inputs_dtype).permute(2,0,1) # permuting necessary for correct input order
            mask = torch.from_numpy(mask).type(self.targets_dtype)
        
        return image, mask, name

    
# splitData(splitTrain,fileNames) 
#  - Mostly random spilt of files for training / validation
#  input:
#    splitTrain - the percentage to of the data used for training in decimal form
#    fileNames  - the names of the files used for training / validation
#  returns:
#    folder_train - the names of the files used for training, split by splitTrain percentage
#    folder_valid - the names of the files used for validation
def splitData(split_train,fileNames):
    folder_train = glob.glob(fileNames + '/*.mat')
    folder_valid = []
    numTotal = np.size(folder_train)
    numTrain = int(np.round(numTotal *split_train))
    numValid = np.size(fileNames) - numTrain

    # flag for quitting the while loop
    breakFlag = 0
    while ~breakFlag:
        # obtaining a random seed number - since patients should be lumped together for vaildation or training set
        n = random.randint(0,np.size(folder_train)-1)

        # parsing string to get patient number
        pnum = re.compile("(?<=/P)\d*(?!>_S)").findall("/".join(folder_train[n].split("\\")))[0]

        indices = [i for i, elem in enumerate(folder_train) if 'P'+pnum in elem]
        # print(indices[::-1])

        # If at the end of the split, we'll need to randomly select slices from a patient. This logic below does that
        if (np.size(folder_train) - np.size(indices)) <= numTrain:
            breakFlag = 1        
            # calculating how many more training datasets need to moved to the validation set,
            # nlastMove <= indices
            nlastMove =  np.size(folder_train)  - numTrain
            indices = sorted(random.sample(indices,nlastMove)) #gathering a random subset of indices

        # Appending files to the tempList
        for i in indices[::-1]:
            folder_valid.append(folder_train.pop(folder_train.index(folder_train[i])))

        if breakFlag:
#             print(numTotal, numTrain, numValid, np.size(folder_train), np.size(tempList))
            return folder_train, folder_valid

#         print(numTotal, numTrain, numValid, np.size(folder_train), np.size(tempList), np.size(indices))
    

### Transformations

In [None]:
# train_transform = A.Compose(
#     [
#         A.Resize(240,240),
#         A.ShiftScaleRotate(shift_limit=0.001, scale_limit=0.01, rotate_limit=2, border_mode=2, p=0.5),
#         A.Normalize(mean=0.5, std=0.25),
#         A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, p=0.5),
#         A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, approximate=False, p=0.5),
#         A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, p=0.5),
#         ToTensorV2(),
#     ]
# )
train_transform = A.Compose(
    [
#         A.Resize(240,240),
        A.ShiftScaleRotate(shift_limit=0.001, scale_limit=0.01, rotate_limit=2, border_mode=2, p=0.5),
#         A.Normalize(mean=0.5, std=0.25),
        A.GridDistortion(num_steps=5, distort_limit=0.1, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, p=0.5),
        A.ElasticTransform(alpha=0.5, sigma=10, alpha_affine=10, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, approximate=False, p=0.5),
        A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, interpolation=1, border_mode=2, value=None, mask_value=None, always_apply=False, p=0.5),
        ToTensorV2(),
    ]
)

In [None]:
test_transform = A.Compose(
    [
        ToTensorV2(),
    ]
)

### Creating the Neural Network (UNet)

In [None]:
def create_model(params):
    model = UNET_V2(params["in_channels"],params["out_channels"])
    model = model.to(params["device"])
    return model

In [None]:
# Neural Network 
# - July 9, 2021
class UNET_V2(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET_V2, self).__init__()
        """ Encoder """
        self.e1 = encoder_block(in_channels, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)         
        """ Bottleneck """
        self.b = conv_block(512, 1024)         
        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)        
        """ Classifier """
        self.outputs = nn.Conv2d(64, out_channels, kernel_size=1, padding=0)     
    
    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)         
        """ Bottleneck """
        b = self.b(p4)         
        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)         
        """ Classifier """
        outputs = self.outputs(d4)        
        return outputs

    
class encoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = conv_block(in_channels, out_channels)
        self.pool = nn.MaxPool2d((2, 2))     
    
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_channels+out_channels, out_channels)     
    
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x
    
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)         
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)         
        self.relu = nn.ReLU()     
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

In [None]:
# Neural Network

class UNET_V1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNET_V1, self).__init__()       
        
        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)

        self.upconv3 = self.expand_block(128, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)
        
    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        upconv3 = self.upconv3(conv3)

        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1
    
    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand
    

### Training/Validating Model Functions

In [None]:
def train(train_loader, model, criterion, optimizer, epoch, params, logs=None):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    epoch_loss= [] # for plotting training
    for i, (images, target, name) in enumerate(stream, start=1):
        images = images.to(params["device"], non_blocking=True)
        target = target.to(params["device"], non_blocking=True)
        target = target.float() # need to cast to float
#         output = model(images).squeeze(1)
        output = model(images)
    
        loss = criterion(output, target)
        metric_monitor.update("Loss", loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        stream.set_description("Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor))
        
        epoch_loss = loss
        epoch_acc = ((output.squeeze()>=0.5)==target).sum()/(target.size(0)*target.size(1)*target.size(2))
        logs['Training loss'] = epoch_loss.item()
        logs['Training accuracy'] = epoch_acc.item()

    return logs
   

In [None]:
def validate(val_loader, model, criterion, epoch, params, logs=None):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    epoch_loss= [] # for plotting validation
    with torch.no_grad():
        for i, (images, target, name) in enumerate(stream, start=1):
            images = images.to(params["device"], non_blocking=True)
            target = target.to(params["device"], non_blocking=True)
            target = target.float()
            output = model(images)
            loss = criterion(output, target)
            metric_monitor.update("Loss", loss.item())
            stream.set_description("Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor))
            
            epoch_loss = loss
            epoch_acc = ((output.squeeze()>=0.5)==target).sum()/(target.size(0)*target.size(1)*target.size(2))
            logs['Validation loss'] = epoch_loss.item()
            logs['Validation accuracy'] = epoch_acc.item()
         
    return logs

In [None]:
def train_and_validate(model, train_dataset, val_dataset, params, savePath=None):
    train_loader = DataLoader(
        train_dataset,
        batch_size=params["batch_size"],
        shuffle=True,
        num_workers=params["num_workers"],
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=params["batch_size"],
        shuffle=False,
        num_workers=params["num_workers"],
        pin_memory=True,
    )
#     criterion = nn.BCEWithLogitsLoss().to(params["device"])
    criterion = FocalLoss(gamma=params["gamma"],alpha=params["alpha"]).to(params["device"])
#     criterion = WeightedFocalLoss(gamma=params["gamma"],alpha=params["alpha"])
    optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
    
    # for plotting
    if params["plotting"]:
#         loss_train =  [] # for plotting training
#         loss_valids =  [] # for plotting validation
#         fig,ax = plt.subplots(1,2);
        liveloss = PlotLosses(outputs=[MatplotlibPlot(figpath =savePath)])
    
    
    for epoch in range(1, params["epochs"] + 1):
        
        logs={}
        # logic to provide the correct axis to the plotter
        if params["plotting"]:
            logs=train(train_loader, model, criterion, optimizer, epoch, params, logs)
            logs=validate(val_loader, model, criterion, epoch, params, logs)
            liveloss.update(logs)
            liveloss.send()
        else:
            train(train_loader, model, criterion, optimizer, epoch, params)
            validate(val_loader, model, criterion, epoch, params)
        
    if params["plotting"]:
        return model, liveloss
    else:
        return model

In [None]:
def predict(model, params, test_dataset, batch_size):
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=params["num_workers"], pin_memory=True,
    )
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, mask in test_loader:
            images = images.to(params["device"], non_blocking=True)
            output = model(images)
            probabilities = torch.sigmoid(output.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1
            predicted_masks = predicted_masks.cpu().numpy()
            for predicted_mask in predicted_masks:
                predictions.append(predicted_mask)
    return predictions

In [None]:
def evaluateModel(predIms,gtIms):
    """Calculates 3 evaluations: pixel accuracy, jaccard's index, and dice coefficient of a prediction and mask"""
    pIm =predIms.squeeze()==1 # prediction image
    mIm = gtIms.squeeze()==1 # mask image
    
    numPix = np.shape(mIm)[0]*np.shape(mIm)[1] # total number of pixels in an image
    tp = (pIm*mIm).sum() # true positive
    tn = (~pIm*~mIm).sum() # true negative
    fn = (~pIm*mIm).sum() # false negatives
    fp = (pIm*~mIm).sum() # false positives
    
    # Pixel Accuracy    
    PA = (tp+tn)/numPix # pixel accuracy     
    # Jaccard's Index (Intersection over Union, IoU)
    JI = tp/(tp+fn+fp)
    # Dice Coefficient
    DC = 2*tp/(2*tp+fn+fp)
    
    return PA,JI,DC
    

def evaluateModelAndPredict(model, params, test_dataset, batch_size):
    """Calculates various metrics to evaluate the model based on predictions and ground truth
       in addition to providing the predictions"""
    # getting the data loader for the test set
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=params["num_workers"], pin_memory=True,
    )
    model.eval()
    predictions = []
    
    # getting evaluation metrics
    PA = []  # Pixel Accuracy
    JI = []  # Jaccard's Index (Intersection over Union, IoU)
    DC = []  # Dice Coefficient
    evaluation = dict() # to store evaluation metrics
    
    with torch.no_grad():
        for images, mask, name in test_loader:
            images = images.to(params["device"], non_blocking=True)
            output = model(images)
            probabilities = torch.sigmoid(output.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1
            predicted_masks = predicted_masks.cpu().numpy()
            for predicted_mask in predicted_masks:
                predictions.append(predicted_mask)
            
            tempPA, tempJI, tempDC = evaluateModel(predicted_masks,mask.cpu().numpy())
            PA.append(tempPA)
            JI.append(tempJI)
            DC.append(tempDC)
    
    evaluation["Pixel Accuracy"] = PA
    evaluation["Jaccard Index"] = JI
    evaluation["Dice Coefficient"] = DC
    
    
    return predictions, evaluation



def evaluateModelAndPredict_andPrint(model, params, test_dataset, batch_size, mainPath, kernelSize=0):
    """Calculates various metrics to evaluate the model based on predictions and ground truth
       in addition to providing the predictions"""
    # getting the data loader for the test set
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=params["num_workers"], pin_memory=True,
    )
    model.eval()
    predictions = []
    
    # getting evaluation metrics
    PA = []  # Pixel Accuracy
    JI = []  # Jaccard's Index (Intersection over Union, IoU)
    DC = []  # Dice Coefficient
    evaluation = dict() # to store evaluation metrics
    
    
    ### Making directory to save to
    outputImDir = makeDir(mainPath +'/'+ 'OutputImages_k'+str(kernelSize))
    
    with torch.no_grad():
        for images, mask, name in test_loader:
            images = images.to(params["device"], non_blocking=True)
            output = model(images)
            probabilities = torch.sigmoid(output.squeeze(1))
            predicted_masks = (probabilities >= 0.5).float() * 1
            predicted_masks = predicted_masks.cpu().numpy()
            for predicted_mask in predicted_masks:
                predictions.append(predicted_mask)
            
            tempPA, tempJI, tempDC = evaluateModel(predicted_masks,mask.cpu().numpy())
            PA.append(tempPA)
            JI.append(tempJI)
            DC.append(tempDC)
            
            # output image
            kernel = np.ones((kernelSize,kernelSize), np.uint8)
            img_mod = cv2.dilate(cv2.erode(predicted_mask, kernel, iterations=1),kernel,iterations=1)
            cv2.imwrite(outputImDir +"/" +name[0] + ".png", img_mod * 255) # needs to be a 0 and 255 image
    
    evaluation["Pixel Accuracy"] = PA
    evaluation["Jaccard Index"] = JI
    evaluation["Dice Coefficient"] = DC
    
    return predictions, evaluation

## Loss Function

In [None]:
# class FocalLoss(nn.modules.loss._WeightedLoss):
#     def __init__(self, weight=None, gamma=2,reduction='mean'):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.reduction = reduction
#         self.weight = torch.Tensor([weight,1-weight]) #weight parameter will act as the alpha parameter to balance class weights

#     def forward(self, input, target):
    
#         print(self.weight)
#         ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
#         pt = torch.exp(-ce_loss)
#         focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
#         return focal_loss

In [None]:
class WeightedFocalLoss(nn.Module):
    "Non weighted version of Focal Loss"
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha])
        self.gamma = gamma

    def forward(self, inputs, targets):
        
        if inputs.dim()>2:
            inputs = inputs.view(inputs.size(0),inputs.size(1),-1)  # N,C,H,W => N,C,H*W
            inputs = inputs.transpose(1,2)    # N,C,H*W => N,H*W,C
            inputs = inputs.contiguous().view(-1,inputs.size(2))   # N,H*W,C => N*H*W,C
        targets = targets.view(-1,1)
        
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-BCE_loss)
        F_loss = at*(1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [None]:
class FocalLoss(nn.Module):

    def __init__(self, gamma=1, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target): 
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        
        logpt = F.log_softmax(input, dim=0)
        if(np.shape(input)[1]>1):
            logpt = logpt.gather(1,target.cpu().long())
            
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.cpu().long().data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()


### Visualization Functions

In [None]:
# """For Plotting the image batches"""

def display_image_grid_loader(dataloader_obj):
    """Fuction to display images from a dataloader object as a grid"""
    echonum = 2 # number 4 is usually the brightest -> around 8ms, though selected 2
    
    imgs, masks, name = next(iter(dataloader_obj)) # next iteration 
    gridImgs = torchvision.utils.make_grid(imgs,nrow=4, normalize=True,scale_each=True)
#     print(np.shape(imgs), np.shape(masks))

    masks = masks.reshape(np.shape(masks)[0],1,np.shape(masks)[1],np.shape(masks)[2])
    gridMasks = torchvision.utils.make_grid(masks,nrow=4)
#     print(np.shape(imgs), np.shape(masks))

    f,ax = plt.subplots(2,figsize=(10,10))
#     f,ax = plt.subplots(2)
    ax[0].imshow(gridImgs[echonum,:,:], interpolation=None)
    ax[1].imshow(gridMasks[0,:,:], interpolation=None)
    ax[0].set_axis_off()
    ax[1].set_axis_off()
    plt.tight_layout()
    plt.show()

    
def display_image_grid_dataset(dataset_obj, numIms, predicted_masks=None):
    """Function to display images from a (custom)dataset object as a grid"""
    echonum = 4 # number 4 is usually the brightest -> around 8ms
    
    rows = numIms
    cols = 3 if predicted_masks else 2
    figure, ax = plt.subplots(nrows=rows, ncols=cols,figsize=(2*6,numIms*3))
    
    for i, batch in enumerate(dataset_obj, start=0):
        image, mask = batch
        image = torch.div(image[echonum,:,:],torch.max(image[echonum,:,:])) # image is normalized
        mask = mask[:,:]
        
        ax[i,0].imshow(image, cmap='gray', interpolation=None) #,cmap='gray'
        ax[i,1].imshow(mask, interpolation=Noned)


        ax[i,0].set_ylabel(name)
        ax[i,0].xaxis.set_ticks([])
        ax[i,0].yaxis.set_ticks([])
        ax[i,1].set_axis_off()

        if predicted_masks:
            predicted_mask = predicted_masks[i]
            ax[i,2].imshow(predicted_mask, interpolation=None)
            ax[i,2].set_axis_off()
        
        # break if reached number of images to print
        if (i+1 == numIms): break
    
    ax[0,0].set_title("Image, Slice: "+str(echonum))
    ax[0,1].set_title("Ground truth mask")   
    if predicted_masks:
        ax[0,2].set_title("Predicted mask") 
    plt.tight_layout()
#     plt.show()



def visualize_augmentations(dataset, idx=0, samples=5):
    """Visualizing augmentations performed on dataset"""
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    figure, ax = plt.subplots(nrows=samples, ncols=3, figsize=(10, 24))
    for i in range(samples):
        image, mask, name = dataset[idx]
        ax[i, 0].imshow(image[4,:,:], interpolation=None)
        ax[i, 2].imshow(image[2,:,:], interpolation=None)
        ax[i, 1].imshow(mask, interpolation=None)
        ax[i, 0].set_title("Augmented image")
        ax[i, 1].set_title("Augmented mask")
        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()
        ax[i, 2].set_axis_off()
    plt.tight_layout()
    plt.show()

In [None]:
class MetricMonitor:
    def __init__(self, float_precision=3):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

## Begining the Modelling [User Input]

In [None]:
"""Initializing to CUDA (gpu framework)"""
# CUDA for PyTorch
import gc
gc.collect()
torch.cuda.empty_cache()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True
device = "cpu"


## Training Model

In [None]:
# train_loader = DataLoader(
#         train_dataset,
#         batch_size=params["batch_size"],
#         shuffle=True,
#         num_workers=params["num_workers"],
#         pin_memory=True,
#     )

# T_im, T_mask = next(iter(train_loader))
# output=model(T_im)
# print(np.shape(output.squeeze()))
# out=output.squeeze().permute(1,0).detach().numpy()

# print(np.shape(out))
# print(np.shape(T_mask))
# plt.imshow(out)
# plt.figure()
# plt.imshow(T_mask.squeeze(0))


In [None]:
"""Getting timestamp"""
now = datetime.now() # current time
timestamp = now.strftime("%Y%m%d_%H%M")

params = {
    "Timestamp": timestamp,
    "model": "UNET_V2",
    "segType": "PZ", #or "PZ"
    "in_channels": 64,
    "out_channels": 1,
    "device": device,
    "lr": 0.00001,
    "batch_size": 10,
    "num_workers": 0,
    "epochs": 30,
    "gamma": 3,
    "alpha": 0.15,
    "transform": None, # None or "train_transform"
    "plotting": 1,
}
# output saving information
folder_saveoutputs = "C:/Users/candi/Documents/Research/1 LWI Project/1 Data/7 Model Outputs/" #need "/" after
modelName = 'UNET_cpu_' + params["segType"] + '_test'
# Folder location of matlab data
folder_data = 'C:/Users/candi/Documents/Research/1 LWI Project/1 Data/6 Datasets For ML/3 T2 Decay'
folder_mask_PZ= 'C:/Users/candi/Documents/Research/1 LWI Project/1 Data/6 Datasets For ML/PZ'
folder_mask_full =  'C:/Users/candi/Documents/Research/1 LWI Project/1 Data/6 Datasets For ML/Full'
folder_test_full = 'C:/Users/candi/Documents/Research/1 LWI Project/1 Data/6 Datasets For ML/Full_TEST'
folder_test_PZ = 'C:/Users/candi/Documents/Research/1 LWI Project/1 Data/6 Datasets For ML/PZ_TEST'



### Loading and Splitting Training/Valid Data

In [None]:
if params["segType"]=='full':
    folder_trainNvalid = folder_mask_full
    folder_test = folder_test_full
    roiFileStrName = 'roi_full'
elif params["segType"]=='PZ':
    folder_trainNvalid = folder_mask_PZ
    folder_test = folder_test_PZ
    roiFileStrName = 'roi_subzone'

# Separating Training, Validation, and Test set [20% of original already split]
# & other training parameters
split_train = 0.80
split_valid = 0.20

# Splitting Training / Testing / Valid
# folder_mask_train, folder_mask_valid = splitData(split_train,folder_mask) # for PZ data
folder_mask_train, folder_mask_valid = splitData(split_train,folder_trainNvalid)

# Creating the dataloader for training / validation
if params["transform"] is not None:
    train_dataset = CustomDataset(folder_data, folder_mask_train, roi_type = roiFileStrName, transform=eval(params["transform"]))
    valid_dataset = CustomDataset(folder_data, folder_mask_valid, roiFileStrName, transform=eval(params["transform"]))
else:
    train_dataset = CustomDataset(folder_data, folder_mask_train, roi_type = roiFileStrName, transform=(params["transform"]))
    valid_dataset = CustomDataset(folder_data, folder_mask_valid, roiFileStrName, transform=(params["transform"]))


### Performing Training/Validation

In [None]:
folder_saveoutputs+'/'+timestamp+"_"+modelName

In [None]:
# Save Model
mainPath = makeDir(folder_saveoutputs+'/'+timestamp+"_"+modelName)

t = time.time() # running a timer
model = create_model(params)
if params["plotting"]:
    model, lossplot = train_and_validate(model, train_dataset, valid_dataset, params, savePath=mainPath+'/lossFig.png')
else:
    model = train_and_validate(model, train_dataset, valid_dataset, params)    
elapsed = time.time() - t # running a timer
minElapsedStr=str(round(elapsed/60,2))


print(" ")
print("L-Rate: ", params["lr"])
print("Epochs: ", params["epochs"])
print("Batches:", params["batch_size"])
print("Gamma:  ", params["gamma"])
print("Alpha:  ", params["alpha"])

print('Start time: ' + now.strftime("%d/%m/%Y %H:%M:%S"))
print('Elapsed training time (mins): '+ minElapsedStr)

# Save Model
torch.save(model.state_dict(), mainPath+"/model.py")
with open(mainPath+"/params_output.txt", 'w') as f:
    for line in params:
        f.write((line+":").ljust(15)+str(params[line]))
        f.write('\n')
    f.write('\n')
    f.write('Path:'.ljust(25) + mainPath)
    f.write('\n')
    f.write('Train Duration (min):'.ljust(25) + minElapsedStr)
    f.write('\n')



## Loading Model

In [None]:
def getParamsFromText(textfilepath):
    """Read text file and gets the parameters saved from it"""
    f=open(textfilepath,'r')
    lines=f.readlines()
    f.close()
    params = dict()
    for i in range(lines.index("\n")):
        dictTemp=re.compile("\w*(?!>\s)").findall(lines[i])[0]
        keyTemp=re.compile("\w*\d*\.?\d*$(?!>\s)").findall(lines[i][15:])[0]
        if keyTemp.isdigit():
            keyTemp = int(keyTemp)
        elif keyTemp[0].isdigit():
            keyTemp = float(keyTemp)
        elif keyTemp=='None':
            keyTemp = None
        params[dictTemp] = keyTemp
        
    return params



In [None]:
# # # Load
# PATH = mainPath + '/model.py'
# textFilePath = mainPath + '/params_output.txt'
# params=getParamsFromText(textFilePath)
# model = UNET_V2(params["in_channels"],params["out_channels"])
# model.load_state_dict(torch.load(PATH))
# model.eval()


## Evaluate and Visualize Test Set

In [None]:
folder_test_list = glob.glob(folder_test + '/*.mat')
# test_dataset = CustomDataset(folder_data, folder_mask_train, roi_type='roi_full', transform=None)
test_dataset = CustomDataset(folder_data, folder_test_list, roi_type=roiFileStrName, transform=None)
test_loader = DataLoader(test_dataset,batch_size=1, shuffle=False)

# Doing the predictions
# predictions = predict(model, params, test_dataset, batch_size=1)
predictions,evaluation = evaluateModelAndPredict(model, params, test_dataset, batch_size=1)

# output of evaluation metrics
for key in evaluation:
    print(('Avg ' + key + ':').ljust(25) + str(np.average(evaluation[key])) + ' (' + str((np.average(evaluation[key])*100).round(2)) + '%)')
# saving to text file
with open(mainPath+"/params_output.txt", 'a') as f:
    for key in evaluation:
        f.write(('Avg ' + key + ':').ljust(25) + str(np.average(evaluation[key])) + ' (' + str((np.average(evaluation[key])*100).round(2)) + '%)')
        f.write('\n')
    # printing full output
    for key in evaluation:
        f.write((key + ':').ljust(25) + str(evaluation[key]))
        f.write('\n')


display_image_grid_dataset(dataset_obj=test_dataset, numIms=len(test_loader), predicted_masks=predictions)
plt.savefig(mainPath+"/outputFig.png")

## Visualize Augmentations

In [None]:
if params["transform"] is not None:
    visualize_augmentations(train_dataset, idx=0, samples=5)
    plt.savefig(mainPath+"/sampleAugmentation.png")
    

## Probability Map Example

In [None]:
# test_loader = DataLoader(
#         test_dataset, batch_size=batch_size, shuffle=False, num_workers=params["num_workers"], pin_memory=True,
#     )
# model.eval()
# predictions = []
# with torch.no_grad():
#     for images, mask in test_loader:
#         images = images.to(params["device"], non_blocking=True)
#         output = model(images)
#         probabilities = torch.sigmoid(output.squeeze(1))
#         predicted_masks = (probabilities >= 0.5).float() * 1
#         predicted_masks = predicted_masks.cpu().numpy()
#         for predicted_mask in predicted_masks:
#             predictions.append(predicted_mask)

T_im, T_mask = next(iter(test_loader))
model.eval()
output = model(T_im.to(params["device"], non_blocking=True))
output=output.squeeze(1)
print(np.min(output[0,:,:].to('cpu').detach().numpy()),np.max(output[0,:,:].to('cpu').detach().numpy()))
probabilities = torch.sigmoid(output)
predicted_masks = (probabilities >= 0.5).float() * 1
print(np.min(probabilities[0,:,:].to('cpu').detach().numpy()),np.max(probabilities[0,:,:].to('cpu').detach().numpy()))
print(np.average(predicted_masks))

figure, ax = plt.subplots(1,4, figsize=(15,15))
ax[0].imshow(T_im[0,7,:,:].cpu().detach().numpy())
ax[0].set_title('Original Image')
ax[1].imshow(T_mask[0,:,:].cpu().detach().numpy())
ax[1].set_title('Mask')
ax[2].imshow(probabilities[0,:,:].to('cpu').detach().numpy())
ax[2].set_title('Probabilities')
ax[3].imshow(predicted_masks[0,:,:].cpu().detach().numpy())
ax[3].set_title('Predicted Mask');

# T_im, T_mask = next(iter(train_loader))

In [None]:
# display_image_grid_dataset(train_dataset,4)
# display_image_grid_loader(train_loader)

In [None]:
# # # Testing
# e = 4 # echo number
# # T_im, T_mask = next(iter(train_loader))
# # T_im, T_mask = T_im.to(device,dtype=torch.float), T_mask.to(device,dtype=torch.float) # computation on "floats" are faster than "double" on GPU - hence the cast to float
# # print(T_im.shape, T_mask.shape)
# # pred = unet(T_im.cuda().float())
# # print(pred.shape)

# A = pred.cpu().detach().numpy()

# f, axarr = plt.subplots(1,4)
# plt.figure(figsize=(100,50))
# axarr[0].imshow(A[0,0,:,:])
# axarr[1].imshow(A[0,1,:,:])
# axarr[2].imshow(T_im[0,e,:,:])
# axarr[3].imshow(T_mask[0,:,:])


In [None]:
# # im,mask = next(iter(train_loader))
# # im = im[0,:,:,:]
# print(np.shape(im))
# print('max:', torch.max(im), ',  min:', torch.min(im))

# bounds = [torch.min(im), torch.max(im)]
# bins=10000
# cutoff = 0.95
# cutoff1 =0.999
# sli = 4
# hist = torch.histc(im[sli,:,:].flatten(),bins=bins, min=bounds[0], max=bounds[1])
# # normalize histogram to sum to 1
# hist = hist.div(hist.sum())
# # calculate the bin edges
# bin_edges = torch.linspace(bounds[0], bounds[1], steps=bins)
# # plotting
# plt.plot(bin_edges.cpu().detach().numpy(), hist.cpu().detach().numpy())
# plt.title('Original Hist')

# # plt.figure()
# s = np.cumsum(hist)
# plt.figure()
# plt.plot(bin_edges.cpu().detach().numpy(), s.cpu().detach().numpy())
# plt.title('Cumalitive sum plot (orig)')

# binednp = bin_edges.cpu().detach().numpy()
# histnp = hist.cpu().detach().numpy()
# snp = s.cpu().detach().numpy()
# maxval = binednp[snp<cutoff][-1]
# print(maxval, np.max(snp))


# plt.figure()
# nhist = torch.histc(im[sli,:,:].flatten(),bins=bins, min=bounds[0], max=maxval)
# nhist = nhist.div(nhist.sum())
# plt.plot(bin_edges.cpu().detach().numpy(), nhist)
# plt.title('After hist cutoff')

# # normalized
# image = copy.deepcopy(im)
# image = (image - image.mean())/image.std()

# # # normalized and histogramed
# bounds1 = [torch.min(image), torch.max(image)]
# hist1 = torch.histc(image[sli,:,:].flatten(),bins=bins, min=bounds1[0], max=bounds1[1])
# # normalize histogram to sum to 1
# hist1 = hist1.div(hist1.sum())
# # calculate the bin edges
# bin_edges1 = torch.linspace(bounds1[0], bounds1[1], steps=bins)
# # plotting
# plt.figure()
# plt.plot(bin_edges1.cpu().detach().numpy(), hist1.cpu().detach().numpy())
# plt.title('Normalized Hist')
# # calc cut off
# binednp1 = bin_edges1.cpu().detach().numpy()
# histnp1 = hist1.cpu().detach().numpy()
# s1 = np.cumsum(hist1)
# snp1 = s1.cpu().detach().numpy()
# maxval1 = binednp1[snp1<cutoff][-1]
# print(maxval1, np.max(snp1))
# # plotting
# plt.figure()
# nhist1 = torch.histc(image[sli,:,:].flatten(),bins=bins, min=bounds1[0], max=maxval1)
# nhist1 = nhist1.div(nhist1.sum())
# plt.plot(bin_edges1.cpu().detach().numpy(), nhist1)
# plt.title('Normalized and cut-off histogram')


# im_cpu = im.cpu().detach().numpy()
# ime = copy.deepcopy(im.cpu().detach().numpy())
# ime[im_cpu>maxval] = maxval
# image_cpu = image.cpu().detach().numpy()
# imagee = copy.deepcopy(im.cpu().detach().numpy())
# imagee[image_cpu>maxval1] = maxval1
# f,ax = plt.subplots(1,4,figsize=[15,15]);
# ax[0].imshow(im_cpu[sli,:,:]);
# ax[0].title.set_text('original')
# ax[1].imshow(ime[sli,:,:]);
# ax[1].title.set_text('hist cutoff')
# ax[2].imshow(image[sli,:,:])
# ax[2].title.set_text('normalized')
# ax[3].imshow(imagee[sli,:,:])
# ax[3].title.set_text('hist cutoff & normalized')

In [None]:
# """Testing"""
# all_dataset = CustomDataset(folder_data, glob.glob(folder_mask + '/*.mat'))
# all_loader = DataLoader(all_dataset,batch_size=batchsize, shuffle=True)
# numtotaldata = len(all_dataset)

# train_ds, valid_ds = torch.utils.data.random_split(all_dataset, (int(np.round(numtotaldata*split_train)), int(np.round(numtotaldata*split_valid))))
# train_dl = DataLoader(train_ds, batch_size=batchsize, shuffle=True)
# valid_dl = DataLoader(valid_ds, batch_size=batchsize, shuffle=True)
# print(len(train_ds), len(valid_ds))

# xb, yb = next(iter(train_dl))
# xb.shape, yb.shape

# # Testing
# # T_im, T_mask = next(iter(train_dataset))
# T_im, T_mask = next(iter(train_dl))
# T_im, T_mask = T_im.to(device,dtype=torch.float), T_mask.to(device,dtype=torch.float) # computation on "floats" are faster than "double" on GPU - hence the cast to float
# print(T_im.shape, T_mask.shape)
# pred = unet(T_im)
# print(pred.shape)

# ## to clear up cache in GPU?
# import gc
# gc.collect()
# torch.cuda.empty_cache()