# 02 â€” PyTorch training loop (synthetic)

This notebook profiles a small training loop on synthetic data.

It demonstrates a useful pattern:

- profile one **epoch** at a time (stable window)
- log/print the summary per epoch

If PyTorch/CUDA is not available, the training cells will be skipped.


In [None]:
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
except Exception as e:
    torch = None
    print("torch not available:", e)

if torch is None or not torch.cuda.is_available():
    print("CUDA not available; skipping training demo.")

In [None]:
if torch is not None and torch.cuda.is_available():
    from profgpu import GpuMonitor

    class SmallMLP(nn.Module):
        def __init__(self, d_in=1024, d_hidden=2048, d_out=10):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(d_in, d_hidden),
                nn.ReLU(),
                nn.Linear(d_hidden, d_out),
            )

        def forward(self, x):
            return self.net(x)

    device = torch.device("cuda")
    model = SmallMLP().to(device)
    opt = optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    summaries = []
    epochs = 3
    batches_per_epoch = 200
    batch_size = 256

    for epoch in range(epochs):
        with GpuMonitor(interval_s=0.2, sync_fn=torch.cuda.synchronize, warmup_s=0.2) as mon:
            for _ in range(batches_per_epoch):
                x = torch.randn(batch_size, 1024, device=device)
                y = torch.randint(0, 10, (batch_size,), device=device)

                opt.zero_grad(set_to_none=True)
                logits = model(x)
                loss = loss_fn(logits, y)
                loss.backward()
                opt.step()

        summaries.append(mon.summary)
        print(f"epoch {epoch}\n{mon.summary.format()}\n")

In [None]:
# Optional: turn summaries into a simple table

if "summaries" in globals() and summaries:
    rows = [
        {
            "epoch": i,
            "duration_s": s.duration_s,
            "util_mean": s.util_gpu_mean,
            "util_p95": s.util_gpu_p95,
            "mem_max_mb": s.mem_used_max_mb,
            "power_mean_w": s.power_mean_w,
        }
        for i, s in enumerate(summaries)
    ]

    try:
        import pandas as pd

        df = pd.DataFrame(rows)
        display(df)
    except Exception:
        for r in rows:
            print(r)