# Distributed training with Keras on Gradient

Last updated: Feb 08th 2022

This shows the TensorFlow tutorial _Distributed training with Keras_, from https://www.tensorflow.org/tutorials/distribute/keras, slightly modified to run on Gradient.

The associated GitHub repository for this version is at https://github.com/gradient-ai/TensorFlow-Distribution-Strategies.

This notebook is designed to run on a multi-GPU machine, e.g., an A5000x2 Gradient instance.

If you ran other notebooks before this one, and you get a GPU out-of-memory error, this can be remedied by restarting the Notebook instance.

Extra setup for Gradient:

- We add a setup step so the notebook runs better in Gradient: ipywidgets.  
- Assuming we are running this on a multi-GPU instance, we check that the GPUs are visible using `nvidia-smi`.  
- A correct `nvidia-smi` result doesn't *guarantee* that TensorFlow will see the GPUs, but in the code below it also shows the number of devices it is seeing.


In [None]:
!pip install ipywidgets















You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.[0m


In [None]:
!nvidia-smi

Tue Feb  8 23:24:41 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  RTX A5000           Off  | 00000000:00:05.0 Off |                  Off |
| 30%   33C    P8    18W / 230W |      0MiB / 24256MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  RTX A5000           Off  | 00000000:00:06.0 Off |                  Off |
| 30%   36C    P8    18W / 230W |      0MiB / 24256MiB |      0%      Defaul

# Original TensorFlow Tutorial
_The remaining content follows the original TensorFlow tutorial, unless noted by a comment prefixed by **#PS** (Paperspace)._

## Overview

The `tf.distribute.Strategy` API provides an abstraction for distributing your training across multiple processing units. It allows you to carry out distributed training using existing models and training code with minimal changes.

This tutorial demonstrates how to use the `tf.distribute.MirroredStrategy` to perform in-graph replication with _synchronous training on many GPUs on one machine_. The strategy essentially copies all of the model's variables to each processor. Then, it uses [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) to combine the gradients from all processors, and applies the combined value to all copies of the model.

You will use the `tf.keras` APIs to build the model and `Model.fit` for training it. (To learn about distributed training with a custom training loop and the `MirroredStrategy`, check out [this tutorial](custom_training.ipynb).)

`MirroredStrategy` trains your model on multiple GPUs on a single machine. For _synchronous training on many GPUs on multiple workers_, use the `tf.distribute.MultiWorkerMirroredStrategy` [with the Keras Model.fit](multi_worker_with_keras.ipynb) or [a custom training loop](multi_worker_with_ctl.ipynb). For other options, refer to the [Distributed training guide](../../guide/distributed_training.ipynb).

To learn about various other strategies, there is the [Distributed training with TensorFlow](../../guide/distributed_training.ipynb) guide.

## Setup

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard

In [None]:
print(tf.__version__)

2.6.0


## Download the dataset

Load the MNIST dataset from [TensorFlow Datasets](https://www.tensorflow.org/datasets). This returns a dataset in the `tf.data` format.

Setting the `with_info` argument to `True` includes the metadata for the entire dataset, which is being saved here to `info`. Among other things, this metadata object includes the number of train and test examples.

In [None]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

2022-02-08 23:25:24.694740: W tensorflow/core/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "Not found: Could not locate the credentials file.". Retrieving token from GCE failed with "Failed precondition: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata".


[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


2022-02-08 23:25:27.870689: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1050] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-08 23:25:27.871585: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1050] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-08 23:25:27.879496: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1050] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-08 23:25:27.880350: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1050] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-08 23:25:27.881154: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1050] successful NUMA node read f

## Define the distribution strategy

Create a `MirroredStrategy` object. This will handle distribution and provide a context manager (`MirroredStrategy.scope`) to build your model inside.

In [None]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [None]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 2


## Set up the input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory and tune the learning rate accordingly.

