##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Save and load a model using a distribution strategy

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/distribute/save_and_load"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/save_and_load.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>

</table>

## Overview

It's common to save and load a model during training. There are two sets of APIs for saving and loading a keras model: a high-level API, and a low-level API. This tutorial demonstrates how you can use the SavedModel APIs when using `tf.distribute.Strategy`. To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example: 

Import dependencies:

In [2]:
import tensorflow_datasets as tfds

import tensorflow as tf


Prepare the data and model using `tf.distribute.Strategy`:

In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Train the model: 

In [4]:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Epoch 1/2


2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


  1/938 [..............................] - ETA: 1:45:34 - loss: 2.3057 - sparse_categorical_accuracy: 0.0938

 12/938 [..............................] - ETA: 4s - loss: 1.8381 - sparse_categorical_accuracy: 0.4870     

 21/938 [..............................] - ETA: 4s - loss: 1.4643 - sparse_categorical_accuracy: 0.6131

 31/938 [..............................] - ETA: 4s - loss: 1.1783 - sparse_categorical_accuracy: 0.6825

 40/938 [>.............................] - ETA: 4s - loss: 1.0227 - sparse_categorical_accuracy: 0.7184

 50/938 [>.............................] - ETA: 4s - loss: 0.9100 - sparse_categorical_accuracy: 0.7466

 59/938 [>.............................] - ETA: 4s - loss: 0.8454 - sparse_categorical_accuracy: 0.7632

 68/938 [=>............................] - ETA: 4s - loss: 0.7794 - sparse_categorical_accuracy: 0.7796

 77/938 [=>............................] - ETA: 4s - loss: 0.7246 - sparse_categorical_accuracy: 0.7950

 87/938 [=>............................] - ETA: 4s - loss: 0.6758 - sparse_categorical_accuracy: 0.8080

 97/938 [==>...........................] - ETA: 4s - loss: 0.6419 - sparse_categorical_accuracy: 0.8175

107/938 [==>...........................] - ETA: 4s - loss: 0.6146 - sparse_categorical_accuracy: 0.8246

117/938 [==>...........................] - ETA: 4s - loss: 0.5873 - sparse_categorical_accuracy: 0.8317

127/938 [===>..........................] - ETA: 4s - loss: 0.5613 - sparse_categorical_accuracy: 0.8394

137/938 [===>..........................] - ETA: 4s - loss: 0.5348 - sparse_categorical_accuracy: 0.8466

147/938 [===>..........................] - ETA: 4s - loss: 0.5157 - sparse_categorical_accuracy: 0.8510

157/938 [====>.........................] - ETA: 4s - loss: 0.4955 - sparse_categorical_accuracy: 0.8562

166/938 [====>.........................] - ETA: 4s - loss: 0.4808 - sparse_categorical_accuracy: 0.8604

176/938 [====>.........................] - ETA: 4s - loss: 0.4667 - sparse_categorical_accuracy: 0.8643

187/938 [====>.........................] - ETA: 4s - loss: 0.4542 - sparse_categorical_accuracy: 0.8681

199/938 [=====>........................] - ETA: 3s - loss: 0.4391 - sparse_categorical_accuracy: 0.8729

210/938 [=====>........................] - ETA: 3s - loss: 0.4262 - sparse_categorical_accuracy: 0.8760



























































































































Epoch 2/2


  1/938 [..............................] - ETA: 26s - loss: 0.2521 - sparse_categorical_accuracy: 0.9375

 16/938 [..............................] - ETA: 3s - loss: 0.1280 - sparse_categorical_accuracy: 0.9697 

 31/938 [..............................] - ETA: 3s - loss: 0.0967 - sparse_categorical_accuracy: 0.9763

 47/938 [>.............................] - ETA: 2s - loss: 0.0910 - sparse_categorical_accuracy: 0.9771

 62/938 [>.............................] - ETA: 2s - loss: 0.0819 - sparse_categorical_accuracy: 0.9786

 77/938 [=>............................] - ETA: 2s - loss: 0.0762 - sparse_categorical_accuracy: 0.9795

 91/938 [=>............................] - ETA: 2s - loss: 0.0761 - sparse_categorical_accuracy: 0.9792

106/938 [==>...........................] - ETA: 2s - loss: 0.0759 - sparse_categorical_accuracy: 0.9794

121/938 [==>...........................] - ETA: 2s - loss: 0.0763 - sparse_categorical_accuracy: 0.9790

