In [34]:
import tensorflow as tf
from tensorflow.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.applications.xception import Xception
import segmentation_models as sm
import tensorflow_advanced_segmentation_models as tasm
import numpy as np
import visualkeras

## Dual Attention

This notebook will conduct tests using activations from the VGG Encoder to develop a dual attention module. Firstly the channel attention module will be built followed by the spatial attention module.

In [2]:
vgg19 = VGG19(weights="imagenet", include_top=False, input_shape=(512,512,3))

In [3]:
#Getting all the blocks from the VGG network
vgg_blocks = {
    f"block{n}" : [layer for layer in vgg19.layers if f"block{n}_conv" in layer.name] for n in range(1, 6)
}

In [4]:
def vgg_encoder_block(x, layers):
    """
    This function passes an input through a set of conv layers from VGG19, returning the downsampled and convolved activation
    """
    for layer in layers:
        x = layer(x)
    
    addition = x
    x = tf.keras.layers.MaxPooling2D((2,2), strides = 2)(x)
    return (x, addition)

def last_vgg_block(x, layers):

    for layer in layers:
        x = layer(x)
    
    return x

In [5]:
def vgg_encoder_full(input, layer_dict):

    """
    This function creates the full encoder given a dictionary of layers from the VGG network, it returns the final activation 
    and a list of intermediate activations
    """

    activations = []
    x = input
    for layer_name in list(layer_dict.keys())[:-1]:
        x, a = vgg_encoder_block(x, layer_dict[layer_name])
        activations.append(a)
    
    x = last_vgg_block(x, layer_dict[list(layer_dict.keys())[-1]])
    
    return x, activations

In [6]:
inp = vgg19.input

In [7]:
x, a = vgg_encoder_full(inp, vgg_blocks)

In [8]:
a

[<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'block1_conv2')>,
 <KerasTensor: shape=(None, 256, 256, 128) dtype=float32 (created by layer 'block2_conv2')>,
 <KerasTensor: shape=(None, 128, 128, 256) dtype=float32 (created by layer 'block3_conv4')>,
 <KerasTensor: shape=(None, 64, 64, 512) dtype=float32 (created by layer 'block4_conv4')>]

In [9]:
test_activation = a[0]

In [10]:
test_activation

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'block1_conv2')>

### Channel Attention

![](channel_attention.jpeg)

#### Duplication of the input

In [11]:
a1=a2=a3=a4= test_activation
H = a2.shape[1]
W = a2.shape[2]
C = a2.shape[3]
a2 = tf.keras.layers.Reshape((H*W, C))(a2)
a3 = tf.keras.layers.Reshape((H*W, C))(a3)
a4 = tf.transpose(tf.keras.layers.Reshape((H*W, C))(a4), perm=[0,2,1])


#### Producing the Softmax Output X

In [12]:
a_T_a = tf.linalg.matmul(a4, a3)
x = tf.keras.layers.Softmax()(a_T_a)
x = tf.transpose(x, perm=[0,2,1])


#### Producing E - the final output

In [13]:
a2_pass = tf.linalg.matmul(a2, x)
a2_pass = tf.keras.layers.Reshape((H,W,C))(a2_pass)

In [14]:
E = tf.keras.layers.Add()([a1, a2_pass])

In [15]:
E

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'add')>

#### Function of the Channel-Attention-Module

In [16]:
def cam(inputs):

    a1=a2=a3=a4= inputs
    H = a2.shape[1]
    W = a2.shape[2]
    C = a2.shape[3]
    a2 = tf.keras.layers.Reshape((H*W, C))(a2)
    a3 = tf.keras.layers.Reshape((H*W, C))(a3)
    a4 = tf.transpose(tf.keras.layers.Reshape((H*W, C))(a4), perm=[0,2,1])


    #Creating X, the softmax on the matrix product of A_T_A
    a_T_a = tf.linalg.matmul(a4, a3)
    x = tf.keras.layers.Softmax()(a_T_a)
    x = tf.transpose(x, perm=[0,2,1])

    a2_pass = tf.linalg.matmul(a2, x)
    a2_pass = tf.keras.layers.Reshape((H,W,C))(a2_pass)

    E = tf.keras.layers.Add()([a1, a2_pass])

    return E


        

