In [1]:
import numpy as np
import tensorflow as tf
import keras
import keras.backend as K

Using TensorFlow backend.


In [2]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.cifar10.load_data()
X_train, X_test = X_train/255, X_test/255
if len(X_train.shape) != 4:
    X_train, X_test = X_train[:,:,:,np.newaxis], X_test[:,:,:,np.newaxis]

In [3]:
class FrequencyMultiplicative(keras.layers.Layer):
    
    def __init__(self, filters, freq_filters=None, **kwargs):
        self.filters = filters
        self.freq_filters = freq_filters
        super(FrequencyMultiplicative, self).__init__(**kwargs)
    
    def build(self, input_shape):
        if self.freq_filters is not None:
            self.kernel_freq = self.add_weight(
                shape=(1, input_shape[1], input_shape[2],self.freq_filters, input_shape[3]),
                initializer='he_uniform', name='kernel_freq')
        self.kernel = self.add_weight(
            shape=(1, self.filters,
                   input_shape[3] if self.freq_filters is None else self.freq_filters,
                   input_shape[2], input_shape[1]),
            initializer='he_uniform', name='kernel')
        self.bias = self.add_weight(
            shape=(self.filters,),
            initializer='zeros', name='bias')
        super(FrequencyMultiplicative, self).build(input_shape)
    
    def call(self, inputs):
        x = inputs
        if self.freq_filters is not None:
            x = K.expand_dims(x, axis=-2)
            x = x * self.kernel_freq
            x = K.sum(x, axis=-1, keepdims=False)
        x = K.permute_dimensions(x, (0,3,1,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.expand_dims(x, axis=1)
        x = x * self.kernel
        x = K.sum(x, axis=2, keepdims=False)
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.idct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,2,3,1))
        x = x + self.bias
        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.filters,)
    
    def get_config(self):
        return {
            **super(FrequencyMultiplicative, self).get_config(),
            'filters': self.filters,
            'freq_filters': self.freq_filters,
        }

