# Example of HMNIST using BW-CNN

In [None]:
import numpy as np
import tensorflow_datasets as tfds
import cv2 as cv
import tensorflow as tf
import FourierNetworks as fn

## Auxiliary functions

In [None]:
# import HMNIST dataset and resize
def get_HMNIST(resize_dim=None):
    data = tfds.load('colorectal_histology', split='train', shuffle_files=True)

    Xs = []
    ys = []
    for di in data:
        img = np.array(di['image'], dtype=np.uint8)
        label = np.array(di['label'])

        if resize_dim != None:
            img = cv.resize(img, (resize_dim,resize_dim))
        
        Xs.append(img)
        ys.append(label)
    Xs = np.array(Xs, dtype=np.float32)
    ys = np.array(ys, dtype=np.float32)

    return Xs, ys

# performs the centered Fourier transform of images
def freq_data_3d(data, shift=True, out_dim=None):
    Xs = []
    for img in data:
        # Fourier transform
        fft = []
        for i in range(3):
            ffti = np.fft.fft2(img[:,:,i])
            if shift:
                ffti = np.fft.fftshift(ffti)
            fft.append(ffti)
        fft = np.array(fft)
        fft = np.swapaxes(fft, 0,2)
        fft = np.swapaxes(fft, 0,1)
        Xs.append([np.real(fft), np.imag(fft)])
    
    Xs = np.array(Xs, dtype=np.float32)

    if out_dim != None:
        mid = int(Xs.shape[2]/2)
        dx = int(out_dim/2)
        print(mid, dx)
        Xs = Xs[:,:,mid-dx:mid+dx,mid-dx:mid+dx,:]
    
    return Xs


# identification of the epoch with maximum validation accuracy
class MaxEpoch(tf.keras.callbacks.Callback):
    def __init__(self, epochs):
        super().__init__()
        self.epochs = epochs # number of epochs
        self.val_loss = [] # loss functions data

        self.max_epoch = 0
        self.max_val_acc = 0.0
        self.max_weights = None

    def on_epoch_end(self, epoch, logs=None):
        if logs.get('val_acc') > self.max_val_acc:
            self.max_epoch = epoch
            self.max_val_acc = logs.get('val_acc')
            self.max_weights = self.model.get_weights()
        self.val_loss.append(logs.get('val_loss'))

        return super().on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        return super().on_train_end(logs)


## Data preparation

In [None]:
data, labels = get_HMNIST(resize_dim=128)
data = freq_data_3d(data) # centered Fourier transform

# split train (70%), val (10%), test (20%)
inds = np.arange(0, len(data), 1, dtype=np.int32)
np.random.seed(1) # seed
np.random.shuffle(inds)
inds_train = inds[0:int(0.7*len(data))]
inds_val = inds[int(0.7*len(data)): int(0.8*len(data))]
inds_test = inds[int(0.8*len(data)):]

X_train = data[inds_train]
y_train = labels[inds_train]
X_val = data[inds_val]
y_val = labels[inds_val]
X_test = data[inds_test]
y_test = labels[inds_test]

# free memory
data = None
labels = None

X_train.shape, X_val.shape, X_test.shape

## Training and evaluation on test set

In [None]:
# hyperparameters
epochs = 10
lr = 0.00001
momentum = 0.9
batch_size = 32

# BW-CNN architecture
input = tf.keras.layers.Input((2,128,128,3))
    
c1 = fn.ButterworthLayer(filters=8, norm=1.0, es=0.45, d=2, act='crelu')(input)
p1 = fn.Spect_Avg_Pool()(c1)

c2 = fn.ButterworthLayer(filters=32, norm=1.0, es=0.45, d=2, act='crelu')(p1)
p2 = fn.Spect_Avg_Pool()(c2)

c3 = fn.ButterworthLayer(filters=64, norm=1.0, es=0.45, d=2, act='crelu')(p2)
p3 = fn.Spect_Avg_Pool()(c3)

ifft = fn.IFFT()(p3)
flat = tf.keras.layers.Flatten()(ifft)
bn = tf.keras.layers.BatchNormalization()(flat)

dense = tf.keras.layers.Dense(128, 'relu')(bn)
out = tf.keras.layers.Dense(8, 'softmax')(dense)
model = tf.keras.Model(input, out)
model.summary()

model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=lr, momentum=momentum),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["acc"]
    )
callback = MaxEpoch(epochs=epochs)
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(X_val, y_val),
                        callbacks=callback, shuffle=True)

# evaluation on test set
model_max = tf.keras.models.clone_model(model)
model_max.set_weights(callback.max_weights)

m = tf.keras.metrics.CategoricalAccuracy()
m.reset_state()
m.update_state(tf.one_hot(y_test, depth=8), model_max.predict(X_test))
test_max_acc = m.result().numpy()

print(f'\n---- MAX EPOCH: {callback.max_epoch+1} TEST_ACC: {test_max_acc} ----')