136/938 [===>..........................] - ETA: 2s - loss: 0.0739 - sparse_categorical_accuracy: 0.9798

151/938 [===>..........................] - ETA: 2s - loss: 0.0723 - sparse_categorical_accuracy: 0.9797

165/938 [====>.........................] - ETA: 2s - loss: 0.0707 - sparse_categorical_accuracy: 0.9795

179/938 [====>.........................] - ETA: 2s - loss: 0.0718 - sparse_categorical_accuracy: 0.9791

193/938 [=====>........................] - ETA: 2s - loss: 0.0720 - sparse_categorical_accuracy: 0.9791

207/938 [=====>........................] - ETA: 2s - loss: 0.0718 - sparse_categorical_accuracy: 0.9794



































































































<keras.callbacks.History at 0x7f3b900396d0>

## Save and load the model

Now that you have a simple model to work with, let's take a look at the saving/loading APIs. 
There are two sets of APIs available:

*   High level keras `model.save` and `tf.keras.models.load_model`
*   Low level `tf.saved_model.save` and `tf.saved_model.load`


### The Keras APIs

Here is an example of saving and loading a model with the Keras APIs:

In [5]:
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)

2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


Restore the model without `tf.distribute.Strategy`:

In [6]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

Epoch 1/2


  1/938 [..............................] - ETA: 4:54 - loss: 0.0342 - sparse_categorical_accuracy: 1.0000

 19/938 [..............................] - ETA: 2s - loss: 0.0677 - sparse_categorical_accuracy: 0.9737  

 38/938 [>.............................] - ETA: 2s - loss: 0.0661 - sparse_categorical_accuracy: 0.9762

 57/938 [>.............................] - ETA: 2s - loss: 0.0637 - sparse_categorical_accuracy: 0.9789

 77/938 [=>............................] - ETA: 2s - loss: 0.0585 - sparse_categorical_accuracy: 0.9807

 97/938 [==>...........................] - ETA: 2s - loss: 0.0567 - sparse_categorical_accuracy: 0.9818

117/938 [==>...........................] - ETA: 2s - loss: 0.0547 - sparse_categorical_accuracy: 0.9828

137/938 [===>..........................] - ETA: 2s - loss: 0.0564 - sparse_categorical_accuracy: 0.9823

156/938 [===>..........................] - ETA: 2s - loss: 0.0567 - sparse_categorical_accuracy: 0.9824

175/938 [====>.........................] - ETA: 2s - loss: 0.0568 - sparse_categorical_accuracy: 0.9823

194/938 [=====>........................] - ETA: 1s - loss: 0.0554 - sparse_categorical_accuracy: 0.9825

214/938 [=====>........................] - ETA: 1s - loss: 0.0547 - sparse_categorical_accuracy: 0.9828















































































Epoch 2/2


  1/938 [..............................] - ETA: 24s - loss: 0.0060 - sparse_categorical_accuracy: 1.0000

 19/938 [..............................] - ETA: 2s - loss: 0.0343 - sparse_categorical_accuracy: 0.9885 

 37/938 [>.............................] - ETA: 2s - loss: 0.0433 - sparse_categorical_accuracy: 0.9852

 55/938 [>.............................] - ETA: 2s - loss: 0.0465 - sparse_categorical_accuracy: 0.9852

 72/938 [=>............................] - ETA: 2s - loss: 0.0450 - sparse_categorical_accuracy: 0.9859

 90/938 [=>............................] - ETA: 2s - loss: 0.0421 - sparse_categorical_accuracy: 0.9863

108/938 [==>...........................] - ETA: 2s - loss: 0.0422 - sparse_categorical_accuracy: 0.9863

126/938 [===>..........................] - ETA: 2s - loss: 0.0412 - sparse_categorical_accuracy: 0.9862

145/938 [===>..........................] - ETA: 2s - loss: 0.0398 - sparse_categorical_accuracy: 0.9870

163/938 [====>.........................] - ETA: 2s - loss: 0.0390 - sparse_categorical_accuracy: 0.9871

181/938 [====>.........................] - ETA: 2s - loss: 0.0385 - sparse_categorical_accuracy: 0.9874

199/938 [=====>........................] - ETA: 2s - loss: 0.0380 - sparse_categorical_accuracy: 0.9875

217/938 [=====>........................] - ETA: 2s - loss: 0.0390 - sparse_categorical_accuracy: 0.9871

















































































