The `tf.distribute.Strategy` APIs provide you an abstraction of distributed training progresses across multiple data preprocessing pipelines. In this tutorial, you are going to use the `tf.distribute.MirroredStrategy` APIs which do in-depth replications with synchronous training on multiple GPUs on one machine. **It copies all of the model's variables to each accelerator. Then it combines the gradients from them and applies the combined value to all copies of the model.** 

In this tutorial, you are going to use the FASHION_MNIST dataset and `tf2.keras` APIs to build a distributed training scenario.

In [0]:
!pip install -q tf-nightly

In [28]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.client import device_lib
import os

print("Tensorflow Version: {}".format(tf.__version__))
print("Eager Mode: {}".format(tf.executing_eagerly()))
print("GPU {} available.".format("is" if tf.config.experimental.list_physical_devices("GPU") else "not"))
print("Devices: {}".format(device_lib.list_local_devices()))

Tensorflow Version: 2.2.0-dev20200119
Eager Mode: True
GPU is available.
Devices: [name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 2933834386269868962
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 2173686386421350107
physical_device_desc: "device: XLA_CPU device"
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 14877219504377459244
physical_device_desc: "device: XLA_GPU device"
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 15956161332
locality {
  bus_id: 1
  links {
  }
}
incarnation: 18018003714542427558
physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0"
]


# Data Preprocessing

In [0]:
datasets, info = tfds.load(name='fashion_mnist', with_info=True, as_supervised=True)

In [12]:
info

tfds.core.DatasetInfo(
    name='fashion_mnist',
    version=1.0.0,
    description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
    homepage='https://github.com/zalandoresearch/fashion-mnist',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{DBLP:journals/corr/abs-1708-07747,
      author    = {Han Xiao and
                   Kashif Rasul and
                   Roland Vollgraf},
      title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
                   Algorithms},
      journal   = {CoRR},
      volume

In [0]:
mnist_train, mnist_test = datasets["train"], datasets["test"]

In [14]:
for _img, _label in mnist_train.take(1):
  print(_img.shape, _label.numpy())

(28, 28, 1) 3


# Define the Distributed Strategy

Create a `MirrorStrategy` object which can handle distribution and provide a context manager (`tf.distribute.MirroredStrategy.scope`) to build your model inside.

In [15]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [16]:
print("Number of GPUs: {}".format(strategy.num_replicas_in_sync))

Number of GPUs: 1


# Setup the Input Pipeline

When training on multiple accelerators, you can use the extra computing resources effectively to setup the input pipeline. By default, the total batch size could be `batch_size_per_replica * the number of replicas`.

In [18]:
num_train_examples = info.splits["train"].num_examples
num_test_examples = info.splits["test"].num_examples

num_train_examples, num_test_examples

(60000, 10000)

In [0]:
BUFFER_SIZE = int(1e4)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Do a data preprocessing, here, you have to normalize the image data values (normalized the value to 0 - 1 range).

In [0]:
def normalize(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255.0
  return image, label

Build a pipeline and apply the normalization function to the data source.

In [0]:
train_dataset = mnist_train.map(normalize).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(normalize).batch(BATCH_SIZE)

# Create the Model

Create a `tf2.keras` model and compile it in the context of `strategy.scope`.

In [0]:
def build_model(inputs):
  x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), padding='same', 
                             activation='elu', input_shape=(28, 28, 1))(inputs)
  x = tf.keras.layers.MaxPooling2D()(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(units=64, activation='elu')(x)
  y = tf.keras.layers.Dense(units=10, activation='softmax', name='classes')(x)
  return y

In [40]:
with strategy.scope():
  inputs = tf.keras.Input(shape=(28, 28, 1))
  outputs = build_model(inputs)
  model = tf.keras.Model(inputs, outputs)

  model.compile(loss='sparse_categorical_crossentropy', 
                optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
                metrics=['accuracy'])

  model.summary()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 6272)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 64)                401472    
_________________________________________________________________
dense_5 (Dense)              (None, 10)                650       
Total params: 402,442
Trainable params: 402,442
Non-trainable params: 0
_____________________________________________________

# Define the Callbacks

After you created a model, you have to create several callbacks in order to do:
* Tensorboard: to monitor the training progress and visualize the graphs
* Model Checkpoints: to save the model weights at the end of every epoch
* Learning Rate Scheduler: schedule/decrease the learning rate at the peroid of epochs

In [0]:
ckpt_dir = './ckpt'
ckpt_prefix = os.path.join(ckpt_dir, "ckpt_{epoch}")

