In [1]:
import os
import json
import tensorflow as tf
import numpy as np

In [5]:
def mnist_dataset(batch_size):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    x_train = x_train / 255.0
    y_train = y_train.astype(np.int64)
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
    return train_dataset

## Build keras model

In [6]:
def create_model():
    model = tf.keras.Sequential([
        tf.keras.Input(shape=[28, 28]),
        tf.keras.layers.Reshape(target_shape=[28, 28, 1]),
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    return model

In [4]:
per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = create_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fe0932eac90>

## Multi-worker Configuration

Now let's enter the world of multi-worker training. In TensorFlow, `TF_CONFIG` environment variable is required for training on multiple machines, each of which possibly has a different role. `TF_CONFIG` is a JSON string used to specify the cluster configuration on each worker that is part of the cluster.

There are two components of `TF_CONFIG`: `cluster` and `task`. cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as `worker`. In multi-worker training, there is usually one `worker` that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the _'chief'_ worker, and it is customary that the `worker` with `index` 0 is appointed as the chief `worker` (in fact this is how `tf.distribute.Strategy` is implemented). `task` on the other hand provides information of the current task. The first component `cluster` is the same for all workers, and the second component `task` is different on each worker  and specifies the `type` and `index` of that worker.

In this example, we set the task `type` to `"worker"` and the task `index` to `0`. This means the machine that has such setting is the first worker, which will be appointed as the chief worker and do more work than other workers. Note that other machines will need to have `TF_CONFIG` environment variable set as well, and it should have the same `cluster` dict, but different task `type` or task `index` depending on what the roles of those machines are.

For illustration purposes, this tutorial shows how one may set a `TF_CONFIG` with 2 workers on `localhost`. In practice, users would create multiple workers on external IP addresses/ports, and set `TF_CONFIG` on each worker appropriately.

In [3]:
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ['localhost: 12345', 'localhost: 23456']
    },
    'task': {
        'type': 'worker', 
        'index': 0
    }
})

## Choose strategy

In [2]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO


## Train the model with MultiWorkerMirroredStrategy

Note: Always pass in `steps_per_epoch` argument to `model.fit()` since `MultiWorkerMirroredStrategy` does not support last partial batch handling. When using `steps_per_epoch`, `model.fit()` does not create a new iterator from the input every epoch, but continues from wherever the last epoch ended. Hence, make sure to call `.repeat()` on the dataset so it has an adequate number of examples for N epochs.

In [15]:
num_worker = 4

# Here the batch size scales up by number of workers
# since `tf.data.Dataset.batch` expects the global batch size,
global_batch_size = per_worker_batch_size * num_worker
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
    multi_worker_model = create_model()
    
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fcee87d2490>

## Dataset sharding and batch size

In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to `model.fit()` without needing to shard; this is because `tf.distribute.Strategy` API takes care of the dataset sharding automatically in multi-worker trainings.

If you prefer manual sharding for four training, automatic sharding can be turned off via `tf.data.experimental.DistributeOption` api.

In [18]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

## Performance

- `MultiWorkerMirroredStrategy` provides multiple collective communication implementations. `RING` implements ring-based collectives using gRPC as the cross-host communication layer. `NCCL` uses Nvidia's NCCL to implement collectives. `AUTO` defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the `communication` parameter of MultiWorkerMirroredStrategy's constructor, e.g. `communication=tf.distribute.experimental.CollectiveCommunication.NCCL`.
- Cast the variables to `tf.float` if possible.

## Fault tolerance

Using `Keras` with `tf.distribute.Strategy` comes with the adcantage of fault tolerance in cases where workers dir or are otherwise unstable. We do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously fialed or preempted, the training state is recorvered.

If a worker gets preempted, the whole cluster pauses until the preempted worker is restarted. Once the worker rejoins the cluster, other workers will also restart. Now, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.

In [20]:
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./training/ckpt')]
with strategy.scope():
    multi_worker_model = create_model()
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70, callbacks=callbacks)

Epoch 1/3
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./training/ckpt/assets
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fced02f2e50>

__Save/Restore outside ModelCheckPoint callback__

In [24]:
# Saving a model
# Let `is_chief` be a utility function that inspects the cluster spec and 
# current task type and returns True if the worker is the chief and False 
# otherwise.
def is_chief():
    return True

if is_chief():
    # This is the model directory will be ideally be a cloud bucket.
    path = '/tmp/model_dir'
else:
    # Save to a path that is unique across workers.
    worker_id = 1 
    path = '/tmp/model_dir/worker_tmp_' + str(worker_id)

multi_worker_model.save(path)

# Restoring a checkpoint
# On the Chief
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
manager = tf.train.CheckpointManager(
    checkpoint,
    directory=path,
    max_to_keep=5
)
status = checkpoint.restore(manager.latest_checkpoint)

# On the Workers
# This is the path that the chief saves the model to
model_dir_path = '/tmp/model_dir'
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
manager = tf.train.CheckpointManager(
    checkpoint,
    directory=path,
    max_to_keep=5
)
latest_checkpoint = tf.train.latest_checkpoint(model_dir_path)
status = checkpoint.restore(latest_checkpoint)


INFO:tensorflow:Assets written to: /tmp/model_dir/assets
