In [2]:
import sys
import time
import tensorflow as tf
from utils_model import *
from tensorflow.keras.layers import Conv2D, Add, Multiply
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Softmax, Input
from tensorflow.keras import Model

In [29]:
def PAM(inp_feature, layer_name, kernel_initializer='glorot_uniform', acti='relu'):
    '''
    Position attention module
    by default input shape => [w,h,c],[240, 240, 128] hence c/8 = 16
    :param layer_name: List of layer names
    [1st conv block, 2nd conv block, softmax output, 3rd conv block, position coefficient, Add output]
    :param inp_feature: feature maps of res block after up sampling [w,h,c]
    :return: PAM features [w/4,h/4,c]
    '''
    # dimensions
    b,w,h,c = inp_feature.shape
    # scale down ratio
    c_8 = c//8
    #
    assert len(layer_name)>=5, 'Layer list length should be 5!'
    # Branch01 Dimension: [w,h,c/8] => [(wxh),c/8]
    query = conv_2d(inp_feature, filters=c_8, layer_name=layer_name[0], batch_norm=False, kernel_size=(1, 1), acti=acti,
            kernel_initializer=kernel_initializer, dropout_rate=None)
    query = tf.reshape(query,[-1,(w*h),c_8 ])
    # Branch02 Dimension: [w,h,c/8] => [c/8,(wxh)]
    key = conv_2d(inp_feature, filters=c_8, layer_name=layer_name[1], batch_norm=False, kernel_size=(1, 1), acti=acti,
        kernel_initializer=kernel_initializer, dropout_rate=None)
    key = tf.reshape(key, [-1,(w*h),c_8 ])
    key = tf.einsum('bij->bji', key) # transpose/permutation
    # matmul pipeline 01 & 02
    matmul_0102 = tf.einsum('bij,bjk->bik', query, key) # [(wxh),(wxh)]
    #attention coefficient
    alpha_p = Softmax(name=layer_name[2])(matmul_0102) # [(wxh),(wxh)]
    # Branch03
    value = conv_2d(inp_feature, filters=c, layer_name=layer_name[3], batch_norm=False, kernel_size=(1, 1), acti=acti,
        kernel_initializer=kernel_initializer, dropout_rate=None)
    value = tf.reshape(value,[-1,(w*h),c]) # [(wxh),c]
    matmul_all = tf.einsum('bij,bjk->bik',alpha_p,value) # [(wxh),c]
    # Output
    output = tf.reshape(matmul_all, [-1,w,h,c]) # [w,h,c]
    # learnable coefficient to control the importance of CAM
    lambda_p = Conv2D(filters=1,kernel_size=1, padding='same',activation='sigmoid', name=layer_name[4])(inp_feature)
    output = Multiply()([output, lambda_p])
    output_add = Add(name = layer_name[-1])([output, inp_feature])
    return output_add

def CAM(inp_feature, layer_name):
    '''
    Channel attention module
    by default input shape => [w,h,c],[240, 240, 128] hence c/8 = 16
    :param inp_feature: feature maps of res block after up sampling [w,h,c]k
    :param layer_name: List of layer names
        [softmax output, channel attention coefficients, Add output]
    :return: CAM features [w/4,h/4,c]
    '''
    # dimensions
    b,w,h,c = inp_feature.shape
    # learnable coefficient to control the importance of CAM
    assert len(layer_name)>=2, 'Layer list length should be 2!'
    # Branch01 Dimension: [w,h,c] => [(wxh),c]
    query = tf.reshape(inp_feature, [-1,(w*h),c])
    # Branch02 Dimension: [w,h,c] => [c,(wxh)]
    key = tf.reshape(inp_feature, [-1,(w*h),c]) # [(wxh),c]
    key = tf.einsum('ijk->ikj', key) # Permute:[c,(wxh)]
    # matmul pipeline 01 & 02
    matmul_0201 = tf.einsum('ijk,ikl->ijl', key, query) # [c,c]
    #attention coefficient
    alpha_c = Softmax(name=layer_name[0])(matmul_0201) # [c,c]
    # Branch03 Dimension: [w,h,c] => [c,(wxh)]
    value = tf.reshape(inp_feature,[-1,(w*h),c]) # [(wxh),c]
    matmul_all = tf.einsum('ijk,ikl->ijl', value, alpha_c) # [(wxh),c]
    # output
    output = tf.reshape(matmul_all,[-1,w,h,c])# [w,h,c]
    #
    lambda_c = tf.keras.backend.variable(tf.zeros([1]), dtype='float32')
    output = Multiply()([output, lambda_c])
    output_add = Add(name=layer_name[-1])([output, inp_feature])
    return output_add

In [18]:
from tensorflow.keras import layers
class att_var(layers.Layer):
    '''
    Attention variable
    '''
    def __init__(self, initial_val):
        super(att_var, self).__init__()
        self.initial_val = initial_val
    def __call__(self):
        lambda_ = tf.Variable(initial_value=self.initial_val, trainable=True)
        return lambda_

lambda_c = att_var(tf.zeros([1]))
tst = lambda_c()

<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>


In [3]:
layer_name_p01 = ['pam01_conv01', 'pam01_conv02', 'pam01_softmax', 'pam01_conv03',
                  'pam01_alpha','pam01_add']
layer_name_c01 = ['cam01_softmax', 'cam01_alpha','cam01_add']
layer_name_p02 = ['pam02_conv01', 'pam02_conv02', 'pam02_softmax', 'pam02_conv03',
                  'pam02_alpha', 'pam02_add']
layer_name_c02 = ['cam02_softmax', 'cam02_alpha','cam02_add']
layer_name_template = [layer_name_p01, layer_name_c01, layer_name_p02, layer_name_c02]

layer_name_ga = []
for b in range(1,4):
    layer_block = []
    for layer in layer_name_template:
        layer_internal = [i+'block0{}'.format(b) for i in layer]
        layer_block.append(layer_internal)
    layer_name_ga.append(layer_block)

In [30]:
hn = 'he_normal' #kernel initializer
lambda_ = tf.keras.backend.variable(tf.zeros([1]), dtype='float32')
input_layer = Input(shape=(200,200,128))
model = Model(input_layer, CAM(input_layer, layer_name_ga[0][1]))

In [31]:
model.summary()

Model: "functional_13"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            [(None, 200, 200, 12 0                                            
__________________________________________________________________________________________________
tf_op_layer_Reshape_25 (TensorF [(None, 40000, 128)] 0           input_9[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_Einsum_18 (TensorFl [(None, 128, 40000)] 0           tf_op_layer_Reshape_25[0][0]     
__________________________________________________________________________________________________
tf_op_layer_Reshape_24 (TensorF [(None, 40000, 128)] 0           input_9[0][0]                    
______________________________________________________________________________________

# Conv3D

In [7]:
from tensorflow.keras.layers import Conv3D, UpSampling3D, MaxPool3D, GaussianNoise

hn = 'he_normal' #kernel initializer

def conv_block_3D(x, filters, norm_fn='gn', kernel_size=3,
               kernel_initializer=hn, acti_fn='prelu', dropout_rate=None):
    '''
    Dual convolution block with [full pre-activation], Norm -> Acti -> Conv
    :param x: Input features
    :param filters: A list that contains the number of filters for 1st and 2nd convolutional layer
    :param norm_fn: Tensorflow function for normalization, 'bn' for Batch Norm, 'gn' for Group Norm
    :param kernel_size: Kernel size for both convolutional layer with 3x3 as default
    :param kernel_initializer: Initializer for kernel weights with 'glorot uniform' as default
    :param acti_fn: Tensorflow function for activation, 'relu' for ReLU, 'prelu' for PReLU
    :param dropout_rate: Specify dropouts for layers
    :return: Feature maps of same size as input with number of filters equivalent to the last layer
    '''
    assert type(filters)==list, "Please input filters of type list."
    assert acti_fn!= None, 'There should be an activation functino specified'
    #1st convolutional block
    if norm_fn=='bn':
        x = BatchNormalization()(x)
    elif norm_fn=='gn':
        x = GroupNormalization()(x)
    if acti_fn=='relu':
        x = ReLU()(x)
    elif acti_fn=='prelu':
        x = PReLU(shared_axes=[1,2,3])(x)
    if dropout_rate != None:
        x = Dropout(dropout_rate)(x)
    x = Conv3D(filters[0], kernel_size, padding='same', kernel_initializer=kernel_initializer)(x)
    #2nd convolutional block
    if norm_fn=='bn':
        x = BatchNormalization()(x)
    elif norm_fn=='gn':
        x = GroupNormalization()(x)
    if acti_fn=='relu':
        x = ReLU()(x)
    elif acti_fn=='prelu':
        x = PReLU(shared_axes=[1,2,3])(x)
    x = Conv3D(filters[1], kernel_size, padding='same', kernel_initializer=kernel_initializer)(x)
    return x


def down_sampling_3D(x, filters, norm_fn='gn', kernel_size=3, acti_fn='relu',
            kernel_initializer=hn, dropout_rate=None):
    '''
    Down sampling function version 2 with Convolutional layer of stride 2 as downsampling operation, with
    [full pre-activation], Norm -> Acti -> Conv
    :param x: Input features
    :param filters: Number of filters for Convolutional layer of stride 2
    :param norm_fn: Tensorflow function for normalization, 'bn' for Batch Norm, 'gn' for Group Norm
    :param kernel_size: Kernel size for both convolutional layer with 3x3 as default
    :param acti_fn: Tensorflow function for activation, 'relu' for ReLU, 'prelu' for PReLU
    :param kernel_initializer: Initializer for kernel weights with 'glorot uniform' as default
    :param dropout_rate: Specify dropouts for layers
    :return: Feature maps of size scaled down by 2 with number of filters specified
    '''
    assert acti_fn!= None, 'There should be an activation function specified'
    #normalization
    if norm_fn=='bn':
        x = BatchNormalization()(x)
    elif norm_fn=='gn':
        x = GroupNormalization()(x)
    if acti_fn=='relu':
        x = ReLU()(x)
    #activation
    elif acti_fn=='prelu':
        x = PReLU(shared_axes=[1,2,3])(x)
    if dropout_rate != None:
        x = Dropout(dropout_rate)(x)
    #normal mode
    x = Conv3D(filters, kernel_size, strides=(1,2,2), padding='same', kernel_initializer=kernel_initializer)(x)
    return x


def res_block_3D(x_in, filters, norm_fn='gn', kernel_size=3,
               kernel_initializer=hn, acti_fn='prelu', dropout_rate=None):
    '''
    This function construct the residual block in 3D by input->conv_block_3D->concat([input,conv_output])
    :param x: Input features
    :param filters: A list that contains the number of filters for 1st and 2nd convolutional layer
    :param norm_fn: Tensorflow function for normalization, 'bn' for Batch Norm, 'gn' for Group Norm
    :param kernel_size: Kernel size for both convolutional layer with 3x3 as default
    :param kernel_initializer: Initializer for kernel weights with 'glorot uniform' as default
    :param acti_fn: Tensorflow function for activation, 'relu' for ReLU, 'prelu' for PReLU
    :param dropout_rate: Specify dropouts for layers
    :return: Resblock output => concatenating input with 2*convlutional output
    '''
    assert len(filters)==2, "Please assure that there is 2 values for filters."
    output_conv_block = conv_block_3D(x_in, filters, norm_fn=norm_fn, kernel_size=kernel_size,
                                   kernel_initializer = kernel_initializer, acti_fn = acti_fn, dropout_rate=dropout_rate)
    output_add = Add()([output_conv_block, x_in])
    return output_add


def up_3D(x_in, filters, merge, kernel_initializer=hn, size=(1, 2, 2)):
    '''
    This function carry out the operation of deconvolution => upsampling + convolution, and
    concatenating feture maps from the skip connection with the deconv feature maps
    @param x_in: input feature
    @param filters: Number of filters
    @param merge: featrure maps from the skip connection
    @param kernel_initializer: Initializer for kernel weights with 'glorot uniform' as default
    @param size: Upsampling size, by default (1,2,2)
    @return: concatenate feature maps of skip connection output and upsampled feature maps from previous output
    '''
    u = UpSampling3D(size)(x_in)
    conv = Conv3D(filters=filters, kernel_size=3, padding='same', kernel_initializer=kernel_initializer)(u)
    conv = PReLU(shared_axes=[1,2,3])(conv)
    concat = tf.concat([merge, conv], axis=-1)
    return concat

def vnet(x):
    # inject gaussian noise
    gauss1 = GaussianNoise(0.01)(x)
    # -----------down sampling path--------------------------------------
    # 1st block [155, 200, 200, 4]
    conv_01 = Conv3D(16, 3, padding='same', kernel_initializer=hn)(gauss1)
    conv_01 = PReLU(shared_axes=[1,2,3])(conv_01)
    res_block01 = conv_block_3D(conv_01, filters=[32, 16])
    # 2nd block [155, 100, 100, 4]
    down_01 = down_sampling_3D(res_block01,filters=32)
    res_block02 = res_block_3D(down_01, filters=[64, 32])
    # 3rd block [155, 50, 50, 4]
    down_02 = down_sampling_3D(res_block02,filters=64)
    res_block03 = res_block_3D(down_02, filters=[128, 64])
    # 4th block [155, 25, 25, 4] *latent space
    down_03 = down_sampling_3D(res_block03,filters=128)
    res_block04 = res_block_3D(down_03, filters=[256, 128])

    # -----------up sampling path-----------------------------------------
    # 1st up [155, 50, 50, 4]
    up_01 = up_3D(res_block04, 64, res_block03)
    up_conv01 = conv_block_3D(up_01, filters=[128, 128])
    # 2nd up [155, 100, 100, 4]
    up_02 = up_3D(up_conv01, 64, res_block02)
    up_conv02 = conv_block_3D(up_02, filters=[64, 64])
    # 3rd up [155, 200, 200, 4]
    up_03 = up_3D(up_conv02, 64, res_block01)
    up_conv03 = conv_block_3D(up_03, filters=[64, 64])

    #segmentation output
    output = Conv3D(4,kernel_size=1, activation='softmax',
                    kernel_initializer=hn)(up_conv03)
    return output

In [8]:
input_layer = Input(shape=(150,200,200,4))
model = Model(input_layer, vnet(input_layer))

In [9]:
model.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 150, 200, 20 0                                            
__________________________________________________________________________________________________
gaussian_noise (GaussianNoise)  (None, 150, 200, 200 0           input_2[0][0]                    
__________________________________________________________________________________________________
conv3d_22 (Conv3D)              (None, 150, 200, 200 1744        gaussian_noise[0][0]             
__________________________________________________________________________________________________
p_re_lu_18 (PReLU)              (None, 150, 200, 200 16          conv3d_22[0][0]                  
_______________________________________________________________________________________