In [None]:
# default_exp core.nn.utils

In [None]:
%load_ext nb_black
%load_ext autoreload
%autoreload 2

<IPython.core.display.Javascript object>

In [None]:
# hide
import warnings

from nbdev.export import *
from nbdev.showdoc import *

warnings.filterwarnings("ignore")

<IPython.core.display.Javascript object>

# Utilitites
> Custom `Torch` utilitities

In [None]:
# export
from functools import partial
from typing import *

import numpy as np
import torch
from torch import nn

<IPython.core.display.Javascript object>

In [None]:
from fastcore.test import *

<IPython.core.display.Javascript object>

## Model Init

In [None]:
# export
norm_types = (
    nn.BatchNorm1d,
    nn.BatchNorm2d,
    nn.BatchNorm3d,
    nn.InstanceNorm1d,
    nn.InstanceNorm2d,
    nn.InstanceNorm3d,
    nn.LayerNorm,
)

bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

<IPython.core.display.Javascript object>

In [None]:
# export
def init_default(m: nn.Module, func: Callable = nn.init.kaiming_normal_):
    """
    Initialize `m` weights with `func` and set `bias` to 0.
    Source: https://github.com/fastai/fastai/blob/master/fastai/torch_core.py
    """
    if func:
        if hasattr(m, "weight"):
            func(m.weight)
        if hasattr(m, "bias") and hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)
    return m

<IPython.core.display.Javascript object>

In [None]:
with torch.no_grad():
    tst = nn.Linear(4, 5)
    tst.weight.data.uniform_(-1, 1)
    tst.bias.data.uniform_(-1, 1)
    tst = init_default(tst, func=lambda x: x.data.fill_(1.0))
    test_eq(tst.weight, torch.ones(5, 4))
    test_eq(tst.bias, torch.zeros(5))

<IPython.core.display.Javascript object>

In [None]:
# export
def cond_init(m: nn.Module, func: Callable):
    """
    Apply `init_default` to `m` unless it's a batchnorm module.
    Source: https://github.com/fastai/fastai/blob/master/fastai/torch_core.py
    """
    if not isinstance(m, norm_types):
        init_default(m, func)

<IPython.core.display.Javascript object>

In [None]:
with torch.no_grad():
    tst = nn.Linear(4, 5)
    tst.weight.data.uniform_(-1, 1)
    tst.bias.data.uniform_(-1, 1)
    cond_init(tst, func=lambda x: x.data.fill_(1.0))
    test_eq(tst.weight, torch.ones(5, 4))
    test_eq(tst.bias, torch.zeros(5))

    tst = nn.BatchNorm2d(5)
    init = [tst.weight.clone(), tst.bias.clone()]
    cond_init(tst, func=lambda x: x.data.fill_(1.0))
    test_eq(tst.weight, init[0])
    test_eq(tst.bias, init[1])

<IPython.core.display.Javascript object>

In [None]:
# export
def apply_leaf(m: nn.Module, f: Callable):
    """
    Apply `f` to children of `m`.
    Source: https://github.com/fastai/fastai/blob/master/fastai/torch_core.py
    """
    c = m.children()
    if isinstance(m, nn.Module):
        f(m)
    for l in c:
        apply_leaf(l, f)

<IPython.core.display.Javascript object>

In [None]:
tst = nn.Sequential(nn.Linear(4, 5), nn.Sequential(nn.Linear(4, 5), nn.Linear(4, 5)))
apply_leaf(tst, partial(init_default, func=lambda x: x.data.fill_(1.0)))


with torch.no_grad():
    for l in [tst[0], *tst[1]]:
        test_eq(l.weight, torch.ones(5, 4))

    for l in [tst[0], *tst[1]]:
        test_eq(l.bias, torch.zeros(5))

<IPython.core.display.Javascript object>

In [None]:
# export
def apply_init(m: nn.Module, func: Callable = nn.init.kaiming_normal_):
    """
    Initialize all non-batchnorm layers of `m` with `func`.
    Source: https://github.com/fastai/fastai/blob/master/fastai/torch_core.py
    """
    apply_leaf(m, partial(cond_init, func=func))

<IPython.core.display.Javascript object>

In [None]:
tst = nn.Sequential(nn.Linear(4, 5), nn.Sequential(nn.Linear(4, 5), nn.BatchNorm1d(5)))
init = [tst[1][1].weight.clone(), tst[1][1].bias.clone()]
apply_init(tst, func=lambda x: x.data.fill_(1.0))

with torch.no_grad():
    for l in [tst[0], tst[1][0]]:
        test_eq(l.weight, torch.ones(5, 4))
    for l in [tst[0], tst[1][0]]:
        test_eq(l.bias, torch.zeros(5))
        test_eq(tst[1][1].weight, init[0])
        test_eq(tst[1][1].bias, init[1])

