### Import required modules


In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy, mean_squared_error
from tensorflow.keras.optimizers import Adam 
from tensorflow.keras.backend import cast, flatten, sum
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow import Tensor
from typing import List

### Utility functions to load data

In [None]:
 def DataGenerator():
     pass

### Define building blocks for network

In [None]:
def DownSample(input_tensor: Tensor,filters:int, name:str)->Tensor:
    x: Tensor = Conv2D(filters=filters, kernel_size=(
        3, 3), strides=1, padding='same', name=name + 'Conv1')(input_tensor)
    x = ReLU(name=name + "ReLU1")(x)
    x = Conv2D(filters=filters, kernel_size=(
        3, 3), strides=1, padding='same', name=name + 'Conv2')(x)
    x = ReLU(name=name + "ReLU2")(x)
    x = MaxPool2D(pool_size=(2, 2), strides=2, name=name + "MaxPool")(x)
    return x

def UpSample(input_tensor:Tensor,tensor_list:List[Tensor], filters:int, name:str)->Tensor:
    x: Tensor = UpSampling2D(size=(2,2), name=name+"up_sample")(input_tensor)
    x = Conv2D(filters=filters, kernel_size=(2,2),padding='same',name=name+"Conv1")(x)
    x = ReLU()(x)
    tensor_list.append(x)
    x = Concatenate(tensor_list,axis=2, name=name+"concatenated_block")
    x = Conv2D(filters=filters,kernel_size=(3,3),padding='same',name=name+"Conv2")(x)
    x = ReLU(x)
    return x

def AttnGate(x, g, channels)->Tensor:
    theta_x = Conv2D(channels, kernel_size=(1, 1), padding='same')(x)
    theta_x = BatchNormalization()(theta_x)
    phi_g = Conv2D(channels, kernel_size=(1, 1), padding='same')(g)
    phi_g = BatchNormalization()(phi_g)
    concat_xg = Add()([theta_x, phi_g])
    concat_act = ReLU()(concat_xg)
    psi_xg = Conv2D(1, kernel_size=(1, 1), padding='same')(concat_act)
    psi_xg = Activation('sigmoid')(psi_xg)
    y = Multiply()([psi_xg, x])
    return y

### Define Losses and Optimizers

In [None]:
def dice_coef(y_true, y_pred, smooth=1)->float:
    y_true_f = cast(flatten(y_true), dtype='float64')
    y_pred_f = cast(flatten(y_pred), dtype='float64')
    intersection = sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (sum(y_true_f) + sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred)->float:
    return 1. - dice_coef(y_true = y_true, y_pred = y_pred)

optimizer = Adam(lr=0.0001)

### Building the Network 


In [None]:
def UnetPlus()-> Model:
    ##################################  Down sampling block ########################################
    input_img = Input(shape=(512,512,1),batch_size=64,name="input")
    x00:Tensor = DownSample(input_tensor=input_img,filters=64,name='00')                   # shape: (bs,256,256,64)
    x10:Tensor = DownSample(input_tensor=x00,filters=128,name="10")                        # shape: (bs,128,128,128)
    x20:Tensor = DownSample(input_tensor=x10,filters=256,name='20')                        # shape: (bs,64,64,256)
    x30:Tensor = DownSample(input_tensor=x20,filters=512,name='30')                        # shape: (bs,32,32,512)
    x40:Tensor = DownSample(input_tensor=x30,filters=1024,name='40')                       # shape: (bs,16,16,1024)   

    ##################################  Single Skip Connection Nodes #########################################
    x01:Tensor = UpSample(input_tensor=x10,tensor_list=[x00],filters=1, name='01')         # shape: (bs,512,512,1) #output1
    x11:Tensor = UpSample(input_tensor=x20,tensor_list=[x10],filters=64, name='11')        # shape: (bs,256,256,64)
    x21:Tensor = UpSample(input_tensor=x30,tensor_list=[x20],filters=128, name='21')       # shape: (bs,128,128,128)
    x31:Tensor = UpSample(input_tensor=x40,tensor_list=[x30],filters=256, name='31')       # shape: (bs,64,64,256)
    
    ################################## Double Skip Connection Nodes ##########################################
    x02:Tensor = UpSample(input_tensor=x11,tensor_list=[x00,x01],filters=1,name='02')      # shape: (bs, 512,512,1) #output2
    x12:Tensor = UpSample(input_tensor=x21,tensor_list=[x10,x11],filters=64,name='12')     # shape: (bs,256,256,64)
    x22:Tensor = UpSample(input_tensor=x31,tensor_list=[x20,x21],filters=256, name='22')   # shape: (bs,64,64,256)
    
    ################################# Triple Skip Connection Nodes ##########################################
    x03:Tensor = UpSample(input_tensor=x12,tensor_list=[x00,x01,x02],filters=1,name='03')  # shape: (bs, 512,512,1) #output3
    x13:Tensor = UpSample(input_tensor=x22,tensor_list=[x10,x11,x12],filters=1,name='13')  # shape: (bs, 512,512,1) #output3
    
    ################################### FINAL NODE #########################################
    x04:Tensor = UpSample(input_tensor=x13,tensor_list=[x00,x01,x02,x03],filters=512,name='04') #shape: (bs, 512,512,1) #output4
    
    Unetplus:Model = Model(input_img,outputs=[x04,x03,x02,x01])
    return Unetplus

### Compile the model

In [None]:
SegmentationModel:Model = UnetPlus()
losses = [dice_loss] * 4
SegmentationModel.compile(optimizer=optimizer,loss=losses)