<keras.callbacks.History at 0x7f3b187b7150>

After restoring the model, you can continue training on it, even without needing to call `compile()` again, since it is already compiled before saving. The model is saved in the TensorFlow's standard `SavedModel` proto format. For more information, please refer to [the guide to `saved_model` format](../../guide/saved_model.ipynb).

Now to load the model and train it using a `tf.distribute.Strategy`:

In [7]:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Epoch 1/2


2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


  1/938 [..............................] - ETA: 6:00 - loss: 0.0377 - sparse_categorical_accuracy: 0.9844

  6/938 [..............................] - ETA: 10s - loss: 0.0842 - sparse_categorical_accuracy: 0.9870 

 11/938 [..............................] - ETA: 10s - loss: 0.0768 - sparse_categorical_accuracy: 0.9815

 16/938 [..............................] - ETA: 10s - loss: 0.0760 - sparse_categorical_accuracy: 0.9824

 21/938 [..............................] - ETA: 9s - loss: 0.0776 - sparse_categorical_accuracy: 0.9807 

 26/938 [..............................] - ETA: 9s - loss: 0.0791 - sparse_categorical_accuracy: 0.9802

 32/938 [>.............................] - ETA: 9s - loss: 0.0757 - sparse_categorical_accuracy: 0.9785

 37/938 [>.............................] - ETA: 9s - loss: 0.0706 - sparse_categorical_accuracy: 0.9806

 42/938 [>.............................] - ETA: 9s - loss: 0.0666 - sparse_categorical_accuracy: 0.9818

 47/938 [>.............................] - ETA: 9s - loss: 0.0658 - sparse_categorical_accuracy: 0.9820

 53/938 [>.............................] - ETA: 9s - loss: 0.0651 - sparse_categorical_accuracy: 0.9814

 58/938 [>.............................] - ETA: 9s - loss: 0.0623 - sparse_categorical_accuracy: 0.9820

 63/938 [=>............................] - ETA: 9s - loss: 0.0613 - sparse_categorical_accuracy: 0.9819

 68/938 [=>............................] - ETA: 9s - loss: 0.0637 - sparse_categorical_accuracy: 0.9814

 74/938 [=>............................] - ETA: 8s - loss: 0.0615 - sparse_categorical_accuracy: 0.9823

 80/938 [=>............................] - ETA: 8s - loss: 0.0629 - sparse_categorical_accuracy: 0.9822

 85/938 [=>............................] - ETA: 8s - loss: 0.0613 - sparse_categorical_accuracy: 0.9825

 90/938 [=>............................] - ETA: 8s - loss: 0.0607 - sparse_categorical_accuracy: 0.9826

 95/938 [==>...........................] - ETA: 8s - loss: 0.0606 - sparse_categorical_accuracy: 0.9824

100/938 [==>...........................] - ETA: 8s - loss: 0.0624 - sparse_categorical_accuracy: 0.9819

105/938 [==>...........................] - ETA: 8s - loss: 0.0620 - sparse_categorical_accuracy: 0.9818

110/938 [==>...........................] - ETA: 8s - loss: 0.0629 - sparse_categorical_accuracy: 0.9820

115/938 [==>...........................] - ETA: 8s - loss: 0.0626 - sparse_categorical_accuracy: 0.9819

120/938 [==>...........................] - ETA: 8s - loss: 0.0636 - sparse_categorical_accuracy: 0.9815

126/938 [===>..........................] - ETA: 8s - loss: 0.0629 - sparse_categorical_accuracy: 0.9815

132/938 [===>..........................] - ETA: 8s - loss: 0.0624 - sparse_categorical_accuracy: 0.9815

138/938 [===>..........................] - ETA: 8s - loss: 0.0623 - sparse_categorical_accuracy: 0.9817

143/938 [===>..........................] - ETA: 8s - loss: 0.0612 - sparse_categorical_accuracy: 0.9821

149/938 [===>..........................] - ETA: 8s - loss: 0.0618 - sparse_categorical_accuracy: 0.9820

155/938 [===>..........................] - ETA: 7s - loss: 0.0612 - sparse_categorical_accuracy: 0.9819

161/938 [====>.........................] - ETA: 7s - loss: 0.0605 - sparse_categorical_accuracy: 0.9821

167/938 [====>.........................] - ETA: 7s - loss: 0.0603 - sparse_categorical_accuracy: 0.9819

