Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MXNet: support Gluon Trainer API #809

Merged
merged 6 commits into from Feb 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 65 additions & 3 deletions README.md
Expand Up @@ -216,7 +216,11 @@ See full training [MNIST](examples/mxnet_mnist.py) and [ImageNet](examples/mxnet

**Note**: we recommend users to build MXNet from source following this [guide](https://mxnet.incubator.apache.org/install/build_from_source.html) when running Horovod with MXNet on a Linux OS with GCC version 5.X and above. The MXNet shared library distributed through MXNet pip package is currently built using GCC 4.8.4. If we build and install Horovod on a Linux OS with GCC 5.X+ with MXNet pip package, we will hit segmentation fault due to std::function definition change from GCC [4.X](https://github.com/gcc-mirror/gcc/blob/gcc-4_8_4-release/libstdc++-v3/include/std/functional#L2069) to GCC [5.X](https://github.com/gcc-mirror/gcc/blob/gcc-5_4_0-release/libstdc++-v3/include/std/functional#L1854).

There are two ways to train a model using MXNet: [Gluon](http://mxnet.incubator.apache.org/api/python/gluon/gluon.html) API (preferred) and [Module](http://mxnet.incubator.apache.org/api/python/module/module.html) API. Here we provide the building block for each set of API to train a model using MXNet with Horovod.

###### Gluon API
```python
from mxnet import autograd, gluon
import mxnet as mx
import horovod.mxnet as hvd

Expand All @@ -229,14 +233,73 @@ num_workers = hvd.size()

# Build model
model = ...
model.hybridize()

# Define hyper parameters
optimizer_params = ...

# Add Horovod Distributed Optimizer
opt = mx.optimizer.create('sgd', sym=model, **optimizer_params)
opt = mx.optimizer.create('sgd', **optimizer_params)
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
model.initialize(initializer, ctx=context)

# Fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer and loss function
trainer = gluon.Trainer(params, opt, kvstore=None)
loss_fn = ...

# Train model
for epoch in range(num_epoch):
train_data.reset()
for nbatch, batch in enumerate(train_data, start=1):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=[context],
batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=[context],
batch_axis=0)
with autograd.record():
outputs = [model(x.astype(dtype, copy=False)) for x in data]
loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()
trainer.step(batch_size)
```

###### Module API
```python
import mxnet as mx
import horovod.mxnet as hvd

# Initialize Horovod
hvd.init()

# Pin GPU to be used to process local rank
context = mx.gpu(hvd.local_rank())
num_workers = hvd.size()

# Build model
model = ...

# Define hyper parameters
optimizer_params = ...

# Add Horovod Distributed Optimizer
opt = mx.optimizer.create('sgd', **optimizer_params)
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
model.bind(data_shapes=train_data.provide_data,
label_shapes=train_data.provide_label)
model.init_params(initializer)

# Fetch and broadcast parameters
(arg_params, aux_params) = model.get_params()
if arg_params:
hvd.broadcast_parameters(arg_params, root_rank=0)
Expand All @@ -246,10 +309,9 @@ model.set_params(arg_params=arg_params, aux_params=aux_params)

# Train model
model.fit(train_data,
kvstore=None,
optimizer=opt,
opitmizer_params=optimizer_params,
num_epoch=num_epoch)

```

## PyTorch
Expand Down