# MLP in BQN on JAX — Polynomial Regression

This notebook is an interactive version of `scripts/simple_mlp_jaxbqn.py`.  
It trains a small MLP (1 → 16 → 1) expressed entirely in BQN, compiled to
JAX, and optimized with an Adam optimizer also written in BQN.

**All math lives in nine BQN expressions** — JAX supplies the
differentiation, JIT compilation, and vectorization.

In [1]:
from pathlib import Path
import sys

def _find_repo_root(start: Path) -> Path:
    for candidate in (start, *start.parents):
        if (candidate / "pyproject.toml").exists() and (candidate / "src" / "bqn_jax").exists():
            return candidate
    raise RuntimeError("Could not locate the bqn-jax repo root.")

_repo_root = _find_repo_root(Path.cwd())
_src = _repo_root / "src"
if str(_src) not in sys.path:
    sys.path.insert(0, str(_src))

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from bqn_jax import ShapePolicy, compile_expression

%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['figure.dpi'] = 100
print(f"JAX version: {jax.__version__}")

JAX version: 0.9.0.1


## The Nine BQN Expressions

| # | Purpose | BQN |
|---|---------|-----|
| E0 | Forward pass (dense + ReLU-like + output) | `+´(v×÷⟜(1˙⊸+∘\|)(w(+´∘×)x+b))+bo` |
| E1 | Target polynomial | `((0.2+(0.9×(+´x)))-(0.3×((+´x)⋆2)))+(0.4×((+´x)⋆3))` |
| E2 | MSE loss | `+´×˜(p-y)÷≠y` |
| E3 | RMSE | `√(+´×˜(p-y)÷≠y)` |
| E4 | Adam momentum update | `(beta1×m)+((1-beta1)×g)` |
| E5 | Adam velocity update | `(beta2×v)+((1-beta2)××˜g)` |
| E6 | Bias-corrected momentum | `m÷(1-(beta1⋆t))` |
| E7 | Bias-corrected velocity | `v÷(1-(beta2⋆t))` |
| E8 | Parameter update | `p-(lr×mhat÷((√vhat)+eps))` |

In [2]:
# ── BQN source expressions ──
E0 = "+´(v×÷⟜(1˙⊸+∘|)(w(+´∘×)x+b))+bo"       # forward pass
E1 = "((0.2 + (0.9 × (+´x))) - (0.3 × ((+´x) ⋆ 2))) + (0.4 × ((+´x) ⋆ 3))"  # target
E2 = "+´×˜(p-y)÷≠y"                              # MSE
E3 = "√(+´×˜(p-y)÷≠y)"                           # RMSE
E4 = "(beta1×m)+((1-beta1)×g)"                    # Adam m-update
E5 = "(beta2×v)+((1-beta2)××˜g)"                  # Adam v-update
E6 = "m÷(1-(beta1⋆t))"                            # bias-correct m
E7 = "v÷(1-(beta2⋆t))"                            # bias-correct v
E8 = "p-(lr×mhat÷((√vhat)+eps))"                  # param update

print("Nine BQN expressions define the entire MLP + optimizer:")
for i, e in enumerate([E0,E1,E2,E3,E4,E5,E6,E7,E8]):
    print(f"  E{i} = {e}")

Nine BQN expressions define the entire MLP + optimizer:
  E0 = +´(v×÷⟜(1˙⊸+∘|)(w(+´∘×)x+b))+bo
  E1 = ((0.2 + (0.9 × (+´x))) - (0.3 × ((+´x) ⋆ 2))) + (0.4 × ((+´x) ⋆ 3))
  E2 = +´×˜(p-y)÷≠y
  E3 = √(+´×˜(p-y)÷≠y)
  E4 = (beta1×m)+((1-beta1)×g)
  E5 = (beta2×v)+((1-beta2)××˜g)
  E6 = m÷(1-(beta1⋆t))
  E7 = v÷(1-(beta2⋆t))
  E8 = p-(lr×mhat÷((√vhat)+eps))


