tf.distribute.Strategy API provides an abstraction for distruting the training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

This tutorial uses the tf.distribute.MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

In [1]:
import tensorflow_datasets as tfds 
import tensorflow as tf
import os 
print(tf.__version__)

2.5.0


Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format. Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

In [2]:
datasets, info = tfds.load(name = 'mnist', with_info= True, as_supervised= True)
mnist_train, mnist_test = datasets['train'], datasets['test']

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


We need to create a MirroredStrategy object 

In [3]:
strategy = tf.distribute.MirroredStrategy()









INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [4]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


Multiple GPU training models allows us to add extra computing power by increasing the batch size. In general, we will use the largest batch size that fits the GPU memory and tune the learning rate.

In [5]:
num_train_examples = info.splits['train'].num_examples 
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64 
BATCH_SIZE = BATCH_SIZE_PER_REPLICA*strategy.num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function 

In [6]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255 

  return image, label

Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

In [7]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Now we will create the model and use our defined 'Strategy' object that uses distributed training

In [8]:
with strategy.scope():
  model = tf.keras.Sequential([
                               tf.keras.layers.Conv2D(32,3,activation='relu', input_shape = (28,28,1)),
                               tf.keras.layers.MaxPooling2D(),
                               tf.keras.layers.Flatten(),
                               tf.keras.layers.Dense(64, activation= 'relu'),
                               tf.keras.layers.Dense(10)
                          
  ])

  model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True),
                optimizer = tf.keras.optimizers.Adam(),
                metrics = ['accuracy'])

## Define the callbacks
The callbacks used here are:

TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.
Model Checkpoint: This callback saves the model after every epoch.
Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

In [9]:
# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'

# name of the checkpoint files 
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

Learning rate need to be scheduled and decayed because not one learning rate fits the whole process. Too small LR may result the model to learn nothing, too big LR may result overfit. That is why we need to schedule the LR 


In [11]:
def decay(epoch):
  if epoch <3 : 
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else: 
    return 1e-5

In [12]:
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs= None):
    print('\nLearning rate for epoch {} in {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [13]:
callbacks = [
             tf.keras.callbacks.TensorBoard(log_dir = './logs'),
             tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                                save_weights_only = True),
             tf.keras.callbacks.LearningRateScheduler(decay),
             PrintLR()
             
]

## Train and evaluate 

In [14]:
model.fit(train_dataset, epochs= 12, callbacks=callbacks)

Epoch 1/12

Learning rate for epoch 1 in 0.0010000000474974513
Epoch 2/12

Learning rate for epoch 2 in 0.0010000000474974513
Epoch 3/12

Learning rate for epoch 3 in 0.0010000000474974513
Epoch 4/12

Learning rate for epoch 4 in 9.999999747378752e-05
Epoch 5/12

Learning rate for epoch 5 in 9.999999747378752e-05
Epoch 6/12

Learning rate for epoch 6 in 9.999999747378752e-05
Epoch 7/12

Learning rate for epoch 7 in 9.999999747378752e-05
Epoch 8/12

Learning rate for epoch 8 in 9.999999747378752e-06
Epoch 9/12

Learning rate for epoch 9 in 9.999999747378752e-06
Epoch 10/12

Learning rate for epoch 10 in 9.999999747378752e-06
Epoch 11/12

Learning rate for epoch 11 in 9.999999747378752e-06
Epoch 12/12

Learning rate for epoch 12 in 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7f314b028550>

See how the learning rate is being changed...and checkpoints are getting saved as well. This is an amazing way of the tf workflow.

In [15]:
# checkpoints 

! ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_4.index
ckpt_10.index		     ckpt_5.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_5.index
ckpt_11.index		     ckpt_6.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_6.index
ckpt_12.index		     ckpt_7.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_7.index
ckpt_1.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


To see how the model perform, load the latest checkpoint and call evaluate on the test data.

Call evaluate as before using appropriate datasets.

In [16]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.038853537291288376, Eval Accuracy: 0.9872000217437744


We can see the Tensorboard logs at the terminal. Lets download and view 

In [17]:
! tensorboard --logdir = path/to/log-directory

2021-07-07 06:51:26.610371: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
usage: tensorboard [-h] [--helpfull] [--logdir PATH] [--logdir_spec PATH_SPEC]
                   [--host ADDR] [--bind_all] [--port PORT]
                   [--reuse_port BOOL] [--load_fast {false,auto,true}]
                   [--extra_data_server_flags EXTRA_DATA_SERVER_FLAGS]
                   [--grpc_creds_type {local,ssl,ssl_dev}]
                   [--grpc_data_provider PORT] [--purge_orphaned_data BOOL]
                   [--db URI] [--db_import] [--inspect] [--version_tb]
                   [--tag TAG] [--event_file PATH] [--path_prefix PATH]
                   [--window_title TEXT] [--max_reload_threads COUNT]
                   [--reload_interval SECONDS] [--reload_task TYPE]
                   [--reload_multifile BOOL]
                   [--reload_multifile_inactive_secs SECONDS]
                   [--generic_data TYPE]
         

In [18]:
! ls -sh ./logs

total 4.0K
4.0K train


In [19]:
# Exporting the model 

path = 'saved_model/'


In [20]:
model.save(path, save_format = 'tf')

INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


In [21]:
# Now load the model without 'strategy'
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer = tf.keras.optimizers.Adam(),
    metrics = ['accuracy']
)

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.038853537291288376, Eval Accuracy: 0.9872000217437744


We can also load the model to train with distributed learning. = strategy 

In [22]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer = tf.keras.optimizers.Adam(),
                           metrics = ['accuracy'])
  
  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.038853537291288376, Eval Accuracy: 0.9872000217437744


Although there was not much difference when trained with distributed learning, I believe the difference will be much higher if dealth with bigger datasets.