<a href="https://colab.research.google.com/github/girishsenthil/Medical-Imaging-Models/blob/main/3DResidualUNetwithAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import Input, Model, models, layers, losses, metrics, optimizers
from tensorflow.keras import backend as K

### Building Blocks for Model

In [2]:
kwargs = {
    'kernel_size': (3, 3, 3),
    'padding': 'same',
    'kernel_initializer': 'he_normal'
}

In [8]:
conv = lambda x, filters, strides: layers.Conv3D(filters= filters, strides = strides, **kwargs)(x)
tran = lambda x, filters, strides: layers.Conv3DTranspose(filters = filters, strides = strides, **kwargs)(x)

In [9]:
norm = lambda x: layers.BatchNormalization()(x)

relu = lambda x: layers.ReLU()(x)
sigmoid = lambda x: layers.Activation('sigmoid')(x)

In [10]:
conv1 = lambda filters, x: relu(norm(conv(x, filters, strides = 1)))
conv2 = lambda filters, x: relu(norm(conv(x, filters, strides = (2, 2, 2))))
tran2 = lambda filters, x: relu(norm(conv(x, filters, strides = (2, 2, 2))))

In [11]:
conv1_block = lambda filters, x: conv1(filters, conv1(filters, x))

In [6]:
def residual_block(filters, x):
  '''
  Residual Block combines normal two layer convolution with the original 
     layer projected to form a skip connection
     
     params: filters (desired filter size of next layer)
             x       (previous layer to form skip connection with)
  '''

  l1 = conv1_block(filters, x)

  #projecting x to same filter size as l1

  residual = norm(layers.Conv3D(filters = filters,
                                kernel_size = (1, 1, 1),
                                padding = 'same')(x))
  final_layer = relu(residual + l1)

  return final_layer


In [7]:
def gating_signal(x, out_size):
  '''
  Gating signal for implementation in Attention Module
     
     params: x         (layer that is used as gate before attention)
             out_size  (filter size to project layer x to for implementation in
                        attention module)
  '''

  res = relu(norm(layers.Conv3D(filters = out_size,
                                kernel_size = (1, 1, 1),
                                strides = (1, 1, 1),
                                padding = 'same')(x)))
  return res

In [12]:
def attention_block(x, g):
  '''Attention Module from Oktay Paper, adapted from DigitalSreeni Youtube Channel'''

  x_shape, g_shape = K.int_shape(x), K.int_shape(g)

  #Reduce feature map size of x

  theta_x = layers.Conv3D(filters = x_shape[-1],
                          kernel_size = (1, 1, 1),
                          strides = (2, 2, 2),
                          padding = 'same')(x)

  theta_x_shape = K.int_shape(theta_x)

  #Project gate to be same filter size as x

  phi_g = layers.Conv3D(filters = x_shape[-1],
                        kernel_size = (1, 1, 1),
                        strides = (1, 1, 1),
                        padding = 'same')(g)
  
  phi_g_shape = K.int_shape(phi_g)

  strides = tuple(map(lambda i, j: i//j, x_shape[1:-1], phi_g_shape[1:-1]))

  tran_g = layers.Conv3DTranspose(filters = x_shape[-1],
                                  kernel_size = (3, 3, 3),
                                  strides = strides,
                                  padding = 'same')(phi_g)

  xg = relu(phi_g + theta_x)

  psi = sigmoid(layers.Conv3D(filters = 1,
                              kernel_size = (1, 1, 1),
                              strides = (1, 1, 1),
                              padding = 'same')(xg))
  
  psi_shape = K.int_shape(psi)

  size = tuple(map(lambda i, j: i//j, x_shape[1:-1], psi_shape[1:-1]))

  upsampled_psi = layers.UpSampling3D(size = size)(psi)

  res = norm(layers.Conv3D(filters = x_shape[-1],
                           kernel_size = (1, 1, 1),
                           strides = (1, 1, 1),
                           padding = 'same')(layers.multiply([upsampled_psi, x])))
  
  return res

In [13]:
def ResAttentionUnet(input_shape, input_dtype, filters, logit_filters):

  '''
  UNet with Residual Blocks and Attention Module

    inputs : input_shape    (shape of 3D data input this model will be trained on)
             input_dtype    (dtype of input)
             filters        (desired intial depth of filters. Deepest layer will have 
                            filters*16 feature maps)
             logit_filters  (desired final layer filters, typically amount of classes
                             in segmentation task)
    return ; Model Backbone 
  '''

  x = Input(shape = input_shape, dtype = input_dtype)



  l1 = residual_block(filters, x)
  l2 = conv1_block(filters, conv2(filters, l1))

  l3 = residual_block(filters*2, l2)
  l4 = conv1_block(filters*2, conv2(filters*2, l3))

  l5 = residual_block(filters*4, l4)
  l6 = conv1_block(filters*4, conv2(filters*4, l5))

  l7 = residual_block(filters*8, l6)
  l8 = conv1_block(filters*8, conv2(filters*8, l7))

  l9 = conv1_block(filters*16, residual_block(filters*16, l8))

  g1 = gating_signal(l9, filters*8)
  a1 = attention_block(l7, g1)

  l10 = tran2(filters*8, l9)
  l11 = conv1_block(filters*8, l10 + a1)
  l12 = residual_block(filters*8, l11)

  g2 = gating_signal(l7, filters*4)
  a2 = attention_block(l5, g2)

  l13 = tran2(filters*4, l12)
  l14 = conv1_block(filters*4, l13 + a2)
  l15 = residual_block(filters*4, l14)

  g3 = gating_signal(l5, filters*2)
  a3 = attention_block(l3, g3)

  l16 = tran2(filters*2, l15)
  l17 = conv1_block(filters*2, l16 + a3)
  l18 = residual_block(filters*2, l17)

  g4 = gating_signal(l3, filters)
  a4 = attention_block(l1, g4)

  l19 = tran2(filters, l18)
  l20 = conv1_block(filters, l19 + a4)
  l21 = residual_block(filters, l20)

  outputs = layers.Conv3D(filters = 2, **kwargs)(l21)

  backbone = Model(inputs = x, outputs = outputs)

  return backbone