In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Download the dataset

In [2]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist (11.06 MiB) to /home/kaimo/tensorflow_datasets/mnist/3.0.0...[0m


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /home/kaimo/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m


In [3]:
mnist_train, mnist_test = datasets['train'], datasets['test']

## Define distribution strategy

In [4]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


## Setup pipeline

In [8]:
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [9]:
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

In [10]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## Create the model

In [19]:
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=[28, 28, 1]),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    
    model.compile(optimizer='adam', 
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                  metrics=['accuracy'])

## Define the callbacks

In [20]:
def decay(epoch):
    if epoch < 3:
        return 1e-3
    elif epoch < 7:
        return 1e-4
    else:
        return 1e-5

## Train and evaluate

In [21]:
model.fit(train_dataset, epochs=12, callbacks=[tf.keras.callbacks.LearningRateScheduler(decay)])

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<tensorflow.python.keras.callbacks.History at 0x7f2cc14173d0>

In [22]:
eval_loss, eval_acc = model.evaluate(eval_dataset)
eval_loss, eval_acc



(0.04120798781514168, 0.987500011920929)

## Save and load

In [None]:
model.save('path/to/save')

In [None]:
# Without scope
unreplicated_model = tf.keras.models.load_model('path/to/save')
unreplicated_model.compile(optimizer='adam', 
                           loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                           metrics=['accuracy'])

In [None]:
# With scope
with strategy.scope():
    replicated_model = tf.keras.models.load_model('path/to/save')
    replicated_model.compile(optimizer='adam', 
                             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                             metrics=['accuracy'])