# Monitoring

> Monitor different aspects of the model and training run

In [1]:
#| default_exp monitoring

In [2]:
# |export
import re
from argparse import Namespace
from collections import Counter

import lightning as L
import matplotlib.pyplot as plt
from glom import glom
from lightning.pytorch.loggers import WandbLogger
from torch import nn

import wandb
from slow_diffusion.fashionmnist import FashionMNISTDataModule
from slow_diffusion.training import get_tiny_unet_lightning

  from .autonotebook import tqdm as notebook_tqdm


We want to make sure the model can be inspected

In [3]:
def test_run(callback):
    with wandb.init():
        dm = FashionMNISTDataModule(256, n_workers=0)
        dm.setup()
        model = get_tiny_unet_lightning()
        trainer = L.Trainer(
            max_epochs=1,
            callbacks=[callback],
            logger=WandbLogger(),
            precision="bf16-mixed",
            log_every_n_steps=1,
        )
        trainer.fit(model=model, datamodule=dm)

In [4]:
# |export
class MonitorCallback(L.Callback):
    """Log arbitrary properties in the training run, such as LR."""

    def __init__(self, gloms: dict[str, str]):
        super().__init__()
        self.gloms = gloms

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        args = Namespace(
            trainer=trainer,
            pl_module=pl_module,
            outputs=outputs,
            batch=batch,
            batch_idx=batch_idx,
        )
        for name, spec in self.gloms.items():
            self.log(name, glom(args, spec), on_step=True)

In [None]:
%%time
test_run(MonitorCallback({"lr": "trainer.optimizers.0.param_groups.0.lr"}))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjfisher40[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
Loading `train_dataloader` to estimate number of stepping batches.
/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the

Sanity Checking: |                                                                      | 0/? [00:00<?, ?it/s]

/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|                                                     | 0/2 [00:00<?, ?it/s]

In [None]:
# |exports
class CountDeadUnitsCallback(L.Callback):
    """Check for numeric underflow or overflow"""

    def __init__(self):
        super().__init__()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        nans = 0
        zeros = 0
        for _, params in pl_module.named_parameters():
            nans += params.isnan().int().sum().item()
            zeros += (params == 0).sum().item()
        self.log("nans", nans, reduce_fx=max)
        self.log("zeros", zeros, reduce_fx=max)

In [None]:
%%time
test_run(CountDeadUnitsCallback())

Check activation distribution metrics.

In [None]:
# |exports
class Stats:
    def __init__(self, label, module, log, live):
        self.label = label
        self.hook = module.register_forward_hook(self.append)
        self.log = log
        self.live = live
        self.means = []
        self.stds = []

    def append(self, module, _, activations):
        if not module.training:
            return
        activations = activations.cpu()
        mean = activations.mean().cpu().item()
        std = activations.std().cpu().item()
        if self.live:
            self.log(f"{self.label}:mean", mean)
            self.log(f"{self.label}:std", std)
        else:
            self.means.append(mean)
            self.stds.append(std)

    def plot(self, ax0, ax1):
        ax0.plot(self.means)
        ax1.plot(self.stds, label=self.label)

    def cleanup(self):
        self.hook.remove()


class StatsCallback(L.Callback):
    def __init__(
        self,
        mods: list[type[nn.Module]] | None = None,
        mod_filter: str | None = None,
        live=False,
    ):
        assert mods or mod_filter
        self.mods = []
        if mods is not None:
            self.mods.extend(mods)
        self.mod_filter = mod_filter
        self.mod_stats = []
        self.live = live

    def on_fit_start(self, trainer, pl_module):
        c = Counter()
        for mod in self.mods:
            cls_name = mod.__class__.__name__
            name = f"{cls_name}:{c.get(cls_name)}"
            s = Stats(name, mod, self.log, self.live)
            self.mod_stats.append(s)
            c.update((cls_name,))

        if self.mod_filter is not None:
            for name, mod in pl_module.named_modules():
                if re.match(self.mod_filter, name):
                    s = Stats(name, mod, self.log, self.live)
                    self.mod_stats.append(s)

    def plot(self, log=True):
        with plt.style.context("ggplot"):
            fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(8, 3))
            ax0.set(title="Means", xlabel="Time Step", ylabel="Activation")
            ax1.set(title="STDs", xlabel="Time Step")
            for mod_stat in self.mod_stats:
                mod_stat.plot(ax0, ax1)
            fig.legend(loc=7)
            fig.subplots_adjust(right=0.75)
            return fig

    def log_stats(self):
        if not self.live:
            fig = self.plot()
            img = wandb.Image(fig)
            wandb.log({"stats": img})
            plt.close(fig)

    def on_train_epoch_end(self, trainer, pl_module):
        self.log_stats()

    def cleanup(self):
        for s in self.mod_stats:
            s.cleanup()

    def on_fit_end(self, trainer, pl_module):
        self.log_stats()
        self.cleanup()

    def on_exception(self, trainer, pl_module, exception):
        self.log_stats()
        self.cleanup()

In [None]:
%%time
cb = StatsCallback(mod_filter=r"unet.(((down|up)blocks.\d+)|start|middle|end)(?!\.)")
test_run(cb)

You can see how _bad_ the training dynamics are initially

In [2]:
#| hide
import nbdev

nbdev.nbdev_export()