Any questions or comments should be directed to brabe2@illinois.edu. Data source paths and support are specific to NCSA's HAL system.

# TensorFlow native distributed training

This tutorial covers distributed training using TensorFlow's native tf.keras API, which allows users to scale their models to multiple GPUs with little code modification. We'll start with a hands-on demonstration of training a model on a single HAL node (4 GPUs) and outline how to set up multi-node training (although it can't be done through jupyter). For this tutorial, we'll do an implementation of SqueezeNet v1.0 trained on ImageNet. The code is fairly standard boilerplate, so you should be able to drop in your own model (defined using tf.keras) and/or dataset (defined using tf.data.Dataset).

First of all, for an optimized, well-maintained, and robust implementation of single- and multi-node training, see the official TF ResNet example (https://github.com/tensorflow/models/tree/master/official/vision/image_classification). It is relatively overwhelming to read and parse the critical sections. The goal of this tutorial is to give bare bones boilerplate that you can understand quickly and extend for your own workload.

# Model definition

First, we'll define a model we want to train using tf.keras (https://www.tensorflow.org/guide/keras):

In [1]:
import numpy as np
import os
import tensorflow as tf

In [2]:
def fire_module(inputs, squeeze_depth, expand_depth, weight_decay):
    x = tf.keras.layers.Conv2D(
        filters=squeeze_depth,
        kernel_size=[1, 1],
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(),
        kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        bias_regularizer=tf.keras.regularizers.l2(weight_decay)) (inputs)
    e1x1 = tf.keras.layers.Conv2D(
        filters=expand_depth,
        kernel_size=[1, 1],
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(),
        kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        bias_regularizer=tf.keras.regularizers.l2(weight_decay)) (x)
    e3x3 = tf.keras.layers.Conv2D(
        filters=expand_depth,
        kernel_size=[3, 3],
        padding='same',
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(),
        kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        bias_regularizer=tf.keras.regularizers.l2(weight_decay)) (x)
    x = tf.keras.layers.Concatenate(1) ([e1x1, e3x3])
    return x

def squeezenet(input_shape, num_classes, weight_decay):
    inputs = tf.keras.Input(shape=input_shape)

    if tf.keras.backend.image_data_format() == 'channels_first':
        x = tf.keras.layers.Lambda(
            lambda x: tf.keras.backend.permute_dimensions(x, (0, 3, 1, 2))
        ) (inputs)
    else:
        x = inputs

    x = tf.keras.layers.Conv2D(
        filters=96,
        kernel_size=[7, 7],
        strides=2,
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling,
        kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        bias_regularizer=tf.keras.regularizers.l2(weight_decay)) (x)
    x = tf.keras.layers.MaxPool2D(
        pool_size=[3, 3],
        strides=2) (x)
    x = fire_module(x, 16, 64, weight_decay)
    x = fire_module(x, 16, 64, weight_decay)
    x = fire_module(x, 32, 128, weight_decay)
    x = tf.keras.layers.MaxPool2D(
        pool_size=[3, 3],
        strides=2) (x)
    x = fire_module(x, 32, 128, weight_decay)
    x = fire_module(x, 48, 192, weight_decay)
    x = fire_module(x, 48, 192, weight_decay)
    x = fire_module(x, 64, 256, weight_decay)
    x = tf.keras.layers.MaxPool2D(
        pool_size=[3, 3],
        strides=2) (x)
    x = fire_module(x, 64, 256, weight_decay)
    x = tf.keras.layers.Dropout(rate=0.5) (x)
    x = tf.keras.layers.Conv2D(
        filters=num_classes,
        kernel_size=[1, 1],
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.random_normal_initializer(mean=0.0, stddev=0.01),
        kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
        bias_regularizer=tf.keras.regularizers.l2(weight_decay)) (x)
    x = tf.keras.layers.GlobalAveragePooling2D() (x)
    outputs = tf.keras.layers.Activation('softmax', dtype='float32')(x)
    
    return tf.keras.Model(inputs, outputs, name='squeezenet')

# Input data pipeline

The next (and possibly most critical) step is defining the input data pipeline. Especially in the case of distributed training, a poorly-designed input pipeline can have a massive impact on throughput. Luckily, TensorFlow's tf.data API provides built-ins that can automatically determine a reasonable pipeline configuration. The input pipeline we'll demo was written following the optimization guidelines in this tutorial (https://www.tensorflow.org/guide/data_performance), which is critical for understanding and adapting the following code.

The first step is getting your data into a tf.data.Dataset object. We'll first show a simple dataset already stored in an h5 file, then move on to ImageNet which is in TFRecord format. Once you have a tf.data.Dataset object, both are treated the same way. An important thing to keep in mind with a custom dataset is the size of your files. You'll need to aggregate (batch) individual images together into larger files, as a large number of reads on small files will overwhelm the network filesystem, which hurts performance for every job on the system. Note that the SqueezeNet model  as defined above cannot be used on the SVHN dataset, as it has 5 prediction outputs (digits) instead of one. See https://github.com/bendrabe/ddl_training/blob/master/svhn/model_estimator.py for more info. The following reads the SVHN h5 file, generates a dataset object from it, and and prepares it for training:

In [3]:
_SHUFFLE_BUFFER = 10000
SVHN_PATH = '/home/shared/svhn/SVHN_multi_digit_norm_grayscale.h5'

def get_svhn_inputfn(is_training, data_dir, batch_size):
    # read from h5 file, must know or figure out layout of custom dataset
    with h5py.File(data_dir,'r') as h5f:
        X_train = h5f['X_train'][:]
        y_train = h5f['y_train'][:]
        X_val = h5f['X_val'][:]
        y_val = h5f['y_val'][:]
        X_test = h5f['X_test'][:]
        y_test = h5f['y_test'][:]
    
    # training data is shuffled and repeated indefinitely
    # shuffle before repeat respects epoch boundaries
    # with keras, don't use num_epochs. instead, use model.fit's steps_per_epoch
    if is_training:
        dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
        dataset = dataset.shuffle(_SHUFFLE_BUFFER).repeat()
    else:
        dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
    
    # both train/val use same batch size and autotuned prefetching
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

Since the ImageNet dataset is stored in TFRecord format for efficiency (https://www.tensorflow.org/tutorials/load_data/tfrecord), it's a bit more complicated to extract. Note that the following depends on how the TFRecords are generated in the first place (feature mapping, data types, file layout), but the data is stored this way on the HAL system in `/home/shared/`:

In [4]:
_NUM_TRAIN_FILES = 1024
_DEFAULT_IMAGE_SIZE = 227
_NUM_CHANNELS = 3

def get_filenames(is_training, data_dir):
    """Return filenames for dataset."""
    if is_training:
        return [
            os.path.join(data_dir, 'train-%05d-of-01024' % i)
            for i in range(_NUM_TRAIN_FILES)]
    else:
        return [
            os.path.join(data_dir, 'validation-%05d-of-00128' % i)
            for i in range(128)]

def parse_serialized_example(serialized_example):
    # Dense features in Example proto.
    feature_map = {
        'image/encoded': tf.io.FixedLenFeature([], dtype=tf.string,
                                            default_value=''),
        'image/class/label': tf.io.FixedLenFeature([], dtype=tf.int64,
                                                default_value=-1),
        'image/class/text': tf.io.FixedLenFeature([], dtype=tf.string,
                                               default_value=''),
    }
    sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32)
    # Sparse features in Example proto.
    feature_map.update(
    {k: sparse_float32 for k in ['image/object/bbox/xmin',
                                 'image/object/bbox/ymin',
                                 'image/object/bbox/xmax',
                                 'image/object/bbox/ymax']})

    features = tf.io.parse_single_example(serialized_example, feature_map)
    label = tf.cast(features['image/class/label'], dtype=tf.int32)

    xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
    ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
    xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
    ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)

    # Note that we impose an ordering of (y, x) just to make life difficult.
    bbox = tf.concat([ymin, xmin, ymax, xmax], 0)

    # Force the variable number of bounding boxes into the shape
    # [1, num_boxes, coords].
    bbox = tf.expand_dims(bbox, 0)
    bbox = tf.transpose(bbox, [0, 2, 1])

    return features['image/encoded'], label, bbox

# PLACE ANY PREPROCESSING (CROPS, FLIPS, NORMALIZATION, ETC) IN THIS FUNCTION
def preprocess_image(raw_image):
    image = tf.image.decode_jpeg(raw_image, channels=_NUM_CHANNELS)
    image = tf.image.resize(
        image, [_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE]
    )
    image = tf.cast(image, dtype=tf.float32)
    return image

def parse_record(raw_record):
    raw_image, label, _ = parse_serialized_example(raw_record)
    image = preprocess_image(raw_image)
    # take care with label range, some code uses 1001 which causes tough to diagnose NaN
    label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
        dtype=tf.float32)
    return image, label

def get_imagenet_inputfn(is_training, data_dir, batch_size):
    filenames = get_filenames(is_training, data_dir)
    dataset = tf.data.Dataset.from_tensor_slices(filenames)

    if is_training:
        # shuffle input files
        dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)

    # convert to individual records
    # 10 files read and deserialized in parallel
    dataset = dataset.interleave(
        tf.data.TFRecordDataset,
        cycle_length=10,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if is_training:
        # shuffle before repeat respects epoch boundaries
        dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
        # with keras, don't use num_epochs. instead, use model.fit's steps_per_epoch
        dataset = dataset.repeat()

    # Parses the raw records into images and labels.
    dataset = dataset.map(
      lambda value: parse_record(value),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)

    # ops between final prefetch and get_next call to iterator are sync.
    # prefetch again to background preprocessing work.
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

# (Optional) Adaptive learning rate

TensorFlow's `tf.keras` API has several learning rate decay schedules available via built-in operations (https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules). However, many users may want to experiment with custom LR decay schedules. The `tf.keras.optimizers.schedules.LearningRateSchedule` class can be extended with a user-defined, step-dependent LR function.

This is critical in large-batch distributed training, as LR warmup has been shown to decrease instability and improve generalization when scaling up your global batch size. In the following code snippet, we define a custom LR schedule that adds linear warmup to the built-in polynomial LR decay schedule (and plots it in tensorboard for debugging):

In [5]:
class PolynomialDecayWithWarmup(
    tf.keras.optimizers.schedules.LearningRateSchedule):
    """Polynomial decay with warmup schedule."""

    def __init__(self, lr0, power, warmup_steps, decay_steps, name=None):
        super(PolynomialDecayWithWarmup, self).__init__()

        self.lr0 = lr0
        self.power = power
        self.post_warmup_lr = lr0 * np.power(1 - warmup_steps/decay_steps, power)
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps
        self.name = name
        
        # caches the LR op so it doesn't get regenerated every time (mem leak)
        self.learning_rate_ops_cache = {}

    def __call__(self, step):
        
        graph = tf.compat.v1.get_default_graph()
        if graph not in self.learning_rate_ops_cache:
            self.learning_rate_ops_cache[graph] = self._get_lr(step)
        return self.learning_rate_ops_cache[graph]
    
    def _get_lr(self, step):
        def warmup_lr(step):
            return self.post_warmup_lr * (
                tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
        def polynomial_lr(step):
            return tf.compat.v1.train.polynomial_decay(
                learning_rate=self.lr0,
                global_step=step,
                decay_steps=self.decay_steps,
                end_learning_rate=0.0,
                power=self.power
            )
        lr = tf.cond(step < self.warmup_steps,
            lambda: warmup_lr(step),
            lambda: polynomial_lr(step))
        # plot in TB for debugging purposes
        tf.summary.scalar('learning_rate', lr)
        return lr

    def get_config(self):
        return {
            'lr0': self.lr0,
            'power': self.power,
            'warmup_steps': self.warmup_steps,
            'decay_steps': self.decay_steps,
            'name': self.name
        }

# Training driver code

Now that we've defined our model and input data pipeline, we'll start on the main driver code. It is fairly self-explanatory.

In [None]:
IMAGENET_PATH = '/home/shared/imagenet/tfrecord'
IMAGENET_NUM_CLASSES = 1000
IMAGENET_NUM_TRAIN_IMAGES = 1281167
IMAGENET_NUM_EVAL_IMAGES = 50000

CHECKPTS = False
MODEL_DIR = 'native_tf_dist/final'
N_GPUS = 4
NUM_EPOCHS = 68
GLOBAL_BATCH_SIZE = 512
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0002
LR0 = 0.04
LR_DECAY_POWER = 1.0
WARMUP_EPOCHS = 3

local_batch_size = GLOBAL_BATCH_SIZE // N_GPUS
train_steps = IMAGENET_NUM_TRAIN_IMAGES // GLOBAL_BATCH_SIZE
eval_steps = IMAGENET_NUM_EVAL_IMAGES // GLOBAL_BATCH_SIZE
decay_steps = NUM_EPOCHS*train_steps
warmup_steps = WARMUP_EPOCHS*train_steps

imagenet_inshape = (_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS)

# determine at runtime NCHW vs NHWC and set in tf.keras
if tf.test.is_built_with_cuda():
    data_format = 'channels_first'
else:
    data_format = 'channels_last'
tf.keras.backend.set_image_data_format(data_format)

# create train and input datasets
train_input_dataset = get_imagenet_inputfn( # get_svhn_inputfn
    is_training=True,
    data_dir=IMAGENET_PATH, # SVHN_PATH
    batch_size=GLOBAL_BATCH_SIZE)
eval_input_dataset = get_imagenet_inputfn(
    is_training=False,
    data_dir=IMAGENET_PATH,
    batch_size=GLOBAL_BATCH_SIZE)

# our custom LR decay schedule
lr_sched = PolynomialDecayWithWarmup(LR0, LR_DECAY_POWER, warmup_steps, decay_steps)
# could also use one of the pre-defined ones as per below
'''
lr_sched = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=LR0,
    decay_steps=decay_steps,
    end_learning_rate=0.0,
    power=1.0)
'''

# setting distribute strategy and compiling model inside scope is
# essentially all you need to go from one device (GPU) to all on the node
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = squeezenet(imagenet_inshape, IMAGENET_NUM_CLASSES, WEIGHT_DECAY)

    optimizer = tf.keras.optimizers.SGD(
        learning_rate=lr_sched,
        momentum=MOMENTUM)

    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
        run_eagerly=False
    )

callbacks = []
# tensorboard callback, can add more metrics with tf.summary
callbacks.append(
    tf.keras.callbacks.TensorBoard(
        log_dir=MODEL_DIR,
        update_freq=100,
        profile_batch=0
    )
)

# checkpointing can be helpful for long-running jobs, below code saves every 5 epochs. see docs for more options.
if CHECKPTS:
    ckpt_full_path = os.path.join(MODEL_DIR, 'model.ckpt-{epoch:04d}')
    callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, save_weights_only=True, save_freq=5))

