In [0]:
# Select the TensorFlow 2.0 runtime
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
# Install Weights and Biases (WnB)
!pip install wandb

In [0]:
# Primary imports
import tensorflow as tf
import numpy as np
import wandb

In [0]:
# Authorize Weights and Biases
!wandb login

In [0]:
# Intialize WnB with a project name of your choice
wandb.init(project="custom_training_loops_tf")

W&B Run: https://app.wandb.ai/sayakpaul/custom_training_loops_tf/runs/3c1cfpm9

In [0]:
# Load the FashionMNIST dataset, scale the pixel values
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train/255.
X_test = X_test/255.

X_train.shape, X_test.shape, y_train.shape, y_test.shape

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))

In [0]:
# Define the labels of the dataset
CLASSES=["T-shirt/top","Trouser","Pullover","Dress","Coat",
        "Sandal","Shirt","Sneaker","Bag","Ankle boot"]

In [0]:
# Change the pixel values to float32 and reshape input data
X_train = X_train.astype("float32").reshape(-1, 28, 28, 1)
X_test = X_test.astype("float32").reshape(-1, 28, 28, 1)

In [0]:
y_train.shape, y_test.shape

((60000,), (10000,))

In [0]:
# TensorFlow imports
from tensorflow.keras.models import *
from tensorflow.keras.layers import *

In [0]:
# Define utility function for building a basic shallow Convnet 
def get_training_model():
    model = Sequential()
    model.add(Conv2D(16, (5, 5), activation="relu",
        input_shape=(28, 28,1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(32, (5, 5), activation="relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation="relu"))
    model.add(Dense(len(CLASSES), activation="softmax"))
    
    return model

In [0]:
# Define loass function and optimizer
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [0]:
# Average the loss across the batch size within an epoch
train_loss = tf.keras.metrics.Mean(name="train_loss")
valid_loss = tf.keras.metrics.Mean(name="test_loss")

# Specify the performance metric
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="train_acc")
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name="valid_acc")

In [0]:
# Batches of 64
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(64)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(64)

In [0]:
# Train the model
@tf.function
def model_train(features, labels):
    # Define the GradientTape context
    with tf.GradientTape() as tape:
        # Get the probabilities
        predictions = model(features)
        # Calculate the loss
        loss = loss_func(labels, predictions)
    # Get the gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Update the weights
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # Update the loss and accuracy
    train_loss(loss)
    train_acc(labels, predictions)

In [0]:
# Validating the model
@tf.function
def model_validate(features, labels):
    predictions = model(features)
    v_loss = loss_func(labels, predictions)

    valid_loss(v_loss)
    valid_acc(labels, predictions)

In [0]:
# A shallow Convnet
model = get_training_model()

In [0]:
# Grab random images from the test and make predictions using 
# the model *while it is training* and log them using WnB
def get_sample_predictions():
    predictions = []
    images = []
    random_indices = np.random.choice(X_test.shape[0], 25)
    for index in random_indices:
        image = X_test[index].reshape(1, 28, 28, 1)
        prediction = np.argmax(model(image).numpy(), axis=1)
        prediction = CLASSES[int(prediction)]
        
        images.append(image)
        predictions.append(prediction)
    
    wandb.log({"predictions": [wandb.Image(image, caption=prediction) 
                               for (image, prediction) in zip(images, predictions)]})

In [0]:
# Train the model for 5 epochs
for epoch in range(5):
    # Run the model through train and test sets respectively
    for (features, labels) in train_ds:
        model_train(features, labels)

    for test_features, test_labels in test_ds:
        model_validate(test_features, test_labels)
        
    # Grab the results
    (loss, acc) = train_loss.result(), train_acc.result()
    (val_loss, val_acc) = valid_loss.result(), valid_acc.result()
    
    # Clear the current state of the metrics
    train_loss.reset_states(), train_acc.reset_states()
    valid_loss.reset_states(), valid_acc.reset_states()
    
    # Local logging
    template = "Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}"
    print (template.format(epoch+1,
                         loss,
                         acc,
                         val_loss,
                         val_acc))
    
    # Logging with WnB
    wandb.log({"train_loss": loss.numpy(),
               "train_accuracy": acc.numpy(),
               "val_loss": val_loss.numpy(),
               "val_accuracy": val_acc.numpy()
    })
    get_sample_predictions()

Epoch 1.000, loss: 0.544, acc: 0.802, val_loss: 0.429, val_acc: 0.845
Epoch 2.000, loss: 0.361, acc: 0.871, val_loss: 0.377, val_acc: 0.860
Epoch 3.000, loss: 0.309, acc: 0.888, val_loss: 0.351, val_acc: 0.869
Epoch 4.000, loss: 0.277, acc: 0.899, val_loss: 0.336, val_acc: 0.873
Epoch 5.000, loss: 0.252, acc: 0.908, val_loss: 0.323, val_acc: 0.882
