# S-R2F2U-Net: A single-stage model for teeth segmentation

This implementation is leveraged from [Yingkai Sha’s](https://github.com/yingkaisha/keras-unet-collection) repository. Base models for recurrent, residual and attention are taken from the above mentioned repository. 

In [None]:
from google.colab import drive  
drive.mount('/content/drive')  

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Set the system path for saving and loading libraries
import sys
sys.path.append('/content/drive/MyDrive/library')

In [None]:
# Uncomment to install keras unet collection
# !pip install --target='/content/drive/MyDrive/library' keras-unet-collection

In [None]:
# Check tensorflow and keras versions
import tensorflow as tf
import keras
print(tf.__version__)
print(keras.__version__)

2.8.0
2.8.0


In [None]:
#Make sure the GPU is available. 
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.optimizers import Adam
from datetime import datetime 
import cv2
import json
from PIL import Image
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TerminateOnNaN

In [None]:
# Load directories
img_dir_train = '/content/drive/MyDrive/panoramicDentalSegmentation/patch_512/train/images'
img_dir_val = '/content/drive/MyDrive/panoramicDentalSegmentation/patch_512/val/images'
mask_dir_train = '/content/drive/MyDrive/panoramicDentalSegmentation/patch_512/train/mask'
mask_dir_val = '/content/drive/MyDrive/panoramicDentalSegmentation/patch_512/val/mask'


In [None]:
from keras_unet_collection import models, losses

### Basic

In [None]:
from __future__ import absolute_import

from tensorflow import expand_dims
from tensorflow.compat.v1 import image
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling2D, AveragePooling3D, UpSampling2D, UpSampling3D, Conv2DTranspose, Conv3DTranspose, GlobalAveragePooling2D
from tensorflow.keras.layers import Conv2D, Conv3D, Lambda
from tensorflow.keras.layers import BatchNormalization, Activation, concatenate, multiply, add
from tensorflow.keras.layers import ReLU, LeakyReLU, PReLU, ELU, Softmax

def decode_layer(X, channel, conv, pool_size, unpool, kernel_size=3, 
                 activation='ReLU', batch_norm=False, name='decode'):
    '''
    An overall decode layer, based on either upsampling or trans conv.
    
    decode_layer(X, channel, conv, pool_size, unpool, kernel_size=3,
                 activation='ReLU', batch_norm=False, name='decode')
    
    Input
    ----------
        X: input tensor.
        pool_size: the decoding factor.
        channel: (for trans conv only) number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.           
        kernel_size: size of convolution kernels. 
                     If kernel_size='auto', then it equals to the `pool_size`.
        activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
    
    * The defaut: `kernel_size=3`, is suitable for `pool_size=2`.
    
    '''
    # Set 2d or 3d convolution
    if conv == '2d': 
      conv_nd = Conv2D
      ConvTranspose_nd = Conv2DTranspose
    elif conv == '3d': 
      conv_nd = Conv3D
      ConvTranspose_nd = Conv3DTranspose
    else: raise ValueError('Wrong keyword for conv')

    # parsers
    if unpool is False:
        # trans conv configurations
        bias_flag = not batch_norm
    
    elif unpool == 'nearest':
        # upsample2d configurations
        unpool = True
        interp = 'nearest'
    
    elif (unpool is True) or (unpool == 'bilinear'):
        # upsample2d configurations
        unpool = True
        interp = 'bilinear'
    
    else:
        raise ValueError('Invalid unpool keyword')
        
    if unpool:
        if conv == '2d':
          X = UpSampling2D(size=pool_size, interpolation=interp, name='{}_unpool'.format(name))(X)
        elif conv == '3d':
          X = UpSampling2D(size=pool_size, interpolation=interp, name='{}_unpool'.format(name))(X)
        else:
          raise ValueError('Wrong keyword for conv')

    else:
        if kernel_size == 'auto':
            kernel_size = pool_size
            
        X = ConvTranspose_nd(channel, kernel_size, strides=pool_size, 
                            padding='same', name='{}_trans_conv'.format(name))(X)
        
        # batch normalization
        if batch_norm:
            X = BatchNormalization(axis=-1, name='{}_bn'.format(name))(X)     # axis changed
            
        # activation
        if activation is not None:
            activation_func = eval(activation)
            X = activation_func(name='{}_activation'.format(name))(X)
        
    return X

def encode_layer(X, channel, conv, pool_size, pool, kernel_size='auto', 
                 activation='ReLU', batch_norm=False, name='encode'):
    '''
    An overall encode layer, based on one of the:
    (1) max-pooling, (2) average-pooling, (3) strided conv2d.
    
    encode_layer(X, channel, pool_size, pool, kernel_size='auto', 
                 activation='ReLU', batch_norm=False, name='encode')
    
    Input
    ----------
        X: input tensor.
        pool_size: the encoding factor.
        channel: (for strided conv only) number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        kernel_size: size of convolution kernels. 
                     If kernel_size='auto', then it equals to the `pool_size`.
        activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
        
    '''
    # Set 2d or 3d convolution
    if conv == '2d': 
      conv_nd = Conv2D
      MaxPooling_nd = MaxPooling2D
      AveragePooling_nd = AveragePooling2D
    elif conv == '3d': 
      conv_nd = Conv3D
      MaxPooling_nd = MaxPooling2D
      AveragePooling_nd = AveragePooling2D      
    else: raise ValueError('Wrong keyword for conv or max-pool or avg-pool')

    # parsers
    if (pool in [False, True, 'max', 'ave']) is not True:
        raise ValueError('Invalid pool keyword')
        
    # maxpooling2d as default
    if pool is True:
        pool = 'max'
        
    elif pool is False:
        # stride conv configurations
        bias_flag = not batch_norm
    
    if pool == 'max':
        X = MaxPooling_nd(pool_size=pool_size, name='{}_maxpool'.format(name))(X)
        
    elif pool == 'ave':
        X = AveragePooling_nd(pool_size=pool_size, name='{}_avepool'.format(name))(X)
        
    else:
        if kernel_size == 'auto':
            kernel_size = pool_size
        
        # linear convolution with strides
        X = conv_nd(channel, kernel_size, strides=pool_size, padding='valid', use_bias=bias_flag, name='{}_stride_conv'.format(name))(X)
        
        # batch normalization
        if batch_norm:
            X = BatchNormalization(axis=-1, name='{}_bn'.format(name))(X)  # changed axis
            
        # activation
        if activation is not None:
            activation_func = eval(activation)
            X = activation_func(name='{}_activation'.format(name))(X)
            
    return X

def attention_gate(X, g, channel, conv,  
                   activation='ReLU', 
                   attention='add', name='att'):
    '''
    Self-attention gate modified from Oktay et al. 2018.
    
    attention_gate(X, g, channel,  activation='ReLU', attention='add', name='att')
    
    Input
    ----------
        X: input tensor, i.e., key and value.
        g: gated tensor, i.e., query.
        channel: number of intermediate channel.
                 Oktay et al. (2018) did not specify (denoted as F_int).
                 intermediate channel is expected to be smaller than the input channel.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        activation: a nonlinear attnetion activation.
                    The `sigma_1` in Oktay et al. 2018. Default is 'ReLU'.
        attention: 'add' for additive attention; 'multiply' for multiplicative attention.
                   Oktay et al. 2018 applied additive attention.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X_att: output tensor.
    
    '''
    # Set 2d or 3d convolution
    if conv == '2d': conv_nd = Conv2D
    elif conv == '3d': conv_nd = Conv3D
    else: raise ValueError('Wrong keyword for conv')

    activation_func = eval(activation)
    attention_func = eval(attention)
    
    # mapping the input tensor to the intermediate channel
    theta_att = conv_nd(channel, 1, use_bias=True, name='{}_theta_x'.format(name))(X)
    
    # mapping the gate tensor
    phi_g = conv_nd(channel, 1, use_bias=True, name='{}_phi_g'.format(name))(g)
    
    # ----- attention learning ----- #
    query = attention_func([theta_att, phi_g], name='{}_add'.format(name))
    
    # nonlinear activation
    f = activation_func(name='{}_activation'.format(name))(query)
    
    # linear transformation
    psi_f = conv_nd(1, 1, use_bias=True, name='{}_psi_f'.format(name))(f)
    # ------------------------------ #
    
    # sigmoid activation as attention coefficients
    coef_att = Activation('sigmoid', name='{}_sigmoid'.format(name))(psi_f)
    
    # multiplicative attention masking
    X_att = multiply([X, coef_att], name='{}_masking'.format(name))
    
    return X_att

def CONV_stack(X, channel, conv, kernel_size=3, stack_num=2, 
               dilation_rate=1, activation='ReLU', 
               batch_norm=False, name='conv_stack'):
    '''
    Stacked convolutional layers:
    (Convolutional layer --> batch normalization --> Activation)*stack_num
    
    CONV_stack(X, channel, kernel_size=3, stack_num=2, dilation_rate=1, activation='ReLU', 
               batch_norm=False, name='conv_stack')
    
    
    Input
    ----------
        X: input tensor.
        channel: number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        kernel_size: size of 2-d convolution kernels.
        stack_num: number of stacked Conv2D-BN-Activation layers.
        dilation_rate: optional dilated convolution kernel.
        activation: one of the `tensorflow.keras.layers` interface, e.g., ReLU.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor
        
    '''
    # Set 2d or 3d convolution
    if conv == '2d': conv_nd = Conv2D
    elif conv == '3d': conv_nd = Conv3D
    else: raise ValueError('Wrong keyword for conv')

    bias_flag = not batch_norm
    
    # stacking Convolutional layers
    for i in range(stack_num):
        
        activation_func = eval(activation)
        
        # linear convolution
        X = conv_nd(channel, kernel_size, padding='same', use_bias=bias_flag, 
                   dilation_rate=dilation_rate, name='{}_{}'.format(name, i))(X)
        
        # batch normalization
        if batch_norm:
            X = BatchNormalization(axis=-1, name='{}_{}_bn'.format(name, i))(X)       # axis changed
        
        # activation
        activation_func = eval(activation)
        X = activation_func(name='{}_{}_activation'.format(name, i))(X)
        
    return X

def Res_CONV_stack(X, X_skip, channel, conv, res_num, activation='ReLU', batch_norm=False, name='res_conv'):
    '''
    Stacked convolutional layers with residual path.
     
    Res_CONV_stack(X, X_skip, channel, res_num, activation='ReLU', batch_norm=False, name='res_conv')
     
    Input
    ----------
        X: input tensor.
        X_skip: the tensor that does go into the residual path 
                can be a copy of X (e.g., the identity block of ResNet).
        channel: number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d' 
        res_num: number of convolutional layers within the residual path.
        activation: one of the `tensorflow.keras.layers` interface, e.g., 'ReLU'.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
        
    '''  
    X = CONV_stack(X, channel, conv=conv, kernel_size=3, stack_num=res_num, dilation_rate=1, 
                   activation=activation, batch_norm=batch_norm, name=name)

    X = add([X_skip, X], name='{}_add'.format(name))
    
    activation_func = eval(activation)
    X = activation_func(name='{}_add_activation'.format(name))(X)
    
    return X


def CONV_output(X, conv, n_labels, kernel_size=1, activation='Softmax', name='conv_output'):
    '''
    Convolutional layer with output activation.
    
    CONV_output(X, n_labels, kernel_size=1, activation='Softmax', name='conv_output')
    
    Input
    ----------
        X: input tensor.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        n_labels: number of classification label(s).
        kernel_size: size of 2-d convolution kernels. Default is 1-by-1.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
                    Default option is 'Softmax'.
                    if None is received, then linear activation is applied.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
        
    '''
    # Set 2d or 3d convolution
    if conv == '2d': conv_nd = Conv2D
    elif conv == '3d': conv_nd = Conv3D
    else: raise ValueError('Wrong keyword for conv')

    X = conv_nd(n_labels, kernel_size, padding='same', use_bias=True, name=name)(X)
    
    if activation:
        
        if activation == 'Sigmoid':
            X = Activation('sigmoid', name='{}_activation'.format(name))(X)
            
        else:
            activation_func = eval(activation)
            X = activation_func(name='{}_activation'.format(name))(X)
            
    return X



### Hybrid U-Net (Combines recurrent, residual, and attention)

In [None]:
from __future__ import absolute_import

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

def any_recur(val) : 
  """
  Checks if it is a basic U-Net or Recurrent U-Net
  
  Input
  -----------
  val: A tuple or list. e.g. val=(True, False)

  Output
  -----------
  result: Boolean type. 'True' for recurrent U-Net, 'False' for basic U-Net
  
  """
  result = False 
  for ele in val: result += ele 
  if result > 0: result = True
  else: result = False

  return result  

def RR_CONV(X, channel, conv, kernel_size=3, stack_num=2,  
            dilation_rate=1, 
            filter_double=False, 
            recur_status=(True, True), recur_num=2, 
            is_residual=True, 
            activation='ReLU', batch_norm=False, name='rr'):
    '''
    Recurrent convolutional layers with skip connection.
    
    RR_CONV(X, channel, conv, kernel_size=3, stack_num=2, recur_num=2, activation='ReLU', batch_norm=False, name='rr')
    
    Input
    ----------
        X: input tensor.
        channel: number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'.
        kernel_size: size of 2-d convolution kernels.
        stack_num: number of stacked recurrent convolutional layers.
        recur_num: number of recurrent iterations.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
        
    '''
    assert len(recur_status) == stack_num, "Length of recur_status should be equal to stack_num"
    
    activation_func = eval(activation)    
    
    print('Recurrent Status: ', recur_status)
    print('Recurrent: ', any_recur(recur_status))
    print('Residual: ', is_residual)
    print('Filter double: ', filter_double)

    # Set 2d or 3d convolution
    if conv == '2d': conv_nd = Conv2D
    elif conv == '3d': conv_nd = Conv3D
    else: raise ValueError('Wrong keyword for conv')

    if filter_double: 
      """
      Let's say input chanel=3, base filter (or channel)=32, stack_num=3, and filter_double=True. 
      Then filter numbers will be 3 -> 32 -> 64-> 128. Filters will be doubled in each stack.
      In this case, skip-connection layer should have 128 filters, otherwise we can't add them.
      So, the filter number of skip-connection layer = base filter x 2^(stack_num-1) = 32 x 2^(3-1) = 128
      """
      layer_skip = conv_nd(channel*(2**(stack_num-1)), 1, dilation_rate=dilation_rate, name='{}_layer_skip'.format(name))(X)

      """ 
      Check if it is a basic U-Net or Recurrent U-Net.
      any_recur=False means there is no recurrent operation. So, it is a simple U-Net. In this case, layer_main=X. 
      """
      if any_recur(recur_status): layer_main = conv_nd(channel, 1, dilation_rate=dilation_rate, name='{}_layer_main'.format(name))(X)  # Recurrent U-Net       
      else: layer_main = X  # Basic U-Net    
        
    else: # no filter doubling
      layer_skip = conv_nd(channel, 1, dilation_rate=dilation_rate, name='{}_conv'.format(name))(X)
      if any_recur(recur_status): layer_main = layer_skip        
      else: layer_main = X     
    
    for i in range(stack_num):

      if i>0 and recur_status[i] and filter_double: 
        layer_main = conv_nd(channel, 1, dilation_rate=dilation_rate, name='{}_conv'.format(name))(layer_main)

      layer_res = conv_nd(channel, kernel_size, padding='same', dilation_rate=dilation_rate, name='{}_conv{}'.format(name, i))(layer_main)
      
      if batch_norm:
          layer_res = BatchNormalization(name='{}_bn{}'.format(name, i))(layer_res)
          
      layer_res = activation_func(name='{}_activation{}'.format(name, i))(layer_res)

      # Recurrent
      if recur_status[i]:
          
        for j in range(recur_num):
            
            layer_add = add([layer_res, layer_main], name='{}_add{}_{}'.format(name, i, j))
            
            layer_res = conv_nd(channel, kernel_size, padding='same', dilation_rate=dilation_rate, name='{}_conv{}_{}'.format(name, i, j))(layer_add)
            
            if batch_norm:
                layer_res = BatchNormalization(name='{}_bn{}_{}'.format(name, i, j))(layer_res)
                
            layer_res = activation_func(name='{}_activation{}_{}'.format(name, i, j))(layer_res)
          
      layer_main = layer_res

      if filter_double: channel = channel * 2 # doubling filter numbers

    # Residual
    if is_residual: out_layer = add([layer_main, layer_skip], name='{}_add{}'.format(name, i)) # for residual connection
    else: out_layer = layer_main

    return out_layer


def UNET_RR_left(X, channel, conv, kernel_size=3, 
                 stack_num=2, 
                 dilation_rate=1, 
                 filter_double=False, 
                 recur_status=(True, True), recur_num=2, 
                 is_residual=True, 
                 is_attention=False, 
                 atten_activation='ReLU', attention='add',
                 activation='ReLU', output_activation='Softmax',
                 pool=True, batch_norm=False, name='left0'):
    '''
    The encoder block of R2U-Net.
    
    UNET_RR_left(X, channel, conv, kernel_size=3, 
                 stack_num=2, recur_num=2, activation='ReLU', 
                 pool=True, batch_norm=False, name='left0')
    
    Input
    ----------
        X: input tensor.
        channel: number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        kernel_size: size of 2-d convolution kernels.
        stack_num: number of stacked recurrent convolutional layers.
        recur_num: number of recurrent iterations.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
    
    *downsampling is fixed to 2-by-2, e.g., reducing feature map sizes from 64-by-64 to 32-by-32
    '''
    pool_size = 2
    
    # maxpooling layer vs strided convolutional layers
    X = encode_layer(X, channel, conv, pool_size, pool, activation=activation, 
                     batch_norm=batch_norm, name='{}_encode'.format(name))

    # stack linear convolutional layers
    X = RR_CONV(X, channel, conv=conv, stack_num=stack_num, 
                dilation_rate=dilation_rate, filter_double=filter_double,
                recur_status=recur_status, recur_num=recur_num, 
                is_residual=is_residual,
                activation=activation, batch_norm=batch_norm, name=name)    
    return X


def UNET_RR_right(X, X_list, channel, conv, kernel_size=3, 
                  stack_num=2, 
                  dilation_rate=1, 
                  filter_double=False, 
                  recur_status=(True, True), recur_num=2, 
                  is_residual=True, 
                  is_attention=False, 
                  atten_activation='ReLU', attention='add', 
                  activation='ReLU',
                  unpool=True, batch_norm=False, name='right0'):
    '''
    The decoder block of R2U-Net.
    
    UNET_RR_right(X, X_list, channel, conv, kernel_size=3, 
                  stack_num=2, recur_num=2, activation='ReLU',
                  unpool=True, batch_norm=False, name='right0')
    
    Input
    ----------
        X: input tensor.
        X_list: a list of other tensors that connected to the input tensor.
        channel: number of convolution filters.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        kernel_size: size of 2-d convolution kernels.
        stack_num: number of stacked recurrent convolutional layers.
        recur_num: number of recurrent iterations.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.
        batch_norm: True for batch normalization, False otherwise.
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor
    
    '''
    
    pool_size = 2
    
    X = decode_layer(X, channel, conv, pool_size, unpool, 
                     activation=activation, batch_norm=batch_norm, name='{}_decode'.format(name))
    
    # linear convolutional layers before concatenation (*********)
    X = CONV_stack(X, channel, conv, kernel_size, stack_num=1, activation=activation, 
                   batch_norm=batch_norm, name='{}_conv_before_concat'.format(name))
    
    # Attention gate
    print('Attention: ', is_attention)
    if is_attention:
      X_left = attention_gate(X=X_list, g=X, channel=channel//2, conv=conv, activation=atten_activation, 
                            attention=attention, name='{}_att'.format(name)) ################### activation and attention manual removed
      # Tensor concatenation
      H = concatenate([X, X_left], axis=-1, name='{}_att_concat'.format(name))
    else:
      # Tensor concatenation
      H = concatenate([X, X_list], axis=-1, name='{}_concat'.format(name))
    # stacked linear convolutional layers after concatenation
    H = RR_CONV(H, channel, conv, stack_num=stack_num, 
                dilation_rate=dilation_rate, filter_double=filter_double,
                recur_status=recur_status, recur_num=recur_num, 
                is_residual=is_residual,
                activation=activation, batch_norm=batch_norm, name=name)

    return H

def hybrid_unet_base(input_tensor, filter_num, conv, kernel_size=3, stack_num_down=2, stack_num_up=2, 
                        dilation_rate=1, 
                        filter_double=False, 
                        recur_status=(True, True), recur_num=2, 
                        is_residual=True, 
                        is_attention=False, 
                        atten_activation='ReLU', attention='add',
                        activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet'):
    
    '''
    The base of Recurrent Residual (R2) U-Net.
    
    hybrid_unet_base(input_tensor, filter_num, conv, stack_num_down=2, stack_num_up=2, recur_num=2,
                    activation='ReLU', batch_norm=False, pool=True, unpool=True, name='res_unet')
    
    ----------
    Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network 
    based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955.
    
    Input
    ----------
        input_tensor: the input tensor of the base, e.g., `keras.layers.Inpyt((None, None, 3))`.
        filter_num: a list that defines the number of filters for each \
                    down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
                    The depth is expected as `len(filter_num)`.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block.
        stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block.
        recur_num: number of recurrent iterations.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        batch_norm: True for batch normalization.
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.                 
        name: prefix of the created keras layers.
        
    Output
    ----------
        X: output tensor.
    
    '''
    
    activation_func = eval(activation)

    X = input_tensor
    X_skip = []

    # downsampling blocks
    X = RR_CONV(X, filter_num[0], conv=conv, stack_num=stack_num_down,
                dilation_rate=dilation_rate, filter_double=filter_double,                
                recur_status=recur_status, recur_num=recur_num, 
                is_residual=is_residual,
                activation=activation, batch_norm=batch_norm, name='{}_down0'.format(name))
    X_skip.append(X)
    
    for i, f in enumerate(filter_num[1:]):
        X = UNET_RR_left(X, f, conv=conv, kernel_size=kernel_size,
                          stack_num=stack_num_down, 
                          dilation_rate=dilation_rate, 
                          filter_double=filter_double, 
                          recur_status=recur_status, recur_num=recur_num, 
                          is_residual=is_residual, 
                          is_attention=is_attention, 
                          atten_activation=atten_activation, attention=attention,
                          activation=activation, pool=pool, batch_norm=batch_norm, name='{}_down{}'.format(name, i+1))        
        X_skip.append(X)
    
    # upsampling blocks
    X_skip = X_skip[:-1][::-1]
    for i, f in enumerate(filter_num[:-1][::-1]):
        X = UNET_RR_right(X, X_skip[i], f, conv=conv, kernel_size=kernel_size, 
                          stack_num=stack_num_up, 
                          dilation_rate=dilation_rate, 
                          filter_double=filter_double, 
                          recur_status=recur_status, recur_num=recur_num, 
                          is_residual=is_residual, 
                          is_attention=is_attention, 
                          atten_activation=atten_activation, attention=attention,
                          activation=activation, unpool=unpool, batch_norm=batch_norm, name='{}_up{}'.format(name, i+1))
    
    return X

def hybrid_unet(input_size, filter_num, n_labels, conv,
               stack_num_down=2, stack_num_up=2, 
               dilation_rate=1,
               filter_double=False,
               recur_status=(True, True), recur_num=2,
               is_residual=True,
               is_attention=True,
               atten_activation='ReLU', attention='add',
               activation='ReLU', output_activation='Softmax', 
               batch_norm=False, pool=True, unpool=True, name='hybrid_unet'):

    '''
    Recurrent Residual (R2) U-Net
    
    hybrid_unet(input_size, filter_num, n_labels, conv,
               stack_num_down=2, stack_num_up=2, recur_num=2,
               activation='ReLU', output_activation='Softmax', 
               batch_norm=False, pool=True, unpool=True, name='hybrid_unet')
    
    ----------
    Alom, M.Z., Hasan, M., Yakopcic, C., Taha, T.M. and Asari, V.K., 2018. Recurrent residual convolutional neural network 
    based on u-net (r2u-net) for medical image segmentation. arXiv preprint arXiv:1802.06955.
    
    Input
    ----------
        input_size: the size/shape of network input, e.g., `(128, 128, 3)`.
        filter_num: a list that defines the number of filters for each \
                    down- and upsampling levels. e.g., `[64, 128, 256, 512]`.
                    The depth is expected as `len(filter_num)`.
        n_labels: number of output labels.
        conv: (str) 2d or 3d convolution. e.g. '2d' or '3d'
        stack_num_down: number of stacked recurrent convolutional layers per downsampling level/block.
        stack_num_down: number of stacked recurrent convolutional layers per upsampling level/block.
        recur_num: number of recurrent iterations.
        activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interfaces, e.g., 'ReLU'.
        output_activation: one of the `tensorflow.keras.layers` or `keras_unet_collection.activations` interface or 'Sigmoid'.
                           Default option is 'Softmax'.
                           if None is received, then linear activation is applied.     
        batch_norm: True for batch normalization.
        pool: True or 'max' for MaxPooling2D.
              'ave' for AveragePooling2D.
              False for strided conv + batch norm + activation.
        unpool: True or 'bilinear' for Upsampling2D with bilinear interpolation.
                'nearest' for Upsampling2D with nearest interpolation.
                False for Conv2DTranspose + batch norm + activation.                  
        name: prefix of the created keras layers.
        
    Output
    ----------
        model: a keras model.
    
    '''
    
    activation_func = eval(activation)

    IN = Input(input_size, name='{}_input'.format(name))

    # base
    X = hybrid_unet_base(IN, filter_num, conv=conv, kernel_size=3,
                        stack_num_down=stack_num_down, stack_num_up=stack_num_up, 
                        dilation_rate=dilation_rate, 
                        filter_double=filter_double, 
                        recur_status=recur_status, recur_num=recur_num, 
                        is_residual=is_residual, 
                        is_attention=is_attention, 
                        atten_activation=atten_activation, attention=attention,
                        activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name)
    # output layer
    OUT = CONV_output(X, conv, n_labels, kernel_size=1, activation=output_activation, name='{}_output'.format(name))
    
    # functional API model
    model = Model(inputs=[IN], outputs=[OUT], name='{}_model'.format(name))
    
    return model 

### Model parameters

In [None]:
# Hyper-parameters
IMG_HEIGHT = 512 #img_train.shape[1]
IMG_WIDTH  = 512 #img_train.shape[2]
IMG_DEPTH  = 512 # for 3D
IMG_CHANNELS = 3 #img_train.shape[3]
CONV = '2d'
NUM_LABELS = 2  #Binary
input_shape = (IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS)
# input_shape = (IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH, IMG_CHANNELS) # uncomment for 3D
batch_size = 2
FILTER_NUM = [32, 64, 128, 256, 512]
STACK_NUM_DOWN = 2
STACK_NUM_UP = 2
DILATION_RATE = 1
FILTER_DOUBLE = True
RECUR_STATUS = (False, True)
RECUR_NUM = 2
IS_RESIDUAL = True
IS_ATTENTION = True
ATTENTION_ACTIVATION = 'ReLU'
ATTENTION = 'add'
ACTIVATION = 'ReLU'
OUTPUT_ACTIVATION = 'Softmax'
BATCH_NORM = True
POOL = False
UNPOOL = False
RETRAIN = False

### Model

In [None]:
# Current version works for "stack_num_down = stack_num_up" only
model = hybrid_unet(input_shape, filter_num=FILTER_NUM, 
                       n_labels=NUM_LABELS, 
                       conv = CONV,
                       stack_num_down=STACK_NUM_DOWN, stack_num_up=STACK_NUM_UP, 
                       dilation_rate=DILATION_RATE,
                       filter_double=FILTER_DOUBLE,
                       recur_status=RECUR_STATUS, recur_num=RECUR_NUM,
                       is_residual=IS_RESIDUAL,
                       is_attention=IS_ATTENTION,
                       atten_activation=ATTENTION_ACTIVATION, attention=ATTENTION,
                       activation=ACTIVATION, output_activation=OUTPUT_ACTIVATION, 
                       batch_norm=BATCH_NORM, pool=POOL, unpool=UNPOOL, name='hybrid_unet')


In [None]:
model.summary()

### Loss function

A hybrid loss function is implemented consisting of cross-entropy loss and dice coefficient loss.

In [None]:
def loss_function(y_true, y_pred):
    y_true = tf.image.convert_image_dtype(y_true, tf.float32)
    y_pred = tf.image.convert_image_dtype(y_pred, tf.float32)
    
    loss_ce = keras.losses.categorical_crossentropy(y_true, y_pred)
    loss_dice_coef = losses.dice_coef(y_true, y_pred)
    
    return loss_ce + (0.5*loss_dice_coef)

### Compile

In [None]:
model.compile(loss=loss_function, optimizer=Adam(learning_rate = 1e-3), 
              metrics=['accuracy', losses.dice_coef])

# print(model.summary())

### Callbacks and checkpoints

In [None]:
# Create checkpoint
checkpoint_folder = "recur_F_T_residual_filter_double"
checkpoint_subfolder = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
checkpoint_base_dir = "/content/drive/MyDrive/panoramicDentalSegmentation/checkpoints/hybrid_unet_2d"
checkpoint_loc = checkpoint_base_dir + '//' + checkpoint_folder + "//" + checkpoint_subfolder
log_path = "/content/drive/MyDrive/panoramicDentalSegmentation/logs/"

# Create checkpoint directory if does not exist
if not os.path.exists(checkpoint_loc):
    os.makedirs(checkpoint_loc)
    
checkpoint_path = os.path.join(checkpoint_loc, "cp-{epoch:04d}.ckpt")
checkpoint_dir = os.path.dirname(checkpoint_path)

callbacks = [
    # EarlyStopping(monitor='', patience=400, verbose=1),
    ReduceLROnPlateau(factor=0.1,
                      monitor='val_loss',
                      patience=10,
                      min_lr=0.00001,
                      verbose=1,
                      mode='auto'),
    ModelCheckpoint(checkpoint_path,
                      monitor = 'val_loss',
                      verbose = 1,
                      save_best_only=False,
                      save_weights_only=False,
                      period=5),
    CSVLogger(os.path.join(log_path, datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv'), separator=',', append=True),
    # TerminateOnNaN()
]

print(checkpoint_folder)
print(checkpoint_subfolder)

### Save info

Create a dictionary file to store model information

In [None]:
info_data = {
    'network_name': checkpoint_folder,
    'checkpoint_subfolder': checkpoint_subfolder, 
    'loss': "loss_ce + (0.5*loss_dice_coef)", 
    'conv': CONV,
    'num_labels': NUM_LABELS,  
    'input_shape': (IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS),
    'batch_size': batch_size,
    'filters': FILTER_NUM,
    'stack_num_down': STACK_NUM_DOWN,
    'stack_num_up': STACK_NUM_UP,
    'dilation_rate': DILATION_RATE,
    'filter_double': FILTER_DOUBLE,
    'recur_status': RECUR_STATUS,
    'recur_num': RECUR_NUM,
    'is_residual': IS_RESIDUAL,
    'is_attention': IS_ATTENTION,
    'attention_activation': ATTENTION_ACTIVATION,
    'attention': ATTENTION,
    'activation': ACTIVATION,
    'output_activation': OUTPUT_ACTIVATION,
    'batch_norm': BATCH_NORM,
    'pool': POOL,
    'unpool': UNPOOL,
    'retrain': RETRAIN,
}

# Save in json file
json_name = checkpoint_folder + "_" + checkpoint_subfolder
info_file = open(os.path.join(checkpoint_base_dir, json_name), "w")
json.dump(info_data, info_file)
info_file.close()

### Data loader

It loads data batch-wise from a given directory. <br>

Loading the entire training and validation dataset is memory expensive for Colab. Instead, a data loader class called `DataGenerator` is implemented that loads images on-the-fly. In other words, it takes a list of image names and the directory where the images are situated. Then it loads images batch-wise while training the model. <br>

Reference: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

In [None]:
import tensorflow
from tensorflow.keras.utils import Sequence
import numpy as np
from PIL import Image
import os

class DataGenerator(Sequence):
    def __init__(self,
                 list_IDs,
                 dir_image,
                 dir_mask,
                 batch_size=1,
                 dim=(512, 512),
                 n_channels_image=3,
                 n_channels_mask=1,
                 n_classes=2,
                 shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.dir_image = dir_image
        self.dir_mask = dir_mask
        self.n_channels_image = n_channels_image
        self.n_channels_mask = n_channels_mask
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):

        # Counts the number of possible batches that can be made from the total available datasets in list_IDs
        # Rule of thumb, num_datasets % batch_size = 0, so every sample is seen
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):

        # Gets the indexes of batch_size number of data from list_IDs for one epoch
        # If batch_size = 8, 8 files/indexes from list_ID are selected
        # Makes sure that on next epoch, the batch does not come from same indexes as the previous batch
        # The same batch is not seen again until __len()__ - 1 batches are done

        indexes = self.indexes[index * self.batch_size:(index + 1) *
                               self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):

        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):

            X = np.empty((self.batch_size, *self.dim, self.n_channels_image)) # 3 for color image
            y = np.empty((self.batch_size, *self.dim, self.n_channels_mask)) # 1 for binary/grayscale image

            for i, ID in enumerate(list_IDs_temp):
                # Write logic for selecting/manipulating X and y here
                ID_only = os.path.splitext(ID)[0]
                xt = Image.open(os.path.join(self.dir_image, ID_only + '.png'))                
                yt = Image.open(os.path.join(self.dir_mask, ID_only + '.png'))

                # Do normalization
                xt = np.array(xt)/255.0
                yt = np.expand_dims(np.array(yt), axis=-1) * 1       # Here, no normalization is needed for the mask image. Because it's a binary image.

                X[i,] = xt
                y[i,] = yt

            y = tensorflow.keras.utils.to_categorical(y, num_classes=self.n_classes, dtype ="int8") # num_classes may vary 
            return X, y

### Return to main code

In [None]:
list_IDs_train = os.listdir(img_dir_train)
list_IDs_val = os.listdir(img_dir_val)

# Uncomment if you want to work with selected numbers of images
selected_img = False
if selected_img:
  import random
  random.shuffle(list_IDs_train)
  random.shuffle(list_IDs_val)
  list_IDs_train = list_IDs_train[:1000]
  list_IDs_val = list_IDs_val[:100]
  print(list_IDs_train)

# Call DataGenerator
train_gen = DataGenerator(list_IDs=list_IDs_train,
                          dir_image=img_dir_train,
                          dir_mask=mask_dir_train,
                          n_channels_image=3,
                          n_channels_mask=1,
                          n_classes=NUM_LABELS,
                          dim=(IMG_HEIGHT,IMG_WIDTH),
                          batch_size=batch_size,
                          shuffle=True)

val_gen = DataGenerator(list_IDs=list_IDs_val,
                          dir_image=img_dir_val,
                          dir_mask=mask_dir_val,
                          n_channels_image=3,
                          n_channels_mask=1,
                          n_classes=NUM_LABELS,
                          dim=(IMG_HEIGHT,IMG_WIDTH),
                          batch_size=batch_size,
                          shuffle=True)



### Run the cell to retrain

Run the following cell if you want to start training from a given checkpoint. 

In [None]:
if RETRAIN:
  # Load model from checkpoint
  from keras.models import load_model
  model = load_model('/content/drive/MyDrive/panoramicDentalSegmentation/checkpoints/hybrid_unet_2d/recur_F_T_residual_filter_double/2022-02-24_06-24-34/cp-0025.ckpt', 
                    compile=False,
                    custom_objects={'dice_coef':losses.dice_coef})
  model.compile(loss=loss_function, optimizer=Adam(learning_rate = 1e-3), 
              metrics=['accuracy', losses.dice_coef])

### Train

In [None]:
hist = model.fit_generator(train_gen,
                    steps_per_epoch=len(train_gen),
                    validation_data = val_gen,
                    epochs=50,
                    verbose=1,
                    callbacks=callbacks,
                    shuffle = False,  # already shuffled in data generator 
                    # workers=20,
                    )

### Loss curve

In [None]:
#plot the training and validation IoU and loss at each epoch
dir_plot_save = '/content/drive/MyDrive/panoramicDentalSegmentation/'
loss = hist.history['loss']
val_loss = hist.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.savefig(os.path.join(dir_plot_save,'model_loss.png'))

### Inference

In [None]:
# Uncomment to install jenti. It is required to create (or merge) patches.
# !pip install --target='/content/drive/MyDrive/library' jenti

In [None]:
from jenti.patch import Patch, Merge
import pandas as pd
from sklearn.metrics import confusion_matrix

In [None]:
checkpoint_base_dir = "/content/drive/MyDrive/panoramicDentalSegmentation/checkpoints/hybrid_unet_2d"

In [None]:
os.listdir(checkpoint_base_dir)

In [None]:
# Load json file. It contains model info
info_file = open(os.path.join(checkpoint_base_dir, json_name), "r")
output = info_file.read()
info_file.close()

In [None]:
# Load model from checkpoint
from keras.models import load_model
# checkpoint_folder = output["network_name"] + "//" + output["checkpoint_subfolder"]
checkpoint_folder = 'recur_F_T_residual_filter_double' + '//' + '2022-02-24_06-24-34'
checkpoint_name = 'cp-0025'
model = load_model('/content/drive/MyDrive/panoramicDentalSegmentation/checkpoints/hybrid_unet_2d/' + checkpoint_folder + '//' + checkpoint_name + '.ckpt', 
                   compile=False,
                   custom_objects={'dice_coef':losses.dice_coef})

In [None]:
model.summary()

In [None]:
# Directories 
load_dir_test_img = '/content/drive/MyDrive/panoramicDentalSegmentation/test_dataset/images'
load_dir_test_mask = '/content/drive/MyDrive/panoramicDentalSegmentation/test_dataset/mask'

save_dir_pred = '/content/drive/MyDrive/panoramicDentalSegmentation/prediction/hybrid_unet_2d/' + checkpoint_folder + '//' + checkpoint_name
if not os.path.exists(save_dir_pred): os.makedirs(save_dir_pred)

In [None]:
# Test names
names_test = os.listdir(load_dir_test_img)

In [None]:
# Create dataframe to store records
df = pd.DataFrame(index=[], columns = [
    'Name', 'Accuracy', 'Specificity', 'Precision', 'Recall', 'Dice'], dtype='object')

In [None]:
# Evaluation
patch_shape = [512, 512]
overlap = [10, 10]
save_pred = True
# Iterate over test samples
for i, name in enumerate(names_test):
  name_only = os.path.splitext(name)[0]
  # Load image and mask
  im = Image.open(os.path.join(load_dir_test_img, name))
  mask = Image.open(os.path.join(load_dir_test_mask, name))
  im = np.array(im) # convert to array
  im = im/255.0 # normalization
  mask = np.array(mask)*1 # convert to array. Multiplied by 1 to covert from boolean to int
  mask = mask.astype('int8')
  # Create patches from the image
  patch = Patch(patch_shape, overlap, patch_name=name_only, csv_output=False)
  patches, info, org_shape_im = patch.patch2d(im)
  org_shape_mask = (org_shape_im[0], org_shape_im[1], 1) # mask is a binary image
  # Iterate over patches
  patchwise_pred = [] # store patch-wise predictions for each test sample
  for patch in patches:
    patch = np.expand_dims(patch, axis=0) # shape: 1 x 512 x 512 x 3
    pred2 = model.predict(patch)
    # print(pred2.shape)
    pred = np.argmax(model.predict(patch), axis=-1) # shape: 1 x 512 x 512
    pred = np.expand_dims(np.squeeze(pred), axis=-1) # shape: 512 x 512 x 1
    # print(np.max(pred))
    patchwise_pred.append(pred)
  # Merge patches
  merge = Merge(info, org_shape_mask, dtype='int8') # create object
  merged = merge.merge2d(patchwise_pred)
  # Save prediction as png
  if save_pred:
    merged_im = Image.fromarray((np.squeeze(merged)*255 ).astype(np.uint8))
    merged_im.save(os.path.join(save_dir_pred, 'pred_' + name_only + '.png'))
  # Calculate accuracy, specificity, precision, recall, and dice
  tn, fp, fn, tp = confusion_matrix(np.squeeze(mask).flatten(), np.squeeze(merged).flatten()).ravel()
  acc = ((tp + tn)/(tp + tn + fn + fp))*100  
  sp = (tn/(tn + fp))*100
  p = (tp/(tp+fp))*100
  r = (tp/(tp+fn))*100
  # f1 = ((2 * p * r)/(p + r))*100
  dice = (2 * tp / (2 * tp + fp + fn))*100
  print("Img # {:1s}, Image {:1s}: acc: {:3f}, sp: {:3f}, p: {:3f}, r: {:3f}, dice: {:3f}".format(str(i+1), name_only, acc, sp, p, r, dice))
  # Add to dataframe
  tmp = pd.Series([name_only, acc, sp, p, r, dice], index=['Name', 'Accuracy', 'Specificity', 'Precision', 'Recall', 'Dice'])
  df = df.append(tmp, ignore_index = True)
  df.to_csv(os.path.join(save_dir_pred, 'result.csv'), index=False)

print("Mean Accuracy: ", df["Accuracy"].mean())
print("Mean Specificity: ", df["Specificity"].mean())
print("Mean precision: ", df["Precision"].mean())
print("Mean recall: ", df["Recall"].mean())
print("Mean Dice: ", df["Dice"].mean())