In [1]:
import os
import json
import tensorflow as tf
import tensorflow_datasets as tfds

The input data is sharded by worker index, so that each worker processes `1/num_workers` distinct portions of the dataset.

In [6]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
    datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
    mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else datasets['test'])
    
    def scale(image, label):
        image = tf.cast(image, tf.float32)
        image /= 255.0
        return image, label
    
    if input_context:
        mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
        
    return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

## Multi-worker configuration

In [3]:
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ['localhost: 12345', 'localhost: 23456']
    },
    'task': {
        'type': 'worker',
        'index': 0
    }
})

## Define the model

In [3]:
LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=[28, 28, 1]),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    logits = model(features, training=False)
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'logits': logits}
        return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
    
    optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    loss = loss_object(labels, logits)
    loss = tf.reduce_sum(loss) * (1.0 / BATCH_SIZE)
    
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=loss)
    
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=optimizer.minimize(loss, tf.compat.v1.train.get_or_create_global_step())
    )

## MultiWorkerMirroredStrategy

In [4]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO


## Train and evaluate the model

In [7]:
config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir='/tmp/multiworker',
    config=config
)

tf.estimator.train_and_evaluate(
    estimator=classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)

INFO:tensorflow:Initializing RunConfig with distribution strategies.


INFO:tensorflow:Initializing RunConfig with distribution strategies.


INFO:tensorflow:Not using Distribute Coordinator.


INFO:tensorflow:Not using Distribute Coordinator.


INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x7f3c5b85be50>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x7f3c5b85be50>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}


INFO:tensorflow:Not using Distribute Coordinator.


INFO:tensorflow:Not using Distribute Coordinator.


INFO:tensorflow:Running training and evaluation locally (non-distributed).


INFO:tensorflow:Running training and evaluation locally (non-distributed).


INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.


INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.


INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy


INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.



Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938


Instructions for updating:
Use standard file utilities to get mtimes.


Instructions for updating:
Use standard file utilities to get mtimes.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...


INFO:tensorflow:loss = 2.2821038, step = 938


INFO:tensorflow:loss = 2.2821038, step = 938


INFO:tensorflow:global_step/sec: 195.464


INFO:tensorflow:global_step/sec: 195.464


INFO:tensorflow:loss = 2.259119, step = 1038 (0.517 sec)


INFO:tensorflow:loss = 2.259119, step = 1038 (0.517 sec)


INFO:tensorflow:global_step/sec: 204.638


INFO:tensorflow:global_step/sec: 204.638


INFO:tensorflow:loss = 2.2728596, step = 1138 (0.487 sec)


INFO:tensorflow:loss = 2.2728596, step = 1138 (0.487 sec)


INFO:tensorflow:global_step/sec: 203.954


INFO:tensorflow:global_step/sec: 203.954


INFO:tensorflow:loss = 2.256188, step = 1238 (0.487 sec)


INFO:tensorflow:loss = 2.256188, step = 1238 (0.487 sec)


INFO:tensorflow:global_step/sec: 193.051


INFO:tensorflow:global_step/sec: 193.051


INFO:tensorflow:loss = 2.2298408, step = 1338 (0.520 sec)


INFO:tensorflow:loss = 2.2298408, step = 1338 (0.520 sec)


INFO:tensorflow:global_step/sec: 200.895


INFO:tensorflow:global_step/sec: 200.895


INFO:tensorflow:loss = 2.2400339, step = 1438 (0.498 sec)


INFO:tensorflow:loss = 2.2400339, step = 1438 (0.498 sec)


INFO:tensorflow:global_step/sec: 183.182


INFO:tensorflow:global_step/sec: 183.182


INFO:tensorflow:loss = 2.2527704, step = 1538 (0.544 sec)


INFO:tensorflow:loss = 2.2527704, step = 1538 (0.544 sec)


INFO:tensorflow:global_step/sec: 213.745


INFO:tensorflow:global_step/sec: 213.745


INFO:tensorflow:loss = 2.2769291, step = 1638 (0.471 sec)


INFO:tensorflow:loss = 2.2769291, step = 1638 (0.471 sec)


INFO:tensorflow:global_step/sec: 247.864


INFO:tensorflow:global_step/sec: 247.864


INFO:tensorflow:loss = 2.2505112, step = 1738 (0.400 sec)


INFO:tensorflow:loss = 2.2505112, step = 1738 (0.400 sec)


INFO:tensorflow:global_step/sec: 551.882


INFO:tensorflow:global_step/sec: 551.882


INFO:tensorflow:loss = 2.223503, step = 1838 (0.183 sec)


INFO:tensorflow:loss = 2.223503, step = 1838 (0.183 sec)


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1876...


INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1876...


INFO:tensorflow:Saving checkpoints for 1876 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Saving checkpoints for 1876 into /tmp/multiworker/model.ckpt.


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1876...


INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1876...


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-05-22T02:30:57Z


INFO:tensorflow:Starting evaluation at 2020-05-22T02:30:57Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-1876


INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-1876


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [10/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [20/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [30/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [40/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [50/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [60/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [70/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [80/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [90/100]


INFO:tensorflow:Evaluation [100/100]


INFO:tensorflow:Evaluation [100/100]


INFO:tensorflow:Inference Time : 1.11871s


INFO:tensorflow:Inference Time : 1.11871s


INFO:tensorflow:Finished evaluation at 2020-05-22-02:30:58


INFO:tensorflow:Finished evaluation at 2020-05-22-02:30:58


INFO:tensorflow:Saving dict for global step 1876: global_step = 1876, loss = 2.2254546


INFO:tensorflow:Saving dict for global step 1876: global_step = 1876, loss = 2.2254546


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1876: /tmp/multiworker/model.ckpt-1876


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1876: /tmp/multiworker/model.ckpt-1876


INFO:tensorflow:Loss for final step: 1.1002737.


INFO:tensorflow:Loss for final step: 1.1002737.


({'loss': 2.2254546, 'global_step': 1876}, [])