In [1]:
import jax
jax.config.update("jax_enable_x64", True)

In [2]:
from dataclasses import dataclass, field
from typing import Union, List, Optional

def _wrap(x):
    if isinstance(x, (int, float)): return Constant(x)
    return x

@dataclass
class Expression:
    # --- Arithmetic Overloads ---
    def __add__(self, other): return BinaryOp(self, _wrap(other), "+")
    def __radd__(self, other): return BinaryOp(_wrap(other), self, "+")
    
    def __sub__(self, other): return BinaryOp(self, _wrap(other), "-")
    def __rsub__(self, other): return BinaryOp(_wrap(other), self, "-")
    
    def __mul__(self, other): return BinaryOp(self, _wrap(other), "*")
    def __rmul__(self, other): return BinaryOp(_wrap(other), self, "*")
    
    def __truediv__(self, other): return BinaryOp(self, _wrap(other), "/")
    def __rtruediv__(self, other): return BinaryOp(_wrap(other), self, "/")
    
    def __pow__(self, other): return BinaryOp(self, _wrap(other), "**")
    
    def __neg__(self): return UnaryOp(self, "-")

@dataclass
class Symbol(Expression): name: str
@dataclass
class Field(Symbol): rank: int = 0
@dataclass
class Parameter(Symbol): pass
@dataclass
class Constant(Expression): value: float

@dataclass
class UnaryOp(Expression):
    operand: Expression; op: str

@dataclass
class BinaryOp(Expression):
    left: Expression; right: Expression; op: str

@dataclass
class Equation:
    lhs: Expression; rhs: Expression
    def __repr__(self): return f"{self.lhs} == {self.rhs}"

# Functional Helpers
def dt(expr): return UnaryOp(expr, 'dt')
def grad(expr): return UnaryOp(expr, 'grad')
def div(expr): return UnaryOp(expr, 'div')
def laplacian(expr): return UnaryOp(expr, 'laplacian')
def Eq(lhs, rhs): return Equation(lhs, rhs)
def vec(expr): return UnaryOp(expr, 'vec') # Explicit vector construction if needed

In [3]:
import jax.numpy as jnp
#from dsl_core import *

