In [None]:
#install dependencies
!pip install albumentations==0.4.6

In [None]:
#set true for model to train 
#reduce the barch size and num of models to set true , if CUDA out of memory error
unet=True
atten_unet=True
uR2net=False
uA2net=False #uR2net attention
u3net =False

In [None]:
#imports 
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2, ToTensor
from sklearn.model_selection import train_test_split

plt.style.use("fivethirtyeight")
#set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
#dataset path variables : used to find path of images and masks
BASE_PATH= "/kaggle/input/lgg-mri-segmentation/kaggle_3m"
BASE_LEN = 89
END_LEN = 4
END_MASK_LEN = 9
IMG_SIZE = 512

# Utility functions

In [None]:
#display augmented images    
def show_aug(inputs, nrows=5, ncols=5):
    plt.figure(figsize=(10, 10))
    plt.subplots_adjust(wspace=0., hspace=0.)
    i_ = 0
    
    if len(inputs) > 25:
        inputs = inputs[:25]
        
    for idx in range(len(inputs)):
    
        img = inputs[idx].numpy().astype(np.float32)
        img = img[0,:,:]
        
        plt.subplot(nrows, ncols, i_+1)
        plt.imshow(img); 
        plt.axis('off')
 
        i_ += 1
        
    return plt.show()

#segmentation metric 
def dice_coef_metric(inputs, target):
    intersection = 2.0 * (target*inputs).sum()
    union = target.sum() + inputs.sum()
    if target.sum() == 0 and inputs.sum() == 0:
        return 1.0 
    return intersection/union

#label data used for splitting data so that we can have of each class 
def pos_neg_diagnosis(mask_path):
    val = np.max(cv2.imread(mask_path))
    if val > 0: return 1
    else: return 0
    
#compute IOU
def compute_iou(model, loader, threshold=0.5):
    valloss = 0
    with torch.no_grad():

        for i_step, (data, target) in enumerate(loader):
            
            data = data.to(device)
            target = target.to(device)
            
            outputs = model(data)

            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
            picloss = dice_coef_metric(out_cut, target.data.cpu().numpy())
            valloss += picloss

    return valloss / i_step

In [None]:
def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, num_epochs):
    print(f"[INFO] Model is initializing... {model_name}")
    
    loss_history = []
    train_history = []
    val_history = []
    
    for epoch in range(num_epochs):
        model.train()
        
        losses = []
        train_iou = []
        
        for i_step, (data, target) in enumerate(tqdm(train_loader)):
            data = data.to(device)
            target = target.to(device)
            
            outputs = model(data)
            
            out_cut = np.copy(outputs.data.cpu().numpy())
            out_cut[np.nonzero(out_cut < 0.5)] = 0.0
            out_cut[np.nonzero(out_cut >= 0.5)] = 1.0
            
            train_dice = dice_coef_metric(out_cut, target.data.cpu().numpy())
            
            loss = train_loss(outputs, target)
            
            losses.append(loss.item())
            train_iou.append(train_dice)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        val_mean_iou = compute_iou(model, val_loader)
        
        loss_history.append(np.array(losses).mean())
        train_history.append(np.array(train_iou).mean())
        val_history.append(val_mean_iou)
        
        print("Epoch [%d]" % (epoch))
        print("Mean loss on train:", np.array(losses).mean(), 
              "\nMean DICE on train:", np.array(train_iou).mean(), 
              "\nMean DICE on validation:", val_mean_iou)
        
    return loss_history, train_history, val_history

def plot_model_history(model_name,
                        train_history, val_history, 
                        num_epochs):
    
    x = np.arange(num_epochs)

    fig = plt.figure(figsize=(10, 6))
    plt.plot(x, train_history, label='train dice', lw=3, c="springgreen")
    plt.plot(x, val_history, label='validation dice', lw=3, c="deeppink")

    plt.title(f"{model_name}", fontsize=15)
    plt.legend(fontsize=12)
    plt.xlabel("Epoch", fontsize=15)
    plt.ylabel("DICE", fontsize=15)

    fn = str(int(time.time())) + ".png"
    plt.show()
    
