In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train/255, X_test/255
if len(X_train.shape) != 4:
    X_train, X_test = X_train[:,:,:,np.newaxis], X_test[:,:,:,np.newaxis]

In [3]:
class FrequencyMultiplicative(keras.layers.Layer):
    
    def __init__(self, filters, freq_filters=None, **kwargs):
        self.filters = filters
        self.freq_filters = freq_filters
        super(FrequencyMultiplicative, self).__init__(**kwargs)
    
    def build(self, input_shape):
        if self.freq_filters is not None:
            self.kernel_freq = self.add_weight(
                shape=(1, input_shape[1], input_shape[2],self.freq_filters, input_shape[3]),
                initializer='he_uniform', name='kernel_freq')
        self.kernel = self.add_weight(
            shape=(1, self.filters,
                   input_shape[3] if self.freq_filters is None else self.freq_filters,
                   input_shape[2], input_shape[1]),
            initializer='he_uniform', name='kernel')
        self.bias = self.add_weight(
            shape=(self.filters,),
            initializer='zeros', name='bias')
        super(FrequencyMultiplicative, self).build(input_shape)
    
    def call(self, inputs):
        x = inputs
        if self.freq_filters is not None:
            x = K.expand_dims(x, axis=-2)
            x = x * self.kernel_freq
            x = K.sum(x, axis=-1, keepdims=False)
        x = K.permute_dimensions(x, (0,3,1,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.expand_dims(x, axis=1)
        x = x * self.kernel
        x = K.sum(x, axis=2, keepdims=False)
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,2,3,1))
        x = x + self.bias
        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.filters,)

In [4]:
class AttentionalPooling(keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super(AttentionalPooling, self).__init__(**kwargs)
    
    def build(self, input_shape):
        data_shape, att_shape = input_shape
        if data_shape[-1] != att_shape[-1]:
            raise Exception('channel count of data and attention required to be equal')
        super(AttentionalPooling, self).build(input_shape)
    
    def call(self, inputs):
        data, att = inputs
        data = K.reshape(data, (-1,data.shape[1]//2,2,data.shape[2]//2,2,data.shape[3]))
        data = K.permute_dimensions(data, (0,1,3,2,4,5))
        data = K.reshape(data, (-1,data.shape[1],data.shape[2],4,data.shape[-1]))
        att = K.reshape(att, (-1,att.shape[1]//2,2,att.shape[2]//2,2,att.shape[-1]))
        att = K.permute_dimensions(att, (0,1,3,2,4,5))
        att = K.reshape(att, (-1,att.shape[1],att.shape[2],4,att.shape[-1]))
        att = K.softmax(att, axis=-2)
        data = data * att
        data = K.sum(data, axis=-2, keepdims=False)
        return data
    
    def compute_output_shape(self, input_shape):
        data_shape, _ = input_shape
        return (data_shape[0], data_shape[1]//2, data_shape[2]//2, data_shape[3])

In [5]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X = keras.layers.BatchNormalization()(X)
X1 = FrequencyMultiplicative(6)(X)
X2 = FrequencyMultiplicative(2, freq_filters=4)(X)
X3 = keras.layers.Conv2D(7, (5,5), padding='same', kernel_initializer='he_uniform')(X)
X4 = keras.layers.Conv2D(1, (1,1), padding='same', kernel_initializer='he_uniform')(X)
XA = keras.layers.Conv2D(4, (7,7), padding='same', kernel_initializer='he_uniform', activation='relu')(X)
XA = keras.layers.Conv2D(16, (1,1), padding='same', kernel_initializer='he_uniform')(XA)
X = keras.layers.Concatenate()([X1,X2,X3,X4])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = AttentionalPooling()([X,XA])
X1 = FrequencyMultiplicative(6)(X)
X2 = FrequencyMultiplicative(2, freq_filters=4)(X)
X3 = keras.layers.Conv2D(7, (5,5), padding='same', kernel_initializer='he_uniform')(X)
X4 = keras.layers.Conv2D(1, (1,1), padding='same', kernel_initializer='he_uniform')(X)
XA = keras.layers.Conv2D(4, (7,7), padding='same', kernel_initializer='he_uniform', activation='relu')(X)
XA = keras.layers.Conv2D(16, (1,1), padding='same', kernel_initializer='he_uniform')(XA)
X = keras.layers.Concatenate()([X1,X2,X3,X4])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = AttentionalPooling()([X,XA])
X1 = FrequencyMultiplicative(6)(X)
X2 = FrequencyMultiplicative(2, freq_filters=4)(X)
X3 = keras.layers.Conv2D(7, (5,5), padding='same', kernel_initializer='he_uniform')(X)
X4 = keras.layers.Conv2D(1, (1,1), padding='same', kernel_initializer='he_uniform')(X)
X = keras.layers.Concatenate()([X1,X2,X3,X4])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.Flatten()(X)
X = keras.layers.Dense(10, activation='softmax')(X)
M = keras.Model(X_input, X)
M.compile('nadam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 28, 28, 1)    4           input_1[0][0]                    
__________________________________________________________________________________________________
frequency_multiplicative_1 (Fre (None, 28, 28, 6)    4710        batch_normalization_1[0][0]      
__________________________________________________________________________________________________
frequency_multiplicative_2 (Fre (None, 28, 28, 2)    9410        batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_1 (

In [6]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    shear_range=0.1,
    zoom_range=0.1,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1)
datagen.fit(X_train)

In [7]:
M.fit_generator(
    datagen.flow(X_train, Y_train, batch_size=64), 
    validation_data=(X_test, Y_test),
    steps_per_epoch=len(X_train) / 64, epochs=30, callbacks=[
    keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1),
])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30

Epoch 00009: ReduceLROnPlateau reducing learning rate to 0.00020000000949949026.
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30

Epoch 00017: ReduceLROnPlateau reducing learning rate to 2.0000000949949027e-05.
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30

Epoch 00022: ReduceLROnPlateau reducing learning rate to 2.0000001313746906e-06.
Epoch 23/30
Epoch 24/30
Epoch 25/30

Epoch 00025: ReduceLROnPlateau reducing learning rate to 2.000000222324161e-07.
Epoch 26/30
Epoch 27/30
Epoch 28/30

Epoch 00028: ReduceLROnPlateau reducing learning rate to 2.000000165480742e-08.
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x7f5f404faf28>