In [0]:
!rm -rf {ckpt_dir}

Define a simple decay function to decrease the learning rate.

In [0]:
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch < 7:
    return 1e-4
  else:
    return 1e-5

Here you can inspect the learning rate using a callback.

In [0]:
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print("Learning Rate on epoch {} is {}.".format(
        epoch+1, model.optimizer.lr.numpy()))

In [0]:
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),
             tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_prefix, 
                                                save_weights_only=True),
             tf.keras.callbacks.LearningRateScheduler(decay),
             PrintLR()
            ]

# Train and Evaluate the Model

Here you can call the method `.fit()` to start a training no matter what it is on the distributed architecture or not.

In [46]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
    938/Unknown - 14s 15ms/step - loss: 0.4177 - accuracy: 0.8533Learning Rate on epoch 1 is 0.0010000000474974513.
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<tensorflow.python.keras.callbacks.History at 0x7f78be3dff28>

During the training, you can observe the checkpoint files were generated.

In [34]:
!ls -al {ckpt_dir}

total 56748
drwxr-xr-x 2 root root    4096 Jan 21 08:08 .
drwxr-xr-x 1 root root    4096 Jan 21 08:07 ..
-rw-r--r-- 1 root root      71 Jan 21 08:08 checkpoint
-rw-r--r-- 1 root root    2486 Jan 21 08:08 ckpt_10.data-00000-of-00002
-rw-r--r-- 1 root root 4829328 Jan 21 08:08 ckpt_10.data-00001-of-00002
-rw-r--r-- 1 root root    1695 Jan 21 08:08 ckpt_10.index
-rw-r--r-- 1 root root    2486 Jan 21 08:08 ckpt_11.data-00000-of-00002
-rw-r--r-- 1 root root 4829328 Jan 21 08:08 ckpt_11.data-00001-of-00002
-rw-r--r-- 1 root root    1695 Jan 21 08:08 ckpt_11.index
-rw-r--r-- 1 root root    2486 Jan 21 08:08 ckpt_12.data-00000-of-00002
-rw-r--r-- 1 root root 4829328 Jan 21 08:08 ckpt_12.data-00001-of-00002
-rw-r--r-- 1 root root    1695 Jan 21 08:08 ckpt_12.index
-rw-r--r-- 1 root root    2486 Jan 21 08:07 ckpt_1.data-00000-of-00002
-rw-r--r-- 1 root root 4829328 Jan 21 08:07 ckpt_1.data-00001-of-00002
-rw-r--r-- 1 root root    1695 Jan 21 08:07 ckpt_1.index
-rw-r--r-- 1 root root    2486 Jan 

You can also get the latest checkpoint name.

In [47]:
tf.train.latest_checkpoint(ckpt_dir)

'./ckpt/ckpt_12'

You can export the whole checkpoint files, transfer them elsewhere, and reload the weights if the same model architecture was established.

In [48]:
model.load_weights(tf.train.latest_checkpoint(ckpt_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f78be37af60>

In [50]:
eval_loss, eval_acc = model.evaluate(eval_dataset)
print("Eval loss: {}, acc: {:.3%}".format(eval_loss, eval_acc))

    157/Unknown - 3s 19ms/step - loss: 0.2502 - accuracy: 0.9086Eval loss: 0.25023063561718933, acc: 90.860%


You can also inspect the training progress via Tensorboard.

In [0]:
%reload_ext tensorboard
!tensorboard --logdir=./logs

# Export to the SavedModel

After you trained the model, you can export the graph and the weights to the platform-agnostic SavedModel format.

In [52]:
path = './savedmodel'

model.save(path, save_format='tf')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Assets written to: ./savedmodel/assets


INFO:tensorflow:Assets written to: ./savedmodel/assets


# Reload a Model from a SavedModel

Load the model **without** the `strategy.scope()`.

In [53]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print("Eval Loss: {}, Acc: {:.3%}".format(eval_loss, eval_acc))

    157/Unknown - 2s 12ms/step - loss: 0.2502 - accuracy: 0.9086Eval Loss: 0.25023063561718933, Acc: 90.860%


Load the model **with** the `strategy.scope()`.

In [54]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)

  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print("Eval Loss: {}, Acc: {:.3%}".format(eval_loss, eval_acc))

    157/Unknown - 3s 19ms/step - loss: 0.2502 - accuracy: 0.9086Eval Loss: 0.25023063561718933, Acc: 90.860%
