In [1]:
import numpy as np

import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Flatten, Dense

from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy
from tensorflow.keras.optimizers import Adam

In [2]:
def get_mnist_dataset():
    (train_val_dataset, test_dataset), dataset_info = tfds.load(name = 'mnist',
                                                                shuffle_files = True,
                                                                as_supervised = True,
                                                                split = ['train', 'test'],
                                                                with_info = True)
    number_train_val = dataset_info.splits['train'].num_examples
    
    train_ratio = 0.8
    number_train = int(number_train_val*train_ratio)
    number_val = number_train_val - number_train
    
    train_dataset = train_val_dataset.take(number_train)
    val_dataset = train_val_dataset.skip(number_train).take(number_val)
    
    return train_dataset, val_dataset, test_dataset

train_dataset, val_dataset, test_dataset = get_mnist_dataset()

2022-02-17 00:25:40.174131: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def standardization(train_dataset, val_dataset, test_dataset, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
    def standard(images, labels):
        images = tf.cast(images, tf.float32) / 255.0
        return (images, labels)
    
    train_dataset = train_dataset.map(standard).shuffle(100).batch(TRAIN_BATCH_SIZE)
    val_dataset = val_dataset.map(standard).batch(TEST_BATCH_SIZE)
    test_dataset = test_dataset.map(standard).batch(TEST_BATCH_SIZE)
    
    return train_dataset, val_dataset, test_dataset

train_dataset, val_dataset, test_dataset = standardization(train_dataset, val_dataset, test_dataset, 32, 32)

In [4]:
class MNIST_Classifier(Model):
    def __init__(self):
        super(MNIST_Classifier, self).__init__()
        
        self.flatten = Flatten()
        self.l1 = Dense(64, activation = 'relu')
        self.l2 = Dense(10, activation = 'softmax')
        
    def call(self, x):
        x = self.flatten(x)
        x = self.l1(x)
        x = self.l2(x)
        
        return x

In [5]:
train_loss = Mean()
val_loss = Mean()    
test_loss = Mean()

train_acc = SparseCategoricalAccuracy()
val_acc = SparseCategoricalAccuracy()
test_acc = SparseCategoricalAccuracy()

In [6]:
@tf.function
def trainer():
    global train_dataset, model, loss_object, optimizer
    global train_loss, train_acc
    
    for images, labels in train_dataset:
        with tf.GradientTape() as tape:
            predictions = model(images)
            loss = loss_object(labels, predictions)
            
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        train_loss(loss)
        train_acc(labels, predictions)

In [7]:
@tf.function
def validator():
    global val_dataset, model, loss_object, optimizer
    global val_loss, val_acc
    
    for images, labels in val_dataset:
        predictions = model(images)
        loss = loss_object(labels, predictions)
        
        val_loss(loss)
        val_acc(labels, predictions)

In [8]:
def tester():
    global test_dataset, model, loss_object, optimizer
    global test_loss, test_acc
    
    for images, labels in test_dataset:
        predictions = model(images)
        loss = loss_object(labels, predictions)
        
        test_loss(loss)
        test_acc(labels, predictions)
        
    print('Test Loss: {:.4f}\t Test Accuracy: {:.2f}%'.format(test_loss.result(), test_acc.result()*100))
    # Not use @tf.function if function contains print()

In [9]:
def train_reporter(epoch):
    global train_loss, train_acc
    global val_loss, val_acc
    
    print('Epoch: ', epoch+1, 'Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\t\
    Validation Loss: {:.4f}\t Validation Accuracy: {:.2f}%'\
    .format(train_loss.result(), train_acc.result()*100, val_loss.result(), val_acc.result()*100))
    
    train_loss.reset_states()
    train_acc.reset_states()
    val_loss.reset_states()
    val_acc.reset_states()

In [10]:
EPOCHS = 10
LR = 0.01

model = MNIST_Classifier()
loss_object = SparseCategoricalCrossentropy()
optimizer = Adam(learning_rate = LR)

for epoch in range(EPOCHS):
    trainer()
    validator()
    train_reporter(epoch)
    
tester()

2022-02-17 00:25:43.028154: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


Epoch:  1 Train Loss: 0.2532	 Train Accuracy: 92.34%	    Validation Loss: 0.1793	 Validation Accuracy: 94.72%
Epoch:  2 Train Loss: 0.1631	 Train Accuracy: 95.26%	    Validation Loss: 0.1853	 Validation Accuracy: 94.80%
Epoch:  3 Train Loss: 0.1388	 Train Accuracy: 96.03%	    Validation Loss: 0.2046	 Validation Accuracy: 94.58%
Epoch:  4 Train Loss: 0.1316	 Train Accuracy: 96.36%	    Validation Loss: 0.2053	 Validation Accuracy: 94.90%
Epoch:  5 Train Loss: 0.1228	 Train Accuracy: 96.69%	    Validation Loss: 0.2051	 Validation Accuracy: 95.14%
Epoch:  6 Train Loss: 0.1156	 Train Accuracy: 96.88%	    Validation Loss: 0.2223	 Validation Accuracy: 95.32%
Epoch:  7 Train Loss: 0.1105	 Train Accuracy: 97.05%	    Validation Loss: 0.2230	 Validation Accuracy: 95.24%
Epoch:  8 Train Loss: 0.1013	 Train Accuracy: 97.34%	    Validation Loss: 0.2288	 Validation Accuracy: 95.30%
Epoch:  9 Train Loss: 0.1039	 Train Accuracy: 97.39%	    Validation Loss: 0.2568	 Validation Accuracy: 95.24%
Epoch:  10