#segmentation loss
class BCEwithDiceLoss(nn.Module):
    def __init__(self):
        super(BCEwithDiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #apply sigmoid to inputs
        inputs = nn.Sigmoid()(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        #calculate dice loss
        intersection = (inputs * targets).sum()                            
        dice = 1- (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 
        
       # calculate BCE
        bce = F.binary_cross_entropy_with_logits(inputs, targets)
        
        return dice + bce
    


# Preprocess Data

In [None]:
data = []

for dir_ in os.listdir(BASE_PATH):
    dir_path = os.path.join(BASE_PATH, dir_)
    if os.path.isdir(dir_path):
        for filename in os.listdir(dir_path):
            img_path = os.path.join(dir_path, filename)
            data.append([dir_, img_path])
    else:
        print(f"[INFO] This is not a dir --> {dir_path}")
        
df = pd.DataFrame(data, columns=["dir_name", "image_path"])
df_imgs = df[~df["image_path"].str.contains("mask")]
df_masks = df[df["image_path"].str.contains("mask")]
imgs = sorted(df_imgs["image_path"].values, key= lambda x: int(x[BASE_LEN: -END_LEN]))
masks = sorted(df_masks["image_path"].values, key=lambda x: int(x[BASE_LEN: -END_MASK_LEN]))
#create new dataset based on patients
dff = pd.DataFrame({"patient": df_imgs.dir_name.values,
                   "image_path": imgs,
                   "mask_path": masks})
dff["diagnosis"] = dff["mask_path"].apply(lambda x: pos_neg_diagnosis(x))
print("Amount of patients: ", len(set(dff.patient)))
print("Amount of records: ", len(dff))

# Data Augmentation

In [None]:
SIZE = 128

transforms = A.Compose([
    A.Resize(width=SIZE, height=SIZE, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(p=1.0),
    ToTensor(),
])

In [None]:
#Data Augmentation class
class BrainMRIDataset:
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.iloc[idx, 1])
        mask = cv2.imread(self.df.iloc[idx, 2], 0)
        
        augmented = self.transforms(image=image,
                                   mask=mask)
        
        image = augmented["image"]
        mask = augmented["mask"]
        
        return image, mask

# Split Data and DataLoaders

In [None]:
train_df, test_df = train_test_split(dff, stratify=dff.diagnosis, test_size=0.1,random_state=44)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

train_df, val_df = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.1,random_state=44)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print(f"Train: {train_df.shape} \nVal: {val_df.shape} \nTest: {test_df.shape}")

In [None]:
b_size=16
train_dataset = BrainMRIDataset(train_df, transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=b_size, num_workers=2, shuffle=True)

val_dataset = BrainMRIDataset(val_df, transforms=transforms)
val_dataloader = DataLoader(val_dataset, batch_size=b_size, num_workers=2, shuffle=False)

test_dataset = BrainMRIDataset(test_df, transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=b_size, num_workers=2, shuffle=False)

In [None]:
# display image data
images, masks = next(iter(train_dataloader))
print(images.shape, masks.shape)

show_aug(images)
show_aug(masks)

# Model Blocks

In [None]:
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.conv(x)
        return x

    
    
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out,scalefactor=2,mode_='nearest',align_corners_=None):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=scalefactor,mode=mode_, align_corners=align_corners_),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.up(x)
        return x
    
 

 #for r2unet     
class Rec_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Rec_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1
        
class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Rec_block(ch_out,t=t),
            Rec_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1   
    
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi
    
