In [1]:
import os
import ipcmagic

In [2]:
%ipcluster start -n {int(os.environ['SLURM_NNODES'])} --mpi

IPCluster is ready! (5 seconds)


In [3]:
%%px
import os
import glob
import types
import tensorflow as tf
import tensorflow_addons as tfa
from datetime import datetime
from tb_cscs import tensorboard

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
    cluster_resolver=tf.distribute.cluster_resolver.SlurmClusterResolver(),
    communication=tf.distribute.experimental.CollectiveCommunication.NCCL,
)

num_workers = int(os.environ['SLURM_NNODES'])
node_id = int(os.environ['SLURM_NODEID'])

node_id, num_workers

[0;31mOut[0:1]: [0m(0, 2)

[0;31mOut[1:1]: [0m(1, 2)

In [4]:
%%px
image_shape = (224, 224)
batch_size = 128 * num_workers

def decode(serialized_example):
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    image = tf.image.resize(image, image_shape, method='bicubic')
    label = tf.cast(features['image/class/label'], tf.int64) - 1  # [0-999]
    return image, label

list_of_files = glob.glob('/scratch/snx3000/stud50/imagenet/train*')

AUTO = tf.data.experimental.AUTOTUNE
dataset = (tf.data.TFRecordDataset(list_of_files, num_parallel_reads=AUTO)
           .map(decode, num_parallel_calls=AUTO)
           .batch(batch_size)
           .prefetch(AUTO)
          )

In [5]:
%%px
with strategy.scope():
    model = tf.keras.applications.InceptionV3(weights=None,
                                              input_shape=(*image_shape, 3),
                                              classes=1000)

    optimizer = tfa.optimizers.LAMB(lr=1e-3 * (num_workers ** 0.5))

    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=os.path.join('inceptionv3_logs',
                                                                  datetime.now().strftime("%d-%H%M")),
                                             histogram_freq=1,
                                             profile_batch='80,100')

In [6]:
%%px
fit = model.fit(dataset,
                steps_per_epoch=100,
                epochs=1,
                callbacks=[tb_callback])



In [7]:
%reload_ext tensorboard

In [8]:
%tensorboard --logdir=inceptionv3_logs

In [9]:
%ipcluster stop