# Monitoring

> Monitor different aspects of the model and training run

In [1]:
#| default_exp monitoring

In [2]:
# |export
import re
from argparse import Namespace

import lightning as L
import wandb
from glom import glom
from lightning.pytorch.loggers import WandbLogger
from torch import nn

from slow_diffusion.fashionmnist import FashionMNISTDataModule
from slow_diffusion.training import get_tiny_unet

  _torch_pytree._register_pytree_node(


In [3]:
# |export
class MonitorCallback(L.Callback):
    def __init__(self, gloms: dict[str, str]):
        super().__init__()
        if not gloms:
            raise ValueError
        self.gloms = gloms

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        args = Namespace(
            trainer=trainer,
            pl_module=pl_module,
            outputs=outputs,
            batch=batch,
            batch_idx=batch_idx,
        )
        for name, spec in self.gloms.items():
            self.log(name, glom(args, spec), on_step=True)

In [4]:
# |exports
class CountDeadUnitsCallback(L.Callback):
    def __init__(self):
        super().__init__()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        nans = 0
        for _, params in pl_module.named_parameters():
            nans += params.isnan().int().sum().item()
        self.log("dead_units", nans, reduce_fx=sum)

In [5]:
# |exports
class Stats:
    def __init__(self, label, module):
        self.label = label
        self.hook = module.register_forward_hook(self.append)

    def append(self, module, _, activations):
        if not module.training:
            return
        activations = activations.cpu()
        self.log(f"{self.label}:mean", activations.mean().cpu().item())
        self.log(f"{self.label}:std", activations.std().cpu().item())

    def cleanup(self):
        self.hook.remove()


class StatsCallback(L.Callback):
    def __init__(
        self,
        mods: list[type[nn.Module]] | None = None,
        mod_filter: str | None = None,
    ):
        assert mods or mod_filter
        self.mods = []
        if mods is not None:
            self.mods.extend(mods)
        self.mod_filter = mod_filter
        self.mod_stats = []

    def on_fit_start(self, trainer, pl_module):
        if self.mod_filter is not None:
            for name, mod in pl_module.named_modules():
                if re.match(self.mod_filter, name):
                    self.mods.append(mod)

        for i, mod in self.mods:
            s = Stats(f"layer_{i}", mod)
            self.mod_stats.append(s)

    def cleanup(self):
        for s in self.mod_stats:
            s.cleanup()

    def on_fit_end(self, trainer, pl_module):
        self.cleanup()

    def on_exception(self, trainer, pl_module, exception):
        self.cleanup()

In [6]:
with wandb.init():
    wandb_logger = WandbLogger()
    dm = FashionMNISTDataModule(256, n_workers=0)
    dm.setup()
    model = get_tiny_unet()
    trainer = L.Trainer(
        max_epochs=1,
        callbacks=[
            MonitorCallback({"lr": "trainer.optimizers.0.param_groups.0.lr"}),
            CountDeadUnitsCallback(),
            StatsCallback(mod_filter=r"convs"),
        ],
        logger=WandbLogger(),
        precision="bf16-mixed",
    )
    trainer.fit(model=model, datamodule=dm)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jeremy/micromamba/envs/slowai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name    | Type    | Params | Mod

Sanity Checking: |                               | 0/? [00:00<?, ?it/s]

/home/jeremy/micromamba/envs/slowai/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |                                      | 0/? [00:00<?, ?it/s]

/home/jeremy/micromamba/envs/slowai/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [3]:
#| hide
import nbdev

nbdev.nbdev_export()