### Distributed Training
Distributed Training can decrease training time

In [8]:
# Imports
import time
import numpy as np
import tensorflow as tf

#### Preprocess Data

In [9]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [10]:
# Normalise data

X_train = X_train/255.0
X_test = X_test/255.0

X_train.shape

(60000, 28, 28)

In [11]:
# Reshape data
X_train = X_train.reshape(-1, 28*28)
X_test = X_test.reshape(-1, 28*28)

X_train.shape

(60000, 784)

#### Define a normal non-distributed CNN

In [12]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=128, activation='relu', input_shape=(784,)))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(units=10, activation='softmax')) # Output

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 128)               100480    
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


#### Mirrored Strategy

In [13]:
distribute = tf.distribute.MirroredStrategy()

with distribute.scope():
    model_distributed = tf.keras.models.Sequential()
    model_distributed.add(tf.keras.layers.Dense(units=128, activation='relu', input_shape=(784,)))
    model_distributed.add(tf.keras.layers.Dropout(0.2))
    model_distributed.add(tf.keras.layers.Dense(units=10, activation='softmax')) # Output

    model_distributed.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])

    model_distributed.summary()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_6 (Dense)              (None, 128)               100480    
_________________________________________________________________
dropout_3 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_7 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
______________________________________________________________

Model training will perform slower if no. of GPUs < 2

In [14]:
# Distributed Training
start = time.time()
model_distributed.fit(X_train, y_train, epochs=10, batch_size=25)
end = time.time()

Train on 60000 samples
Epoch 1/10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [15]:
print(f'Distributed training took: {end-start}s')

Distributed training took: 51.92981576919556s


In [16]:
# Normal (non-distributed) Training
start = time.time()
model.fit(X_train, y_train, epochs=10, batch_size=25)
end = time.time()

Train on 60000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [17]:
print(f'Normal (non-distributed) training took: {end-start}s')

Normal (non-distributed) training took: 43.18038249015808s
