# Residual Attention Network

- A stack of Attention Modules
- Attention Modules have 2 branches
    - Trunk Branch
    - Soft Mask Branch

In [1]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, MaxPool2D, UpSampling2D, AveragePooling2D
from tensorflow.keras.layers import Activation, Flatten, Dense, Add, Multiply, BatchNormalization

from tensorflow.keras.models import Model
import keras

import os

Using TensorFlow backend.


# Residual Attention Network Model

In [2]:
# Todo: Make scalable/all-encompassing
class ResidualAttentionNetwork():

    def __init__(self, input_shape, n_classes, p=1, t=2, r=1):
        '''
        Params:
        - num attention modules
        - p
        - r
        - t


        Conv Layer
        Max Pooling Layer

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit
        Attention Module

        Residual Unit

        Average Pooling

        Flatten

        Dense Layer(s)
        Output Dense Layer (Num. Classer, activation='softmax'))
        '''

        # Initialize a Keras Tensor of input_shape
        input_data = Input(shape=input_shape)
        
        # Initial Layers before Attention Module
        conv_layer_1 = self.convolution_layer(conv_input_data=input_data)
        
        max_pool_layer_1 = self.max_pool_layer(conv_layer_1)

        # Residual Unit then Attention Module #1
        res_unit_1 = self.residual_unit(max_pool_layer_1)
        
        att_mod_1 = self.attention_module(res_unit_1, p, t, r)
        
        # Residual Unit then Attention Module #2
        res_unit_2 = self.residual_unit(att_mod_1)
        att_mod_2 = self.attention_module(res_unit_2, p, t, r)

        # Residual Unit then Attention Module #3
        res_unit_3 = self.residual_unit(att_mod_2)
        att_mod_3 = self.attention_module(res_unit_3, p, t, r)

        # Ending it all
        res_unit_end_1 = self.residual_unit(att_mod_3)
        res_unit_end_2 = self.residual_unit(res_unit_end_1)
        res_unit_end_3 = self.residual_unit(res_unit_end_2)
        res_unit_end_4 = self.residual_unit(res_unit_end_3)

        # Avg Pooling
        avg_pool_layer = self.avg_pool_layer(res_unit_end_4)

        # Flatten the data
        flatten_op = Flatten()(avg_pool_layer)

        # FC Layer for prediction
        fully_connected_layers = Dense(n_classes, activation='softmax')(flatten_op)

        # Fully constructed model
        self.model = Model(inputs=input_data, outputs=fully_connected_layers)

    def convolution_layer(self, conv_input_data, filters=32, kernel_size=(5, 5), strides=(1, 1)):

        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conv_input_data)

        batch_op = BatchNormalization()(conv_op)

        activation_op = Activation('relu')(batch_op)

        return activation_op

    def max_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return MaxPool2D(pool_size=pool_size,
                         strides=strides,
                         padding='same')(pool_input_data)

    def avg_pool_layer(self, pool_input_data, pool_size=(2, 2), strides=(2, 2)):
        return AveragePooling2D(pool_size=pool_size,
                                strides=strides,
                                padding='same')(pool_input_data)

    def upsampling_layer(self, upsampling_input_data, size=(2, 2), interpolation='bilinear'):
        return UpSampling2D(size=size,
                            interpolation=interpolation)(upsampling_input_data)

    def residual_unit(self, residual_input_data):
        # Hold input_x here for later processing
        skipped_x = residual_input_data

        # Layer 1
        res_conv_1 = self.convolution_layer(conv_input_data=residual_input_data, filters=32)

        # Layer 2
        res_conv_2 = self.convolution_layer(conv_input_data=res_conv_1, filters=64)

        # Connecting Layer
        output = self.connecting_residual_layer(conn_input_data=res_conv_2, skipped_x=skipped_x)

        return output

    def connecting_residual_layer(self, conn_input_data, skipped_x, filters=32, kernel_size=(5, 5), strides=(1, 1)):
        # Connecting Layer
        conv_op = Conv2D(filters=filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding='same')(conn_input_data)

        batch_op = BatchNormalization()(conv_op)
        
        # Todo: 
            # Do some work if skipped_x.shape is not the same as batch_op.shape
            # Gotta do the convolution + batch_norm work on skipped x

        # Combine processed_x with skipped_x
        add_op = Add()([batch_op, skipped_x])

        activation_op = Activation('relu')(add_op)

        return activation_op

    def attention_module(self, attention_input_data, p, t, r):

        # Send input_x through #p residual_units
        p_res_unit_op_1 = attention_input_data
        for i in range(p):
            p_res_unit_op_1 = self.residual_unit(p_res_unit_op_1)

        # Perform Trunk Branch Operation
        trunk_branch_op = self.trunk_branch(trunk_input_data=p_res_unit_op_1, t=t)

        # Perform Mask Branch Operation
        mask_branch_op = self.mask_branch(mask_input_data=p_res_unit_op_1, r=r)

        # Perform Attention Residual Learning: Combine Trunk and Mask branch results
        ar_learning_op = self.attention_residual_learning(mask_input=mask_branch_op, trunk_input=trunk_branch_op)

        # Send branch results through #p residual_units
        p_res_unit_op_2 = ar_learning_op
        for _ in range(p):
            p_res_unit_op_2 = self.residual_unit(p_res_unit_op_2)

        return p_res_unit_op_2

    def trunk_branch(self, trunk_input_data, t):
        # sequence of residual units
        t_res_unit_op = trunk_input_data
        for _ in range(t):
            t_res_unit_op = self.residual_unit(t_res_unit_op)

        return t_res_unit_op

    def mask_branch(self, mask_input_data, r, m=3):
        # r = num of residual units between adjacent pooling layers
        # m = num max pooling / linear interpolations to do

        # Downsampling Step Initialization - Top
        downsampling = self.max_pool_layer(pool_input_data=mask_input_data)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            downsampling = self.residual_unit(residual_input_data=downsampling)

        # Last pooling step before middle step - Bottom
        downsampling = self.max_pool_layer(pool_input_data=downsampling)

        # Middle Residuals - Perform 2*r residual units steps before upsampling
        middleware = downsampling
        for _ in range(2 * r):
            middleware = self.residual_unit(residual_input_data=middleware)

        # Upsampling Step Initialization - Top
        upsampling = self.upsampling_layer(upsampling_input_data=middleware)

        # Perform residual units ops r times between adjacent pooling layers
        for j in range(r):
            upsampling = self.residual_unit(residual_input_data=upsampling)

        # Last interpolation step - Bottom
        upsampling = self.upsampling_layer(upsampling_input_data=upsampling)

        conv1 = self.convolution_layer(conv_input_data=upsampling, kernel_size=(1, 1))
        conv2 = self.convolution_layer(conv_input_data=conv1, kernel_size=(1, 1))

        sigmoid = Activation('sigmoid')(conv2)

        return sigmoid

    def attention_residual_learning(self, mask_input, trunk_input):
        # https://stackoverflow.com/a/53361303/9221241
        m = Lambda(lambda x: 1 + x)(mask_input) # 1 + mask
        
        # https://www.tensorflow.org/api_docs/python/tf/pad
        # https://stackoverflow.com/questions/43928642/how-does-tensorflow-pad-work
        # https://stackoverflow.com/questions/34141430/tensorflow-tensor-reshape-and-pad-with-zeros
        # if m.shape != trunk_input.shape:
        #    print(max(m.shape[1], trunk_input.shape[1]),max(m.shape[2],trunk_input.shape[2]))
            
        return Multiply()([m, trunk_input]) # M(x) * T(x)

