In [5]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras.backend as K

In [6]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train/255, X_test/255
if len(X_train.shape) != 4:
    X_train, X_test = X_train[:,1:-1,:,np.newaxis], X_test[:,1:-1,:,np.newaxis]

In [7]:
class FrequencyMultiplicative(keras.layers.Layer):
    
    def __init__(self, filters, **kwargs):
        self.filters = filters
        super(FrequencyMultiplicative, self).__init__(**kwargs)
    
    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(1, self.filters, input_shape[3], input_shape[2], input_shape[1]),
            initializer='he_uniform', name='kernel')
        self.bias = self.add_weight(
            shape=(self.filters,),
            initializer='zeros', name='bias')
        super(FrequencyMultiplicative, self).build(input_shape)
    
    def call(self, inputs):
        x = inputs
        x = K.permute_dimensions(x, (0,3,1,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.expand_dims(x, axis=1)
        x = x * self.kernel
        x = K.sum(x, axis=2, keepdims=False)
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,2,3,1))
        x = x + self.bias
        print(x.shape)
        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.filters,)

In [8]:
X = X_input = keras.layers.Input((26, 28, 1))
X = keras.layers.BatchNormalization()(X)
X = FrequencyMultiplicative(8)(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.AveragePooling2D()(X)
X = FrequencyMultiplicative(16)(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.AveragePooling2D()(X)
X = FrequencyMultiplicative(32)(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.Flatten()(X)
X = keras.layers.Dense(10, activation='softmax')(X)
M = keras.Model(X_input, X)
M.compile('nadam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

(?, 26, 28, 8)
(?, 13, 14, 16)
(?, 6, 7, 32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 26, 28, 1)         0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 26, 28, 1)         4         
_________________________________________________________________
frequency_multiplicative_2 ( (None, 26, 28, 8)         5832      
_________________________________________________________________
batch_normalization_3 (Batch (None, 26, 28, 8)         32        
_________________________________________________________________
activation_1 (Activation)    (None, 26, 28, 8)         0         
_________________________________________________________________
average_pooling2d_1 (Average (None, 13, 14, 8)         0         
_________________________________________________________________
frequency_multiplicative_3 ( (N

In [9]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    shear_range=0.1,
    zoom_range=0.1,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1)
datagen.fit(X_train)

In [10]:
M.fit_generator(
    datagen.flow(X_train, Y_train, batch_size=64), 
    validation_data=(X_test, Y_test),
    steps_per_epoch=len(X_train) / 64, epochs=30, callbacks=[
    keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1),
])

Epoch 1/30
Epoch 2/30
 89/937 [=>............................] - ETA: 47s - loss: 0.1221 - acc: 0.9594

KeyboardInterrupt: 