# Using AMP (Automatic Mixed Precision) in MXNet

Training Deep Learning networks is a very computationally intensive task. Novel model architectures tend to have increasing number of layers and parameters, which slows down training. Fortunately, new generations of training hardware as well as software optimizations, make it a feasible task. 

However, where most of the (both hardware and software) optimization opportunities exists is in exploiting lower precision (like FP16) to, for example, utilize Tensor Cores available on new Volta and Turing GPUs. While training in FP16 showed great success in image classification tasks, other more complicated neural networks typically stayed in FP32 due to difficulties in applying the FP16 training guidelines.

That is where AMP (Automatic Mixed Precision) comes into play. It automatically applies the guidelines of FP16 training, using FP16 precision where it provides the most benefit, while conservatively keeping in full FP32 precision operations unsafe to do in FP16.

This tutorial shows how to get started with mixed precision training using AMP for MXNet. As an example of a network we will use SSD network from GluonCV.

## Data loader and helper functions

For demonstration purposes we will use synthetic data loader.

In [None]:
import mxnet as mx
import mxnet.gluon as gluon
import gluoncv as gcv

data_shape = 512
batch_size = 8

# set up logger
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.info('Start training')

ce_metric = mx.metric.Loss('CrossEntropy')
smoothl1_metric = mx.metric.Loss('SmoothL1')

In [None]:
class SyntheticDataLoader(object):
    def __init__(data_shape, batch_size):
        super(SyntheticDataLoader, self).__init__()
        self.counter = 0
        self.epoch_size = 200
        self.data = None
        self.cls_targets = None
        self.box_targets = None
    
    def __next__():
        if self.counter >= self.epoch_size:
            self.counter = self.counter % self.epoch_size
            raise StopIteration
        self.counter += 1
        return [self.data, self.cls_targets, self.box_targets]
    
train_data = SyntheticDataLoader(data_shape, batch_size)

In [None]:
def get_network():
    # SSD with RN50 backbone

    # training contexts
    ctx = [mx.gpu(0)]

    # network
    net_name = 'ssd_512_resnet50_v1_coco'
    net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm)
    async_net = net
    
    net.initialize()
    return net

# Training in FP32

In [None]:
net = get_network()
net.hybridize(static_alloc=True, static_shape=True)

trainer = gluon.Trainer(
    net.collect_params(), 'sgd',
    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum})

mbox_loss = gcv.loss.SSDMultiBoxLoss()

for epoch in range(1):
    ce_metric.reset()
    smoothl1_metric.reset()
    tic = time.time()
    btic = time.time()

    for i, batch in enumerate(train_data):
        batch_size = batch[0].shape[0]
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
        with autograd.record():
            cls_preds = []
            box_preds = []
            for x in data:
                cls_pred, box_pred, _ = net(x)
                cls_preds.append(cls_pred)
                box_preds.append(box_pred)
            sum_loss, cls_loss, box_loss = mbox_loss(
                cls_preds, box_preds, cls_targets, box_targets)
            autograd.backward(sum_loss)
        trainer.step(1)
        ce_metric.update(0, [l * batch_size for l in cls_loss])
        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
        if not (i + 1) % 50:
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
        btic = time.time()

    name1, loss1 = ce_metric.get()
    name2, loss2 = smoothl1_metric.get()
    logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
        epoch, (time.time()-tic), name1, loss1, name2, loss2))

## Training with AMP

### AMP initialization

In order to start using AMP, we need to import and initialize it. This has to happen before we create the network.

In [None]:
from mxnet import amp

amp.init()

After that, we can create the network exactly the same way we did in FP32 training.

In [None]:
net = get_network()
net.hybridize(static_alloc=True, static_shape=True)

For some models that may be enough to start training in mixed precision, but the full FP16 recipe recommends using dynamic loss scaling to guard against over- and underflows of FP16 values. Therefore, as a next step, we create a trainer and initialize it with support for AMP's dynamic loss scaling. Currently, support for dynamic loss scaling is limited to trainers created with `update_on_kvstore=False` option, and so we add it to our trainer initialization.

In [None]:
trainer = gluon.Trainer(
    net.collect_params(), 'sgd',
    {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum},
    update_on_kvstore=False)

amp.init_trainer(trainer)

### Dynamic loss scaling in the training loop

The last step is to apply the dynamic loss scaling during the training loop and . We can achieve that using the `amp.scale_loss` function.

In [None]:
mbox_loss = gcv.loss.SSDMultiBoxLoss()


for epoch in range(1):
    ce_metric.reset()
    smoothl1_metric.reset()
    tic = time.time()
    btic = time.time()

    for i, batch in enumerate(train_data):
        batch_size = batch[0].shape[0]
        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
        cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
        box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0)
        with autograd.record():
            cls_preds = []
            box_preds = []
            for x in data:
                cls_pred, box_pred, _ = net(x)
                cls_preds.append(cls_pred)
                box_preds.append(box_pred)
            sum_loss, cls_loss, box_loss = mbox_loss(
                cls_preds, box_preds, cls_targets, box_targets)
            with amp.scale_loss(sum_loss, trainer) as scaled_loss:
                autograd.backward(scaled_loss)
        trainer.step(1)
        ce_metric.update(0, [l * batch_size for l in cls_loss])
        smoothl1_metric.update(0, [l * batch_size for l in box_loss])
        if not (i + 1) % 50:
            name1, loss1 = ce_metric.get()
            name2, loss2 = smoothl1_metric.get()
            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format(
                epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2))
        btic = time.time()

    name1, loss1 = ce_metric.get()
    name2, loss2 = smoothl1_metric.get()
    logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format(
        epoch, (time.time()-tic), name1, loss1, name2, loss2))