# Example 13: CNN Training

This notebook trains a small 2D CNN on a synthetic regression task
and shows the same training step written **twice** — eager and compiled —
then benchmarks them head-to-head.

| Version | Code | When to use |
|---------|------|-------------|
| **Eager** (baseline) | plain function | debugging, single shots |
| **Compiled** | `@nb.compile` | repeated training loops |

**CNN architecture:**
- Stage 1: `conv2d (1→8 ch, 3×3) + ReLU + avg_pool2d 2×2`
- Stage 2: `conv2d (8→16 ch, 3×3) + ReLU + max_pool2d 2×2`
- Head: flatten → `matmul + bias + ReLU`

## 1. Imports

In [5]:
from __future__ import annotations

import time

import numpy as np

import nabla as nb

print("Nabla CNN — Eager vs. Compiled")

Nabla CNN — Eager vs. Compiled


## 2. Synthetic Dataset

Task: predict the **mean squared activation of the center 8×8 patch** of a 16×16 grayscale image.
This gives a cheap, differentiable regression target that depends on a spatial crop.

In [6]:
def make_dataset(seed: int = 0, batch_size: int = 64):
    rng = np.random.default_rng(seed)
    x = rng.normal(size=(batch_size, 16, 16, 1)).astype(np.float32)
    center = x[:, 4:12, 4:12, :]
    y = np.mean(center ** 2, axis=(1, 2, 3), keepdims=True).astype(np.float32)
    return (
        nb.Tensor.from_dlpack(x),
        nb.Tensor.from_dlpack(y),
    )


X, Y = make_dataset(seed=0)
print(f"Inputs:  {X.shape}")
print(f"Targets: {Y.shape}")

Inputs:  [Dim(64), Dim(16), Dim(16), Dim(1)]
Targets: [Dim(64), Dim(1), Dim(1), Dim(1)]


## 3. CNN Architecture

The model is a **pure function over a flat parameter list** — no module classes, no hidden state.
This is Nabla's functional API, analogous to JAX.

In [7]:
def cnn(x: nb.Tensor, params: list[nb.Tensor]) -> nb.Tensor:
    w1, b1, w2, b2, wh, bh = params

    # Stage 1: conv + ReLU + avg pool
    y = nb.relu(nb.conv2d(x, w1, bias=b1, stride=(1, 1), padding=(1, 1, 1, 1)))
    y = nb.avg_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)

    # Stage 2: conv + ReLU + max pool
    y = nb.relu(nb.conv2d(y, w2, bias=b2, stride=(1, 1), padding=(1, 1, 1, 1)))
    y = nb.max_pool2d(y, kernel_size=(2, 2), stride=(2, 2), padding=0)

    # Head: flatten → linear
    y = nb.reshape(y, (int(y.shape[0]), int(y.shape[1] * y.shape[2] * y.shape[3])))
    return nb.relu(nb.matmul(y, wh) + bh)


print("cnn(x, params) defined")

cnn(x, params) defined


## 4. Parameter Initialization