## Compile BQN → JAX

In [3]:
_STATIC  = ShapePolicy(kind="static")
_DYNAMIC = ShapePolicy(kind="dynamic")
_ce_s = lambda expr, names: compile_expression(expr, arg_names=names, shape_policy=_STATIC)
_ce_d = lambda expr, names: compile_expression(expr, arg_names=names, shape_policy=_DYNAMIC)

F   = _ce_d(E0, ("x", "w", "b", "v", "bo"))   # forward
YF  = _ce_d(E1, ("x",))                        # target
LF  = _ce_s(E2, ("p", "y"))                     # loss
AF  = _ce_d(E3, ("p", "y"))                     # rmse
MUF = _ce_s(E4, ("m", "g", "beta1"))            # Adam momentum
VUF = _ce_s(E5, ("v", "g", "beta2"))            # Adam velocity
MHF = _ce_s(E6, ("m", "beta1", "t"))            # bias-correct m
VHF = _ce_s(E7, ("v", "beta2", "t"))            # bias-correct v
PUF = _ce_s(E8, ("p", "mhat", "vhat", "lr", "eps"))  # param update

# Vectorize over batch dimension
Y_MAP    = YF.vmap(in_axes=0, out_axes=0)
PRED_MAP = F.vmap(in_axes=(0, None, None, None, None), out_axes=0)

print("All BQN expressions compiled to JAX successfully.")

BQNUnsupportedError: Unsupported in JAX IR backend: call forms beyond fold-call lowering

## Generate Data

The target polynomial is:

$$y = 0.2 + 0.9x - 0.3x^2 + 0.4x^3$$

expressed in BQN as `E1`.

In [None]:
SEED = 0
N_SAMPLES = 2048

k = jax.random.PRNGKey(SEED)
x_all = jax.random.uniform(k, (N_SAMPLES, 1), minval=-1.5, maxval=1.5, dtype=jnp.float32)
y_all = Y_MAP(x_all).astype(jnp.float32)

n = int(0.8 * N_SAMPLES)
xt, yt = x_all[:n], y_all[:n]
xv, yv = x_all[n:], y_all[n:]

print(f"Training: {xt.shape[0]} samples  |  Validation: {xv.shape[0]} samples")

# Plot the data
fig, ax = plt.subplots()
idx = np.argsort(np.array(xt[:, 0]))
ax.scatter(np.array(xt[:, 0]), np.array(yt), s=4, alpha=0.3, label='Train')
ax.scatter(np.array(xv[:, 0]), np.array(yv), s=4, alpha=0.3, color='orange', label='Val')

xs = np.linspace(-1.5, 1.5, 300).reshape(-1, 1)
ys = np.array(Y_MAP(jnp.array(xs, dtype=jnp.float32)))
ax.plot(xs[:, 0], ys, 'k-', lw=2, label='Target polynomial')
ax.set_xlabel('x'); ax.set_ylabel('y')
ax.set_title('Target: $y = 0.2 + 0.9x - 0.3x^2 + 0.4x^3$')
ax.legend(); ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Initialize Weights & Adam State

In [None]:
k1, k2 = jax.random.split(jax.random.PRNGKey(SEED + 1))

HIDDEN = 16

# w: (1, 16), b: (16,), v: (16,), bo: scalar
w = (
    0.2 * jax.random.normal(k1, (1, HIDDEN), dtype=jnp.float32),   # input weights
    jnp.zeros((HIDDEN,), dtype=jnp.float32),                       # hidden bias
    0.2 * jax.random.normal(k2, (HIDDEN,), dtype=jnp.float32),     # output weights
    jnp.asarray(0.0, dtype=jnp.float32),                           # output bias
)

# Adam state
m_state = jax.tree_util.tree_map(jnp.zeros_like, w)
v_state = jax.tree_util.tree_map(jnp.zeros_like, w)
t = jnp.asarray(1.0, dtype=jnp.float32)

