### FCNNs

In [4]:
import numpy as np
import tensorflow as tf

In [5]:
# Fashion-MNIST test dataset

def load_data():
    import tensorflow as tf
    print('Using tensorflow version {}.'.format(tf.__version__))
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255
    # convert labels to categorical samples
    y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
    print('Loaded Fashion-MNIST into x_train, y_train, x_test, y_test.')
    print('Shapes: x_train: {}, y_train: {}, x_test: {}, y_test: {}'.format(x_train.shape, y_train.shape, x_test.shape, y_test.shape))
    return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

### loading data ###
batch_size=32

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

Using tensorflow version 2.7.1.
Loaded Fashion-MNIST into x_train, y_train, x_test, y_test.
Shapes: x_train: (60000, 28, 28), y_train: (60000, 10), x_test: (10000, 28, 28), y_test: (10000, 10)


2022-08-26 02:15:22.653675: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


##### Do some visuals on the fourier transform

In [11]:
class SpectralConvert(tf.keras.layers.Layer):
    # converts a 2d image of the spatial basis into the spectral one
    def __init__(self, basis):
        super(SpectralConvert, self).__init__()
        self.__name__ = 'spatial_spectral_converter'
        self.basis = basis

    def call(self, inputs):
        # converts tensor input in spatial to that of fourier; lowest freq. centered
        # or do the inverse!
        x = tf.cast(inputs, dtype='complex128')
        if self.basis == 'spectral':
            return tf.signal.fftshift(tf.signal.fft2d(x))
        elif self.basis == 'spatial':
            return tf.math.real(tf.signal.ifft2d(tf.signal.ifftshift(x)))


class SpectralPooling(tf.keras.layers.Layer):
    # applies spectral pooling to the image by reducing higher global frequencies
    # huge limitations here on what can be learnt - perhaps high freq. global informations are important priors!
    # as an assumption though, this isn't the case

    def __init__(self, thresh):
        super(SpectralPooling, self).__init__()
        self.__name__ = 'spectral_pooler'
        self.thresh = thresh # freq. threshold at which to crop spectral decomposition, as a percentage of original size

    def build(self, input_shape):
        # assumes a square input
        x, y = tf.cast(tf.meshgrid([i for i in range(input_shape[-2])], [i for i in range(input_shape[-2])]), dtype='float')
        low_pass_filter = tf.math.sqrt((x - input_shape[-2]//2)**2 + (y-input_shape[-2]//2)**2) <= (input_shape[-2]//2)*self.thresh
        self.low_pass_filter = tf.expand_dims(tf.cast(low_pass_filter, dtype='complex128'), axis=-1)

    def call(self, inputs):
        # returns, in the fourier basis, a clipped version of the original
        return inputs*self.low_pass_filter

# class SpectralConv

##### Paper 1: Spectral pooling ops example usage

In [13]:
tf.keras.backend.clear_session() # functional model to be used

x_shape, y_shape = x_train[0].shape[0], x_train[0].shape[1]
inp0 = tf.keras.Input(shape=(x_shape, y_shape, 1))

# convs (doing 3 such)
c1 = tf.keras.layers.Conv2D(filters=24, kernel_size=(3,3), padding='same')(inp0)
s1 = SpectralConvert('spectral')(c1)
p1 = SpectralPooling(thresh=0.8)(s1) # first spectral pooling op in-usage
i1 = SpectralConvert('spatial')(p1)
# CNN.add(tf.keras.layers.LeakyReLU(alpha=0.01))


c2 = tf.keras.layers.Conv2D(filters=24, kernel_size=(3,3), padding='same')(i1)
s2 = SpectralConvert('spectral')(c2)
p2 = SpectralPooling(thresh=0.8)(s2) # first spectral pooling op in-usage
i2 = SpectralConvert('spatial')(p2)
# CNN.add(tf.keras.layers.LeakyReLU(alpha=0.01))


c3 = tf.keras.layers.Conv2D(filters=24, kernel_size=(3,3), padding='same')(i2)
s3 = SpectralConvert('spectral')(c3)
p3 = SpectralPooling(thresh=0.8)(s3) # first spectral pooling op in-usage
i3 = SpectralConvert('spatial')(p3)
# CNN.add(tf.keras.layers.LeakyReLU(alpha=0.01))

# feed into dense
f1 = tf.keras.layers.Flatten()(i3)
d1 = tf.keras.layers.Dense(64, activation='relu')(f1)
d2 = tf.keras.layers.Dense(64, activation='relu')(d1)
d3 = tf.keras.layers.Dense(10, activation='relu')(d2)

# end
outp = tf.keras.layers.Softmax()(d3)

model = tf.keras.Model([inp0], [outp])
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

opt = tf.keras.optimizers.SGD(learning_rate=1e-3)
model.compile(optimizer=opt, loss=loss_fn)

In [14]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 28, 28, 24)        240       
                                                                 
 spectral_convert (SpectralC  (None, 28, 28, 24)       0         
 onvert)                                                         
                                                                 
 spectral_pooling (SpectralP  (None, 28, 28, 24)       0         
 ooling)                                                         
                                                                 
 spectral_convert_1 (Spectra  (None, 28, 28, 24)       0         
 lConvert)                                                       
                                                             

In [None]:
# model.fit

In [297]:
# read o.g. papers for weight initialisation determinations!

##### Learnable filter passes