In [17]:
e = cam(test_activation)
e

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'add_1')>

#### Layer for Channel Attention

In [18]:
class ChannelAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(ChannelAttention, self).__init__()
        self.beta = tf.Variable(initial_value=0.0, name="beta", trainable=True)
    
    def build(self, input_shape):
        self.C = input_shape[-1]
        self.H = input_shape[1]
        self.W = input_shape[2]
    
    def call(self, inputs):

        a1=a2=a3=a4= inputs
        n_shape = self.H * self.W
        a2 = tf.keras.layers.Reshape((n_shape, self.C))(a2)
        a3 = tf.keras.layers.Reshape((n_shape, self.C))(a3)
        a4 = tf.transpose(tf.keras.layers.Reshape((n_shape, self.C))(a4), perm=[0,2,1])


        #Creating X, the softmax on the matrix product of A_T_A
        a_T_a = tf.linalg.matmul(a4, a3)
        x = tf.keras.layers.Softmax()(a_T_a)
        x = tf.transpose(x, perm=[0,2,1])

        a2_pass = self.beta * tf.linalg.matmul(a2, x)
        a2_pass = tf.keras.layers.Reshape((self.H,self.W,self.C))(a2_pass)

        E = tf.keras.layers.Add()([a1, a2_pass])

        return E
            

In [19]:
e = ChannelAttention()(test_activation)

## Spatial Attention Module

The spatial attention module will be constructed in the same spirit as the channel attention module.

![](spatial_attention.jpeg)

In [20]:
class SpatialAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.alpha = tf.Variable(initial_value=0.0, trainable=True)
    
    def build(self, input_shape):
        
        self.C = input_shape[-1]
        self.H = input_shape[1]
        self.W = input_shape[2]

        #Defining the convolutions
        self.conv1 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv2 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv3 = tf.keras.layers.Conv2D(self.C, 1)

    def call(self, inputs):

        n_shape = self.H * self.W

        a = inputs
        b = self.conv1(inputs)
        c = self.conv2(inputs)
        d = self.conv3(inputs)

        b = tf.transpose(tf.keras.layers.Reshape((n_shape, self.C))(b), perm=[0,2,1])
        c = tf.keras.layers.Reshape((n_shape, self.C))(c)
        d = tf.keras.layers.Reshape((n_shape, self.C))(d)

        c = tf.linalg.matmul(c, b)
        S = tf.keras.layers.Softmax()(c)
        S = tf.transpose(S, perm=[0,2,1])

        d = self.alpha * tf.linalg.matmul(S, d)
        d = tf.keras.layers.Reshape((self.H, self.W, self.C))(d)
        E = tf.keras.layers.Add()([a, d])        

        return E


In [21]:
e = SpatialAttention()(test_activation)

In [22]:
e

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'spatial_attention')>

## Dual Attention Module

The dual attention module applies a convolution to the outputs of the spatial and channel attention modules, then applies an elementwise sum.

In [23]:
class DualAttention(tf.keras.layers.Layer):

    def __init__(self):
        super(DualAttention, self).__init__()
    
    def build(self, input_shape):
        self.C = input_shape[-1]
        self.conv1 = tf.keras.layers.Conv2D(self.C, 1)
        self.conv2 = tf.keras.layers.Conv2D(self.C, 1)
        self.sam = SpatialAttention()
        self.cam = ChannelAttention()
    
    def call(self, inputs):

        e1 = self.sam(inputs)
        e2 = self.cam(inputs)

        e1 = self.conv1(e1)
        e2 = self.conv2(e2)

        F = tf.keras.layers.Add()([e1, e2])
        return F

In [24]:
test_activation

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'block1_conv2')>

In [25]:
f = DualAttention()(test_activation)

In [26]:
f

<KerasTensor: shape=(None, 512, 512, 64) dtype=float32 (created by layer 'dual_attention')>

