In [1]:
%%javascript
utils.load_extension('collapsible_headings/main')
utils.load_extension('hide_input/main')
utils.load_extension('autosavetime/main')
utils.load_extension('execute_time/ExecuteTime')
utils.load_extension('code_prettify/code_prettify')
utils.load_extension('scroll_down/main')
utils.load_extension('jupyter-js-widgets/extension')

<IPython.core.display.Javascript object>

In [2]:
from IPython.display import display, HTML, clear_output
display(HTML("<style>.container { width:100% !important; }</style>"))

Ignacio Oguiza - email: oguiza@gmail.com

## BatchLossFilter

I’d like to share with you a new callback I have created that has worked very well on some of my datasets.

Last week @Redknight wrote a [post](https://forums.fast.ai/t/interesting-accelerating-deep-learning-by-focusing-on-the-biggest-losers/56091?u=oguiza) about this [paper](https://arxiv.org/pdf/1910.00762):

Accelerating Deep Learning by Focusing on the Biggest Losers.

The paper describes **Selective-Backprop**, a technique that accelerates the training of deep neural networks (DNNs) by **prioritizing examples with high loss at each iteration**. 

In parallel I also read a tweet by David Page:

<img src="./images/tweet_blf.jpg">

The idea really resonated with me. I’ve always thought that it’d be good to spend most of the time learning about the most difficult examples. This seems a good way to do it, so I decided to try it. 

The paper’s code base in Pytorch is publically available [here](https://anonymous.4open.science/r/c6d4060d-bdac-4d31-839e-8579650255b3/). 

However, I thought I’d rather implement the idea with a different, simpler approach. The idea is this: identify those items within each batch that are responsible for a high % (I chose 90%) of the total batch loss, and remove the rest of the samples. In this way, your model will **dynamically focus on the high loss/ most difficult samples**. The percentage of samples remaining will vary per batch and along training as you’ll see.

I’ve run a test in CIFAR10. Here these are the results:

1) **Time to train** (100 epochs):  15.2 less time to train (in spite of the additional overhead)

<img src="./images/time_blf.jpg">

2) **Accuracy**: same as the baseline model (at least in 100 epochs)

<img src="./images/accuracy_blf.jpg"> 

However, training is more smooth, and there’s a significant different in terms of validation loss. I believe that with a longer training there could be a difference in accuracy. But I have not confirmed this yet.

3) **Validation loss**: much lower and smoother.

<img src="./images/valid_loss_blf.jpg">

4) **Selected samples per batch**: This is very interesting in my opinion, as it shows the % of samples that make up 90% of the total batch loss. As you can see, 90% of the total loss is initially made by a large % of batch samples, but as training progresses, it dynamically focuses on the most difficult samples. This samples are not necessarily the same all the time, as they are chosen for each batch. In the end, the model will be focused on 12% of the most difficult samples. This is why training takes less time.

Note: 
There are actually 2 hyperparameters: min_loss_perc: select samples that make a at least that %, and min_samples_perc: select at least a given % of highest losses. Both can be used at the same time. In my case I just used min_loss_perc.

<img src="./images/sel_samples_blf.jpg"> 

## Import libraries

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai_extensions import *

In [3]:
import math

class BatchLossFilterCallback(LearnerCallback):
    _order = -20

    def __init__(self, learn:Learner, min_sample_perc:float=0., min_loss_perc:float=0.):
        super().__init__(learn)
        assert min_sample_perc >0. or min_loss_perc > 0., 'min_sample_perc <= 0 and min_loss_perc <= 0'
        self.min_sample_perc, self.min_loss_perc = min_sample_perc, min_loss_perc
        self.learn = learn
        self.model = learn.model
        self.crit = learn.loss_func
        if hasattr(self.crit, 'reduction'):  self.red = self.crit.reduction
        self.sel_losses_sum, self.losses_sum = 0., 0.
        self.sel_samples, self.samples = 0., 0.
        self.recorder.add_metric_names(["loss_perc", "samp_perc"])

    def on_epoch_begin(self, **kwargs):
        "Set the inner value to 0."
        self.sel_losses_sum, self.losses_sum = 0., 0.
        self.sel_samples, self.samples = 0., 0.
    
    def on_batch_begin(self, last_input, last_target, train, epoch, **kwargs):
        if not train or epoch == 0: return
        if hasattr(self.crit, 'reduction'):  setattr(self.crit, 'reduction', 'none')
        with torch.no_grad():  self.losses = np.array(self.crit(self.model(last_input), last_target))
        if hasattr(self.crit, 'reduction'):  setattr(self.crit, 'reduction', self.red)
        self.get_loss_idxs()
        self.sel_losses_sum += self.losses[self.idxs].sum()
        self.losses_sum += self.losses.sum()
        self.sel_samples += len(self.idxs)
        self.samples += len(self.losses)
        return {"last_input": last_input[self.idxs], "last_target": last_target[self.idxs]}
        
    def on_epoch_end(self, epoch, last_metrics, **kwargs):
        loss_perc = self.sel_losses_sum / self.losses_sum if epoch > 0 else 1.
        sample_perc = self.sel_samples / self.samples if epoch > 0 else 1.
        return add_metrics(last_metrics, [loss_perc, sample_perc])
    
    def on_train_end(self, **kwargs):
        """At the end of training this calleback will be removed"""
        if hasattr(self.learn.loss_func, 'reduction'):  setattr(self.learn.loss_func, 'reduction', self.red)
        drop_cb_fn(self.learn, 'TopLossesCallback')
        
    def get_loss_idxs(self):
        idxs = np.argsort(self.losses)[::-1]
        sample_max = math.ceil(len(idxs) * self.min_sample_perc)
        self.losses /= self.losses.sum()
        loss_max = np.argmax(self.losses[idxs].cumsum() >= self.min_loss_perc) + 1
        self.idxs =  list(idxs[:max(sample_max, loss_max)])
        

