In [1]:
import shutil
import tensorflow as tf

In [2]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

In [3]:
train_images = tf.expand_dims(train_images, axis=-1)
test_images = tf.expand_dims(test_images, axis=-1)

train_images = tf.cast(train_images, tf.float32) / 255.0
test_images = tf.cast(test_images, tf.float32) / 255.0

## Create a strategy to distribute the variables and the graph

How does `tf.distribute.MirroredStrategy` strategy work?

- All the variables and the model graph is replicated on the replicas.
- Input is evenly distributed across the replicas.
- Each replica calculates the loss and gradients for the input it received.
- The gradients are synced across all the replicas by summing them.
- After the sync, the same update is made to the copies of the variables on each replica.

In [4]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


## Setup input pipeline

In [5]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

In [6]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE)

# Distribute dataset
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

## Create the model

In [7]:
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Conv2D(64, 3, activation='relu'),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    
    return model

## Definee the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how should the loss be calculated when using a `tf.distribute.Strategy`?

- For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.
- The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).

Why do this?

- This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.

How to do this in TensorFlow?

- If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) or you can use `tf.nn.compute_average_loss` which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.
- If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the `tf.nn.scale_regularization_loss` function.
- Using `tf.reduce_mean` is __not recommended__. Doing so divides the loss by actual per replica batch size which may vary step to step.
- This reduction and scaling is done automatically in keras `model.compile` and `model.fit`
- If using `tf.keras.losses` classes (as in the example below), the loss reduction needs to be explicitly specified to be one of `NONE` or `SUM`. `AUTO` and `SUM_OVER_BATCH_SIZE` are disallowed when used with tf.distribute.Strategy. `AUTO` is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. `SUM_OVER_BATCH_SIZE` is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.

In [8]:
with strategy.scope():
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE
    )
    
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=BATCH_SIZE_PER_REPLICA)

## Define the metrics to track loss and accuracy

In [9]:
with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

## Training loop

In [10]:
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [11]:
def train_step(inputs):
    """Returns per-replica loss (sum_replica_loss / (replica_batch_size * num_replicas))"""
    images, labels = inputs
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = compute_loss(labels, predictions) # replica loss (divided by num_replicas)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_accuracy.update_state(labels, predictions)
    return loss

def test_step(inputs):
    images, labels = inputs
    predictions = model(images, training=False)
    loss = loss_object(labels, predictions)
    
    test_loss.update_state(loss)
    test_accuracy.update_state(labels, predictions)

In [12]:
# 'run' replicates the provided computation and runs it with the distributed input.

@tf.function
def distributed_train_step(dataset_inputs):
    """Return average loss across all replicas"""
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs, ))

for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_dataset:
        total_loss += distributed_train_step(x)
        num_batches += 1
    train_loss = total_loss / num_batches
    
    # TEST LOOP
    for x in test_dist_dataset:
        distributed_test_step(x)
        
    if epoch % 2 == 0:
        checkpoint.save('./training_checkpoints/ckpt')
        
    template = ("Epoch {:2d}, Loss: {:.4f}, Accuracy: {:.2%}, Test Loss: {:.4f}, Test Accuracy: {:.2%}")
    print(template.format(epoch+1, train_loss, train_accuracy.result(), test_loss.result(), test_accuracy.result()))
    
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch  1, Loss: 0.5006, Accuracy: 82.23%, Test Loss: 0.3768, Test Accuracy: 86.72%
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then b

## Restore the latest checkpoint and test

In [34]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE)

In [35]:
@tf.function
def eval_step(images, labels):
    predictions = new_model(images, training=False)
    eval_accuracy.update_state(labels, predictions)

In [37]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint('./training_checkpoints'))

for images, labels in test_dataset:
    eval_step(images, labels)
    
print('Accuracy after restoring the saved model without strategy: {:.2%}'.format(eval_accuracy.result()))

Accuracy after restoring the saved model without strategy: 91.04%


## Alternate ways of iterating over a dataset

### Use iterator

In [39]:
for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    train_iter = iter(train_dist_dataset)
    
    for _ in range(10):
        total_loss += distributed_train_step(next(train_iter))
        num_batches += 1
    average_train_loss = total_loss / num_batches
    
    template = ("Epoch {:2d}, Loss: {:.4f}, Accuracy: {:.2%}")
    print(template.format(epoch+1, average_train_loss, train_accuracy.result()))
    train_accuracy.reset_states()

Epoch  1, Loss: 0.1528, Accuracy: 95.16%
Epoch  2, Loss: 0.1205, Accuracy: 95.47%
Epoch  3, Loss: 0.1164, Accuracy: 95.63%
Epoch  4, Loss: 0.0980, Accuracy: 96.25%
Epoch  5, Loss: 0.1277, Accuracy: 95.47%
Epoch  6, Loss: 0.1487, Accuracy: 94.38%
Epoch  7, Loss: 0.1313, Accuracy: 95.31%
Epoch  8, Loss: 0.1460, Accuracy: 94.69%
Epoch  9, Loss: 0.1097, Accuracy: 96.72%
Epoch 10, Loss: 0.0970, Accuracy: 96.09%


### Iterate inside a tf.function

In [15]:
@tf.function
def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for x in dataset:
        per_replica_losses = strategy.run(train_step, args=(x, ))
        total_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        num_batches += 1
    return total_loss / tf.cast(num_batches, tf.float32)

for epoch in range(EPOCHS):
    train_loss = distributed_train_epoch(train_dist_dataset)
    
    template = ('Epoch {:2d}, Loss: {:.4f}, Accuracy: {:.2%}')
    print(template.format(epoch, train_loss, train_accuracy.result()))
    train_accuracy.reset_states()

Epoch  0, Loss: 0.0736, Accuracy: 97.17%
Epoch  1, Loss: 0.0679, Accuracy: 97.53%
Epoch  2, Loss: 0.0605, Accuracy: 97.73%
Epoch  3, Loss: 0.0570, Accuracy: 97.91%
Epoch  4, Loss: 0.0542, Accuracy: 97.97%
Epoch  5, Loss: 0.0453, Accuracy: 98.34%
Epoch  6, Loss: 0.0450, Accuracy: 98.34%
Epoch  7, Loss: 0.0398, Accuracy: 98.55%
Epoch  8, Loss: 0.0409, Accuracy: 98.46%
Epoch  9, Loss: 0.0331, Accuracy: 98.78%


In [13]:
shutil.rmtree('./training_checkpoints/')