In [21]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [22]:
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

In [23]:
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [24]:
optimizer = keras.optimizers.SGD(learning_rate=1e-3)

In [25]:
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [26]:
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [27]:
x_train.shape

(60000, 28, 28)

In [28]:
x_train = np.reshape (x_train, (-1, 784))

In [29]:
x_train.shape

(60000, 784)

In [30]:
x_test = np.reshape(x_test, (-1, 784))

In [31]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

In [32]:
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [33]:
# prepare validation dataset
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

In [34]:
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(64)

In [35]:
x_train.shape

(50000, 784)

In [36]:
x_val.shape

(10000, 784)

Any function that takes tensors as input can be compiled into a graph.

In [37]:
@tf.function
def train_step(x, y):
    # with statement does not create a scope
    with tf.GradientTape() as tape:
        # forward pass
        logits = model(x, training=True)
        # loss
        loss_value = loss_fn(y, logits)

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    # update train metric
    train_acc_metric.update_state(y, logits)
    
    return loss_value

In [38]:
@tf.function
def test_step(x,y):
    val_logits = model(x, training=False)
    # update val metrics
    val_acc_metric.update_state(y, val_logits)

In [39]:
import time

In [40]:
epochs = 2
for epoch in range(epochs):
    print("start of epoch %d" % epoch)
    start_time = time.time()
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)
        
        if step % 200 == 0:
            print("train loss: %.4f at step: %d" % (float(loss_value), step))
        
    # display train metric at end of each epoch
    train_acc = train_acc_metric.result()
    train_acc_metric.reset_states()
    print("train acc over epoch: %.2f" % float(train_acc))
        
    # run a validation loop at the end of each epoch
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("val acc: %.4f" % float(val_acc))
    print("time taken: %.2fs" % (time.time() - start_time))

start of epoch 0
train loss: 83.8055 at step: 0
train loss: 0.9717 at step: 200
train loss: 0.8052 at step: 400
train loss: 1.1229 at step: 600
train loss: 1.1742 at step: 800
train acc over epoch: 0.72
val acc: 0.8364
time taken: 9.58s
start of epoch 1
train loss: 0.6489 at step: 0
train loss: 0.4961 at step: 200
train loss: 0.7379 at step: 400
train loss: 1.2128 at step: 600
train loss: 0.3753 at step: 800
train acc over epoch: 0.84
val acc: 0.8708
time taken: 5.87s
