# Multi worker training with Keras API 
tf.distribute.MultiWorkerMirroredStrategy. With the help of this strategy, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.

In [1]:
import json 
import os 
import sys 

Before importing TensorFlow, make a few changes to the environment.

Disable all GPUs. This prevents errors caused by the workers all trying to use the same GPU. For a real application each worker would be on a different machine.

In [2]:
# disable all GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

Reset the TF_CONFIG environment variable

In [3]:
os.environ.pop('TF_CONFIG', None)


Be sure that the current directory is on python's path. This allows the notebook to import the files written by %%writefile later.

In [4]:
if '.' not in sys.path:
  sys.path.insert(0, '.')

In [5]:
# now import tf 
import tensorflow as tf

## Dataset and model definition

Next create an mnist.py file with a simple model and dataset setup. This python file will be used by the worker-processes in this tutorial:

In [6]:
%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

Overwriting mnist.py


Try training the model for a small number of epochs and observe the results of a single worker to make sure everything works correctly. As training progresses, the loss should drop and the accuracy should increase.

In [7]:
import mnist

batch_size = 64
single_worker_dataset = mnist.mnist_dataset(batch_size)
single_worker_model = mnist.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f881ddbc650>

## Multi workder configuration 

We need TF.CONFIG env variable for training on multiple machines 

In [8]:
tf_config = {
    'cluster': {
        'worker' : ['localhost:12345', 'localhost: 23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

In [10]:
# tf_config as a json string 
json.dumps(tf_config)

'{"cluster": {"worker": ["localhost:12345", "localhost: 23456"]}, "task": {"type": "worker", "index": 0}}'

Now we will get into Multiworker strategy 

In [11]:
strategy = tf.distribute.MultiWorkerMirroredStrategy()

INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO


In [12]:
with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()



To actually run with MultiWorkerMirroredStrategy you'll need to run worker processes and pass a TF_CONFIG to them.

Like the mnist.py file written earlier, here is the main.py that each of the workers will run:

In [13]:
%%writefile main.py 

import os 
import json 

import tensorflow as tf
import mnist

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size*num_workers
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epoch = 3, steps_per_epoch= 70)

Writing main.py


In [14]:
# json serialize the TF_CONFIG and add it to the environment variables
os.environ['TF_CONFIG'] = json.dumps(tf_config)

Now, you can launch a worker process that will run the main.py and use the TF_CONFIG:

In [15]:
# first kill any previous runs 

%killbgscripts

All background processes were killed.


In [None]:
#We will use the below command to run main.py file and to log hte outputs to a log file 
! python main.py &> job_0.log

In [None]:
import time
time.sleep(10)

In [None]:
# look over to the log file 
! cat job_0.log

The first worker is ready and is waiting for all the other workers to be ready to proceed.

In [None]:
# update the tf_config for the second worker's process to pick up
tf_config['task']['index'] = 1 
os.environ['TF_CONFIG'] = 

Now launch the second worker. This will start the training since all the workers are active (so there's no need to background this process):

In [None]:
# launching second worker 
! python main.py 

This will run slower than the last time because we running this on a single machine. We will kill the background so that they dont affect any next process

In [None]:
os.environ.pop('TF_CONFIG', None)
%killbgscripts

## Dataset Sharding 
In multi-worker training, dataset sharding is needed to ensure convergence and performance.

The example in the previous section relies on the default autosharding provided by the tf.distribute.Strategy API. You can control the sharding by setting the tf.data.experimental.AutoShardPolicy of the tf.data.experimental.DistributeOptions. To learn more about auto-sharding see the Distributed input guide.

Here is a quick example of how to turn OFF the auto sharding, so each replica processes every example (not recommended):

In [None]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF


global_batch_size = 64
multi_worker_dataset = mnist.mnist_dataset(batch_size = 64)
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

## Model saving 



In [None]:
model_path = '/tmp/keras_model'

def _is_chief(task_type, task_id):
  # as this task will run with a single worker, ww'll add taskype=none
  # this is how the tf architecutre is designed 

  return (task_type == 'worker' and task_id == 0) or task_type is None

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir 

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)

write_model_path = write_filepath(model_path, task_type, task_id)

In [None]:
# model saving 
multi_worker_model.save(write_model_path)

As described above, later on the model should only be loaded from the path chief saved to, so let's remove the temporary ones the non-chief workers saved:

In [None]:
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(os.path.dirname(write_model_path))

Now, when it's time to load, let's use convenient tf.keras.models.load_model API, and continue with further work. Here, assume only using single worker to load and continue training, in which case you do not call tf.keras.models.load_model within another strategy.scope().

In [None]:
loaded_model = tf.keras.models.load_model(model_path)

loaded_model.fit(single_worker_dataset, epochs = 2, steps_per_epoch = 20)

##Checkpoint saving and restoring

On the other hand, checkpointing allows you to save model's weights and restore them without having to save the whole model. Here, you'll create one tf.train.Checkpoint that tracks the model, which is managed by a tf.train.CheckpointManager so that only the latest checkpoint is preserved.

In [None]:
checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model = multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory = write_checkpoint_dir, max_to_keep = 1
)

Once the CheckpointManager is set up, you're now ready to save, and remove the checkpoints non-chief workers saved.

In [None]:
checkpoint_manager.save():
if not _is_chief(task_type, task_id):
  tf.io.gfile.rmtree(write_checkpoint_dir)

Now, when you need to restore, you can find the latest checkpoint saved using the convenient tf.train.latest_checkpoint function. After restoring the checkpoint, you can continue with training.

In [None]:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs = 2, steps_per_epoch = 20)

## Backup and restore callback 



In [None]:
callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir = '/tmp/backup')]

with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset,
                       epochs = 3,
                       steps_per_epoch = 70,
                       callbacks = callbacks)