# Save and load a model using a distributed strategy

There are two sets of APIs for saving and loading model in Keras - High level, low level. This tutorial demonstrates how we can use SavedModel APIs when using tf.distribute.strategy. 

In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf 

Prepare the data and model using tf.distribute.strategy

In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load( name = 'mnist', with_info = True, as_supervised = True)
  mnist_train, mnist_test = datasets['train'], datasets [ 'test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA*mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
                                 
    ])

    model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True),
                  optimizer = tf.keras.optimizers.Adam(),
                  metrics = [tf.metrics.SparseCategoricalAccuracy()]
                  )
    return model 




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [4]:
#train the model 

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs = 2)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m
Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f1e66e9e450>

## Save and load the model 

Now that you have a simple model to work with, let's take a look at the saving/loading APIs. There are two sets of APIs available:

High level keras model.save and tf.keras.models.load_model
Low level tf.saved_model.save and tf.saved_model.load

## The Keras APIs 

Here is an example of saving and loading a model with the Keras APIs:


In [6]:
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)

INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


Reloading the model without tf.distribute.Strategy

In [7]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs = 2)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f1e65eb0d90>

After restoring the model, you can continue training on it, even without needing to call compile() again, since it is already compiled before saving. The model is saved in the TensorFlow's standard SavedModel proto format.

Lets load the model and train it

In [None]:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Epoch 1/2
Epoch 2/2


## tf.saved_model APIs
Lets take a look at the lower level APIs. Saving the model is similar to the keras API

In [11]:
model = get_model()
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)


FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.



FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


Loading can be done with tf.saved_model.load(). However, since it is an API that is on the lower level (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

In [12]:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The "serving_default" is the default key for the inference function with a saved Keras model. To do an inference with this function:

In [13]:
predict_dataset = eval_dataset.map(lambda image,label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 6.33138046e-02, -3.60530466e-02, -1.09878458e-01,
        -1.35206953e-01,  1.17276855e-01, -5.08575886e-02,
         1.64591402e-01,  1.48164248e-03,  3.06673627e-02,
        -1.16027847e-01],
       [ 6.24670349e-02, -5.81023619e-02, -3.23682614e-02,
        -1.79502115e-01, -5.37753571e-03, -4.24278937e-02,
         1.50196269e-01,  2.15451270e-02, -5.42667992e-02,
         2.01064870e-02],
       [ 4.65019681e-02, -1.52240157e-01, -1.21094733e-01,
        -3.47707756e-02, -6.38998896e-02, -8.08315575e-02,
         2.70212144e-01, -1.73431300e-02, -4.92041111e-02,
         4.08092588e-02],
       [ 8.67107660e-02,  9.40914266e-03, -1.97490290e-01,
        -7.65941814e-02, -5.37231565e-02,  2.91546136e-02,
         2.18398899e-01,  7.45067792e-03,  3.97551368e-04,
         2.76662838e-02],
       [-8.11396465e-02, -6.86105266e-02, -1.76406965e-01,
        -1.25270575e-01,  3.70395295e-02, -2.49482617e-02,
        

You can also load and do inference in a distributed manner

In [14]:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset
  )

  # calling for function in a distributed manner 

  for batch in dist_predict_dataset:
    another_strategy.run(inference_func, args = (batch,))









INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






















Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this.

In [16]:
import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape = (28, 28, 1), name = 'input_x')
  # Wrap what's loaded to a KerasLayer

  keras_layer = hub.KerasLayer(loaded, trainable = True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model 

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits= True),
                optimizer = tf.keras.optimizers.Adam(),
                metrics = [tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs = 2)











INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


Epoch 1/2
Epoch 2/2


As you can see, hub.KerasLayer wraps the result loaded back from tf.saved_model.load() into a Keras layer that can be used to build another model. This is very useful for transfer learning.

It is possible to mix and match the APIs. You can save a Keras model with model.save, and load a non-Keras model with the low-level API, tf.saved_model.load.

In [17]:
model = get_model()

# saving the model using keras' save() API 
model.save(keras_model_path)

another_strategy = tf.distribute.MirroredStrategy()

# Loading the model using lower level api

with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)

INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets










INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


## Saving/Loading from local device 

When saving and loading from a local io device while running remotely, for example using a cloud TPU, the option experimental_io_device must be used to set the io device to localhost.

In [19]:
model = get_model()

# saving the model to a path on localhost

saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options = save_options)

# loading the model from a path on a localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device= '/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options= load_options)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets










INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


## Caveats 

A special case is when you have a Keras model that does not have well-defined inputs. For example, a Sequential model can be created without any input shapes (Sequential([Dense(3), ...]). Subclassed models also do not have well-defined inputs after initialization. In this case, you should stick with the lower level APIs on both saving and loading, otherwise you will get an error.



In [20]:
class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype= tf.dtypes.float32, name = self.output_name
    )

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()

tf.saved_model.save(my_model, saved_model_path)










FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.



FOR DEVS: If you are overwriting _tracking_metadata in your class, this property has been used to save metadata in the SavedModel. The metadta field will be deprecated soon, so please move the metadata to a different file.


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets
