# Example 6a: Transformer Training — PyTorch-Style (Imperative)

Nabla supports two training paradigms; this notebook demonstrates the **PyTorch-style** imperative API:

| Paradigm | Gradient API | Optimizer API |
|----------|--------------|---------------|
| **PyTorch-style** (this notebook) | `loss.backward()` + `.grad` | `AdamW(model)` → `optimizer.step()` |
| **JAX-style** ([6b](06b_transformer_jax)) | `nb.value_and_grad(fn)(args)` | `adamw_init` + `adamw_update` |

We build a small Transformer encoder for a synthetic **sequence classification** task
using `nb.nn.TransformerEncoderLayer`, `Embedding`, and `MultiHeadAttention`.

In [1]:
import numpy as np

import nabla as nb

import time

print("Nabla Transformer Training — PyTorch-style")

Nabla Transformer Training — PyTorch-style


## 1. Positional Encoding

We'll use sinusoidal positional encoding, computed as a fixed buffer.

In [2]:
def make_positional_encoding(max_len: int, d_model: int) -> np.ndarray:
    """Sinusoidal positional encoding."""
    pe = np.zeros((max_len, d_model), dtype=np.float32)
    position = np.arange(0, max_len, dtype=np.float32)[:, np.newaxis]
    div_term = np.exp(
        np.arange(0, d_model, 2, dtype=np.float32) * -(np.log(10000.0) / d_model)
    )
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe  # (max_len, d_model)

## 2. Define the Model

Our `TransformerClassifier` is an `nb.nn.Module` subclass with these components:

| Component | Purpose |
|-----------|---------|
| `nb.nn.Embedding` | Maps token IDs → dense vectors |
| Sinusoidal PE | Encodes position information (fixed, not learned) |
| `nb.nn.TransformerEncoderLayer` × N | Self-attention + feed-forward blocks |
| `nb.nn.Linear` | Classification head |

The `__init__` method creates these components; `forward` chains them together.

In [3]:
class TransformerClassifier(nb.nn.Module):
    """Transformer encoder for sequence classification."""

    def __init__(self, vocab_size, d_model, num_heads, num_layers,
                 num_classes, max_len=128, dim_feedforward=128):
        super().__init__()
        self.d_model = d_model

        # --- Embeddings ---
        self.embedding = nb.nn.Embedding(vocab_size, d_model)
        pe_np = make_positional_encoding(max_len, d_model)
        self.pe = nb.Tensor.from_dlpack(pe_np)  # fixed, not learned

        # --- Encoder stack ---
        self.layers = []
        for i in range(num_layers):
            layer = nb.nn.TransformerEncoderLayer(
                d_model=d_model, num_heads=num_heads,
                dim_feedforward=dim_feedforward, dropout=0.0,
            )
            setattr(self, f"encoder_{i}", layer)
            self.layers.append(layer)

        # --- Classifier ---
        self.classifier = nb.nn.Linear(d_model, num_classes)

    def forward(self, token_ids):
        # Embed + positional encoding
        x = self.embedding(token_ids)
        seq_len = token_ids.shape[-1]
        pe = nb.slice_tensor(self.pe, start=(0, 0), size=(seq_len, self.d_model))
        x = x + pe

        # Encoder layers
        for layer in self.layers:
            x = layer(x)

        # Mean pool + classify
        return self.classifier(nb.mean(x, axis=-2))

## 3. Create Synthetic Data

Generate a simple classification task:
- Sequences of random token IDs
- Labels based on a rule (e.g., majority token determines class)

In [4]:
np.random.seed(42)

vocab_size = 20
seq_len = 8
num_classes = 3
n_samples = 150
d_model = 32
num_heads = 4
num_layers = 2

# Generate random token sequences
token_ids_np = np.random.randint(0, vocab_size, (n_samples, seq_len)).astype(np.int64)

# Labels: class = (sum of tokens) mod num_classes
labels_np = (token_ids_np.sum(axis=1) % num_classes).astype(np.int64)

# One-hot encode labels
labels_onehot_np = np.zeros((n_samples, num_classes), dtype=np.float32)
labels_onehot_np[np.arange(n_samples), labels_np] = 1.0

token_ids = nb.Tensor.from_dlpack(token_ids_np)
labels = nb.Tensor.from_dlpack(labels_onehot_np)

print(f"Dataset: {n_samples} sequences of length {seq_len}")
print(f"Vocab size: {vocab_size}, Classes: {num_classes}")
print(f"Sample tokens: {token_ids_np[0]}")
print(f"Sample label:  {labels_np[0]}")

Dataset: 150 sequences of length 8
Vocab size: 20, Classes: 3
Sample tokens: [ 6 19 14 10  7  6 18 10]
Sample label:  0


