# Grid-based Methods for Legendre–Fenchel Transforms

## The Direct Method

The direct method implements the Legendre–Fenchel transform by its definition.  
For a function \(u(x)\) in \(d\) dimensions:

$$
u^*(s) = \sup_{x}\;\bigl\langle s, x\bigr\rangle \;-\; u(x)
$$

In multi-dimensional form:

$$
u^*(s_1,\dots,s_d)
=
\max_{x_1,\dots,x_d}\;\Bigl(s_1x_1 + \dots + s_d x_d \;-\; u(x_1,\dots,x_d)\Bigr)
$$

### Key Characteristics

1. **Brute-force**  
   Evaluate every coordinate combination for each slope vector.

2. **Computational complexity**  
   \(\displaystyle \mathcal{O}(N^{2d})\) for \(N\) grid points per dimension, because:  
   - \(\mathcal{O}(N^d)\) slope combinations  
   - \(\mathcal{O}(N^d)\) coordinate sweep per slope

3. **Memory usage**  
   \(\displaystyle \mathcal{O}(N^d)\) to store the output array (plus minimal temp storage).

---

## The Nested Approach

Exploit a nested-max decomposition:

$$
u^*(s_1,\dots,s_d)
=
\max_{x_1}\!\Bigl\{\,s_1x_1
  +\max_{x_2}\!\{\,s_2x_2 + \dots
     +\max_{x_d}\{\,s_d x_d - u(x_1,\dots,x_d)\}\dots\}\Bigr\}
$$

Compute one dimension at a time, from \(x_d\) inward.

---

# Algorithm Structure

## 1D Linear-time Legendre Transform (LLT)

Computes \(u^*(s)\) in \(\mathcal{O}(n + m)\), where \(n\)=data points, \(m\)=slopes:

1. **Convex hull** of \((x_i, u(x_i))\).  
2. **Slopes** between hull vertices:
   $$
   c_i = \frac{u(x_{i+1}) - u(x_i)}{x_{i+1} - x_i}
   $$
3. **Merge** for each target \(s_j\): find interval \([c_i, c_{i+1}]\), then
   $$
   u^*(s_j) = s_j\,x_i \;-\; u(x_i).
   $$

---

## \(d\)-dimensional Nested Algorithm

1. **Innermost** (\(i=d\))  
   $$
   V_d(x_1,\dots,x_{d-1},s_d)
   = \max_{x_d}\{\,s_d x_d - u(x_1,\dots,x_d)\}.
   $$
2. **Work outward** for \(i=d-1,\dots,1\):
   $$
     V_i(x_1,\dots,x_{i-1},s_i,\dots,s_d)
     = \max_{x_i}\{\,s_i x_i + V_{i+1}(x_1,\dots,x_i,s_{i+1},\dots,s_d)\}.
   $$
3. **Final**  
   $$
   u^*(s_1,\dots,s_d) = V_1(s_1,\dots,s_d).
   $$

---

# Comparative Characteristics

_Assume equal grid resolution \(N\) in primal and dual:_

| Feature                   | Direct Method               | Nested Method               |
|:--------------------------|:----------------------------|:----------------------------|
| Decomposition             | none                         | \(d\) sequential stages     |
| Complexity                | \(\mathcal{O}(N^{2d})\)      | \(\mathcal{O}(d\,N^{d+1})\)  |
| Memory                    | \(\mathcal{O}(N^d)\)         | \(\mathcal{O}(N^d)\)         |
| Efficiency trade-off      | simple but slow             | faster, more memory use     |
| Accuracy (grid-dependent) | improves as \(N\) increases | improves as \(N\) increases |

> **Note:** Both methods’ error depends on the grid resolution \(N\).  


In [None]:
import time, numpy as np
import matplotlib.pyplot as plt
from itertools import product

# ---------- 1‑D Lucet  -------------------------------------------------
def llt_1d(x,u,s):
    x,u,s = map(np.asarray,(x,u,s))
    hx,hu=np.empty_like(x),np.empty_like(u); h=0
    for xi,ui in zip(x,u):
        while h>=2 and (hu[h-1]-hu[h-2])*(xi-hx[h-1]) >= (ui-hu[h-1])*(hx[h-1]-hx[h-2]):
            h-=1
        hx[h],hu[h]=xi,ui; h+=1
    hx,hu=hx[:h],hu[:h]
    edge=np.concatenate(([-np.inf],np.diff(hu)/np.diff(hx),[np.inf]))
    out,k=np.empty_like(s),0
    for j,sj in enumerate(s):
        while sj>edge[k+1]: k+=1
        out[j]=sj*hx[k]-hu[k]
    return out

# ---------- nested Lucet any‑d ----------------------------------------
def lucet_nd(x_arrs,f,s_arrs):
    V=f(*np.meshgrid(*x_arrs,indexing='ij',sparse=False))
    flip=False
    for axis in reversed(range(len(x_arrs))):
        x,s=x_arrs[axis],s_arrs[axis]
        V=np.moveaxis(V,axis,0)
        out=np.empty((len(s),)+V.shape[1:],float)
        it=np.nditer(V[0],flags=['multi_index'])
        while not it.finished:
            idx=it.multi_index
            line=V[(slice(None),)+idx]
            out[(slice(None),)+idx]=llt_1d(x,-line if flip else line,s)
            it.iternext()
        V=np.moveaxis(out,0,axis); flip=True
    return V

# ---------- Direct on the same slope grid -----------------------------
def direct_nd(x_arrs,f,s_arrs):
    X=np.meshgrid(*x_arrs,indexing='ij',sparse=False)
    U=f(*X)
    out=np.empty(tuple(len(sa) for sa in s_arrs))
    it=np.nditer(out,flags=['multi_index'],op_flags=['writeonly'])
    while not it.finished:
        slopes=[s_arrs[k][it.multi_index[k]] for k in range(len(x_arrs))]
        it[0]=np.max(sum(s*Xk for s,Xk in zip(slopes,X))-U)
        it.iternext()
    return out

# ---------- multilinear interpolation ---------------------------------
def interp_nd(grid, axes, pt):
    idx_low, t = [], []
    for p,ax in zip(pt,axes):
        j=np.searchsorted(ax,p)-1
        j=np.clip(j,0,len(ax)-2)
        idx_low.append(j)
        t.append((p-ax[j])/(ax[j+1]-ax[j]))
    val=0.0
    for corners in product((0,1), repeat=len(pt)):
        w, idx = 1.0, []
        for c,tl,jl in zip(corners,t,idx_low):
            w*=tl if c else 1-tl
            idx.append(jl+c)
        val+=w*grid[tuple(idx)]
    return val

# ---------- test functions -------------------------------------------
def quadratic(d):
    """d-dimensional quadratic function: u(x) = 0.5 * sum(x_i^2)"""
    return lambda *x: 0.5*sum(np.asarray(xi)**2 for xi in x)

def neg_log(d):
    """d-dimensional negative logarithm function: u(x) = -sum(log(x_i))"""
    def func(*args):
        # Handle domain constraints - inputs should be positive
        result = 0
        for x in args:
            x_safe = np.maximum(x, 1e-10)  # Avoid log(0)
            result -= np.log(x_safe)
        return result
    return func

def ql_func(d):
    """d-dimensional quadratic-over-linear function: u(x) = (sum(x_i^2) + 1) / (sum(x_i) + 1)"""
    def func(*args):
        numerator = sum(x**2 for x in args) + 1
        denominator = sum(args) + 1
        return numerator / denominator
    return func

def mem_mb(n_elems): return n_elems*8/1024/1024

