# Notebook Setup

In [None]:
import tensorflow as tf
import tensorflow.keras as K

In [None]:
tf.__version__

In [None]:
import numpy as np

In [None]:
# import encoders
import sys
sys.path.append('../functional_api')
import encoders_fun as enet

# Load Fashion MNIST Dataset

In [None]:
# import dataset
mnist = K.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# reshape dataset
x_train = x_train.reshape([-1,28,28,1])
x_test = x_test.reshape([-1,28,28,1])
y_train = K.utils.to_categorical(y_train)
y_test = K.utils.to_categorical(y_test)

### Data generator for data augmentation

In [None]:
datagen = K.preprocessing.image.ImageDataGenerator(
                             rotation_range=5,
                             width_shift_range=0.1,
                             height_shift_range=0.1,
                             horizontal_flip=False)

In [None]:
datagen.fit(x_train)

# Fashion Mnist Baseline Model

In [None]:
# defining model
input_layer = K.layers.Input(shape=(28,28,1))
output_layer = enet.fashion_mnist_baseline(input_layer)
fashion_mnist_baseline_model = K.Model(inputs=input_layer, outputs=output_layer)

In [None]:
# The compile step specifies the training configuration.
fashion_mnist_baseline_model.compile(optimizer=tf.train.AdamOptimizer(),
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])

In [None]:
fashion_mnist_baseline_model.summary()

### Run without Data Augmentation

In [None]:
fashion_mnist_baseline_model.fit(x=x_train,y=y_train, batch_size=256,epochs=50,
                                 validation_data=(x_test,y_test))

### Run with Data Augmentation

In [None]:
# defining model
input_layer = K.layers.Input(shape=(28,28,1))
output_layer = enet.fashion_mnist_baseline(input_layer)
fashion_mnist_baseline_model = K.Model(inputs=input_layer, outputs=output_layer)

In [None]:
# The compile step specifies the training configuration.
fashion_mnist_baseline_model.compile(optimizer=tf.train.AdamOptimizer(),
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])

In [None]:
fashion_mnist_baseline_model.fit_generator(datagen.flow(x_train, y_train, batch_size=1024),epochs=50,
                          validation_data=(x_test,y_test))

# Fashion Mnist Enet Model

In [None]:
# input layer
input_layer_enet = tf.keras.layers.Input(shape=(28,28,1))

In [None]:
# mnist model enet
output_layer_enet = enet.fashion_mnist_enc(input_layer_enet,dropout=0.25)
fashion_mnist_model_enet = tf.keras.Model(inputs=input_layer_enet, outputs=output_layer_enet)

In [None]:
fashion_mnist_model_enet.summary()

### Run without Data Augmentation

In [None]:
# lr_scheduler
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs',batch_size=1024)]

In [None]:
# The compile step specifies the training configuration.
fashion_mnist_model_enet.compile(optimizer=tf.train.AdamOptimizer(learning_rate=3e-4),
                         loss='categorical_crossentropy',
                         metrics=['accuracy'])

In [None]:
fashion_mnist_baseline_model.summary()

In [None]:
fashion_mnist_model_enet.fit(x=x_train,y=y_train, batch_size=1024,epochs=50,
                             validation_data=(x_test,y_test),
                             callbacks=callbacks)

### Run with Data Augmentation

In [None]:
# mnist model enet
output_layer_enet = encoders.enet_encoder_mnist(input_layer_enet,dropout=0.20)
mnist_model_enet = tf.keras.Model(inputs=input_layer_enet, outputs=output_layer_enet)

In [None]:
mnist_model_enet.summary()

In [None]:
# lr_scheduler
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs',batch_size=1024)]

In [None]:
# The compile step specifies the training configuration.
mnist_model_enet.compile(optimizer=tf.train.AdamOptimizer(learning_rate=3e-4),
                         loss='categorical_crossentropy',
                         metrics=['accuracy'])

In [None]:
mnist_model_enet.fit_generator(datagen.flow(x_train, y_train, batch_size=1024),epochs=200,
                               validation_data=(x_test,y_test),
                               callbacks=callbacks)