class JAXCompiler:
    def __init__(self, algebra, manifold):
        self.algebra = algebra
        self.manifold = manifold
        self.state_registry = {}
        self.param_registry = {}

    def compile(self, equation: Equation):
        # 1. Residual Form (LHS - RHS)
        if not (isinstance(equation.lhs, UnaryOp) and equation.lhs.op == 'dt'):
             return self._compile_residual(equation)
        
        # 2. Update Form (dt(u) = RHS)
        rhs_expr = equation.rhs
        def kernel(u_dict, params_dict):
            self.state_registry = u_dict
            self.param_registry = params_dict
            return self._eval(rhs_expr)
        return kernel

    def _compile_residual(self, equation):
        def residual_fn(u_dict, params_dict):
            self.state_registry = u_dict
            self.param_registry = params_dict
            return self._eval(equation.lhs) - self._eval(equation.rhs)
        return residual_fn

    def _contract_list(self, val_list):
        """ Converts [v_x, v_y] -> Sum(e_i * v_i) for Geometric Algebra """
        if self.algebra is None: return val_list[0]
        acc = jnp.zeros_like(val_list[0])
        for i, comp in enumerate(val_list):
            if i >= len(self.algebra.basis_matrices): break
            M = self.algebra.basis_matrices[i]
            term = jnp.einsum('kj,...j->...k', M, comp)
            acc = acc + term
        return acc

    def _eval(self, expr: Expression):
        # --- ATOMS ---
        if isinstance(expr, Field): return self.state_registry[expr.name]
        if isinstance(expr, Parameter): return self.param_registry[expr.name]
        if isinstance(expr, Constant): return expr.value

        # --- BINARY OPS ---
        if isinstance(expr, BinaryOp):
            left = self._eval(expr.left)
            right = self._eval(expr.right)
            is_list_l, is_list_r = isinstance(left, list), isinstance(right, list)

            if expr.op == '+': 
                if is_list_l and not is_list_r: return self._contract_list(left) + right
                if not is_list_l and is_list_r: return left + self._contract_list(right)
                if is_list_l and is_list_r: return [l+r for l,r in zip(left, right)]
                return left + right
            
            if expr.op == '-':
                if is_list_l and not is_list_r: return self._contract_list(left) - right
                if not is_list_l and is_list_r: return left - self._contract_list(right)
                if is_list_l and is_list_r: return [l-r for l,r in zip(left, right)]
                return left - right

            if expr.op == '*': return self._apply_product(left, right)
            if expr.op == '/': return left / right if not is_list_l else [l/right for l in left]
            if expr.op == '**': return left ** right

        # --- UNARY OPS ---
        if isinstance(expr, UnaryOp):
            val = self._eval(expr.operand)
            
            if expr.op == '-': 
                return [-v for v in val] if isinstance(val, list) else -val
            
            # --- GRADIENT (Central Difference) ---
            if expr.op == 'grad':
                # Manifold returns (fwd, bwd, lap)
                gf, gb, _ = self.manifold.gradients(val)
                # Average to get Central Difference: (Next - Prev) / 2dx
                return [0.5 * (f + b) for f, b in zip(gf, gb)]
            
            # --- LAPLACIAN ---
            if expr.op == 'laplacian':
                _, _, lap = self.manifold.gradients(val)
                return lap

            # --- DIVERGENCE (Restored!) ---
            if expr.op == 'div':
                if isinstance(val, list):
                    # val is a list of components [Fx, Fy]
                    # div = d(Fx)/dx + d(Fy)/dy
                    div_sum = 0
                    for i, comp in enumerate(val):
                        # Get Central Gradient of component i
                        gf, gb, _ = self.manifold.gradients(comp)
                        grad_central = 0.5 * (gf[i] + gb[i]) # Select derivative along axis i
                        div_sum += grad_central
                    return div_sum
                else:
                    # Fallback if someone asks for div(scalar) -> Laplacian
                    _, _, lap = self.manifold.gradients(val)
                    return lap

        raise NotImplementedError(f"Op {expr} not implemented")

    def _apply_product(self, left, right):
        is_list_l, is_list_r = isinstance(left, list), isinstance(right, list)
        if is_list_r and not is_list_l: return [left * r for r in right]
        if is_list_l and not is_list_r: return [l * right for l in left]
        if is_list_l and is_list_r:
            # Dot Product accumulation
            acc = 0
            for l, r in zip(left, right): acc += l * r
            return acc
        return left * right

In [4]:
class NonUniformMesh1D:
    def __init__(self, L_total, N_points):
        # Create non-uniform grid using tanh or geometric progression
        # Dense in middle (x=0), sparse at ends (-L/2, L/2)
        xi = jnp.linspace(-1, 1, N_points)
        # Mapping function: x = L/2 * sign(xi) * |xi|^power
        # Or Tanh stretching
        k = 3.0 # Stretch factor
        self.nodes = (L_total / 2) * (jnp.tanh(k * xi) / jnp.tanh(k)) + (L_total/2)
        
        # Pre-compute metric (dx)
        # Central difference weights for non-uniform grid are complex.
        # Simplification: Map to uniform Computational Space (xi) + Jacobian
        
        # Simple Finite Difference on irregular grid:
        # dx_fwd[i] = x[i+1] - x[i]
        self.dx_fwd = self.nodes[1:] - self.nodes[:-1]
        self.dx_bwd = jnp.concatenate([self.dx_fwd[0:1], self.dx_fwd]) # Shifted
        self.dx_fwd = jnp.concatenate([self.dx_fwd, self.dx_fwd[-1:]]) # Pad
        
    def gradients(self, u):
        # u: (N,) or (N, 1)
        
        # Forward Diff
        u_next = jnp.roll(u, -1, axis=0)
        u_next = u_next.at[-1].set(u[-1]) # Clamp boundary
        grad_f = (u_next - u) / self.dx_fwd
        
        # Backward Diff
        u_prev = jnp.roll(u, 1, axis=0)
        u_prev = u_prev.at[0].set(u[0])
        grad_b = (u - u_prev) / self.dx_bwd
        
        # Laplacian (Non-uniform 3-point stencil)
        # 2 * ( (u+ - u)/h1 - (u - u-)/h2 ) / (h1 + h2)
        numer = (u_next - u)/self.dx_fwd - (u - u_prev)/self.dx_bwd
        denom = 0.5 * (self.dx_fwd + self.dx_bwd)
        lap = numer / denom
        
        return [grad_f], [grad_b], lap

