In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
mirrored_strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [3]:
def get_data():
    datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
    mnist_train, mnist_test = datasets['train'], datasets['test']
    
    BUFFER_SIZE = 10000
    
    BATCH_SIZE_PER_REPLICA = 64
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
    
    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255.
        return image, label
    
    train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
    eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
    
    return train_dataset, eval_dataset

def get_model():
    with mirrored_strategy.scope():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D(),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10)
        ])
        
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
        
        return model

In [4]:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f522db3a990>

## Save and load model

- High level keras `model.save` and `tf.keras.model.load_model`
- Low level `tf.saved_model.save` and `tf.saved_model.load`

### The Keras APIs

In [5]:
keras_model_path = '/tmp/keras_save'
model.save(keras_model_path)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


INFO:tensorflow:Assets written to: /tmp/keras_save/assets


In [6]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)

In [28]:
# With other strategy
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
    restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
    restored_keras_model_ds.fit(train_dataset, epochs=2)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Epoch 1/2
Epoch 2/2


In [31]:
restored_keras_model_ds.predict(predict_dataset.take(2))

array([[ -8.380849  ,  -2.9549682 ,  19.51616   , ...,  -5.288287  ,
          2.152981  ,   1.2952065 ],
       [ 15.341793  , -17.153084  ,  -0.700944  , ...,  -9.770292  ,
          4.5389743 ,  -2.9369178 ],
       [-10.871754  ,  -2.4758656 ,  -6.406628  , ...,   3.7767417 ,
          0.84727967,   2.8655677 ],
       ...,
       [ -2.196619  , -11.504096  ,  -7.55236   , ...,  -9.367491  ,
         -0.25092593,   4.6905107 ],
       [ -6.5191507 ,  -1.3269238 ,  -1.8378897 , ...,  -3.230547  ,
         -3.5324705 ,   2.2022908 ],
       [ -7.56697   ,  -4.6730614 ,  -0.84442943, ...,  -4.1343274 ,
         -2.3198633 ,   5.8161902 ]], dtype=float32)

### The tf.saved_model APIs

In [10]:
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


In [13]:
DEFAULT_FUNCTOIN_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTOIN_KEY]

In [19]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(2):
    print(tf.argmax(inference_func(batch)['dense_1'], axis=-1))

tf.Tensor(
[2 0 4 8 7 6 0 6 3 1 8 0 7 9 8 4 5 3 4 0 6 6 3 0 2 3 6 6 7 4 9 3 8 2 5 4 2
 5 5 8 5 2 9 2 4 2 7 0 5 1 0 7 9 9 9 6 5 8 8 6 9 9 5 4], shape=(64,), dtype=int64)
tf.Tensor(
[2 6 8 1 0 6 9 5 5 4 1 6 7 5 2 9 0 6 4 4 2 8 7 8 3 0 9 0 1 1 9 4 5 9 7 6 6
 0 7 7 8 4 8 8 1 8 0 2 9 1 0 3 9 7 0 4 9 6 8 9 3 5 4 3], shape=(64,), dtype=int64)


In [25]:
# In a distributed manner
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
    loaded = tf.saved_model.load(saved_model_path)
    inference_func = loaded.signatures[DEFAULT_FUNCTOIN_KEY]
    
    dist_predict_dataset = another_strategy.experimental_distribute_dataset(predict_dataset)
    
    for batch in dist_predict_dataset:
        another_strategy.run(inference_func, args=(batch, ))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)






















Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:

In [27]:
import tensorflow_hub as hub

def build_model(loaded):
    x = tf.keras.layers.Input(shape=[28, 28, 1], name='input_x')
    keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
    model = tf.keras.Model(x, keras_layer)
    return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
    loaded = tf.saved_model.load(saved_model_path)
    model = build_model(loaded)
    
    model.compile(optimizer='Adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
    model.fit(train_dataset, epochs=2)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Epoch 1/2
Epoch 2/2
