# Monitoring

> Monitor different aspects of the model and training run

In [1]:
#| default_exp monitoring

In [2]:
# |export
import re
from argparse import Namespace

import lightning as L
from glom import glom
from lightning.pytorch.loggers import WandbLogger
from torch import nn

import wandb
from slow_diffusion.fashionmnist import FashionMNISTDataModule
from slow_diffusion.training import get_tiny_unet

  from .autonotebook import tqdm as notebook_tqdm


We want to make sure the model can be inspected

In [3]:
def fashion_test_run(callback):
    with wandb.init():
        wandb_logger = WandbLogger()
        dm = FashionMNISTDataModule(256, n_workers=0)
        dm.setup()
        model = get_tiny_unet()
        trainer = L.Trainer(
            max_epochs=1,
            callbacks=[],
            logger=WandbLogger(),
            precision="bf16-mixed",
        )
        trainer.fit(model=model, datamodule=dm)

Log arbitrary properties in the training run, such as LR.

In [4]:
# |export
class MonitorCallback(L.Callback):
    def __init__(self, gloms: dict[str, str]):
        super().__init__()
        if not gloms:
            raise ValueError
        self.gloms = gloms

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        args = Namespace(
            trainer=trainer,
            pl_module=pl_module,
            outputs=outputs,
            batch=batch,
            batch_idx=batch_idx,
        )
        for name, spec in self.gloms.items():
            self.log(name, glom(args, spec), on_step=True)

In [None]:
fashion_test_run(MonitorCallback({"lr": "trainer.optimizers.0.param_groups.0.lr"}))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjfisher40[0m. Use [1m`wandb login --relogin`[0m to force relogin


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 6954.47 examples/s]
Saving the dataset (1/1 shards): 100%|███████████████████████████████████████████████████████████████████████████████████| 60000/60000 [00:00<00:00, 229322.46 examples/s]
Saving the dataset (1/1 shards): 100%|███████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 190447.61 examples/s]
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
Loading `train

Sanity Checking: |                                                                                                                                  | 0/? [00:00<?, ?it/s]

/Users/jeremiahfisher/miniforge3/envs/slow_diffusion/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|                                                                                                                 | 0/2 [00:00<?, ?it/s]



Epoch 0:   0%|                                                                                                                                    | 0/235 [00:00<?, ?it/s]

Check for overflow.

In [None]:
# |exports
class CountDeadUnitsCallback(L.Callback):
    def __init__(self):
        super().__init__()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        nans = 0
        for _, params in pl_module.named_parameters():
            nans += params.isnan().int().sum().item()
        self.log("dead_units", nans, reduce_fx=sum)

In [None]:
CountDeadUnitsCallback(),
StatsCallback(mod_filter=r"convs"),

Check activation distribution metrics.

In [None]:
# |exports
class Stats:
    def __init__(self, label, module):
        self.label = label
        self.hook = module.register_forward_hook(self.append)

    def append(self, module, _, activations):
        if not module.training:
            return
        activations = activations.cpu()
        self.log(f"{self.label}:mean", activations.mean().cpu().item())
        self.log(f"{self.label}:std", activations.std().cpu().item())

    def cleanup(self):
        self.hook.remove()


class StatsCallback(L.Callback):
    def __init__(
        self,
        mods: list[type[nn.Module]] | None = None,
        mod_filter: str | None = None,
    ):
        assert mods or mod_filter
        self.mods = []
        if mods is not None:
            self.mods.extend(mods)
        self.mod_filter = mod_filter
        self.mod_stats = []

    def on_fit_start(self, trainer, pl_module):
        if self.mod_filter is not None:
            for name, mod in pl_module.named_modules():
                if re.match(self.mod_filter, name):
                    self.mods.append(mod)

        for i, mod in self.mods:
            s = Stats(f"layer_{i}", mod)
            self.mod_stats.append(s)

    def cleanup(self):
        for s in self.mod_stats:
            s.cleanup()

    def on_fit_end(self, trainer, pl_module):
        self.cleanup()

    def on_exception(self, trainer, pl_module, exception):
        self.cleanup()

In [None]:
with wandb.init():
    wandb_logger = WandbLogger()
    dm = FashionMNISTDataModule(256, n_workers=0)
    dm.setup()
    model = get_tiny_unet()
    trainer = L.Trainer(
        max_epochs=1,
        callbacks=[
            MonitorCallback({"lr": "trainer.optimizers.0.param_groups.0.lr"}),
            CountDeadUnitsCallback(),
            StatsCallback(mod_filter=r"convs"),
        ],
        logger=WandbLogger(),
        precision="bf16-mixed",
    )
    trainer.fit(model=model, datamodule=dm)

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()