In [None]:
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Define a function that normalizes the image pixel values from the `[0, 255]` range to the  `[0, 1]` range ([feature scaling](https://en.wikipedia.org/wiki/Feature_scaling)):

In [None]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this `scale` function to the training and test data, and then use the `tf.data.Dataset` APIs to shuffle the training data (`Dataset.shuffle`), and batch it (`Dataset.batch`). Notice that you are also keeping an in-memory cache of the training data to improve performance (`Dataset.cache`).

In [None]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## Create the model

Create and compile the Keras model in the context of `Strategy.scope`:

In [None]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


## Define the callbacks


Define the following `tf.keras.callbacks`:

- `tf.keras.callbacks.TensorBoard`: writes a log for TensorBoard, which allows you to visualize the graphs.
- `tf.keras.callbacks.ModelCheckpoint`: saves the model at a certain frequency, such as after every epoch.
- `tf.keras.callbacks.LearningRateScheduler`: schedules the learning rate to change after, for example, every epoch/batch.

For illustrative purposes, add a custom callback called `PrintLR` to display the *learning rate* in the notebook.

In [None]:
# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [None]:
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [None]:
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [None]:
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

2022-02-08 23:29:24.659079: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2022-02-08 23:29:24.659118: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.
2022-02-08 23:29:24.659154: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1614] Profiler found 2 GPUs


2022-02-08 23:29:25.038566: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.


2022-02-08 23:29:25.038802: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1749] CUPTI activity buffer flushed


## Train and evaluate

