In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.layers import *
from local.data.pipeline import *
from local.data.source import *
from local.data.core import *
from local.data.external import *
from local.notebook.showdoc import show_doc
from local.optimizer import *
from local.learner import *

In [None]:
#default_exp callback.hook

# Model hooks

> Callback and helper function to add hooks in models

## Synthetic data

We'll use the following for testing purposes (a basic linear regression problem):

In [None]:
from torch.utils.data import TensorDataset

def synth_data(a=2, b=3, bs=16, n_trn=10, n_val=2):
    x_trn = torch.randn(bs*n_trn, 1)
    y_trn = a*x_trn + b + 0.1*torch.randn(bs*n_trn, 1)
    x_val = torch.randn(bs*n_val, 1)
    y_val = a*x_val + b + 0.1*torch.randn(bs*n_val, 1)
    train_ds = TensorDataset(x_trn, y_trn)
    valid_ds = TensorDataset(x_val, y_val)
    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True)
    valid_dl = TfmdDL(valid_ds, bs=bs)
    return DataBunch(train_dl, valid_dl)

In [None]:
class RegModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
    def forward(self, x): return x * self.a + self.b

In [None]:
def synth_learner(n_trn=10, n_val=2, **kwargs):
    return Learner(RegModel(), synth_data(n_trn=n_trn,n_val=n_val), MSELossFlat(), opt_func=partial(SGD, mom=0.9), **kwargs)

## What are hooks?

Hooks are function you can attach to a particular layer in your model and that will be executed in the foward pass (for forward hooks) or backward pass (for backward hooks).

Forward hooks are functions that take three arguments, the layer it's applied to, the input of that layer and the output of that layer.

In [None]:
tst_model = nn.Linear(5,3)
def example_forward_hook(m,i,o): print(m,i,o)
    
x = torch.randn(4,5)
hook = tst_model.register_forward_hook(example_forward_hook)
y = tst_model(x)
hook.remove()

Linear(in_features=5, out_features=3, bias=True) (tensor([[ 1.8137, -0.3405, -1.8038, -1.2162,  3.0698],
        [-0.3961, -0.1697,  0.1833, -0.9861, -0.0687],
        [ 2.2365,  2.4447, -1.3131,  0.4555,  1.7607],
        [ 0.7440, -0.1006, -1.2119,  0.6596,  1.0360]]),) tensor([[ 2.5172,  0.7680, -0.0755],
        [-0.1359, -0.2236,  0.1336],
        [ 1.6839,  0.9853, -0.0560],
        [ 1.2432,  0.6146,  0.3685]], grad_fn=<AddmmBackward>)


Backward hooks are functions that take three arguments: the layer it's applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output.

In [None]:
def example_backward_hook(m,gi,go): print(m,gi,go)
hook = tst_model.register_backward_hook(example_backward_hook)

x = torch.randn(4,5)
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
hook.remove()

Linear(in_features=5, out_features=3, bias=True) (tensor([-0.0441,  0.3203, -0.1670]), None, tensor([[ 0.0263,  0.1888, -0.0427],
        [-0.2727,  0.2828, -0.1494],
        [-0.3649,  0.3190, -0.2619],
        [-0.3129,  0.0943,  0.0154],
        [ 0.4231, -0.1366, -0.0128]])) (tensor([[-0.0960,  0.1345, -0.0605],
        [-0.1369,  0.0697, -0.0723],
        [ 0.1337,  0.0700, -0.0641],
        [ 0.0551,  0.0461,  0.0299]]),)


Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it's best to have you hook associated to a class so that it can put it in the state of an instance of that class.

## Hook -

In [None]:
@docs
class Hook():
    "Create a hook on `m` with `hook_func`."
    def __init__(self, m, hook_func, is_forward=True, detach=True, cpu=False):
        self.hook_func,self.detach,self.cpu,self.stored = hook_func,detach,cpu,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module, input, output):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach: input,output = to_detach(input, cpu=self.cpu),to_detach(output, cpu=self.cpu)
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()
        
    _docs = dict(__enter__="Register the hook",
                 __exit__="Remove the hook")

This will be called during the forward pass if `is_forward=True`, the backward pass otherwise, and will optionally `detach` and put on the `cpu` the (gradient of the) input/output of the model before passing them to `hook_func`. The result of `hook_func` will be stored in the `stored` attribute of the `Hook`.

In [None]:
tst_model = nn.Linear(5,3)
hook = Hook(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hook.stored, y)

In [None]:
show_doc(Hook.hook_fn)

<h4 id="<code>Hook.hook_fn</code>" class="doc_header"><code>Hook.hook_fn</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L10" class="source_link" style="float:right">[source]</a></h4>

> <code>Hook.hook_fn</code>(**`module`**, **`input`**, **`output`**)

