In [58]:
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
(train_x, train_y), (test_x, test_y) = fashion_mnist.load_data()

In [59]:
print(train_x.shape)
print(train_y.shape)

(60000, 28, 28)
(60000,)


In [60]:
from tensorflow.keras import layers, metrics, optimizers, Model

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = layers.Flatten(input_shape=(28, 28))
        self.dense1 = layers.Dense(32, activation='relu')
        self.dense2 = layers.Dense(10, activation='softmax')
    def call(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        return self.dense2(x)

In [61]:
train_x, test_x = train_x / 255., test_x / 255.

In [62]:
model = MyModel()

In [63]:
from tensorflow.keras import losses

loss_obj = losses.SparseCategoricalCrossentropy()
optimizer = optimizers.Adam()

In [64]:
train_loss = metrics.Mean(name='train_loss')
train_accuracy = metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = metrics.Mean(name='test_loss')
test_accuracy = metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [65]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_obj(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

In [66]:
@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    loss = loss_obj(labels, predictions)

    test_loss(loss)
    test_accuracy(labels, predictions)

In [67]:
train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y)).batch(32)

In [68]:
epochs = 5
for epoch in range(epochs):
    print('Epoch ==>', epoch)
    train_accuracy.reset_states()
    train_loss.reset_states()
    test_accuracy.reset_states()
    test_loss.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)
    for images, labels in test_ds:
        test_step(images, labels)
    print('loss:', train_loss.result(), 'acc:', train_accuracy.result(), 'val_loss:', test_loss.result(), 'val_acc:', test_accuracy.result())

Epoch ==> 0
loss: tf.Tensor(0.54599804, shape=(), dtype=float32) acc: tf.Tensor(0.8120833, shape=(), dtype=float32) val_loss: tf.Tensor(0.4647985, shape=(), dtype=float32) val_acc: tf.Tensor(0.8318, shape=(), dtype=float32)
Epoch ==> 1
loss: tf.Tensor(0.4134953, shape=(), dtype=float32) acc: tf.Tensor(0.85405, shape=(), dtype=float32) val_loss: tf.Tensor(0.43390346, shape=(), dtype=float32) val_acc: tf.Tensor(0.8418, shape=(), dtype=float32)
Epoch ==> 2
loss: tf.Tensor(0.38149646, shape=(), dtype=float32) acc: tf.Tensor(0.8622, shape=(), dtype=float32) val_loss: tf.Tensor(0.40960747, shape=(), dtype=float32) val_acc: tf.Tensor(0.853, shape=(), dtype=float32)
Epoch ==> 3
loss: tf.Tensor(0.3575212, shape=(), dtype=float32) acc: tf.Tensor(0.8717333, shape=(), dtype=float32) val_loss: tf.Tensor(0.37783873, shape=(), dtype=float32) val_acc: tf.Tensor(0.8646, shape=(), dtype=float32)
Epoch ==> 4
loss: tf.Tensor(0.33981755, shape=(), dtype=float32) acc: tf.Tensor(0.8767, shape=(), dtype=float