In [None]:
import os
!pip install tensorflow-gpu>=2.0.0 tqdm
try:
    from checkmate.tf2 import get_keras_model
except:
   !git clone https://github.com/parasj/checkmate.git
   os.chdir('./checkmate')
   !pip install -e .


In [None]:
import logging
import numpy as np
import tensorflow as tf
from checkmate.tf2 import get_keras_model
from tqdm import tqdm
logging.basicConfig(level=logging.DEBUG)

# Checkmate getting started guide
Checkmate is a system for training large neural neural networks on memory-constrained hardware. State-of-the-art models require
increasing amounts of GPU memory. Checkmate traces your TensorFlow application and efficiently reschedules the TF graph so that
total memory requirements are under the memory budget of your GPU.

In this tutorial, we walk through how to train a computer vision model with a basic application of Checkmate. While this 
application would likely fit within the limits of most GPUs, it serves to illustrate the mechanics of using Checkmate.

## Loading CIFAR10 using keras
Checkmate optimizes any TensorFlow 2.0 graph. In this example, we load CIFAR10 as an example. We also use a basic few-layer neural network as an example.

In [None]:
# load cifar10 dataset
batch_size = 1024
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, y_train = x_train.astype(np.float32), y_train.astype(np.float32)
x_test, y_test = x_test.astype(np.float32), y_test.astype(np.float32)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

# load TensorFlow model from Keras applications along with loss function and optimizer
model = get_keras_model("test", input_shape=x_train[0].shape, num_classes=10)
loss = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer, loss=loss)

## Recompiling the TensorFlow test model using Checkmate
Checkmate exposes a convenience function `checkmate.tf2.compile_tf2` that will take a Keras model and return
a `tf.Function` that runs a single training iteration over a batch. In order to accurately measure memory
consumption per operation, Checkmate needs to know the full size of the inputs to your model. The training
dataset usually contains this under `train_ds.element_spec`. Note that `element_spec` will also return the
shape of the output, which is not needed.

In [None]:
from checkmate.tf2.wrapper import compile_tf2
element_spec = train_ds.__iter__().__next__()
train_iteration = compile_tf2(
    model,
    loss=loss,
    optimizer=optimizer,
    input_spec=element_spec[0],  # retrieve first element of dataset
    label_spec=element_spec[1]
)

# Training the large neural network
Checkmate has now recompiled our training function. We can continue to use existing TensorFlow functionality for training neural networks, but we substitute the call to the model with Checkmate's version of the training iteration.

In [None]:
train_loss = tf.keras.metrics.Mean(name="train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")
test_loss = tf.keras.metrics.Mean(name="test_loss")
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")

for epoch in range(10):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    with tqdm(total=x_train.shape[0]) as pbar:
        for images, labels in train_ds:
            predictions, loss_value = train_iteration(images, labels)
            train_loss(loss_value)
            train_accuracy(labels, predictions)
            pbar.update(images.shape[0])
            pbar.set_description('Train epoch {}; loss={:0.4f}, acc={:0.4f}'.format(epoch + 1, train_loss.result(), train_accuracy.result()))

    with tqdm(total=x_test.shape[0]) as pbar:
        for images, labels in test_ds:
            predictions = model(images)
            test_loss_value = loss(labels, predictions)
            test_loss(test_loss_value)
            test_accuracy(labels, predictions)
            pbar.update(images.shape[0])
            pbar.set_description('Valid epoch {}, loss={:0.4f}, acc={:0.4f}'.format(epoch + 1, test_loss.result(), test_accuracy.result()))