# Model Execution

### https://www.kaggle.com/uysimty/keras-cnn-dog-or-cat-classification

In [3]:
import numpy as np
import pandas as pd 
from keras.preprocessing.image import ImageDataGenerator, load_img
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random

In [4]:
print(os.listdir("./dogs-vs-cats/"))

['.DS_Store', 'sampleSubmission.csv', 'test1', 'train']


In [5]:
FAST_RUN = False
IMAGE_WIDTH=128
IMAGE_HEIGHT=128
IMAGE_SIZE=(IMAGE_WIDTH, IMAGE_HEIGHT)
IMAGE_CHANNELS=3
input_shape=(IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS)
batch_size = 32
num_classes=2

In [6]:
filenames = os.listdir("./dogs-vs-cats/train/")
categories = []
for filename in filenames:
    category = filename.split('.')[0]
    if category == 'dog':
        categories.append(1)
    else:
        categories.append(0)

df = pd.DataFrame({
    'filename': filenames,
    'category': categories
})

In [7]:
ran_model = ResidualAttentionNetwork(input_shape=input_shape, n_classes=num_classes).model

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
ran_model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
# ran_model.summary(line_length=100)

In [9]:
df["category"] = df["category"].replace({0: 'cat', 1: 'dog'}) 

In [10]:
train_df, validate_df = train_test_split(df, test_size=0.20, random_state=42)
train_df = train_df.reset_index(drop=True)
validate_df = validate_df.reset_index(drop=True)

