In [None]:
# default_exp torch

# Pytorch Errors

> All the possible errors that fastdebug can support and verbosify involving Pytorch

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.test import test_eq

In [None]:
#export
import torch
import re
from fastai.callback.hook import Hook
from fastai.torch_core import to_detach
from fastai.layers import flatten_model

from fastcore.basics import store_attr

  return torch._C._cuda_getDeviceCount() > 0


## Errors

While some errrors are specifically designed for the [fastai](https://docs.fast.ai) library, the general idea still holds true in raw `Pytorch` as well. 

In [None]:
#export
def device_error(e:Exception, a:str, b:str) -> Exception:
    """
    Verbose error for if `a` and `b` are on different devices
    Should be used when checking if a model is on the same device, or two tensors
    """
    inp, weight, _ = e.args[0].replace('( ', '').split(')')
    inp = inp.replace('Input type', f'{a} has type: \t\t')
    weight = weight.replace(' and weight type', f'{b} have type: \t')
    err = f'Mismatch between weight types\n\n{inp})\n{weight})\n\nBoth should be the same.'
    e.args = [err]
    raise e

The device error provides a much more readable error when `a` and `b` were on two different devices. An situation is below:
```python
inp = torch.rand().cuda()
model = model.cpu()
try:
    _ = model(inp)
except Exception as e:
    device_error(e, 'Input type', 'Model weights')
```
And our new log:
```bash
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-28-981e0ace9c38> in <module>()
      2     model(x)
      3 except Exception as e:
----> 4     device_error(e, 'Input type', 'Model weights')

10 frames
/usr/local/lib/python3.7/dist-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
    993 
    994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
    996             return _convert(ret, cls)
    997 

RuntimeError: Mismatch between weight types

Input type has type: 		 (torch.cuda.FloatTensor)
Model weights have type: 	 (torch.FloatTensor)

Both should be the same.
```

In [None]:
#export
def hook_fn(m, i):
    "Simple hook fn to return the layer"
    return m

In [None]:
#export
class PreHook(Hook):
    "Creates and registers a hook on `m` with `hook_func` as a forward pre_hook"
    def __init__(self, m, hook_func, is_forward=True, detach=True, cpu=False, gather=False):
        store_attr('hook_func,detach,cpu,gather')
        f = m.register_forward_pre_hook if is_forward else m.register_backward_pre_hook
        self.hook = f(self.hook_fn)
        self.stored,self.removed = None, False

    def hook_fn(self, module, inp):
        "Applies `hook_fn` to `module` and `inp`"
        if self.detach:
            inp = to_detach(inp, cpu=self.cpu, gather=self.gather)
        self.stored = self.hook_func(module, inp)

In [None]:
#export
class ForwardHooks():
    "Create several forward-hooks on the modules in `ms` with `hook_func`"
    def __init__(self, ms, hook_func, is_forward=True, detach=True, cpu=False):
        self.hooks = []
        for i, m in enumerate(flatten_model(ms)):
            self.hooks.append(PreHook(m, hook_func, is_forward, detach, cpu))

In [None]:
#export
def hook_outputs(modules, detach=True, cpu=False, grad=False):
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return ForwardHooks(modules, hook_fn, detach=detach, cpu=cpu, is_forward=not grad)

By using forward hooks, we can locate our problem layers when they arrive rather than trying to figure out which one it is through a list of confusing errors.

For this tutorial and testing we'll purposefully write a broken model:

In [None]:
from torch import nn
m = nn.Sequential(
    nn.Conv2d(3,3,1),
    nn.ReLU(),
    nn.Linear(3,2)
)

In [None]:
#export
def layer_error(e:Exception, model, *inp) -> Exception:
    """
    Verbose error for when there is a size mismatch between some input and the model. 
    `model` should be any torch model
    `inp` is the input that went to the model
    """
    args = e.args[0].replace("Expected", "Model expected")
    hooks = hook_outputs(model)
    try:
        _ = model(*inp)
    except:
        pass
    finally:
        layers,num = [], 0
        for i, layer in enumerate(hooks.hooks):
            if layer.stored is not None: 
                layers.append(layer.stored)
                num += 1
        layer = layers[-1]
        [h.remove() for h in hooks.hooks]
        e.args = [f'Size mismatch between input tensors and what the model expects\n{"-"*76}\nLayer: {i}, {layer}\nError: {args}']
        raise e

`layer_error` can be used anywhere that you want to check that the inputs are right for some model.

Let's use our `m` model from earlier to show an example:

In [None]:
#failing
inp = torch.rand(5,2, 3)
try:
    m(inp)
except Exception as e:
    layer_error(e, m, inp)

RuntimeError: Size mismatch between input tensors and what the model expects
----------------------------------------------------------------------------
Layer: 2, Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
Error: Model expected 4-dimensional input for 4-dimensional weight [3, 3, 1, 1], but got 3-dimensional input of size [5, 2, 3] instead

This will also work with multi-input and multi-output models:

In [None]:
class DoubleInputModel(nn.Sequential):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Conv2d(3,3,1),
            nn.ReLU(),
            nn.Linear(3,2))
    def forward(self, a, b):
        return self.layers(a), self.layers(b)

In [None]:
model = DoubleInputModel()

In [None]:
#failing
inp = torch.rand(5,2, 3)
try:
    model(inp, inp)
except Exception as e:
    layer_error(e, model, inp, inp)

RuntimeError: Size mismatch between input tensors and what the model expects
----------------------------------------------------------------------------
Layer: 2, Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
Error: Model expected 4-dimensional input for 4-dimensional weight [3, 3, 1, 1], but got 3-dimensional input of size [5, 2, 3] instead

Much more readable!