# Package

In [1]:
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten, Dense, Activation
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

# Utils

In [2]:
def get_mnist_ds():
    (train_validation_ds, test_ds), ds_info = tfds.load(name = "mnist",
                                                        shuffle_files=True,
                                                        as_supervised=True,
                                                        split = ["train", "test"],
                                                        with_info=True)

    n_train_validation = ds_info.splits["train"].num_examples
    train_ratio = 0.8
    n_train = int(n_train_validation * train_ratio)
    n_validation = n_train_validation - n_train

    train_ds = train_validation_ds.take(n_train)
    remaining_ds = train_validation_ds.skip(n_train)
    validation_ds = remaining_ds.take(n_validation)

    return train_ds, validation_ds, test_ds

def standardization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
    global train_ds, validation_ds, test_ds

    def stnd(images, lables):
        images = tf.cast(images, tf.float32) / 255.
        return [images, lables]

    train_ds = train_ds.map(stnd).shuffle(1000).batch(TRAIN_BATCH_SIZE)
    validation_ds = validation_ds.map(stnd).shuffle(1000).batch(TEST_BATCH_SIZE)
    test_ds = test_ds.map(stnd).shuffle(1000).batch(TEST_BATCH_SIZE)


In [3]:
class MNIST_Classifier(Model):
    def __init__(self):
        super(MNIST_Classifier, self).__init__()

        self.flatten = Flatten()
        self.d1 = Dense(64, activation = "relu")
        self.d2 = Dense(10, activation = "softmax")

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        x = self.d2(x)
        return x

In [4]:
def load_metrics():
    global train_loss, train_acc
    global validation_loss, validation_acc
    global test_loss, test_acc

    train_loss = Mean()
    validation_loss = Mean()
    test_loss = Mean()

    train_acc = SparseCategoricalAccuracy()
    validation_acc = SparseCategoricalAccuracy()
    test_acc = SparseCategoricalAccuracy()

In [5]:
@tf.function
def trainer():
    global train_ds, optimizer, train_loss, train_acc, model, loss_object
    for x, y in train_ds:
        with tf.GradientTape() as tape:
            predictions = model(x)
            loss = loss_object(y, predictions)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_acc(y, predictions)

@tf.function
def validation():
    global validation_ds, model, loss_object, validation_loss, validation_acc
    for x, y in validation_ds:
        predictions = model(x)
        loss = loss_object(y, predictions)

        validation_loss(loss)
        validation_acc(y, predictions)

@tf.function
def tester():
    global test_ds, model, loss_object, test_loss, test_acc
    for x, y in test_ds:
        predictions = model(x)
        loss = loss_object(y, predictions)

        test_loss(loss)
        test_acc(y, predictions)

# Data Generate

In [6]:
TRAIN_BATCH_SIZE = 16
TEST_BATCH_SIZE = 32

In [7]:
train_ds, validation_ds, test_ds = get_mnist_ds()
standardization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


# Model

In [None]:
model = MNIST_Classifier()

# Initiation

In [None]:
EPOCHS = 10
LR = 0.001

loss_object = SparseCategoricalCrossentropy()
optimizer = SGD(learning_rate = LR)

load_metrics()

# Train & Test

In [None]:
for epoch in range(EPOCHS):
    trainer()
    validation()
    print(f'train_loss: {(train_loss.result()):.3f}, train_acc: {(train_acc.result() * 100):.3f}')
    print(f'validation_loss: {(validation_loss.result()):.3f}, validation_acc: {(validation_acc.result() * 100):.3f}')

    train_loss.reset_states(); train_acc.reset_states()
    validation_loss.reset_states(); validation_acc.reset_states()

train_loss: 0.251, train_acc: 92.852
validation_loss: 0.262, validation_acc: 92.725
train_loss: 0.246, train_acc: 93.004
validation_loss: 0.260, validation_acc: 92.783
train_loss: 0.243, train_acc: 93.094
validation_loss: 0.257, validation_acc: 92.817
train_loss: 0.240, train_acc: 93.165
validation_loss: 0.254, validation_acc: 92.900
train_loss: 0.237, train_acc: 93.262
validation_loss: 0.252, validation_acc: 92.942
train_loss: 0.234, train_acc: 93.298
validation_loss: 0.250, validation_acc: 93.058
train_loss: 0.231, train_acc: 93.388
validation_loss: 0.248, validation_acc: 93.092
train_loss: 0.229, train_acc: 93.463
validation_loss: 0.245, validation_acc: 93.167
train_loss: 0.226, train_acc: 93.569
validation_loss: 0.243, validation_acc: 93.208
train_loss: 0.224, train_acc: 93.650
validation_loss: 0.241, validation_acc: 93.267


In [None]:
tester()
print(f'Test_loss: {(test_loss.result()):.3f}, test_acc: {(test_acc.result() * 100):.3f}')

Test_loss: 0.218, test_acc: 93.760