Applies `hook_func` to `module`, `input`, `output`.

In [None]:
show_doc(Hook.remove)

<h4 id="<code>Hook.remove</code>" class="doc_header"><code>Hook.remove</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L15" class="source_link" style="float:right">[source]</a></h4>

> <code>Hook.remove</code>()

Remove the hook from the model.

> Note: It's important to properly remove your hooks for your model when you're done to avoid them being called again next time your model is applied to some inputs, and to free the memory that go with their state.

In [None]:
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
hook = Hook(tst_model, example_forward_hook)
test_stdout(lambda: tst_model(x), f"{tst_model} [{x}] {y.detach()}")
hook.remove()
test_stdout(lambda: tst_model(x), "")

### Context Manager

Since it's very important to remove your `Hook` even if your code is interrupted by some bug, `Hook` can be used as context managers.

In [None]:
show_doc(Hook.__enter__)

<h4 id="<code>Hook.__enter__</code>" class="doc_header"><code>Hook.__enter__</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L21" class="source_link" style="float:right">[source]</a></h4>

> <code>Hook.__enter__</code>(**\*`args`**)

Register the hook

In [None]:
show_doc(Hook.__exit__)

<h4 id="<code>Hook.__exit__</code>" class="doc_header"><code>Hook.__exit__</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L22" class="source_link" style="float:right">[source]</a></h4>

> <code>Hook.__exit__</code>(**\*`args`**)

Remove the hook

In [None]:
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
with Hook(tst_model, example_forward_hook) as h:
    test_stdout(lambda: tst_model(x), f"{tst_model} [{x}] {y.detach()}")
test_stdout(lambda: tst_model(x), "")

In [None]:
def _hook_inner(m,i,o): return o if isinstance(o,Tensor) or is_listy(o) else list(o)

def hook_output(module, detach=True, cpu=False, grad=False):
    "Return a `Hook` that stores activations of `module` in `self.stored`"
    return Hook(module, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)

The activations stored are the gradients if `grad=True`, otherwise the output of `module`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU.

In [None]:
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
with hook_output(tst_model) as h:
    y = tst_model(x)
    test_eq(y, h.stored)
    assert not h.stored.requires_grad
    
with hook_output(tst_model, grad=True) as h:
    y = tst_model(x)
    loss = y.pow(2).mean()
    loss.backward()
    test_close(2*y / y.numel(), h.stored[0])

In [None]:
#cuda
with hook_output(tst_model, cpu=True) as h:
    y = tst_model.cuda()(x.cuda())
    test_eq(h.stored.device, torch.device('cpu'))

## Hooks -

In [None]:
@docs
class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    def __init__(self, ms, hook_func, is_forward=True, detach=True, cpu=False):
        self.hooks = [Hook(m, hook_func, is_forward, detach, cpu) for m in ms]

    def __getitem__(self,i): return self.hooks[i]
    def __len__(self):       return len(self.hooks)
    def __iter__(self):      return iter(self.hooks)
    @property
    def stored(self):        return [o.stored for o in self]

    def remove(self):
        "Remove the hooks from the model."
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()
            
    _docs = dict(stored = "The states saved in each hook.",
                 __enter__="Register the hooks",
                 __exit__="Remove the hooks")

In [None]:
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
hooks = Hooks(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hooks.stored[0], layers[0](x))
test_eq(hooks.stored[1], F.relu(layers[0](x)))
test_eq(hooks.stored[2], y)
hooks.remove()

In [None]:
show_doc(Hooks.stored, name='Hooks.stored')

<h4 id="<code>Hooks.stored</code>" class="doc_header"><code>Hooks.stored</code><a href="" class="source_link" style="float:right">[source]</a></h4>

The states saved in each hook.

In [None]:
show_doc(Hooks.remove)

<h4 id="<code>Hooks.remove</code>" class="doc_header"><code>Hooks.remove</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>Hooks.remove</code>()

Remove the hooks from the model.

### Context Manager

Like `Hook` , you can use `Hooks` as context managers.

In [None]:
show_doc(Hooks.__enter__)

<h4 id="<code>Hooks.__enter__</code>" class="doc_header"><code>Hooks.__enter__</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L17" class="source_link" style="float:right">[source]</a></h4>

> <code>Hooks.__enter__</code>(**\*`args`**)

Register the hooks

In [None]:
show_doc(Hooks.__exit__)

<h4 id="<code>Hooks.__exit__</code>" class="doc_header"><code>Hooks.__exit__</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>Hooks.__exit__</code>(**\*`args`**)

Remove the hooks

In [None]:
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
with Hooks(layers, lambda m,i,o: o) as h:
    y = tst_model(x)
    test_eq(h.stored[0], layers[0](x))
    test_eq(h.stored[1], F.relu(layers[0](x)))
    test_eq(h.stored[2], y)

