In [None]:
#default_exp fastai.learner

In [None]:
#hide
from nbdev.showdoc import *

# Learner Errors
> In-place fastai specific errors to ease debugging

In [None]:
#export
from fastdebug.torch import layer_error, device_error

from fastai.data.all import *
from fastai.optimizer import *
from fastai.learner import *
from fastai.callback.core import event
from fastai.callback.training import ShortEpochCallback
from fastai.torch_core import default_device


from fastcore.basics import patch
from fastcore.meta import delegates

This notebook contains a series of various errors that can be used when running with `fastai`. It should be noted here that there is no other imports or magic you need to do to use this section of the library other then: `from fastdebug import *`. It will automatically load in what's needed.

As a style choice, we are choosing to do the `.*` notation as this loads in not only all of our errors, but also replaces sections of `fastai`'s code to inject some error handling (as we'll see later)

## Error Types

In [None]:
#export
def loss_func_error(e:Exception, learn) -> Exception:
    """
    Error that should be run when there is an issue when working with the loss function
    
    Raises with a message stating the shapes of the inputs and targs, and the error
    """
    err = f'There was an issue with calculating the loss with `{getattr(learn.loss_func, "__name__", learn.loss_func)}`'
    err += f'\n\nPrediction shape(s): {[p.shape for p in listify(learn.pred)]}'
    err += f'\nLabel Shape(s): {[y.shape for y in learn.yb]}'
    err += f'\nError: {e.args[0]}'
    e.args = [err]
    raise e

In [None]:
#export
def callback_error(e:Exception, cb:str, event_name:str) -> Exception:
    """
    Raises an error from when a Callback event failed, showing what event, the name of the Callback and the trace
    """
    e.args = [f"Exception raised in the {cb} Callback during {event_name}:\n\n{e.args[0]}"]
    raise e

In [None]:
#export
def catch_pred_errors(e:Exception, model) -> Exception:
    "Catches any errors relating to prediction that are either related to the device or model layers. Else raise `e`"
    if "Input type" in e.args[0]: device_error(e, 'Input', 'Model weights')
    elif "Expected" in e.args[0]: layer_error(e, model)
    else: raise e # anything else 

In [None]:
#export
def catch_loss_errors(e:Exception, learn):
    "Catches any errors that occur with the loss function and its calculation"
    if "Input type" in e.args[0]: device_error(e, 'Model prediction', 'Truths')
    else: loss_func_error(e, learn)

## Modifications and Enhancements to the fastai Source Code and `Learner`:

In [None]:
#export
@patch
def sanity_check(self:Learner, show_table=False):
    "Performs a short epoch and uses all the callbacks in `self.cbs` on the CPU to ensure nothing is broken"
    device = getattr(self.dls, 'device', default_device())
    if hasattr(self.dls, 'device'):
        self.dls.device = 'cpu'
    else:
        # Using raw torch
        self.model.to('cpu')
    self.save('tmp')
    cbs = [ShortEpochCallback(short_valid=False)]
    if show_table:
        with self.no_bar(), self.no_logging():
            self.fit(1, cbs=cbs)
    else:
        self.fit(1, cbs=cbs)
    if hasattr(self.dls, 'device'):
        self.dls.device = device
    else:
        self.model.to(device)
    self.load('tmp')

In [None]:
#export
@patch
@delegates(Learner.sanity_check)
def __init__(self:Learner, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
                 metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
                 moms=(0.95,0.85,0.95), sanity_check=False, **kwargs):
    "Group together a `model`, some `dls` and a `loss_func` to handle training, potentially run a sanity check"
    path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))
    if loss_func is None:
        loss_func = getattr(dls.train_ds, 'loss_func', None)
        assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
    self.dls,self.model = dls,model
    store_attr(but='dls,model,cbs')
    self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()
    self.add_cbs(L(defaults.callbacks)+L(cbs))
    self("after_create")
    if sanity_check: self.sanity_check(**kwargs)

In [None]:
show_doc(Learner.__init__)

<h4 id="Learner.__init__" class="doc_header"><code>Learner.__init__</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.__init__</code>(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*, **`sanity_check`**=*`False`*, **`show_table`**=*`False`*)

Group together a `model`, some `dls` and a `loss_func` to handle training, potentially run a sanity check

In [None]:
show_doc(Learner.sanity_check)

<h4 id="Learner.sanity_check" class="doc_header"><code>Learner.sanity_check</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>Learner.sanity_check</code>(**`show_table`**=*`False`*)

Performs a short epoch and uses all the callbacks in `self.cbs` on the CPU to ensure nothing is broken

With `sanity_check`, you can make sure that you've set everything up properly and you won't get any issues before pushing to the GPU. This allows you to quickly ensure that you won't get any `CUDA` device-assist errors, and that the whole training regiment will go well. 

In [None]:
#export
@patch
def _do_one_batch(self:Learner):
    try:
        self.pred = self.model(*self.xb)
    except RuntimeError as e:
        catch_pred_errors(e, self.model)
    self('after_pred')
    if len(self.yb):
        try:
            self.loss_grad = self.loss_func(self.pred, *self.yb)
        except Exception as e:
            catch_loss_errors(e, self)
        self.loss = self.loss_grad.clone()
    self('after_loss')
    if not self.training or not len(self.yb): return
    self('before_backward')
    self.loss_grad.backward()
    self._with_events(self.opt.step, 'step', CancelStepException)
    self.opt.zero_grad()

In [None]:
#export
@patch
def _call_one(self:Learner, event_name):
    if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'):
        try:
            cb(event_name)
        except Exception as e:
            callback_error(e, cb.__repr__(), event_name)

In [None]:
#export
def module_error(e:AttributeError) -> AttributeError:
    """
    Raises an error when trying to load in a previous `Learner` and custom functions were not available in the namespace
    """
    args = e.args[0]
    err = 'Custom classes or functions exported with your `Learner` are not available in the namespace currently.\n'
    err += 'Please re-declare them before calling `load_learner`:\n'
    err += args
    e.args = [err]
    raise e

In [None]:
#export
def load_learner(fname, cpu=True, pickle_module=pickle):
    "Load a `Learner` object in `fname`, optionally putting it on the `cpu`"
    distrib_barrier()
    try: res = torch.load(fname, map_location='cpu' if cpu else None, pickle_module=pickle_module)
    except AttributeError as e: module_error(e)
    if hasattr(res, 'to_fp32'): res = res.to_fp32()
    if cpu: res.dls.cpu()
    return res

We have a custom `load_learner` function here that can check if everything exported is available when bringing the model in, if not then it'll raise an explicit error