## 4. Build Model and Optimizer

**Important initialization order for the stateful optimizer:**
Create `AdamW` while the model is in **train mode** (`_training=True`, the default).
Nabla's Module pytree includes `_training` in its metadata, so the optimizer's internal
moment tensors (`m`, `v`) are snapshot-initialized with that training mode.
Calling `model.eval()` *before* creating the optimizer would bake `_training=False` into
those snapshots, causing a pytree metadata mismatch the first time `model.train()` is
called inside the training loop.

Rule of thumb:
- **Stateful optimizer** (`AdamW(model)`) → create in **train mode**, call `model.eval()` only for eval passes.
- **Functional optimizer** (`adamw_init(model)`) → `model.eval()` *before* `adamw_init` so every pass shares the same `_training=False` state.

In [5]:
model = TransformerClassifier(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    num_classes=num_classes,
    max_len=seq_len,
    dim_feedforward=64,
)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model: {num_layers} encoder layers, d_model={d_model}, heads={num_heads}")
print(f"Total trainable parameters: {n_params}")

# Create optimizer while model is in train mode (default _training=True)
model.train()
optimizer = nb.nn.optim.AdamW(model, lr=1e-3)
print(f"Optimizer: AdamW (lr={optimizer.lr})")

Model: 2 encoder layers, d_model=32, heads=4
Total trainable parameters: 17827
Optimizer: AdamW (lr=0.001)


## 5. PyTorch-Style Training Loop

Imperative four-step loop: `zero_grad → forward → backward → step`

- `model.train()` at the top of each iteration ensures the model is in train mode before `optimizer.step()`.
- `loss.backward()` populates `.grad` on every `requires_grad=True` parameter and **batch-realizes** all gradients before returning.
- `optimizer.step()` (no arguments) reads `.grad`, applies the AdamW update, and returns the updated model.
- Assigning `model = optimizer.step()` is necessary because Nabla's lazy execution cannot mutate tensor data truly in-place.

For comparability, we use **60 training steps** and record:
- total eager loop time
- average milliseconds per eager step

In [6]:
num_epochs = 60

print(f"\n{'Epoch':<8} {'Loss':<12} {'Accuracy':<10}")
print("-" * 32)

eager_train_start = time.perf_counter()

for epoch in range(num_epochs):
    model.train()        # ensures train mode before optimizer step
    model.zero_grad()    # clear .grad from previous iteration

    # Forward pass
    logits = model(token_ids)
    loss = nb.nn.functional.cross_entropy_loss(logits, labels)

    # Backward pass — fills .grad on all trainable parameters
    loss.backward()

    # Optimizer step — reads .grad, applies AdamW, returns updated model
    model = optimizer.step()

    if (epoch + 1) % 10 == 0:
        model.eval()
        logits_eval = model(token_ids)
        pred_classes = nb.argmax(logits_eval, axis=-1)
        target_classes = nb.Tensor.from_dlpack(labels_np.astype(np.int64))
        correct = nb.equal(pred_classes, target_classes)
        accuracy = nb.mean(nb.cast(correct, nb.DType.float32)).item()
        print(f"{epoch + 1:<8} {loss.item():<12.4f} {accuracy:<10.2%}")

eager_train_elapsed = time.perf_counter() - eager_train_start
eager_train_step_ms = (eager_train_elapsed / max(1, num_epochs)) * 1000.0
print(f"\nEager PyTorch-style training time: {eager_train_elapsed:.4f} s")
print(f"Eager PyTorch-style avg step:     {eager_train_step_ms:.3f} ms/step")


Epoch    Loss         Accuracy  
--------------------------------
10       3.0606       30.67%    
20       2.6096       30.67%    
30       2.2841       30.00%    
40       2.0109       32.00%    
50       1.7453       32.00%    
60       1.5089       32.67%    

Eager PyTorch-style training time: 6.3128 s
Eager PyTorch-style avg step:     105.213 ms/step


## 6. Compiled Training (Bonus)

`@nb.compile` runs the same training-step function with cached compiled execution when input metadata matches (shape, dtype, sharding, structure).

> **API note:** Inside a compiled function, `value_and_grad` (functional transform) must be used — not `loss.backward()`. The imperative `.backward()` / `.grad` path is for eager execution only.

Speedup interpretation in this notebook is simple:
- compiled cached runs remove most Python overhead
- eager runs keep Python in the step loop

In [7]:
import time

# Fresh model + functional optimizer state for compiled run
model_c = TransformerClassifier(
    vocab_size=vocab_size, d_model=d_model, num_heads=num_heads,
    num_layers=num_layers, num_classes=num_classes,
    max_len=seq_len, dim_feedforward=64,
)
# eval() BEFORE adamw_init so both share _training=False pytree structure
model_c.eval()
opt_state_c = nb.nn.optim.adamw_init(model_c)