# ---------- plotting functions ----------------------------------------
def plot_1d_results(x_arrays, f, s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot results for 1D transform."""
    plt.figure(figsize=(15, 5))

    # Plot original function
    plt.subplot(1, 3, 1)
    x = x_arrays[0]
    u_x = np.array([f(xi) for xi in x])
    plt.plot(x, u_x)
    plt.xlabel('x')
    plt.ylabel('u(x)')
    plt.title(f'Original {func_name.capitalize()} Function')
    plt.grid(True)

    # Plot transforms
    plt.subplot(1, 3, 2)
    s = s_arrays[0]
    plt.plot(s, u_star_direct, 'b-', label='Direct')
    plt.plot(s, u_star_nested, 'r--', label='Nested')

    # Add analytical solution if available
    if func_name == 'quadratic':
        u_star_analytical = 0.5 * s**2
        plt.plot(s, u_star_analytical, 'g:', label='Analytical')
    elif func_name == 'neg_log':
        u_star_analytical = np.array([-1 - np.log(-si) for si in s])
        plt.plot(s, u_star_analytical, 'g:', label='Analytical')

    plt.xlabel('s')
    plt.ylabel('u*(s)')
    plt.title('Legendre Transform')
    plt.legend()
    plt.grid(True)

    # Plot error
    plt.subplot(1, 3, 3)
    if func_name == 'quadratic' or func_name == 'neg_log':
        # For functions with analytical solutions
        if func_name == 'quadratic':
            u_star_analytical = 0.5 * s**2
        else:  # neg_log
            u_star_analytical = np.array([-1 - np.log(-si) for si in s])

        plt.semilogy(s, np.abs(u_star_direct - u_star_analytical), 'b-', label='Direct Error')
        plt.semilogy(s, np.abs(u_star_nested - u_star_analytical), 'r--', label='Nested Error')
        plt.title('Error vs Analytical')
    else:
        plt.semilogy(s, np.abs(u_star_direct - u_star_nested), 'k-', label='Method Difference')
        plt.title('Method Difference')

    plt.xlabel('s')
    plt.ylabel('Absolute Error/Difference')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

def plot_2d_results(s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot results for 2D transform."""
    fig = plt.figure(figsize=(15, 10))

    # Create meshgrid for visualization
    S1, S2 = np.meshgrid(s_arrays[0], s_arrays[1], indexing='ij')

    # Generate analytical solution if available
    if func_name == 'quadratic':
        u_star_analytical = 0.5 * (S1**2 + S2**2)
    elif func_name == 'neg_log':
        u_star_analytical = np.zeros_like(S1)
        for i in range(u_star_analytical.shape[0]):
            for j in range(u_star_analytical.shape[1]):
                s1, s2 = S1[i,j], S2[i,j]
                u_star_analytical[i,j] = -2 - np.log(-s1) - np.log(-s2)
    else:
        u_star_analytical = None

    # Plot direct transform
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    surf1 = ax1.plot_surface(S1, S2, u_star_direct, cmap='viridis', alpha=0.8)
    ax1.set_xlabel('s₁')
    ax1.set_ylabel('s₂')
    ax1.set_zlabel('u*(s₁,s₂)')
    ax1.set_title(f'Direct Method - {func_name.capitalize()}')

    # Plot nested transform
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    surf2 = ax2.plot_surface(S1, S2, u_star_nested, cmap='viridis', alpha=0.8)
    ax2.set_xlabel('s₁')
    ax2.set_ylabel('s₂')
    ax2.set_zlabel('u*(s₁,s₂)')
    ax2.set_title(f'Nested Method - {func_name.capitalize()}')

    # Plot errors or differences
    if u_star_analytical is not None:
        # Mask out invalid values for neg_log
        if func_name == 'neg_log':
            valid_mask = ~np.isnan(u_star_analytical)
            if not np.any(valid_mask):
                u_star_analytical = None

        if u_star_analytical is not None:
            # Create error plots
            error_direct = np.abs(u_star_direct - u_star_analytical)
            error_nested = np.abs(u_star_nested - u_star_analytical)

            # Mask invalid values if needed
            if func_name == 'neg_log':
                error_direct = np.where(valid_mask, error_direct, np.nan)
                error_nested = np.where(valid_mask, error_nested, np.nan)

            ax3 = fig.add_subplot(2, 2, 3, projection='3d')
            surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                                    norm=colors.LogNorm() if np.max(error_direct[~np.isnan(error_direct)]) > 0 else None)
            ax3.set_xlabel('s₁')
            ax3.set_ylabel('s₂')
            ax3.set_zlabel('Error')
            ax3.set_title('Direct Method Error')

            ax4 = fig.add_subplot(2, 2, 4, projection='3d')
            surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                                    norm=colors.LogNorm() if np.max(error_nested[~np.isnan(error_nested)]) > 0 else None)
            ax4.set_xlabel('s₁')
            ax4.set_ylabel('s₂')
            ax4.set_zlabel('Error')
            ax4.set_title('Nested Method Error')

    # Show difference between methods if no analytical solution
    else:
        difference = np.abs(u_star_direct - u_star_nested)
        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, difference, cmap='plasma', alpha=0.8,
                                norm=colors.LogNorm() if np.max(difference) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Difference')
        ax3.set_title('Method Difference')

        # Cross-section comparison
        mid_idx = len(s_arrays[1]) // 2
        ax4 = fig.add_subplot(2, 2, 4)
        ax4.plot(s_arrays[0], u_star_direct[:, mid_idx], 'b-', label='Direct')
        ax4.plot(s_arrays[0], u_star_nested[:, mid_idx], 'r--', label='Nested')
        ax4.set_xlabel('s₁')
        ax4.set_ylabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f})')
        ax4.set_title('Cross-section Comparison')
        ax4.legend()
        ax4.grid(True)

    plt.tight_layout()