In [None]:
def hook_outputs(modules, detach=True, cpu=False, grad=False)->Hooks:
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, cpu=cpu, is_forward=not grad)

The activations stored are the gradients if `grad=True`, otherwise the output of `modules`. If `detach=True` they are detached from their history, and if `cpu=True`, they're put on the CPU.

In [None]:
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
x = torch.randn(4,5)
with hook_outputs(layers) as h:
    y = tst_model(x)
    test_eq(h.stored[0], layers[0](x))
    test_eq(h.stored[1], F.relu(layers[0](x)))
    test_eq(h.stored[2], y)
    for s in h.stored: assert not s.requires_grad
    
with hook_outputs(layers, grad=True) as h:
    y = tst_model(x)
    loss = y.pow(2).mean()
    loss.backward()
    g = 2*y / y.numel()
    test_close(g, h.stored[2][0])
    g = g @ layers[2].weight.data
    test_close(g, h.stored[1][0])
    g = g * (layers[0](x) > 0).float()
    test_close(g, h.stored[0][0])

In [None]:
#cuda
with hook_outputs(tst_model, cpu=True) as h:
    y = tst_model.cuda()(x.cuda())
    for s in h.stored: test_eq(s.device, torch.device('cpu'))

## HookCallback -

To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a `hook` function (plus any element you might need).

In [None]:
class HookCallback(Callback):
    "`Callback` that can be used to register hooks on `modules`"
    def __init__(self, modules=None, do_remove=True, is_forward=True, detach=True, cpu=False):
        self.modules,self.do_remove = modules,do_remove
        self.is_forward,self.detach,self.cpu = is_forward,detach,cpu

    def begin_fit(self):
        "Register the `Hooks` on `self.modules`."
        if not self.modules:
            self.modules = [m for m in flatten_model(self.model) if hasattr(m, 'weight')]
        self.hooks = Hooks(self.modules, self.hook, self.is_forward, self.detach, self.cpu)

    def after_fit(self):
        "Remove the `Hooks`."
        if self.do_remove: self._remove()

    def _remove(self): 
        if getattr(self, 'hooks', None): self.hooks.remove()
    
    def __del__(self): self._remove()

If not provided, `modules` will default to the layers of `self.model` that have a `weight` attribute. Depending on `do_remove`, the hooks will be properly removed at the end of training (or in case of error). `is_forward` , `detach` and `cpu` are passed to `Hooks`.

The function called at each forward (or backward) pass is `self.hook` and must be implemented when subclassing this callback.

In [None]:
class TstCallback(HookCallback):
    def hook(self, m, i, o): return o
    def after_batch(self): test_eq(self.hooks.stored[0], self.pred)
        
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.model = nn.Linear(1,1)
learn.fit(1)

[tensor(5.5230), tensor(3.3197), '00:00']


In [None]:
class TstCallback(HookCallback):
    def __init__(self, modules=None, do_remove=True, detach=True, cpu=False):
        super().__init__(modules, do_remove, False, detach, cpu)
    def hook(self, m, i, o): return o
    def after_batch(self): 
        pass#TODO: fix
        #test_eq(self.hooks.stored[0][0], 2*(self.pred-self.yb)/self.pred.shape[0])
        
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.model = nn.Linear(1,1)
learn.fit(1)

[tensor(14.1785), tensor(6.8552), '00:00']


In [None]:
show_doc(HookCallback.begin_fit)

<h4 id="<code>HookCallback.begin_fit</code>" class="doc_header"><code>HookCallback.begin_fit</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L7" class="source_link" style="float:right">[source]</a></h4>

> <code>HookCallback.begin_fit</code>()

Register the `Hooks` on `self.modules`.

In [None]:
show_doc(HookCallback.after_fit)

<h4 id="<code>HookCallback.after_fit</code>" class="doc_header"><code>HookCallback.after_fit</code><a href="https://github.com/fastai/fastai_docs/tree/master/dev/__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>HookCallback.after_fit</code>()

Remove the `Hooks`.

An example of such a `HookCallback` is the following, that stores the mean and stds of activations that go through the network.

In [None]:
@docs
class ActivationStats(HookCallback):
    "Callback that record the mean and std of activations."

    def begin_fit(self):
        "Initialize stats."
        super().begin_fit()
        self.stats = []

    def hook(self, m, i, o): return o.mean().item(),o.std().item()
    
    def after_batch(self):
        "Take the stored results and puts it in `self.stats`"
        if self.training: self.stats.append(self.hooks.stored)
    
    def after_fit(self):
        "Polish the final result."
        self.stats = tensor(self.stats).permute(2,1,0)
        super().after_fit()
        
    _docs = dict(hook="Take the mean and std of the output")