In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (5 seconds)


In [3]:
%%px
import os
import tensorflow as tf
import horovod.tensorflow.keras as hvd
from datetime import datetime
from tensorflow import keras

In [4]:
%%px
hvd.init()

BATCH_SIZE = 64
NUM_EPOCS = 4

def decode(serialized_example):
    """Parses an image and label from the given `serialized_example`."""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        })
    label = tf.cast(features['label'], tf.int32)
    image = tf.io.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, (28, 28, 1))
    return image, label


def normalize(image, label):
    """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label


def get_train_set(filename, batch_size=BATCH_SIZE, epochs=NUM_EPOCS):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    # dataset = dataset.shuffle(128)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(normalize)
    dataset = dataset.shard(hvd.size(), hvd.rank())
    return dataset

# here we use a different function for the validation data
# only to do not shard. Like this we have the same batch in both
# workers and we can check that the validation loss is the same
# for both.
def get_val_set(filename, batch_size=BATCH_SIZE, epochs=NUM_EPOCS):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(normalize)
    return dataset


model = keras.Sequential([keras.layers.Conv2D(32, kernel_size=(3, 3),
                                             activation='relu',
                                             input_shape=(28, 28, 1)),
                          keras.layers.Conv2D(64, (3, 3), activation='relu'),
                          keras.layers.MaxPooling2D(pool_size=(2, 2)),
                          keras.layers.Dropout(0.25),
                          keras.layers.Flatten(),
                          keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dropout(0.5),
                          keras.layers.Dense(10, activation='softmax')])

opt = keras.optimizers.Adam(0.001)
opt = hvd.DistributedOptimizer(opt)

model.compile(optimizer=opt,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(
    log_dir = os.path.join('cnn_hvd_logs',
                           datetime.now().strftime("%d-%H%M")),
    histogram_freq = 1,
    profile_batch = '700,730')

hvd_callback = hvd.callbacks.BroadcastGlobalVariablesCallback(0)

In [5]:
%%px
fit = model.fit(get_train_set('../tfrecords/train.tfrecords'),
                validation_data=get_val_set('../tfrecords/test.tfrecords', epochs=1),
                epochs=NUM_EPOCS,
                verbose=2,
                callbacks=[hvd_callback, tb_callback])

[stdout:0] 
Epoch 1/4
469/469 - 7s - loss: 0.2875 - accuracy: 0.9126 - val_loss: 0.0543 - val_accuracy: 0.9836
Epoch 2/4
469/469 - 8s - loss: 0.0986 - accuracy: 0.9710 - val_loss: 0.0400 - val_accuracy: 0.9869
Epoch 3/4
469/469 - 8s - loss: 0.0734 - accuracy: 0.9775 - val_loss: 0.0349 - val_accuracy: 0.9892
Epoch 4/4
469/469 - 8s - loss: 0.0588 - accuracy: 0.9825 - val_loss: 0.0328 - val_accuracy: 0.9892
[stdout:1] 
Epoch 1/4
469/469 - 7s - loss: 0.2851 - accuracy: 0.9133 - val_loss: 0.0543 - val_accuracy: 0.9836
Epoch 2/4
469/469 - 8s - loss: 0.0991 - accuracy: 0.9693 - val_loss: 0.0400 - val_accuracy: 0.9869
Epoch 3/4
469/469 - 8s - loss: 0.0724 - accuracy: 0.9773 - val_loss: 0.0349 - val_accuracy: 0.9892
Epoch 4/4
469/469 - 8s - loss: 0.0600 - accuracy: 0.9813 - val_loss: 0.0328 - val_accuracy: 0.9892


(!) Note that the validation losses and accuracies must be the same for both ranks.

In [6]:
%load_ext tensorboard

In [7]:
%tensorboard --logdir=cnn_hvd_logs

In [8]:
%ipcluster stop