In [1]:
import os
import logging
from multiprocessing import cpu_count
import tensorflow as tf
from utils.dist_utils import is_sm_dist
from models import resnet, darknet, hrnet
from engine.schedulers import WarmupScheduler
from engine.optimizers import MomentumOptimizer
from datasets import create_dataset, parse
from engine.trainer import Trainer
if is_sm_dist():
    import smdistributed.dataparallel.tensorflow as dist
else:
    import horovod.tensorflow as dist
dist.init()

In [2]:
tf32 = True
xla = True
fp16 = True

In [3]:
tf.config.experimental.enable_tensor_float_32_execution(tf32)
tf.config.threading.intra_op_parallelism_threads = 1 # Avoid pool of Eigen threads
tf.config.threading.inter_op_parallelism_threads = max(2, cpu_count()//dist.local_size()-2)
tf.config.optimizer.set_jit(xla)
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": fp16})

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[dist.local_rank()], 'GPU')

In [4]:
train_data_dir = '/home/ubuntu/data/imagenet/tfrecord/train/'
validation_data_dir = '/home/ubuntu/data/imagenet/tfrecord/validation/'
model_dir = '/home/ubuntu/models'
train_dataset_size = 1281167
num_classes = 1000
batch_size = 128
num_epochs = 125
schedule = 'cosine'
learning_rate = 0.01
momentum = 0.9
label_smoothing = 0.1
l2_weight_decay = 1e-5
mixup_alpha = 0.2
steps_per_epoch = train_dataset_size // (batch_size * dist.size())
iterations = steps_per_epoch * num_epochs

In [5]:
model = resnet.ResNet152V1_d(weights=None, weight_decay=l2_weight_decay, classes=num_classes)
scheduler = tf.keras.experimental.CosineDecayRestarts(initial_learning_rate=learning_rate,
                    first_decay_steps=iterations, t_mul=1, m_mul=1)
scheduler = WarmupScheduler(scheduler=scheduler, initial_learning_rate=learning_rate / 10, warmup_steps=500)
opt = MomentumOptimizer(learning_rate=scheduler, momentum=momentum, nesterov=True) 
if fp16:
    opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
    #opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt, loss_scale=128.)
loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=label_smoothing, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) 

In [6]:
if dist.rank() == 0:
    path_logs = os.path.join(os.getcwd(), model_dir, 'log.csv')
    os.makedirs(model_dir, exist_ok=True)
    logging.basicConfig(filename=path_logs,
                            filemode='a',
                            format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S',
                            level=logging.DEBUG)
    logging.info("Training Logs")
    logger = logging.getLogger('logger')
    # logger.info('Training options: %s', FLAGS)

# Barrier
_ = dist.allreduce(tf.constant(0))

In [7]:
train_data = create_dataset(train_data_dir, batch_size, preprocessing='resnet', validation=False)
validation_data = create_dataset(validation_data_dir, batch_size, preprocessing='resnet', validation=True)

In [8]:
trainer = Trainer(model, opt, loss_func, scheduler, logging=logger, fp16=fp16, mixup_alpha=mixup_alpha, model_dir='~/models/')

In [9]:
for epoch in range(num_epochs):
    trainer.train_epoch(train_data)

Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
step: 0, step time: 1.9200, train_loss: 8.4951, top_1_accuracy: 0.0000, learning_rate: 0.0010
step: 50, step time: 1.3894, train_loss: 8.5024, top_1_accuracy: 0.0000, learning_rate: 0.0019
step: 100, step time: 0.1898, train_loss: 8.4982, top_1_accuracy: 0.0000, learning_rate: 0.0028
step: 150, step time: 0.1892, train_loss: 8.4859, top_1_accuracy: 0.0000, learning_rate: 0.0037
step: 200, step time: 0.1888, train_loss: 8.5289, top_1_accuracy: 0.0078, learning_rate: 0.0046
step: 250, step time: 0.1880, train_loss: 8.4823, t

KeyboardInterrupt: 

In [None]:
trainer.validation_epoch(validation_data, output_name='epoch_1')