In [5]:
import jax
import jax.numpy as jnp
from jax.experimental import sparse

class UnstructuredMesh2D:
    def __init__(self, vertices, faces):
        """
        vertices: (N, 2) array of [x, y]
        faces: (M, 3) array of [n1, n2, n3] indices
        """
        self.nodes = vertices
        self.faces = faces
        self.N = len(vertices)
        
        print(f"Building Unstructured Operators for {self.N} nodes...")
        self.Gx, self.Gy, self.L = self._build_operators(vertices, faces)

    def gradients(self, u):
        # Maps generic DSL call 'gradients(u)' to our matrix ops
        # u shape: (N,)
        
        # Gradient is a vector [d/dx, d/dy]
        # In JAX sparse, matmul is distinct
        dx = self.Gx @ u
        dy = self.Gy @ u
        
        # Laplacian
        lap = self.L @ u
        
        # Return format: ([grad_components...], [grad_components_bwd...], lap)
        # For unstructured, we don't distinguish fwd/bwd in this simple reconstruction
        grads = [dx, dy]
        return (grads, grads, lap)

    def _build_operators(self, V, F):
        # 1. Compute Triangle Gradients
        # For a triangle with coords (x1,y1), (x2,y2), (x3,y3)
        # and values u1, u2, u3.
        # The gradient is constant. 
        # Area = 0.5 * det(...)
        
        # Vectorized implementation for M triangles
        v1 = V[F[:, 0]]; v2 = V[F[:, 1]]; v3 = V[F[:, 2]]
        
        x1, y1 = v1.T; x2, y2 = v2.T; x3, y3 = v3.T
        
        # 2A = (x2-x1)(y3-y1) - (x3-x1)(y2-y1)
        two_area = (x2 - x1)*(y3 - y1) - (x3 - x1)*(y2 - y1)
        area = 0.5 * two_area
        
        # Gradients of Basis functions (Linear Elements)
        # b1_x = (y2 - y3) / 2A
        # b1_y = (x3 - x2) / 2A
        b1_x = (y2 - y3) / two_area; b1_y = (x3 - x2) / two_area
        b2_x = (y3 - y1) / two_area; b2_y = (x1 - x3) / two_area
        b3_x = (y1 - y2) / two_area; b3_y = (x2 - x1) / two_area
        
        # 2. Distribute Triangle Gradients to Nodes (Weighted Average by Area)
        # This is the "Lumped Mass" approach.
        # Node Gradient = sum(Area_tri * Grad_tri) / sum(Area_tri)
        # Actually Area cancels out: sum(Area * (1/2A * ...)) = 1/2 * sum(...)
        
        # We build the Sparse Matrix 'G' directly.
        # Rows = Nodes, Cols = Nodes.
        # G_ij contributes to gradient at node i from neighbor j.
        
        # Easier approach for JAX: 
        # G maps (N_nodes) -> (N_nodes).
        # We iterate triangles and accumulate contributions.
        
        # Init dense for simplicity (N < 500), convert to sparse later
        # (For PhD scale, you'd build Coordinate list (row, col, val))
        N = len(V)
        Gx = jnp.zeros((N, N))
        Gy = jnp.zeros((N, N))
        Mass = jnp.zeros(N) # Sum of areas touching a node
        
        # It's cleaner to define the operators in pure numpy first, then move to JAX
        # to avoid slow JAX loops during setup.
        import numpy as np
        Gx_np = np.zeros((N, N)); Gy_np = np.zeros((N, N)); L_np = np.zeros((N, N))
        Mass_np = np.zeros(N)
        
        # Loop over faces (slow in Python, but done once)
        # Optimization: This is standard FEM assembly.
        for i in range(len(F)):
            idx = F[i] # [n1, n2, n3]
            a = area[i]
            
            # Gradients of shape functions
            # dN/dx = [b1_x, b2_x, b3_x]
            # Grad_Tri = dN/dx . U_element
            bx = [b1_x[i], b2_x[i], b3_x[i]]
            by = [b1_y[i], b2_y[i], b3_y[i]]
            
            # Accumulate to all 3 nodes of this triangle
            for row_node in idx:
                Mass_np[row_node] += a
                for k, col_node in enumerate(idx):
                    # Contribution of u[col_node] to Grad at row_node
                    # Weighted by area: Area * (dN_k/dx)
                    # Note: Area * b = 0.5 * (coord diff)
                    Gx_np[row_node, col_node] += a * bx[k]
                    Gy_np[row_node, col_node] += a * by[k]
                    
            # Laplacian Stiffnes Matrix (Cotangent)
            # Standard FEM: K_ij = Integral(grad_Ni . grad_Nj)
            # Local K = Area * (bx_i*bx_j + by_i*by_j)
            for r in range(3):
                row = idx[r]
                for c in range(3):
                    col = idx[c]
                    val = a * (bx[r]*bx[c] + by[r]*by[c])
                    L_np[row, col] += val

        # Normalize Gradients by Mass (Area)
        # G = inv(M) * G_accum
        inv_mass = 1.0 / (Mass_np + 1e-9)
        Gx_np = (Gx_np.T * inv_mass).T
        Gy_np = (Gy_np.T * inv_mass).T
        
        # Laplacian: FEM solves K u = F. 
        # But our code expects 'laplacian(u)' to return the value div(grad u).
        # In Lumped mass approximation: L u = M^-1 K u
        L_np = -1 * (L_np.T * inv_mass).T # Sign convention: laplacian is negative definite
        
        return jnp.array(Gx_np), jnp.array(Gy_np), jnp.array(L_np)

