<a href="https://colab.research.google.com/github/lexmar07/Deep-Legendre-Transform/blob/main/main_part/Table_2__DLT_and_direct_learning_of_conjugate_funcitons.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


###  Comparing Different NN Architectures

####  Comparison with Direct Learning (When Dual is Known)

To validate our Deep Legendre Transform (DLT) approach, we compare it to *direct learning* of the convex conjugate in cases where the analytical form of \$f^\*\$ is known.

We benchmark DLT against direct learning across multiple convex functions and input dimensions.

####  Benchmark Setup

We compare two methods:

* **DLT (implicit)**: Train $g_\theta$ such that
  $g_\theta(\nabla f(x)) \approx \langle x, \nabla f(x) \rangle - f(x)$
* **Direct**: Train $h_\theta$ to approximate $f^*$ directly,
  $h_\theta(y) \approx f^*(y)$
  Only possible when $f^*$ is known.

> Direct learning serves as a "gold standard" — but only works when \$f^\*\$ has a closed-form expression.

We use the same neural architectures across both methods:

* **MLP**: 2 hidden layers, 128 ReLU units
* **MLP\_ICNN**: convex MLP with 2 layers, 128 units
* **ResNet**: 2 residual blocks, 128 units
* **ICNN**: Input-convex NN with skip connections (Amos et al. 2017), 2×128 layers

Optimization:

* Adam optimizer, LR = \$10^{-3}\$, batch size = 128*d
* Early stop when \$L^2\_2\$ error < \$10^{-6}\$ or after 50k iterations

####  Sampling Strategy

We carefully match sampling between primal and dual spaces via the gradient map \$\nabla f(x)\$.

* **Quadratic**:
  $f(x) = \frac{1}{2} \|x\|^2 \quad\Rightarrow\quad \nabla f(x) = x,\quad f^*(y) = \frac{1}{2} \|y\|^2$
  Sample \$x \sim \mathcal{N}(0, I)\$ ⇒ \$y = x\$

* **Neg. Log**:
  $f(x) = -\sum_{i=1}^d \log(x_i) \quad\Rightarrow\quad \nabla f(x) = -\frac{1}{x_i},\quad f^*(y) = -\sum \log(-y_i) - d$
  Sample \$x\$ in $\[0.1, 10]^d\$ ⇒ \$y \in \[-10, -0.1]^d\$

* **Neg. Entropy**:
  $f(x) = \sum x_i \log x_i \quad\Rightarrow\quad \nabla f(x) = \log x_i + 1,\quad f^*(y) = \sum \exp(y_i - 1)$
  Sample \$x\$ in log-space ⇒ \$y \in \[-1.3, 3.3]^d\$

#### Results and Analysis

We test dimensions \$d \in {2, 5, 10, 20}\$ for:

$$
\begin{aligned}
\text{Quadratic:} \quad & f(x) = \frac{\|x\|^2}{2}, \quad f^*(y) = \frac{\|y\|^2}{2} \\
\text{Neg. Log:} \quad & f(x) = -\sum \log(x_i), \quad f^*(y) = -\sum \log(-y_i) - d \\
\text{Neg. Entropy:} \quad & f(x) = \sum x_i \log x_i, \quad f^*(y) = \sum \exp(y_i - 1)
\end{aligned}
$$




In [None]:
#!/usr/bin/env python3
# benchmark.py – implicit vs explicit convex-conjugate learning
# ν-sampling • staircase LR • per-model activations (relu|gelu|softplus)
# batch size: "scale" (d×64) or constant • repeats with σ
# prints mean L2 errors and saves LaTeX tables with σ column
# --------------------------------------------------------------------
from __future__ import annotations
import time, os, sys, argparse
from functools import partial
from typing import Sequence, Callable, Dict
import jax, jax.numpy as jnp, optax
from flax import linen as nn
from flax.training import train_state
from jax import random
import numpy as np

# ═════════ 1. convex test functions ═════════════════════════════════
f_quad,  grad_quad  = lambda x:0.5*jnp.sum(x**2,-1),      lambda x:x
fst_quad            = lambda y:0.5*jnp.sum(y**2,-1)

f_nlog,  grad_nlog  = lambda x:-jnp.sum(jnp.log(x),-1),   lambda x:-1./x
fst_nlog            = lambda y:-jnp.sum(jnp.log(-y),-1)-y.shape[-1]