# Hyperparameters
LR    = jnp.asarray(0.02, dtype=jnp.float32)
BETA1 = jnp.asarray(0.9,  dtype=jnp.float32)
BETA2 = jnp.asarray(0.999, dtype=jnp.float32)
EPS   = jnp.asarray(1e-8, dtype=jnp.float32)

print(f"Architecture: 1 → {HIDDEN} (BQN ReLU-like) → 1")
n_params = sum(p.size for p in jax.tree_util.tree_leaves(w))
print(f"Parameters: {n_params}")
print(f"Optimizer: Adam (lr={float(LR)}, β₁={float(BETA1)}, β₂={float(BETA2)})")

## Training Loop

In [None]:
def pred(q, xb):
    return PRED_MAP(xb, q[0], q[1], q[2], q[3])

def loss(q, xb, yb):
    return LF(pred(q, xb), yb)

@jax.jit
def step(q, m_acc, v_acc, step_t, xb, yb):
    l, grads = jax.value_and_grad(loss)(q, xb, yb)
    m_acc = jax.tree_util.tree_map(lambda m, g: MUF(m, g, BETA1), m_acc, grads)
    v_acc = jax.tree_util.tree_map(lambda v, g: VUF(v, g, BETA2), v_acc, grads)
    m_hat = jax.tree_util.tree_map(lambda m: MHF(m, BETA1, step_t), m_acc)
    v_hat = jax.tree_util.tree_map(lambda v: VHF(v, BETA2, step_t), v_acc)
    q = jax.tree_util.tree_map(lambda p, mh, vh: PUF(p, mh, vh, LR, EPS), q, m_hat, v_hat)
    return q, m_acc, v_acc, step_t + 1.0, l

def rmse(q, xb, yb):
    return float(AF(pred(q, xb), yb))

EPOCHS = 1000
LOG_EVERY = 50

train_losses, val_losses = [], []
train_rmses, val_rmses = [], []
snapshots = {}  # epoch -> prediction on grid

xs_grid = jnp.linspace(-1.5, 1.5, 300).reshape(-1, 1).astype(jnp.float32)

for e in range(EPOCHS):
    w, m_state, v_state, t, l = step(w, m_state, v_state, t, xt, yt)
    tl = float(l)
    vl = float(loss(w, xv, yv))
    train_losses.append(tl)
    val_losses.append(vl)

    if e % LOG_EVERY == 0 or e == EPOCHS - 1:
        tr = rmse(w, xt, yt)
        vr = rmse(w, xv, yv)
        train_rmses.append(tr)
        val_rmses.append(vr)
        snapshots[e] = np.array(pred(w, xs_grid))
        print(f"epoch {e:4d}  loss={tl:.6f}  train_rmse={tr:.4f}  val_rmse={vr:.4f}")

print("\nTraining complete!")

## Loss & RMSE Curves

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(train_losses, label='Train MSE', color='#2563eb', lw=1.2)
ax1.plot(val_losses, label='Val MSE', color='#dc2626', lw=1.2, alpha=0.7)
ax1.set_xlabel('Epoch'); ax1.set_ylabel('MSE')
ax1.set_title('Loss Curves')
ax1.set_yscale('log')
ax1.legend(); ax1.grid(True, alpha=0.3)

rmse_epochs = sorted(snapshots.keys())
ax2.plot(rmse_epochs, train_rmses, 'o-', ms=4, label='Train RMSE', color='#2563eb')
ax2.plot(rmse_epochs, val_rmses, 's-', ms=4, label='Val RMSE', color='#dc2626')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('RMSE')
ax2.set_title('RMSE Over Training')
ax2.legend(); ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## MLP Fit Progression

Watch the network learn the polynomial shape.

In [None]:
ys_true = np.array(Y_MAP(xs_grid))

fig, ax = plt.subplots(figsize=(12, 7))
ax.scatter(np.array(xt[:, 0]), np.array(yt), s=4, alpha=0.15, color='gray', label='Train data')
ax.plot(np.array(xs_grid[:, 0]), ys_true, 'k--', lw=2, alpha=0.6, label='Target polynomial')