Now, train the model in the usual way by calling `Model.fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

In [None]:
EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)

2022-02-08 23:29:37.439223: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-02-08 23:29:37.490212: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Epoch 1/12
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


2022-02-08 23:29:41.739994: I tensorflow/stream_executor/cuda/cuda_dnn.cc:381] Loaded cuDNN version 8204


2022-02-08 23:29:43.138135: I tensorflow/stream_executor/cuda/cuda_dnn.cc:381] Loaded cuDNN version 8204


2022-02-08 23:29:44.948287: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


  1/469 [..............................] - ETA: 1:08:40 - loss: 2.3169 - accuracy: 0.0703

2022-02-08 23:29:46.388964: I tensorflow/core/profiler/lib/profiler_session.cc:131] Profiler session initializing.
2022-02-08 23:29:46.389005: I tensorflow/core/profiler/lib/profiler_session.cc:146] Profiler session started.


  2/469 [..............................] - ETA: 16:13 - loss: 2.2653 - accuracy: 0.2070  

2022-02-08 23:29:48.456382: I tensorflow/core/profiler/lib/profiler_session.cc:66] Profiler session collecting data.
2022-02-08 23:29:48.456696: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1749] CUPTI activity buffer flushed
2022-02-08 23:29:48.542141: I tensorflow/core/profiler/internal/gpu/cupti_collector.cc:673]  GpuTracer has collected 658 callback api events and 576 activity events. 
2022-02-08 23:29:48.552115: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session tear down.
2022-02-08 23:29:48.623362: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: ./logs/train/plugins/profile/2022_02_08_23_29_48

2022-02-08 23:29:48.643544: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.trace.json.gz


  3/469 [..............................] - ETA: 9:12 - loss: 2.2141 - accuracy: 0.2943 



2022-02-08 23:29:48.661507: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: ./logs/train/plugins/profile/2022_02_08_23_29_48

2022-02-08 23:29:48.725226: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.memory_profile.json.gz
2022-02-08 23:29:48.731305: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: ./logs/train/plugins/profile/2022_02_08_23_29_48
Dumped tool data for xplane.pb to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.xplane.pb
Dumped tool data for overview_page.pb to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.overview_page.pb
Dumped tool data for input_pipeline.pb to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.input_pipeline.pb
Dumped tool data for tensorflow_stats.pb to ./logs/train/plugins/profile/2022_02_08_23_29_48/n5ppnv4suq.tensorflow_st





 11/469 [..............................] - ETA: 1:50 - loss: 1.7857 - accuracy: 0.5646

 17/469 [>.............................] - ETA: 1:10 - loss: 1.5002 - accuracy: 0.6443

 23/469 [>.............................] - ETA: 51s - loss: 1.2834 - accuracy: 0.6875 

 32/469 [=>............................] - ETA: 36s - loss: 1.0565 - accuracy: 0.7400

 43/469 [=>............................] - ETA: 26s - loss: 0.8827 - accuracy: 0.7782

 53/469 [==>...........................] - ETA: 21s - loss: 0.7896 - accuracy: 0.7973

 62/469 [==>...........................] - ETA: 18s - loss: 0.7217 - accuracy: 0.8131

 72/469 [===>..........................] - ETA: 15s - loss: 0.6662 - accuracy: 0.8261

 83/469 [====>.........................] - ETA: 13s - loss: 0.6168 - accuracy: 0.8365

 93/469 [====>.........................] - ETA: 11s - loss: 0.5772 - accuracy: 0.8458

103/469 [=====>........................] - ETA: 10s - loss: 0.5446 - accuracy: 0.8541


























































Learning rate for epoch 1 is 0.0010000000474974513
Epoch 2/12
  1/469 [..............................] - ETA: 8s - loss: 0.0895 - accuracy: 0.9688

 19/469 [>.............................] - ETA: 1s - loss: 0.1103 - accuracy: 0.9700

 37/469 [=>............................] - ETA: 1s - loss: 0.1055 - accuracy: 0.9713

 55/469 [==>...........................] - ETA: 1s - loss: 0.1062 - accuracy: 0.9706

 73/469 [===>..........................] - ETA: 1s - loss: 0.1046 - accuracy: 0.9713

 91/469 [====>.........................] - ETA: 1s - loss: 0.0984 - accuracy: 0.9730

109/469 [=====>........................] - ETA: 1s - loss: 0.0955 - accuracy: 0.9733












































Learning rate for epoch 2 is 0.0010000000474974513
Epoch 3/12
  1/469 [..............................] - ETA: 8s - loss: 0.0542 - accuracy: 0.9844

 18/469 [>.............................] - ETA: 1s - loss: 0.0638 - accuracy: 0.9852

 36/469 [=>............................] - ETA: 1s - loss: 0.0597 - accuracy: 0.9855

 54/469 [==>...........................] - ETA: 1s - loss: 0.0632 - accuracy: 0.9841

 72/469 [===>..........................] - ETA: 1s - loss: 0.0631 - accuracy: 0.9833

 90/469 [====>.........................] - ETA: 1s - loss: 0.0621 - accuracy: 0.9836

107/469 [=====>........................] - ETA: 1s - loss: 0.0633 - accuracy: 0.9831












































Learning rate for epoch 3 is 0.0010000000474974513
Epoch 4/12
  1/469 [..............................] - ETA: 8s - loss: 0.1162 - accuracy: 0.9531

 19/469 [>.............................] - ETA: 1s - loss: 0.0421 - accuracy: 0.9885

 35/469 [=>............................] - ETA: 1s - loss: 0.0416 - accuracy: 0.9891

 53/469 [==>...........................] - ETA: 1s - loss: 0.0420 - accuracy: 0.9884

 72/469 [===>..........................] - ETA: 1s - loss: 0.0440 - accuracy: 0.9872

 88/469 [====>.........................] - ETA: 1s - loss: 0.0431 - accuracy: 0.9872

105/469 [=====>........................] - ETA: 1s - loss: 0.0440 - accuracy: 0.9876












































Learning rate for epoch 4 is 9.999999747378752e-05
Epoch 5/12
  1/469 [..............................] - ETA: 8s - loss: 0.0058 - accuracy: 1.0000

 18/469 [>.............................] - ETA: 1s - loss: 0.0413 - accuracy: 0.9883

 36/469 [=>............................] - ETA: 1s - loss: 0.0396 - accuracy: 0.9900

 55/469 [==>...........................] - ETA: 1s - loss: 0.0409 - accuracy: 0.9891

 71/469 [===>..........................] - ETA: 1s - loss: 0.0405 - accuracy: 0.9892

 86/469 [====>.........................] - ETA: 1s - loss: 0.0397 - accuracy: 0.9891

102/469 [=====>........................] - ETA: 1s - loss: 0.0378 - accuracy: 0.9894












































Learning rate for epoch 5 is 9.999999747378752e-05
Epoch 6/12
  1/469 [..............................] - ETA: 8s - loss: 0.0469 - accuracy: 0.9766

 19/469 [>.............................] - ETA: 1s - loss: 0.0402 - accuracy: 0.9893

 37/469 [=>............................] - ETA: 1s - loss: 0.0406 - accuracy: 0.9882

 55/469 [==>...........................] - ETA: 1s - loss: 0.0377 - accuracy: 0.9882

 73/469 [===>..........................] - ETA: 1s - loss: 0.0360 - accuracy: 0.9891

 90/469 [====>.........................] - ETA: 1s - loss: 0.0336 - accuracy: 0.9899

108/469 [=====>........................] - ETA: 1s - loss: 0.0351 - accuracy: 0.9898












































Learning rate for epoch 6 is 9.999999747378752e-05
Epoch 7/12
  1/469 [..............................] - ETA: 8s - loss: 0.0660 - accuracy: 0.9766

 19/469 [>.............................] - ETA: 1s - loss: 0.0418 - accuracy: 0.9893

 36/469 [=>............................] - ETA: 1s - loss: 0.0380 - accuracy: 0.9900

 54/469 [==>...........................] - ETA: 1s - loss: 0.0381 - accuracy: 0.9897

 72/469 [===>..........................] - ETA: 1s - loss: 0.0378 - accuracy: 0.9897

 85/469 [====>.........................] - ETA: 1s - loss: 0.0370 - accuracy: 0.9898

104/469 [=====>........................] - ETA: 1s - loss: 0.0363 - accuracy: 0.9899












































Learning rate for epoch 7 is 9.999999747378752e-05
Epoch 8/12
  1/469 [..............................] - ETA: 8s - loss: 0.0505 - accuracy: 0.9766

 19/469 [>.............................] - ETA: 1s - loss: 0.0332 - accuracy: 0.9910

 37/469 [=>............................] - ETA: 1s - loss: 0.0343 - accuracy: 0.9903

 56/469 [==>...........................] - ETA: 1s - loss: 0.0358 - accuracy: 0.9902

 74/469 [===>..........................] - ETA: 1s - loss: 0.0340 - accuracy: 0.9904

 92/469 [====>.........................] - ETA: 1s - loss: 0.0328 - accuracy: 0.9907












































Learning rate for epoch 8 is 9.999999747378752e-06
Epoch 9/12
  1/469 [..............................] - ETA: 9s - loss: 0.0143 - accuracy: 0.9922

 16/469 [>.............................] - ETA: 1s - loss: 0.0421 - accuracy: 0.9888

 35/469 [=>............................] - ETA: 1s - loss: 0.0417 - accuracy: 0.9879

 54/469 [==>...........................] - ETA: 1s - loss: 0.0385 - accuracy: 0.9889

 73/469 [===>..........................] - ETA: 1s - loss: 0.0351 - accuracy: 0.9898

 91/469 [====>.........................] - ETA: 1s - loss: 0.0340 - accuracy: 0.9906

101/469 [=====>........................] - ETA: 1s - loss: 0.0335 - accuracy: 0.9908












































Learning rate for epoch 9 is 9.999999747378752e-06
Epoch 10/12
  1/469 [..............................] - ETA: 8s - loss: 0.0405 - accuracy: 0.9844

 19/469 [>.............................] - ETA: 1s - loss: 0.0319 - accuracy: 0.9905

 37/469 [=>............................] - ETA: 1s - loss: 0.0332 - accuracy: 0.9897

 55/469 [==>...........................] - ETA: 1s - loss: 0.0353 - accuracy: 0.9902

 74/469 [===>..........................] - ETA: 1s - loss: 0.0339 - accuracy: 0.9906

 92/469 [====>.........................] - ETA: 1s - loss: 0.0318 - accuracy: 0.9915














































Learning rate for epoch 10 is 9.999999747378752e-06


Epoch 11/12
  1/469 [..............................] - ETA: 8s - loss: 0.0104 - accuracy: 1.0000

 18/469 [>.............................] - ETA: 1s - loss: 0.0339 - accuracy: 0.9887

 31/469 [>.............................] - ETA: 1s - loss: 0.0364 - accuracy: 0.9889

 48/469 [==>...........................] - ETA: 1s - loss: 0.0337 - accuracy: 0.9897

 67/469 [===>..........................] - ETA: 1s - loss: 0.0324 - accuracy: 0.9902

 85/469 [====>.........................] - ETA: 1s - loss: 0.0312 - accuracy: 0.9903

103/469 [=====>........................] - ETA: 1s - loss: 0.0319 - accuracy: 0.9904












































Learning rate for epoch 11 is 9.999999747378752e-06


Epoch 12/12
  1/469 [..............................] - ETA: 8s - loss: 0.0078 - accuracy: 1.0000

 18/469 [>.............................] - ETA: 1s - loss: 0.0355 - accuracy: 0.9909

 36/469 [=>............................] - ETA: 1s - loss: 0.0343 - accuracy: 0.9898

 55/469 [==>...........................] - ETA: 1s - loss: 0.0339 - accuracy: 0.9905

 73/469 [===>..........................] - ETA: 1s - loss: 0.0329 - accuracy: 0.9904

 89/469 [====>.........................] - ETA: 1s - loss: 0.0311 - accuracy: 0.9911

107/469 [=====>........................] - ETA: 1s - loss: 0.0314 - accuracy: 0.9910












































Learning rate for epoch 12 is 9.999999747378752e-06


<keras.callbacks.History at 0x7f9c8fbd7580>

Check for saved checkpoints:

In [None]:
# Check the checkpoint directory.
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index		     ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index		     ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index		     ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


To check how well the model performs, load the latest checkpoint and call `Model.evaluate` on the test data:

In [None]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))

2022-02-08 23:30:27.922901: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


 1/79 [..............................] - ETA: 1:54 - loss: 0.0573 - accuracy: 0.9766

 5/79 [>.............................] - ETA: 0s - loss: 0.0426 - accuracy: 0.9828  

 9/79 [==>...........................] - ETA: 0s - loss: 0.0423 - accuracy: 0.9844

13/79 [===>..........................] - ETA: 0s - loss: 0.0376 - accuracy: 0.9850

17/79 [=====>........................] - ETA: 0s - loss: 0.0384 - accuracy: 0.9862

























Eval loss: 0.04501435533165932, Eval accuracy: 0.9847000241279602


To visualize the output, launch TensorBoard and view the logs:

In [None]:
# %tensorboard --logdir=logs   #PS: This is not set up on Gradient yet so we comment out

<!-- <img class="tfo-display-only-on-site" src="images/tensorboard_distributed_training_with_keras.png"/> -->

In [None]:
# !ls -sh ./logs   #PS: This gives total size 0 so we comment out

total 0
0 train


## Export to SavedModel

Export the graph and the variables to the platform-agnostic SavedModel format using `Model.save`. After your model is saved, you can load it with or without the `Strategy.scope`.

In [None]:
path = 'saved_model/'

In [None]:
model.save(path, save_format='tf')

2022-02-08 23:32:44.371448: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


Now, load the model without `Strategy.scope`:

In [None]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

 1/79 [..............................] - ETA: 10s - loss: 0.0573 - accuracy: 0.9766

13/79 [===>..........................] - ETA: 0s - loss: 0.0376 - accuracy: 0.9850 











Eval loss: 0.045014359056949615, Eval Accuracy: 0.9847000241279602


Load the model with `Strategy.scope`:

In [None]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

2022-02-08 23:32:57.567965: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


 1/79 [..............................] - ETA: 3:10 - loss: 0.0573 - accuracy: 0.9766

13/79 [===>..........................] - ETA: 0s - loss: 0.0376 - accuracy: 0.9850  











Eval loss: 0.04501435533165932, Eval Accuracy: 0.9847000241279602


### Additional resources

More examples that use different distribution strategies with the Keras `Model.fit` API:

1. The [Solve GLUE tasks using BERT on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial uses `tf.distribute.MirroredStrategy` for training on GPUs and `tf.distribute.TPUStrategy`—on TPUs.
1. The [Save and load a model using a distribution strategy](save_and_load.ipynb) tutorial demonstates how to use the SavedModel APIs with `tf.distribute.Strategy`.
1. The [official TensorFlow models](https://github.com/tensorflow/models/tree/master/official) can be configured to run multiple distribution strategies.

To learn more about TensorFlow distribution strategies:

1. The [Custom training with tf.distribute.Strategy](custom_training.ipynb) tutorial shows how to use the `tf.distribute.MirroredStrategy` for single-worker training with a custom training loop.
1. The [Multi-worker training with Keras](multi_worker_with_keras.ipynb) tutorial shows how to use the `MultiWorkerMirroredStrategy` with `Model.fit`.
1. The [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb) tutorial shows how to use the `MultiWorkerMirroredStrategy` with Keras and a custom training loop.
1. The [Distributed training in TensorFlow](https://www.tensorflow.org/guide/distributed_training) guide provides an overview of the available distribution strategies.
1. The [Better performance with tf.function](../../guide/function.ipynb) guide provides information about other strategies and tools, such as the [TensorFlow Profiler](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.

Note: `tf.distribute.Strategy` is actively under development and TensorFlow will be adding more examples and tutorials in the near future. Please give it a try. Your feedback is welcome—feel free to submit it via [issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new).

##### Copyright 2019 The TensorFlow Authors.


In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.