In [6]:
import jax
import jax.numpy as jnp
from jax import jit, lax
import functools
from kingdon import Algebra
from typing import NamedTuple, List

# ==============================================================================
# 1. REUSED ABSTRACTIONS (The Framework)
# ==============================================================================

class MatrixAlgebra:
    """ Pre-computes Geometric Product matrices (Works for 1D, 2D, 3D) """
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg) # 1D=2 (Scalar, Vector), 2D=4
        # Basis generators (e1...)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items(): dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            matrices.append(jnp.stack(cols, axis=1))
        self.basis_matrices = jnp.stack(matrices)

class CartesianBox:
    """ Generic N-Dimensional Grid with Boundary Support """
    def __init__(self, shape, dx):
        self.shape = shape # Tuple (Nx, Ny...)
        self.dx = dx
        self.ndim = len(shape)
        
        # Create coordinates
        coords = [jnp.linspace(0, s*dx, s) for s in shape]
        self.grid = jnp.meshgrid(*coords, indexing='ij')

    @property
    def coordinates(self): return self.grid

    def gradients(self, u):
        # Generic padding for any dimension
        # Pad spatial dims, leave channel dim (last) alone if it exists
        # u shape is (N, N, ..., Channels)
        
        # Determine padding width: ((1,1), (1,1)... (0,0))
        # We assume u has exactly self.ndim spatial dimensions + optional channel dim
        has_channel = (u.ndim > self.ndim)
        pad_width = [(1, 1)] * self.ndim
        if has_channel: pad_width.append((0, 0))
        
        u_pad = jnp.pad(u, tuple(pad_width), mode='edge') 
        
        center_slice = [slice(1, -1)] * self.ndim
        if has_channel: center_slice.append(slice(None))
        u_c = u_pad[tuple(center_slice)]
        
        grads_fwd = []
        grads_bwd = []
        lap_sum = jnp.zeros_like(u_c)
        
        for i in range(self.ndim):
            # Slices for neighbor access along axis i
            # Next neighbor
            s_next = [slice(1, -1)] * self.ndim
            s_next[i] = slice(2, None)
            if has_channel: s_next.append(slice(None))
            
            # Prev neighbor
            s_prev = [slice(1, -1)] * self.ndim
            s_prev[i] = slice(0, -2)
            if has_channel: s_prev.append(slice(None))
            
            u_next = u_pad[tuple(s_next)]
            u_prev = u_pad[tuple(s_prev)]
            
            grads_fwd.append( (u_next - u_c) / self.dx )
            grads_bwd.append( (u_c - u_prev) / self.dx )
            lap_sum += (u_next + u_prev - 2*u_c) / (self.dx**2)
            
        return (grads_fwd, grads_bwd, lap_sum)

    def enforce_boundaries(self, u):
        """ Zero out the spatial boundaries (Hard Walls) """
        # Works for 1D, 2D, 3D...
        res = u
        for i in range(self.ndim):
            # Create a dynamic slice object
            # u.at[..., 0, ...].set(0) where 0 is at axis i
            
            # Left Boundary (Index 0)
            sl_0 = [slice(None)] * u.ndim
            sl_0[i] = 0
            res = res.at[tuple(sl_0)].set(0.0)
            
            # Right Boundary (Index -1)
            sl_1 = [slice(None)] * u.ndim
            sl_1[i] = -1
            res = res.at[tuple(sl_1)].set(0.0)
            
        return res