In [27]:
ar = np.array([1 , 3])

In [28]:
len(ar[ar>3])

0

In [29]:
not 0

True

## Testing The Incorporation of the Dual Attention Module

In [30]:
def decoder_block(a, x, f, attention=False):

    if attention:
        a = DualAttention()(a)

    x = tf.keras.layers.Conv2DTranspose(filters=f, kernel_size=2, strides=2, padding="same", activation="relu")(x)
    if a is not  None:
        x = tf.concat([a, x], axis=-1)
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x) 

    return x

def decoder_full(activations, x, filters, num_classes, attention_indices):

    #Looping over the activations and filters from bottom to top
    #Activation are reversed for this effect


        #Flag to indicate that there is no need for attention
    
    ai = None
    #Flag to indicate whether the point of attention is found
    found = True
    #Flag to pass to the decoder block, whether dual attention should be applied
    att = False
    for i,(a,f) in enumerate(zip(activations[::-1],filters)):
        there = len(attention_indices)
        if found and there:
            ai = attention_indices.pop()
        
        #Check if the current activation needs attention

        att=found = (i+1 == ai)
        print(att)
        x = decoder_block(a, x, f, att)
    
    output = tf.keras.layers.Conv2D(num_classes, 1, padding="same", activation="softmax")(x)

    return output

In [31]:
def vgg_encoder_block(x, layers):
    """
    This function passes an input through a set of conv layers from VGG19, returning the downsampled and convolved activation
    """
    for layer in layers:
        x = layer(x)
    
    addition = x
    x = tf.keras.layers.MaxPooling2D((2,2), strides = 2)(x)
    return (x, addition)

def last_vgg_block(x, layers):

    for layer in layers:
        x = layer(x)
    
    return x

def vgg_encoder_full(input, layer_dict):

    """
    This function creates the full encoder given a dictionary of layers from the VGG network, it returns the final activation 
    and a list of intermediate activations
    """

    activations = []
    x = input
    for layer_name in list(layer_dict.keys())[:-1]:
        x, a = vgg_encoder_block(x, layer_dict[layer_name])
        activations.append(a)
    
    x = last_vgg_block(x, layer_dict[list(layer_dict.keys())[-1]])
    
    return x, activations



def vgg_unet(num_classes, input_size, input_dim, att_indices=[], last_attention=False):

    #Downloading the VGG network
    vgg19 = VGG19(weights="imagenet", include_top=False, input_shape=(input_size, input_size,input_dim))
    vgg19.trainable = False
    #Getting all the blocks from the VGG network
    vgg_blocks = {
        f"block{n}" : [layer for layer in vgg19.layers if f"block{n}_conv" in layer.name] for n in range(1, 6)
    }
    
    #Filters for the Decoder
    filters = [512, 256, 128, 64]

    l = len(att_indices)
    assert l >= 0, "Attention indices should be 0 or greater"
    assert l <= len(filters) + 1, "Number of layers for attetention can not exceed 5"

    #assert len(att_indices[att_indices > 5]) == 0, "Attention indices must be from 1 to 5"
    

    vgg_input = vgg19.input

    #Defining the encoder

    #First Preprocess the input
    x = preprocess_input(x=vgg_input)
    
    x, a = vgg_encoder_full(x, vgg_blocks)

    if last_attention:
        x = DualAttention()(x)

    output = decoder_full(a, x, filters, num_classes, att_indices)

    vgg_unet_model = tf.keras.Model(vgg_input, output)

    return vgg_unet_model

In [36]:
m = vgg_unet(5, 512, 3, [2], True)

False
True
False
False


In [37]:
visualkeras.layered_view(m, scale_xy=1, legend=True).show()

: 

In [33]:
m.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 tf.__operators__.getitem (Slic  (None, 512, 512, 3)  0          ['input_2[0][0]']                
 ingOpLambda)                                                                                     
                                                                                                  
 tf.nn.bias_add (TFOpLambda)    (None, 512, 512, 3)  0           ['tf.__operators__.getitem[0][0]'
                                                                 ]                            