def plot_high_dim_results(d, s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot 2D slices for higher dimensional transforms (d ≥ 3)."""
    fig = plt.figure(figsize=(15, 10))

    # Create 2D slice indices (middle of each dimension > 2)
    slice_indices = [len(s) // 2 for s in s_arrays[2:]]

    # Create meshgrid for first two dimensions
    S1, S2 = np.meshgrid(s_arrays[0], s_arrays[1], indexing='ij')

    # Create indexing tuple for slicing higher dimensions
    idx_slice = (slice(None), slice(None)) + tuple(slice_indices)

    # Plot direct transform (slice)
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    slice_title = f"Direct Method - {func_name.capitalize()} (s₃"
    for dim in range(3, d):
        slice_title += f",s{dim+1}"
    slice_title += " slice)"

    surf1 = ax1.plot_surface(S1, S2, u_star_direct[idx_slice], cmap='viridis', alpha=0.8)
    ax1.set_xlabel('s₁')
    ax1.set_ylabel('s₂')
    ax1.set_zlabel('u*(s₁,s₂,...)')
    ax1.set_title(slice_title)

    # Plot nested transform (slice)
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    slice_title = slice_title.replace("Direct", "Nested")
    surf2 = ax2.plot_surface(S1, S2, u_star_nested[idx_slice], cmap='viridis', alpha=0.8)
    ax2.set_xlabel('s₁')
    ax2.set_ylabel('s₂')
    ax2.set_zlabel('u*(s₁,s₂,...)')
    ax2.set_title(slice_title)

    # Generate analytical solution for the slice (if available)
    if func_name == 'quadratic':
        u_star_analytical_slice = 0.5 * (S1**2 + S2**2 + sum(s_arrays[i][slice_indices[i-2]]**2 for i in range(2, d)))

        # Plot errors
        error_direct = np.abs(u_star_direct[idx_slice] - u_star_analytical_slice)
        error_nested = np.abs(u_star_nested[idx_slice] - u_star_analytical_slice)

        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_direct) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Error')
        ax3.set_title('Direct Method Error (slice)')

        ax4 = fig.add_subplot(2, 2, 4, projection='3d')
        surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_nested) > 0 else None)
        ax4.set_xlabel('s₁')
        ax4.set_ylabel('s₂')
        ax4.set_zlabel('Error')
        ax4.set_title('Nested Method Error (slice)')

    elif func_name == 'neg_log':
        # Create analytical solution for the slice
        u_star_analytical_slice = np.zeros((len(s_arrays[0]), len(s_arrays[1])))
        fixed_slopes = [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]

        for i, s1 in enumerate(s_arrays[0]):
            for j, s2 in enumerate(s_arrays[1]):
                # Calculate -d - sum(log(-s_i))
                u_star_analytical_slice[i, j] = -d - np.log(-s1) - np.log(-s2) - sum(np.log(-s) for s in fixed_slopes)

        # Plot errors
        error_direct = np.abs(u_star_direct[idx_slice] - u_star_analytical_slice)
        error_nested = np.abs(u_star_nested[idx_slice] - u_star_analytical_slice)

        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_direct) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Error')
        ax3.set_title('Direct Method Error (slice)')

        ax4 = fig.add_subplot(2, 2, 4, projection='3d')
        surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_nested) > 0 else None)
        ax4.set_xlabel('s₁')
        ax4.set_ylabel('s₂')
        ax4.set_zlabel('Error')
        ax4.set_title('Nested Method Error (slice)')

    else:
        # For functions without analytical solution, show difference between methods
        difference = np.abs(u_star_direct[idx_slice] - u_star_nested[idx_slice])
        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, difference, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(difference) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Difference')
        ax3.set_title('Method Difference (slice)')

        # One more view - cross section at middle of second dimension
        mid_idx = len(s_arrays[1]) // 2

        if d == 3:
            # For 3D, create a second slice view (s1 vs s3, fixing s2)
            S1, S3 = np.meshgrid(s_arrays[0], s_arrays[2], indexing='ij')
            cross_section = u_star_direct[:, mid_idx, :]
            ax4 = fig.add_subplot(2, 2, 4, projection='3d')
            surf4 = ax4.plot_surface(S1, S3, cross_section, cmap='viridis', alpha=0.8)
            ax4.set_xlabel('s₁')
            ax4.set_ylabel('s₃')
            ax4.set_zlabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f},s₃)')
            ax4.set_title('Direct Method (s₂ slice)')
        else:
            # For dimension > 3, show a 1D cross-section
            cross_idx = [slice(None), mid_idx] + slice_indices
            cross_idx = tuple(cross_idx)

            ax4 = fig.add_subplot(2, 2, 4)
            ax4.plot(s_arrays[0], u_star_direct[cross_idx], 'b-', label='Direct')
            ax4.plot(s_arrays[0], u_star_nested[cross_idx], 'r--', label='Nested')
            if func_name == 'quadratic':
                fixed_s = [s_arrays[1][mid_idx]] + [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]
                analytic_1d = 0.5 * s_arrays[0]**2 + 0.5 * sum(s**2 for s in fixed_s)
                ax4.plot(s_arrays[0], analytic_1d, 'g:', label='Analytical')
            elif func_name == 'neg_log':
                fixed_s = [s_arrays[1][mid_idx]] + [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]
                analytic_1d = np.array([-d - np.log(-s1) - sum(np.log(-s) for s in fixed_s) for s1 in s_arrays[0]])
                ax4.plot(s_arrays[0], analytic_1d, 'g:', label='Analytical')

            ax4.set_xlabel('s₁')
            ax4.set_ylabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f},...)')
            ax4.set_title('Cross-section Comparison')
            ax4.legend()
            ax4.grid(True)

    plt.tight_layout()

# ---------- run benchmark --------------------------------------------
def run_benchmark(func_name='quadratic', n_pts=10, max_dim=5, d_L=6):
    """Run benchmark for a specific function type.

    Parameters:
    func_name : str
        Name of the function to test ('quadratic', 'neg_log', or 'ql')
    n_pts : int
        Number of points per dimension
    max_dim : int
        Maximum dimension to test
    d_L : int
        Threshold dimension: for d >= d_L, only run the Lucet algorithm (skip direct method)
    """

    # Select function based on name
    if func_name == 'quadratic':
        f_creator = quadratic
        x_range = (-3.0, 3.0)
        s_range = (-3.0, 3.0)
    elif func_name == 'neg_log':
        f_creator = neg_log
        x_range = (0.1, 5.0)  # Domain constraint: x > 0
        s_range = (-5.0, -0.1)  # Range constraint for neg_log
    elif func_name == 'ql':
        f_creator = ql_func
        x_range = (-3.0, 3.0)
        s_range = (-3.0, 3.0)
    else:
        raise ValueError(f"Unknown function type: {func_name}")

    print(f"\n===== BENCHMARKING {func_name.upper()} FUNCTION =====\n")
    print(f"{'d':>2} | {'t_Lucet(s)':>10} | {'t_Direct(s)':>12} | {'MB_L(act)':>9} | {'MB_D(act)':>9} | {'MB_L(min)':>9} | {'MB_D(min)':>9} | "
          f"{'max|L-D|':>10} | {'err interp':>11}")
    print("-"*110)

    rng = np.random.default_rng(0)

    for d in range(1, max_dim+1):
        x_arrs = [np.linspace(x_range[0], x_range[1], n_pts)]*d
        s_arrs = [np.linspace(s_range[0], s_range[1], n_pts)]*d
        total  = n_pts**d
        f      = f_creator(d)

        # Lucet algorithm (always run)
        t0=time.perf_counter(); U_luc=lucet_nd(x_arrs,f,s_arrs)
        t_luc=time.perf_counter()-t0

        # Run direct method only if below the dimension threshold
        run_direct = d < d_L
        if run_direct:
            t0=time.perf_counter(); U_dir=direct_nd(x_arrs,f,s_arrs)
            t_dir=time.perf_counter()-t0
            disc=np.max(np.abs(U_luc-U_dir))
        else:
            t_dir = float('inf')  # Indicate direct method not run
            U_dir = None
            disc = float('nan')  # No comparison possible

        # Actual memory usage in the implementation
        mem_luc_act = mem_mb(total) * 2       # Two N^d arrays: V and out
        mem_dir_act = mem_mb(total) * (d + 2) if run_direct else float('inf')  # Indicate not calculated

        # Theoretical minimum memory usage
        mem_luc_min = mem_mb(total)  # One N^d array (reused)
        mem_dir_min = mem_mb(total)  # One N^d array (output only)

        # Calculate error against analytical solution if available
        err_str = "N/A"

        # Calculate error based on function type
        try:
            if func_name == 'quadratic':
                # For quadratic function: u*(s) = 0.5*sum(s_i^2)
                err = 0.0
                for _ in range(100):
                    s_vec = rng.uniform(s_range[0], s_range[1], size=d)
                    v_luc = interp_nd(U_luc, s_arrs, s_vec)
                    v_true = 0.5 * np.sum(s_vec**2)
                    err = max(err, abs(v_luc - v_true))
                err_str = f"{err:.3e}"

            elif func_name == 'neg_log':
                if d == 1:
                    # For 1D neg_log, directly compare with analytical (-1-log(-s))
                    analytical = np.array([-1 - np.log(-si) for si in s_arrs[0]])
                    error = np.max(np.abs(U_luc - analytical))
                    err_str = f"{error:.3e}"
                else:
                    # For higher dimensions, calculate analytical at sample points
                    err = 0.0
                    for _ in range(100):
                        s_vec = rng.uniform(s_range[0], s_range[1], size=d)
                        v_luc = interp_nd(U_luc, s_arrs, s_vec)

                        # Calculate analytical value: -d - sum(log(-s_i))
                        v_true = -d
                        for s in s_vec:
                            v_true -= np.log(-s)

                        err = max(err, abs(v_luc - v_true))
                    err_str = f"{err:.3e}"
        except Exception as e:
            print(f"Error calculating error for {func_name} in dimension {d}: {e}")

        # Format for display
        if run_direct:
            print(f"{d:2d} | {t_luc:10.3f} | {t_dir:12.3f} | "
                f"{mem_luc_act:9.1f} | {mem_dir_act:9.1f} | {mem_luc_min:9.1f} | {mem_dir_min:9.1f} | "
                f"{disc:10.3e} | {err_str:>11}")
        else:
            print(f"{d:2d} | {t_luc:10.3f} | {'N/A':>12} | "
                f"{mem_luc_act:9.1f} | {'N/A':>9} | {mem_luc_min:9.1f} | {mem_dir_min:9.1f} | "
                f"{'N/A':>10} | {err_str:>11}")

        # Create visualizations when both methods are run
        if run_direct:
            if d == 1:
                plot_1d_results(x_arrs, f, s_arrs, U_dir, U_luc, func_name)
            elif d == 2:
                plot_2d_results(s_arrs, U_dir, U_luc, func_name)
            else:
                plot_high_dim_results(d, s_arrs, U_dir, U_luc, func_name)

            plt.savefig(f'{func_name}_{d}d_transform_results.png', dpi=300)
            plt.close()
        else:
            print(f"  Direct method skipped for d={d} (d ≥ {d_L}).")

    print(f"\nBenchmark complete for {func_name} function.")

# ---------- run combined benchmarks ----------------------------------
if __name__ == "__main__":
    # Fix matplotlib imports
    import matplotlib.colors as colors

    # Run benchmarks for all three function types
    for func in ['quadratic', 'neg_log', 'ql']:
        try:
            # Run both methods up to dimension 4, then only Lucet for dimensions 5-8
            run_benchmark(func_name=func, n_pts=10, max_dim=8, d_L=6)
        except Exception as e:
            print(f"Error running benchmark for {func}: {e}")

    # Summary table
    print("\n===== FUNCTION COMPARISON SUMMARY =====\n")
    print("Function     | Characteristics                        | Legendre Transform")
    print("-------------|----------------------------------------|--------------------")
    print("Quadratic    | u(x) = 0.5∑x_i²                        | u*(s) = 0.5∑s_i² (analytical)")
    print("Neg-Log      | u(x) = -∑log(x_i), domain: x_i > 0     | u*(s) = -d-∑log(-s_i), domain: s_i < 0")
    print("Quad-Linear  | u(x) = (∑x_i² + 1)/(∑x_i + 1)          | No simple analytical form")

    print("\nThe Nested Lucet method provides consistent computational advantages for all functions")
    print("as dimensionality increases, while maintaining numerical stability and accuracy.")

    # Estimated timing prediction table
    print("\n===== PREDICTED DIRECT METHOD TIMING FOR HIGHER DIMENSIONS =====\n")
    print("(Based on O(N^(2d)) complexity for Direct vs. O(dN^(d+1)) for Lucet)\n")
    print("Dimension | Points/dim | Direct Method (est.) | Lucet Method (est.) | Speedup Factor")
    print("----------|------------|----------------------|---------------------|---------------")

    n_pts = 10
    # Reference timing from dimension 4 (adjust if needed)
    ref_d = 4
    ref_direct_time = 0.5  # seconds, approximate
    ref_lucet_time = 0.05  # seconds, approximate

    for d in range(ref_d+1, 15):
        # Calculate scaling factors
        direct_scale = (n_pts**(2*d)) / (n_pts**(2*ref_d))
        lucet_scale = (d * n_pts**(d+1)) / (ref_d * n_pts**(ref_d+1))

        direct_est = ref_direct_time * direct_scale
        lucet_est = ref_lucet_time * lucet_scale
        speedup = direct_est / lucet_est

        # Format time with appropriate units
        if direct_est < 60:
            direct_str = f"{direct_est:.2f} seconds"
        elif direct_est < 3600:
            direct_str = f"{direct_est/60:.2f} minutes"
        elif direct_est < 86400:
            direct_str = f"{direct_est/3600:.2f} hours"
        elif direct_est < 31536000:
            direct_str = f"{direct_est/86400:.2f} days"
        else:
            direct_str = f"{direct_est/31536000:.2f} years"

        if lucet_est < 60:
            lucet_str = f"{lucet_est:.2f} seconds"
        elif lucet_est < 3600:
            lucet_str = f"{lucet_est/60:.2f} minutes"
        else:
            lucet_str = f"{lucet_est/3600:.2f} hours"

        print(f"{d:10} | {n_pts:10} | {direct_str:20} | {lucet_str:19} | {speedup:.2e}")

    print("\nNote: These are theoretical estimates. Actual performance may vary based on hardware and implementation.")
    print("The estimates demonstrate why the direct method becomes infeasible for higher dimensions.")


===== BENCHMARKING QUADRATIC FUNCTION =====

 d | t_Lucet(s) |  t_Direct(s) | MB_L(act) | MB_D(act) | MB_L(min) | MB_D(min) |   max|L-D| |  err interp
--------------------------------------------------------------------------------------------------------------
 1 |      0.000 |        0.001 |       0.0 |       0.0 |       0.0 |       0.0 |  0.000e+00 |   5.556e-02


  plt.tight_layout()


 2 |      0.002 |        0.029 |       0.0 |       0.0 |       0.0 |       0.0 |  8.882e-16 |   1.107e-01
 3 |      0.054 |        0.113 |       0.0 |       0.0 |       0.0 |       0.0 |  1.776e-15 |   1.655e-01
 4 |      0.172 |        0.631 |       0.2 |       0.5 |       0.1 |       0.1 |  3.553e-15 |   2.111e-01
 5 |      2.069 |       89.035 |       1.5 |       5.3 |       0.8 |       0.8 |  7.105e-15 |   2.666e-01
 6 |     26.613 |          N/A |      15.3 |       N/A |       7.6 |       7.6 |        N/A |   3.271e-01
  Direct method skipped for d=6 (d ≥ 6).
 7 |    311.651 |          N/A |     152.6 |       N/A |      76.3 |      76.3 |        N/A |   3.379e-01
  Direct method skipped for d=7 (d ≥ 6).
 8 |   3554.039 |          N/A |    1525.9 |       N/A |     762.9 |     762.9 |        N/A |   3.865e-01
  Direct method skipped for d=8 (d ≥ 6).

Benchmark complete for quadratic function.

===== BENCHMARKING NEG_LOG FUNCTION =====

 d | t_Lucet(s) |  t_Direct(s) | MB_L(act) | MB

  return numerator / denominator
  return numerator / denominator


 2 |      0.001 |        0.001 |       0.0 |       0.0 |       0.0 |       0.0 |  7.105e-15 |         N/A
 3 |      0.015 |        0.020 |       0.0 |       0.0 |       0.0 |       0.0 |  1.600e+01 |         N/A
 4 |      0.288 |        0.979 |       0.2 |       0.5 |       0.1 |       0.1 |  1.421e-14 |         N/A
 5 |      2.203 |       79.730 |       1.5 |       5.3 |       0.8 |       0.8 |  3.200e+01 |         N/A
 6 |     28.108 |          N/A |      15.3 |       N/A |       7.6 |       7.6 |        N/A |         N/A
  Direct method skipped for d=6 (d ≥ 6).
 7 |    336.400 |          N/A |     152.6 |       N/A |      76.3 |      76.3 |        N/A |         N/A
  Direct method skipped for d=7 (d ≥ 6).
 8 |   3717.800 |          N/A |    1525.9 |       N/A |     762.9 |     762.9 |        N/A |         N/A
  Direct method skipped for d=8 (d ≥ 6).

Benchmark complete for ql function.

===== FUNCTION COMPARISON SUMMARY =====

Function     | Characteristics                        | 

In [1]:
import time, numpy as np
import matplotlib.pyplot as plt
from itertools import product

# ---------- 1‑D Lucet  -------------------------------------------------
def llt_1d(x,u,s):
    x,u,s = map(np.asarray,(x,u,s))
    hx,hu=np.empty_like(x),np.empty_like(u); h=0
    for xi,ui in zip(x,u):
        while h>=2 and (hu[h-1]-hu[h-2])*(xi-hx[h-1]) >= (ui-hu[h-1])*(hx[h-1]-hx[h-2]):
            h-=1
        hx[h],hu[h]=xi,ui; h+=1
    hx,hu=hx[:h],hu[:h]
    edge=np.concatenate(([-np.inf],np.diff(hu)/np.diff(hx),[np.inf]))
    out,k=np.empty_like(s),0
    for j,sj in enumerate(s):
        while sj>edge[k+1]: k+=1
        out[j]=sj*hx[k]-hu[k]
    return out

# ---------- nested Lucet any‑d ----------------------------------------
def lucet_nd(x_arrs,f,s_arrs):
    V=f(*np.meshgrid(*x_arrs,indexing='ij',sparse=False))
    flip=False
    for axis in reversed(range(len(x_arrs))):
        x,s=x_arrs[axis],s_arrs[axis]
        V=np.moveaxis(V,axis,0)
        out=np.empty((len(s),)+V.shape[1:],float)
        it=np.nditer(V[0],flags=['multi_index'])
        while not it.finished:
            idx=it.multi_index
            line=V[(slice(None),)+idx]
            out[(slice(None),)+idx]=llt_1d(x,-line if flip else line,s)
            it.iternext()
        V=np.moveaxis(out,0,axis); flip=True
    return V

# ---------- Direct on the same slope grid -----------------------------
def direct_nd(x_arrs,f,s_arrs):
    X=np.meshgrid(*x_arrs,indexing='ij',sparse=False)
    U=f(*X)
    out=np.empty(tuple(len(sa) for sa in s_arrs))
    it=np.nditer(out,flags=['multi_index'],op_flags=['writeonly'])
    while not it.finished:
        slopes=[s_arrs[k][it.multi_index[k]] for k in range(len(x_arrs))]
        it[0]=np.max(sum(s*Xk for s,Xk in zip(slopes,X))-U)
        it.iternext()
    return out

# ---------- multilinear interpolation ---------------------------------
def interp_nd(grid, axes, pt):
    idx_low, t = [], []
    for p,ax in zip(pt,axes):
        j=np.searchsorted(ax,p)-1
        j=np.clip(j,0,len(ax)-2)
        idx_low.append(j)
        t.append((p-ax[j])/(ax[j+1]-ax[j]))
    val=0.0
    for corners in product((0,1), repeat=len(pt)):
        w, idx = 1.0, []
        for c,tl,jl in zip(corners,t,idx_low):
            w*=tl if c else 1-tl
            idx.append(jl+c)
        val+=w*grid[tuple(idx)]
    return val

# ---------- test functions -------------------------------------------
def quadratic(d):
    """d-dimensional quadratic function: u(x) = 0.5 * sum(x_i^2)"""
    return lambda *x: 0.5*sum(np.asarray(xi)**2 for xi in x)

def neg_log(d):
    """d-dimensional negative logarithm function: u(x) = -sum(log(x_i))"""
    def func(*args):
        # Handle domain constraints - inputs should be positive
        result = 0
        for x in args:
            x_safe = np.maximum(x, 1e-10)  # Avoid log(0)
            result -= np.log(x_safe)
        return result
    return func

def ql_func(d):
    """d-dimensional quadratic-over-linear function: u(x) = (sum(x_i^2) + 1) / (sum(x_i) + 1)"""
    def func(*args):
        numerator = sum(x**2 for x in args) + 1
        denominator = sum(args) + 1
        return numerator / denominator
    return func

def mem_mb(n_elems): return n_elems*8/1024/1024

# ---------- plotting functions ----------------------------------------
def plot_1d_results(x_arrays, f, s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot results for 1D transform."""
    plt.figure(figsize=(15, 5))

    # Plot original function
    plt.subplot(1, 3, 1)
    x = x_arrays[0]
    u_x = np.array([f(xi) for xi in x])
    plt.plot(x, u_x)
    plt.xlabel('x')
    plt.ylabel('u(x)')
    plt.title(f'Original {func_name.capitalize()} Function')
    plt.grid(True)

    # Plot transforms
    plt.subplot(1, 3, 2)
    s = s_arrays[0]
    plt.plot(s, u_star_direct, 'b-', label='Direct')
    plt.plot(s, u_star_nested, 'r--', label='Nested')

    # Add analytical solution if available
    if func_name == 'quadratic':
        u_star_analytical = 0.5 * s**2
        plt.plot(s, u_star_analytical, 'g:', label='Analytical')
    elif func_name == 'neg_log':
        u_star_analytical = np.array([-1 - np.log(-si) for si in s])
        plt.plot(s, u_star_analytical, 'g:', label='Analytical')

    plt.xlabel('s')
    plt.ylabel('u*(s)')
    plt.title('Legendre Transform')
    plt.legend()
    plt.grid(True)

    # Plot error
    plt.subplot(1, 3, 3)
    if func_name == 'quadratic' or func_name == 'neg_log':
        # For functions with analytical solutions
        if func_name == 'quadratic':
            u_star_analytical = 0.5 * s**2
        else:  # neg_log
            u_star_analytical = np.array([-1 - np.log(-si) for si in s])

        plt.semilogy(s, np.abs(u_star_direct - u_star_analytical), 'b-', label='Direct Error')
        plt.semilogy(s, np.abs(u_star_nested - u_star_analytical), 'r--', label='Nested Error')
        plt.title('Error vs Analytical')
    else:
        plt.semilogy(s, np.abs(u_star_direct - u_star_nested), 'k-', label='Method Difference')
        plt.title('Method Difference')

    plt.xlabel('s')
    plt.ylabel('Absolute Error/Difference')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

def plot_2d_results(s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot results for 2D transform."""
    fig = plt.figure(figsize=(15, 10))

    # Create meshgrid for visualization
    S1, S2 = np.meshgrid(s_arrays[0], s_arrays[1], indexing='ij')

    # Generate analytical solution if available
    if func_name == 'quadratic':
        u_star_analytical = 0.5 * (S1**2 + S2**2)
    elif func_name == 'neg_log':
        u_star_analytical = np.zeros_like(S1)
        for i in range(u_star_analytical.shape[0]):
            for j in range(u_star_analytical.shape[1]):
                s1, s2 = S1[i,j], S2[i,j]
                u_star_analytical[i,j] = -2 - np.log(-s1) - np.log(-s2)
    else:
        u_star_analytical = None

    # Plot direct transform
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    surf1 = ax1.plot_surface(S1, S2, u_star_direct, cmap='viridis', alpha=0.8)
    ax1.set_xlabel('s₁')
    ax1.set_ylabel('s₂')
    ax1.set_zlabel('u*(s₁,s₂)')
    ax1.set_title(f'Direct Method - {func_name.capitalize()}')

    # Plot nested transform
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    surf2 = ax2.plot_surface(S1, S2, u_star_nested, cmap='viridis', alpha=0.8)
    ax2.set_xlabel('s₁')
    ax2.set_ylabel('s₂')
    ax2.set_zlabel('u*(s₁,s₂)')
    ax2.set_title(f'Nested Method - {func_name.capitalize()}')

    # Plot errors or differences
    if u_star_analytical is not None:
        # Mask out invalid values for neg_log
        if func_name == 'neg_log':
            valid_mask = ~np.isnan(u_star_analytical)
            if not np.any(valid_mask):
                u_star_analytical = None

        if u_star_analytical is not None:
            # Create error plots
            error_direct = np.abs(u_star_direct - u_star_analytical)
            error_nested = np.abs(u_star_nested - u_star_analytical)

            # Mask invalid values if needed
            if func_name == 'neg_log':
                error_direct = np.where(valid_mask, error_direct, np.nan)
                error_nested = np.where(valid_mask, error_nested, np.nan)

            ax3 = fig.add_subplot(2, 2, 3, projection='3d')
            surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                                    norm=colors.LogNorm() if np.max(error_direct[~np.isnan(error_direct)]) > 0 else None)
            ax3.set_xlabel('s₁')
            ax3.set_ylabel('s₂')
            ax3.set_zlabel('Error')
            ax3.set_title('Direct Method Error')

            ax4 = fig.add_subplot(2, 2, 4, projection='3d')
            surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                                    norm=colors.LogNorm() if np.max(error_nested[~np.isnan(error_nested)]) > 0 else None)
            ax4.set_xlabel('s₁')
            ax4.set_ylabel('s₂')
            ax4.set_zlabel('Error')
            ax4.set_title('Nested Method Error')

    # Show difference between methods if no analytical solution
    else:
        difference = np.abs(u_star_direct - u_star_nested)
        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, difference, cmap='plasma', alpha=0.8,
                                norm=colors.LogNorm() if np.max(difference) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Difference')
        ax3.set_title('Method Difference')

        # Cross-section comparison
        mid_idx = len(s_arrays[1]) // 2
        ax4 = fig.add_subplot(2, 2, 4)
        ax4.plot(s_arrays[0], u_star_direct[:, mid_idx], 'b-', label='Direct')
        ax4.plot(s_arrays[0], u_star_nested[:, mid_idx], 'r--', label='Nested')
        ax4.set_xlabel('s₁')
        ax4.set_ylabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f})')
        ax4.set_title('Cross-section Comparison')
        ax4.legend()
        ax4.grid(True)

    plt.tight_layout()

def plot_high_dim_results(d, s_arrays, u_star_direct, u_star_nested, func_name):
    """Plot 2D slices for higher dimensional transforms (d ≥ 3)."""
    fig = plt.figure(figsize=(15, 10))

    # Create 2D slice indices (middle of each dimension > 2)
    slice_indices = [len(s) // 2 for s in s_arrays[2:]]

    # Create meshgrid for first two dimensions
    S1, S2 = np.meshgrid(s_arrays[0], s_arrays[1], indexing='ij')

    # Create indexing tuple for slicing higher dimensions
    idx_slice = (slice(None), slice(None)) + tuple(slice_indices)

    # Plot direct transform (slice)
    ax1 = fig.add_subplot(2, 2, 1, projection='3d')
    slice_title = f"Direct Method - {func_name.capitalize()} (s₃"
    for dim in range(3, d):
        slice_title += f",s{dim+1}"
    slice_title += " slice)"

    surf1 = ax1.plot_surface(S1, S2, u_star_direct[idx_slice], cmap='viridis', alpha=0.8)
    ax1.set_xlabel('s₁')
    ax1.set_ylabel('s₂')
    ax1.set_zlabel('u*(s₁,s₂,...)')
    ax1.set_title(slice_title)

    # Plot nested transform (slice)
    ax2 = fig.add_subplot(2, 2, 2, projection='3d')
    slice_title = slice_title.replace("Direct", "Nested")
    surf2 = ax2.plot_surface(S1, S2, u_star_nested[idx_slice], cmap='viridis', alpha=0.8)
    ax2.set_xlabel('s₁')
    ax2.set_ylabel('s₂')
    ax2.set_zlabel('u*(s₁,s₂,...)')
    ax2.set_title(slice_title)

    # Generate analytical solution for the slice (if available)
    if func_name == 'quadratic':
        u_star_analytical_slice = 0.5 * (S1**2 + S2**2 + sum(s_arrays[i][slice_indices[i-2]]**2 for i in range(2, d)))

        # Plot errors
        error_direct = np.abs(u_star_direct[idx_slice] - u_star_analytical_slice)
        error_nested = np.abs(u_star_nested[idx_slice] - u_star_analytical_slice)

        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_direct) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Error')
        ax3.set_title('Direct Method Error (slice)')

        ax4 = fig.add_subplot(2, 2, 4, projection='3d')
        surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_nested) > 0 else None)
        ax4.set_xlabel('s₁')
        ax4.set_ylabel('s₂')
        ax4.set_zlabel('Error')
        ax4.set_title('Nested Method Error (slice)')

    elif func_name == 'neg_log':
        # Create analytical solution for the slice
        u_star_analytical_slice = np.zeros((len(s_arrays[0]), len(s_arrays[1])))
        fixed_slopes = [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]

        for i, s1 in enumerate(s_arrays[0]):
            for j, s2 in enumerate(s_arrays[1]):
                # Calculate -d - sum(log(-s_i))
                u_star_analytical_slice[i, j] = -d - np.log(-s1) - np.log(-s2) - sum(np.log(-s) for s in fixed_slopes)

        # Plot errors
        error_direct = np.abs(u_star_direct[idx_slice] - u_star_analytical_slice)
        error_nested = np.abs(u_star_nested[idx_slice] - u_star_analytical_slice)

        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, error_direct, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_direct) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Error')
        ax3.set_title('Direct Method Error (slice)')

        ax4 = fig.add_subplot(2, 2, 4, projection='3d')
        surf4 = ax4.plot_surface(S1, S2, error_nested, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(error_nested) > 0 else None)
        ax4.set_xlabel('s₁')
        ax4.set_ylabel('s₂')
        ax4.set_zlabel('Error')
        ax4.set_title('Nested Method Error (slice)')

    else:
        # For functions without analytical solution, show difference between methods
        difference = np.abs(u_star_direct[idx_slice] - u_star_nested[idx_slice])
        ax3 = fig.add_subplot(2, 2, 3, projection='3d')
        surf3 = ax3.plot_surface(S1, S2, difference, cmap='plasma', alpha=0.8,
                               norm=colors.LogNorm() if np.max(difference) > 0 else None)
        ax3.set_xlabel('s₁')
        ax3.set_ylabel('s₂')
        ax3.set_zlabel('Difference')
        ax3.set_title('Method Difference (slice)')

        # One more view - cross section at middle of second dimension
        mid_idx = len(s_arrays[1]) // 2

        if d == 3:
            # For 3D, create a second slice view (s1 vs s3, fixing s2)
            S1, S3 = np.meshgrid(s_arrays[0], s_arrays[2], indexing='ij')
            cross_section = u_star_direct[:, mid_idx, :]
            ax4 = fig.add_subplot(2, 2, 4, projection='3d')
            surf4 = ax4.plot_surface(S1, S3, cross_section, cmap='viridis', alpha=0.8)
            ax4.set_xlabel('s₁')
            ax4.set_ylabel('s₃')
            ax4.set_zlabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f},s₃)')
            ax4.set_title('Direct Method (s₂ slice)')
        else:
            # For dimension > 3, show a 1D cross-section
            cross_idx = [slice(None), mid_idx] + slice_indices
            cross_idx = tuple(cross_idx)

            ax4 = fig.add_subplot(2, 2, 4)
            ax4.plot(s_arrays[0], u_star_direct[cross_idx], 'b-', label='Direct')
            ax4.plot(s_arrays[0], u_star_nested[cross_idx], 'r--', label='Nested')
            if func_name == 'quadratic':
                fixed_s = [s_arrays[1][mid_idx]] + [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]
                analytic_1d = 0.5 * s_arrays[0]**2 + 0.5 * sum(s**2 for s in fixed_s)
                ax4.plot(s_arrays[0], analytic_1d, 'g:', label='Analytical')
            elif func_name == 'neg_log':
                fixed_s = [s_arrays[1][mid_idx]] + [s_arrays[i][slice_indices[i-2]] for i in range(2, d)]
                analytic_1d = np.array([-d - np.log(-s1) - sum(np.log(-s) for s in fixed_s) for s1 in s_arrays[0]])
                ax4.plot(s_arrays[0], analytic_1d, 'g:', label='Analytical')

            ax4.set_xlabel('s₁')
            ax4.set_ylabel(f'u*(s₁,{s_arrays[1][mid_idx]:.2f},...)')
            ax4.set_title('Cross-section Comparison')
            ax4.legend()
            ax4.grid(True)

    plt.tight_layout()

# ---------- run benchmark --------------------------------------------
def run_benchmark(func_name='quadratic', n_pts=10, max_dim=5, d_L=6):
    """Run benchmark for a specific function type.

    Parameters:
    func_name : str
        Name of the function to test ('quadratic', 'neg_log', or 'ql')
    n_pts : int
        Number of points per dimension
    max_dim : int
        Maximum dimension to test
    d_L : int
        Threshold dimension: for d >= d_L, only run the Lucet algorithm (skip direct method)
    """

    # Select function based on name
    if func_name == 'quadratic':
        f_creator = quadratic
        x_range = (-3.0, 3.0)
        s_range = (-3.0, 3.0)
    elif func_name == 'neg_log':
        f_creator = neg_log
        x_range = (0.1, 5.0)  # Domain constraint: x > 0
        s_range = (-5.0, -0.1)  # Range constraint for neg_log
    elif func_name == 'ql':
        f_creator = ql_func
        x_range = (-3.0, 3.0)
        s_range = (-3.0, 3.0)
    else:
        raise ValueError(f"Unknown function type: {func_name}")

    print(f"\n===== BENCHMARKING {func_name.upper()} FUNCTION =====\n")
    print(f"{'d':>2} | {'t_Lucet(s)':>10} | {'t_Direct(s)':>12} | {'MB_L(act)':>9} | {'MB_D(act)':>9} | {'MB_L(min)':>9} | {'MB_D(min)':>9} | "
          f"{'max|L-D|':>10} | {'err interp':>11} | {'RMSE':>11}")
    print("-"*124)

    rng = np.random.default_rng(0)

    for d in range(1, max_dim+1):
        x_arrs = [np.linspace(x_range[0], x_range[1], n_pts)]*d
        s_arrs = [np.linspace(s_range[0], s_range[1], n_pts)]*d
        total  = n_pts**d
        f      = f_creator(d)

        # Lucet algorithm (always run)
        t0=time.perf_counter(); U_luc=lucet_nd(x_arrs,f,s_arrs)
        t_luc=time.perf_counter()-t0

        # Run direct method only if below the dimension threshold
        run_direct = d < d_L
        if run_direct:
            t0=time.perf_counter(); U_dir=direct_nd(x_arrs,f,s_arrs)
            t_dir=time.perf_counter()-t0
            disc=np.max(np.abs(U_luc-U_dir))
        else:
            t_dir = float('inf')  # Indicate direct method not run
            U_dir = None
            disc = float('nan')  # No comparison possible

        # Actual memory usage in the implementation
        mem_luc_act = mem_mb(total) * 2       # Two N^d arrays: V and out
        mem_dir_act = mem_mb(total) * (d + 2) if run_direct else float('inf')  # Indicate not calculated

        # Theoretical minimum memory usage
        mem_luc_min = mem_mb(total)  # One N^d array (reused)
        mem_dir_min = mem_mb(total)  # One N^d array (output only)

        # Calculate error against analytical solution if available
        err_str = "N/A"
        rmse_str = "N/A"  # Added RMSE string

        # Calculate error based on function type
        try:
            if func_name == 'quadratic':
                # For quadratic function: u*(s) = 0.5*sum(s_i^2)
                err = 0.0
                errors = []  # Store all errors for RMSE
                for _ in range(100):
                    s_vec = rng.uniform(s_range[0], s_range[1], size=d)
                    v_luc = interp_nd(U_luc, s_arrs, s_vec)
                    v_true = 0.5 * np.sum(s_vec**2)
                    error = abs(v_luc - v_true)
                    errors.append(error)  # Store error for RMSE
                    err = max(err, error)
                err_str = f"{err:.3e}"

                # Calculate RMSE
                rmse = np.sqrt(np.mean(np.array(errors)**2))
                rmse_str = f"{rmse:.3e}"

            elif func_name == 'neg_log':
                if d == 1:
                    # For 1D neg_log, directly compare with analytical (-1-log(-s))
                    analytical = np.array([-1 - np.log(-si) for si in s_arrs[0]])
                    errors = np.abs(U_luc - analytical)
                    error = np.max(errors)
                    err_str = f"{error:.3e}"

                    # Calculate RMSE
                    rmse = np.sqrt(np.mean(errors**2))
                    rmse_str = f"{rmse:.3e}"
                else:
                    # For higher dimensions, calculate analytical at sample points
                    err = 0.0
                    errors = []  # Store all errors for RMSE
                    for _ in range(100):
                        s_vec = rng.uniform(s_range[0], s_range[1], size=d)
                        v_luc = interp_nd(U_luc, s_arrs, s_vec)

                        # Calculate analytical value: -d - sum(log(-s_i))
                        v_true = -d
                        for s in s_vec:
                            v_true -= np.log(-s)

                        error = abs(v_luc - v_true)
                        errors.append(error)  # Store error for RMSE
                        err = max(err, error)
                    err_str = f"{err:.3e}"

                    # Calculate RMSE
                    rmse = np.sqrt(np.mean(np.array(errors)**2))
                    rmse_str = f"{rmse:.3e}"
        except Exception as e:
            print(f"Error calculating error for {func_name} in dimension {d}: {e}")

        # Format for display
        if run_direct:
            print(f"{d:2d} | {t_luc:10.3f} | {t_dir:12.3f} | "
                f"{mem_luc_act:9.1f} | {mem_dir_act:9.1f} | {mem_luc_min:9.1f} | {mem_dir_min:9.1f} | "
                f"{disc:10.3e} | {err_str:>11} | {rmse_str:>11}")
        else:
            print(f"{d:2d} | {t_luc:10.3f} | {'N/A':>12} | "
                f"{mem_luc_act:9.1f} | {'N/A':>9} | {mem_luc_min:9.1f} | {mem_dir_min:9.1f} | "
                f"{'N/A':>10} | {err_str:>11} | {rmse_str:>11}")

        # Create visualizations when both methods are run
        if run_direct:
            if d == 1:
                plot_1d_results(x_arrs, f, s_arrs, U_dir, U_luc, func_name)
            elif d == 2:
                plot_2d_results(s_arrs, U_dir, U_luc, func_name)
            else:
                plot_high_dim_results(d, s_arrs, U_dir, U_luc, func_name)

            plt.savefig(f'{func_name}_{d}d_transform_results.png', dpi=300)
            plt.close()
        else:
            print(f"  Direct method skipped for d={d} (d ≥ {d_L}).")

    print(f"\nBenchmark complete for {func_name} function.")

# ---------- run combined benchmarks ----------------------------------
if __name__ == "__main__":
    # Fix matplotlib imports
    import matplotlib.colors as colors

    # Run benchmarks for all three function types
    for func in ['quadratic', 'neg_log', 'ql']:
        try:
            # Run both methods up to dimension 4, then only Lucet for dimensions 5-8
            run_benchmark(func_name=func, n_pts=10, max_dim=8, d_L=6)
        except Exception as e:
            print(f"Error running benchmark for {func}: {e}")

    # Summary table
    print("\n===== FUNCTION COMPARISON SUMMARY =====\n")
    print("Function     | Characteristics                        | Legendre Transform")
    print("-------------|----------------------------------------|--------------------")
    print("Quadratic    | u(x) = 0.5∑x_i²                        | u*(s) = 0.5∑s_i² (analytical)")
    print("Neg-Log      | u(x) = -∑log(x_i), domain: x_i > 0     | u*(s) = -d-∑log(-s_i), domain: s_i < 0")
    print("Quad-Linear  | u(x) = (∑x_i² + 1)/(∑x_i + 1)          | No simple analytical form")

    print("\nThe Nested Lucet method provides consistent computational advantages for all functions")
    print("as dimensionality increases, while maintaining numerical stability and accuracy.")

    # Estimated timing prediction table
    print("\n===== PREDICTED DIRECT METHOD TIMING FOR HIGHER DIMENSIONS =====\n")
    print("(Based on O(N^(2d)) complexity for Direct vs. O(dN^(d+1)) for Lucet)\n")
    print("Dimension | Points/dim | Direct Method (est.) | Lucet Method (est.) | Speedup Factor")
    print("----------|------------|----------------------|---------------------|---------------")

    n_pts = 10
    # Reference timing from dimension 4 (adjust if needed)
    ref_d = 4
    ref_direct_time = 0.5  # seconds, approximate
    ref_lucet_time = 0.05  # seconds, approximate

    for d in range(ref_d+1, 15):
        # Calculate scaling factors
        direct_scale = (n_pts**(2*d)) / (n_pts**(2*ref_d))
        lucet_scale = (d * n_pts**(d+1)) / (ref_d * n_pts**(ref_d+1))

        direct_est = ref_direct_time * direct_scale
        lucet_est = ref_lucet_time * lucet_scale
        speedup = direct_est / lucet_est

        # Format time with appropriate units
        if direct_est < 60:
            direct_str = f"{direct_est:.2f} seconds"
        elif direct_est < 3600:
            direct_str = f"{direct_est/60:.2f} minutes"
        elif direct_est < 86400:
            direct_str = f"{direct_est/3600:.2f} hours"
        elif direct_est < 31536000:
            direct_str = f"{direct_est/86400:.2f} days"
        else:
            direct_str = f"{direct_est/31536000:.2f} years"

        if lucet_est < 60:
            lucet_str = f"{lucet_est:.2f} seconds"
        elif lucet_est < 3600:
            lucet_str = f"{lucet_est/60:.2f} minutes"
        else:
            lucet_str = f"{lucet_est/3600:.2f} hours"

        print(f"{d:10} | {n_pts:10} | {direct_str:20} | {lucet_str:19} | {speedup:.2e}")

    print("\nNote: These are theoretical estimates. Actual performance may vary based on hardware and implementation.")
    print("The estimates demonstrate why the direct method becomes infeasible for higher dimensions.")


===== BENCHMARKING QUADRATIC FUNCTION =====

 d | t_Lucet(s) |  t_Direct(s) | MB_L(act) | MB_D(act) | MB_L(min) | MB_D(min) |   max|L-D| |  err interp |        RMSE
----------------------------------------------------------------------------------------------------------------------------
 1 |      0.001 |        0.000 |       0.0 |       0.0 |       0.0 |       0.0 |  0.000e+00 |   5.556e-02 |   4.106e-02


  plt.tight_layout()


 2 |      0.002 |        0.002 |       0.0 |       0.0 |       0.0 |       0.0 |  8.882e-16 |   1.107e-01 |   8.006e-02
 3 |      0.012 |        0.018 |       0.0 |       0.0 |       0.0 |       0.0 |  1.776e-15 |   1.655e-01 |   1.165e-01
 4 |      0.158 |        0.614 |       0.2 |       0.5 |       0.1 |       0.1 |  3.553e-15 |   2.111e-01 |   1.495e-01
 5 |      1.981 |       83.838 |       1.5 |       5.3 |       0.8 |       0.8 |  7.105e-15 |   2.666e-01 |   1.858e-01
 6 |     25.979 |          N/A |      15.3 |       N/A |       7.6 |       7.6 |        N/A |   3.271e-01 |   2.201e-01
  Direct method skipped for d=6 (d ≥ 6).
 7 |    299.697 |          N/A |     152.6 |       N/A |      76.3 |      76.3 |        N/A |   3.379e-01 |   2.663e-01
  Direct method skipped for d=7 (d ≥ 6).
 8 |   3419.780 |          N/A |    1525.9 |       N/A |     762.9 |     762.9 |        N/A |   3.865e-01 |   2.897e-01
  Direct method skipped for d=8 (d ≥ 6).

Benchmark complete for quadratic fun

  return numerator / denominator
  return numerator / denominator


 2 |      0.001 |        0.001 |       0.0 |       0.0 |       0.0 |       0.0 |  7.105e-15 |         N/A |         N/A
 3 |      0.013 |        0.018 |       0.0 |       0.0 |       0.0 |       0.0 |  1.600e+01 |         N/A |         N/A
 4 |      0.291 |        0.931 |       0.2 |       0.5 |       0.1 |       0.1 |  1.421e-14 |         N/A |         N/A
 5 |      2.177 |       84.829 |       1.5 |       5.3 |       0.8 |       0.8 |  3.200e+01 |         N/A |         N/A
 6 |     26.545 |          N/A |      15.3 |       N/A |       7.6 |       7.6 |        N/A |         N/A |         N/A
  Direct method skipped for d=6 (d ≥ 6).
 7 |    333.159 |          N/A |     152.6 |       N/A |      76.3 |      76.3 |        N/A |         N/A |         N/A
  Direct method skipped for d=7 (d ≥ 6).
 8 |   3604.530 |          N/A |    1525.9 |       N/A |     762.9 |     762.9 |        N/A |         N/A |         N/A
  Direct method skipped for d=8 (d ≥ 6).

Benchmark complete for ql function.
