In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import math

(train, test) , info = tfds.load('mnist', split =['train', 'test'], 
                                 as_supervised=True , with_info=True)

In [None]:
print(info)

1. How manny train/test images are there? — 60,000/10,000
2. Whats the image shape? — (28, 28, 1)
3. What rage are the pixel values in? — [0; 255]


In [None]:
tfds.show_examples (train, info)

In [None]:
def preprocess(data):
    # RESHAPE
    data = data.map(lambda img, target: (tf.reshape(img, (28**2,)), target))
    # DATA TYPE
    data = data.map(lambda img, target: (tf.cast(img, tf.float32), target))
    # NORMALIZE
    data = data.map(lambda img, target: (img/128. - 1., target))
    # ONE HOT ENCODING
    data = data.map(lambda img, target: (img, tf.one_hot(target, depth=10)))
    # DATAFLOW PREP
    data = data.cache()
    data = data.shuffle(1000, seed=42)
    data = data.batch(32)
    data = data.prefetch(20)
    return data

train = train.apply(preprocess)
test  = test.apply(preprocess)

In [None]:
class Affine(tf.keras.layers.Layer):
    name = 'Affine'
    def __init__(self, n_output, activation, **kwargs):
        super().__init__(**kwargs)
        self.n_output   = n_output
        self.activation = activation if activation is not None else tf.identity
    

    def build(self, n_input):
        self.n_input = n_input[-1]
        limit = math.sqrt(6/(self.n_input + self.n_output))
        self.W = tf.Variable(tf.random.uniform((self.n_input, self.n_output), -limit, limit))
        self.b = tf.Variable(tf.zeros(shape=(1, self.n_output)))
        
    
    @tf.function
    def call(self, x):
        return self.activation(x @ self.W + self.b)

In [None]:
class Vanilla(tf.keras.Model):
    def __init__(self, sizes, activations, **kwargs):
        super().__init__(**kwargs)
        self.layers_list = [Affine(size, activation) for size, activation in zip(sizes, activations)]
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_metric = tf.keras.metrics.Mean(name="loss")
        self.accuracy_metric = tf.keras.metrics.CategoricalAccuracy(name="acc")
        self.loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

    
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_state()
    
        
    def __iter__(self):
        return iter(self.layers)

    
    @tf.function
    def call(self, x):
        for layer in self:
            x = layer(x)
        return x
    
    
    @tf.function
    def train_step(self, X, T):
        # TRAIN NETWORK
        with tf.GradientTape() as tape:
            Y_logit = self(X)
            L = self.loss(T, Y_logit)
        gradient = tape.gradient(L, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradient, self.trainable_variables))
        # UPDATE METRICS
        Y = tf.nn.softmax(Y_logit)
        self.loss_metric.update_state(L)
        self.accuracy_metric.update_state(T, Y)
        return {metric.name: float(metric.result()) for metric in self.metrics}
    
    
    @tf.function
    def test_step(self, X, T):
        # COMPUTE METRICS
        Y_logit = self(X)
        Y = tf.nn.softmax(Y_logit)
        L = self.loss(T, Y_logit)
        # UPDATE METRICS
        self.loss_metric.update_state(L)
        self.accuracy_metric.update_state(T, Y)
        return {metric.name: metric.result() for metric in self.metrics}

In [None]:
model = Vanilla([128, 64, 32, 10], [tf.nn.relu, tf.nn.relu, tf.nn.relu, None])
model(np.zeros((1, 28**2)))
model.summary()

In [None]:
import datetime


# DEFINE PATHS
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_path = f"logs/{current_time}/train"
val_log_path = f"logs/{current_time}/val"

# CONSTRUCT WRITERS
writer_train = tf.summary.create_file_writer(train_log_path)
writer_val   = tf.summary.create_file_writer(val_log_path)

In [None]:
from tqdm.notebook import tqdm


def training(epochs, model, train, test):
    with tqdm(range(epochs), leave=True) as out_bar:
        for epoch in out_bar:
            with tqdm(train, leave=False) as in_bar:
                out_bar.set_description('TRAINING')
                for X, T in in_bar:
                    metrics = model.train_step(X, T)
                    with writer_train.as_default():
                        for metric in model.metrics:
                            tf.summary.scalar(metric.name, metric.result(), step=epoch)
                    in_bar.set_postfix({key: value.numpy() for key, value in metrics.items()})
                    model.reset_metrics()
            with tqdm(test, leave=False) as in_bar:
                out_bar.set_description('TESTING')
                for X, T in in_bar:
                    metrics = model.test_step(X, T)
                    with writer_val.as_default():
                        for metric in model.metrics:
                            tf.summary.scalar(metric.name, metric.result(), step=epoch)
                    in_bar.set_postfix({key: value.numpy() for key, value in metrics.items()})
                    model.reset_metrics()

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/

In [None]:
training(epochs=1000, model=model, train=train, test=test)

In [None]:
model.save_weights('mnist_vanilla_v1', save_format='tf')

In [None]:
X, T = batch = next(iter(test.batch(200)))
Y_logits = model(X)
Y = tf.nn.softmax(Y_logits)
print(f"{round(np.mean(np.argmax(Y, axis=1) == np.argmax(T, axis=1)) * 100, 1)}%")

In [8]:
from tqdm import tqdm#
#from tqdm.notebook import tqdm
from time import sleep

for _ in tqdm(range(12), leave=True):
    sleep(0.1)
    for _ in tqdm(range(12), leave=False):
        sleep(0.01)

  0%|                                                                            | 0/12 [00:00<?, ?it/s]
  0%|                                                                            | 0/12 [00:00<?, ?it/s][A
 83%|███████████████████████████████████████████████████████▊           | 10/12 [00:00<00:00, 98.79it/s][A
  8%|█████▋                                                              | 1/12 [00:00<00:02,  4.29it/s][A
  0%|                                                                            | 0/12 [00:00<?, ?it/s][A
 83%|███████████████████████████████████████████████████████▊           | 10/12 [00:00<00:00, 98.64it/s][A
 17%|███████████▎                                                        | 2/12 [00:00<00:02,  4.28it/s][A
  0%|                                                                            | 0/12 [00:00<?, ?it/s][A
 83%|███████████████████████████████████████████████████████▊           | 10/12 [00:00<00:00, 98.92it/s][A
 25%|█████████████████         