In this tutorial, you are going to train a model via the `TF2.Core` APIs in in-depth replication with synchronous training on multiple GPUs. You are going to train a model on the FASHION MNIST dataset.

References:
* Custom training with tf.distribute.Strategy: https://www.tensorflow.org/tutorials/distribute/custom_training#training_loop

In [0]:
!pip install -q tf-nightly

In [3]:
import tensorflow as tf
from tensorflow.python.client import device_lib
import numpy as np
import os

print("Tensorflow Version: {}".format(tf.__version__))
print("Eager Mode: {}".format(tf.executing_eagerly()))
print("GPU {} available.".format("is" if tf.config.experimental.list_physical_devices("GPU") else "not"))
print("List devices:", device_lib.list_local_devices(), sep="\n")

Tensorflow Version: 2.2.0-dev20200129
Eager Mode: True
GPU is available.
List devices:
[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 12662604050648042070
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 9344183571507787081
physical_device_desc: "device: XLA_CPU device"
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 13524239733876520926
physical_device_desc: "device: XLA_GPU device"
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 14912199066
locality {
  bus_id: 1
  links {
  }
}
incarnation: 615689992357562656
physical_device_desc: "device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5"
]


# Data Preparation

In [0]:
fashion_mnist = tf.keras.datasets.fashion_mnist

In [5]:
(train_imgs, train_labels), (test_imgs, test_labels) = fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [6]:
train_imgs.shape, train_labels.shape

((60000, 28, 28), (60000,))

In this tutorial, you are going to train a CNN model so that the image dataset is required to expand dimensions at the final axis as `(BATCH_SIZE, IMG_SIZE, IMG_SIZE, IMG_CHANNELS)`.

In [7]:
train_imgs = train_imgs[..., None]
test_imgs = test_imgs[..., None]

train_imgs.shape, test_imgs.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

Always normalize the continuous data.

In [0]:
train_imgs = train_imgs / np.float32(255.0)
test_imgs = test_imgs / np.float32(255.0)

# Creating a Distributed Strategy

Next, you are going to create a distributed strategy via the `tf.distribute.MirroredStrategy` API. In this strategy, all variables are replicated to each replica, input pipeline is also distributed, each replica calculates the loss and the gradient itself, the gradients are synced across all replicas by summing them, and the update would be copied back to each replica after the sync.

In [9]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [10]:
print("Number of devices: {}.".format(strategy.num_replicas_in_sync))

Number of devices: 1.


# Setting up the Input Pipeline

In [0]:
BUFFER_SIZE = len(train_imgs)

BUFFER_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BUFFER_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

You can simply access the dataset in memory via the `tf.data.Dataset.from_tensor_slices` APIs. After you created a `tf.data.Dataset`, you can make it a distributed pipeline.

In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_imgs, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_imgs, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [0]:
train_dist_datatset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

# Creating a Model

Here you can create a model via the `tf2.keras` APIs.

In [0]:
def create_model():
  def _model_body(inputs):
    x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), padding='same', 
                               activation='elu', name='input')(inputs)  # (None, 28, 28, 32)
    x = tf.keras.layers.MaxPool2D()(x)  # (None, 14, 14, 32)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding='same', 
                               activation='elu')(x)  # (None, 14, 14, 64)
    x = tf.keras.layers.MaxPool2D()(x)  # (None, 7, 7, 64)
    x = tf.keras.layers.Flatten()(x)  # (None, 7*7*64)
    x = tf.keras.layers.Dense(units=64, activation='elu')(x)
    output = tf.keras.layers.Dense(units=10, activation='softmax', name="output")(x)
    return output

  inputs = tf.keras.Input(shape=(28, 28, 1))
  outputs = _model_body(inputs)
  model = tf.keras.Model(inputs, outputs)
  return model

In [0]:
ckpt_dir = "./ckpts"
ckpt_prefix = os.path.join(ckpt_dir, "ckpt")

# Define the Loss Function

In general, in one CPU/GPU device the loss value is divided by the number of examples of the input. However, if you train a model on the multiple GPU devices, you can first do the calculation that the loss value on each replica is divided by the `GLOBAL_BATCH_SIZE` (e.g. 16 batch sizes on 4 replicas, then the GLOBAL_BATCH_SIZE is 16*4=64). After the dividing calculation, you can sum all of them to the final loss value.

In Tensorflow,
* If you define a custom loss function, you can sum the per example losses on each replica and then divided by the GLOBAL_BATCH_SIZE. For example, `scaled_loss = tf.reduce_sum(losses) * (1.0 / GLOBAL_BATCH_SIZE)`. Or you can use `tf.nn.compute_acerage_loss`  which takes per example losses and GLOBAL_BATCH_SIZE as the arguments.

* If you use the regularization loss, you need to scale the loss value by the number of replicas. (the `tf.nn.scale_regularization_loss` API)

* If you use the `tf.keras.losses` classes, the loss reduction is required to set to one of `NONE` or `SUM`. (`AUTO` or `SUM_OVER_BATCH_SIZE` is not allowed.) The reduction and scaling are done automatically on the `compile` or `fit` step.

* Do not use `tf.reduce_mean()` to calculate the loss value divided by the per replica batch size. This action causes the variation step by step.

 

In [0]:
with strategy.scope():
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE)

  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss=per_example_loss, 
                                      global_batch_size=GLOBAL_BATCH_SIZE)

# Define the Metrics

In [0]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')

# Training Loop

The model and optimizer must be under `strategy.scope`.

In [0]:
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  ckpts = tf.train.Checkpoint(optimizer=optimizer, model=model)

Define the train and test steps.

In [0]:
with strategy.scope():
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

In [24]:
with strategy.scope():
  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.experimental_run_v2(train_step, 
                                                      args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

  @tf.function
  def distributed_test_step(dataset_inputs):
    return strategy.experimental_run_v2(test_step, args=(dataset_inputs, ))

  for epoch in range(EPOCHS):
    # train loop
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_datatset:
      total_loss += distributed_train_step(x)
      num_batches += 1
    train_loss = total_loss / num_batches

    # test loop
    for x in test_dist_dataset:
      distributed_test_step(x)

    if epoch % 2 == 0:
      ckpts.save(ckpt_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print(template.format(epoch+1, train_loss, 
                          train_accuracy.result() * 100, 
                          test_loss.result(), 
                          test_accuracy.result() * 100))
    
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.41324031352996826, Accuracy: 85.2550048828125, Test Loss: 0.3330850601196289, Test Accuracy: 87.69000244140625
INFO:tensorflow:Reduce to /job:l

# Restore the latest checkpoint and test

A model checkpoint can be loaded with or without a strategy.

In [0]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_imgs, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [0]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [28]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(ckpt_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print("Restoring the model without a strategy, the accuracy is {}.".format(
  eval_accuracy.result() * 100))

Restoring the model without a strategy, the accuracy is 91.11000061035156.


# Iterating Over a Dataset

## Using Iterators

In [29]:
with strategy.scope():
  for _ in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    train_iter = iter(train_dist_datatset)

    for _ in range(10):
      total_loss += distributed_train_step(next(train_iter))
      num_batches += 1
    average_train_loss = total_loss / num_batches

    template = ("Epoch {}, Loss {}, Accuracy {}")
    print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))

    train_accuracy.reset_states()

Epoch 10, Loss 0.07053428143262863, Accuracy 97.96875
Epoch 10, Loss 0.05302988365292549, Accuracy 98.28125
Epoch 10, Loss 0.05932199954986572, Accuracy 98.59375
Epoch 10, Loss 0.06402073800563812, Accuracy 96.875
Epoch 10, Loss 0.05695180967450142, Accuracy 98.125
Epoch 10, Loss 0.06418366730213165, Accuracy 96.5625
Epoch 10, Loss 0.03831333667039871, Accuracy 98.125
Epoch 10, Loss 0.060050249099731445, Accuracy 97.96875
Epoch 10, Loss 0.06483839452266693, Accuracy 97.5
Epoch 10, Loss 0.059599876403808594, Accuracy 97.5


## Iterating inside a TF.Function

In [30]:
with strategy.scope():
  @tf.function
  def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for x in dataset:
      pre_replica_losses = strategy.experimental_run_v2(train_step, args=(x, ))
      total_loss += strategy.reduce(
        tf.distribute.ReduceOp.SUM, pre_replica_losses, axis=None)
      num_batches += 1
    return total_loss / tf.cast(num_batches, tf.float32)

  for epoch in range(EPOCHS):
    train_loss = distributed_train_epoch(train_dist_datatset)

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

    train_accuracy.reset_states()

Epoch 1, Loss: 0.06267233192920685, Accuracy: 97.68999481201172
Epoch 2, Loss: 0.05286860093474388, Accuracy: 98.04499816894531
Epoch 3, Loss: 0.04505863040685654, Accuracy: 98.35333251953125
Epoch 4, Loss: 0.037907831370830536, Accuracy: 98.63333129882812
Epoch 5, Loss: 0.03564491868019104, Accuracy: 98.66832733154297
Epoch 6, Loss: 0.02910824678838253, Accuracy: 98.94000244140625
Epoch 7, Loss: 0.027507269755005836, Accuracy: 98.99832916259766
Epoch 8, Loss: 0.028171533718705177, Accuracy: 99.05500030517578
Epoch 9, Loss: 0.021189115941524506, Accuracy: 99.27999877929688
Epoch 10, Loss: 0.021836640313267708, Accuracy: 99.24666595458984
