In [5]:
import tensorflow as tf
from tensorflow import keras

In [1]:
def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a `tf.data.Dataset`.
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )



In [2]:
import os
from tensorflow import keras

# Prepare a directory to store all the checkpoints
checkpoint_dir = "./ckpt"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    
def make_or_restore_model():
    # Either restore the latest model or create a fresh one
    # if there is no checkpoint availabel
    
    checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key = os.path.getctime)
        print("Restoring from ", latest_checkpoint)
        return keras.models.load_model(latest_checkpoint)
    print("Creating a new Model ")
    return get_compiled_model()

In [3]:
def run_training(epochs = 1):
    # Create a Mirrored Strategy
    strategy = tf.distribute.MirroredStrategy()
    
    # Open a strategy scope and create / restore the model
    with strategy.scope():
        model = make_or_restore_model()
    callbacks = [
        # This callback saves a SavedModel every epochs
        # We include the current epoch in the folder name
        keras.callbacks.ModelCheckpoint(
        filepath = checkpoint_dir + '/ckpt-{epoch}', save_freq = "epoch")
    ]
    model.fit(
    train_dataset,
    epochs = epochs,
    callbacks = callbacks,
    validation_data = val_dataset,
    verbose = 2)

In [7]:
# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()

In [8]:
# Running the First time Create the Model
run_training(epochs=1)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Creating a new Model 
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
1563/1563 - 7s - loss: 0.2293 - sparse_categorical_accuracy: 0.9317 - val_loss: 0.1374 - val_sparse_categorical_accuracy: 0.9578
INFO:tensorflow:Assets written to: ./ckpt\ckpt-1\assets


In [10]:
# Calling the same function again will resume from where we left off
run_training(epochs = 1)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Restoring from  ./ckpt/ckpt-1
1563/1563 - 6s - loss: 0.0966 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.0914 - val_sparse_categorical_accuracy: 0.9719
INFO:tensorflow:Assets written to: ./ckpt\ckpt-1\assets


In [13]:
os.listdir('./ckpt/ckpt-1/') # Check where model saved

['assets', 'saved_model.pb', 'variables']