In [1]:
# https://www.tensorflow.org/tutorials/distribute/custom_training

``` text
This tutorial demonstrates how to use tf.distribute.Strategy with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.
```

In [2]:
# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.7.0


# Download the fashion MNIST dataset

In [3]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None] # adding channel dimension for a convolutional layer.
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

## Create a strategy to distribute the variables and the graph


How does tf.distribute.MirroredStrategy strategy work?

* All the variables and the model graph is replicated on the replicas.
* Input is evenly distributed across the replicas.
* Each replica calculates the loss and gradients for the input it received.
* The gradients are synced across all the replicas by summing them.
* After the sync, the same update is made to the copies of the variables on each replica.

In [4]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

2021-09-15 11:42:50.987663: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 47220 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:67:00.0, compute capability: 7.5
2021-09-15 11:42:50.989091: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 46874 MB memory:  -> device: 1, name: Quadro RTX 8000, pci bus id: 0000:68:00.0, compute capability: 7.5


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [5]:
print('Numer of devices: {}'.format(strategy.num_replicas_in_sync))

Numer of devices: 2


## Set up the input pipeline

In [6]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

Create the datasets and distribute them:

In [7]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

2021-09-15 11:45:37.649095: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:705] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}

2021-09-15 11:45:37.670862: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:705] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source

## Create the model

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API to do this.

In [8]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model

In [9]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

In [12]:
with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

## Define the metrics to track loss and accuracy

In [13]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

## Training loop

In [14]:
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [15]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [16]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1


2021-09-15 12:12:43.807069: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8202
2021-09-15 12:12:44.798571: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8202


INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 0.5525131821632385, Accuracy: 80.1866683959961, Test Loss: 0.4275776445865631, Test Accuracy: 84.41999816894531
Epoch 2, Loss: 0.3616889417171478, Accuracy: 87.1383285522461, Test Loss: 0.3542369604110718, Test Accuracy: 87.2699966430664
Epoch 3, Loss: 0.31681355834007263, Accuracy: 88.538330078125, Test Loss: 0.3333788812160492, Test Accuracy: 88.13999938964844
Epoch 4, Loss: 0.2883014380931854, Accuracy: 89.60833740234375, Test Loss: 0.3262891471385956, Test Accuracy: 88.36000061035156
Epoch 5, Loss: 0.26733988523483276, Accuracy: 90.2683334350586, Test Loss: 0.3004698157310486, Test Accuracy: 89.1300048828125
Epoch 6, Loss: 0.2474144548177719, Accuracy: 90.94499969482422, Test Loss: 0.30530864000320435, Test Accuracy: 89.0999984741211
Epoch 7, Loss: 0.23352253437042236, Accuracy: 91.38166809082031, Test Loss: 0.27879011631011963, Test Accuracy: 89.95000457763672
Epoch 8, Loss: 0.21598

Things to note in the example above:

* We are iterating over the train_dist_dataset and test_dist_dataset using a for x in ... construct.

* The scaled loss is the return value of the distributed_train_step. This value is aggregated across replicas using the tf.distribute.Strategy.reduce call and then across batches by summing the return value of the tf.distribute.Strategy.reduce calls.

* tf.keras.Metrics should be updated inside train_step and test_step that gets executed by tf.distribute.Strategy.run. *tf.distribute.Strategy.run returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do tf.distribute.Strategy.reduce to get an aggregated value. You can also do tf.distribute.Strategy.experimental_local_results to get the list of values contained in the result, one per local replica.


## Restore the latest checkpoint and test

In [18]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [19]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [20]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))

Accuracy after restoring the saved model without strategy: 90.55000305175781


## Alternate ways of iterating over a dataset

In [25]:
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()


Epoch 10, Loss: 0.16488070785999298, Accuracy: 93.828125
Epoch 10, Loss: 0.1774645745754242, Accuracy: 94.0625
Epoch 10, Loss: 0.14581023156642914, Accuracy: 94.21875
Epoch 10, Loss: 0.15926451981067657, Accuracy: 94.140625
Epoch 10, Loss: 0.16859349608421326, Accuracy: 93.515625
Epoch 10, Loss: 0.1727086901664734, Accuracy: 94.453125
Epoch 10, Loss: 0.17819590866565704, Accuracy: 94.21875
Epoch 10, Loss: 0.17876367270946503, Accuracy: 94.0625
Epoch 10, Loss: 0.1698833405971527, Accuracy: 94.140625
Epoch 10, Loss: 0.18342843651771545, Accuracy: 93.359375


### Iterating inside a tf.function

```
You can also iterate over the entire input train_dist_dataset inside a tf.function using the for x in ... construct or by creating iterators like we did above. The example below demonstrates wrapping one epoch of training in a tf.function and iterating over train_dist_dataset inside the function.
```

In [27]:
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()

Epoch 10, Loss: 0.1833042949438095, Accuracy: 94.140625
Epoch 10, Loss: 0.16555854678153992, Accuracy: 94.53125
Epoch 10, Loss: 0.1660725176334381, Accuracy: 93.828125
Epoch 10, Loss: 0.15906807780265808, Accuracy: 93.90625
Epoch 10, Loss: 0.15899965167045593, Accuracy: 94.6875
Epoch 10, Loss: 0.1722952276468277, Accuracy: 94.53125
Epoch 10, Loss: 0.1721753180027008, Accuracy: 93.359375
Epoch 10, Loss: 0.16018742322921753, Accuracy: 94.453125
Epoch 10, Loss: 0.1922616958618164, Accuracy: 93.046875
Epoch 10, Loss: 0.19559378921985626, Accuracy: 92.578125