f_nent,  grad_nent  = lambda x:jnp.sum(x*jnp.log(x),-1),  lambda x:jnp.log(x)+1
fst_nent            = lambda y:jnp.sum(jnp.exp(y-1.),-1)

def _u(rng, sh, lo, hi):
    return random.uniform(rng, sh, minval=lo, maxval=hi, dtype=jnp.float32)

FUNCTIONS = {
    "quadratic":   (f_quad, grad_quad, fst_quad,
                    lambda k,s: random.normal(k, s, jnp.float32)),
    "neg_log":     (f_nlog, grad_nlog, fst_nlog,
                    lambda k,s: jnp.exp(_u(k, s, -2.3, 2.3))),
    "neg_entropy": (f_nent, grad_nent, fst_nent,
                    lambda k,s: jnp.exp(_u(k, s, -2.3, 2.3))),
}
FUNCPRINT = {"quadratic":"Quadratic",
             "neg_log":"Neg.\ Log",
             "neg_entropy":"Neg.\ Entropy"}

# ═════════ 2. activation helper (relu | gelu | softplus) ════════════
def _act(name:str)->Callable:
    n=name.lower()
    if n=="relu":     return nn.relu
    if n=="gelu":     return jax.nn.gelu
    if n=="softplus": return jax.nn.softplus
    raise ValueError(f"unknown activation {name}")

# ═════════ 3. model zoo (activation pluggable) ═════════════════════=
class DensePos(nn.Module):
    features:int; use_bias:bool=True
    @nn.compact
    def __call__(self,x):
        W=nn.softplus(self.param("raw_W",nn.initializers.lecun_normal(),
                                 (x.shape[-1],self.features)))
        y=x@W
        if self.use_bias:
            y+=self.param("b",nn.initializers.zeros,(self.features,))
        return y