173/938 [====>.........................] - ETA: 7s - loss: 0.0598 - sparse_categorical_accuracy: 0.9821

178/938 [====>.........................] - ETA: 7s - loss: 0.0597 - sparse_categorical_accuracy: 0.9823

184/938 [====>.........................] - ETA: 7s - loss: 0.0593 - sparse_categorical_accuracy: 0.9822

190/938 [=====>........................] - ETA: 7s - loss: 0.0582 - sparse_categorical_accuracy: 0.9826

196/938 [=====>........................] - ETA: 7s - loss: 0.0574 - sparse_categorical_accuracy: 0.9829

202/938 [=====>........................] - ETA: 7s - loss: 0.0563 - sparse_categorical_accuracy: 0.9832

208/938 [=====>........................] - ETA: 7s - loss: 0.0558 - sparse_categorical_accuracy: 0.9834

214/938 [=====>........................] - ETA: 7s - loss: 0.0557 - sparse_categorical_accuracy: 0.9835













































































































































































































































































Epoch 2/2


  1/938 [..............................] - ETA: 34s - loss: 0.0175 - sparse_categorical_accuracy: 1.0000

  6/938 [..............................] - ETA: 9s - loss: 0.0336 - sparse_categorical_accuracy: 0.9896 

 11/938 [..............................] - ETA: 9s - loss: 0.0264 - sparse_categorical_accuracy: 0.9915

 16/938 [..............................] - ETA: 9s - loss: 0.0280 - sparse_categorical_accuracy: 0.9922

 21/938 [..............................] - ETA: 9s - loss: 0.0279 - sparse_categorical_accuracy: 0.9918

 26/938 [..............................] - ETA: 9s - loss: 0.0307 - sparse_categorical_accuracy: 0.9910

 31/938 [..............................] - ETA: 9s - loss: 0.0274 - sparse_categorical_accuracy: 0.9924

 36/938 [>.............................] - ETA: 9s - loss: 0.0271 - sparse_categorical_accuracy: 0.9922

 41/938 [>.............................] - ETA: 9s - loss: 0.0254 - sparse_categorical_accuracy: 0.9924

 47/938 [>.............................] - ETA: 9s - loss: 0.0245 - sparse_categorical_accuracy: 0.9927

 52/938 [>.............................] - ETA: 9s - loss: 0.0263 - sparse_categorical_accuracy: 0.9925

 57/938 [>.............................] - ETA: 9s - loss: 0.0269 - sparse_categorical_accuracy: 0.9929

 62/938 [>.............................] - ETA: 9s - loss: 0.0280 - sparse_categorical_accuracy: 0.9927

 67/938 [=>............................] - ETA: 8s - loss: 0.0312 - sparse_categorical_accuracy: 0.9914

 72/938 [=>............................] - ETA: 8s - loss: 0.0306 - sparse_categorical_accuracy: 0.9915

 77/938 [=>............................] - ETA: 8s - loss: 0.0326 - sparse_categorical_accuracy: 0.9909

 82/938 [=>............................] - ETA: 8s - loss: 0.0315 - sparse_categorical_accuracy: 0.9912

 87/938 [=>............................] - ETA: 8s - loss: 0.0311 - sparse_categorical_accuracy: 0.9914

 92/938 [=>............................] - ETA: 8s - loss: 0.0303 - sparse_categorical_accuracy: 0.9915

 97/938 [==>...........................] - ETA: 8s - loss: 0.0311 - sparse_categorical_accuracy: 0.9907

102/938 [==>...........................] - ETA: 8s - loss: 0.0321 - sparse_categorical_accuracy: 0.9907

107/938 [==>...........................] - ETA: 8s - loss: 0.0322 - sparse_categorical_accuracy: 0.9905

112/938 [==>...........................] - ETA: 8s - loss: 0.0326 - sparse_categorical_accuracy: 0.9905

117/938 [==>...........................] - ETA: 8s - loss: 0.0323 - sparse_categorical_accuracy: 0.9908

122/938 [==>...........................] - ETA: 8s - loss: 0.0330 - sparse_categorical_accuracy: 0.9904

127/938 [===>..........................] - ETA: 8s - loss: 0.0329 - sparse_categorical_accuracy: 0.9904

132/938 [===>..........................] - ETA: 8s - loss: 0.0334 - sparse_categorical_accuracy: 0.9903

137/938 [===>..........................] - ETA: 8s - loss: 0.0330 - sparse_categorical_accuracy: 0.9904

