##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Distributed training with Keras

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/keras"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/keras.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

The `tf.distribute.Strategy` API provides an abstraction for distributing your training
across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

This tutorial uses the `tf.distribute.MirroredStrategy`, which
does in-graph replication with synchronous training on many GPUs on one machine.
Essentially, it copies all of the model's variables to each processor.
Then, it uses [all-reduce](http://mpitutorial.com/tutorials/mpi-reduce-and-allreduce/) to combine the gradients from all processors and applies the combined value to all copies of the model.

`MirroredStrategy` is one of several distribution strategy available in TensorFlow core. You can read about more strategies at [distribution strategy guide](../../guide/distributed_training.ipynb).


### Keras API

This example uses the `tf.keras` API to build the model and training loop. For custom training loops, see the [tf.distribute.Strategy with training loops](training_loops.ipynb) tutorial.

## Import dependencies

In [2]:
# Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf

import os

In [3]:
print(tf.__version__)

2.3.0


## Download the dataset

Download the MNIST dataset and load it from [TensorFlow Datasets](https://www.tensorflow.org/datasets). This returns a dataset in `tf.data` format.

Setting `with_info` to `True` includes the metadata for the entire dataset, which is being saved here to `info`.
Among other things, this metadata object includes the number of train and test examples. 


In [4]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



[1mDataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


## Define distribution strategy

Create a `MirroredStrategy` object. This will handle distribution, and provides a context manager (`tf.distribute.MirroredStrategy.scope`) to build your model inside.

In [5]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [6]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

In [7]:
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Pixel values, which are 0-255, [have to be normalized to the 0-1 range](https://en.wikipedia.org/wiki/Feature_scaling). Define this scale in a function.

In [8]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this function to the training and test data, shuffle the training data, and [batch it for training](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch). Notice we are also keeping an in-memory cache of the training data to improve performance.


In [9]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## Create the model

Create and compile the Keras model in the context of `strategy.scope`.

In [10]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

## Define the callbacks


The callbacks used here are:

*   *TensorBoard*: This callback writes a log for TensorBoard which allows you to visualize the graphs.
*   *Model Checkpoint*: This callback saves the model after every epoch.
*   *Learning Rate Scheduler*: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the *learning rate* in the notebook.

In [11]:
# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [12]:
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [13]:
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [14]:
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

## Train and evaluate

Now, train the model in the usual way, calling `fit` on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.


In [15]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


  1/938 [..............................] - ETA: 0s - loss: 2.3083 - accuracy: 0.0156

Instructions for updating:
use `tf.profiler.experimental.stop` instead.


Instructions for updating:
use `tf.profiler.experimental.stop` instead.






  5/938 [..............................] - ETA: 9s - loss: 2.1730 - accuracy: 0.2937

 16/938 [..............................] - ETA: 5s - loss: 1.7422 - accuracy: 0.5537

 27/938 [..............................] - ETA: 5s - loss: 1.3694 - accuracy: 0.6447

 38/938 [>.............................] - ETA: 4s - loss: 1.1314 - accuracy: 0.7015

 49/938 [>.............................] - ETA: 4s - loss: 0.9784 - accuracy: 0.7366

 60/938 [>.............................] - ETA: 4s - loss: 0.8691 - accuracy: 0.7630

 71/938 [=>............................] - ETA: 4s - loss: 0.7914 - accuracy: 0.7826

 82/938 [=>............................] - ETA: 4s - loss: 0.7352 - accuracy: 0.7974

 93/938 [=>............................] - ETA: 4s - loss: 0.6858 - accuracy: 0.8108

104/938 [==>...........................] - ETA: 4s - loss: 0.6488 - accuracy: 0.8202

115/938 [==>...........................] - ETA: 4s - loss: 0.6129 - accuracy: 0.8295

126/938 [===>..........................] - ETA: 3s - loss: 0.5834 - accuracy: 0.8371

137/938 [===>..........................] - ETA: 3s - loss: 0.5595 - accuracy: 0.8433

148/938 [===>..........................] - ETA: 3s - loss: 0.5351 - accuracy: 0.8498

160/938 [====>.........................] - ETA: 3s - loss: 0.5117 - accuracy: 0.8562

171/938 [====>.........................] - ETA: 3s - loss: 0.4942 - accuracy: 0.8605

182/938 [====>.........................] - ETA: 3s - loss: 0.4787 - accuracy: 0.8646

193/938 [=====>........................] - ETA: 3s - loss: 0.4635 - accuracy: 0.8688

204/938 [=====>........................] - ETA: 3s - loss: 0.4500 - accuracy: 0.8727

215/938 [=====>........................] - ETA: 3s - loss: 0.4377 - accuracy: 0.8760
























































































































Learning rate for epoch 1 is 0.0010000000474974513


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Epoch 2/12
  1/938 [..............................] - ETA: 0s - loss: 0.2609 - accuracy: 0.9844

 20/938 [..............................] - ETA: 2s - loss: 0.0811 - accuracy: 0.9773

 40/938 [>.............................] - ETA: 2s - loss: 0.0757 - accuracy: 0.9793

 59/938 [>.............................] - ETA: 2s - loss: 0.0790 - accuracy: 0.9783

 79/938 [=>............................] - ETA: 2s - loss: 0.0739 - accuracy: 0.9802

 99/938 [==>...........................] - ETA: 2s - loss: 0.0729 - accuracy: 0.9796

119/938 [==>...........................] - ETA: 2s - loss: 0.0742 - accuracy: 0.9787

140/938 [===>..........................] - ETA: 2s - loss: 0.0720 - accuracy: 0.9795

160/938 [====>.........................] - ETA: 1s - loss: 0.0714 - accuracy: 0.9799

180/938 [====>.........................] - ETA: 1s - loss: 0.0715 - accuracy: 0.9798

200/938 [=====>........................] - ETA: 1s - loss: 0.0711 - accuracy: 0.9798














































































Learning rate for epoch 2 is 0.0010000000474974513


Epoch 3/12


  1/938 [..............................] - ETA: 0s - loss: 0.0118 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0442 - accuracy: 0.9844

 39/938 [>.............................] - ETA: 2s - loss: 0.0494 - accuracy: 0.9840

 58/938 [>.............................] - ETA: 2s - loss: 0.0468 - accuracy: 0.9855

 77/938 [=>............................] - ETA: 2s - loss: 0.0442 - accuracy: 0.9858

 96/938 [==>...........................] - ETA: 2s - loss: 0.0473 - accuracy: 0.9854

115/938 [==>...........................] - ETA: 2s - loss: 0.0486 - accuracy: 0.9857

135/938 [===>..........................] - ETA: 2s - loss: 0.0493 - accuracy: 0.9858

155/938 [===>..........................] - ETA: 2s - loss: 0.0502 - accuracy: 0.9853

174/938 [====>.........................] - ETA: 2s - loss: 0.0491 - accuracy: 0.9858

194/938 [=====>........................] - ETA: 1s - loss: 0.0479 - accuracy: 0.9859

213/938 [=====>........................] - ETA: 1s - loss: 0.0482 - accuracy: 0.9860












































































Learning rate for epoch 3 is 0.0010000000474974513


Epoch 4/12
  1/938 [..............................] - ETA: 0s - loss: 0.0047 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0382 - accuracy: 0.9906

 39/938 [>.............................] - ETA: 2s - loss: 0.0350 - accuracy: 0.9892

 59/938 [>.............................] - ETA: 2s - loss: 0.0323 - accuracy: 0.9897

 78/938 [=>............................] - ETA: 2s - loss: 0.0315 - accuracy: 0.9904

 98/938 [==>...........................] - ETA: 2s - loss: 0.0308 - accuracy: 0.9904

117/938 [==>...........................] - ETA: 2s - loss: 0.0302 - accuracy: 0.9909

137/938 [===>..........................] - ETA: 2s - loss: 0.0298 - accuracy: 0.9910

156/938 [===>..........................] - ETA: 2s - loss: 0.0285 - accuracy: 0.9914

176/938 [====>.........................] - ETA: 1s - loss: 0.0292 - accuracy: 0.9911

195/938 [=====>........................] - ETA: 1s - loss: 0.0295 - accuracy: 0.9910

214/938 [=====>........................] - ETA: 1s - loss: 0.0296 - accuracy: 0.9914












































































Learning rate for epoch 4 is 9.999999747378752e-05


Epoch 5/12
  1/938 [..............................] - ETA: 0s - loss: 0.0408 - accuracy: 0.9844

 21/938 [..............................] - ETA: 2s - loss: 0.0179 - accuracy: 0.9970

 41/938 [>.............................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9962

 60/938 [>.............................] - ETA: 2s - loss: 0.0178 - accuracy: 0.9961

 80/938 [=>............................] - ETA: 2s - loss: 0.0204 - accuracy: 0.9949

 99/938 [==>...........................] - ETA: 2s - loss: 0.0246 - accuracy: 0.9942

118/938 [==>...........................] - ETA: 2s - loss: 0.0244 - accuracy: 0.9936

137/938 [===>..........................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9941

156/938 [===>..........................] - ETA: 2s - loss: 0.0258 - accuracy: 0.9935

175/938 [====>.........................] - ETA: 2s - loss: 0.0252 - accuracy: 0.9935

195/938 [=====>........................] - ETA: 1s - loss: 0.0254 - accuracy: 0.9932

215/938 [=====>........................] - ETA: 1s - loss: 0.0246 - accuracy: 0.9932












































































Learning rate for epoch 5 is 9.999999747378752e-05


Epoch 6/12


  1/938 [..............................] - ETA: 0s - loss: 0.0206 - accuracy: 0.9844

 20/938 [..............................] - ETA: 2s - loss: 0.0151 - accuracy: 0.9969

 39/938 [>.............................] - ETA: 2s - loss: 0.0272 - accuracy: 0.9936

 58/938 [>.............................] - ETA: 2s - loss: 0.0242 - accuracy: 0.9943

 77/938 [=>............................] - ETA: 2s - loss: 0.0235 - accuracy: 0.9941

 96/938 [==>...........................] - ETA: 2s - loss: 0.0230 - accuracy: 0.9943

116/938 [==>...........................] - ETA: 2s - loss: 0.0215 - accuracy: 0.9946

135/938 [===>..........................] - ETA: 2s - loss: 0.0226 - accuracy: 0.9942

155/938 [===>..........................] - ETA: 2s - loss: 0.0233 - accuracy: 0.9941

175/938 [====>.........................] - ETA: 2s - loss: 0.0226 - accuracy: 0.9940

195/938 [=====>........................] - ETA: 1s - loss: 0.0219 - accuracy: 0.9941

214/938 [=====>........................] - ETA: 1s - loss: 0.0225 - accuracy: 0.9940










































































Learning rate for epoch 6 is 9.999999747378752e-05


Epoch 7/12
  1/938 [..............................] - ETA: 0s - loss: 0.0252 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0411 - accuracy: 0.9922

 40/938 [>.............................] - ETA: 2s - loss: 0.0313 - accuracy: 0.9934

 60/938 [>.............................] - ETA: 2s - loss: 0.0252 - accuracy: 0.9945

 79/938 [=>............................] - ETA: 2s - loss: 0.0246 - accuracy: 0.9949

 98/938 [==>...........................] - ETA: 2s - loss: 0.0227 - accuracy: 0.9954

118/938 [==>...........................] - ETA: 2s - loss: 0.0225 - accuracy: 0.9954

137/938 [===>..........................] - ETA: 2s - loss: 0.0218 - accuracy: 0.9954

157/938 [====>.........................] - ETA: 2s - loss: 0.0222 - accuracy: 0.9951

177/938 [====>.........................] - ETA: 1s - loss: 0.0210 - accuracy: 0.9955

196/938 [=====>........................] - ETA: 1s - loss: 0.0204 - accuracy: 0.9955

215/938 [=====>........................] - ETA: 1s - loss: 0.0198 - accuracy: 0.9956










































































Learning rate for epoch 7 is 9.999999747378752e-05


Epoch 8/12
  1/938 [..............................] - ETA: 0s - loss: 0.0450 - accuracy: 0.9844

 21/938 [..............................] - ETA: 2s - loss: 0.0158 - accuracy: 0.9955

 42/938 [>.............................] - ETA: 2s - loss: 0.0180 - accuracy: 0.9952

 62/938 [>.............................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9955

 82/938 [=>............................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9954

102/938 [==>...........................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9956

121/938 [==>...........................] - ETA: 2s - loss: 0.0171 - accuracy: 0.9956

141/938 [===>..........................] - ETA: 2s - loss: 0.0172 - accuracy: 0.9957

161/938 [====>.........................] - ETA: 1s - loss: 0.0171 - accuracy: 0.9955

181/938 [====>.........................] - ETA: 1s - loss: 0.0167 - accuracy: 0.9959

200/938 [=====>........................] - ETA: 1s - loss: 0.0168 - accuracy: 0.9959












































































Learning rate for epoch 8 is 9.999999747378752e-06


Epoch 9/12


  1/938 [..............................] - ETA: 0s - loss: 0.0020 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0114 - accuracy: 0.9977

 40/938 [>.............................] - ETA: 2s - loss: 0.0206 - accuracy: 0.9945

 59/938 [>.............................] - ETA: 2s - loss: 0.0187 - accuracy: 0.9950

 79/938 [=>............................] - ETA: 2s - loss: 0.0176 - accuracy: 0.9953

 98/938 [==>...........................] - ETA: 2s - loss: 0.0166 - accuracy: 0.9955

118/938 [==>...........................] - ETA: 2s - loss: 0.0164 - accuracy: 0.9959

138/938 [===>..........................] - ETA: 2s - loss: 0.0161 - accuracy: 0.9959

157/938 [====>.........................] - ETA: 2s - loss: 0.0157 - accuracy: 0.9960

176/938 [====>.........................] - ETA: 2s - loss: 0.0165 - accuracy: 0.9958

196/938 [=====>........................] - ETA: 1s - loss: 0.0181 - accuracy: 0.9956

216/938 [=====>........................] - ETA: 1s - loss: 0.0180 - accuracy: 0.9955












































































Learning rate for epoch 9 is 9.999999747378752e-06


Epoch 10/12


  1/938 [..............................] - ETA: 0s - loss: 0.0065 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0215 - accuracy: 0.9937

 39/938 [>.............................] - ETA: 2s - loss: 0.0220 - accuracy: 0.9960

 58/938 [>.............................] - ETA: 2s - loss: 0.0191 - accuracy: 0.9968

 77/938 [=>............................] - ETA: 2s - loss: 0.0200 - accuracy: 0.9957

 96/938 [==>...........................] - ETA: 2s - loss: 0.0192 - accuracy: 0.9956

115/938 [==>...........................] - ETA: 2s - loss: 0.0194 - accuracy: 0.9955

134/938 [===>..........................] - ETA: 2s - loss: 0.0186 - accuracy: 0.9959

153/938 [===>..........................] - ETA: 2s - loss: 0.0182 - accuracy: 0.9959

172/938 [====>.........................] - ETA: 2s - loss: 0.0185 - accuracy: 0.9955

191/938 [=====>........................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9958

210/938 [=====>........................] - ETA: 1s - loss: 0.0177 - accuracy: 0.9958










































































Learning rate for epoch 10 is 9.999999747378752e-06


Epoch 11/12
  1/938 [..............................] - ETA: 0s - loss: 0.0050 - accuracy: 1.0000

 21/938 [..............................] - ETA: 2s - loss: 0.0289 - accuracy: 0.9940

 41/938 [>.............................] - ETA: 2s - loss: 0.0232 - accuracy: 0.9950

 61/938 [>.............................] - ETA: 2s - loss: 0.0201 - accuracy: 0.9954

 81/938 [=>............................] - ETA: 2s - loss: 0.0199 - accuracy: 0.9956

101/938 [==>...........................] - ETA: 2s - loss: 0.0196 - accuracy: 0.9957

121/938 [==>...........................] - ETA: 2s - loss: 0.0197 - accuracy: 0.9956

141/938 [===>..........................] - ETA: 2s - loss: 0.0183 - accuracy: 0.9958

162/938 [====>.........................] - ETA: 1s - loss: 0.0175 - accuracy: 0.9959

182/938 [====>.........................] - ETA: 1s - loss: 0.0177 - accuracy: 0.9960

202/938 [=====>........................] - ETA: 1s - loss: 0.0174 - accuracy: 0.9961












































































Learning rate for epoch 11 is 9.999999747378752e-06


Epoch 12/12
  1/938 [..............................] - ETA: 0s - loss: 0.0141 - accuracy: 1.0000

 20/938 [..............................] - ETA: 2s - loss: 0.0148 - accuracy: 0.9969

 40/938 [>.............................] - ETA: 2s - loss: 0.0203 - accuracy: 0.9949

 60/938 [>.............................] - ETA: 2s - loss: 0.0232 - accuracy: 0.9937

 80/938 [=>............................] - ETA: 2s - loss: 0.0197 - accuracy: 0.9949

100/938 [==>...........................] - ETA: 2s - loss: 0.0204 - accuracy: 0.9950

119/938 [==>...........................] - ETA: 2s - loss: 0.0194 - accuracy: 0.9951

138/938 [===>..........................] - ETA: 2s - loss: 0.0184 - accuracy: 0.9954

158/938 [====>.........................] - ETA: 2s - loss: 0.0177 - accuracy: 0.9957

177/938 [====>.........................] - ETA: 1s - loss: 0.0172 - accuracy: 0.9959

196/938 [=====>........................] - ETA: 1s - loss: 0.0175 - accuracy: 0.9957

215/938 [=====>........................] - ETA: 1s - loss: 0.0188 - accuracy: 0.9957










































































Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7f50a0d94780>

As you can see below, the checkpoints are getting saved.

In [16]:
# check the checkpoint directory
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index		     ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index		     ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index		     ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


To see how the model perform, load the latest checkpoint and call `evaluate` on the test data.

Call `evaluate` as before using appropriate datasets.

In [17]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.0767 - accuracy: 0.9688

  9/157 [>.............................] - ETA: 0s - loss: 0.0436 - accuracy: 0.9878

 18/157 [==>...........................] - ETA: 0s - loss: 0.0382 - accuracy: 0.9870

 27/157 [====>.........................] - ETA: 0s - loss: 0.0327 - accuracy: 0.9890

 36/157 [=====>........................] - ETA: 0s - loss: 0.0341 - accuracy: 0.9887



























Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991


To see the output, you can download and view the TensorBoard logs at the terminal.

```
$ tensorboard --logdir=path/to/log-directory
```

In [18]:
!ls -sh ./logs

total 4.0K
4.0K train


## Export to SavedModel

Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.


In [19]:
path = 'saved_model/'

In [20]:
model.save(path, save_format='tf')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


Load the model without `strategy.scope`.

In [21]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.0767 - accuracy: 0.9688

 14/157 [=>............................] - ETA: 0s - loss: 0.0413 - accuracy: 0.9877

 28/157 [====>.........................] - ETA: 0s - loss: 0.0330 - accuracy: 0.9888



















Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991


Load the model with `strategy.scope`.

In [22]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

  1/157 [..............................] - ETA: 0s - loss: 0.0767 - accuracy: 0.9688

 12/157 [=>............................] - ETA: 0s - loss: 0.0407 - accuracy: 0.9883

 23/157 [===>..........................] - ETA: 0s - loss: 0.0333 - accuracy: 0.9885

 34/157 [=====>........................] - ETA: 0s - loss: 0.0335 - accuracy: 0.9890























Eval loss: 0.03928377106785774, Eval Accuracy: 0.9864000082015991


### Examples and Tutorials
Here are some examples for using distribution strategy with keras fit/compile:
1. [Transformer](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer_main.py) example trained using `tf.distribute.MirroredStrategy`
2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) example trained using `tf.distribute.MirroredStrategy`.

More examples listed in the [Distribution strategy guide](../../guide/distributed_training.ipynb#examples_and_tutorials)

## Next steps

* Read the [distribution strategy guide](../../guide/distributed_training.ipynb).
* Read the [Distributed Training with Custom Training Loops](training_loops.ipynb) tutorial.
* Visit the [Performance section](../../guide/function.ipynb) in the guide to learn more about other strategies and [tools](../../guide/profiler.md) you can use to optimize the performance of your TensorFlow models.

Note: `tf.distribute.Strategy` is actively under development and we will be adding more examples and tutorials in the near future. Please give it a try. We welcome your feedback via [issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new).