In [1]:
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [None]:
class MNIST:
    def __init__(self, export_path, buffer_size=1000, batch_size=32, learning_rate=1e-3, epochs=10):
        self._export_path = export_path
        self._buffer_size = buffer_size
        self._batch_size = batch_size
        self._learning_rate = learning_rate
        self._epochs = epochs
        
        self._build_model()
        self.train_dataset, self.test_dataset = self._prepare_dataset()
        
    def preprocess_fn(self, x):
        # Cast x to tf.float32 using the tf.cast func
        # Also normalize the func between [0, 1]
        x = tf.cast(x, dtype=tf.float32) / 255.0
        return x
    
    def _build_model(self):
        
        self._model = tf.keran.models.Sequential([
            tf.keras.layers.Input(shape=(28, 28, 1), dtype=tf.uint8),
            tf.keras.layers.Lambda(lambda x: self._preprocess_fn(x)),
            tf.keras.layers.Conv2D(8, 3, padding='same'),
            tf.keras.layers.MaxPool2D(),
            tf.keras.layers.Conv2D(16, 3, padding='same'),
            tf.keras.layers.MaxPool2D(),
            tf.keras.layers.Conv2D(32, 3, padding='same'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128),
            tf.keras.layers.Dense(10)
        ])
        optimizer_fn = tf.keras.optimizers.Adam(lr=self._learning_rate)
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
        metrics_list = ['accuracy']
        
        self._model.compile(optimizer_fn, loss=loss_fn, metrics=metrics_list)
        
    def _prepare_dataset(self):
        filePath = f"{getcwd()}/../tmp2"
        
        dataset = tfds.load(data_dir=filePath, split=tfds.Split.TRAIN + tfds.split.TEST)

In [None]:
Layer (type)                 Output Shape              Param #   
=================================================================
lambda_1 (Lambda)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 8)         80        
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 14, 14, 8)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 14, 14, 16)        1168      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 7, 7, 16)          0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 7, 7, 32)          4640      
_________________________________________________________________
flatten_1 (Flatten)          (None, 1568)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               200832    
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
=================================================================
Total params: 208,010
Trainable params: 208,010
Non-trainable params: 0
_________________________________________________________________