This example uses tensorflow 2.0

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
tf.debugging.set_log_device_placement(True)

Get a list of physical devices visible to the runtime

In [3]:
cpus = tf.config.experimental.list_physical_devices('CPU') # get a list of cpus

In [4]:
cpus

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]

Splits the CPU into 2 virtual devices

In [5]:
try:
    tf.config.experimental.set_virtual_device_configuration(cpus[0],
                                                            [tf.config.experimental.VirtualDeviceConfiguration(),
                                                             tf.config.experimental.VirtualDeviceConfiguration()])
    logical_cpus = tf.config.experimental.list_logical_devices('CPU')
except RuntimeError as e:
    print(e)

In [6]:
logical_cpus

[LogicalDevice(name='/job:localhost/replica:0/task:0/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/job:localhost/replica:0/task:0/device:CPU:1', device_type='CPU')]

In [7]:
GLOBAL_BATCH_SIZE = 8
x = np.arange(1,15)
x

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [8]:
mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/cpu:0", "/cpu:1"])
dataset = tf.data.Dataset.from_tensor_slices(x).batch(GLOBAL_BATCH_SIZE)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

W1121 23:25:50.034326 17472 cross_device_ops.py:1205] Some requested devices in `tf.distribute.Strategy` are not visible to TensorFlow: /job:localhost/replica:0/task:0/device:CPU:0,/job:localhost/replica:0/task:0/device:CPU:1


Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RebatchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AutoShardDataset in device /job:localhost/replica:0/task:0/device:CPU:0


In [9]:
with mirrored_strategy.scope():
    for d in dist_dataset:
        print(d)

Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MultiDeviceIterator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MultiDeviceIteratorInit in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MultiDeviceIteratorToStringHandle in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:CPU:1
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:1
Executing op AnonymousIteratorV2 in device /jo

In [10]:
def fn(inputs):
    '''
    inputs: "per-replica" values, such as those produced by a "distributed Dataset"
    '''
    return inputs
with mirrored_strategy.scope():
    for d in dist_dataset:
        per_example_inputs = mirrored_strategy.experimental_run_v2(fn, args=(d,))
        print(per_example_inputs)

Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:CPU:1


W1121 23:25:50.406548 17472 mirrored_strategy.py:659] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
W1121 23:25:50.414528 17472 mirrored_strategy.py:659] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


PerReplica:{
  0 /job:localhost/replica:0/task:0/device:CPU:0: tf.Tensor([1 2 3 4], shape=(4,), dtype=int32),
  1 /job:localhost/replica:0/task:0/device:CPU:1: tf.Tensor([5 6 7 8], shape=(4,), dtype=int32)
}
PerReplica:{
  0 /job:localhost/replica:0/task:0/device:CPU:0: tf.Tensor([ 9 10 11 12], shape=(4,), dtype=int32),
  1 /job:localhost/replica:0/task:0/device:CPU:1: tf.Tensor([13 14], shape=(2,), dtype=int32)
}


In [11]:
per_example_inputs.values

(<tf.Tensor: id=255, shape=(4,), dtype=int32, numpy=array([ 9, 10, 11, 12])>,
 <tf.Tensor: id=257, shape=(2,), dtype=int32, numpy=array([13, 14])>)

In [12]:
mean = mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_inputs, axis=0)
print('Aggregate mode MEAN: ', mean)
add = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_example_inputs, axis=0)
print('Aggregate mode SUM: ', add)

W1121 23:25:50.633474 17472 mirrored_strategy.py:659] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


Executing op Sum in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Sum in device /job:localhost/replica:0/task:0/device:CPU:1
Executing op AddN in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Cast in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RealDiv in device /job:localhost/replica:0/task:0/device:CPU:0


W1121 23:25:50.655410 17472 mirrored_strategy.py:659] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.


Aggregate mode MEAN:  tf.Tensor(11.5, shape=(), dtype=float64)
Aggregate mode SUM:  tf.Tensor(69, shape=(), dtype=int32)