In [7]:
class TCADSolverComposer:
    def __init__(self, op_factory):
        self.ops = op_factory

    def compile_hybrid_maccormack(self, eq_advection, eq_diffusion):
        # Explicit Advection Operators
        adv_fwd = self.ops.build_functional(eq_advection, mode='fwd')
        adv_bwd = self.ops.build_functional(eq_advection, mode='bwd')
        
        # Implicit Diffusion Operator Factory
        diff_matrix_factory = self.ops.build_matrix(eq_diffusion)
        
        def step_fn(u_curr, params, dt):
            # 1. MacCormack Advection
            k1 = adv_fwd(u_curr, params)
            u_p = u_curr + k1 * dt
            
            k2 = adv_bwd(u_p, params)
            u_adv = 0.5 * (u_curr + u_p + k2 * dt)
            
            # Note: Diffusion is applied via matrix multiply outside this step
            # to keep the JIT loop clean and modular.
            return u_adv 
            
        return step_fn, diff_matrix_factory



In [8]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [9]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import jax.scipy.sparse.linalg as spla 
import optax
#from dsl_core import *
#from sta_inverse_precision import MatrixAlgebra, CartesianBox 

# --- 1. MODULAR COMPONENTS ---

class StencilProvider:
    def __init__(self, manifold):
        self.manifold = manifold
    
    def get_grads(self, u):
        return self.manifold.gradients(u)

