# Imporation

In [4]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [5]:
import numpy as np 
import pandas as pd 

import SimpleITK as sitk
from ipywidgets import interact, fixed
from tqdm import tqdm 
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from torch.autograd import Variable

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
import albumentations as A 
from collections import OrderedDict
import random
import gc


from loss.ssim import * 
from models.UNet import *
from datasets.merging_dataset import * 

In [6]:
# SEED Everything 

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Model Description

## Generative adversarial network (GAN)

The generative adversarial network, consists of two components: a generator G and a discriminator D. 
It's a deep learning framework that trains the generative model and the discriminative model alternately. The general idea of the generative adversarial network is an adversarial process amongst the models pitting against each other to improve the performance of the networks, where the generator counterfeits the sample images to deceive the discriminator and the discriminative model determines whether the images are
fake or not. 

In our case, the generative model will reconstitute a 3T HR like image of the hippocampus area in the brain. The adversarial model will try to discriminate the generative images from the ground truth 3T HR.

![](https://malekmechergui.github.io/work/model%20gan.png)

## Generative network

Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox: https://arxiv.org/abs/1505.04597

![](https://malekmechergui.github.io/work/Image1.png)

UNet was first model designed especially for medical image segmentation. It showed such good results that it used in many other fields after.

The main idea behind CNN is to learn the feature mapping of an image and exploit it to make more nuanced feature mapping. This works well in classification problems as the image is converted into a vector which used further for classification. But in image segmentation, we not only need to convert feature map into a vector but also reconstruct an image from this vector. This is a mammoth task because it’s a lot tougher to convert a vector into an image than vice versa. The whole idea of UNet is revolved around this problem.

While converting an image into a vector, we already learned the feature mapping of the image so why not use the same mapping to convert it again to image. This is the recipe behind UNet. Use the same feature maps that are used for contraction to expand a vector to a segmented image. This would preserve the structural integrity of the image which would reduce distortion enormously. 

We will be using this UNet architecture as an Generative network in our Model.

## Adversarial network 

The adversarial model takes the 3T HR hippocampal image or 3T HR like produced by the generative model as input. A binary class probability
is produced to determine whether the image is the ground truth 3T HR or not.

![](https://malekmechergui.github.io/work/Adversial.png)

## Loss Function 

The Loss function formula of the generative adversarial network is as follows:


$$ l(\:\theta_g ,\:\theta_a\:) = \sum_{n=1}^{N} \: \:\:l_{mce}(g(x_n),y_n) - [ \: l_{bce}(\:a(x_n,y_n),1\:) + l_{bce}(\:a(x_n,g(x_n)),0\:)\:]$$

Where :
* $x_n$ denotes input 3T LR image with the size of H×W
* $y_n$ corresponding 3T HR image of $x_n$ 
* $g(x) $  denotes the class probability map over M classes
* $a(x, y)∈[0,1]$  denotes the probability predicted by the adversarial model
*  $\theta_g$, $\theta_a$ are the parameters of the generative model and of the adversarial model respectively.
* $ l_{mce}( \hat{y} , y ) = - \sum_{i=1}^{H \times W }\: \sum_{m=1}^{M} \: y_{im} ln (\hat{y_{im}}) $ : denotes the multiclass cross-entropy loss 

* $ l_{bce}(\hat{a} ,a ) = = − a \: ln (\hat{a}) + (1 − a) \:ln( 
1 − \hat{a}) $ : denotes the binary cross entropy loss . 

When training the generative model, we minimize the loss with respect to $θ_g$, while maximizing it with respect to $θ_a$ when training the adversarial model.

# Model implementation 

## Genrative Network (UNet)


In [7]:
#unet

## Adversarial 

In [8]:
class Adeversarial(nn.Module) : 
    def __init__(self) : 
        super(Adeversarial, self).__init__()
        
        self.image_encoder = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 128 , kernel_size = 5), 
            nn.ReLU(inplace=True), 
            nn.MaxPool2d(kernel_size = 2 , stride=2),
        
            nn.Conv2d(128, 128, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 2, kernel_size=3),
            nn.MaxPool2d(kernel_size=3, stride=3),
            nn.Sigmoid()
        
        )
        
    def forward(self, image):
        y = self.image_encoder(image)
        return(y[:,:,0,0])           

In [9]:
gc.collect()

68

# Loss function

In [10]:
#ssim

In [11]:
def cross_entropy_loss(y_hat, y) : 
    return nn.CrossEntropyLoss()(y_hat, y)

In [12]:
def loss_fn (img1, img2):
    return 1-SSIM()(img1, img2)

# Train function

In [13]:
def train_fn( data_loader , netG , netD , optimizerG , optimizerD ):
    
    netG.train()
    netD.train()
   
    D_loss = 0 
    G_loss = 0 
    counter = 0 
    train_ssim = 0 
    
    if verbose : 
        tk0 = tqdm(enumerate(data_loader), total=len(data_loader))
    else : 
        tk0 = enumerate(data_loader)
    for bi, d in tk0 :   
 
        
        y = d["HR"].to(device, dtype=torch.float)
        x = d["LR"].to(device, dtype=torch.float) 
        BATCH_SIZE = y.shape[0]
        
        ### Train The Genrative model 
        netG.zero_grad()    

        label = Variable( (torch.ones(BATCH_SIZE)).to(device , dtype = torch.long) ) 
        generated_reconstruction = netG(x.unsqueeze(1)) 
        output_g = netD(generated_reconstruction.detach())

        lossG_seg = loss_fn(generated_reconstruction, y.unsqueeze(1))
        train_ssim += lossG_seg.item()
        lossG_class = cross_entropy_loss(output_g, label)
        lossG = lossG_seg*0.6 + lossG_class*0.4
        G_loss += lossG.item() 
        lossG.backward(retain_graph=True)
        optimizerG.step()
        
        ### Train the Adversial model 
        netD.zero_grad()    
      
        label  = Variable((torch.ones(BATCH_SIZE)).to(device , dtype = torch.long))
        output = netD(y.unsqueeze(1))
        lossD_real  = cross_entropy_loss(output, label)
        
        label = Variable((torch.zeros(BATCH_SIZE)).to(device, dtype = torch.long))
        lossD_fake  = cross_entropy_loss(output_g, label)
        
        lossD = (lossD_real + lossD_fake)/2
        D_loss += lossD.item()
        lossD.backward(retain_graph=True)
        optimizerD.step()

       
      
        counter += 1 
        
    return G_loss/counter , D_loss/counter  , train_ssim/counter

## Evaluation function

In [14]:
def eval_fn(data_loader, model):
    model.eval()
    tr_loss = 0
    counter = 0

    if verbose : 
        tk0 = tqdm(enumerate(data_loader), total=len(data_loader))
    else : 
        tk0 = enumerate(data_loader)
    with torch.no_grad():
        
        for bi, d in tk0 :
       
            y = d["HR"].to(device, dtype=torch.float)
            x = d["LR"].to(device, dtype=torch.float) 
            x = x.unsqueeze(1)
            
            y_hat = model(x) #forward prop
            y = y.unsqueeze(1)
            
            loss = loss_fn(y_hat, y) # Loss calaculation of batch i 
            
            
            tr_loss += loss.item()
            counter +=1 

        
            
        return tr_loss/counter 

# Run fucntion

In [15]:
def run(netD,netG,TRAIN_BATCH_SIZE, VALID_BATCH_SIZE):
    
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle = True , 
        batch_size=TRAIN_BATCH_SIZE,
        num_workers=8
    )
    
    
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        num_workers=4
    )

    

    # Setup Optimizer for G and D
    optimizerD = optim.Adam(netD.parameters(), lr=lr)
    optimizerG = optim.Adam(netG.parameters(), lr=lr)

    Dloss = []
    Gloss = []
    trainloss = []
    valoss = []
    
    best_validation_dsc = 1.0
    patience = 0 
    
    for epoch in range(EPOCHS):
        if verbose : 
            print(f'--------- Epoch {epoch} ---------')
        elif epoch % 10 == 0 : 
            print(f'--------- Epoch {epoch} ---------')
       
        G_loss, D_loss, train_ssim  = train_fn(train_data_loader, netG , netD , optimizerG , optimizerD)
        if verbose : 
            print(f"D_loss  = {D_loss}       ,     G_loss  = {G_loss}   ,        train_ssim = {train_ssim}")
        elif epoch % 10 == 0 : 
            print(f"D_loss  = {D_loss}       ,     G_loss  = {G_loss}   ,        train_ssim = {train_ssim}")
    
        
        val  = eval_fn(valid_data_loader, netG)
        if verbose : 
            print(f" val_loss  = {val}")
        elif epoch % 10 == 0 :
            print(f" val_loss  = {val}")
            
        
        if val < best_validation_dsc : 
            best_validation_dsc =val 
            patience = 0 
            torch.save(netG.state_dict(), 'netG.pt')
            torch.save(netD.state_dict(), 'netD.pt')

        else : 
            patience +=1
        
        if patience>20 : 
            print(f'Eraly Stopping on Epoch {epoch}')
            print(f'Best Loss =  {best_validation_dsc}')
            break
        
        Dloss.append(D_loss)
        Gloss.append(G_loss)
        trainloss.append(train_ssim)
        valoss.append(val)
        
    netG.load_state_dict(torch.load('netG.pt'), strict=False)
    netD.load_state_dict(torch.load('netD.pt'), strict=False)
    
    return Dloss, Gloss, trainloss, valoss


