In [1]:
import os
import ipcmagic
from ipcmagic import utilities
%ipcluster start -n {int(os.environ['SLURM_NNODES'])} --mpi

IPCluster is ready! (7 seconds)


In [2]:
%%px
import os
import math
import glob
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_addons as tfa
import albumentations as alb
import horovod.tensorflow.keras as hvd

AUTO = tf.data.experimental.AUTOTUNE

epochs = 1 #50
batch_size = 128
valid_samples, train_samples = 50000, 1281167

hvd.init()

num_nodes, node_id = hvd.size(), hvd.rank()
cache_train = (num_nodes >= 4)

## Large Batch Optimization for Deep Learning https://arxiv.org/abs/1904.00962 (tfa.optimizers.LAMB)
## - square root LR scaling
lr = 0.1 * num_nodes**0.5

global_batch_size = (batch_size * num_nodes)
validation_steps = math.ceil(valid_samples / global_batch_size)
steps_per_epoch = round(train_samples / global_batch_size)

print(num_nodes, node_id, cache_train, lr, validation_steps, steps_per_epoch)

[stdout:0] 2 0 False 0.2 196 5005
[stdout:1] 2 1 False 0.2 196 5005


In [3]:
%%px
train_files = sorted(glob.glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/train/*'))
valid_files = sorted(glob.glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/validation/*'))
print(len(train_files), len(valid_files))

[stdout:0] 1024 128
[stdout:1] 1024 128


In [4]:
%%px
image_shape = (224, 224)

def process_image(serialized_example, transforms):
    ''' decode and augment images '''
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    label = tf.cast(features['image/class/label'], tf.int64) - 1  # [0-999]
    
    image_aug = lambda img: transforms(image=img)["image"]
    aug_img = tf.numpy_function(func=image_aug, inp=[image], Tout=tf.float32)
    return aug_img, label


train_transforms = alb.Compose([
            alb.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=30),
            alb.HorizontalFlip(p=0.5),
            alb.OneOf([ # Add or remove noise
                alb.GaussNoise(var_limit=(50, 200)),
                alb.ImageCompression(quality_lower=80, quality_upper=95),
                alb.GaussianBlur(sigma_limit=(2, 10)),
            ], p=1),
            alb.SomeOf([ # Color ops
                alb.ToGray(p=0.1),
                alb.Equalize(by_channels=False),
                alb.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10),
                alb.RandomGamma(gamma_limit=(90, 110)),
                alb.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10),
                alb.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.2, 0.2)),
            ], n=2),
            alb.RandomResizedCrop(*image_shape, scale=(0.9, 1.1), ratio=(0.9, 1.1)),
            alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

valid_transforms = alb.Compose([
            alb.SmallestMaxSize(max_size=image_shape[0]),
            alb.CenterCrop(*image_shape),
            alb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])

In [5]:
%%px
def get_ds(tfrecords,
           transforms,
           batch_size,
           num_nodes,
           node_id,
           seed=42,
           cache=False,
           repeat=False,
           shuffle_buffer=None,
          ):
    dataset = tf.data.Dataset.list_files(tfrecords,
                                         shuffle=(not cache), # no point shuffling if we cache it later
                                         seed=seed)
    dataset = dataset.shard(num_shards=num_nodes,
                            index=node_id)
    dataset = dataset.interleave(tf.data.TFRecordDataset,
                                 cycle_length=4,
                                 block_length=1,
                                 num_parallel_calls=AUTO)

    if cache:
        # caching has to be done before repeat/shuffle
        dataset = dataset.cache()
    if repeat:
        dataset = dataset.repeat()
    if shuffle_buffer:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer,
                                  reshuffle_each_iteration=True,
                                  seed=seed)

    # decode and augment images
    decode_aug = lambda img: process_image(img, transforms)
    dataset = dataset.map(decode_aug, num_parallel_calls=AUTO)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=AUTO)
    return dataset

valid_ds = get_ds(valid_files, valid_transforms, batch_size, num_nodes, node_id, cache=True)
train_ds = get_ds(train_files, train_transforms, batch_size, num_nodes, node_id,
                  shuffle_buffer=1024, repeat=True, cache=cache_train)

In [6]:
%%px
%load_ext autoreload
%autoreload 2

from tensorflow.python.keras.layers import VersionAwareLayers
from horovod.tensorflow.sync_batch_norm import SyncBatchNormalization

class SyncedLayers(VersionAwareLayers):
  """Overload internal utility to access layers in a V1/V2-aware fashion to
     use synchronized layers, i.e.: SyncedLayers().BatchNormalization == SyncBatchNormalization
  """
  def __getattr__(self, name):
    if name is 'BatchNormalization':
        return SyncBatchNormalization
    return super().__getattr__(name)

## Let's test it
print(SyncedLayers().Dense)
print(SyncedLayers().BatchNormalization)

[stdout:0] 
<class 'tensorflow.python.keras.layers.core.Dense'>
<class 'horovod.tensorflow.sync_batch_norm.SyncBatchNormalization'>
[stdout:1] 
<class 'tensorflow.python.keras.layers.core.Dense'>
<class 'horovod.tensorflow.sync_batch_norm.SyncBatchNormalization'>


In [7]:
%%px
from resnet_v2 import ResNet50V2
# model = ResNet50V2(include_top=True, classes=1000, weights=None)
model = ResNet50V2(include_top=True, classes=1000, weights=None, layers=SyncedLayers())

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=lr,
    alpha=0.01,
    decay_steps=steps_per_epoch*epochs
)

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
# optimizer = tfa.optimizers.LAMB(learning_rate=lr_schedule, momentum=0.9)
optimizer = hvd.DistributedOptimizer(optimizer)

model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy',
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name="Top5")])

callbacks = [
    # Horovod: average metrics among workers at the end of every epoch.
    # Note: This callback must be in the list before the ReduceLROnPlateau,
    # TensorBoard or other metrics-based callbacks.
    hvd.callbacks.MetricAverageCallback(),
    
    # Horovod: broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),

#     tf.keras.callbacks.TensorBoard(
#         log_dir=os.path.join(os.environ['SCRATCH'], 'imagenet_logs', datetime.now().strftime("%y%m%d-%H%M")),
#         histogram_freq=1,
# #         profile_batch='80,100',
#         update_freq=100, # batches
#     )
]

[stderr:0] 
  "Some callbacks may not have access to the averaged metrics, "
[stderr:1] 
  "Some callbacks may not have access to the averaged metrics, "


In [8]:
# %reload_ext tensorboard

# imagenet_logs = os.path.join(os.environ['SCRATCH'], 'imagenet_logs')
# %tensorboard --logdir={imagenet_logs}

In [9]:
%%px --noblock -o training
# the non-blocking execution (`%%px --noblock`) returns an `AsyncResult` object inmediately.
# the `AsyncResult` object can be accessed from python with the option `-o <variable>`.
# by doing that we can fetch information while the code running.

fit = model.fit(
    train_ds,
    steps_per_epoch=10, #steps_per_epoch,
#     validation_data=valid_ds,
#     validation_steps=validation_steps,
    epochs=1, #epochs,
    callbacks=[callbacks],
    verbose=1 if hvd.rank() == 0 else 0,  # use verbose=2 in a production script (sbatch)
)

<AsyncResult: execute>

In [10]:
# watch the output in real time
utilities.watch_asyncresult(training)

[ stdout 0 ]



In [11]:
%%px
fit.history

[0;31mOut[0:8]: [0m{'loss': [7.1615887], 'accuracy': [0.001171875], 'Top5': [0.005859375]}

[0;31mOut[1:8]: [0m{'loss': [7.1615887], 'accuracy': [0.001171875], 'Top5': [0.005859375]}

### Even though we only synchronize the gradients, all model weights must be equal across all nodes

In [12]:
%%px
## Average and calculate MAE of all 1 dimensional weights
norms = {w.name : np.mean(abs(w - hvd.allreduce(w))) for w in model.weights if len(w.shape) == 1}
norms

[0;31mOut[0:9]: [0m
{'conv1_conv/bias:0': 0.0,
 'conv2_block1_preact_bn/gamma:0': 0.0,
 'conv2_block1_preact_bn/beta:0': 0.0,
 'conv2_block1_preact_bn/moving_mean:0': 0.0,
 'conv2_block1_preact_bn/moving_variance:0': 0.0,
 'conv2_block1_1_bn/gamma:0': 0.0,
 'conv2_block1_1_bn/beta:0': 0.0,
 'conv2_block1_1_bn/moving_mean:0': 0.0,
 'conv2_block1_1_bn/moving_variance:0': 0.0,
 'conv2_block1_2_bn/gamma:0': 0.0,
 'conv2_block1_2_bn/beta:0': 0.0,
 'conv2_block1_2_bn/moving_mean:0': 0.0,
 'conv2_block1_2_bn/moving_variance:0': 0.0,
 'conv2_block1_0_conv/bias:0': 0.0,
 'conv2_block1_3_conv/bias:0': 0.0,
 'conv2_block2_preact_bn/gamma:0': 0.0,
 'conv2_block2_preact_bn/beta:0': 0.0,
 'conv2_block2_preact_bn/moving_mean:0': 0.0,
 'conv2_block2_preact_bn/moving_variance:0': 0.0,
 'conv2_block2_1_bn/gamma:0': 0.0,
 'conv2_block2_1_bn/beta:0': 0.0,
 'conv2_block2_1_bn/moving_mean:0': 0.0,
 'conv2_block2_1_bn/moving_variance:0': 0.0,
 'conv2_block2_2_bn/gamma:0': 0.0,
 'conv2_block2_2_bn/beta:0': 

[0;31mOut[1:9]: [0m
{'conv1_conv/bias:0': 0.0,
 'conv2_block1_preact_bn/gamma:0': 0.0,
 'conv2_block1_preact_bn/beta:0': 0.0,
 'conv2_block1_preact_bn/moving_mean:0': 0.0,
 'conv2_block1_preact_bn/moving_variance:0': 0.0,
 'conv2_block1_1_bn/gamma:0': 0.0,
 'conv2_block1_1_bn/beta:0': 0.0,
 'conv2_block1_1_bn/moving_mean:0': 0.0,
 'conv2_block1_1_bn/moving_variance:0': 0.0,
 'conv2_block1_2_bn/gamma:0': 0.0,
 'conv2_block1_2_bn/beta:0': 0.0,
 'conv2_block1_2_bn/moving_mean:0': 0.0,
 'conv2_block1_2_bn/moving_variance:0': 0.0,
 'conv2_block1_0_conv/bias:0': 0.0,
 'conv2_block1_3_conv/bias:0': 0.0,
 'conv2_block2_preact_bn/gamma:0': 0.0,
 'conv2_block2_preact_bn/beta:0': 0.0,
 'conv2_block2_preact_bn/moving_mean:0': 0.0,
 'conv2_block2_preact_bn/moving_variance:0': 0.0,
 'conv2_block2_1_bn/gamma:0': 0.0,
 'conv2_block2_1_bn/beta:0': 0.0,
 'conv2_block2_1_bn/moving_mean:0': 0.0,
 'conv2_block2_1_bn/moving_variance:0': 0.0,
 'conv2_block2_2_bn/gamma:0': 0.0,
 'conv2_block2_2_bn/beta:0': 

In [13]:
# %ipcluster stop