In [8]:
def init_params(seed: int = 1) -> list[nb.Tensor]:
    rng = np.random.default_rng(seed)

    return [
        nb.Tensor.from_dlpack((0.10 * rng.normal(size=(3, 3, 1, 8))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((8,), dtype=np.float32)),
        nb.Tensor.from_dlpack((0.08 * rng.normal(size=(3, 3, 8, 16))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((16,), dtype=np.float32)),
        nb.Tensor.from_dlpack((0.10 * rng.normal(size=(16 * 4 * 4, 1))).astype(np.float32)),
        nb.Tensor.from_dlpack(np.zeros((1,), dtype=np.float32)),
    ]


demo_params = init_params()
print("Parameter shapes:")
for p in demo_params:
    print(f"  {p.shape}")

Parameter shapes:
  [Dim(3), Dim(3), Dim(1), Dim(8)]
  [Dim(8)]
  [Dim(3), Dim(3), Dim(8), Dim(16)]
  [Dim(16)]
  [Dim(256), Dim(1)]
  [Dim(1)]


## 5. Loss Function

Plain MSE loss. Both training variants (eager and compiled) share this function unchanged.

In [9]:
def loss_fn(params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor) -> nb.Tensor:
    diff = cnn(x, params) - y
    return nb.mean(diff * diff)

## 6. Eager Training Step (Baseline)

`value_and_grad` traces and executes the computation graph on every call.
No caching, no compilation overhead — but also no reuse.

> Use the eager step for **debugging** or when you only call it once.

In [10]:
def eager_train_step(
    params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
    new_params = [p - lr * g for p, g in zip(params, grads)]
    return new_params, loss

## 7. Compiled Training Step

Decorating the **exact same logic** with `@nb.compile`:
- **First call** — Nabla traces the Python function and compiles it to a MAX graph.
- **All later calls** with the same input shapes/dtypes hit the cache and skip Python dispatch entirely.

The compiled and eager versions produce **identical numerical results**.

In [11]:
@nb.compile
def compiled_train_step(
    params: list[nb.Tensor], x: nb.Tensor, y: nb.Tensor, lr: float = 3e-2
):
    loss, grads = nb.value_and_grad(loss_fn, argnums=0)(params, x, y)
    new_params = [p - lr * g for p, g in zip(params, grads)]
    return new_params, loss

## 8. Head-to-Head: Eager vs. Compiled

We train two identical models from the same seed for 60 steps.
The **first step is a warmup** (for the compiled version this triggers trace + compile)
and is excluded from the timing measurement.

What to observe:
- Loss curves should be **numerically identical**.
- Compiled should report a lower average `ms/step`.

In [14]:
def run(step_fn, label: str, steps: int = 60, lr: float = 3e-2, seed: int = 0):
    nb._clear_caches()
    x, y = make_dataset(seed=seed)
    params = init_params(seed=seed + 1)

    # Warmup (for compiled: triggers trace + compile)
    params, loss_warmup = step_fn(params, x, y, lr)
    nb.realize_all(loss_warmup, *params)

    print(f"\n{label}")
    print(f"{'Step':<8} {'Loss':<12}")
    print("-" * 22)

    losses = []
    t0 = time.perf_counter()
    for step in range(steps):
        params, loss = step_fn(params, x, y, lr)
        nb.realize_all(loss, *params)
        loss_value = float(loss.item())
        losses.append(loss_value)
        if (step + 1) % 10 == 0:
            print(f"{step + 1:<8} {loss_value:<12.6f}")

    avg_ms = (time.perf_counter() - t0) / steps * 1000.0
    print(f"Avg step: {avg_ms:.1f} ms/step")
    return {
        "avg_ms": avg_ms,
        "initial_loss": losses[0],
        "final_loss": losses[-1],
        "losses": losses,
    }

## 9. Execute Benchmark

Run both variants with identical data and initialization.
This gives an apples-to-apples performance comparison.

In [15]:
eager_result = run(eager_train_step, "Eager baseline", steps=60, lr=3e-2, seed=0)
compiled_result = run(compiled_train_step, "Compiled (@nb.compile)", steps=60, lr=3e-2, seed=0)

# Backward compatibility for partially-run kernels where run() may still return float
if isinstance(eager_result, float):
    eager_result = {"avg_ms": eager_result, "final_loss": float("nan")}
if isinstance(compiled_result, float):
    compiled_result = {"avg_ms": compiled_result, "final_loss": float("nan")}

speedup = eager_result["avg_ms"] / max(compiled_result["avg_ms"], 1e-9)
print("\nSummary")
print("-" * 40)
print(f"Eager avg step:    {eager_result['avg_ms']:.2f} ms")
print(f"Compiled avg step: {compiled_result['avg_ms']:.2f} ms")
print(f"Speedup:           {speedup:.2f}x")
print(
    f"Loss check (eager final / compiled final): "
    f"{eager_result.get('final_loss', float('nan')):.6f} / "
    f"{compiled_result.get('final_loss', float('nan')):.6f}"
)
print(f"Compiled cache stats: {compiled_train_step.stats}")


Eager baseline
Step     Loss        
----------------------
10       0.029387    
20       0.029257    
30       0.029142    
40       0.029040    
50       0.028946    
60       0.028860    
Avg step: 239.4 ms/step

Compiled (@nb.compile)
Step     Loss        
----------------------
10       0.029387    
20       0.029257    
30       0.029142    
40       0.029040    
50       0.028946    
60       0.028860    
Avg step: 24.7 ms/step

Summary
----------------------------------------
Eager avg step:    239.42 ms
Compiled avg step: 24.68 ms
Speedup:           9.70x
Loss check (eager final / compiled final): 0.028860 / 0.028860
Compiled cache stats: CompilationStats(hits=121, misses=1, fallbacks=0, hit_rate=99.2%)
