# Example 3a: MLP Training — PyTorch-Style (Imperative)

Nabla supports **two distinct training paradigms**:

| Paradigm | Gradient API | Optimizer API |
|----------|--------------|---------------|
| **PyTorch-style** (this notebook) | `loss.backward()` + `.grad` | `AdamW(model)` → `optimizer.step()` |
| **JAX-style** ([03b](03b_mlp_training_jax)) | `nb.value_and_grad(fn)(args)` | `adamw_init` + `adamw_update` |

Here we demonstrate the **PyTorch-style** imperative API end-to-end.
The training loop mirrors PyTorch exactly: `zero_grad → forward → backward → step`.

In [1]:
import numpy as np

import nabla as nb

print("Nabla MLP Training — PyTorch-style")

Nabla MLP Training — PyTorch-style


## 1. Define the Model

Subclass `nb.nn.Module` and define layers in `__init__`. The `forward()`
method specifies the computation. Parameters (from `nb.nn.Linear`, etc.)
are automatically registered and tracked.

In [2]:
class MLP(nb.nn.Module):
    """Two-layer MLP with ReLU activation."""

    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.fc1 = nb.nn.Linear(in_dim, hidden_dim)
        self.fc2 = nb.nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = nb.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = MLP(4, 32, 1)
print(f"Model: fc1 {model.fc1.weight.shape}, fc2 {model.fc2.weight.shape}")
print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters())}")

Model: fc1 [Dim(4), Dim(32)], fc2 [Dim(32), Dim(1)]
Total trainable parameters: 193


## 2. Create Synthetic Data

We'll create a regression dataset: predict `y = sin(x0) + cos(x1) + 0.5*x2 - x3`.

In [3]:
np.random.seed(42)
n_samples = 200
X_np = np.random.randn(n_samples, 4).astype(np.float32)
y_np = (
    np.sin(X_np[:, 0])
    + np.cos(X_np[:, 1])
    + 0.5 * X_np[:, 2]
    - X_np[:, 3]
).reshape(-1, 1).astype(np.float32)

X = nb.Tensor.from_dlpack(X_np)
y = nb.Tensor.from_dlpack(y_np)
print(f"Dataset: X {X.shape}, y {y.shape}")

Dataset: X [Dim(200), Dim(4)], y [Dim(200), Dim(1)]


## 3. Set Up the Stateful Optimizer

`nb.nn.optim.AdamW` is a **stateful optimizer** — it holds references to
the model parameters and maintains its own moment estimates (m, v).
This is Nabla's counterpart to `torch.optim.AdamW`.

> **JAX-style note:** Nabla's *functional* optimizer (`nb.nn.optim.adamw_init`
> + `nb.nn.optim.adamw_update`) takes params and optimizer state as explicit
> arguments and returns new values — no internal state at all. See [03b](03b_mlp_training_jax).

In [4]:
optimizer = nb.nn.optim.AdamW(model, lr=1e-2)
print(f"Optimizer: AdamW (lr={optimizer.lr}, β=({optimizer.beta1}, {optimizer.beta2}))")

Optimizer: AdamW (lr=0.01, β=(0.9, 0.999))


## 4. PyTorch-Style Training Loop

The four-step rhythm is identical to PyTorch:

1. **`model.zero_grad()`** — clear accumulated `.grad` tensors from the previous iteration
2. **Forward pass** — build the lazy computation graph
3. **`loss.backward()`** — propagate gradients; every parameter with `requires_grad=True` gets its `.grad` populated and **batch-realized**
4. **`optimizer.step()`** — read `.grad` from each parameter, apply the AdamW update, return the updated model

> **Lazy execution note:** Because Nabla cannot mutate tensor data in-place
> without breaking the lazy graph, `optimizer.step()` returns the new model.
> Assign it back to `model` each iteration.

In [5]:
num_epochs = 100
print(f"\n{'Epoch':<8} {'Loss':<12}")
print("-" * 22)

for epoch in range(num_epochs):
    # 1. Clear gradients from the previous iteration
    model.zero_grad()

    # 2. Forward pass
    predictions = model(X)
    loss = nb.nn.functional.mse_loss(predictions, y)

    # 3. Backward pass — gradients stored in p.grad for each parameter
    loss.backward()

    # 4. Optimizer step — reads .grad, applies AdamW update
    model = optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"{epoch + 1:<8} {loss.item():<12.6f}")