142/938 [===>..........................] - ETA: 8s - loss: 0.0329 - sparse_categorical_accuracy: 0.9904

147/938 [===>..........................] - ETA: 8s - loss: 0.0328 - sparse_categorical_accuracy: 0.9905

152/938 [===>..........................] - ETA: 8s - loss: 0.0329 - sparse_categorical_accuracy: 0.9904

157/938 [====>.........................] - ETA: 8s - loss: 0.0327 - sparse_categorical_accuracy: 0.9904

162/938 [====>.........................] - ETA: 8s - loss: 0.0333 - sparse_categorical_accuracy: 0.9902

167/938 [====>.........................] - ETA: 8s - loss: 0.0334 - sparse_categorical_accuracy: 0.9901

172/938 [====>.........................] - ETA: 7s - loss: 0.0334 - sparse_categorical_accuracy: 0.9902

177/938 [====>.........................] - ETA: 7s - loss: 0.0341 - sparse_categorical_accuracy: 0.9901

182/938 [====>.........................] - ETA: 7s - loss: 0.0342 - sparse_categorical_accuracy: 0.9900

187/938 [====>.........................] - ETA: 7s - loss: 0.0342 - sparse_categorical_accuracy: 0.9900

192/938 [=====>........................] - ETA: 7s - loss: 0.0342 - sparse_categorical_accuracy: 0.9900

197/938 [=====>........................] - ETA: 7s - loss: 0.0337 - sparse_categorical_accuracy: 0.9901

202/938 [=====>........................] - ETA: 7s - loss: 0.0337 - sparse_categorical_accuracy: 0.9900

207/938 [=====>........................] - ETA: 7s - loss: 0.0338 - sparse_categorical_accuracy: 0.9899

212/938 [=====>........................] - ETA: 7s - loss: 0.0342 - sparse_categorical_accuracy: 0.9897

217/938 [=====>........................] - ETA: 7s - loss: 0.0347 - sparse_categorical_accuracy: 0.9896























































































































































































































































































As you can see, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. 

### The `tf.saved_model` APIs

Now let's take a look at the lower level APIs. Saving the model is similar to the keras API:

In [8]:
model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


