# Importations

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [3]:
import gc ,random 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import seaborn as sns
from sklearn import model_selection

import cv2
import SimpleITK as sitk
from ipywidgets import interact, fixed
from tqdm import tqdm 
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
import albumentations as A 

from collections import OrderedDict
from sklearn.model_selection import train_test_split

from loss.dice import * 
from loss.ssim import * 
from models.UNet import *
from datasets.merging_dataset import * 

In [4]:
# SEED Everything 
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Model Description 

![](https://nhoues.github.io/Segmentaion-and-reconstruction-MRI-/images/cunet_v1.PNG)

# Loss function

# Model implementation

In [None]:
class CascadedUNet(nn.Module) : 
    def __init__(self) :
        super(CascadedUNet,self).__init__()
        self.layer_1  = UNet(1,1, segmentation = False) 
        self.layer_2  = UNet(2,4, segmentation = False)
        self.layer_3  = UNet(5,1, segmentation = False)
        self.layer_4  = UNet(2,4, segmentation = False)
    def forward(self,image) :
        hr_1 = self.layer_1(image)
        im = torch.cat([hr_1,image],dim=1)
        seg_1 = self.layer_2(im)
        
        seg_t = self.segmentation_gen(seg_1)
        im = torch.cat([seg_t,image],dim=1)
        
        hr_2 = self.layer_3(im) 
        
        im = torch.cat([hr_2,image],dim=1)
        seg_2 = self.layer_4(im)
        return hr_2,seg_2,hr_1,seg_1
    def segmentation_gen(self,x) : 
        y_1 = torch.argmax(nn.Softmax2d()(x) , dim=1)
        x_label_0 = (y_1==0).type(torch.long).unsqueeze(1)
        x_label_1 = (y_1==1).type(torch.long).unsqueeze(1)
        x_label_2 = (y_1==2).type(torch.long).unsqueeze(1)
        x_label_3 = (y_1==3).type(torch.long).unsqueeze(1)
        y_1 = torch.cat([x_label_0,x_label_1,x_label_2,x_label_3] , dim = 1)
        y_1 = y_1.type(torch.float)
        return y_1 

# Engine 

In [4]:
def loss_fn (img1, img2):
    return 1-SSIM()(img1, img2)

In [5]:
def dice_loss(pred, target, smooth = 1.):
    
    pred =  torch.sigmoid(pred)
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    loss_label_1 = loss[:,0].mean()
    loss_label_2 = loss[:,1].mean()
    loss_label_3 = loss[:,2].mean()
    loss_label_4 = loss[:,3].mean()

    return ((loss_label_1+loss_label_2+loss_label_3+loss_label_4)/4 , (loss_label_1 , loss_label_2 ,loss_label_3,loss_label_4))

def train_fn(data_loader, model, optimizer, scheduler,device):
    
    model.train()
   
    tr_loss = 0 
    counter = 0 
    
    if verbose : 
        tk0 = tqdm(enumerate(data_loader), total=len(data_loader))
    else : 
        tk0 = enumerate(data_loader)
    for bi, d in tk0 : # LOOP : batch number i   
        

        real_mask = d["label"].to(device, dtype=torch.long)

        LR = d["LR"].to(device, dtype=torch.float) 
        HR = d["HR"].to(device, dtype=torch.float)  
        
        HR_1 , mask_1 , HR_2 , mask_2   = model(LR.unsqueeze(1)) #forward prop
        
        loss_seg_1 , _ = dice_loss (mask_1, real_mask) # Loss calaculation of batch i 
        loss_seg_2 , _ = dice_loss(mask_2 , real_mask)
        
        loss_rec_1 = loss_fn(HR_1, HR.unsqueeze(1) )
        loss_rec_2 = loss_fn(HR_2 , HR.unsqueeze(1) )
        
    
        loss = (loss_seg_1+loss_seg_2+loss_rec_1+loss_rec_2)/4
        optimizer.zero_grad() #
       
        tr_loss += loss.item()
        counter +=1 
        
        loss.backward()  # backward prop 
        optimizer.step() 
        
        
    return tr_loss/counter

def eval_fn(data_loader, model , device ):
    model.eval()
    seg_loss = 0 
    rec_loss = 0
    counter = 0
    
    label1_loss  = 0
    label2_loss  = 0 
    label3_loss  = 0 
    label4_loss = 4 
    if verbose : 
        tk0 = tqdm(enumerate(data_loader), total=len(data_loader))
    else : 
        tk0 = enumerate(data_loader)

    with torch.no_grad():
        
        for bi, d in tk0 :
       
            real_mask = d["label"].to(device, dtype=torch.long)

            LR = d["LR"].to(device, dtype=torch.float) 
            HR = d["HR"].to(device, dtype=torch.float)  
            HR_1 , mask_1 , _ , _    = model(LR.unsqueeze(1)) #forward prop
            
            loss , labels = dice_loss(mask_1, real_mask) # Loss calaculation of batch i 
            
            ssim_score = loss_fn(HR_1 , HR.unsqueeze(1) )
            
            label1_loss += labels[0].item() 
            label2_loss += labels[1].item() 
            label3_loss += labels[2].item() 
            label4_loss += labels[3].item() 
            
            seg_loss += loss.item()
            rec_loss += ssim_score.item()
            counter +=1 
        return rec_loss/counter , seg_loss/counter ,  (label1_loss /counter ,label2_loss /counter, label3_loss /counter,label4_loss/counter)

def run(model, EPOCHS , train_dataset , valid_dataset , device , LR , TRAIN_BATCH_SIZE ,VALID_BATCH_SIZE):
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle = True , 
        batch_size=TRAIN_BATCH_SIZE,
        num_workers=8
    )
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        num_workers=4
    )
    num_train_steps = int(len(train_data_loader)) * EPOCHS
    optimizer = optim.Adam(model.parameters(), lr=LR)   
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
    train_loss =  []
    rec_score = []
    seg_score = []
    val_loss = []
    best_validation_dsc = 1.0
    patience = 0 
    print(f'*****************************')
    print(f'********* Fold {f} **********')
    print(f'*****************************')
    for epoch in range(EPOCHS):
        
        if verbose : 
            print(f'--------- Epoch {epoch} ---------')
        elif epoch % 10 == 0 : 
            print(f'--------- Epoch {epoch} ---------')
       
        tr_loss=train_fn(train_data_loader, model, optimizer, scheduler,device)
        train_loss.append(tr_loss)
       
        if verbose : 
            print(f" train_loss  = {tr_loss}")
        elif epoch % 10 == 0 : 
            print(f" train_loss  = {tr_loss}")

        
        rec , seg , _  = eval_fn(valid_data_loader, model,device)
        val = rec  
        
        rec_score.append(1-rec)
        seg_score.append(1-seg)
        val_loss.append(val)
        if verbose : 
            print(f" Segmentation  Dice  = {1-seg} , Reconstruction SSIM = {1-rec}")
        elif epoch % 10 == 0 : 
            print(f" Segmentation  Dice  = {1-seg} , Reconstruction SSIM = {1-rec}")
       

        
        if val < best_validation_dsc : 
            best_validation_dsc =val 
            patience = 0 
            torch.save(model.state_dict(), 'CUNet.pt')
        else : 
            patience +=1
        
        if patience>30 : 
            print(f'Eraly Stopping on Epoch {epoch}')
            print(f'Best Loss =  {best_validation_dsc}')
            break
        scheduler.step()
        
    model.load_state_dict(torch.load('CUNet.pt'), strict=False)
    return val_loss,train_loss

# Model Training 

In [7]:
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 32
LR = 5e-4
EPOCHS = 250
device = torch.device('cuda')
verbose = False 

In [8]:
all_data = pd.read_csv('data_5fold.csv')
subjects = all_data[all_data['slice']==0]

In [None]:
train_folds_loss = []
valid_folds_loss = []
for f in range(1) : 
    df_train = all_data[all_data['kfold'] !=f]
    df_valid = all_data[all_data['kfold'] ==f]
    Left_train_dataset = Merging_data_set(df_train ,  subjects  , Left = True , is_train = True)
    Left_valid_dataset = Merging_data_set(df_valid  ,  subjects  , Left = True , is_train  = False)
    Left_model = CascadedUNet()
    Left_model = Left_model.to(device)
    val_loss , train_loss =  run( Left_model  , EPOCHS , Left_train_dataset , Left_valid_dataset , device , LR , TRAIN_BATCH_SIZE , VALID_BATCH_SIZE )
    train_folds_loss.append(train_loss)
    valid_folds_loss.append(val_loss) 
    torch.save(Left_model.state_dict(), f'trained_model/Casceded UNet/Rec Cascaded Unet Left fold {f}.pt')

[get_training_augmentation]  resize_to: (160, 160)
*****************************
********* Fold 4 **********
*****************************
--------- Epoch 0 ---------


In [None]:
def plot_(f) :
    plt.plot(train_folds_loss[f]) 
    plt.plot(valid_folds_loss[f])
    plt.title(f'Learning curve fold 0={f}')
    plt.ylabel('score')
    plt.xlabel('epoch')
    plt.legend(['Train', 'Val'], loc='upper right')
    plt.show()

In [None]:
plot_(0)