# Residual Attention Network
https://arxiv.org/abs/1704.06904
- A stack of Attention Modules
- Attention Modules have 2 branches
    - Trunk Branch
    - Soft Mask Branch

### Helpful Links
- https://towardsdatascience.com/residual-blocks-building-blocks-of-resnet-fd90ca15d6ec
- https://towardsdatascience.com/understanding-and-coding-a-resnet-in-keras-446d7ff84d33
- https://towardsdatascience.com/review-residual-attention-network-attention-aware-features-image-classification-7ae44c4f4b8
- https://sebastianwallkoetter.wordpress.com/2018/04/08/layered-layers-residual-blocks-in-the-sequential-keras-api/

In [None]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, MaxPool2D, UpSampling2D, AveragePooling2D, ZeroPadding2D
from tensorflow.keras.layers import Activation, Flatten, Dense, Add, Multiply, BatchNormalization

from tensorflow.keras.models import Model
import keras

import os

# Residual Attention Network Model

In [None]:
# Todo: Make scalable/all-encompassing
class ResidualAttentionNetwork():

    def __init__(self, input_shape, n_classes, p=1, t=2, r=1):
        '''
        Params:
        - num attention modules
        - p
        - r
        - t


        Conv Layer
        Max Pooling Layer

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit

        Average Pooling

        Flatten

        Dense Layer(s)
        Output Dense Layer (Num. Classer, activation='softmax'))
        '''

        # Initialize a Keras Tensor of input_shape
        input_data = Input(shape=input_shape)
        
        # Initial Layers before Attention Module
        
        # Doing padding because I'm having trouble with img dims that are <= 28
        if input_shape[0] <= 28 or input_shape[1] <= 28:
            x_dim_inc = (32 - input_shape[0]) // 2
            y_dim_inc = (32 - input_shape[1]) // 2
            padded_input_data = ZeroPadding2D( (x_dim_inc,y_dim_inc) )(input_data)
            conv_layer_1 = self.convolution_layer(conv_input_data=padded_input_data)
        else:
            conv_layer_1 = self.convolution_layer(conv_input_data=input_data)
            
        
        
        max_pool_layer_1 = self.max_pool_layer(conv_layer_1)

        # Residual Unit then Attention Module #1
        res_unit_1 = self.residual_unit(max_pool_layer_1)
        
        att_mod_1 = self.attention_module(res_unit_1, p, t, r)
        
        # Residual Unit then Attention Module #2
        res_unit_2 = self.residual_unit(att_mod_1)
        att_mod_2 = self.attention_module(res_unit_2, p, t, r)

        # Residual Unit then Attention Module #3
        res_unit_3 = self.residual_unit(att_mod_2)
        att_mod_3 = self.attention_module(res_unit_3, p, t, r)

        # Ending it all
        res_unit_end_1 = self.residual_unit(att_mod_3)
        res_unit_end_2 = self.residual_unit(res_unit_end_1)
        res_unit_end_3 = self.residual_unit(res_unit_end_2)
        res_unit_end_4 = self.residual_unit(res_unit_end_3)

        # Avg Pooling
        avg_pool_layer = self.avg_pool_layer(res_unit_end_4)

        # Flatten the data
        flatten_op = Flatten()(avg_pool_layer)

        # FC Layer for prediction
        fully_connected_layers = Dense(n_classes, activation='softmax')(flatten_op)

        # Fully constructed model
        self.model = Model(inputs=input_data, outputs=fully_connected_layers)

    def convolution_layer(self, conv_input_data, filters=32, kernel_size=(3, 3), strides=(1, 1)):

        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conv_input_data)

        batch_op = BatchNormalization()(conv_op)

        activation_op = Activation('relu')(batch_op)

        return activation_op

    def max_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return MaxPool2D(pool_size=pool_size,
                         strides=strides,
                         padding='same')(pool_input_data)

    def avg_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return AveragePooling2D(pool_size=pool_size,
                                strides=strides,
                                padding='same')(pool_input_data)

    def upsampling_layer(self, upsampling_input_data, size=(2, 2), interpolation='bilinear'):
        return UpSampling2D(size=size,
                            interpolation=interpolation)(upsampling_input_data)

    def residual_unit(self, residual_input_data):
        # Hold input_x here for later processing
        skipped_x = residual_input_data

        # Layer 1
        res_conv_1 = self.convolution_layer(conv_input_data=residual_input_data, filters=32)

        # Layer 2
        res_conv_2 = self.convolution_layer(conv_input_data=res_conv_1, filters=64)

        # Connecting Layer
        output = self.connecting_residual_layer(conn_input_data=res_conv_2, skipped_x=skipped_x)

        return output

    def connecting_residual_layer(self, conn_input_data, skipped_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        # Connecting Layer
        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conn_input_data)

        batch_op = BatchNormalization()(conv_op)
        
        # Todo: 
            # Do some work if skipped_x.shape is not the same as batch_op.shape
            # Gotta do the convolution + batch_norm work on skipped x

        # Combine processed_x with skipped_x
        add_op = Add()([batch_op, skipped_x])

        activation_op = Activation('relu')(add_op)

        return activation_op

    def attention_module(self, attention_input_data, p, t, r):

        # Send input_x through #p residual_units
        p_res_unit_op_1 = attention_input_data
        for i in range(p):
            p_res_unit_op_1 = self.residual_unit(p_res_unit_op_1)

        # Perform Trunk Branch Operation
        trunk_branch_op = self.trunk_branch(trunk_input_data=p_res_unit_op_1, t=t)

        # Perform Mask Branch Operation
        mask_branch_op = self.mask_branch(mask_input_data=p_res_unit_op_1, r=r)

        # Perform Attention Residual Learning: Combine Trunk and Mask branch results
        ar_learning_op = self.attention_residual_learning(mask_input=mask_branch_op, trunk_input=trunk_branch_op)

        # Send branch results through #p residual_units
        p_res_unit_op_2 = ar_learning_op
        for _ in range(p):
            p_res_unit_op_2 = self.residual_unit(p_res_unit_op_2)

        return p_res_unit_op_2

    def trunk_branch(self, trunk_input_data, t):
        # sequence of residual units
        t_res_unit_op = trunk_input_data
        for _ in range(t):
            t_res_unit_op = self.residual_unit(t_res_unit_op)

        return t_res_unit_op

    def mask_branch(self, mask_input_data, r, m=3):
        # r = num of residual units between adjacent pooling layers
        # m = num max pooling / linear interpolations to do

        # Downsampling Step Initialization - Top
        downsampling = self.max_pool_layer(pool_input_data=mask_input_data)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            downsampling = self.residual_unit(residual_input_data=downsampling)

        # Last pooling step before middle step - Bottom
        downsampling = self.max_pool_layer(pool_input_data=downsampling)

        # Middle Residuals - Perform 2*r residual units steps before upsampling
        middleware = downsampling
        for _ in range(2 * r):
            middleware = self.residual_unit(residual_input_data=middleware)

        # Upsampling Step Initialization - Top
        upsampling = self.upsampling_layer(upsampling_input_data=middleware)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            upsampling = self.residual_unit(residual_input_data=upsampling)

        # Last interpolation step - Bottom
        upsampling = self.upsampling_layer(upsampling_input_data=upsampling)

        conv1 = self.convolution_layer(conv_input_data=upsampling, kernel_size=(1, 1))
        conv2 = self.convolution_layer(conv_input_data=conv1, kernel_size=(1, 1))

        sigmoid = Activation('sigmoid')(conv2)

        return sigmoid

    def attention_residual_learning(self, mask_input, trunk_input):
        # https://stackoverflow.com/a/53361303/9221241
        m = Lambda(lambda x: 1 + x)(mask_input) # 1 + mask
        
        # https://www.tensorflow.org/api_docs/python/tf/pad
        # https://stackoverflow.com/questions/43928642/how-does-tensorflow-pad-work
        # https://stackoverflow.com/questions/34141430/tensorflow-tensor-reshape-and-pad-with-zeros
        # if m.shape != trunk_input.shape:
        #    print(max(m.shape[1], trunk_input.shape[1]),max(m.shape[2],trunk_input.shape[2]))
            
        return Multiply()([m, trunk_input]) # M(x) * T(x)

# Model Execution

### https://keras.io/examples/mnist_cnn/

In [None]:
model = ResidualAttentionNetwork(input_shape,num_classes).model

In [None]:
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer="adam",
              metrics=['accuracy'])

In [None]:
history = model.fit_generator(generator=train_generator,
                    steps_per_epoch=STEP_SIZE_TRAIN,
                    validation_data=validation_generator,
                    validation_steps=STEP_SIZE_VALID,
                    workers=8,
                    use_multiprocessing=True,
                    epochs=epochs)

# Visualize Data 

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
ax1.plot(history.history['loss'], color='b', label="Training loss")
ax1.plot(history.history['val_loss'], color='r', label="validation loss")
ax1.set_xticks(np.arange(1, epochs, 1))
ax1.set_yticks(np.arange(0, 1, 0.1))

ax2.plot(history.history['acc'], color='b', label="Training accuracy")
ax2.plot(history.history['val_acc'], color='r',label="Validation accuracy")
ax2.set_xticks(np.arange(1, epochs, 1))

legend = plt.legend(loc='best', shadow=True)
plt.tight_layout()
plt.show()

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

In [None]:
train_generator.classes

In [None]:
model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,

                    validation_data=(x_test, y_test))