<IPython.core.display.Javascript object>

## Miscellaneous  Functions

In [None]:
# export
def set_bn_eval(m: nn.Module):
    """
    Recursively Set bn layers in eval mode for all recursive children of `m`.
    Source: https://github.com/fastai/fastai/blob/master/fastai/callback/training.py#L43
    """
    for l in m.children():
        if isinstance(l, bn_types):
            l.eval()
        set_bn_eval(l)

<IPython.core.display.Javascript object>

In [None]:
model = nn.Sequential(nn.Linear(4, 5), nn.BatchNorm1d(5), nn.Linear(5, 1))

<IPython.core.display.Javascript object>

grab the first `BatchNorm` layer, and store its running mean:

In [None]:
m = model[1].running_mean.clone()

<IPython.core.display.Javascript object>

You can see that now that running mean has changed:

In [None]:
i = torch.randn(32, 4)
o = model(i)
test_ne(m, model[1].running_mean.detach())

<IPython.core.display.Javascript object>

When we use the `set_bn_eval` function, the running statistics will not be changed during training

In [None]:
model = nn.Sequential(nn.Linear(4, 5), nn.BatchNorm1d(5))
model.train()
model.eval()
m = model[1].running_mean.clone()

set_bn_eval(model)

i = torch.randn(32, 4)
o = model(i)

test_eq(m, model[1].running_mean.detach())

<IPython.core.display.Javascript object>

In [None]:
# export
def trainable_params(m: nn.Module):
    "Return all trainable parameters of `m`"
    return [p for p in m.parameters() if p.requires_grad]

<IPython.core.display.Javascript object>

In [None]:
# export
def params(m):
    "Return all parameters of `m`"
    return [p for p in m.parameters()]

<IPython.core.display.Javascript object>

In [None]:
with torch.no_grad():
    m = nn.Linear(4, 5)
    test_eq(trainable_params(m), [m.weight, m.bias])

    m.weight.requires_grad_(False)
    test_eq(trainable_params(m), [m.bias])
    test_eq(params(m), [m.weight, m.bias])

<IPython.core.display.Javascript object>

In [None]:
# export
def maybe_convert_to_onehot(
    target: torch.Tensor, output: torch.Tensor
) -> torch.LongTensor:
    """
    This function infers whether `target` is `one_hot` encoded
    and converts it to `one_hot` encoding if necessary.

    Returns a `one_hot` encoded `torch.LongTensor` with same shape as output.

    Shape:
    - Output : $(N, C)$ where N is the mini-batch size and $C$ is the total number of classes.
    - Returns: $(N, C)$
    """
    target_shape_list = list(target.size())
    if len(target_shape_list) == 1 or (
        len(target_shape_list) == 2 and target_shape_list[1] == 1
    ):
        target = torch.nn.functional.one_hot(target, output.shape[1])
    return target

<IPython.core.display.Javascript object>

In [None]:
output = torch.randn(10, 10)

t0 = torch.nn.functional.one_hot(torch.arange(0, 10) % 3, num_classes=10)
t1 = torch.arange(0, 10) % 3

o0 = maybe_convert_to_onehot(t0, output)
o1 = maybe_convert_to_onehot(t1, output)

test_eq(o0.shape, output.shape)
test_eq(o1.shape, output.shape)
test_eq(t0, o0)

<IPython.core.display.Javascript object>

... We can see that `maybe_convert_to_onehot` converted `t1` to a `one_hot` encoded tensor but did not change `t0` because it was already `one_hot` in encoded `form`/`shape`.

In [None]:
# export
def worker_init_fn(worker_id):
    """
    You can set the seed for `NumPy` in the `worker_init_fn`


    For more information see:
    https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
    """
    np.random.seed(np.random.get_state()[1][0] + worker_id)

<IPython.core.display.Javascript object>

## Export -

In [None]:
notebook2script()

Converted 00_core.logging.ipynb.
Converted 00a_core.structures.ipynb.
Converted 00b_core.visualize.ipynb.
Converted 01_core.nn.utils.ipynb.
Converted 01a_core.nn.losses.ipynb.
Converted 01b_core.nn.optim.optimizers.ipynb.
Converted 01c_core.nn.optim.lr_schedulers.ipynb.
Converted 02_core.classes.ipynb.
Converted 03_config.optimizers.ipynb.
Converted 03a_config.schedulers.ipynb.
Converted 03b_config.common.ipynb.
Converted 04_classification.modelling.backbones.ipynb.
Converted 05_collections.pandas.ipynb.
Converted 06a_collections.callbacks.notebook.ipynb.
Converted 06b_collections.callbacks.ema.ipynb.
Converted index.ipynb.


<IPython.core.display.Javascript object>