In [1]:
# Setup 

In [2]:
import tensorflow as tf
from tensorflow import keras

## Single-host, multi-device synchronous training

#### Get compiled Model

In [3]:
import math
math.sqrt(784)

28.0

In [4]:
def get_compiled_model():
    # make a simple 2-layer densely-connected neural network
    inputs = keras.Input(shape=(784,)) # INput_shape 28 X 28 size image
    x = keras.layers.Dense(256, activation = "relu")(inputs)
    x = keras.layers.Dense(256, activation = "relu")(x)
    outputs = keras.layers.Dense(10)(x)  # 10 Output 0 to 10 
    model = keras.Model(inputs, outputs)
    model.compile(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True), # For Multiple Classification 
    metrics=[keras.metrics.SparseCategoricalAccuracy()],)
    return model
    

In [9]:
def get_dataset():
    batch_size = 32
    num_val_samples = 10000
    
    # Return the MNIST dataset in the form of a " tf.data.Dataset"
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() # LOad Mnist Data set
    
    # Check data size
    print(x_train.shape, x_test.shape, y_train.shape)
    print(type(y_train))
    
    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32")/255. # Change 28 X 28 Data sets to input size, Flatten 
    x_test = x_test.reshape(-1, 784).astype("float32") / 255.
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")
    
    # Reserve num_val_samples samples for validataion 
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    
    return (
    tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )

In [10]:
get_dataset()

(60000, 28, 28) (10000, 28, 28) (60000,)
<class 'numpy.ndarray'>


(<BatchDataset shapes: ((None, 784), (None,)), types: (tf.float32, tf.float32)>,
 <BatchDataset shapes: ((None, 784), (None,)), types: (tf.float32, tf.float32)>,
 <BatchDataset shapes: ((None, 784), (None,)), types: (tf.float32, tf.float32)>)

In [7]:
## Create a Mirrored Strategy
strategy = tf.distribute.MirroredStrategy()
print("Number of devices : {}".format(strategy.num_replicas_in_sync))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Number of devices : 1


In [8]:
# Open a strategy scope
with strategy.scope():
    # Everything that creates variable should be under the strategy scope
    # in general this is only model construction & compille
    
    
    model = get_compiled_model()
    
    
# Train the model on all available devices
train_dataset, val_dataset , test_dataset = get_dataset()
model.fit(train_dataset,epochs = 5, validation_data = val_dataset)

# Test the model on all available devices.
model.evaluate(test_dataset)

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

[0.08927737921476364, 0.9758999943733215]