# Residual Attention Network

- A stack of Attention Modules
- Attention Modules have 2 branches
    - Trunk Branch
    - Soft Mask Branch

In [24]:
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, UpSampling2D, Activation, AveragePooling2D, Flatten, Dense, Add, Multiply, BatchNormalization
from tensorflow.keras.models import Model

- https://keras.io/layers/merge/
- https://keras.io/layers/convolutional/
- https://keras.io/activations/

In [3]:
class ResidualAttentionNetwork():
    
    def __init__(self, input_shape, n_classes, p=1, t=2, r=1):
        '''
        Params: 
        - num attention modules
        - p
        - r
        - t
        
        
        Conv Layer
        Max Pooling Layer
        
        Residual Unit
        Attention Module
        
        Residual Unit
        Attention Module
        
        Residual Unit
        Attention Module

        Residual Unit
        
        Average Pooling
        
        Flatten
        
        Dense Layer(s)
        Output Dense Layer (Num. Classer, activation='softmax'))
        '''
        
        # Initialize a Keras Tensor of input_shape
        input_data = Input(shape=input_shape)
        
        # Initial Layers before Attention Module
        conv_layer_1 = self.convolution_layer(input_data)
        max_pool_layer_1 = self.max_pool_layer(conv_layer_1)

        # Residual Unit then Attention Module #1
        res_unit_1 = self.residual_unit(max_pool_layer_1)
        att_mod_1 = self.attention_module(res_unit_1, p, t, r)

        # Residual Unit then Attention Module #1
        res_unit_2 = self.residual_unit(att_mod_1)
        att_mod_2 = self.attention_module(res_unit_2, p, t, r)
        
        # Residual Unit then Attention Module #1
        res_unit_3 = self.residual_unit(att_mod_2)
        att_mod_3 = self.attention_module(res_unit_3, p, t, r)
        
        # Ending it all
        res_unit_end_1 = self.residual_unit(att_mod_3)
        res_unit_end_2 = self.residual_unit(res_unit_end_1)
        res_unit_end_3 = self.residual_unit(res_unit_end_2)
        res_unit_end_4 = self.residual_unit(res_unit_end_3)
        
        avg_pool_layer = self.avg_pool_layer(res_unit_end_4)
        
        flatten_op = Flatten()(avg_pool_layer)
        
        fully_connected_layers = Dense(n_classes, activation='softmax')(flatten_op)
        
        ran_model = Model(inputs=input_data, outputs=fully_connected_layers)
        
        return ran_model

    def convolution_layer(self, input_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        
        conv_op = Conv2D(filters=filters, 
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same') (input_x)
        
        batch_op = BatchNormalization() (conv_op)
        
        activation_op = Activation('relu') (batch_op) 
    
        return activation_op
    
    def max_pool_layer(self, input_data, pool_size=(2, 2), strides=(2, 2)):
        return MaxPool2D(pool_size=pool_size, 
                      strides=strides, 
                      padding='same') (input_data)
    
    def avg_pool_layer(self, input_data, pool_size=(2, 2), strides=(2, 2)):
        return AveragePooling2D(pool_size=pool_size, 
                      strides=strides, 
                      padding='same') (input_data)
    
    def upsampling_layer(self, input_data, size=(2, 2), interpolation='bilinear'):
        return UpSampling2D(size=size, 
                            interpolation=interpolation) (input_data)
    
    def residual_unit(self, input_x):
        # Hold input_x here for later processing
        skipped_x = input_x

        # Layer 1
        res_conv_1 = self.convolution_layer(input_x, filters=32)
        
        # Layer 2
        res_conv_2 = self.convolution_layer(res_conv_1, filters=64)
        
        # Connecting Layer
        output = self.connecting_residual_layer(input_x=res_conv_2, skipped_x=skipped_x)
        
        return output
    
    def connecting_residual_layer(self, input_x, skipped_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        # Connecting Layer
        conv_op = Conv2D(filters=filters, 
                   kernel_size=kernel_size,
                   strides=strides,
                   padding='same') (input_x)
        
        batch_op = BatchNormalization() (conv_op)
        
        # Combine processed_x with input_x, aka skipped_x
        add_op = Add()[batch_op, skipped_x]
        
        activation_op = Activation('relu') (add_op) 
    
        return activation_op
    
    def attention_module(self, input_x, p, t, r):
        
        # Send input_x through #p residual_units
        p_res_unit_op_1 = input_x
        for _ in range(p):
            p_res_unit_op = self.residual_unit(p_res_unit_op)
        
        # Perform Trunk Branch Operation
        trunk_branch_op = self.trunk_branch(trunk_input_x=p_res_unit_op, t=t)
        
        # Perform Mask Branch Operation
        mask_branch_op = self.mask_branch(mask_input_x=p_res_unit_op, r=r)
        
        # Perform Attention Residual Learning: Combine Trunk and Mask branch results
        ar_learning_op = self.attention_residual_learning(mask_input=mask_branch_op, trunk_input=trunk_branch_op)
        
        # Send branch results through #p residual_units
        p_res_unit_op_2 = ar_learning_op
        for _ in range(p):
            p_res_unit_op_2 = self.residual_unit(p_res_unit_op_2)
        
        return p_res_unit_op_2
    
    def trunk_branch(self, trunk_input_x, t):
        #sequence of residual units
        t_res_unit_op = trunk_input_x
        for _ in range(t):
            t_res_unit_op = self.residual_unit(t_res_unit_op)
            
        return t_res_unit_op

    def mask_branch(self, mask_input_x, r, m=3):    
        # r = num of residual units between adjacent pooling layers
        # m = num max pooling / linear interpolations to do
        
        # Downsampling Step
        downsampling = mask_input_x
        
        for _ in range(m):
            downsampling = self.max_pool_layer(input_data=downsampling)
            
            # Perform residual units ops r times between adjacent pooling layers
            for _ in range(r):
                downsampling = self.residual_unit(input_x=downsampling)
        
        # Last pooling step before middle step
        downsampling = self.max_pool_layer(input_data=downsampling)
        
        # Perform 2*r residual units steps before upsampling
        middleware = downsampling
        for _ in range(2*r):
            middleware = self.residual_unit(input_x=middleware)
        
        # Upsampling Step
        upsampling = middleware
        
        for _ in range(m):
            upsampling = self.upsampling_layer(input_data=upsampling)
            
            # Perform residual units ops r times between adjacent pooling layers
            for _ in range(r):
                upsampling = self.residual_unit(input_x=upsampling)
        
        # Last interpolation step
        upsampling = self.upsampling_layer(input_data=upsampling)
        
        conv1 = self.convolution_layer(input_x=upsampling, kernel_size=(1,1))
        conv2 = self.convolution_layer(input_x=conv1, kernel_size=(1,1))
        
        sigmoid = Activation('sigmoid') (conv2)
        
        return sigmoid
    
    def attention_residual_learning(self, mask_input, trunk_input):
        M = Add()[1, mask_input] # 1 + mask
        return Multiply()[M, trunk_input] # M(x) * T(x)

