# Monitoring

> Monitor different aspects of the model and training run

In [None]:
#| default_exp monitoring

In [None]:
#| hide
import re
from argparse import Namespace
from collections import defaultdict
from dataclasses import dataclass

import lightning as L
from glom import glom
from torch import nn

import wandb
from slow_diffusion.fashionmnist import TinyFashionMNISTDataModule
from slow_diffusion.training import get_tiny_unet

In [None]:
# |export
class MonitorCallback(L.Callback):
    def __init__(self, /, **gloms):
        super().__init__()
        if not gloms:
            raise ValueError
        self.gloms = gloms

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        args = Namespace(
            trainer=trainer,
            pl_module=pl_module,
            outputs=outputs,
            batch=batch,
            batch_idx=batch_idx,
        )
        for name, spec in self.gloms.items():
            self.log(name, glom(args, spec), on_step=True)

In [None]:
# |exports
class CountDeadUnitsCallback(L.Callback):
    def __init__(self):
        super().__init__()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        nans = 0
        for _, params in pl_module.named_parameters():
            nans += params.isnan().int().sum().item()
        self.log("dead_units", nans, reduce_fx=sum)

In [None]:
# |exports
class Hook:
    """Wrapper for a PyTorch hook, facilitating adding instance state"""

    def __init__(self, m, f):
        self.hook = m.register_forward_hook(partial(f, self))

    def remove(self):
        self.hook.remove()

    def __del__(self):
        self.remove()

In [None]:
# |exports
class ModuleStats:
    def __init__(self, label, module):
        self.label = label
        self.hook = module.register_forward_hook(self.append)

    def append(self, module, _, activations):
        if not module.training:
            return
        activations = activations.cpu()
        self.log(f"{label}:mean", activations.mean().cpu().item())
        self.log(f"{label}:std", activations.std().cpu().item())

    def cleanup(self):
        self.hook.remove()


class StatsCallback(L.Callback):
    def __init__(
        self,
        mods: list[type[nn.Module]] | None = None,
        mod_filter: list[str] | None = None,
    ):
        assert mods or mod_filter
        self.mods = []
        if mods is not None:
            self.mods.extend(mods)
        self.mod_filter = mod_filter
        self.mod_stats = []

    def on_fit_start(self, trainer, pl_module):
        if self.mod_filter is not None:
            for name, mod in pl_module.named_modules():
                if re.match(self.mod_filter, name):
                    self.mods.append(mod)

        for i, mod in self.mods:
            ms = ModuleStats(f"layer_{i}", mod)
            self.mod_stats.append(ms)

    def cleanup(self):
        for hook in self.hooks:
            hook.cleanup()

    def on_fit_end(self, trainer, pl_module):
        self.cleanup()

    def on_exception(self, trainer, pl_module, exception):
        self.cleanup()

In [None]:
class PrintLogger(L.pytorch.loggers.logger.DummyLogger):
    def log_metrics(self, metrics, step):
        print(metrics)

In [None]:
dm = TinyFashionMNISTDataModule(32, n_workers=0)
dm.setup()
model = get_tiny_unet()
trainer = L.Trainer(
    max_epochs=1,
    callbacks=[
        MonitorCallback(lr="trainer.optimizers.0.param_groups.0.lr"),
        CountDeadUnitsCallback(),
        StatsCallback(mod_filter=r"convs"),
    ],
    logger=PrintLogger(),
)
trainer.fit(model=model, datamodule=dm)

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()