In [4]:
class DCT(keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super(DCT, self).__init__(**kwargs)
    
    def build(self, input_shape):
        super(DCT, self).build(input_shape)
    
    def call(self, inputs):
        x = inputs
        x = K.permute_dimensions(x, (0,3,1,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,1,3,2))
        x = tf.spectral.dct(x, norm='ortho')
        x = K.permute_dimensions(x, (0,3,2,1))
        return x
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [5]:
class AttentionalPooling(keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super(AttentionalPooling, self).__init__(**kwargs)
    
    def build(self, input_shape):
        data_shape, att_shape = input_shape
        if data_shape[-1] != att_shape[-1]:
            raise Exception('channel count of data and attention required to be equal')
        super(AttentionalPooling, self).build(input_shape)
    
    def call(self, inputs):
        data, att = inputs
        data = K.reshape(data, (-1,data.shape[1]//2,2,data.shape[2]//2,2,data.shape[3]))
        data = K.permute_dimensions(data, (0,1,3,2,4,5))
        data = K.reshape(data, (-1,data.shape[1],data.shape[2],4,data.shape[-1]))
        att = K.reshape(att, (-1,att.shape[1]//2,2,att.shape[2]//2,2,att.shape[-1]))
        att = K.permute_dimensions(att, (0,1,3,2,4,5))
        att = K.reshape(att, (-1,att.shape[1],att.shape[2],4,att.shape[-1]))
        att = K.softmax(att, axis=-2)
        data = data * att
        data = K.sum(data, axis=-2, keepdims=False)
        return data
    
    def compute_output_shape(self, input_shape):
        data_shape, _ = input_shape
        return (data_shape[0], data_shape[1]//2, data_shape[2]//2, data_shape[3])

In [7]:
X = X_input = keras.layers.Input(X_train.shape[1:])
X_spatial = X
X_freq = DCT()(X)
X_inceptions = []
for X in [X_spatial, X_freq]:
    X = keras.layers.BatchNormalization()(X)
    X1 = FrequencyMultiplicative(12)(X)
    X2 = FrequencyMultiplicative(4, freq_filters=4)(X)
    X3 = keras.layers.Conv2D(14, (3,3), padding='same', kernel_initializer='he_uniform')(X)
    X4 = keras.layers.Conv2D(2, (1,1), padding='same', kernel_initializer='he_uniform')(X)
    XA = keras.layers.Conv2D(32, (5,5), padding='same', kernel_initializer='he_uniform')(X)
    X = keras.layers.Concatenate()([X1,X2,X3,X4])
    X = keras.layers.BatchNormalization()(X)
    X = keras.layers.Activation('relu')(X)
    X = AttentionalPooling()([X,XA])
    X1 = FrequencyMultiplicative(24)(X)
    X2 = FrequencyMultiplicative(8, freq_filters=4)(X)
    X3 = keras.layers.Conv2D(28, (3,3), padding='same', kernel_initializer='he_uniform')(X)
    X4 = keras.layers.Conv2D(4, (1,1), padding='same', kernel_initializer='he_uniform')(X)
    XA = keras.layers.Conv2D(64, (5,5), padding='same', kernel_initializer='he_uniform')(X)
    X = keras.layers.Concatenate()([X1,X2,X3,X4])
    X = keras.layers.BatchNormalization()(X)
    X = keras.layers.Activation('relu')(X)
    X = AttentionalPooling()([X,XA])
    X1 = FrequencyMultiplicative(24)(X)
    X2 = FrequencyMultiplicative(8, freq_filters=8)(X)
    X3 = keras.layers.Conv2D(28, (3,3), padding='same', kernel_initializer='he_uniform')(X)
    X4 = keras.layers.Conv2D(4, (1,1), padding='same', kernel_initializer='he_uniform')(X)
    X = keras.layers.Concatenate()([X1,X2,X3,X4])
    X = keras.layers.BatchNormalization()(X)
    X = keras.layers.Activation('relu')(X)
    X = keras.layers.Flatten()(X)
    X_inceptions.append(X)
X = keras.layers.Concatenate()(X_inceptions)
X = keras.layers.Dense(128, kernel_initializer='he_uniform')(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.Dense(32, kernel_initializer='he_uniform')(X)
X = keras.layers.BatchNormalization()(X)
X = keras.layers.Activation('relu')(X)
X = keras.layers.Dense(max(Y_train)+1, activation='softmax')(X)
M = keras.Model(X_input, X)
M.compile('nadam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
dct_2 (DCT)                     (None, 32, 32, 3)    0           input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 32, 32, 3)    12          input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 32, 32, 3)    12          dct_2[0][0]                      
__________________________________________________________________________________________________
frequency_

In [6]:
datagen = keras.preprocessing.image.ImageDataGenerator(
    shear_range=0.1,
    zoom_range=0.1,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True)
datagen.fit(X_train)

In [8]:
M.fit_generator(
    datagen.flow(X_train, Y_train, batch_size=64), 
    validation_data=(X_test, Y_test),
    steps_per_epoch=len(X_train) / 64, epochs=30, callbacks=[
    keras.callbacks.ReduceLROnPlateau(patience=2, verbose=1),
    #keras.callbacks.TensorBoard(
    #    log_dir='./models/frequency-multiplicative-convolution/logs/',
    #    batch_size=64, histogram_freq=1),
])

Epoch 1/30
Epoch 2/30
  2/781 [..............................] - ETA: 10:06 - loss: 0.6163 - acc: 0.7656

KeyboardInterrupt: 

In [9]:
keras.models.save_model(M, './models/frequency-multiplicative-convolution/model.hdf5')

In [7]:
M = keras.models.load_model('./models/frequency-multiplicative-convolution/model.hdf5', custom_objects={
    'FrequencyMultiplicative': FrequencyMultiplicative,
    'DCT': DCT,
    'AttentionalPooling': AttentionalPooling,
})