# Deep Legendre Transform vs Classical Grid Methods

This notebook implements and compares the **Deep Legendre Transform (DLT)** algorithm with classical grid-based methods for computing convex conjugates (Legendre-Fenchel transforms).

## What this code does:

### Methods Compared:
- **DLT (Deep Legendre Transform)**: Neural network-based approach using the implicit formulation `g(∇f(x)) ≈ ⟨x, ∇f(x)⟩ - f(x)`
  - Architectures: ResNet, MLP, ICNN (Input Convex Neural Networks)
- **Classical Methods**:
  - **Lucet's algorithm**: Efficient nested 1D convex hull computations with O(dN^(d+1)) complexity
  - **Direct method**: Brute-force grid evaluation with O(N^(2d)) complexity

### Test Functions:
- **Quadratic**: `f(x) = 0.5 ∑x_i²` (closed-form conjugate available)
- **Negative Log**: `f(x) = -∑log(x_i)` (closed-form conjugate available)
- **Negative Entropy**: `f(x) = ∑x_i log(x_i)` (closed-form conjugate available)



| **Function** | **$u(x)$** | **Domain $C$** | **$u^*(y)$** | **$D = \nabla u(C)$** |
|--------------|------------|----------------|--------------|----------------------|
| Quadratic | $\frac{1}{2} \sum x_i^2$ | $\mathcal{N}(0,1)^d$ | $\frac{1}{2} \sum y_i^2$ | $\mathcal{N}(0,1)^d$ |
| Neg-Log | $-\sum \log(x_i)$ | $\exp(U[-2.3,2.3])^d$ | $-d - \sum \log(-y_i)$ | $\{-1/x : x\in C\} \approx [-10,-0.1]^d$ |
| Neg-Entropy | $\sum x_i \log x_i$ | $\exp(U[-2.3,2.3])^d$ | $\sum \exp(y_i - 1)$ | $\{\log x + 1 : x\in C\} \approx [-1.3,3.3]^d$ |


### Key Features:
- Early stopping with patience for DLT training
- Optional log-uniform sampling for positive-domain functions
- Memory and compute feasibility checks for classical methods
- Comprehensive evaluation metrics (RMSE, relative L2 error, max error)
- Performance comparison across dimensions d ∈ {2, 3, ..., 50, 100, 200}

### Output:
- CSV table with detailed performance metrics
- Log-log plots comparing accuracy vs computation time and memory usage
- Shows DLT scales to high dimensions where classical methods become infeasible



In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DLT (implicit) vs Classical (Nested Lucet + Direct) on SAME D = ∇u(C)

What’s included
---------------
- Fix for the original math-domain error (no scalar-boolean `jnp.where`).
- Optional **log-uniform x-sampling** for positive-domain functions
  (neg_log, neg_entropy) via --x_loguniform / --no_x_loguniform.
- Fast Nested Lucet (`lucet_nd_fast`) with preallocated buffers (no nditer).
- DLT training with streaming minibatches + early stopping (Stopper).
- Evaluation enabled for quadratic, neg_log, **and neg_entropy**.
- CSV output and simple log-log plots.

Examples
--------
# Paper-like domains, 10-per-dim classical grids, progress for Lucet
python dlt_vs_classical.py --dims 2 3 4 5 6 --domain_profile paper --lucet_progress

# Disable log-uniform x-sampling (use uniform x on positive domains)
python dlt_vs_classical.py --dims 8 10 20 --no_x_loguniform

