In [None]:
#default_exp callback.training

In [None]:
#export
from fastai2.basics import *
from fastai2.callback.progress import *
from fastai2.callback.fp16 import *

In [None]:
#hide
from nbdev.showdoc import *
from fastai2.test_utils import *

# Tracking callbacks

> Callbacks that make decisions depending how a monitored metric/loss behaves

## ShortEpochCallback -

In [None]:
#export
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())

epoch,train_loss,valid_loss,time
0,00:00,,


In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))

epoch,train_loss,valid_loss,time
0,12.680832,00:00,


## GradientAccumulation -

In [None]:
# export
class GradientAccumulation(Callback):
    "Accumulate gradients before updating weights"
    toward_end,run_before=True,MixedPrecision

    def __init__(self, n_acc=32): store_attr(self, 'n_acc')
    def begin_fit(self): self.count=0

    def after_backward(self):
        self.count += find_bs(self.learn.yb)
        if self.count < self.n_acc: raise CancelBatchException() #skip weight update
        else: self.count=0

    _docs = dict(begin_fit="Set counter to 0",
                 after_backward="Skip weight update if we have not seen enough items")

In [None]:
learn = synth_learner()

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]

epoch,train_loss,valid_loss,time
0,14.479012,6.487211,00:00
1,7.784875,0.214303,00:00


epoch,train_loss,valid_loss,time
0,0.224275,0.214303,00:00
1,0.2242,0.214303,00:00


## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.l