colors = plt.cm.plasma(np.linspace(0.1, 0.9, len(snapshots)))
for (ep, y_pred), color in zip(sorted(snapshots.items()), colors):
    ax.plot(np.array(xs_grid[:, 0]), y_pred, color=color, lw=1.5, alpha=0.8, label=f'Epoch {ep}')

ax.set_xlabel('x'); ax.set_ylabel('y')
ax.set_title('MLP Fit Progression — BQN Forward Pass + Adam Optimizer')
ax.legend(fontsize=7, ncol=2, loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Final Prediction vs Target

In [None]:
y_final = np.array(pred(w, xs_grid))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Overlay
ax1.scatter(np.array(xv[:, 0]), np.array(yv), s=10, alpha=0.4, color='steelblue', label='Val data')
ax1.plot(np.array(xs_grid[:, 0]), ys_true, 'k-', lw=2, label='Target')
ax1.plot(np.array(xs_grid[:, 0]), y_final, 'r-', lw=2.5, label='MLP prediction')
ax1.fill_between(np.array(xs_grid[:, 0]), ys_true.ravel(), y_final.ravel(),
                  alpha=0.15, color='red')
final_vr = rmse(w, xv, yv)
ax1.set_title(f'Final Fit — Val RMSE = {final_vr:.4f}')
ax1.set_xlabel('x'); ax1.set_ylabel('y')
ax1.legend(); ax1.grid(True, alpha=0.3)

# Residual plot
residuals = y_final.ravel() - ys_true.ravel()
ax2.plot(np.array(xs_grid[:, 0]), residuals, 'r-', lw=1.5)
ax2.axhline(0, color='k', ls='--', lw=1)
ax2.fill_between(np.array(xs_grid[:, 0]), residuals, alpha=0.2, color='red')
ax2.set_xlabel('x'); ax2.set_ylabel('Residual')
ax2.set_title('Residuals (MLP − Target)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Hidden Unit Activations

Peek inside the trained network to see what each of the 16 hidden units learned.

In [None]:
# Compute pre- and post-activation for each hidden unit
W_in, bias, W_out, b_out = w
pre_act = np.array(xs_grid) @ np.array(W_in).T + np.array(bias)  # (300, 16)
post_act = pre_act / (1 + np.abs(pre_act))  # BQN's ÷⟜(1˙⊸+∘|)

fig, axes = plt.subplots(4, 4, figsize=(14, 10), sharex=True)
x_np = np.array(xs_grid[:, 0])
w_out_np = np.array(W_out)

for i, ax in enumerate(axes.flat):
    color = 'tab:blue' if w_out_np[i] >= 0 else 'tab:red'
    ax.plot(x_np, post_act[:, i], color=color, lw=1.5)
    ax.axhline(0, color='gray', ls=':', lw=0.5)
    ax.set_title(f'Unit {i} (w={w_out_np[i]:.2f})', fontsize=8)
    ax.tick_params(labelsize=6)
    ax.grid(True, alpha=0.2)

fig.suptitle('Hidden Unit Activations (blue=positive output weight, red=negative)', fontsize=12)
plt.tight_layout()
plt.show()

## Weight Distribution

In [None]:
all_params = np.concatenate([np.array(p).ravel() for p in jax.tree_util.tree_leaves(w)])

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(all_params, bins=40, color='#6366f1', edgecolor='white', alpha=0.8)
ax.axvline(0, color='k', ls='--', lw=1)
ax.set_xlabel('Weight value'); ax.set_ylabel('Count')
ax.set_title(f'Weight Distribution ({len(all_params)} parameters)')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nSummary")
print(f"  Nine BQN expressions compiled to JAX")
print(f"  {len(all_params)} trainable parameters")
print(f"  Final val RMSE: {rmse(w, xv, yv):.4f}")
print(f"  All math expressed in BQN, differentiated and JIT-compiled by JAX.")