In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import os
import tensorflow as tf
import horovod.tensorflow.keras as hvd
from datetime import datetime

In [4]:
%%px
hvd.init()


image_shape = (224, 224)
batch_size = 128

def decode(serialized_example):
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    image = tf.image.resize(image, image_shape, method='bicubic')
    label = tf.cast(features['image/class/label'], tf.int64) - 1  # [0-999]
    return image, label


data_dir = '/scratch/snx3000/stud50/imagenet/'
# We split by hand the files before sending them to each worker
# the we interleave withing the resulting groups on each worker.
num_files_per_worker = 80
list_of_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
list_of_files = [list_of_files[i: i + num_files_per_worker]
                 for i in range(len(list_of_files) // num_files_per_worker)][hvd.rank()]

dataset = tf.data.Dataset.list_files(list_of_files)
dataset = dataset.interleave(tf.data.TFRecordDataset,
                             cycle_length=num_files_per_worker,
                             block_length=1,
                             num_parallel_calls=12)
dataset = dataset.map(decode)
dataset = dataset.batch(128)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

model = tf.keras.applications.InceptionV3(weights=None,
                                          input_shape=(*image_shape, 3),
                                          classes=1001)

optimizer = tf.keras.optimizers.SGD(lr=0.01, momentum=0.9)
optimizer = hvd.DistributedOptimizer(optimizer)

model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=os.path.join('inceptionv3_logs',
                                                                  datetime.now().strftime("%d-%H%M")),
                                             histogram_freq=1,
                                             profile_batch='85,95')

hvd_callback = hvd.callbacks.BroadcastGlobalVariablesCallback(0)

In [5]:
%%px
fit = model.fit(dataset.take(100),
                epochs=1,
                callbacks=[hvd_callback, tb_callback])



In [6]:
%load_ext tensorboard

In [7]:
%tensorboard --logdir=inceptionv3_logs

Reusing TensorBoard on port 6006 (pid 31146), started 0:05:06 ago. (Use '!kill 31146' to kill it.)

In [8]:
%ipcluster stop