In [1]:
!pip install -q tf-nightly
import tensorflow as tf
import numpy as np

[K     |████████████████████████████████| 322.7MB 52kB/s 
[K     |████████████████████████████████| 460kB 45.5MB/s 
[K     |████████████████████████████████| 6.8MB 48.8MB/s 
[?25h

In [2]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO


In [3]:
def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()

  x_train = x_train / np.float32(255)
  y_train = y_train / np.float32(255)
  train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

In [4]:
def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
    tf.keras.Input(shape=(28,28)),
    tf.keras.layers.Reshape(target_shape=(28,28,1)),
    tf.keras.layers.Conv2D(32, 3, activation="relu"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10)                             
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=["accuracy"]
  )
  return model

In [5]:
per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fbe7bd801d0>

In [6]:
import json
import os
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

In [7]:
num_workers = 4

global_batch_size = per_worker_batch_size*num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fbdc7ff6ac8>

In [8]:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

In [9]:
def is_chief():
  return True

if is_chief():
  path = "/tmp/model_dir"
else:
  worker_id = 1
  path = f"/tmp/model_dir/worker_tmp_{worker_id}"

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
manager = tf.train.CheckpointManager(
    checkpoint, directory=path, max_to_keep=5
)
manager.save()

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
manager = tf.train.CheckpointManager(
    checkpoint, directory=path, max_to_keep=5
)
status = checkpoint.restore(manager.latest_checkpoint)

model_dir_path = "/tmp/model_dir"
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
latest_checkpoint = tf.train.latest_checkpoint(model_dir_path)
status = checkpoint.restore(latest_checkpoint)

In [10]:
callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir="/tmp/backup")]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset,
                       epochs=3,
                       steps_per_epoch=70,
                       callbacks=callbacks)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fbdc6642358>