In [1]:
import sys
import time
import tensorflow as tf
from utils_model import *
from tensorflow.keras.layers import Conv2D, Add, Multiply
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Softmax, Input
from tensorflow.keras import Model

In [29]:
def PAM(inp_feature, layer_name, kernel_initializer='glorot_uniform', acti='relu'):
    '''
    Position attention module
    by default input shape => [w,h,c],[240, 240, 128] hence c/8 = 16
    :param layer_name: List of layer names
    [1st conv block, 2nd conv block, softmax output, 3rd conv block, position coefficient, Add output]
    :param inp_feature: feature maps of res block after up sampling [w,h,c]
    :return: PAM features [w/4,h/4,c]
    '''
    # dimensions
    b,w,h,c = inp_feature.shape
    # scale down ratio
    c_8 = c//8
    #
    assert len(layer_name)>=5, 'Layer list length should be 5!'
    # Branch01 Dimension: [w,h,c/8] => [(wxh),c/8]
    query = conv_2d(inp_feature, filters=c_8, layer_name=layer_name[0], batch_norm=False, kernel_size=(1, 1), acti=acti,
            kernel_initializer=kernel_initializer, dropout_rate=None)
    query = tf.reshape(query,[-1,(w*h),c_8 ])
    # Branch02 Dimension: [w,h,c/8] => [c/8,(wxh)]
    key = conv_2d(inp_feature, filters=c_8, layer_name=layer_name[1], batch_norm=False, kernel_size=(1, 1), acti=acti,
        kernel_initializer=kernel_initializer, dropout_rate=None)
    key = tf.reshape(key, [-1,(w*h),c_8 ])
    key = tf.einsum('bij->bji', key) # transpose/permutation
    # matmul pipeline 01 & 02
    matmul_0102 = tf.einsum('bij,bjk->bik', query, key) # [(wxh),(wxh)]
    #attention coefficient
    alpha_p = Softmax(name=layer_name[2])(matmul_0102) # [(wxh),(wxh)]
    # Branch03
    value = conv_2d(inp_feature, filters=c, layer_name=layer_name[3], batch_norm=False, kernel_size=(1, 1), acti=acti,
        kernel_initializer=kernel_initializer, dropout_rate=None)
    value = tf.reshape(value,[-1,(w*h),c]) # [(wxh),c]
    matmul_all = tf.einsum('bij,bjk->bik',alpha_p,value) # [(wxh),c]
    # Output
    output = tf.reshape(matmul_all, [-1,w,h,c]) # [w,h,c]
    # learnable coefficient to control the importance of CAM
    lambda_p = Conv2D(filters=1,kernel_size=1, padding='same',activation='sigmoid', name=layer_name[4])(inp_feature)
    output = Multiply()([output, lambda_p])
    output_add = Add(name = layer_name[-1])([output, inp_feature])
    return output_add

def CAM(inp_feature, layer_name):
    '''
    Channel attention module
    by default input shape => [w,h,c],[240, 240, 128] hence c/8 = 16
    :param inp_feature: feature maps of res block after up sampling [w,h,c]k
    :param layer_name: List of layer names
        [softmax output, channel attention coefficients, Add output]
    :return: CAM features [w/4,h/4,c]
    '''
    # dimensions
    b,w,h,c = inp_feature.shape
    # learnable coefficient to control the importance of CAM
    assert len(layer_name)>=2, 'Layer list length should be 2!'
    # Branch01 Dimension: [w,h,c] => [(wxh),c]
    query = tf.reshape(inp_feature, [-1,(w*h),c])
    # Branch02 Dimension: [w,h,c] => [c,(wxh)]
    key = tf.reshape(inp_feature, [-1,(w*h),c]) # [(wxh),c]
    key = tf.einsum('ijk->ikj', key) # Permute:[c,(wxh)]
    # matmul pipeline 01 & 02
    matmul_0201 = tf.einsum('ijk,ikl->ijl', key, query) # [c,c]
    #attention coefficient
    alpha_c = Softmax(name=layer_name[0])(matmul_0201) # [c,c]
    # Branch03 Dimension: [w,h,c] => [c,(wxh)]
    value = tf.reshape(inp_feature,[-1,(w*h),c]) # [(wxh),c]
    matmul_all = tf.einsum('ijk,ikl->ijl', value, alpha_c) # [(wxh),c]
    # output
    output = tf.reshape(matmul_all,[-1,w,h,c])# [w,h,c]
    #
    lambda_c = tf.keras.backend.variable(tf.zeros([1]), dtype='float32')
    output = Multiply()([output, lambda_c])
    output_add = Add(name=layer_name[-1])([output, inp_feature])
    return output_add

In [18]:
from tensorflow.keras import layers
class att_var(layers.Layer):
    '''
    Attention variable
    '''
    def __init__(self, initial_val):
        super(att_var, self).__init__()
        self.initial_val = initial_val
    def __call__(self):
        lambda_ = tf.Variable(initial_value=self.initial_val, trainable=True)
        return lambda_

lambda_c = att_var(tf.zeros([1]))
tst = lambda_c()

<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>


In [3]:
layer_name_p01 = ['pam01_conv01', 'pam01_conv02', 'pam01_softmax', 'pam01_conv03',
                  'pam01_alpha','pam01_add']
layer_name_c01 = ['cam01_softmax', 'cam01_alpha','cam01_add']
layer_name_p02 = ['pam02_conv01', 'pam02_conv02', 'pam02_softmax', 'pam02_conv03',
                  'pam02_alpha', 'pam02_add']
layer_name_c02 = ['cam02_softmax', 'cam02_alpha','cam02_add']
layer_name_template = [layer_name_p01, layer_name_c01, layer_name_p02, layer_name_c02]

layer_name_ga = []
for b in range(1,4):
    layer_block = []
    for layer in layer_name_template:
        layer_internal = [i+'block0{}'.format(b) for i in layer]
        layer_block.append(layer_internal)
    layer_name_ga.append(layer_block)

In [30]:
hn = 'he_normal' #kernel initializer
lambda_ = tf.keras.backend.variable(tf.zeros([1]), dtype='float32')
input_layer = Input(shape=(200,200,128))
model = Model(input_layer, CAM(input_layer, layer_name_ga[0][1]))

In [31]:
model.summary()

Model: "functional_13"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            [(None, 200, 200, 12 0                                            
__________________________________________________________________________________________________
tf_op_layer_Reshape_25 (TensorF [(None, 40000, 128)] 0           input_9[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_Einsum_18 (TensorFl [(None, 128, 40000)] 0           tf_op_layer_Reshape_25[0][0]     
__________________________________________________________________________________________________
tf_op_layer_Reshape_24 (TensorF [(None, 40000, 128)] 0           input_9[0][0]                    
______________________________________________________________________________________