In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab

In [None]:
#default_exp callback.training

In [None]:
#export
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import *

In [None]:
#hide
from nbdev.showdoc import *
from fastai.test_utils import *
from fastai.vision.all import *

# Training callbacks

> Various callbacks to customize training behavior

## ShortEpochCallback -

In [None]:
#export
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())

epoch,train_loss,valid_loss,time
0,00:00,,


In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))

epoch,train_loss,valid_loss,time
0,27.506021,00:00,


## GradientAccumulation -

In [None]:
# export
class GradientAccumulation(Callback):
    "Accumulate gradients before updating weights"
    toward_end,run_before=True,MixedPrecision
    def __init__(self, n_acc=32): store_attr('n_acc')
    def before_fit(self): self.count=0
    def after_loss(self): 
        if self.training: self.learn.loss /= self.n_acc/find_bs(self.learn.yb)
    def after_backward(self):
        self.learn.loss *= self.n_acc/find_bs(self.learn.yb) #so correct loss is logged
        self.count += find_bs(self.learn.yb)
        if self.count < self.n_acc: raise CancelBatchException() #skip weight update
        else: self.count=0

    _docs = dict(before_fit="Set counter to 0",
                 after_backward="Skip weight update if we have not seen enough items")

In [None]:
#hide
class GetGrads(Callback):
    run_after=GradientAccumulation
    def after_backward(self):
        if self.training: self.grads=to_detach([p.grad.clone() for p in self.model.parameters()])

seed=random.randint(0,2**32-1)
with no_random(seed): 
    db=synth_dbunch(bs=8,n_train=1,n_valid=1)
    learn = synth_learner(data=db,cbs=[GetGrads()])
    learn.fit(1, lr=0.01)
    grads=learn.get_grads.grads
with no_random(seed): 
    db=synth_dbunch(bs=1,n_train=8,n_valid=8)
    learn = synth_learner(data=db,cbs=[GradientAccumulation(n_acc=8),GetGrads()])
    learn.fit(1, lr=0.01)
    grads_accum=learn.get_grads.grads
#grads should be the same, valid loss the same,train loss should be different
test_close(grads_accum,grads) 

epoch,train_loss,valid_loss,time
0,6.158221,4.506574,00:00


epoch,train_loss,valid_loss,time
0,6.223377,4.506574,00:00


In [None]:
learn = synth_learner()

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]

epoch,train_loss,valid_loss,time
0,15.826974,15.079698,00:00
1,10.71912,3.257349,00:00


epoch,train_loss,valid_loss,time
0,2.219906,3.257349,00:00
1,2.218323,3.257349,00:00


## BnFreeze

In [None]:
#export
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def set_bn_eval(m:nn.Module, use_eval=True)->None:
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            if use_eval: l.eval()
            else:        l.train()
        set_bn_eval(l)

class BnFreeze(Callback):
    run_after=TrainEvalCallback
    "Freeze moving average statistics in all non-trainable batchnorm layers."
    def before_train(self):
        set_bn_eval(self.model)

`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning. <br>

`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`.

In [None]:
#slow
path = untar_data(URLs.MNIST_TINY)
dls  = ImageDataLoaders.from_folder(path, valid_pct=0.2)

We first demonstrate the mismatch of the running stats when using only `train_bn=False`, by creating a `Learner`...:

In [None]:
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)

...and grab the first `BatchNorm` layer, and store its running mean: 

In [None]:
#slow
m = learn1.model[0][1].running_mean.clone()

You can see that now that running mean has changed:

In [None]:
#slow
learn1.fit(1, lr=0.02)
test_ne(to_detach(learn1.model[0][1].running_mean), m)

epoch,train_loss,valid_loss,time
0,1.131516,2.997253,00:03


When we use the `BnFreeze` callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.

In [None]:
#slow
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.detach().clone()
learn1.fit(1, lr=0.02)
test_eq(to_detach(learn1.model[0][1].running_mean), m)

epoch,train_loss,valid_loss,time
0,0.49108,0.315896,00:02


## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 