## **Setup**

In [6]:
# Import the necesary packages
import numpy as np

import tensorflow as tf

from sklearn.model_selection import train_test_split

In [16]:
config = {
    # Basic information
    "AUTHOR": "Kiernan",
    
    # Training params
    "BATCH_SIZE": 32,
    "EPOCHS": 30,
    
    # Model params
    "EMBEDDING_SIZE": 32
    
}

## **Loading Data**

In [17]:
(X, y), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

X = (X.astype(np.float32) - 127.5) / 127.5
X_test = (X_test.astype(np.float32) - 127.5) / 127.5

X = X.reshape((*X.shape, 1))
X_test = X_test.reshape((*X_test.shape, 1))

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=X_test.shape[0], shuffle=True)
print(f"Train data shape: {X_train.shape} Val data shape: {X_val.shape} Test data shape: {X_test.shape}")

Train data shape: (50000, 28, 28, 1) Val data shape: (10000, 28, 28, 1) Test data shape: (10000, 28, 28, 1)


## **Create Model**

In [19]:
def create_base(image_shape):
    inputs = tf.keras.layers.Input(shape=image_shape)
    
    x = tf.keras.layers.Conv2D(16, (3,3), padding="same")(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv2D(16, (3,3), padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv2D(16, (3,3), padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv2D(config["EMBEDDING_SIZE"], (1,1), padding="same")(x)
    outputs = tf.keras.layers.GlobalAveragePooling2D()(x)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs)

base = create_base(X_train.shape[1:])
base.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 16)        160       
_________________________________________________________________
batch_normalization_6 (Batch (None, 28, 28, 16)        64        
_________________________________________________________________
re_lu_6 (ReLU)               (None, 28, 28, 16)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 28, 28, 16)        2320      
_________________________________________________________________
batch_normalization_7 (Batch (None, 28, 28, 16)        64        
_________________________________________________________________
re_lu_7 (ReLU)               (None, 28, 28, 16)        0   

In [None]:
def create_head(n_classes):
    inputs = tf.keras.layers.Input(shape=(config["EMBEDDING_SIZE"]))
    
    x = tf.keras.layers.Dense(n_classes, activation='softmax')