# Prep train

In [16]:
data = pd.read_csv('data_5fold.csv')
subjects = data[data['slice']==0]

## Training on Left hippocampus 

In [17]:
# hyper param
lr = 2e-4
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 16
EPOCHS = 200
device = torch.device('cuda')
verbose = False

In [None]:
# 5 folds cross validation
for f in range(5) : 
    
    df_train = data[data['kfold'] !=f]
    df_valid = data[data['kfold'] ==f]
    train_dataset = Merging_data_set(df_train, subjects, Left = True , is_train = True)
    valid_dataset = Merging_data_set(df_valid, subjects, Left = True , is_train  = False)
    netG = UNet(1,1)
    netD = Adeversarial()
    netG = netG.to(device)
    netD = netD.to(device)
    Dloss, Gloss, trainloss, valoss = run(netD,netG, TRAIN_BATCH_SIZE, VALID_BATCH_SIZE)
    torch.save(netG.state_dict(), f'trained_model/GAN_reconstruction/GAN Left fold {f}.pt')

[get_training_augmentation]  resize_to: (160, 160)
--------- Epoch 0 ---------
D_loss  = 0.6576804381608963       ,     G_loss  = 0.5379263031482696   ,        train_ssim = 0.38653976798057554
 val_loss  = 0.28347673296928405