class OperatorFactory:
    def __init__(self, algebra, stencil):
        self.algebra = algebra
        self.stencil = stencil
        self.bases = tuple(self.algebra.basis_matrices)

    def build_functional_fused(self, terms):
        """ 
        Returns a kernel that computes specific physics terms.
        terms: list of strings ['adv', 'diff', 'decay']
        """
        # We pre-bake the allowed terms into the closure
        enable_adv = 'adv' in terms
        enable_diff = 'diff' in terms
        enable_decay = 'decay' in terms
        
        def kernel(u, grads, lap, params):
            rate = jnp.zeros_like(u)
            
            # 1. Advection: -c * (Basis . Grad)
            if enable_adv and 'c' in params:
                 c = params['c']
                 # Safety check: if grads is None, we can't compute advection
                 if grads is not None:
                     for i, basis in enumerate(self.bases):
                         if i >= len(grads): break
                         term = jnp.einsum('kj,...j->...k', basis, grads[i])
                         rate = rate - c * term

            # 2. Diffusion: eps * Lap
            if enable_diff and 'epsilon' in params:
                 rate = rate + params['epsilon'] * lap
                 
            # 3. Decay: -V * u
            if enable_decay and 'V' in params:
                 rate = rate - params['V'] * u
                 
            return rate
        return kernel

class MatrixFreeComposer:
    def __init__(self, ops):
        self.ops = ops

    def compile_hybrid(self):
        # Explicit Kernel: Only Advection
        kern_adv = self.ops.build_functional_fused(terms=['adv'])
        
        # Implicit Kernel: Diffusion + Decay
        kern_diff = self.ops.build_functional_fused(terms=['diff', 'decay'])
        
        def step_fn(u_curr, params, dt):
            # --- PHASE 1: EXPLICIT MACCORMACK (Advection) ---
            gf, gb, lap = self.ops.stencil.get_grads(u_curr)
            
            # Predict
            k1 = kern_adv(u_curr, gf, lap, params)
            u_p = u_curr + k1 * dt
            
            # Correct
            gf_p, gb_p, lap_p = self.ops.stencil.get_grads(u_p)
            k2 = kern_adv(u_p, gb_p, lap_p, params)
            
            u_adv = 0.5 * (u_curr + u_p + k2 * dt)
            
            # --- PHASE 2: IMPLICIT DIFFUSION (Matrix-Free) ---
            
            def linear_op(x_flat):
                x = x_flat.reshape(u_curr.shape)
                # We need Laplacian of x for diffusion
                _, _, lap_x = self.ops.stencil.get_grads(x)
                
                # Evaluate Diffusion/Decay only (passed grads=None safely)
                rate = kern_diff(x, None, lap_x, params) 
                
                # Implicit Relation: (I - dt * DiffOp) x = b
                res = x - dt * rate 
                return res.ravel()

            # Solve Ax = b
            u_flat_adv = u_adv.ravel()
            
            # Use u_adv as guess. 
            # If eps=0, linear_op is Identity, converges in 1 iter.
            u_next_flat, _ = spla.cg(linear_op, u_flat_adv, x0=u_flat_adv, tol=1e-5, maxiter=20)
            
            return u_next_flat.reshape(u_curr.shape)
            
        return step_fn

# --- 2. SETUP ---
GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200

algebra = MatrixAlgebra(2, 0)
manifold = CartesianBox((GRID_SIZE, GRID_SIZE), DX)
stencil = StencilProvider(manifold)
ops = OperatorFactory(algebra, stencil)
composer = MatrixFreeComposer(ops)

# Compile (No expression needed, we use the 'terms' logic)
step_fn = composer.compile_hybrid()

print("--- Matrix-Free Hybrid Solver Compiled (Fixed) ---")

# --- 3. RUNNER ---
@jax.jit
def run_matrix_free(params_dict, hyper_params):
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    X, Y = manifold.coordinates
    p = {**params_dict, **hyper_params}
    x0, y0 = p['pos']; vx, vy = p['vel']
    
    def body(carry, i):
        u_curr = carry
        u_next = step_fn(u_curr, p, DT)
        u_next = manifold.enforce_boundaries(u_next)
        
        t = i * DT
        src = jnp.exp(-((X-(x0+vx*t))**2 + (Y-(y0+vy*t))**2)/4.5) * \
              jnp.exp(-(i-50)**2/800.0) * 100.0 * DT
        u_next = u_next.at[..., 0].add(src)
        return u_next, u_next[..., 0]

    _, history = lax.scan(body, u, jnp.arange(DURATION_STEPS))
    return history

