In [None]:
#default_exp callback.core

In [None]:
#hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export
from fastai_annotated.foundations import DocumentedEnum, noop

from fastcore.basics import mk_class, class2attr, GetAttr, Stateful, store_attr, detuplify
from fastcore.foundation import L
from fastcore.meta import funcs_kwargs

In [None]:
#export
class Events(DocumentedEnum):
    "All possible Callback events"
    after_create = "Called after the `Learner` is created"
    before_fit = "Called before starting training or inference, ideal for initial setup"
    before_epoch = "Called at the beginning of each epoch, useful for any behavior you need to reset at each epoch"
    before_train = "Called at the beginning of each training part of an epoch"
    before_batch = "Called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance)."
    after_pred = "Called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss."
    after_loss = "Called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance)."
    before_backward = "Called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)"
    before_step = "Called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance)."
    after_step = "Called after the step and before the gradients are zeroed."
    after_batch = "Called at the end of a batch, for any clean-up before the next one."
    after_cancel_train = "Reached immediately after a `CancelTrainException` before proceeding to `after_epoch`"
    after_train = "Called at the end of the training phase of an epoch."
    before_validate = "Called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation."
    after_cancel_validate = "Reached immediately after a `CancelValidException` before proceeding to `after_epoch`"
    after_validate = "Called at the end of the validation part of an epoch."
    after_cancel_epoch = "Reached immediately after a `CancelEpochException` before proceeding to `after_epoch`"
    after_epoch = "Called at the end of an epoch, for any clean-up before the next one."
    after_cancel_fit = "Reached immediately after a `CancelFitException` before proceeding to `after_fit`"
    after_fit = "Called at the end of training, for final clean-up."

In [None]:
show_doc(Events)

<h2 id="Events" class="doc_header"><code>class</code> <code>Events</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>Events</code>(**`value`**, **`names`**=*`None`*, **`module`**=*`None`*, **`qualname`**=*`None`*, **`type`**=*`None`*, **`start`**=*`1`*) :: [`DocumentedEnum`](/fastai_annotated/core.html#DocumentedEnum)

All possible Callback events

In [None]:
#hide_input
for event in Events: 
    show_doc(event, title_level=4)

<h4 id="after_create" class="doc_header"><code>after_create</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after the `Learner` is created

<h4 id="before_fit" class="doc_header"><code>before_fit</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called before starting training or inference, ideal for initial setup

<h4 id="before_epoch" class="doc_header"><code>before_epoch</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the beginning of each epoch, useful for any behavior you need to reset at each epoch

<h4 id="before_train" class="doc_header"><code>before_train</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the beginning of each training part of an epoch

<h4 id="before_batch" class="doc_header"><code>before_batch</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).

<h4 id="after_pred" class="doc_header"><code>after_pred</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.

<h4 id="after_loss" class="doc_header"><code>after_loss</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).

<h4 id="before_backward" class="doc_header"><code>before_backward</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)

<h4 id="before_step" class="doc_header"><code>before_step</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).

<h4 id="after_step" class="doc_header"><code>after_step</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called after the step and before the gradients are zeroed.

<h4 id="after_batch" class="doc_header"><code>after_batch</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the end of a batch, for any clean-up before the next one.

<h4 id="after_cancel_train" class="doc_header"><code>after_cancel_train</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Reached immediately after a `CancelTrainException` before proceeding to `after_epoch`

<h4 id="after_train" class="doc_header"><code>after_train</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the end of the training phase of an epoch.

<h4 id="before_validate" class="doc_header"><code>before_validate</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.

<h4 id="after_cancel_validate" class="doc_header"><code>after_cancel_validate</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Reached immediately after a `CancelValidException` before proceeding to `after_epoch`

<h4 id="after_validate" class="doc_header"><code>after_validate</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the end of the validation part of an epoch.

<h4 id="after_cancel_epoch" class="doc_header"><code>after_cancel_epoch</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Reached immediately after a `CancelEpochException` before proceeding to `after_epoch`

<h4 id="after_epoch" class="doc_header"><code>after_epoch</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the end of an epoch, for any clean-up before the next one.

<h4 id="after_cancel_fit" class="doc_header"><code>after_cancel_fit</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Reached immediately after a `CancelFitException` before proceeding to `after_fit`

<h4 id="after_fit" class="doc_header"><code>after_fit</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Called at the end of training, for final clean-up.

## Callback -

In [None]:
#export
_inner_loop = "before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()

In [None]:
#export
@funcs_kwargs(as_method=True)
class Callback(Stateful,GetAttr):
    """
    Basic class handling tweaks of the training loop by changing a `Learner` in various events.
    
    To use, implement any supported event in `Events` that should be called. 
    An `order` can be passed to dictate its call priority
    """
    order,_default,learn,run,run_train,run_valid = 0,'learn',None,True,True,True
    _methods = [e.name for e in Events]

    def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
    def __repr__(self): return type(self).__name__

    def __call__(self, event_name):
        "Call `self.{event_name}` if it's defined"
        _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
               (self.run_valid and not getattr(self, 'training', False)))
        res = None
        if self.run and _run: 
            try:
                res = getattr(self, event_name, noop)()
            except Exception as e:
                e.args = [f'Exception occured when calling event `{event_name}` in `{self.name}`:\n\t{e.args[0]}']
                raise e
        if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
        return res

    def __setattr__(self, name, value):
        if hasattr(self.learn,name):
            warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
        super().__setattr__(name, value)

    @property
    def name(self):
        "Name of the `Callback`, camel-cased and with '*Callback*' removed"
        return class2attr(self, 'Callback').title()

One way to define callbacks is through subclassing:

In [None]:
class _T(Callback):
    def call_me(self): return "maybe"
test_eq(_T()("call_me"), "maybe")

Another way is by passing the callback function to the constructor:
> Note: Notice that the `cb` still has `self` as a parameter. It is considered a class function we assign in this case

In [None]:
def cb(self): return "maybe"
_t = Callback(before_fit=cb)
test_eq(_t(Events.before_fit), "maybe")

`Callback` provides a shortcut to avoid having to write `self.learn.bla` for any `bla` attribute we seek on `Learner`; instead just write `self.bla`.
> Note: This only works for **getting** attributes, *not* for setting them

In [None]:
mk_class('TstLearner', 'a')

class TstCallback(Callback):
    def batch_begin(self): print(self.a)

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_stdout(lambda: cb('batch_begin'), "1")