# Custom fastai callback for debugging purposes
This is a self-contained notebook that shows how to create your own [fastai callback](https://docs.fast.ai/callback.html). It's a bit bare-bones, see my more [detailed blogpost](https://laurenth.me/2019/06/10/custom-fastai-callbacks) for a more thorough walkthrough.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.vision import *
from fastai.basic_data import DataBunch
from fastai.basic_train import LearnerCallback

from IPython.core.debugger import set_trace
from pathlib import Path

import torch
from torch.utils.data import TensorDataset, DataLoader

Let's quickly set up some trash data so we don't have to grab a whole dataset just for demonstration purposes.

In [3]:
x = torch.ones(16,3,128,128)
y = torch.ones(16,1,128,128)

In [4]:
train_set = TensorDataset(x,y)
val_set = TensorDataset(x,y)

Make sure to have a batch size that is greater than 1 or the BatchNorm layers in ResNet will complain.

In [5]:
train_dl = DataLoader(train_set, 4)
valid_dl = DataLoader(val_set, 4)

In [6]:
data = DataBunch(train_dl, valid_dl)

# Just needs to have some kind of value
# fastai uses it to decide on the size of the custom head
data.c=2 

Quickly set up some example learner. We'll skip downloading the weights because we're just here for demonstration purposes.

In [7]:
learn = cnn_learner(data, models.resnet34, metrics=accuracy, pretrained=False)

The meat of it: setting up the custom callback. The idea is to debug the inputs going into the loss function, so we'll drop a `set_trace()` in `on_loss_begin()`.

In [8]:
class LossDebug(LearnerCallback):
    def __init__(self, learn:Learner):
        super().__init__(learn)
        
    def on_loss_begin(self, last_output, last_target, **kwargs):
        set_trace()

In [None]:
learn.fit(1, callbacks=[LossDebug(learn)])

epoch,train_loss,valid_loss,accuracy,time


--Return--
None
> [0;32m<ipython-input-8-a746fb510396>[0m(6)[0;36mon_loss_begin[0;34m()[0m
[0;32m      2 [0;31m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlearn[0m[0;34m:[0m[0mLearner[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0mlearn[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m[0;32m      5 [0;31m    [0;32mdef[0m [0mon_loss_begin[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mlast_output[0m[0;34m,[0m [0mlast_target[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m        [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