class MLP(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        for h in self.hidden: x=self.act(nn.Dense(h)(x))
        return jnp.squeeze(nn.Dense(1)(x),-1)

class MLP_ICNN(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        z=x
        for h in self.hidden: z=self.act(DensePos(h)(z))
        out=DensePos(1,use_bias=False)(z)+nn.Dense(1,use_bias=False)(x)
        return jnp.squeeze(out,-1)

class ICNN(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        z=jnp.zeros((x.shape[0],1))
        for h in self.hidden:
            z=self.act(DensePos(h)(z)+nn.Dense(h)(x))
        out=DensePos(1,use_bias=False)(z)+nn.Dense(1,use_bias=False)(x)
        return jnp.squeeze(out,-1)

class ResBlock(nn.Module):
    f:int; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        y=self.act(nn.Dense(self.f)(x)); y=nn.Dense(self.f)(y)
        if x.shape[-1]!=self.f: x=nn.Dense(self.f,use_bias=False)(x)
        return self.act(x+y)

class ResNet(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        for h in self.hidden: x=ResBlock(h,act=self.act)(x)
        return jnp.squeeze(nn.Dense(1)(x),-1)

def parse_hidden(s:str)->tuple[int,...]:
    return tuple(int(v) for v in s.split(",") if v)

# ═════════ 4. optimiser / loss / jit helpers ════════════════════════
class State(train_state.TrainState): ...
def schedule(lr:float): return optax.exponential_decay(lr,20_000,0.5,True)
def new_state(rng,model,d,lr):
    p=model.init(rng,jnp.zeros((1,d),jnp.float32))["params"]
    return State.create(apply_fn=model.apply,params=p,tx=optax.adam(schedule(lr)))

loss_impl=lambda p,af,x,f,g:jnp.mean((af({"params":p},g(x))-
                                      (jnp.sum(x*g(x),-1)-f(x)))**2)
loss_expl=lambda p,af,y,fst:jnp.mean((af({"params":p},y)-fst(y))**2)

@partial(jax.jit, static_argnums=(2,3))
def step_impl(st,b,f,g):
    l,gr=jax.value_and_grad(loss_impl)(st.params,st.apply_fn,b,f,g)
    return st.apply_gradients(grads=gr),l
@partial(jax.jit, static_argnums=(2,))
def step_expl(st,b,fst):
    l,gr=jax.value_and_grad(loss_expl)(st.params,st.apply_fn,b,fst)
    return st.apply_gradients(grads=gr),l

@partial(jax.jit, static_argnums=(1,3,4))
def _ei(p,af,x,f,g): return loss_impl(p,af,x,f,g)
def eval_impl(p,af,x,f,g): return float(_ei(p,af,x,f,g))
@partial(jax.jit, static_argnums=(1,3))
def _ee(p,af,y,fst): return loss_expl(p,af,y,fst)
def eval_expl(p,af,y,fst): return float(_ee(p,af,y,fst))

# ═════════ 5. early stopping ════════════════════════════════════════
class Stopper:
    def __init__(self,pat,tol=1e-6):
        self.best=float("inf"); self.pat=pat; self.tol=tol
        self.cnt=0; self.bp=None
    def update(self,loss,p):
        loss=float(loss)
        if loss+self.tol<self.best:
            self.best,self.cnt=loss,0; self.bp=p
        else: self.cnt+=1
        return self.cnt>=self.pat or self.best<self.tol
    def res(self): return self.best,self.bp

# ═════════ 6. helper: batch size ════════════════════════════════════
def batch_size(d:int,arg:str)->int:
    return d*64 if arg=="scale" else int(arg)

# ═════════ 7. training routine (returns final L2 error) ═════════════
def train(model_fn,d,f,g,samp,steps,lr,pat,seed,
          implicit,batch,verb=False):
    st=new_state(random.PRNGKey(seed),model_fn(),d,lr)
    stop=Stopper(pat); stepf=step_impl if implicit else step_expl
    tag="impl" if implicit else "expl"; bar=max(steps//20,1)
    i=0
    while i<steps:
        mb=samp(random.fold_in(random.PRNGKey(seed+999),i),(batch,d))
        st,loss = stepf(st,mb,f,g) if implicit else stepf(st,mb,f)
        if stop.update(loss,st.params): break
        if verb and i%bar==0:
            pct=100*i/steps
            print(f"[{tag}] {i:6d}/{steps} ({pct:5.1f}%) loss {float(loss):.3e}")
        elif (not verb) and i%bar==0:
            pct=i/steps; br=int(20*pct)
            sys.stdout.write(f"\r[{tag}] [{'#'*br}{'.'*(20-br)}] {pct*100:5.1f}%")
            sys.stdout.flush()
        i+=1
    if not verb: sys.stdout.write("\n")
    _,bp=stop.res()
    if implicit:
        return eval_impl(bp,st.apply_fn,
                         samp(random.PRNGKey(0),(batch,d)),f,g)
    else:
        return eval_expl(bp,st.apply_fn,
                         samp(random.PRNGKey(0),(batch,d)),f)

# ═════════ 8. benchmark (L2 means/σ + ratio) ════════════════════════
def bench(fn,d,steps,pat,models,runs,batch_arg,verb):
    f,g,fst,sampx=FUNCTIONS[fn]; sampy=lambda k,sh:g(sampx(k,sh))
    bs=batch_size(d,batch_arg)
    rows=[]
    for nm,sp in models.items():
        l2I,l2E,ratios=[],[],[]
        for r in range(runs):
            if verb: print(f"\n▶ {nm} ({fn},d={d}) run {r+1}/{runs}")
            errI=train(sp["make"],d,f,g,sampx,steps,sp["lr"],pat,
                       7000+d*11+r*5,True, bs,verb)
            errE=train(sp["make"],d,fst,None,sampy,steps,sp["lr"],pat,
                       7100+d*13+r*5,False,bs,verb)
            l2I.append(errI); l2E.append(errE)
            ratios.append(errI/errE if errE else 1.)
        rows.append(dict(model=nm,d=d,
                         l2I=float(np.mean(l2I)),
                         l2E=float(np.mean(l2E)),
                         ratio_mean=float(np.mean(ratios)),
                         ratio_sd=float(np.std(ratios))))
    return rows

# ═════════ 9. LaTeX helper (σ only) ═════════════════════════════════
def tex_tables(res:Dict[str,list],dims)->Dict[str,str]:
    out={}
    for fn,rows in res.items():
        tex=["\\begin{table}[h]","\\centering",
             f"\\caption{{Results for {FUNCPRINT[fn]}}}",
             "\\begin{tabular}{cc|cc}",
             "\\toprule",
             "$d$ & Model & $\\mathbb E[\\rho]$ & $\\sigma(\\rho)$ \\\\ \\midrule"]
        for d in dims:
            dr=[r for r in rows if r["d"]==d]
            if not dr: continue
            tex.append(f"\\multirow{{{len(dr)}}}{{*}}{{{d}}}")
            for i,r in enumerate(dr):
                line=" & ".join([r["model"],
                                 f"{r['ratio_mean']:.2f}",
                                 f"{r['ratio_sd']:.2f}"])+"\\\\"
                tex.append((" " if i else "")+line)
            if d!=dims[-1]: tex.append("\\cmidrule{2-4}")
        tex+=["\\bottomrule","\\end{tabular}","\\end{table}"]
        out[fn]="\n".join(tex)
    return out

# ═════════ 10. CLI & main ═══════════════════════════════════════════
def build_parser():
    P=argparse.ArgumentParser()
    P.add_argument("--steps",type=int,default=50000)
    P.add_argument("--patience",type=int,default=10000)
    P.add_argument("--lr",type=float,default=1e-3)
    P.add_argument("--runs",type=int,default=10)
    P.add_argument("--batch",default="scale")
    P.add_argument("--dims",nargs="+",type=int,default=[2,5,10])
    P.add_argument("--verbose",action="store_true")
    # hidden sizes
    P.add_argument("--mlp_hidden",default="128,128")
    P.add_argument("--mlp_icnn_hidden",default="128,128")
    P.add_argument("--resnet_hidden",default="128,128")
    P.add_argument("--icnn_hidden",default="128,128")
    # learning-rates
    P.add_argument("--mlp_lr",type=float);  P.add_argument("--mlp_icnn_lr",type=float)
    P.add_argument("--resnet_lr",type=float); P.add_argument("--icnn_lr",type=float)
    # activations
    # P.add_argument("--mlp_act",default="relu")
    # P.add_argument("--mlp_icnn_act",default="relu")
    # P.add_argument("--resnet_act",default="relu")
    # P.add_argument("--icnn_act",default="relu")
    P.add_argument("--mlp_act",default="relu")
    P.add_argument("--mlp_icnn_act",default="softplus")
    P.add_argument("--resnet_act",default="relu")
    P.add_argument("--icnn_act",default="softplus")
    return P

def main(argv=None):
    args,_=build_parser().parse_known_args(argv or sys.argv[1:])
    os.makedirs("results",exist_ok=True)
    base_lr=args.lr
    models={
        "MLP":{"make":lambda:MLP(parse_hidden(args.mlp_hidden),
                                 act=_act(args.mlp_act)),
               "lr":args.mlp_lr or base_lr},
        "MLP_ICNN":{"make":lambda:MLP_ICNN(parse_hidden(args.mlp_icnn_hidden),
                                           act=_act(args.mlp_icnn_act)),
               "lr":args.mlp_icnn_lr or base_lr*3},
        "ResNet":{"make":lambda:ResNet(parse_hidden(args.resnet_hidden),
                                       act=_act(args.resnet_act)),
               "lr":args.resnet_lr or base_lr},
        "ICNN":{"make":lambda:ICNN(parse_hidden(args.icnn_hidden),
                                   act=_act(args.icnn_act)),
               "lr":args.icnn_lr or base_lr*3},
    }

    all_res={}
    for fn in FUNCTIONS:
        print("\n"+"="*78+f"\n{FUNCPRINT[fn]} benchmark\n"+"="*78)
        all_res[fn]=[]
        for d in args.dims:
            rows=bench(fn,d,args.steps,args.patience,
                       models,args.runs,args.batch,args.verbose)
            all_res[fn].extend(rows)
            for r in rows:
                print(f"{r['model']:<10}"
                      f"L2impl {r['l2I']:.2e}  L2expl {r['l2E']:.2e}  "
                      f"ratio {r['ratio_mean']:.2f} σ={r['ratio_sd']:.2f}")
    # ➜ LaTeX
    tables=tex_tables(all_res,args.dims)
    for fn,tex in tables.items():
        fname=f"results/{fn}_table.tex"
        with open(fname,"w") as f: f.write(tex)
        print(f"LaTeX saved → {fname}")

if __name__=="__main__":
    main()



Quadratic benchmark
[impl] [###################.]  95.0%
[expl] [#############.......]  65.0%
[impl] [###################.]  95.0%
[expl] [###################.]  95.0%
[impl] [##############......]  70.0%
[expl] [###################.]  95.0%
[impl] [##################..]  90.0%
[expl] [################....]  80.0%
[impl] [##############......]  70.0%
[expl] [################....]  80.0%
[impl] [##################..]  90.0%
[expl] [##############......]  70.0%
[impl] [###################.]  95.0%
[expl] [#################...]  85.0%
[impl] [##################..]  90.0%
[expl] [#################...]  85.0%
[impl] [###############.....]  75.0%
[expl] [############........]  60.0%
[impl] [##################..]  90.0%
[expl] [##############......]  70.0%
[impl] [#############.......]  65.0%
[expl] [###############.....]  75.0%
[impl] [##############......]  70.0%
[expl] [##############......]  70.0%
[impl] [#############.......]  65.0%
[expl] [#############.......]  65.0%
[impl] [#########

In [None]:
#!/usr/bin/env python3
# benchmark.py – implicit vs explicit convex-conjugate learning
# ν‑sampling • staircase LR • per‑model activations (relu|gelu|softplus)
# batch size: "scale" (d×64) or constant • repeats with σ
# prints rows + one combined LaTeX table
# --------------------------------------------------------------------
from __future__ import annotations
import os, sys, time, argparse
from functools import partial
from typing import Sequence, Callable, Dict

import jax, jax.numpy as jnp, optax
from jax import random
from flax import linen as nn
from flax.training import train_state
import numpy as np

# ═════ 1. convex test functions ═════════════════════════════════════
f_quad,  grad_quad  = lambda x: 0.5*jnp.sum(x**2, -1),        lambda x: x
fst_quad            = lambda y: 0.5*jnp.sum(y**2, -1)

f_nlog,  grad_nlog  = lambda x:-jnp.sum(jnp.log(x), -1),      lambda x:-1./x
fst_nlog            = lambda y:-jnp.sum(jnp.log(-y), -1) - y.shape[-1]

f_nent,  grad_nent  = lambda x:jnp.sum(x*jnp.log(x), -1),     lambda x:jnp.log(x)+1
fst_nent            = lambda y:jnp.sum(jnp.exp(y-1.), -1)

def _u(rng, sh, lo, hi):
    return random.uniform(rng, shape=sh, minval=lo, maxval=hi,
                          dtype=jnp.float32)

FUNCTIONS = {
    "quadratic":   (f_quad, grad_quad, fst_quad,
                    lambda k,s: random.normal(k, s, dtype=jnp.float32)),
    "neg_log":     (f_nlog, grad_nlog, fst_nlog,
                    lambda k,s: jnp.exp(_u(k, s, -2.3,  2.3))),
    "neg_entropy": (f_nent, grad_nent, fst_nent,
                    lambda k,s: jnp.exp(_u(k, s, -2.3,  2.3))),
}
FUNCPRINT = {"quadratic": "Quadratic",
             "neg_log":   "Neg.\ Log",
             "neg_entropy":"Neg.\ Entropy"}

# ═════ 2. activations ═══════════════════════════════════════════════
def _act(name:str)->Callable:
    n=name.lower()
    if n=="relu":     return nn.relu
    if n=="gelu":     return jax.nn.gelu
    if n=="softplus": return jax.nn.softplus
    raise ValueError(f"unknown activation {name}")

# ═════ 3. model zoo ═════════════════════════════════════════════════
class DensePos(nn.Module):
    features:int; use_bias:bool=True
    @nn.compact
    def __call__(self,x):
        W = nn.softplus(self.param("rawW", nn.initializers.lecun_normal(),
                                   (x.shape[-1], self.features)))
        y = x @ W
        if self.use_bias:
            y += self.param("b", nn.initializers.zeros, (self.features,))
        return y

class MLP(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        for h in self.hidden: x = self.act(nn.Dense(h)(x))
        return jnp.squeeze(nn.Dense(1)(x), -1)

class MLP_ICNN(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        z=x
        for h in self.hidden: z = self.act(DensePos(h)(z))
        out = DensePos(1, use_bias=False)(z) + nn.Dense(1, use_bias=False)(x)
        return jnp.squeeze(out, -1)

class ICNN(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        z=jnp.zeros((x.shape[0],1))
        for h in self.hidden:
            z = self.act(DensePos(h)(z) + nn.Dense(h)(x))
        out = DensePos(1, use_bias=False)(z) + nn.Dense(1, use_bias=False)(x)
        return jnp.squeeze(out, -1)

class ResBlock(nn.Module):
    f:int; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        y=self.act(nn.Dense(self.f)(x)); y=nn.Dense(self.f)(y)
        if x.shape[-1]!=self.f:
            x = nn.Dense(self.f, use_bias=False)(x)
        return self.act(x+y)

class ResNet(nn.Module):
    hidden:Sequence[int]; act:Callable=nn.relu
    @nn.compact
    def __call__(self,x):
        for h in self.hidden: x = ResBlock(h, act=self.act)(x)
        return jnp.squeeze(nn.Dense(1)(x), -1)

def parse_hidden(s:str)->tuple[int,...]:
    return tuple(int(v) for v in s.split(",") if v)

# ═════ 4. optimiser / losses / jit helpers ═════════════════════════
class State(train_state.TrainState): ...

def schedule(lr:float):
    return optax.exponential_decay(lr, 20_000, 0.5, staircase=True)

def new_state(rng, model, d, lr):
    params = model.init(rng, jnp.zeros((1, d), jnp.float32))["params"]
    return State.create(apply_fn=model.apply, params=params,
                        tx=optax.adam(schedule(lr)))

loss_impl = lambda p,af,x,f,g: jnp.mean(
    (af({"params":p}, g(x)) - (jnp.sum(x*g(x), -1) - f(x)))**2)
loss_expl = lambda p,af,y,fst: jnp.mean((af({"params":p}, y) - fst(y))**2)

@partial(jax.jit, static_argnums=(2,3))
def step_impl(st,b,f,g):
    l,gr = jax.value_and_grad(loss_impl)(st.params, st.apply_fn, b, f, g)
    return st.apply_gradients(grads=gr), l

@partial(jax.jit, static_argnums=(2,))
def step_expl(st,b,fst):
    l,gr = jax.value_and_grad(loss_expl)(st.params, st.apply_fn, b, fst)
    return st.apply_gradients(grads=gr), l

@partial(jax.jit, static_argnums=(1,3,4))
def _ei(p,af,x,f,g): return loss_impl(p,af,x,f,g)
def eval_impl(p,af,x,f,g): return float(_ei(p,af,x,f,g))

@partial(jax.jit, static_argnums=(1,3))
def _ee(p,af,y,fst): return loss_expl(p,af,y,fst)
def eval_expl(p,af,y,fst): return float(_ee(p,af,y,fst))

# ═════ 5. early stopping ═══════════════════════════════════════════
class Stopper:
    def __init__(self, pat:int, tol:float=1e-6):
        self.best=float("inf"); self.pat=pat; self.tol=tol
        self.cnt=0; self.bp=None
    def update(self, loss, params):
        loss=float(loss)
        if loss+self.tol < self.best:
            self.best, self.cnt = loss, 0; self.bp = params
        else:
            self.cnt += 1
        return self.cnt >= self.pat or self.best < self.tol
    def res(self): return self.best, self.bp

# ═════ 6. utilities ════════════════════════════════════════════════
def batch_size(d:int, arg:str)->int:
    return d*64 if arg=="scale" else int(arg)

# ═════ 7. training routine (returns err & time) ════════════════════
def train(model_fn, d, f, g, samp, steps, lr, pat, seed,
          implicit:bool, batch:int, verb=False):
    st   = new_state(random.PRNGKey(seed), model_fn(), d, lr)
    stop = Stopper(pat)
    step = step_impl if implicit else step_expl
    tag  = "impl" if implicit else "expl"
    bar  = max(steps//20, 1)
    t0   = time.perf_counter()

    for i in range(steps):
        mb = samp(random.fold_in(random.PRNGKey(seed+999), i), (batch, d))
        st, loss = step(st, mb, f, g) if implicit else step(st, mb, f)
        if stop.update(loss, st.params):
            break
        if not verb and i%bar==0:
            pct = i/steps; br = int(20*pct)
            sys.stdout.write(f"\r[{tag}] [{'#'*br}{'.'*(20-br)}] {pct*100:5.1f}%")
            sys.stdout.flush()
        elif verb and i%bar==0:
            print(f"[{tag}] {i:6d}/{steps} "
                  f"({100*i/steps:5.1f}%) loss {float(loss):.3e}")
    if not verb:
        sys.stdout.write("\n")

    _, bp = stop.res()
    rng0  = random.PRNGKey(0)
    err = (eval_impl if implicit else eval_expl)(
        bp, st.apply_fn,
        samp(rng0, (batch, d)),
        f, g) if implicit else \
        eval_expl(bp, st.apply_fn, samp(rng0, (batch, d)), f)
    return err, time.perf_counter() - t0

# ═════ 8. benchmark (means, σ, times) ══════════════════════════════
def bench(fn, d, steps, pat, models, runs, batch_arg, verb):
    f, g, fst, sampx = FUNCTIONS[fn]
    sampy = lambda k, sh: g(sampx(k, sh))
    bs = batch_size(d, batch_arg)
    rows = []

    for nm, sp in models.items():
        l2I, l2E, tI, tE, ratios = [], [], [], [], []
        for r in range(runs):
            if verb: print(f"\n▶ {nm} ({fn}, d={d}) run {r+1}/{runs}")
            errI, timeI = train(sp["make"], d, f, g, sampx,
                                steps, sp["lr"], pat,
                                7000+d*11+r*5, True,  bs, verb)
            errE, timeE = train(sp["make"], d, fst, None, sampy,
                                steps, sp["lr"], pat,
                                7100+d*13+r*5, False, bs, verb)
            l2I.append(errI); l2E.append(errE)
            tI.append(timeI);  tE.append(timeE)
            ratios.append(errI/errE if errE else 1.)
        rows.append(dict(model=nm, d=d,
                         l2I=float(np.mean(l2I)), l2E=float(np.mean(l2E)),
                         tI=float(np.mean(tI)),   tE=float(np.mean(tE)),
                         rho_mu=float(np.mean(ratios)),
                         rho_sigma=float(np.std(ratios))))
    return rows

# ═════ 9. combined LaTeX helper ════════════════════════════════════
# ──────────────────────────────────────────────────────────────────────
# 9′.  Combined LaTeX helper  (4 rows per model‑block, keeps Time + σ)
# ──────────────────────────────────────────────────────────────────────
def tex_tables(res: Dict[str, list], dims):
    fun_order = ["quadratic", "neg_log", "neg_entropy"]
    model_order = ["MLP", "MLP_ICNN", "ResNet", "ICNN"]

    tex = [
        "\\begin{table}[h]",
        "  \\centering",
        "  \\caption{Benchmark results comparing implicit DLT against direct learning with known duals}",
        "  \\label{tab:combined_benchmark}",
        "  \\begin{tabular}{ccc|cc|cc|cc}",
        "    \\toprule",
        "    \\multirow{2}{*}{Function} & \\multirow{2}{*}{$d$} & \\multirow{2}{*}{Model}"
        " & \\multicolumn{2}{c|}{$L^2$ Error} & \\multicolumn{2}{c|}{Time (s)} & \\multicolumn{2}{c}{Ratio} \\\\",
        "    & & & Impl. & Dir. & Impl. & Dir. & $\\mu$ & $\\sigma$ \\\\",
        "    \\midrule"
    ]

    for fn in fun_order:
        rows_fn = sorted(res.get(fn, []),
                         key=lambda r: (r["d"], model_order.index(r["model"])))
        if not rows_fn:
            continue
        total_rows_fn = len(rows_fn)
        fn_first_row_written = False

        for d in dims:
            rows_dim = [r for r in rows_fn if r["d"] == d]
            if not rows_dim:
                continue
            rows_dim = sorted(rows_dim,
                              key=lambda r: model_order.index(r["model"]))
            dim_first_row_written = False

            for r in rows_dim:
                line_parts = []
                # Function column
                if not fn_first_row_written:
                    line_parts.append(
                        f"\\multirow{{{total_rows_fn}}}{{*}}{{{FUNCPRINT[fn]}}}")
                    fn_first_row_written = True
                else:
                    line_parts.append(" ")

                # Dimension column
                if not dim_first_row_written:
                    line_parts.append(
                        f"\\multirow{{{len(rows_dim)}}}{{*}}{{{d}}}")
                    dim_first_row_written = True
                else:
                    line_parts.append(" ")

                # Model + metrics
                line_parts.extend([
                    r["model"],
                    f"{r['l2I']:.2e}", f"{r['l2E']:.2e}",
                    f"{r['tI']:.1f}",  f"{r['tE']:.1f}",
                    f"{r['rho_mu']:.2f}", f"{r['rho_sigma']:.2f}"
                ])
                tex.append(" & ".join(line_parts) + " \\\\")
            # horizontal line between different d‑blocks
            if d != dims[-1]:
                tex.append("    \\cmidrule{2-9}")
        # mid‑rule between functions
        if fn != fun_order[-1]:
            tex.append("    \\midrule")

    tex += [
        "    \\bottomrule",
        "  \\end{tabular}",
        "\\end{table}"
    ]

    table = "\n".join(tex)
    os.makedirs("results", exist_ok=True)
    with open("results/combined_table.tex", "w") as f:
        f.write(table)
    print("\nCombined LaTeX table (also saved to results/combined_table.tex):\n")
    print(table + "\n")
    return table


# ═════ 10. CLI & main ═══════════════════════════════════════════════
def build_parser():
    P = argparse.ArgumentParser()
    P.add_argument("--steps", type=int, default=50_000)
    P.add_argument("--patience", type=int, default=10_000)
    P.add_argument("--lr", type=float, default=1e-3)
    P.add_argument("--runs", type=int, default=10)
    P.add_argument("--batch", default="scale")
    P.add_argument("--dims", nargs="+", type=int, default=[2, 5, 10])
    P.add_argument("--verbose", action="store_true")
    # hidden sizes
    P.add_argument("--mlp_hidden", default="128,128")
    P.add_argument("--mlp_icnn_hidden", default="128,128")
    P.add_argument("--resnet_hidden", default="128,128")
    P.add_argument("--icnn_hidden", default="128,128")
    # learning rates
    P.add_argument("--mlp_lr", type=float)
    P.add_argument("--mlp_icnn_lr", type=float)
    P.add_argument("--resnet_lr", type=float)
    P.add_argument("--icnn_lr", type=float)
    # activations
    P.add_argument("--mlp_act", default="relu")
    P.add_argument("--mlp_icnn_act", default="softplus")
    P.add_argument("--resnet_act", default="relu")
    P.add_argument("--icnn_act", default="softplus")
    return P

def main(argv=None):
    args, _ = build_parser().parse_known_args(argv or sys.argv[1:])
    base_lr = args.lr
    models = {
        "MLP": {
            "make": lambda: MLP(parse_hidden(args.mlp_hidden),
                                act=_act(args.mlp_act)),
            "lr": args.mlp_lr or base_lr},
        "MLP_ICNN": {
            "make": lambda: MLP_ICNN(parse_hidden(args.mlp_icnn_hidden),
                                     act=_act(args.mlp_icnn_act)),
            "lr": args.mlp_icnn_lr or base_lr*3},
        "ResNet": {
            "make": lambda: ResNet(parse_hidden(args.resnet_hidden),
                                   act=_act(args.resnet_act)),
            "lr": args.resnet_lr or base_lr},
        "ICNN": {
            "make": lambda: ICNN(parse_hidden(args.icnn_hidden),
                                 act=_act(args.icnn_act)),
            "lr": args.icnn_lr or base_lr*3},
    }

    all_res = {}
    for fn in FUNCTIONS:
        print("\n" + "="*78 + f"\n{FUNCPRINT[fn]} benchmark\n" + "="*78)
        all_res[fn] = []
        for d in args.dims:
            rows = bench(fn, d, args.steps, args.patience,
                         models, args.runs, args.batch, args.verbose)
            all_res[fn].extend(rows)
            for r in rows:
                print(f"{r['model']:<10}"
                      f"L2impl {r['l2I']:.2e}  L2dir {r['l2E']:.2e}  "
                      f"time {r['tI']:.1f}/{r['tE']:.1f}s  "
                      f"ρ {r['rho_mu']:.2f} σ={r['rho_sigma']:.2f}")

    tex_tables(all_res, args.dims)

if __name__ == "__main__":
    main()



Quadratic benchmark
[impl] [##############......]  70.0%
[expl] [###################.]  95.0%
[impl] [##################..]  90.0%
[expl] [###################.]  95.0%
[impl] [###################.]  95.0%
[expl] [###################.]  95.0%
[impl] [##############......]  70.0%
[expl] [###################.]  95.0%
[impl] [##############......]  70.0%
[expl] [##############......]  70.0%
[impl] [###################.]  95.0%
[expl] [###############.....]  75.0%
[impl] [###############.....]  75.0%
[expl] [###################.]  95.0%
[impl] [##############......]  70.0%
[expl] [###############.....]  75.0%
[impl] [###############.....]  75.0%
[expl] [################....]  80.0%
[impl] [############........]  60.0%
[expl] [###################.]  95.0%
[impl] [#############.......]  65.0%
[expl] [###################.]  95.0%
[impl] [############........]  60.0%
[expl] [##############......]  70.0%
[impl] [#############.......]  65.0%
[expl] [##############......]  70.0%
[impl] [#########