Loading can be done with `tf.saved_model.load()`. However, since it is an API that is on the lower level (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

In [9]:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The `"serving_default"` is the default key for the inference function with a saved Keras model. To do an inference with this function: 

In [10]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.18789300e-01, -1.78404614e-01,  4.92432676e-02,
        -9.37875658e-02,  1.14302970e-01, -8.99422392e-02,
         9.47709680e-02, -7.75382966e-02,  4.04430032e-02,
         2.41404288e-02],
       [-2.35370561e-01, -3.39397341e-02,  2.73427293e-02,
        -1.08200148e-01,  5.10682352e-02,  1.36142194e-01,
         9.28785652e-02, -5.35808355e-02,  2.56292164e-01,
         1.05301209e-01],
       [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02,
        -1.99329913e-01, -7.45072216e-02,  2.42738128e-02,
         2.07733169e-01, -3.15396488e-03,  4.95976806e-02,
         2.14848563e-01],
       [-9.82482210e-02, -6.13910556e-02,  1.00815810e-01,
        -1.87558904e-01,  1.14685424e-01,  1.53835595e-01,
         1.85714245e-01, -8.74890238e-02,  1.07493028e-01,
         1.57510787e-02],
       [-8.56257528e-02,  3.23683321e-02, -3.66768315e-02,
        -1.47201523e-01, -5.31517603e-02,  1.52744055e-02,
        

2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


You can also load and do inference in a distributed manner:

In [11]:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.






















Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [hub.KerasLayer](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:

In [12]:
import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Epoch 1/2


2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


  1/938 [..............................] - ETA: 39:26 - loss: 2.3146 - sparse_categorical_accuracy: 0.0938

 15/938 [..............................] - ETA: 3s - loss: 1.6928 - sparse_categorical_accuracy: 0.5979   

 30/938 [..............................] - ETA: 3s - loss: 1.2264 - sparse_categorical_accuracy: 0.7021

 45/938 [>.............................] - ETA: 3s - loss: 0.9716 - sparse_categorical_accuracy: 0.7531

 60/938 [>.............................] - ETA: 3s - loss: 0.8240 - sparse_categorical_accuracy: 0.7836

 74/938 [=>............................] - ETA: 3s - loss: 0.7277 - sparse_categorical_accuracy: 0.8083

 90/938 [=>............................] - ETA: 2s - loss: 0.6586 - sparse_categorical_accuracy: 0.8236

105/938 [==>...........................] - ETA: 2s - loss: 0.6126 - sparse_categorical_accuracy: 0.8344

120/938 [==>...........................] - ETA: 2s - loss: 0.5700 - sparse_categorical_accuracy: 0.8443

135/938 [===>..........................] - ETA: 2s - loss: 0.5395 - sparse_categorical_accuracy: 0.8506

150/938 [===>..........................] - ETA: 2s - loss: 0.5124 - sparse_categorical_accuracy: 0.8572

165/938 [====>.........................] - ETA: 2s - loss: 0.4863 - sparse_categorical_accuracy: 0.8646

181/938 [====>.........................] - ETA: 2s - loss: 0.4602 - sparse_categorical_accuracy: 0.8717

197/938 [=====>........................] - ETA: 2s - loss: 0.4392 - sparse_categorical_accuracy: 0.8773

213/938 [=====>........................] - ETA: 2s - loss: 0.4218 - sparse_categorical_accuracy: 0.8813

































































































Epoch 2/2


  1/938 [..............................] - ETA: 26s - loss: 0.2405 - sparse_categorical_accuracy: 0.9375

 16/938 [..............................] - ETA: 3s - loss: 0.0837 - sparse_categorical_accuracy: 0.9746 

 31/938 [..............................] - ETA: 3s - loss: 0.0874 - sparse_categorical_accuracy: 0.9733

 47/938 [>.............................] - ETA: 2s - loss: 0.0899 - sparse_categorical_accuracy: 0.9754

 62/938 [>.............................] - ETA: 2s - loss: 0.0829 - sparse_categorical_accuracy: 0.9781

 77/938 [=>............................] - ETA: 2s - loss: 0.0785 - sparse_categorical_accuracy: 0.9787

 93/938 [=>............................] - ETA: 2s - loss: 0.0765 - sparse_categorical_accuracy: 0.9787

109/938 [==>...........................] - ETA: 2s - loss: 0.0778 - sparse_categorical_accuracy: 0.9782

125/938 [==>...........................] - ETA: 2s - loss: 0.0770 - sparse_categorical_accuracy: 0.9787

141/938 [===>..........................] - ETA: 2s - loss: 0.0802 - sparse_categorical_accuracy: 0.9776

157/938 [====>.........................] - ETA: 2s - loss: 0.0785 - sparse_categorical_accuracy: 0.9784

172/938 [====>.........................] - ETA: 2s - loss: 0.0782 - sparse_categorical_accuracy: 0.9781

186/938 [====>.........................] - ETA: 2s - loss: 0.0774 - sparse_categorical_accuracy: 0.9781

201/938 [=====>........................] - ETA: 2s - loss: 0.0757 - sparse_categorical_accuracy: 0.9785

216/938 [=====>........................] - ETA: 2s - loss: 0.0751 - sparse_categorical_accuracy: 0.9786





































































































As you can see, `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load()` into a Keras layer that can be used to build another model. This is very useful for transfer learning. 

### Which API should I use?

For saving, if you are working with a keras model, it is almost always recommended to use the Keras's `model.save()` API. If what you are saving is not a Keras model, then the lower level API is your only choice. 

For loading, which API you use depends on what you want to get from the loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load()`. Otherwise, use `tf.keras.models.load_model()`. Note that you can get a Keras model back only if you saved a Keras model. 

It is possible to mix and match the APIs. You can save a Keras model with `model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. 

In [13]:
model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)

INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


### Saving/Loading from local device

When saving and loading from a local io device while running remotely, for example using a cloud TPU, the option `experimental_io_device` must be used to set the io device to localhost.

In [14]:
model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


### Caveats

A special case is when you have a Keras model that does not have well-defined inputs. For example, a Sequential model can be created without any input shapes (`Sequential([Dense(3), ...]`). Subclassed models also do not have well-defined inputs after initialization. In this case, you should stick with the lower level APIs on both saving and loading, otherwise you will get an error. 

To check if your model has well-defined inputs, just check if `model.inputs` is `None`. If it is not `None`, you are all good. Input shapes are automatically defined when the model is used in `.fit`, `.evaluate`, `.predict`, or when calling the model (`model(inputs)`). 

Here is an example:

In [15]:
class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)









INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets
