# Residual Attention Network

- A stack of Attention Modules
- Attention Modules have 2 branches
    - Trunk Branch
    - Soft Mask Branch

In [6]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, MaxPool2D, UpSampling2D, AveragePooling2D
from tensorflow.keras.layers import Activation, Flatten, Dense, Add, Multiply, BatchNormalization

from tensorflow.keras.models import Model
import keras

import os
from keras.datasets import cifar10

In [7]:
# Todo: Make scalable/all-encompassing
class ResidualAttentionNetwork():

    def __init__(self, input_shape, n_classes, p=1, t=2, r=1):
        '''
        Params:
        - num attention modules
        - p
        - r
        - t


        Conv Layer
        Max Pooling Layer

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit

        Average Pooling

        Flatten

        Dense Layer(s)
        Output Dense Layer (Num. Classer, activation='softmax'))
        '''

        # Initialize a Keras Tensor of input_shape
        input_data = Input(shape=input_shape)
        
        # Initial Layers before Attention Module
        conv_layer_1 = self.convolution_layer(conv_input_data=input_data)
        
        max_pool_layer_1 = self.max_pool_layer(conv_layer_1)

        # Residual Unit then Attention Module #1
        res_unit_1 = self.residual_unit(max_pool_layer_1)
        
        att_mod_1 = self.attention_module(res_unit_1, p, t, r)
        
        # Residual Unit then Attention Module #2
        res_unit_2 = self.residual_unit(att_mod_1)
        att_mod_2 = self.attention_module(res_unit_2, p, t, r)

        # Residual Unit then Attention Module #3
        res_unit_3 = self.residual_unit(att_mod_2)
        att_mod_3 = self.attention_module(res_unit_3, p, t, r)

        # Ending it all
        res_unit_end_1 = self.residual_unit(att_mod_3)
        res_unit_end_2 = self.residual_unit(res_unit_end_1)
        res_unit_end_3 = self.residual_unit(res_unit_end_2)
        res_unit_end_4 = self.residual_unit(res_unit_end_3)

        # Avg Pooling
        avg_pool_layer = self.avg_pool_layer(res_unit_end_4)

        # Flatten the data
        flatten_op = Flatten()(avg_pool_layer)

        # FC Layer for prediction
        fully_connected_layers = Dense(n_classes, activation='softmax')(flatten_op)

        # Fully constructed model
        self.model = Model(inputs=input_data, outputs=fully_connected_layers)

    def convolution_layer(self, conv_input_data, filters=32, kernel_size=(5, 5), strides=(1, 1)):

        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conv_input_data)

        batch_op = BatchNormalization()(conv_op)

        activation_op = Activation('relu')(batch_op)

        return activation_op

    def max_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return MaxPool2D(pool_size=pool_size,
                         strides=strides,
                         padding='same')(pool_input_data)

    def avg_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return AveragePooling2D(pool_size=pool_size,
                                strides=strides,
                                padding='same')(pool_input_data)

    def upsampling_layer(self, upsampling_input_data, size=(2, 2), interpolation='bilinear'):
        return UpSampling2D(size=size,
                            interpolation=interpolation)(upsampling_input_data)

    def residual_unit(self, residual_input_data):
        # Hold input_x here for later processing
        skipped_x = residual_input_data

        # Layer 1
        res_conv_1 = self.convolution_layer(conv_input_data=residual_input_data, filters=32)

        # Layer 2
        res_conv_2 = self.convolution_layer(conv_input_data=res_conv_1, filters=64)

        # Connecting Layer
        output = self.connecting_residual_layer(conn_input_data=res_conv_2, skipped_x=skipped_x)

        return output

    def connecting_residual_layer(self, conn_input_data, skipped_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        # Connecting Layer
        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conn_input_data)

        batch_op = BatchNormalization()(conv_op)
        
        # Todo: 
            # Do some work if skipped_x.shape is not the same as batch_op.shape
            # Gotta do the convolution + batch_norm work on skipped x

        # Combine processed_x with skipped_x
        add_op = Add()([batch_op, skipped_x])

        activation_op = Activation('relu')(add_op)

        return activation_op

    def attention_module(self, attention_input_data, p, t, r):

        # Send input_x through #p residual_units
        p_res_unit_op_1 = attention_input_data
        for i in range(p):
            p_res_unit_op_1 = self.residual_unit(p_res_unit_op_1)

        # Perform Trunk Branch Operation
        trunk_branch_op = self.trunk_branch(trunk_input_data=p_res_unit_op_1, t=t)

        # Perform Mask Branch Operation
        mask_branch_op = self.mask_branch(mask_input_data=p_res_unit_op_1, r=r)

        # Perform Attention Residual Learning: Combine Trunk and Mask branch results
        ar_learning_op = self.attention_residual_learning(mask_input=mask_branch_op, trunk_input=trunk_branch_op)

        # Send branch results through #p residual_units
        p_res_unit_op_2 = ar_learning_op
        for _ in range(p):
            p_res_unit_op_2 = self.residual_unit(p_res_unit_op_2)

        return p_res_unit_op_2

    def trunk_branch(self, trunk_input_data, t):
        # sequence of residual units
        t_res_unit_op = trunk_input_data
        for _ in range(t):
            t_res_unit_op = self.residual_unit(t_res_unit_op)

        return t_res_unit_op

    def mask_branch(self, mask_input_data, r, m=3):
        # r = num of residual units between adjacent pooling layers
        # m = num max pooling / linear interpolations to do

        # Downsampling Step Initialization - Top
        downsampling = self.max_pool_layer(pool_input_data=mask_input_data)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            downsampling = self.residual_unit(residual_input_data=downsampling)

        # Last pooling step before middle step - Bottom
        downsampling = self.max_pool_layer(pool_input_data=downsampling)

        # Middle Residuals - Perform 2*r residual units steps before upsampling
        middleware = downsampling
        for _ in range(2 * r):
            middleware = self.residual_unit(residual_input_data=middleware)

        # Upsampling Step Initialization - Top
        upsampling = self.upsampling_layer(upsampling_input_data=middleware)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            upsampling = self.residual_unit(residual_input_data=upsampling)

        # Last interpolation step - Bottom
        upsampling = self.upsampling_layer(upsampling_input_data=upsampling)

        conv1 = self.convolution_layer(conv_input_data=upsampling, kernel_size=(1, 1))
        conv2 = self.convolution_layer(conv_input_data=conv1, kernel_size=(1, 1))

        sigmoid = Activation('sigmoid')(conv2)

        return sigmoid

    def attention_residual_learning(self, mask_input, trunk_input):
        # https://stackoverflow.com/a/53361303/9221241
        m = Lambda(lambda x: 1 + x)(mask_input) # 1 + mask
        
        # https://www.tensorflow.org/api_docs/python/tf/pad
        # https://stackoverflow.com/questions/43928642/how-does-tensorflow-pad-work
        # https://stackoverflow.com/questions/34141430/tensorflow-tensor-reshape-and-pad-with-zeros
        # if m.shape != trunk_input.shape:
        #    print(max(m.shape[1], trunk_input.shape[1]),max(m.shape[2],trunk_input.shape[2]))
            
        return Multiply()([m, trunk_input]) # M(x) * T(x)

In [8]:
# https://keras.io/examples/cifar10_cnn/
batch_size = 32
num_classes = 10
epochs = 100
data_augmentation = True
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

In [9]:
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples


In [10]:
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

In [11]:
input_shape = x_train.shape[1:]

In [13]:
ran_model = ResidualAttentionNetwork(input_shape=input_shape, n_classes=num_classes).model

In [14]:
ran_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [15]:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [None]:
ran_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(x_test, y_test),
          shuffle=True,    
          workers=6,
          use_multiprocessing=True)

Train on 50000 samples, validate on 10000 samples
Instructions for updating:
Use tf.cast instead.
Epoch 1/100
 4736/50000 [=>............................] - ETA: 1:11:56 - loss: 13.3113 - acc: 0.1626