<a href="https://colab.research.google.com/github/rajdeepd/tensorflow_2.0_book_code/blob/master/ch03/custom_TF_training_loops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Primary imports
import tensorflow as tf
import numpy as np

In [2]:
print(tf.__version__)

2.6.0


In [3]:
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

X_train.shape, X_test.shape, y_train.shape, y_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

In [4]:
# Define the labels of the dataset
CLASSES=["T-shirt/top","Trouser","Pullover","Dress","Coat",
        "Sandal","Shirt","Sneaker","Bag","Ankle boot"]

In [5]:
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

In [6]:
y_train.shape, y_test.shape

((60000,), (10000,))

In [7]:
# TensorFlow imports
from tensorflow.keras.models import *
from tensorflow.keras.layers import *

In [8]:
# Define utility function for building a basic shallow Convnet 
def get_training_model():
    model = Sequential()
    model.add(Conv2D(16, (5, 5), activation="relu",
        input_shape=(28, 28,1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(32, (5, 5), activation="relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation="relu"))
    model.add(Dense(len(CLASSES), activation="softmax"))
    
    return model

In [9]:
# Define loass function and optimizer
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [10]:
# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

In [11]:
# Batches of 64
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

In [12]:
# Train the model
@tf.function
def model_train(features, labels):
    # Define the GradientTape context
    with tf.GradientTape() as tape:
        # Get the probabilities
        predictions = model(features)
        # Calculate the loss
        loss = loss_func(labels, predictions)
    # Get the gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Update the weights
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Update the loss and accuracy
    train_loss(loss)
    train_acc(labels, predictions)

In [13]:
# Validating the model
@tf.function
def model_validate(features, labels):
    predictions = model(features)
    v_loss = loss_func(labels, predictions)

    valid_loss(v_loss)
    valid_acc(labels, predictions)

In [14]:
# A shallow Convnet
model = get_training_model()

In [15]:
# Grab random images from the test and make predictions using 
# the model *while it is training* and log them using WnB
def get_sample_predictions():
    predictions = []
    images = []
    random_indices = np.random.choice(X_test.shape[0], 25)
    for index in random_indices:
        image = X_test[index].reshape(1, 28, 28, 1)
        prediction = np.argmax(model(image).numpy(), axis=1)
        prediction = CLASSES[int(prediction)]
        
        images.append(image)
        predictions.append(prediction)
    
    #wandb.log({"predictions": [wandb.Image(image, caption=prediction) 
    #                           for (image, prediction) in zip(images, predictions)]})

In [16]:
# Train the model for 5 epochs
for epoch in range(5):
    # Run the model through train and test sets respectively
    for (features, labels) in train_ds:
        model_train(features, labels)

    for test_features, test_labels in test_ds:
        model_validate(test_features, test_labels)
        
    # Grab the results
    (loss, acc) = train_loss.result(), train_acc.result()
    (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
    
    # Clear the current state of the metrics
    train_loss.reset_states(), train_acc.reset_states()
    valid_loss.reset_states(), valid_acc.reset_states()
    
    # Local logging
    template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
    print (template.format(epoch+1,
                         loss,
                         acc,
                         val_loss,
                         val_acc))
    
    get_sample_predictions()

Epoch 1, loss: 0.543, acc: 0.805, val_loss: 0.422, val_acc: 0.849
Epoch 2, loss: 0.355, acc: 0.873, val_loss: 0.373, val_acc: 0.865
Epoch 3, loss: 0.305, acc: 0.889, val_loss: 0.346, val_acc: 0.876
Epoch 4, loss: 0.275, acc: 0.899, val_loss: 0.332, val_acc: 0.880
Epoch 5, loss: 0.252, acc: 0.907, val_loss: 0.319, val_acc: 0.884