# --- 4. PRECISION SOLVE ---
def solve_final_matrix_free():
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    
    print("Generating Ground Truth...")
    true_hist = run_matrix_free(TRUE_PARAMS, {'epsilon': 0.0})
    
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]
    
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.multi_transform(
        {'pos': optax.adam(0.1), 'vel': optax.adam(5.0), 'c': optax.adam(1.0), 'V': optax.adam(0.5)},
        {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}))
    opt_state = optimizer.init(guess)
    
    # Compiled Update Step
    @jax.jit
    def update(state, guess, eps):
        def loss(p):
            sim = run_matrix_free(p, {'epsilon': eps})
            dat = sim[:, SENSORS[:,0], SENSORS[:,1]]
            safe = 1e-6
            sim_n = (dat - jnp.mean(dat, 0)) / (jnp.std(dat, 0) + safe)
            obs_n = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe)
            pb = jnp.sum(jnp.maximum(0, -p['pos'])) + jnp.sum(jnp.maximum(0, p['pos'] - 10.0))
            return (1.0 - jnp.mean(sim_n * obs_n)) + pb
        l, g = value_and_grad(loss)(guess)
        u, s = optimizer.update(g, state, guess)
        return l, optax.apply_updates(guess, u), s

    print("\n--- Starting Matrix-Free Precision Search ---")
    
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        for i in range(201):
            loss, new_guess, opt_state = update(opt_state, guess, eps)
            
            d_vel = jnp.linalg.norm(new_guess['vel'] - guess['vel'])
            guess = new_guess
            
            if i > 20 and d_vel < 0.01:
                print(f"  [CONVERGED] Iter {i} | Vel Delta={d_vel:.4f}")
                print(f"  Current: Vel={guess['vel']} | V={guess['V']:.3f}")
                break
            if i % 20 == 0:
                print(f"  Iter {i:3d}: Loss={loss:.4f} | Vel={guess['vel']} | V={guess['V']:.3f}")

    return guess

final = solve_final_matrix_free()
print("\n--- FINAL RESULT ---")
print(f"Rec Vel: {final['vel']}")

--- Matrix-Free Hybrid Solver Compiled (Fixed) ---
Generating Ground Truth...

--- Starting Matrix-Free Precision Search ---

>>> STAGE 0 (Epsilon=2.0) <<<
  Iter   0: Loss=0.4141 | Vel=[-4.99991963  4.99993126] | V=0.500
  Iter  20: Loss=0.0612 | Vel=[-70.4819555  76.3940099] | V=10.924
  Iter  40: Loss=0.0105 | Vel=[-60.81338231  64.02255718] | V=15.404
  Iter  60: Loss=0.0042 | Vel=[-56.05321931  64.85965103] | V=10.363
  Iter  80: Loss=0.0025 | Vel=[-49.92660438  58.58254573] | V=3.798
  Iter 100: Loss=0.0018 | Vel=[-40.8963796   51.15803643] | V=-0.503
  Iter 120: Loss=0.0015 | Vel=[-32.03709638  44.26590873] | V=-2.163
  Iter 140: Loss=0.0012 | Vel=[-23.66307176  37.50746745] | V=-2.396
  Iter 160: Loss=0.0010 | Vel=[-15.59953386  30.87466563] | V=-2.272
  Iter 180: Loss=0.0008 | Vel=[-7.92960878 24.49939463] | V=-2.165
  Iter 200: Loss=0.0007 | Vel=[-0.77082548 18.48419117] | V=-2.107

>>> STAGE 1 (Epsilon=0.5) <<<
  Iter   0: Loss=0.0006 | Vel=[-0.42078298 18.19470447] | V=-2.0