In [8]:
import tensorflow as tf
import keras

In [14]:
from tensorflow.keras.datasets import mnist
  
def get_mnist_model():                                                
    inputs = keras.Input(shape=(28 * 28,))
    features = keras.layers.Dense(512, activation="relu")(inputs)
    features = keras.layers.Dropout(0.5)(features)
    outputs = keras.layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model
    
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255 
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255 
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

In [15]:
model = get_mnist_model()
 
loss_fn = keras.losses.SparseCategoricalCrossentropy()                    
optimizer = keras.optimizers.RMSprop()                                    
metrics = [keras.metrics.SparseCategoricalAccuracy()]                     
loss_tracking_metric = keras.metrics.Mean()                               
 
def train_step(inputs, targets):
    with tf.GradientTape() as tape:                                       
        predictions = model(inputs, training=True)                        
        loss = loss_fn(targets, predictions)                              
    gradients = tape.gradient(loss, model.trainable_weights)              
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))    
 
    logs = {}                                                             
    for metric in metrics:                                                
        metric.update_state(targets, predictions)                         
        logs[metric.name] = metric.result()                               
 
    loss_tracking_metric.update_state(loss)                               
    logs["loss"] = loss_tracking_metric.result()                          
    return logs                                                           
def reset_metrics():
    for metric in metrics:
        metric.reset_state()
    loss_tracking_metric.reset_state()



In [None]:
training_dataset = tf.data.Dataset.from_tensor_slices(
    (train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3 
for epoch in range(epochs):
    reset_metrics()
    for inputs_batch, targets_batch in training_dataset:
        logs = train_step(inputs_batch, targets_batch)
    print(f"Results at the end of epoch {epoch}")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")

Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9201
...loss: 0.2731
