In [49]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [50]:
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

In [51]:
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [52]:
optimizer = keras.optimizers.SGD(learning_rate=1e-3)

In [53]:
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [54]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [55]:
x_train.shape

(60000, 28, 28)

In [56]:
x_train = np.reshape (x_train, (-1, 784))

In [57]:
x_train.shape

(60000, 784)

In [58]:
x_test = np.reshape(x_test, (-1, 784))

In [59]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

In [60]:
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [61]:
# prepare validation dataset
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

In [62]:
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)

In [63]:
x_train.shape

(50000, 784)

In [64]:
x_val.shape

(10000, 784)

In [65]:
import time

In [66]:
epochs = 2
for epoch in range(epochs):
    print("start of epoch %d" % epoch)
    start_time = time.time()
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # with statement does not create a scope
        with tf.GradientTape() as tape:
            # forward pass
            logits = model(x_batch_train, training=True)
            # loss
            loss_value = loss_fn(y_batch_train, logits)
            
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        # update train metric
        train_acc_metric.update_state(y_batch_train, logits)
        
        if step % 100 == 0:
            print("train loss: %.4f at step: %d" % (float(loss_value), step))
        
        
    # display train metric at end of each epoch
    train_acc = train_acc_metric.result()
    train_acc_metric.reset_states()
    print("train acc over epoch: %.2f" % float(train_acc))
        
    # run a validation loop at the end of each epoch
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("val acc: %.4f" % float(val_acc))
    print("time taken: %.2fs" % (time.time() - start_time))

start of epoch 0
train loss: 154.9669 at step: 0
train loss: 1.2318 at step: 100
train loss: 1.3028 at step: 200
train loss: 0.5447 at step: 300
train loss: 1.6573 at step: 400
train loss: 0.8065 at step: 500
train loss: 0.7988 at step: 600
train loss: 0.6013 at step: 700
train loss: 0.8471 at step: 800
train loss: 0.5211 at step: 900
train acc over epoch: 0.73
val acc: 0.8423
time taken: 13.06s
start of epoch 1
train loss: 0.7701 at step: 0
train loss: 0.6285 at step: 100
train loss: 0.7396 at step: 200
train loss: 0.5205 at step: 300
train loss: 0.5758 at step: 400
train loss: 0.4211 at step: 500
train loss: 0.5970 at step: 600
train loss: 0.6718 at step: 700
train loss: 0.6328 at step: 800
train loss: 0.4723 at step: 900
train acc over epoch: 0.85
val acc: 0.8730
time taken: 13.07s
