diff --git a/README.md b/README.md index 6a9654b17e..10f7650921 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,11 @@ See full training [MNIST](examples/mxnet_mnist.py) and [ImageNet](examples/mxnet **Note**: we recommend users to build MXNet from source following this [guide](https://mxnet.incubator.apache.org/install/build_from_source.html) when running Horovod with MXNet on a Linux OS with GCC version 5.X and above. The MXNet shared library distributed through MXNet pip package is currently built using GCC 4.8.4. If we build and install Horovod on a Linux OS with GCC 5.X+ with MXNet pip package, we will hit segmentation fault due to std::function definition change from GCC [4.X](https://github.com/gcc-mirror/gcc/blob/gcc-4_8_4-release/libstdc++-v3/include/std/functional#L2069) to GCC [5.X](https://github.com/gcc-mirror/gcc/blob/gcc-5_4_0-release/libstdc++-v3/include/std/functional#L1854). +There are two ways to train a model using MXNet: [Gluon](http://mxnet.incubator.apache.org/api/python/gluon/gluon.html) API (preferred) and [Module](http://mxnet.incubator.apache.org/api/python/module/module.html) API. Here we provide the building block for each set of API to train a model using MXNet with Horovod. + +###### Gluon API ```python +from mxnet import autograd, gluon import mxnet as mx import horovod.mxnet as hvd @@ -229,14 +233,73 @@ num_workers = hvd.size() # Build model model = ... +model.hybridize() # Define hyper parameters optimizer_params = ... # Add Horovod Distributed Optimizer -opt = mx.optimizer.create('sgd', sym=model, **optimizer_params) +opt = mx.optimizer.create('sgd', **optimizer_params) opt = hvd.DistributedOptimizer(opt) +# Initialize parameters +model.initialize(initializer, ctx=context) + +# Fetch and broadcast parameters +params = model.collect_params() +if params is not None: + hvd.broadcast_parameters(params, root_rank=0) + +# Create trainer and loss function +trainer = gluon.Trainer(params, opt, kvstore=None) +loss_fn = ... + +# Train model +for epoch in range(num_epoch): + train_data.reset() + for nbatch, batch in enumerate(train_data, start=1): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=[context], + batch_axis=0) + label = gluon.utils.split_and_load(batch.label[0], ctx_list=[context], + batch_axis=0) + with autograd.record(): + outputs = [model(x.astype(dtype, copy=False)) for x in data] + loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)] + for l in loss: + l.backward() + trainer.step(batch_size) +``` + +###### Module API +```python +import mxnet as mx +import horovod.mxnet as hvd + +# Initialize Horovod +hvd.init() + +# Pin GPU to be used to process local rank +context = mx.gpu(hvd.local_rank()) +num_workers = hvd.size() + +# Build model +model = ... + +# Define hyper parameters +optimizer_params = ... + +# Add Horovod Distributed Optimizer +opt = mx.optimizer.create('sgd', **optimizer_params) +opt = hvd.DistributedOptimizer(opt) + +# Initialize parameters +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) +model.bind(data_shapes=train_data.provide_data, + label_shapes=train_data.provide_label) +model.init_params(initializer) + +# Fetch and broadcast parameters (arg_params, aux_params) = model.get_params() if arg_params: hvd.broadcast_parameters(arg_params, root_rank=0) @@ -246,10 +309,9 @@ model.set_params(arg_params=arg_params, aux_params=aux_params) # Train model model.fit(train_data, + kvstore=None, optimizer=opt, - opitmizer_params=optimizer_params, num_epoch=num_epoch) - ``` ## PyTorch diff --git a/examples/mxnet_imagenet_resnet50.py b/examples/mxnet_imagenet_resnet50.py index 1f28f61e43..69fe0aea0a 100644 --- a/examples/mxnet_imagenet_resnet50.py +++ b/examples/mxnet_imagenet_resnet50.py @@ -19,13 +19,13 @@ import logging import math import os +import time from gluoncv.model_zoo import get_model import horovod.mxnet as hvd import mxnet as mx import numpy as np -from mxnet import gluon -from mxnet import lr_scheduler +from mxnet import autograd, gluon, lr_scheduler from mxnet.io import DataBatch, DataIter @@ -35,7 +35,7 @@ parser.add_argument('--use-rec', action='store_true', default=False, help='use image record iter for data input (default: False)') parser.add_argument('--data-nthreads', type=int, default=2, - help='number of threads for data decoding') + help='number of threads for data decoding (default: 2)') parser.add_argument('--rec-train', type=str, default='', help='the training data') parser.add_argument('--rec-train-idx', type=str, default='', @@ -49,7 +49,7 @@ parser.add_argument('--dtype', type=str, default='float32', help='data type for training (default: float32)') parser.add_argument('--num-epochs', type=int, default=90, - help='number of training epochs.') + help='number of training epochs (default: 90)') parser.add_argument('--lr', type=float, default=0.05, help='learning rate for a single GPU (default: 0.05)') parser.add_argument('--momentum', type=float, default=0.9, @@ -58,12 +58,11 @@ help='weight decay rate (default: 0.0001)') parser.add_argument('--lr-mode', type=str, default='poly', help='learning rate scheduler mode. Options are step, \ - poly and cosine. (default: poly)') + poly and cosine (default: poly)') parser.add_argument('--lr-decay', type=float, default=0.1, help='decay rate of learning rate (default: 0.1)') parser.add_argument('--lr-decay-epoch', type=str, default='40,60', - help='epoches at which learning rate decays \ - (default is : 40,60)') + help='epoches at which learning rate decays (default: 40,60)') parser.add_argument('--warmup-lr', type=float, default=0.0, help='starting warmup learning rate (default: 0.0)') parser.add_argument('--warmup-epochs', type=int, default=10, @@ -73,16 +72,23 @@ each bottleneck to 0 (default: False)') parser.add_argument('--model', type=str, default='resnet50_v1', help='type of model to use. see vision_model for options.') +parser.add_argument('--mode', type=str, default='module', + help='mode in which to train the model. options are \ + module, gluon (default: module)') parser.add_argument('--use-pretrained', action='store_true', default=False, help='load pretrained model weights (default: False)') -parser.add_argument('--eval-epoch', action='store_true', default=False, - help='evaluate validation accuracy after each epoch (default: False)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training (default: False)') +parser.add_argument('--eval-epoch', action='store_true', default=False, + help='evaluate validation accuracy after each epoch \ + when training in module mode (default: False)') +parser.add_argument('--eval-frequency', type=int, default=0, + help='frequency of evaluating validation accuracy \ + when training with gluon mode (default: 0)') parser.add_argument('--log-interval', type=int, default=0, help='number of batches to wait before logging (default: 0)') -parser.add_argument('--save-frequency', type=int, default=10, - help='frequency of model saving. (default: 10)') +parser.add_argument('--save-frequency', type=int, default=0, + help='frequency of model saving (default: 0)') args = parser.parse_args() @@ -130,13 +136,6 @@ else: raise ValueError('Invalid lr mode') -# Horovod: pin GPU to local rank -context = mx.cpu() if args.no_cuda else mx.gpu(local_rank) -kwargs = {'ctx': context, 'pretrained': args.use_pretrained, - 'classes': num_classes} -if args.last_gamma: - kwargs['last_gamma'] = True - # Function for reading data from record file # For more details about data loading in MXNet, please refer to # https://mxnet.incubator.apache.org/tutorials/basic/data.html?highlight=imagerecorditer @@ -204,7 +203,6 @@ def batch_fn(batch, ctx): return train_data, val_data, batch_fn - # Create data iterator for synthetic data class SyntheticDataIter(DataIter): def __init__(self, num_classes, data_shape, max_iter, dtype, ctx): @@ -249,6 +247,8 @@ def __next__(self): def reset(self): self.cur_iter = 0 +# Horovod: pin GPU to local rank +context = mx.cpu(local_rank) if args.no_cuda else mx.gpu(local_rank) if args.use_rec: # Fetch training and validation data if present @@ -267,12 +267,116 @@ def reset(self): val_data = None -def train(): - # Get model from GluonCV model zoo - # https://gluon-cv.mxnet.io/model_zoo/index.html - net = get_model(args.model, **kwargs) - net.cast(args.dtype) +# Get model from GluonCV model zoo +# https://gluon-cv.mxnet.io/model_zoo/index.html +kwargs = {'ctx': context, + 'pretrained': args.use_pretrained, + 'classes': num_classes} +if args.last_gamma: + kwargs['last_gamma'] = True +net = get_model(args.model, **kwargs) +net.cast(args.dtype) + +# Create initializer +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) + +# Create optimizer +optimizer_params = {'wd': args.wd, + 'momentum': args.momentum, + 'rescale_grad': 1.0 / batch_size, + 'lr_scheduler': lr_sched} +if args.dtype == 'float16': + optimizer_params['multi_precision'] = True +opt = mx.optimizer.create('sgd', **optimizer_params) + +# Horovod: wrap optimizer with DistributedOptimizer +opt = hvd.DistributedOptimizer(opt) + + +def train_gluon(): + def evaluate(epoch): + if not args.use_rec: + return + + val_data.reset() + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + for _, batch in enumerate(val_data): + data, label = batch_fn(batch, [context]) + outputs = [net(x.astype(args.dtype, copy=False)) for x in data] + acc_top1.update(label, outputs) + acc_top5.update(label, outputs) + + top1_name, top1_acc = acc_top1.get() + top5_name, top5_acc = acc_top5.get() + logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f', + epoch, rank, top1_name, top1_acc, top5_name, top5_acc) + + # Hybridize and initialize model + net.hybridize() + net.initialize(initializer, ctx=context) + + # Horovod: fetch and broadcast parameters + params = net.collect_params() + if params is not None: + hvd.broadcast_parameters(params, root_rank=0) + + # Create trainer, loss function and train metric + trainer = gluon.Trainer(params, opt, kvstore=None) + loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() + metric = mx.metric.Accuracy() + # Train model + for epoch in range(args.num_epochs): + tic = time.time() + if args.use_rec: + train_data.reset() + metric.reset() + + btic = time.time() + for nbatch, batch in enumerate(train_data, start=1): + data, label = batch_fn(batch, [context]) + with autograd.record(): + outputs = [net(x.astype(args.dtype, copy=False)) for x in data] + loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)] + for l in loss: + l.backward() + trainer.step(batch_size) + + metric.update(label, outputs) + if args.log_interval and nbatch % args.log_interval == 0: + name, acc = metric.get() + logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f', + epoch, rank, nbatch, name, acc, trainer.learning_rate) + if rank == 0: + batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic) + logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec', + epoch, nbatch, batch_speed) + btic = time.time() + + # Report metrics + elapsed = time.time() - tic + _, acc = metric.get() + logging.info('Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f', + epoch, rank, nbatch, elapsed, acc) + if rank == 0: + epoch_speed = num_workers * batch_size * nbatch / elapsed + logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed) + + # Evaluate performance + if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0: + evaluate(epoch) + + # Save model + if args.save_frequency and (epoch + 1) % args.save_frequency == 0: + net.export('%s-%d' % (args.model, rank), epoch=epoch) + + # Evaluate performance at the end of training + evaluate(epoch) + + +def train_module(): # Create input symbol data = mx.sym.var('data') if args.dtype == 'float16': @@ -285,6 +389,10 @@ def train(): out = mx.sym.Cast(data=out, dtype=np.float32) softmax = mx.sym.SoftmaxOutput(out, name='softmax') + # Create model + mod = mx.mod.Module(softmax, context=context) + + # Initialize parameters if args.use_pretrained: arg_params = {} for x in net.collect_params().values(): @@ -293,25 +401,6 @@ def train(): else: arg_params = None aux_params = None - - # Create model - mod = mx.mod.Module(softmax, context=context) - - # Create optimizer - optimizer_params = {'wd': args.wd, - 'momentum': args.momentum, - 'rescale_grad': 1.0 / batch_size, - 'lr_scheduler': lr_sched} - if args.dtype == 'float16': - optimizer_params['multi_precision'] = True - opt = mx.optimizer.create('sgd', sym=out, **optimizer_params) - - # Horovod: wrap optimizer with DistributedOptimizer - opt = hvd.DistributedOptimizer(opt) - - # Create initializer and initializer parameters - initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", - magnitude=2) mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params) @@ -329,9 +418,10 @@ def train(): if args.eval_epoch: eval_data = val_data batch_callback = None - if args.log_interval > 0: - batch_callback = mx.callback.Speedometer(batch_size, - max(1, args.log_interval)) + if args.log_interval > 0 and rank == 0: + batch_callback = mx.callback.Speedometer(batch_size * num_workers, + args.log_interval) + epoch_callback = None if args.save_frequency > 0: epoch_callback = mx.callback.do_checkpoint( @@ -345,8 +435,7 @@ def train(): kvstore=None, batch_end_callback=batch_callback, epoch_end_callback=epoch_callback, - optimizer=opt, - optimizer_params=optimizer_params) + optimizer=opt) # Evaluate performance if not using synthetic data if args.use_rec: @@ -359,4 +448,9 @@ def train(): if __name__ == '__main__': - train() + if args.mode == 'module': + train_module() + elif args.mode == 'gluon': + train_gluon() + else: + raise ValueError('Invalid training mode.') diff --git a/examples/mxnet_mnist.py b/examples/mxnet_mnist.py index 2607f17f09..eac8ca3524 100644 --- a/examples/mxnet_mnist.py +++ b/examples/mxnet_mnist.py @@ -125,6 +125,7 @@ def conv_net(): model.set_params(arg_params=arg_params, aux_params=aux_params) model.fit(train_iter, # train data + kvstore=None, # no kvstore eval_data=val_iter, # validation data optimizer=opt, # use SGD to train eval_metric='acc', # report accuracy during training diff --git a/horovod/mxnet/__init__.py b/horovod/mxnet/__init__.py index 2e9d0fb449..5250d7f95f 100644 --- a/horovod/mxnet/__init__.py +++ b/horovod/mxnet/__init__.py @@ -71,30 +71,34 @@ def set_wd_mult(self, args_wd_mult): def broadcast_parameters(params, root_rank=0): """ Broadcasts the parameters from root rank to all other processes. - Typical usage is to broadcast the `model.get_params()`. + Typical usage is to broadcast the `Module.get_params()` or the + `Block.collect_params()`. Arguments: params: One of the following: - - list of parameters to broadcast - dict of parameters to broadcast + - ParameterDict to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ + tensors = [] if isinstance(params, dict): - params = sorted(params.items()) - elif isinstance(params, list): - # support both named_parameters() and regular parameters() - params = [p if isinstance(p, tuple) else (None, p) for p in params] + tensors = [p for _, p in sorted(params.items())] + elif isinstance(params, mx.gluon.parameter.ParameterDict): + for _, p in sorted(params.items()): + try: + tensors.append(p.data()) + except mx.gluon.parameter.DeferredInitializationError: + # skip broadcasting deferred init param + pass else: raise ValueError('invalid params of type: %s' % type(params)) # Run broadcasts. - count = 0 - for _, p in params: - broadcast_(p, root_rank, str(count)) - count += 1 + for i, tensor in enumerate(tensors): + broadcast_(tensor, root_rank, str(i)) # Make sure tensors pushed to MXNet engine get processed such that all # workers are synced before starting training. - for _, p in params: - p.wait_to_read() + for tensor in tensors: + tensor.wait_to_read()