In [11]:
total_train = train_df.shape[0]
total_validate = validate_df.shape[0]

In [12]:
train_datagen = ImageDataGenerator(
    rotation_range=15,
    rescale=1./255,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
    width_shift_range=0.1,
    height_shift_range=0.1
)

train_generator = train_datagen.flow_from_dataframe(
    train_df, 
    "./dogs-vs-cats/train/", 
    x_col='filename',
    y_col='category',
    target_size=IMAGE_SIZE,
    class_mode='categorical',
    batch_size=batch_size
)

Found 20000 images belonging to 2 classes.


In [13]:
validation_datagen = ImageDataGenerator(rescale=1./255)
validation_generator = validation_datagen.flow_from_dataframe(
    validate_df, 
    "./dogs-vs-cats/train/", 
    x_col='filename',
    y_col='category',
    target_size=IMAGE_SIZE,
    class_mode='categorical',
    batch_size=batch_size
)

Found 5000 images belonging to 2 classes.


In [None]:
epochs=3 if FAST_RUN else 50
history = ran_model.fit_generator(
    train_generator, 
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=total_validate//batch_size,
    steps_per_epoch=total_train//batch_size,    
    workers=6,
    use_multiprocessing=True
)

Instructions for updating:
Use tf.cast instead.
Epoch 1/50
  4/625 [..............................] - ETA: 3:02:33 - loss: 6.2804 - acc: 0.5156

Process ForkPoolWorker-2:
Process ForkPoolWorker-1:
Process ForkPoolWorker-4:
Process ForkPoolWorker-3:
Process ForkPoolWorker-5:
Process ForkPoolWorker-6:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/

  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__

  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in ge

Traceback (most recent call last):
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiproces

  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
KeyboardInterrupt
KeyboardInterrupt
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
KeyboardInterrupt
Process ForkPoolWorker-49:
Process ForkPoolWorker-46:
Process ForkPoolWorker-47:
Process ForkPoolWorker-48:
Process ForkPoolWorker-45:
Process ForkPoolWorker-44:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/anaconda3/envs/MachineLearning/lib/python3.6/mul

  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/anaconda3/envs/MachineLearning/lib/python3.6/multiproce

# Visualize Data 

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
ax1.plot(history.history['loss'], color='b', label="Training loss")
ax1.plot(history.history['val_loss'], color='r', label="validation loss")
ax1.set_xticks(np.arange(1, epochs, 1))
ax1.set_yticks(np.arange(0, 1, 0.1))

ax2.plot(history.history['acc'], color='b', label="Training accuracy")
ax2.plot(history.history['val_acc'], color='r',label="Validation accuracy")
ax2.set_xticks(np.arange(1, epochs, 1))

legend = plt.legend(loc='best', shadow=True)
plt.tight_layout()
plt.show()