# Residual Attention Network

- A stack of Attention Modules
- Attention Modules have 2 branches
    - Trunk Branch
    - Soft Mask Branch

In [31]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, MaxPool2D, UpSampling2D, Activation, AveragePooling2D, Flatten, Dense, Add, Multiply, BatchNormalization
from tensorflow.keras.models import Model
from keras.datasets import mnist
import keras

- https://keras.io/layers/merge/
- https://keras.io/layers/convolutional/
- https://keras.io/activations/

In [96]:
# Todo: Make scalable/all-encompassing
class ResidualAttentionNetwork():

    def __init__(self, input_shape, n_classes, p=1, t=2, r=1):
        '''
        Params:
        - num attention modules
        - p
        - r
        - t


        Conv Layer
        Max Pooling Layer

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit

        Average Pooling

        Flatten

        Dense Layer(s)
        Output Dense Layer (Num. Classer, activation='softmax'))
        '''

        # Initialize a Keras Tensor of input_shape
        input_data = Input(shape=input_shape)
        
        print("Inp:",input_data.shape)
        # Initial Layers before Attention Module
        conv_layer_1 = self.convolution_layer(conv_input_data=input_data)
        print("CV1:",conv_layer_1.shape)
        
        max_pool_layer_1 = self.max_pool_layer(conv_layer_1)
        print("MP1:",max_pool_layer_1.shape)

        # Residual Unit then Attention Module #1
        res_unit_1 = self.residual_unit(max_pool_layer_1)
        print("R1", res_unit_1.shape)
        
        att_mod_1 = self.attention_module(res_unit_1, p, t, r)
        print("A1:", att_mod_1.shape)
        
        # Residual Unit then Attention Module #2
        res_unit_2 = self.residual_unit(att_mod_1)
        att_mod_2 = self.attention_module(res_unit_2, p, t, r)

        # Residual Unit then Attention Module #3
        res_unit_3 = self.residual_unit(att_mod_2)
        att_mod_3 = self.attention_module(res_unit_3, p, t, r)

        # Ending it all
        res_unit_end_1 = self.residual_unit(att_mod_3)
        res_unit_end_2 = self.residual_unit(res_unit_end_1)
        res_unit_end_3 = self.residual_unit(res_unit_end_2)
        res_unit_end_4 = self.residual_unit(res_unit_end_3)

        # Avg Pooling
        avg_pool_layer = self.avg_pool_layer(res_unit_end_4)

        # Flatten the data
        flatten_op = Flatten()(avg_pool_layer)

        # FC Layer for prediction
        fully_connected_layers = Dense(n_classes, activation='softmax')(flatten_op)

        # Fully constructed model
        self.model = Model(inputs=input_data, outputs=fully_connected_layers)

    def convolution_layer(self, conv_input_data, filters=32, kernel_size=(5, 5), strides=(1, 1)):

        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conv_input_data)

        batch_op = BatchNormalization()(conv_op)

        activation_op = Activation('relu')(batch_op)

        return activation_op

    def max_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return MaxPool2D(pool_size=pool_size,
                         strides=strides,
                         padding='same')(pool_input_data)

    def avg_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return AveragePooling2D(pool_size=pool_size,
                                strides=strides,
                                padding='same')(pool_input_data)

    def upsampling_layer(self, upsampling_input_data, size=(2, 2), interpolation='bilinear'):
        return UpSampling2D(size=size,
                            interpolation=interpolation)(upsampling_input_data)

    def residual_unit(self, residual_input_data):
        # Hold input_x here for later processing
        skipped_x = residual_input_data

        # Layer 1
        res_conv_1 = self.convolution_layer(conv_input_data=residual_input_data, filters=32)

        # Layer 2
        res_conv_2 = self.convolution_layer(conv_input_data=res_conv_1, filters=64)

        # Connecting Layer
        output = self.connecting_residual_layer(conn_input_data=res_conv_2, skipped_x=skipped_x)

        return output

    def connecting_residual_layer(self, conn_input_data, skipped_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        # Connecting Layer
        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conn_input_data)

        batch_op = BatchNormalization()(conv_op)
        
        # Todo: 
            # Do some work if skipped_x.shape is not the same as batch_op.shape
            # Gotta do the convolution + batch_norm work on skipped x

        # Combine processed_x with skipped_x
        add_op = Add()([batch_op, skipped_x])

        activation_op = Activation('relu')(add_op)

        return activation_op

    def attention_module(self, attention_input_data, p, t, r):

        # Send input_x through #p residual_units
        p_res_unit_op_1 = attention_input_data
        for i in range(p):
            print("AM_P{}".format(i), p_res_unit_op_1.shape)
            p_res_unit_op_1 = self.residual_unit(p_res_unit_op_1)

        # Perform Trunk Branch Operation
        trunk_branch_op = self.trunk_branch(trunk_input_data=p_res_unit_op_1, t=t)
        print("Tr:",trunk_branch_op.shape)

        # Perform Mask Branch Operation
        mask_branch_op = self.mask_branch(mask_input_data=p_res_unit_op_1, r=r)
        print("Mask:",mask_branch_op.shape)

        # Perform Attention Residual Learning: Combine Trunk and Mask branch results
        ar_learning_op = self.attention_residual_learning(mask_input=mask_branch_op, trunk_input=trunk_branch_op)
        print("ARL:", ar_learning_op.shape)
        # Send branch results through #p residual_units
        p_res_unit_op_2 = ar_learning_op
        for _ in range(p):
            p_res_unit_op_2 = self.residual_unit(p_res_unit_op_2)

        return p_res_unit_op_2

    def trunk_branch(self, trunk_input_data, t):
        # sequence of residual units
        t_res_unit_op = trunk_input_data
        for _ in range(t):
            t_res_unit_op = self.residual_unit(t_res_unit_op)

        return t_res_unit_op

    def mask_branch(self, mask_input_data, r, m=3):
        # r = num of residual units between adjacent pooling layers
        # m = num max pooling / linear interpolations to do

        # Downsampling Step
        downsampling = mask_input_data

        for i in range(m):
            downsampling = self.max_pool_layer(pool_input_data=downsampling)
            print("Mask_Down{}".format(i), downsampling.shape)

            # Perform residual units ops r times between adjacent pooling layers
            for j in range(r):
                downsampling = self.residual_unit(residual_input_data=downsampling)
                print("Mask_Down_Residual{}".format(j), downsampling.shape)

        # Last pooling step before middle step
        downsampling = self.max_pool_layer(pool_input_data=downsampling)
        print("Mask_Down_last", downsampling.shape)

        # Perform 2*r residual units steps before upsampling
        middleware = downsampling
        for _ in range(2 * r):
            middleware = self.residual_unit(residual_input_data=middleware)
            print("Mask_middle", middleware.shape)

        # Upsampling Step
        upsampling = middleware

        for i in range(m):
            upsampling = self.upsampling_layer(upsampling_input_data=upsampling)
            print("Mask_Up{}".format(i), upsampling.shape)

            # Perform residual units ops r times between adjacent pooling layers
            for j in range(r):
                upsampling = self.residual_unit(residual_input_data=upsampling)
                print("Mask_Up_Residual{}".format(j), upsampling.shape)

        # Last interpolation step
        upsampling = self.upsampling_layer(upsampling_input_data=upsampling)
        print("Mask_up_last", upsampling.shape)

        conv1 = self.convolution_layer(conv_input_data=upsampling, kernel_size=(1, 1))
        print("Mask_conv1", conv1.shape)
        
        conv2 = self.convolution_layer(conv_input_data=conv1, kernel_size=(1, 1))
        print("Mask_conv2", conv2.shape)

        sigmoid = Activation('sigmoid')(conv2)
        print("sigmoid", sigmoid.shape)

        return sigmoid

    def attention_residual_learning(self, mask_input, trunk_input):
        # https://stackoverflow.com/a/53361303/9221241
        m = res5 = Lambda(lambda x: 1 + x)(mask_input) # 1 + mask
        return Multiply()([m, trunk_input]) # M(x) * T(x)