model.fit(
    train_input_dataset,
    epochs=NUM_EPOCHS,
    steps_per_epoch=train_steps,
    callbacks=callbacks,
    validation_data=eval_input_dataset,
    validation_steps=eval_steps, # see github issue 28995
    validation_freq=1,
    verbose=0 # see github issue 28995, TensorBoard log anyway
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 52 all-reduces with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/re

# (Optional) Extension to multi-node training (> 4 gpus on HAL)

To re-iterate, with how the jupyter scheduling is configured, you cannot easily perform multi-node training within a jupyter notebook. Here we provide the code modifications as well as a SLURM runscript for running multi-node training as a batch job.

An official tutorial can be found at https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras. There are essentially only two necessary changes to the code we've written so far:

1) switching the distribute strategy from `tf.distribute.MirroredStrategy` to `tf.distribute.experimental.MultiWorkerMirroredStrategy` and 

2) setting the `TF_CONFIG` environment variable as per the API spec (https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)

As (1) is a one-line change, we only need to discuss (2). When training on multiple nodes, each will run a copy of your script separately. TensorFlow will handle the communication, but you'll need to provide it a `TF_CONFIG` containing a list of worker hostnames as well as the index in that list of the current worker (different for each node). This must be done __before__ initializing your distribute strategy. We can get all this information from SLURM and put it into the JSON serialized format TF expects using the following code:

In [6]:
import json
import socket
import subprocess

PORT_NUMBER=8888

# get nodelist from slurm and parse into list
slurm_job_nodelist = os.environ["SLURM_JOB_NODELIST"]
cmd = "scontrol show hostname " + slurm_job_nodelist
cmd_result = subprocess.run(cmd.split(), check=True, text=True, stdout=subprocess.PIPE)
node_list = cmd_result.stdout.split('\n')[:-1]

# sorting the node list makes each worker index unique / easy to compute
node_list.sort()

# get the current node's position in list for worker index
node = socket.gethostname()
node_idx = node_list.index(node)

# set env var
os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["{}:{}".format(x, PORT_NUMBER) for x in node_list],
    },
   "task": {"type": "worker", "index": node_idx}
})

# print it for demo purposes
print(os.environ["TF_CONFIG"])

{"cluster": {"worker": ["hal09:8888"]}, "task": {"type": "worker", "index": 0}}


Finally, we need a SLURM batch script that will allocate the number of nodes we want and run the script once on each node. If you clone the official ResNet example (https://github.com/tensorflow/models/tree/master/official/vision/image_classification) into `~/models_r2.1.0`, a functional batch script would look like the following:

    #!/bin/bash

    #SBATCH --job-name="native-tf-dist"
    #SBATCH --output="native-tf-dist.%j.%N.out"
    #SBATCH --error="native-tf-dist.%j.%N.err" 
    #SBATCH --partition=gpu
    #SBATCH --nodes=2
    #SBATCH --ntasks-per-node=1
    #SBATCH --cpus-per-task=144
    #SBATCH --mem-per-cpu=1200
    #SBATCH --gres=gpu:v100:4
    #SBATCH --export=ALL
    #SBATCH -t 24:00:00

    source ~/.bashrc
    conda activate wmlce-v1.7.0-py3.7
    export PYTHONPATH="$PYTHONPATH:~/models_r2.1.0/"

    cd models_r2.1.0/official/vision/image_classification/
    srun python resnet_imagenet_main.py -bs 1024 -dd /home/shared/imagenet/tfrecord -ds multi_worker_mirrored -md ~/native-tf-dist-bs1024_summary -ng 4 -te 90 -dt fp16 -ara nccl --enable_tensorboard
