In [1]:
import ipcmagic

In [2]:
%ipcluster start -n 2 --mpi

IPCluster is ready! (6 seconds)


In [3]:
%%px
import os
import math
import tensorflow as tf
from datetime import datetime
from tensorflow import keras

In [4]:
%%px
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
    cluster_resolver=tf.distribute.cluster_resolver.SlurmClusterResolver(),
    communication=tf.distribute.experimental.CollectiveCommunication.NCCL,
)

@tf.function(experimental_compile=True)
def decode(serialized_example):
    """Parses an image and label from the given `serialized_example`."""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        })
    label = tf.cast(features['label'], tf.int32)
    image = tf.io.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, (28, 28, 1))
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label

@tf.function(experimental_compile=True)
def normalize(image, label):
    """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label


def get_train_set(filename, batch_size):
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA

    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    # dataset = dataset.shuffle(128)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # dataset = dataset.map(normalize)
    dataset = dataset.with_options(options)
    return dataset

def get_val_set(filename, batch_size):
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA

    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # dataset = dataset.map(normalize)
    dataset = dataset.with_options(options)
    return dataset

with strategy.scope():
    model = keras.Sequential([keras.layers.Conv2D(32, kernel_size=(3, 3),
                                                 activation='relu',
                                                 input_shape=(28, 28, 1)),
                              keras.layers.Conv2D(64, (3, 3), activation='relu'),
                              keras.layers.MaxPooling2D(pool_size=(2, 2)),
                              keras.layers.Dropout(0.25),
                              keras.layers.Flatten(),
                              keras.layers.Dense(128, activation='relu'),
                              keras.layers.Dropout(0.5),
                              keras.layers.Dense(10, activation='softmax')])

    opt = keras.optimizers.Adam(0.001)

    model.compile(optimizer=opt,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

tb_callback = tf.keras.callbacks.TensorBoard(
    log_dir = os.path.join('cnn_tfdistr_logs',
                           datetime.now().strftime("%d-%H%M")),
    histogram_freq = 1,
    profile_batch = '100,130')

In [5]:
%%px
train_size = 60000
val_size = 10000
batch_size_per_worker = 64
num_workers = int(os.environ['SLURM_NNODES'])
batch_size = batch_size_per_worker * num_workers
num_epochs = 2

fit = model.fit(get_train_set('../tfrecords/train.tfrecords', batch_size),
                validation_data=get_val_set('../tfrecords/test.tfrecords', batch_size),
                epochs=num_epochs,
                verbose=2,
                steps_per_epoch=train_size // batch_size // num_workers,
                validation_steps=val_size // batch_size // num_workers,
                callbacks=[tb_callback])

[stdout:0] 
Epoch 1/2
234/234 - 5s - loss: 0.4175 - accuracy: 0.8730 - val_loss: 0.1248 - val_accuracy: 0.9579
Epoch 2/2
234/234 - 3s - loss: 0.1465 - accuracy: 0.9571 - val_loss: 0.0797 - val_accuracy: 0.9736
[stdout:1] 
Epoch 1/2
234/234 - 5s - loss: 0.4175 - accuracy: 0.8730 - val_loss: 0.1248 - val_accuracy: 0.9579
Epoch 2/2
234/234 - 3s - loss: 0.1465 - accuracy: 0.9571 - val_loss: 0.0797 - val_accuracy: 0.9736


(!) `tf.distributed` adds automatically `dataset = strategy.experimental_distribute_dataset(dataset)` to the dataset.

In [6]:
%load_ext tensorboard

In [7]:
%tensorboard --logdir=cnn_tfdistr_logs

In [8]:
%ipcluster stop