def loss_fn_for_compile(model, tokens, targets):
    logits = model(tokens)
    return nb.nn.functional.cross_entropy_loss(logits, targets)


@nb.compile
def compiled_step(model, opt_state, tokens, targets):
    loss, grads = nb.value_and_grad(loss_fn_for_compile, argnums=0)(
        model, tokens, targets
    )
    model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-3)
    return model, opt_state, loss


def eager_step(model, opt_state, tokens, targets):
    loss, grads = nb.value_and_grad(loss_fn_for_compile, argnums=0)(
        model, tokens, targets
    )
    model, opt_state = nb.nn.optim.adamw_update(model, grads, opt_state, lr=1e-3)
    return model, opt_state, loss

Run the compiled timing loop.

For fair comparison, this uses the **same number of steps** as eager training (`num_epochs = 60`).

In [8]:
# Use the same step count as the eager loop for comparability
n_timed_steps = num_epochs

print(f"\nCompiled training (functional API inside @nb.compile):")
print(f"{'Step':<8} {'Loss':<12}")
print("-" * 22)

# 1) First compiled call includes trace+compile overhead
compile_start = time.perf_counter()
model_c, opt_state_c, loss_c = compiled_step(model_c, opt_state_c, token_ids, labels)
first_compiled_ms = (time.perf_counter() - compile_start) * 1000.0

# 2) Cached compiled execution timing (same loop length as eager)
cached_start = time.perf_counter()
for step in range(1, n_timed_steps):
    model_c, opt_state_c, loss_c = compiled_step(model_c, opt_state_c, token_ids, labels)
    if (step + 1) % 10 == 0:
        print(f"{step + 1:<8} {loss_c.item():<12.4f}")
cached_elapsed = time.perf_counter() - cached_start
cached_step_ms = (cached_elapsed / max(1, n_timed_steps - 1)) * 1000.0

print("\nCompiled cache stats:", compiled_step.stats)
print(f"First compiled call (trace+compile): {first_compiled_ms:.2f} ms")
print(f"Cached compiled step avg:          {cached_step_ms:.2f} ms")

# 3) Eager functional baseline timing (same math, no @nb.compile)
model_e = TransformerClassifier(
    vocab_size=vocab_size, d_model=d_model, num_heads=num_heads,
    num_layers=num_layers, num_classes=num_classes,
    max_len=seq_len, dim_feedforward=64,
)
model_e.eval()
opt_state_e = nb.nn.optim.adamw_init(model_e)

# one warmup
model_e, opt_state_e, _ = eager_step(model_e, opt_state_e, token_ids, labels)

eager_start = time.perf_counter()
for _ in range(n_timed_steps - 1):
    model_e, opt_state_e, loss_e = eager_step(model_e, opt_state_e, token_ids, labels)
eager_elapsed = time.perf_counter() - eager_start
eager_step_ms = (eager_elapsed / max(1, n_timed_steps - 1)) * 1000.0

speedup = eager_step_ms / max(cached_step_ms, 1e-9)
print(f"Eager functional step avg:         {eager_step_ms:.2f} ms")
print(f"Compiled cached speedup vs eager:  {speedup:.2f}x")


Compiled training (functional API inside @nb.compile):
Step     Loss        
----------------------
10       1.1420      
20       1.0677      
30       1.0248      
40       0.9630      
50       0.8735      
60       0.7395      

Compiled cache stats: CompilationStats(hits=59, misses=1, fallbacks=0, hit_rate=98.3%)
First compiled call (trace+compile): 1047.90 ms
Cached compiled step avg:          10.83 ms
Eager functional step avg:         157.98 ms
Compiled cached speedup vs eager:  14.59x


## Summary

### PyTorch-Style (Eager) — This Notebook

| Component | API |
|-----------|-----|
| Token embedding | `nb.nn.Embedding(vocab_size, d_model)` |
| Transformer layer | `nb.nn.TransformerEncoderLayer(d_model, heads, ff_dim)` |
| Fixed buffer | `tensor.requires_grad = False` |
| Training mode | `model.train()` / `model.eval()` |
| Clear gradients | `model.zero_grad()` |
| Compute gradients | `loss.backward()` |
| Parameter update | `model = optimizer.step()` |

### JAX-Style (Functional) — See [6b](06b_transformer_jax)

| Concept | API |
|---------|-----|
| Model state | Nested dict pytree |
| Compute loss + grads | `nb.value_and_grad(fn, argnums=0)(params, ...)` |
| Optimizer | `adamw_init` + `adamw_update` |
| Compiled training | `@nb.compile` on functional `value_and_grad` step |