# Force a specific DLT architecture and steps
python dlt_vs_classical.py --dims 10 20 --arch RESNET:256,256 --steps 40000
"""

import os, sys, time, math, argparse, subprocess
import numpy as np
from typing import Callable, List, Tuple, Dict, Sequence
from functools import partial
from itertools import product

# ---- minimal dependency check (Colab-safe) ----
def _ensure(pkgs):
    import importlib
    miss=[]
    for p in pkgs:
        try: importlib.import_module(p)
        except Exception: miss.append(p)
    if miss:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + miss)

_ensure(["flax","optax","pandas","matplotlib"])

import jax, jax.numpy as jnp, optax
from jax import random
from flax import linen as nn
from flax.training import train_state
import pandas as pd
import matplotlib.pyplot as plt

# ========================= 1) Test functions & domains =========================
f_quad    = lambda x: 0.5*jnp.sum(x**2, -1)
grad_quad = lambda x: x
fst_quad  = lambda y: 0.5*jnp.sum(y**2, -1)

f_nlog    = lambda x:-jnp.sum(jnp.log(x), -1)                # dom x>0
grad_nlog = lambda x:-1.0/x
fst_nlog  = lambda y:-jnp.sum(jnp.log(-y), -1) - y.shape[-1] # dom y<0

f_nent    = lambda x:jnp.sum(x*jnp.log(x), -1)               # dom x>0
grad_nent = lambda x:jnp.log(x)+1.0
fst_nent  = lambda y:jnp.sum(jnp.exp(y-1.0), -1)

# ---- Domain profiles ----------------------------------------------------------
DOMAINS_PAPER = {
    "quadratic":   dict(C=(-3.0,  3.0), D=(-3.0,  3.0)),
    "neg_log":     dict(C=( 0.1,  5.0), D=(-5.0, -0.1)),
    "neg_entropy": dict(C=(math.exp(-2.3), math.exp(2.3)), D=(-1.3, 3.3)),
}
DOMAINS_WIDE = {
    "quadratic":   dict(C=(-3.0,  3.0), D=(-3.0,  3.0)),
    "neg_log":     dict(C=( 0.1, 10.0), D=(-10.0,-0.1)),
    "neg_entropy": dict(C=(math.exp(-2.3), math.exp(2.3)), D=(-1.3, 3.3)),
}

FUNS = {
    "quadratic": dict(f=f_quad, g=grad_quad, fst=fst_quad, printable="Quadratic"),
    "neg_log": dict(f=f_nlog, g=grad_nlog, fst=fst_nlog, printable="Neg. Log"),
    "neg_entropy": dict(f=f_nent, g=grad_nent, fst=fst_nent, printable="Neg. Entropy"),
}
# Evaluate DLT error wherever u* is closed-form (all three here)
EVAL_DLT_FUNS = {"quadratic", "neg_log", "neg_entropy"}

# ========================= 2) DLT model & training ============================
def _act(name: str) -> Callable:
    n=name.lower()
    if n=="relu": return nn.relu
    if n=="gelu": return jax.nn.gelu
    if n=="softplus": return jax.nn.softplus
    raise ValueError(f"Unknown activation {name}")

class DensePos(nn.Module):
    features:int
    use_bias:bool=True
    @nn.compact
    def __call__(self,x):
        W=nn.softplus(self.param("rawW", nn.initializers.lecun_normal(),
                                 (x.shape[-1], self.features)))
        y=x@W
        if self.use_bias: y += self.param("b", nn.initializers.zeros, (self.features,))
        return y

class ICNN(nn.Module):
    hidden:Sequence[int]
    act: Callable = nn.softplus
    @nn.compact
    def __call__(self,x):
        z=jnp.zeros((x.shape[0],1))
        for h in self.hidden:
            z=self.act(DensePos(h)(z) + nn.Dense(h)(x))
        out=DensePos(1, use_bias=False)(z) + nn.Dense(1, use_bias=False)(x)
        return jnp.squeeze(out, -1)

class MLP(nn.Module):
    hidden:Sequence[int]
    act: Callable = nn.relu
    @nn.compact
    def __call__(self,x):
        for h in self.hidden: x=self.act(nn.Dense(h)(x))
        return jnp.squeeze(nn.Dense(1)(x), -1)

class ResBlock(nn.Module):
    features:int
    act: Callable
    @nn.compact
    def __call__(self, x):
        h = self.act(nn.Dense(self.features)(x))
        h = nn.Dense(self.features)(h)
        if x.shape[-1] != self.features:
            x = nn.Dense(self.features, use_bias=False)(x)
        return self.act(h + x)

class ResNet(nn.Module):
    hidden:Sequence[int]
    act: Callable = jax.nn.gelu
    @nn.compact
    def __call__(self, x):
        assert len(self.hidden) >= 1
        h = self.act(nn.Dense(self.hidden[0])(x))
        for w in self.hidden:
            h = ResBlock(w, act=self.act)(h)
        out = nn.Dense(1)(h)
        return jnp.squeeze(out, -1)

def parse_hidden(s): return tuple(int(v) for v in s.split(",") if v)

class State(train_state.TrainState): ...
def schedule(lr): return optax.exponential_decay(lr, 20_000, 0.5, staircase=True)
def new_state(rng, model, d, lr):
    params = model.init(rng, jnp.zeros((1,d), jnp.float32))["params"]
    return State.create(apply_fn=model.apply, params=params, tx=optax.adam(schedule(lr)))

# implicit DLT loss: g(∇u(x)) ≈ <x,∇u(x)> − u(x)
def loss_impl(p,af,x,f,g):
    y=g(x); target=jnp.sum(x*y, -1) - f(x); pred=af({"params":p}, y)
    return jnp.mean((pred - target)**2)

@partial(jax.jit, static_argnums=(2,3))
def step_impl(st, x, f, g):
    l,gr=jax.value_and_grad(loss_impl)(st.params, st.apply_fn, x, f, g)
    return st.apply_gradients(grads=gr), l

def count_param_bytes(params):
    leaves = jax.tree_util.tree_leaves(params)
    return sum(np.prod(l.shape) * np.dtype(l.dtype).itemsize for l in leaves)

def estimate_dlt_mem_mb(params, d, bs, model_name, hidden):
    pbytes = count_param_bytes(params); adam=2*pbytes; grads=pbytes
    act_dim = d + sum(hidden) + 1
    act_bytes = bs * act_dim * 4  # float32
    return (pbytes + adam + grads + act_bytes) / (1024**2)

def clamp(v, lo, hi): return max(lo, min(hi, v))
def auto_steps_from_dim(d: int) -> int:
    base = max(10_000, int(round(1000 * d * math.log(max(d, 2)))))
    if d <= 10:   return clamp(base, 5_000, 20_000)
    if d <= 100:  return clamp(base, 20_000, 100_000)
    return max(base, 100_000)
def auto_patience_for_steps(steps: int, d: int) -> int:
    if d <= 10:   return max(3000, min(steps//5, 10_000))
    if d <= 100:  return max(5000, min(steps//4, 20_000))
    return max(10_000, min(steps//3, 30_000))
def auto_samples_and_batch(d: int, k_train: int) -> Tuple[int,int]:
    k = int(clamp(k_train, 100, 1000))
    N = max(d, k * d)
    B = max(16, min(64, max(1, N // 10)))
    return N, B

class Stopper:
    def __init__(self, pat:int, tol:float=1e-6):
        self.best=float("inf"); self.pat=int(pat); self.tol=float(tol)
        self.cnt=0; self.bp=None
    def update(self, loss, params):
        lv=float(loss)
        if lv + self.tol < self.best:
            self.best, self.cnt, self.bp = lv, 0, params
        else:
            self.cnt += 1
        return self.cnt >= self.pat or self.best < self.tol
    def res(self): return self.best, self.bp

def train_dlt_dataset(model_maker, d, f, g, samp_x, steps, lr, patience, seed, N, bs):
    key0 = random.PRNGKey(seed)
    st   = new_state(key0, model_maker(), d, lr)
    stop = Stopper(patience)
    t0   = time.perf_counter()
    for i in range(int(steps)):
        mb = samp_x(random.fold_in(key0, i), (int(bs), d))
        st, loss = step_impl(st, mb, f, g)
        if stop.update(loss, st.params): break
    best_loss, best_params = stop.res()
    if best_params is None: best_params = st.params
    return best_params, time.perf_counter() - t0, int(bs)

# ========================= 3) Classical: fast Lucet & Direct ==================
def llt_1d(x,u,s):
    x,u,s = map(np.asarray,(x,u,s))
    hx,hu=np.empty_like(x),np.empty_like(u); h=0
    for xi,ui in zip(x,u):
        while h>=2 and (hu[h-1]-hu[h-2])*(xi-hx[h-1]) >= (ui-hu[h-1])*(hx[h-1]-hx[h-2]):
            h-=1
        hx[h],hu[h]=xi,ui; h+=1
    hx,hu=hx[:h],hu[:h]
    edge=np.concatenate(([-np.inf],np.diff(hu)/np.diff(hx),[np.inf]))
    out,k=np.empty_like(s),0
    for j,sj in enumerate(s):
        while sj>edge[k+1]: k+=1
        out[j]=sj*hx[k]-hu[k]
    return out

def _llt_1d_inplace_signed(x, u, s, hx, hu, edge, out, negate_u: bool):
    n = x.shape[0]; m = s.shape[0]
    h = 0
    for i in range(n):
        xi = x[i]
        ui = -u[i] if negate_u else u[i]
        while h >= 2 and (hu[h-1]-hu[h-2])*(xi-hx[h-1]) >= (ui-hu[h-1])*(hx[h-1]-hx[h-2]):
            h -= 1
        hx[h] = xi
        hu[h] = ui
        h += 1
    edge[0] = -np.inf
    for i in range(h-1):
        edge[i+1] = (hu[i+1] - hu[i]) / (hx[i+1] - hx[i])
    edge[h] = np.inf
    k = 0
    for j in range(m):
        sj = s[j]
        while sj > edge[k+1]:
            k += 1
        out[j] = sj * hx[k] - hu[k]

def lucet_nd_fast(x_arrs, f, s_arrs, progress=False):
    d = len(x_arrs)
    X = np.meshgrid(*x_arrs, indexing='ij', sparse=False)
    V = f(*X)  # shape (Nx,)*d
    flip = False
    for ax in reversed(range(d)):
        x = x_arrs[ax]; s = s_arrs[ax]
        V = np.moveaxis(V, ax, 0)
        L = len(x)
        rest_shape = V.shape[1:]
        M = int(np.prod(rest_shape, dtype=np.int64)) if rest_shape else 1
        V2 = V.reshape(L, M)
        out2 = np.empty((len(s), M), dtype=float)
        hx = np.empty(L, dtype=float)
        hu = np.empty(L, dtype=float)
        edge = np.empty(L+1, dtype=float)
        wout = np.empty(len(s), dtype=float)
        if progress:
            print(f"[Lucet] axis={ax} lines={M:,} L={L} Ns={len(s)} flip={flip}")
        for j in range(M):
            uline = V2[:, j]
            _llt_1d_inplace_signed(x, uline, s, hx, hu, edge, wout, negate_u=flip)
            out2[:, j] = wout
        V = out2.reshape((len(s),) + rest_shape)
        V = np.moveaxis(V, 0, ax)
        flip = not flip
    return V

def lucet_nd(x_arrs,f,s_arrs):
    # reference (unused – fast version is called)
    V=f(*np.meshgrid(*x_arrs,indexing='ij',sparse=False))
    flip=False
    for axis in reversed(range(len(x_arrs))):
        x,s=x_arrs[axis],s_arrs[axis]
        V=np.moveaxis(V,axis,0)
        out=np.empty((len(s),)+V.shape[1:],float)
        it=np.nditer(V[0],flags=['multi_index'])
        while not it.finished:
            idx=it.multi_index
            line=V[(slice(None),)+idx]
            out[(slice(None),)+idx]=llt_1d(x,-line if flip else line,s)
            it.iternext()
        V=np.moveaxis(out,0,axis); flip=True
    return V

def direct_nd(x_arrs,f,s_arrs):
    X=np.meshgrid(*x_arrs,indexing='ij',sparse=False)
    U=f(*X)
    out=np.empty(tuple(len(sa) for sa in s_arrs))
    it=np.nditer(out,flags=['multi_index'],op_flags=['writeonly'])
    while not it.finished:
        slopes=[s_arrs[k][it.multi_index[k]] for k in range(len(x_arrs))]
        it[0]=np.max(sum(s*Xk for s,Xk in zip(slopes,X))-U)
        it.iternext()
    return out

def interp_nd(grid, axes, pt):
    idx_low, t = [], []
    for p,ax in zip(pt,axes):
        j=np.searchsorted(ax,p)-1
        j=np.clip(j,0,len(ax)-2)
        idx_low.append(j)
        t.append((p-ax[j])/(ax[j+1]-ax[j]))
    val=0.0
    for corners in product((0,1), repeat=len(pt)):
        w, idx = 1.0, []
        for c,tl,jl in zip(corners,t,idx_low):
            w*=tl if c else 1-tl
            idx.append(jl+c)
        val+=w*grid[tuple(idx)]
    return val

def _prod_len(arrs): return int(np.prod([len(a) for a in arrs], dtype=int))
def lucet_mem_mb_est(x_arrs, s_arrs):
    prod_x = _prod_len(x_arrs); prod_s = _prod_len(s_arrs)
    peak = 2 * max(prod_x, prod_s) * 8.0
    return peak/(1024**2)
def direct_mem_mb_est(x_arrs, s_arrs, d):
    prod_x = _prod_len(x_arrs); prod_s = _prod_len(s_arrs)
    bytes_total = ((d + 2) * prod_x + prod_s) * 8.0
    return bytes_total/(1024**2)
def feasible_lucet(x_arrs, s_arrs, mem_gb_cap, flop_cap):
    mem_mb = lucet_mem_mb_est(x_arrs, s_arrs)
    mem_ok = (mem_mb/1024.0) <= mem_gb_cap
    d=len(x_arrs); N=max(max(len(a) for a in x_arrs), max(len(a) for a in s_arrs))
    flops = float(d) * float(N**d) * 20.0
    return (mem_ok and flops <= flop_cap), mem_mb, flops
def feasible_direct(x_arrs, s_arrs, mem_gb_cap, flop_cap, d):
    prod_x = _prod_len(x_arrs); prod_s = _prod_len(s_arrs)
    mem_mb = direct_mem_mb_est(x_arrs, s_arrs, d); mem_ok = (mem_mb/1024.0) <= mem_gb_cap
    flops  = float(prod_x) * float(prod_s)
    return (mem_ok and flops <= flop_cap), mem_mb, flops

# ========================= 4) Build grids & shared Y_eval =====================
def make_samplers(domain, fn_key, x_loguniform: bool):
    """
    Optional log-uniform x sampling for positive-domain functions.
    - If fn_key in {neg_log, neg_entropy} and x_loguniform is True:
        x ~ log-uniform on [C_lo, C_hi]  (so y=g(x) is more uniform on D)
    - Else:
        x ~ uniform on [C_lo, C_hi]
    Slopes y=s are always sampled uniform on D for classical eval.
    """
    (x_lo,x_hi) = domain[fn_key]["C"]
    (s_lo,s_hi) = domain[fn_key]["D"]

    if fn_key in ("neg_log", "neg_entropy") and x_loguniform:
        if not (x_lo > 0 and x_hi > 0):
            raise ValueError(f"{fn_key} requires x>0, got C=({x_lo}, {x_hi})")
        log_x_lo = float(math.log(x_lo))
        log_x_hi = float(math.log(x_hi))
        def samp_x(k, sh):
            return jnp.exp(
                random.uniform(k, sh, minval=log_x_lo, maxval=log_x_hi, dtype=jnp.float32)
            )
    else:
        def samp_x(k, sh):
            return random.uniform(k, sh, minval=float(x_lo), maxval=float(x_hi), dtype=jnp.float32)

    def samp_y(k, sh):
        return random.uniform(k, sh, minval=float(s_lo), maxval=float(s_hi), dtype=jnp.float32)

    return samp_x, samp_y

def build_grids_and_eval(fn_key, d, Nx, Ns, n_eval, domain):
    (x_lo,x_hi)=domain[fn_key]["C"]; (s_lo,s_hi)=domain[fn_key]["D"]
    x_axes=[np.linspace(x_lo, x_hi, Nx)]*d
    s_axes=[np.linspace(s_lo, s_hi, Ns)]*d
    eps = 1e-6
    lo = s_lo + eps*(s_hi - s_lo)
    hi = s_hi - eps*(s_hi - s_lo)
    Y_eval = np.array(
        random.uniform(random.PRNGKey(1234+d), (n_eval, d),
                       minval=lo, maxval=hi, dtype=jnp.float32)
    )
    def f_np(*Xs):
        V=np.stack(Xs, axis=-1)
        if fn_key=="quadratic":       return 0.5*np.sum(V*V, axis=-1)
        if fn_key=="neg_log":         return -np.sum(np.log(np.maximum(V,1e-10)), axis=-1)
        if fn_key=="neg_entropy":
            Xs_=np.maximum(V,1e-10)
            return np.sum(Xs_*np.log(Xs_), axis=-1)
        raise ValueError
    def fst_np(Y):
        if fn_key=="quadratic":   return 0.5*np.sum(Y*Y, axis=-1)
        if fn_key=="neg_log":     return -np.sum(np.log(-Y), axis=-1) - d
        if fn_key=="neg_entropy": return np.sum(np.exp(Y-1.0), axis=-1)
        raise ValueError
    return x_axes, s_axes, Y_eval, f_np, fst_np

# ========================= 5) Metrics helpers =================================
def relative_l2(pred: np.ndarray, true: np.ndarray) -> float:
    denom = float(np.linalg.norm(true))
    if denom == 0.0: return float("nan")
    return float(np.linalg.norm(pred - true) / denom)

# ========================= 6) One (fn, d, Ns, arch) run =======================
def dlt_model_factory(model_name: str, hidden: Tuple[int, ...], act_name: str):
    act = _act(act_name)
    name = model_name.upper()
    if name == "ICNN": return lambda: ICNN(hidden=hidden, act=act)
    if name == "MLP":  return lambda: MLP(hidden=hidden, act=act)
    return lambda: ResNet(hidden=hidden, act=act)  # default: RESNET

def run_case_classical(fn_key, d, Nx, Ns, args, Y_eval, f_np, fst_np, domain):
    rows=[]
    (x_lo,x_hi)=domain[fn_key]["C"]; (s_lo,s_hi)=domain[fn_key]["D"]
    x_axes=[np.linspace(x_lo, x_hi, Nx)]*d
    s_axes=[np.linspace(s_lo, s_hi, Ns)]*d

    lo_ok = np.all(Y_eval >= np.array([ax[0] for ax in s_axes]), axis=1)
    hi_ok = np.all(Y_eval <= np.array([ax[-1] for ax in s_axes]), axis=1)
    mask  = np.logical_and(lo_ok, hi_ok)
    true  = fst_np(Y_eval[mask]) if np.any(mask) else None

    # ----- Lucet (fast) -----
    t_luc = None; mb_luc = None; rel_l2_l=None; rmse_l=None; maxe_l=None; t_eval_l=None
    if d <= 12:
        feas_L, mem_mb_L, _ = feasible_lucet(x_axes, s_axes, args.mem_gb_cap, args.flop_cap)
        if feas_L:
            t0=time.perf_counter()
            U_luc = lucet_nd_fast(x_axes, f_np, s_axes, progress=args.lucet_progress)
            t_luc = time.perf_counter() - t0
            if np.any(mask):
                t0=time.perf_counter()
                luc_vals = np.array([interp_nd(U_luc, s_axes, y) for y in Y_eval[mask]])
                t_eval_l = time.perf_counter() - t0
                diff = luc_vals - true
                maxe_l   = float(np.max(np.abs(diff)))
                rmse_l   = float(np.sqrt(np.mean(diff**2)))
                rel_l2_l = relative_l2(luc_vals, true)
            mb_luc = mem_mb_L
            rows.append(dict(
                Function=FUNS[fn_key]["printable"], fn_key=fn_key, d=d,
                method="Lucet", model="", hidden="",
                Nx=Nx, Ns=Ns, N_train=None, batch=None, steps=None, patience=None,
                t_solve=t_luc, t_eval=t_eval_l, mem_MB=mb_luc,
                max_err=maxe_l, rmse=rmse_l, relL2=rel_l2_l
            ))

    # ----- Direct -----
    t_def=None; mb_def=None; rel_l2_d=None; rmse_d=None; maxe_d=None; t_eval_d=None
    if d <= args.def_dim_max:
        feas_D, mem_mb_D, _ = feasible_direct(x_axes, s_axes, args.mem_gb_cap_def, args.flop_cap_def, d)
        if feas_D:
            t0=time.perf_counter()
            U_def = direct_nd(x_axes, f_np, s_axes)
            t_def = time.perf_counter() - t0
            if np.any(mask):
                t0=time.perf_counter()
                def_vals = np.array([interp_nd(U_def, s_axes, y) for y in Y_eval[mask]])
                t_eval_d = time.perf_counter() - t0
                diff = def_vals - true
                maxe_d   = float(np.max(np.abs(diff)))
                rmse_d   = float(np.sqrt(np.mean(diff**2)))
                rel_l2_d = relative_l2(def_vals, true)
            mb_def = mem_mb_D
            rows.append(dict(
                Function=FUNS[fn_key]["printable"], fn_key=fn_key, d=d,
                method="Direct", model="", hidden="",
                Nx=Nx, Ns=Ns, N_train=None, batch=None, steps=None, patience=None,
                t_solve=t_def, t_eval=t_eval_d, mem_MB=mb_def,
                max_err=maxe_d, rmse=rmse_d, relL2=rel_l2_d
            ))
    return rows

def run_case_dlt(fn_key, d, arch_name, hidden, act_name, args, Y_eval, f, g, fst, samp_x):
    steps = auto_steps_from_dim(d) if args.steps == "auto" else int(args.steps)
    N_train, batch = auto_samples_and_batch(d, args.k_train) if args.batch == "auto" else (max(d, args.k_train*d), int(args.batch))
    patience = auto_patience_for_steps(steps, d) if args.patience == "auto" else int(args.patience)

    maker = dlt_model_factory(arch_name, hidden, args.act)
    params, t_dlt, bs = train_dlt_dataset(
        maker, d, f, g, samp_x, steps=steps, lr=args.lr, patience=patience,
        seed=7000+d*17+hash(arch_name)%1000, N=N_train, bs=batch
    )
    mb_dlt = estimate_dlt_mem_mb(params, d, bs, arch_name, hidden)

    t_eval=None; maxe=None; rmse=None; rel=None
    if fn_key in EVAL_DLT_FUNS:
        yj = jnp.asarray(Y_eval, jnp.float32)
        t0=time.perf_counter()
        pred = np.array(maker().apply({"params":params}, yj))
        t_eval = time.perf_counter() - t0
        true   = np.array(fst(Y_eval))
        diff   = pred - true
        maxe   = float(np.max(np.abs(diff)))
        rmse   = float(np.sqrt(np.mean(diff**2)))
        rel    = relative_l2(pred, true)

    return dict(
        Function=FUNS[fn_key]["printable"], fn_key=fn_key, d=d,
        method="DLT", model=arch_name.upper(), hidden=",".join(map(str,hidden)),
        Nx=None, Ns=None, N_train=N_train, batch=bs, steps=steps, patience=patience,
        t_solve=t_dlt, t_eval=t_eval, mem_MB=mb_dlt,
        max_err=maxe, rmse=rmse, relL2=rel
    )

# ========================= 7) CLI & orchestration =============================
def parse_dim_map(tokens, default_map) -> Dict[int,int]:
    m = dict(default_map)
    for tok in tokens or []:
        d,v = tok.split(":"); m[int(d)] = int(v)
    return m
def parse_ns_sweep(tokens: List[str], default_map: Dict[int, List[int]]) -> Dict[int, List[int]]:
    m = {k:list(v) for k,v in default_map.items()}
    for tok in tokens or []:
        d, vs = tok.split(":")
        m[int(d)] = [int(x) for x in vs.split(",") if x]
    return m
def parse_arch_list(tokens: List[str], default: List[Tuple[str,Tuple[int,...]]]) -> List[Tuple[str,Tuple[int,...]]]:
    if not tokens: return default
    out=[]
    for t in tokens:
        name, widths = t.split(":")
        out.append( (name.upper(), parse_hidden(widths)) )
    return out

def build_parser():
    P = argparse.ArgumentParser(add_help=True)
    P.add_argument("--dims", nargs="+", type=int, default=[2,3,4,5,6,8,10,20,50],
                   help="Run both up to 12; DLT-only for larger dims.")
    P.add_argument("--n_eval", type=int, default=5000, help="size of shared Y_eval ⊂ D")
    P.add_argument("--domain_profile", choices=["paper","wide"], default="paper")

    # Classical feasibility gates
    P.add_argument("--mem_gb_cap", type=float, default=3.0, help="peak GiB cap for Lucet")
    P.add_argument("--flop_cap",   type=float, default=3e10, help="flop cap proxy for Lucet")
    # Direct feasibility gates
    P.add_argument("--mem_gb_cap_def", type=float, default=1.0, help="peak GiB cap for Direct")
    P.add_argument("--flop_cap_def",   type=float, default=1e12, help="flop cap proxy for Direct")
    P.add_argument("--def_dim_max", type=int, default=6, help="Attempt Direct only up to this dimension")

    # DLT training (auto by default)
    P.add_argument("--act", default="softplus", choices=["relu","gelu","softplus"])
    P.add_argument("--arch", nargs="*", default=[], help='DLT architectures, e.g. "RESNET:128,128" "ICNN:128,128"')
    P.add_argument("--steps", default="auto", help='int or "auto"')
    P.add_argument("--k_train", type=int, default=300, help="N ≈ k * d, clamped to [100,1000]")
    P.add_argument("--batch", default="auto", help='int or "auto" (min 16, max 64, ≈ N/10)')
    P.add_argument("--patience", default="auto", help='int or "auto" (early stopping)')
    P.add_argument("--lr", type=float, default=1e-3)

    # Grid overrides / sweeps
    P.add_argument("--nx", nargs="*", default=[], help='override x-grid per d, e.g. "2:10" "5:10"')
    P.add_argument("--ns", nargs="*", default=[], help='override slope-grid per d, e.g. "2:10" "5:10"')
    P.add_argument("--ns_sweep", nargs="*", default=[], help='sweep Ns per d, e.g. "2:9,11" "5:9,11"')

    # Lucet fast progress
    P.add_argument("--lucet_progress", action="store_true", help="print per-axis progress for Lucet")

    # Optional: log-uniform x sampling for positive-domain functions
    P.add_argument("--x_loguniform", dest="x_loguniform", action="store_true", default=True,
                   help="Use log-uniform x for neg_log and neg_entropy (default: on)")
    P.add_argument("--no_x_loguniform", dest="x_loguniform", action="store_false",
                   help="Disable log-uniform x; use uniform x on [C_lo, C_hi]")

    # Output
    P.add_argument("--csv", default="results_dlt_vs_classical.csv")
    P.add_argument("--outdir", default="figs")
    return P

def default_grids():
    NX_default = {d:10 for d in [2,3,4,5,6,8,10,12,20,50,200]}
    NS_default = {d:10 for d in [2,3,4,5,6,8,10,12,20,50,200]}
    NS_SWEEP_DEFAULT = {d:[10] for d in [2,3,4,5,6,8,10,12]}
    return NX_default, NS_default, NS_SWEEP_DEFAULT

# ========================= 8) Plotting ========================================
def make_plots(df: pd.DataFrame, outdir: str):
    os.makedirs(outdir, exist_ok=True)
    if df.empty: return
    for (fn_key, d), g in df.groupby(["fn_key","d"]):
        g = g.dropna(subset=["relL2","t_solve"])
        if g.empty: continue

        plt.figure()
        for meth, gm in g.groupby("method"):
            plt.scatter(gm["t_solve"], gm["relL2"], label=meth)
        plt.xscale("log"); plt.yscale("log")
        plt.xlabel("Solve time (s)"); plt.ylabel("Relative L2 error")
        title = f"{fn_key} (d={d}) — relL2 vs time"
        plt.title(title); plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"{fn_key}_d{d}_relL2_vs_time.png"), dpi=150)
        plt.close()

        g_mem = g.dropna(subset=["mem_MB"])
        if not g_mem.empty:
            plt.figure()
            for meth, gm in g_mem.groupby("method"):
                plt.scatter(gm["mem_MB"], gm["relL2"], label=meth)
            plt.xscale("log"); plt.yscale("log")
            plt.xlabel("Active memory (MiB)"); plt.ylabel("Relative L2 error")
            plt.title(f"{fn_key} (d={d}) — relL2 vs memory")
            plt.legend(); plt.tight_layout()
            plt.savefig(os.path.join(outdir, f"{fn_key}_d{d}_relL2_vs_mem.png"), dpi=150)
            plt.close()

# ========================= 9) Main ============================================
def main(argv=None):
    parser = build_parser()
    if argv is None: argv = sys.argv[1:]
    args, _ = parser.parse_known_args(argv)

    domain = DOMAINS_PAPER if args.domain_profile == "paper" else DOMAINS_WIDE
    NX_default, NS_default, NS_SWEEP_DEFAULT = default_grids()
    NX_map = parse_dim_map(args.nx, NX_default)
    NS_map = parse_dim_map(args.ns, NS_default)
    NS_sweep_map = parse_ns_sweep(args.ns_sweep, NS_SWEEP_DEFAULT)
    archs = parse_arch_list(args.arch, default=[("RESNET", (128,128))])

    rows = []
    hdr = (
        f"{'Function':<12} {'d':>4} | "
        f"{'method':<8} | {'model':<8} | {'hidden':<12} | "
        f"{'Nx':>4} | {'Ns':>4} | {'Ntr':>6} | {'B':>4} | {'steps':>7} | "
        f"{'tSolve(s)':>10} | {'tEval(s)':>9} | {'MB(act)':>9} | "
        f"{'max_err':>10} | {'RMSE':>10} | {'relL2':>10}"
    )
    print(hdr); print("-"*len(hdr))

    for fn_key in ["quadratic","neg_log","neg_entropy"]:
        f = FUNS[fn_key]["f"]; g = FUNS[fn_key]["g"]; fst = FUNS[fn_key]["fst"]
        samp_x, samp_y = make_samplers(domain, fn_key, args.x_loguniform)

        for d in args.dims:
            Nx_default = NX_map.get(d, 10)
            Ns_default = NS_map.get(d, 10)
            _, _, Y_eval, f_np, fst_np = build_grids_and_eval(fn_key, d, Nx_default, Ns_default, args.n_eval, domain)

            # ---- Classical sweeps over Ns ----
            Ns_list = NS_sweep_map.get(d, [Ns_default])
            for Ns in Ns_list:
                Nx = Nx_default
                try:
                    c_rows = run_case_classical(fn_key, d, Nx, Ns, args, Y_eval, f_np, fst_np, domain)
                    for r in c_rows:
                        rows.append(r)
                        print(f"{r['Function']:<12} {r['d']:>4} | {r['method']:<8} | "
                              f"{r['model']:<8} | {r['hidden']:<12} | "
                              f"{str(r['Nx'] or ''):>4} | {str(r['Ns'] or ''):>4} | "
                              f"{str(r['N_train'] or ''):>6} | {str(r['batch'] or ''):>4} | "
                              f"{str(r['steps'] or ''):>7} | "
                              f"{(r['t_solve'] or float('nan')):>10.2f} | {(r['t_eval'] or float('nan')):>9.2f} | "
                              f"{(r['mem_MB'] or float('nan')):>9.1f} | "
                              f"{(r['max_err'] or float('nan')):>10.2e} | {(r['rmse'] or float('nan')):>10.2e} | "
                              f"{(r['relL2'] or float('nan')):>10.2e}")
                except Exception as e:
                    print(f"[skip classical {fn_key} d={d} Ns={Ns}] {e}")

            # ---- DLT architectures ----
            for (arch_name, hidden) in archs:
                try:
                    r = run_case_dlt(fn_key, d, arch_name, hidden, args.act, args, Y_eval, f, g, fst, samp_x)
                    rows.append(r)
                    print(f"{r['Function']:<12} {r['d']:>4} | {r['method']:<8} | "
                          f"{r['model']:<8} | {r['hidden']:<12} | "
                          f"{str(r['Nx'] or ''):>4} | {str(r['Ns'] or ''):>4} | "
                          f"{str(r['N_train'] or ''):>6} | {str(r['batch'] or ''):>4} | "
                          f"{str(r['steps'] or ''):>7} | "
                          f"{(r['t_solve'] or float('nan')):>10.2f} | {(r['t_eval'] or float('nan')):>9.2f} | "
                          f"{(r['mem_MB'] or float('nan')):>9.1f} | "
                          f"{(r['max_err'] or float('nan')):>10.2e} | {(r['rmse'] or float('nan')):>10.2e} | "
                          f"{(r['relL2'] or float('nan')):>10.2e}")
                except Exception as e:
                    print(f"[skip DLT {arch_name} {fn_key} d={d}] {e}")

    df = pd.DataFrame(rows)
    df.to_csv(args.csv, index=False)
    print(f"\nWrote CSV: {args.csv}  ({len(df)} rows)")
    try:
        make_plots(df, args.outdir)
        print(f"Saved figures to: {args.outdir}")
    except Exception as e:
        print(f"[plotting skipped] {e}")

if __name__=="__main__":
    if "ipykernel" in sys.modules or "google.colab" in sys.modules:
        main([])
    else:
        main(None)


Function        d | method   | model    | hidden       |   Nx |   Ns |    Ntr |    B |   steps |  tSolve(s) |  tEval(s) |   MB(act) |    max_err |       RMSE |      relL2
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Quadratic       2 | Lucet    |          |              |   10 |   10 |        |      |         |       0.00 |      0.12 |       0.0 |   1.11e-01 |   7.71e-02 |   2.17e-02
Quadratic       2 | Direct   |          |              |   10 |   10 |        |      |         |       0.00 |      0.12 |       0.0 |   1.11e-01 |   7.71e-02 |   2.17e-02
Quadratic       2 | DLT      | RESNET   | 128,128      |      |      |    600 |   60 |   10000 |      21.48 |      2.78 |       1.1 |   6.73e-02 |   6.11e-03 |   1.72e-03
Quadratic       3 | Lucet    |          |              |   10 |   10 |        |      |         |       0.01 |      0.20 |       0.0 |   1.16e+01 