In [20]:
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

In [21]:
print(tf.__version__)

2.2.0


In [22]:
datasets, info = tfds.load(name="mnist", with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets["train"], datasets["test"]

In [23]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [24]:
print(f"Number of devices: {strategy.num_replicas_in_sync}")

Number of devices: 1


In [25]:
num_train_examples = info.splits["train"].num_examples
num_test_examples = info.splits["test"].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [26]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

In [27]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

In [28]:
with strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation="relu"),
    tf.keras.layers.Dense(10)                              
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=["accuracy"])

In [29]:
checkpoint_dir = "./training_checkpoints"

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [30]:
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [31]:
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print(f"\nLearning rate for epoch {epoch + 1} is {model.optimizer.lr.numpy()}")

In [32]:
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir="./logs"),
  tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True),
  tf.keras.callbacks.LearningRateScheduler(decay),
  PrintLR()              
]

In [33]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
Learning rate for epoch 1 is 0.0010000000474974513
Epoch 2/12
Learning rate for epoch 2 is 0.0010000000474974513
Epoch 3/12
Learning rate for epoch 3 is 0.0010000000474974513
Epoch 4/12
Learning rate for epoch 4 is 9.999999747378752e-05
Epoch 5/12
Learning rate for epoch 5 is 9.999999747378752e-05
Epoch 6/12
Learning rate for epoch 6 is 9.999999747378752e-05
Epoch 7/12
Learning rate for epoch 7 is 9.999999747378752e-05
Epoch 8/12
Learning rate for epoch 8 is 9.999999747378752e-06
Epoch 9/12
Learning rate for epoch 9 is 9.999999747378752e-06
Epoch 10/12
Learning rate for epoch 10 is 9.999999747378752e-06
Epoch 11/12
Learning rate for epoch 11 is 9.999999747378752e-06
Epoch 12/12
Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7f78b7a7dcc0>

In [34]:
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_4.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_4.index
ckpt_10.index		     ckpt_5.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_5.index
ckpt_11.index		     ckpt_6.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_6.index
ckpt_12.index		     ckpt_7.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_7.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_7.index
ckpt_1.index		     ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index


In [35]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print(f"Eval loss: {eval_loss}, Eval Accuracy: {eval_acc}")

Eval loss: 0.037436749786138535, Eval Accuracy: 0.9861999750137329


In [36]:
!tensorboard --logdir=path/to/log-directory

2020-07-09 14:15:38.693950: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.2.2 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [37]:
ls -sh ./logs

total 4.0K
4.0K [0m[01;34mtrain[0m/


In [38]:
path = 'saved_model/'

In [39]:
model.save(path, save_format='tf')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


In [40]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=["accuracy"]
)

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print(f"Eval loss: {eval_loss}, Eval Accuracy: {eval_acc}")

Eval loss: 0.037436749786138535, Eval Accuracy: 0.9861999750137329


In [41]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.Adam(),
      metrics=["accuracy"]
  )

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print(f"Eval loss: {eval_loss}, Eval Accuracy: {eval_acc}")

Eval loss: 0.037436749786138535, Eval Accuracy: 0.9861999750137329
