In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train/255, X_test/255
if len(X_train.shape) != 4:
    X_train, X_test = X_train[:,:,:,np.newaxis], X_test[:,:,:,np.newaxis]

In [3]:
class FrequencyMultiplicative(keras.layers.Layer):
    
    def __init__(self, filters, freq_filters=None, **kwargs):
        self.filters = filters
        self.freq_filters = freq_filters
        super(FrequencyMultiplicative, self).__init__(**kwargs)
    
    def build(self, input_shape):
        if self.freq_filters is not None:
            self.kernel_freq = self.add_weight(
                shape=(1, input_shape[1], input_shape[2],self.freq_filters, input_shape[3]),
                initializer='he_uniform', name='kernel_freq')
        self.kernel = self.add_weight(
            shape=(1, self.filters,
                   input_shape[3] if self.freq_filters is None else self.freq_filters,
                   input_shape[2], input_shape[1]),
            initializer='he_uniform', name='kernel')
        self.bias = self.add_weight(
            shape=(self.filters,),
            initializer='zeros', name='bias')
        super(FrequencyMultiplicative, self).build(input_shape)
    
    def call(self, inputs):
        x = inputs
        if self.freq_filters is not None:
            x = K.expand_dims(x, axis=-2)
            x = x * self.kernel_freq
            x = K.sum(x, axis=-1, keepdims=False)
        x = K.permute_dimensions(x, (0,3,1,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.expand_dims(x, axis=1)
        x = x * self.kernel
        x = K.sum(x, axis=2, keepdims=False)
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,2,3,1))
        x = x + self.bias
        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.filters,)

In [4]:
X = X_input = keras.layers.Input((28, 28, 1))
X = keras.layers.BatchNormalization()(X)
X = FrequencyMultiplicative(8)(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.AveragePooling2D()(X)
X1 = FrequencyMultiplicative(8)(X)
X2 = FrequencyMultiplicative(8, freq_filters=8)(X)
X = keras.layers.Concatenate(axis=-1)([X1,X2])
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.AveragePooling2D()(X)
X = FrequencyMultiplicative(32)(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.Flatten()(X)
X = keras.layers.Dense(10, activation='softmax')(X)
M = keras.Model(X_input, X)
M.compile('nadam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 28, 28, 1)    4           input_1[0][0]                    
__________________________________________________________________________________________________
frequency_multiplicative_1 (Fre (None, 28, 28, 8)    6280        batch_normalization_1[0][0]      
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 28, 28, 8)    32          frequency_multiplicative_1[0][0] 
__________________________________________________________________________________________________
activation

In [5]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    shear_range=0.1,
    zoom_range=0.1,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1)
datagen.fit(X_train)

In [None]:
M.fit_generator(
    datagen.flow(X_train, Y_train, batch_size=64), 
    validation_data=(X_test, Y_test),
    steps_per_epoch=len(X_train) / 64, epochs=30, callbacks=[
    keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1),
])

Epoch 1/30
Epoch 2/30
Epoch 3/30