In [None]:
from fastai.gen_doc.gen_notebooks import update_module_metadata
import fastai.callbacks.mixup
# For updating jekyll metadata. You MUST reload notebook immediately after executing this cell for changes to save
# Leave blank to autopopulate from mod.__doc__
update_module_metadata(fastai.callbacks.mixup, title='callbacks.mixup', summary="Implementation of mixup")

# Mixup data augmentation

In [None]:
from fastai.gen_doc.nbdoc import *
from fastai.callbacks.mixup import * 
from fastai.docs import *

## What is Mixup?

This module contains the implementation of a data augmentation technique called [Mixup](https://arxiv.org/abs/1710.09412). It is extremely efficient at regularizing models in computer vision (we used it to get our time to train CIFAR10 to 94% on one GPU to 6 minutes). 

As the name kind of suggests, the authors of the mixup article propose to train the model on a mix of the pictures of the training set. Let’s say we’re on CIFAR10 for instance, then instead of feeding the model the raw images, we take two (which could be in the same class or not) and do a linear combination of them: in terms of tensor it’s

`new_image = t * image1 + (1-t) * image2`

where t is a float between 0 and 1. Then the target we assign to that image is the same combination of the original targets:

`new_target = t * target1 + (1-t) * target2`

assuming your targets are one-hot encoded (which isn’t the case in pytorch usually). And that’s as simple as this.

![mixup](imgs/mixup.png)

Dog or cat? The right answer here is 70% dog and 30% cat!

As the picture above shows, it’s a bit hard for a human eye to comprehend the pictures obtained (although we do see the shapes of a dog and a cat) but somehow, it makes a lot of sense to the model which trains more efficiently. The final loss (training or validation) will be higher than when training without mixup even if the accuracy is far better, which means that a model trained like this will make predictions that are a bit less confident.

## Basic Training

To test this method, we will first build a `simple_cnn` and train it like we did with `basic_train` so we can compare its results with a network trained with Mixup.

In [None]:
data = get_mnist()
model = simple_cnn((3,16,16,2))
learn = Learner(data, model, metrics=[accuracy])

In [None]:
learn.fit(20)

VBox(children=(HBox(children=(IntProgress(value=0, max=20), HTML(value='0.00% [0/20 00:00<00:00]'))), HTML(val…

Total time: 00:47
epoch  train loss  valid loss  accuracy
0      0.133583    0.102363    0.961236  (00:02)
1      0.082058    0.082410    0.972031  (00:02)
2      0.069833    0.055543    0.981354  (00:02)
3      0.066950    0.053367    0.981845  (00:02)
4      0.048239    0.040258    0.984789  (00:02)
5      0.044019    0.032995    0.987733  (00:02)
6      0.033873    0.028625    0.988714  (00:02)
7      0.032364    0.039206    0.981845  (00:02)
8      0.031153    0.027577    0.989205  (00:02)
9      0.029979    0.023575    0.990677  (00:02)
10     0.023657    0.022094    0.991168  (00:02)
11     0.021728    0.023140    0.990186  (00:02)
12     0.021997    0.023602    0.989696  (00:02)
13     0.024850    0.027446    0.991168  (00:02)
14     0.017646    0.018759    0.991659  (00:02)
15     0.018809    0.018319    0.991659  (00:02)
16     0.018132    0.025562    0.990677  (00:02)
17     0.015993    0.017342    0.991659  (00:02)
18     0.011576    0.017541    0.991659  (00:02)
19     0.01

## Mixup implementation in the library

In the original article, the authors suggested four things:

    1. Create two separate dataloaders and draw a batch from each at every iteration to mix them up
    2. Draw a t value following a beta distribution with a parameter alpha (0.4 is suggested in their article)
    3. Mix up the two batches with the same value t.
    4. Use one-hot encoded targets

The implementation of this module is based on these suggestions but was modified when experiments suggested modifications with positive impact in performance.

The authors suggest to use the beta distribution with the same parameters alpha. Why do they suggest this? Well it looks like this:

![betadist](imgs/betadist-mixup.png)

so it means there is a very high probability of picking values close to 0 or 1 (in which case the image is almost from 1 category) and then a somewhat constant probability of picking something in the middle (0.33 as likely as 0.5 for instance).

While this works very well, it’s not the fastest way we can do this and this is the first suggestion we will adjust. The main point that slows down this process is wanting two different batches at every iteration (which means loading twice the amount of images and applying to them the other data augmentation function). To avoid this slow down, ou be a little smarter and mixup a batch with a shuffled version of itself (this way the images mixed up are still different).

Using the same parameter t for the whole batch is another suggestion we will modify. In our experiments, we noticed that the model can train faster if we draw a different `t` for every image in the batch (both options get to the same result in terms of accuracy, it’s just that one arrives there more slowly).
The last trick we have to apply with this is that there can be some duplicates with this strategy: let’s say we decide to mix `image0` with `image1` then `image1` with `image0`, and that we draw `t=0.1` for the first, and `t=0.9` for the second. Then

`image0 * 0.1 + shuffle0 * (1-0.1) = image0 * 0.1 + image1 * 0.9`

and

`image1 * 0.9 + shuffle1 * (1-0.9) = image1 * 0.9 + image0 * 0.1`

will be the sames. Of course we have to be a bit unlucky but in practice, we saw there was a drop in accuracy by using this without removing those duplicates. To avoid them, the tricks is to replace the vector of parameters `t` we drew by:

`t = max(t, 1-t)`

The beta distribution with the two parameters equal is symmetric in any case, and this way we insure that the biggest coefficient is always near the first image (the non-shuffled batch).

## Adding Mixup to the Mix

Now we will add `MixUpCallback` to our Learner so that it modifies our input and target accordingly. The `mixup` function does that for us behind the scene, with a few other tweaks detailed below.

In [None]:
model = simple_cnn((3,16,16,2))
learner = Learner(data, model, metrics=[accuracy]).mixup()
learner.fit(20)

VBox(children=(HBox(children=(IntProgress(value=0, max=20), HTML(value='0.00% [0/20 00:00<00:00]'))), HTML(val…

Total time: 00:49
epoch  train loss  valid loss  accuracy
0      0.380133    0.202769    0.947498  (00:02)
1      0.357177    0.171234    0.962218  (00:02)
2      0.329518    0.115624    0.978901  (00:02)
3      0.322647    0.114372    0.984789  (00:02)
4      0.317570    0.100772    0.985280  (00:02)
5      0.309182    0.086424    0.988714  (00:02)
6      0.306099    0.089540    0.988714  (00:02)
7      0.315323    0.096520    0.989696  (00:02)
8      0.306009    0.091482    0.990677  (00:02)
9      0.302448    0.087128    0.991168  (00:02)
10     0.307314    0.082438    0.991659  (00:02)
11     0.307111    0.095313    0.991168  (00:02)
12     0.304113    0.083817    0.994112  (00:02)
13     0.308068    0.089792    0.990677  (00:02)
14     0.300224    0.069017    0.993621  (00:02)
15     0.302404    0.080377    0.993621  (00:02)
16     0.301102    0.077373    0.995093  (00:02)
17     0.302036    0.094481    0.993621  (00:02)
18     0.291388    0.083959    0.991168  (00:02)
19     0.29

Training the net with Mixup improves the best accuracy from 99.3% to 99.5% (which is a reduction in error rate of 29%!)

Note that the validation loss is higher than without MixUp, because the model makes less confident predictions: without mixup, most precisions are very close to 0. or 1. (in terms of probability) whereas the model with MixUp will give predictions that are more nuanced. Be sure to know what is the thing you want to optimize (lower loss or better accuracy) before using it.

In [None]:
show_doc(MixUpCallback, doc_string=False)

## <a id=MixUpCallback></a>`class` `MixUpCallback`
> `MixUpCallback`(`learner`:`Learner`, `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) :: `Callback`
<a href="https://github.com/fastai/fastai/blob/master/fastai/callbacks/mixup.py#L7">[source]</a>

Create a `Callback` for mixup on `learn` with a parameter `alpha` for the beta distribution. `stack_x` and `stack_y` determines if we stack our inputs/targets with the vector lambda drawn or do the linear combination (in general, we stack the inputs or ouputs when they correspond to categories or classes and do the linear combination otherwise).

In [None]:
show_doc(MixUpCallback.on_batch_begin, doc_string=False)

#### <a id=on_batch_begin></a>`on_batch_begin`
> `on_batch_begin`(`last_input`, `last_target`, `kwargs`)
<a href="https://github.com/fastai/fastai/blob/master/fastai/callbacks/mixup.py#L14">[source]</a>

Draws a vector of lambda following a beta distribution with `self.alpha` and operates the mixup on `last_input` and `last_target` according to `self.stack_x` and `self.stack_y`.

## Dealing with the loss

We often have to modify the loss so that it is compatible with Mixup: pytorch was very careful to avoid one-hot encoding targets when it could, so it seems a bit of a drag to undo this. Fortunately for us, if the loss is a classic [cross-entropy](https://pytorch.org/docs/stable/nn.html#torch.nn.functional.cross_entropy), we have

`loss(output, new_target) = t * loss(output, target1) + (1-t) * loss(output, target2)`

so we won’t one-hot encode anything and just compute those two losses then do the linear combination.

The following class is used to adapt the loss to mixup. Note that the `mixup` function will use it to change the `Learner.loss_fn` if necessary.

In [None]:
show_doc(MixUpLoss, doc_string=False, title_level=3)

### <a id=MixUpLoss></a>`class` `MixUpLoss`
> `MixUpLoss`(`crit`) :: `Module`
<a href="https://github.com/fastai/fastai/blob/master/fastai/callbacks/mixup.py#L30">[source]</a>

Create a loss function from `crit` that is compatible with MixUp.

## Undocumented Methods - Methods moved below this line will intentionally be hidden

In [None]:
show_doc(MixUpLoss.forward)