In [None]:
#default_exp callback.training

In [None]:
#export
from fastai2.basics import *
from fastai2.callback.progress import *

In [None]:
#hide
from nbdev.showdoc import *
from fastai2.test_utils import *

# Tracking callbacks

> Callbacks that make decisions depending how a monitored metric/loss behaves

## ShortEpochCallback -

In [None]:
#export
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())

epoch,train_loss,valid_loss,time
0,00:00,,


In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))

epoch,train_loss,valid_loss,time
0,5.195858,00:00,


## GradientAccumulation -

In [None]:
# export
class GradientAccumulation(Callback):
    "Accumulate gradients before updating weights"
    toward_end,run_before=True,MixedPrecision

    def __init__(self, n_items=32):
        store_attr(self, 'n_items')

    def begin_fit(self):
        self.count=0

    def after_backward(self):
        self.count += find_bs(self.learn.yb)
        if self.count < self.n_items:
            raise CancelBatchException() #skip weight update
        else:
            self.count=0

    _docs = dict(begin_fit="Set counter to 0",
                 after_backward="Skip weight update if we have not seen enough items"
    )

In [None]:
learn = synth_learner()

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_items=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]

learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_items=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]