def batch_loss_filter(learn:Learner, min_sample_perc:float=0., min_loss_perc:float=.9)->Learner:
    learn.callback_fns.append(partial(BatchLossFilterCallback, min_sample_perc=min_sample_perc, 
                                      min_loss_perc=min_loss_perc))
    return learn

Learner.batch_loss_filter = batch_loss_filter

## Prepare data

In [4]:
bs = 128
path = untar_data(URLs.CIFAR)
tfms = get_transforms()
data = (ItemLists('.',
                  ImageList.from_folder(path / 'train'),
                  ImageList.from_folder(path / 'test'))
        .label_from_folder()
        .transform(tfms)
        .databunch(bs=bs, val_bs=bs * 2)
        .normalize(cifar_stats))
data

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
ship,ship,ship,ship,ship
Path: /home/oguizadl/.fastai/data/cifar10/train;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
ship,ship,ship,ship,ship
Path: /home/oguizadl/.fastai/data/cifar10/test;

Test: None

## Speed test

In [5]:
model = models.WideResNet(num_groups=3, N=4, num_classes=10, k=2, start_nf=32).to(device)
xb,yb=next(iter(data.train_dl))
with torch.no_grad():
    losses = np.array(nn.CrossEntropyLoss(reduction='none')(model(xb), yb))

In [6]:
def get_loss_idxs(losses, min_sample_perc=0., min_loss_perc=0.):
    idxs = np.argsort(losses)[::-1]
    sample_max = math.ceil(len(idxs) * min_sample_perc)
    losses /= losses.sum()
    loss_max = np.argmax(losses[idxs].cumsum() >= min_loss_perc) + 1
    return list(idxs[:max(sample_max, loss_max)])

In [7]:
%timeit get_loss_idxs(losses, min_sample_perc=0., min_loss_perc=0.9)

31.2 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Build learner

In [8]:
model = models.WideResNet(num_groups=3, N=4, num_classes=10, k=2, start_nf=32).to(device)
learn = Learner(data, model, metrics=accuracy).batch_loss_filter(min_loss_perc=.9)
learn.save('stage-0')

## Train model

In [52]:
learn.load('stage-0')
learn.fit_one_cycle(100)
learn.save('stage-1')

epoch,train_loss,valid_loss,accuracy,loss_perc,samp_perc,time
0,1.453295,1.385631,0.484,1.0,1.0,00:45
1,1.523884,1.208875,0.5736,0.90221,0.754207,00:51
2,1.420627,1.114552,0.6079,0.901955,0.731991,00:49
3,1.338659,0.988463,0.6575,0.902135,0.696995,00:49
4,1.314952,0.919686,0.6923,0.901948,0.658974,00:47
5,1.234776,0.808278,0.728,0.90189,0.622937,00:46
6,1.199178,0.729833,0.7459,0.901855,0.580669,00:45
7,1.150283,0.626935,0.7875,0.901864,0.547095,00:45
8,1.092753,0.67828,0.7681,0.901998,0.51897,00:44
9,1.06925,0.584593,0.792,0.901962,0.485557,00:43


## How to use it?

You'll only need to clone the repo. I have added BatchLossFilter as another fastai_extensions. so all you need to do is: 

In [11]:
from fastai_extensions import *

Then prepare your data as you would normally do.

In [12]:
bs = 128
path = untar_data(URLs.CIFAR)
tfms = get_transforms()
data = (ItemLists('.',
                  ImageList.from_folder(path / 'train'),
                  ImageList.from_folder(path / 'test'))
        .label_from_folder()
        .transform(tfms)
        .databunch(bs=bs, val_bs=bs * 2)
        .normalize(cifar_stats))
data

ImageDataBunch;

Train: LabelList (50000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
ship,ship,ship,ship,ship
Path: /home/oguizadl/.fastai/data/cifar10/train;

Valid: LabelList (10000 items)
x: ImageList
Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32),Image (3, 32, 32)
y: CategoryList
ship,ship,ship,ship,ship
Path: /home/oguizadl/.fastai/data/cifar10/test;

Test: None

Build your learner, and add batch_loss_filter. You can modify the min_loss_perc and or/ min_sample_perc hyperparameters if you wish, or leave them with their default values: min_loss_perc = .9, min_sample_perc=.0, which will select the top items responsible for 90% of the loss per batch, independently of how many they are (you may add a contraint if you prefer adding min_sample_perc equal to .1, .2, etc), but my current view is that this doesn't bring any additional value.

In [13]:
model = models.WideResNet(num_groups=3, N=4, num_classes=10, k=2, start_nf=32).to(device)
learn = Learner(data, model, metrics=accuracy).batch_loss_filter(min_loss_perc=.9)

And now you are ready to train!!

In [None]:
learn.fit_one_cycle(100)

Good luck with your experiments!!