Epoch    Loss        
----------------------
10       0.625402    
20       0.242037    
30       0.159796    
40       0.118573    
50       0.082741    
60       0.064866    
70       0.054502    
80       0.047352    
90       0.041529    
100      0.037468    


## 5. Inspecting Gradients via `.grad`

After `loss.backward()`, every trainable parameter exposes its gradient
via `.grad` — exactly like PyTorch. The gradients are already realized
(not lazy) when `.backward()` returns.

In [6]:
# One more backward pass to show .grad access
model.train()
model.zero_grad()
b_loss = nb.nn.functional.mse_loss(model(X), y)
b_loss.backward()

import numpy as _np
print("Parameter gradients after backward():")
for name, param in model.named_parameters():
    g = param.grad
    if g is not None:
        g_np = _np.from_dlpack(g)
        print(f"  {name:30s}  shape {str(g.shape):<16}  |grad|={float(_np.linalg.norm(g_np)):.4f}")

Parameter gradients after backward():
  fc1.bias                        shape [Dim(1), Dim(32)]  |grad|=0.0201
  fc1.weight                      shape [Dim(4), Dim(32)]  |grad|=0.0227
  fc2.bias                        shape [Dim(1), Dim(1)]  |grad|=0.0086
  fc2.weight                      shape [Dim(32), Dim(1)]  |grad|=0.0203


## 6. Evaluate the Trained Model

In [7]:
model.eval()
final_loss = nb.nn.functional.mse_loss(model(X), y)
print(f"\nFinal MSE loss: {final_loss.item():.6f}")

predictions = model(X)
print(f"\n{'Prediction':>12}  {'Target':>12}")
print("-" * 28)
for i in range(5):
    pred_i = nb.gather(predictions, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    true_i = nb.gather(y, nb.constant(np.array([i], dtype=np.int64)), axis=0)
    print(f"{pred_i.item():>12.4f}  {true_i.item():>12.4f}")


Final MSE loss: 0.037132

  Prediction        Target
----------------------------
      0.1029        0.2678
      0.7215        0.7629
      0.7124        0.6380
     -0.3226       -0.3964
      1.2237        1.0610


## 7. Contrast: JAX-Style API (for reference)

The JAX-style equivalent of the same training step — note the absence of
`.backward()`, `.grad`, and stateful optimizer mutations:

```python
# JAX-style (functional) — see 03b_mlp_training_jax.py
def loss_fn(model, X, y):
    return nb.nn.functional.mse_loss(model(X), y)

# Single call computes both the loss value and all gradients
loss, grads = nb.value_and_grad(loss_fn, argnums=0)(model, X, y)

# Functional optimizer — returns new model + new state (no mutation)
model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-2)
```

Both paradigms are fully supported in Nabla:
- **PyTorch-style**: familiar to PyTorch users, great for interactive
  debugging and stateful training loops
- **JAX-style**: composable with `nb.vmap`, `@nb.compile`, `nb.jacrev`, etc.;
  required when nesting transforms or writing pure-functional pipelines

In [8]:
print("\n✅ Example 03a completed!")


✅ Example 03a completed!


## Summary

### PyTorch-Style API (this notebook)

| Concept | API |
|---------|-----|
| Define model | `class MyModel(nb.nn.Module)` |
| Linear layer | `nb.nn.Linear(in_dim, out_dim)` |
| Loss functions | `nb.nn.functional.mse_loss`, `cross_entropy_loss` |
| Clear gradients | `model.zero_grad()` |
| Compute gradients | `loss.backward()` |
| Inspect gradients | `param.grad` |
| Create optimizer | `optimizer = nb.nn.optim.AdamW(model, lr=...)` |
| Update parameters | `model = optimizer.step()` |

### JAX-Style API (see [03b](03b_mlp_training_jax))

| Concept | API |
|---------|-----|
| Compute loss + grads | `loss, grads = nb.value_and_grad(fn, argnums=0)(model, ...)` |
| Optimizer init | `opt_state = nb.nn.optim.adamw_init(params)` |
| Optimizer update | `model, opt_state = nb.nn.optim.adamw_update(...)` |