#for unet 3
class ThreePlusDecoder(torch.nn.Module):
    def __init__(self, center_out_channels, level, unet_depth=5):
        super(ThreePlusDecoder, self).__init__()

        self.unet_depth = unet_depth 
        self.max_pool = torch.nn.MaxPool2d
        self.operation_list = torch.nn.ModuleList()
        
        #from center(bottlneck) to the decoder in consideration
        center_up_factor = int(2 ** (self.unet_depth - level))
        self.operation_list.append(
            up_conv(center_out_channels,64,scalefactor=center_up_factor,mode_='bilinear',align_corners_=True))

        #from previos decoders to the decoder in consideration
        for i in range(1, self.unet_depth - level):
            up_scaling_factor = int(2 ** (self.unet_depth - level - i))
            in_channels = 320
            self.operation_list.append(
                up_conv(in_channels,64,scalefactor=up_scaling_factor,mode_='bilinear',align_corners_=True))

        #from same enocoder level to the decoder in consideration
        current_scale = int(2 ** (self.unet_depth - level))
        current_in_channels = center_out_channels // current_scale
        self.current_level_operation = conv_block(current_in_channels, 64)
        self.operation_list.append(self.current_level_operation)
        
        #from encoder above to the decoder in consideration
        for i in range(1, level):
            kernel_size = int(2 ** i)
            in_channels = current_in_channels // int(2 ** i)
            self.operation_list.append(
                torch.nn.Sequential(
                    self.max_pool(kernel_size),
                    conv_block(in_channels, 64),
                )
            )
            

        self.final =  torch.nn.Conv2d(64 * self.unet_depth, 64 * self.unet_depth,kernel_size=(3, 3), padding=(1, 1))
        self.relu = torch.nn.ReLU(inplace=True)
        self.batchnorm=torch.nn.BatchNorm2d(64 * self.unet_depth)

    def forward(self, *args):
        out_list = []
        for idx, element in enumerate(args):
            out_list.append(self.operation_list[idx](element))

        x = torch.cat(out_list, dim=1)
        x= self.final(x)
        return self.relu(self.batchnorm(x))

# U-NET Model

In [None]:
class U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

# Attention U-NET Model

In [None]:
class AttentionUNet(nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(AttentionUNet,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

# R2U-NET

In [None]:
class R2U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(R2U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

# Attention R2U-NET

In [None]:
class AttentionR2U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=1,t=2):
        super(AttentionR2U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.Upsample = nn.Upsample(scale_factor=2)

        self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)

        self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
        
        self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
        
        self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
        
        self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
        

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
        
        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)

        self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)


    def forward(self,x):
        # encoding path
        x1 = self.RRCNN1(x)

        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

# UNET 3+

In [None]:
class UNet3(torch.nn.Module):
    def __init__(self,img_ch=3,output_ch=1):
        super(UNet3,self).__init__()


        self.center_channels = 1024
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.enc_1 = conv_block(img_ch,64)
        self.enc_2 = conv_block(64,128)
        self.enc_3 = conv_block(128,256)
        self.enc_4 = conv_block(256,512)

        self.center = conv_block(512, self.center_channels)

        self.dec_4 = ThreePlusDecoder(self.center_channels,  level=4)
        self.dec_3 = ThreePlusDecoder(self.center_channels, level=3)
        self.dec_2 = ThreePlusDecoder(self.center_channels, level=2)
        self.dec_1 = ThreePlusDecoder(self.center_channels, level=1)

        self.final = torch.nn.Conv2d(320, output_ch, kernel_size=(3, 3), padding=(1, 1))

    def forward(self, x):
        enc_1 = self.enc_1(x)
        pool_1=self.Maxpool(enc_1)
        enc_2= self.enc_2(pool_1)
        pool_2=self.Maxpool(enc_2)
        enc_3= self.enc_3(pool_2)
        pool_3=self.Maxpool(enc_3)
        enc_4= self.enc_4(pool_3)
        pool_4=self.Maxpool(enc_4)
        center = self.center(pool_4)

        dec_4 = self.dec_4(center, enc_4, enc_3, enc_2, enc_1)
        dec_3 = self.dec_3(center, dec_4, enc_3, enc_2, enc_1)
        dec_2 = self.dec_2(center, dec_4, dec_3, enc_2, enc_1)
        dec_1 = self.dec_1(center, dec_4, dec_3, dec_2, enc_1)

        x = self.final(dec_1)
        return x

