# Distributed training of a CNN on MNIST with Horovod

In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import numpy as np
import tensorflow as tf
import horovod.tensorflow.keras as hvd

In [4]:
%%px
hvd.init()

model = tf.keras.Sequential([
    tf.keras.Input(shape=(28, 28)),
    tf.keras.layers.Reshape(target_shape=(28, 28, 1)), # Convolutional layers expect a channel dimension
    tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Conv2D(128, 3, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10),
])

optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9)
optimizer = hvd.DistributedOptimizer(optimizer)

model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    metrics=['accuracy'])

hvd_callback = hvd.callbacks.BroadcastGlobalVariablesCallback(0)

In [5]:
%%px
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

train_size = len(y_train)
valid_size = len(y_test)

# The `x` arrays are in uint8 and have values in the range [0, 255].
# We need to convert them to float32 with values in the range [0, 1]
train_dataset = (tf.data.Dataset
                   .from_tensor_slices((x_train / np.float32(255), y_train.astype(np.int32)))
                   .shuffle(train_size // hvd.size())
                   .batch(128, drop_remainder=True)
                   .shard(hvd.size(), hvd.rank())
                )

valid_dataset = (tf.data.Dataset
                   .from_tensor_slices((x_test / np.float32(255), y_test.astype(np.int32)))
                   .batch(128, drop_remainder=False))

In [6]:
%%px
fit = model.fit(train_dataset,
                epochs=5,
                validation_data=valid_dataset,
                callbacks=[hvd_callback])

[stdout:0] 
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
[stdout:1] 
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [7]:
%ipcluster stop