# Initialization

> We need to make sure that our gradients do not explode or vanish. This should help us break down the apropriate initialization for each torch module.

In [None]:
# |default_exp init

In [None]:
# |export
from functools import singledispatch

import matplotlib.pyplot as plt
import torch
from torch import nn

from slow_diffusion.fashionmnist import TinyFashionMNISTDataModule
from slow_diffusion.model import ConvBlock, TimeEmbeddingMixer

In [None]:
# |hide
plt.style.use("ggplot")

In [None]:
dm = TinyFashionMNISTDataModule(bs=32)
dm.setup()

In [None]:
((x_t, t), ε) = next(iter(dm.train_dataloader()))
x_t.shape

Let's verify that the distribution remains normal after the transformation is applied.

In [None]:
@singledispatch
def kaiming(module):
    ...

In [None]:
@kaiming.register(ConvBlock)
def _(c: ConvBlock):
    if isinstance(c.act, nn.ReLU):
        torch.nn.init.kaiming_normal_(c.conv.weight, a=0.0)
    elif isinstance(c.act, nn.SiLU):
        torch.nn.init.kaiming_normal_(c.conv.weight, a=0.2)
    else:
        raise ValueError
    if c.conv.bias is not None:
        torch.nn.init.constant_(c.conv.bias, 0)

In [None]:
conv_relu_default = ConvBlock(1, 1, act=nn.ReLU)
conv_relu = ConvBlock(1, 1, act=nn.ReLU)
conv_silu_default = ConvBlock(1, 1, act=nn.SiLU)
conv_silu = ConvBlock(1, 1, act=nn.SiLU)
for m in [conv_relu, conv_silu]:
    kaiming(m)

In [None]:
def plot_distribution_variance(xb, args, modules: list[tuple[str, nn.Module]]):
    fig, axes = plt.subplots(1, len(modules), figsize=(4 * len(modules), 4))
    for ax, (label, c) in zip(axes, cs):
        _, bins, _ = ax.hist(xb.reshape(-1), bins=30, alpha=0.5, label="input")
        with torch.no_grad():
            yb = c(*args)
        ax.hist(yb.reshape(-1), bins, alpha=0.33, label=label)
        ax.set(xlabel="Logit magnitude", ylabel="Frequency", title=label)

In [None]:
cs = [
    ("default relu", conv_relu_default),
    ("kaiming relu", conv_relu),
    ("default silu", conv_silu_default),
    ("kaiming silu", conv_silu),
]
plot_distribution_variance(x_t, (x_t,), cs)

In [None]:
@kaiming.register(TimeEmbeddingMixer)
def _(t: TimeEmbeddingMixer):
    ...

Good! The Kaiming methods preserve the distribution variance, unlike default.

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()