# Training

In [None]:
num_ep = 100

## Attention-UNET

In [None]:
%%time
# atten_unet=False
if atten_unet==True:
    attention_unet = AttentionUNet().to(device)
    opt_atUnet = torch.optim.Adam(attention_unet.parameters())
    aun_lh, aun_th, aun_vh = train_model("Attention UNet", attention_unet, train_dataloader, val_dataloader, BCEwithDiceLoss(), opt_atUnet, num_ep)
    test_iou_aunet = compute_iou(attention_unet, test_dataloader)
    print(f"""Attention U-Net\nMean IoU of the test images - {np.around(test_iou_aunet, 2)*100}%""")
    plot_model_history("Attention U-Net", aun_th, aun_vh, num_ep)
    plt.plot(range(num_ep), aun_lh)

## UNET

In [None]:
%%time
# unet=True
if unet==True:
    unet = U_Net().to(device)
    opt_Unet = torch.optim.Adam(unet.parameters())
    un_lh, un_th, un_vh = train_model("UNet", unet, train_dataloader, val_dataloader, BCEwithDiceLoss(), opt_Unet, num_ep)
    test_iou = compute_iou(unet, test_dataloader)
    print(f"""U-Net\nMean IoU of the test images - {np.around(test_iou, 2)*100}%""")
    plot_model_history("U-Net", un_th, un_vh, num_ep)
    plt.plot(range(num_ep), un_lh)

## R2U-NET

In [None]:
%%time
# uR2net=False
if uR2net==True:
    R2unet = R2U_Net().to(device)
    opt_R2Unet = torch.optim.Adam(R2unet.parameters())
    R2un_lh, R2un_th, R2un_vh = train_model("R2U-Net", R2unet, train_dataloader, val_dataloader, BCEwithDiceLoss(), opt_R2Unet, num_ep)
    test_iouR2 = compute_iou(R2unet, test_dataloader)
    print(f"""R2U-Net\nMean IoU of the test images - {np.around(test_iouR2, 2)*100}%""")
    plot_model_history("R2U-Net", R2un_th, R2un_vh, num_ep)
    plt.plot(range(num_ep), R2un_lh)

## Attention R2U-NET

In [None]:
%%time
# uA2net=False
if uA2net==True:
    AR2unet = AttentionR2U_Net().to(device)
    opt_AR2Unet = torch.optim.Adam(AR2unet.parameters())
    AR2un_lh, AR2un_th, AR2un_vh = train_model("UNet", AR2unet, train_dataloader, val_dataloader, BCEwithDiceLoss(), opt_AR2Unet, num_ep)
    test_iouAR2 = compute_iou(AR2unet, test_dataloader)
    print(f"""Attention R2U-Net\nMean IoU of the test images - {np.around(test_iouAR2, 2)*100}%""")
    plot_model_history("Attention R2U-Net", AR2un_th, AR2un_vh, num_ep)
    plt.plot(range(num_ep), AR2un_lh)

## UNET 3

In [None]:
%%time
# u3net=False
if u3net==True:
    unet3 = UNet3().to(device)
    opt_Unet3 = torch.optim.Adam(unet3.parameters())
    u3_lh,u3_th, u3_vh = train_model("UNet3", unet3, train_dataloader, val_dataloader, BCEwithDiceLoss(), opt_Unet3, num_ep)
    test_iou3 = compute_iou(unet3, test_dataloader)
    print(f"""U-Net 3 \nMean IoU of the test images - {np.around(test_iou3, 2)*100}%""")
    plot_model_history("U-Net 3", u3_th, u3_vh, num_ep)
    plt.plot(range(num_ep),u3_lh)