In [97]:
ran_model = ResidualAttentionNetwork(input_shape=input_shape, n_classes=num_classes).model

Inp: (?, 28, 28, 1)
CV1: (?, 28, 28, 32)
MP1: (?, 28, 28, 32)
R1 (?, 28, 28, 32)
AM_P0 (?, 28, 28, 32)
Tr: (?, 28, 28, 32)
Mask_Down0 (?, 14, 14, 32)
Mask_Down_Residual0 (?, 14, 14, 32)
Mask_Down1 (?, 7, 7, 32)
Mask_Down_Residual0 (?, 7, 7, 32)
Mask_Down2 (?, 4, 4, 32)
Mask_Down_Residual0 (?, 4, 4, 32)
Mask_Down_last (?, 4, 4, 32)
Mask_middle (?, 4, 4, 32)
Mask_middle (?, 4, 4, 32)
Mask_Up0 (?, 4, 4, 32)
Mask_Up_Residual0 (?, 4, 4, 32)
Mask_Up1 (?, 4, 4, 32)
Mask_Up_Residual0 (?, 4, 4, 32)
Mask_Up2 (?, 4, 4, 32)
Mask_Up_Residual0 (?, 4, 4, 32)
Mask_up_last (?, 4, 4, 32)
Mask_conv1 (?, 4, 4, 32)
Mask_conv2 (?, 4, 4, 32)
sigmoid (?, 4, 4, 32)
Mask: (?, 4, 4, 32)


ValueError: Operands could not be broadcast together with shapes (4, 4, 32) (28, 28, 32)

In [24]:
batch_size = 128
num_classes = 10
epochs = 12

In [25]:
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

In [26]:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [27]:
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)