--------- Epoch 10 ---------
D_loss  = 0.37061293005943297       ,     G_loss  = 0.6677623766660691   ,        train_ssim = 0.28270873069763186
 val_loss  = 0.2696079325675964
--------- Epoch 20 ---------
D_loss  = 0.6930881154537201       ,     G_loss  = 0.43756081342697145   ,        train_ssim = 0.2670405423641205
 val_loss  = 0.26286670446395877
--------- Epoch 30 ---------
D_loss  = 0.6782529121637344       ,     G_loss  = 0.47150147438049317   ,        train_ssim = 0.2603526383638382
 val_loss  = 0.2538441264629364
--------- Epoch 40 ---------
D_loss  = 0.6929100465774536       ,     G_loss  = 0.4285150998830795   ,        train_ssim = 0.2515767413377762
 val_loss  = 0.24993059039115906
--------- Epoch 50 ---------
D_loss  = 0.6930287629365921       ,     G_

In [None]:
plt.plot(Dloss)
plt.plot(Gloss)
plt.plot(trainloss)
plt.plot(valoss)
plt.title('Training on Left Hippocampus')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Discriminator','Generator','Train', 'Validation'], loc='upper right')
plt.show()

In [None]:
gc.collect()

## Training on Right hippocampus 

In [None]:
# 5 folds cross validation
#for f in range(5) : 
#    
#    df_train = data[data['kfold'] !=f]
#    df_valid = data[data['kfold'] ==f]
#    train_dataset = Merging_data_set(df_train  , subjects, Left = False , is_train = True)
#    valid_dataset = Merging_data_set(df_valid  , subjects,  Left = False , is_train  = False)
#    netG = UNet(1,1)
#    netD = Adeversarial()
#    netG = netG.to(device)
#    netD = netD.to(device)
#    Dloss, Gloss, trainloss, valoss = run(netD,netG,TRAIN_BATCH_SIZE, VALID_BATCH_SIZE)
#    torch.save(netG.state_dict(), f'GAN Right fold {f}.pt')