In [1]:
import jax
jax.config.update("jax_enable_x64", True)

In [2]:
from dataclasses import dataclass, field
from typing import Union, List, Optional

def _wrap(x):
    if isinstance(x, (int, float)): return Constant(x)
    return x

@dataclass
class Expression:
    # --- Arithmetic Overloads ---
    def __add__(self, other): return BinaryOp(self, _wrap(other), "+")
    def __radd__(self, other): return BinaryOp(_wrap(other), self, "+")
    
    def __sub__(self, other): return BinaryOp(self, _wrap(other), "-")
    def __rsub__(self, other): return BinaryOp(_wrap(other), self, "-")
    
    def __mul__(self, other): return BinaryOp(self, _wrap(other), "*")
    def __rmul__(self, other): return BinaryOp(_wrap(other), self, "*")
    
    def __truediv__(self, other): return BinaryOp(self, _wrap(other), "/")
    def __rtruediv__(self, other): return BinaryOp(_wrap(other), self, "/")
    
    def __pow__(self, other): return BinaryOp(self, _wrap(other), "**")
    
    def __neg__(self): return UnaryOp(self, "-")

@dataclass
class Symbol(Expression): name: str
@dataclass
class Field(Symbol): rank: int = 0
@dataclass
class Parameter(Symbol): pass
@dataclass
class Constant(Expression): value: float

@dataclass
class UnaryOp(Expression):
    operand: Expression; op: str

@dataclass
class BinaryOp(Expression):
    left: Expression; right: Expression; op: str

@dataclass
class Equation:
    lhs: Expression; rhs: Expression
    def __repr__(self): return f"{self.lhs} == {self.rhs}"

# Functional Helpers
def dt(expr): return UnaryOp(expr, 'dt')
def grad(expr): return UnaryOp(expr, 'grad')
def div(expr): return UnaryOp(expr, 'div')
def laplacian(expr): return UnaryOp(expr, 'laplacian')
def Eq(lhs, rhs): return Equation(lhs, rhs)
def vec(expr): return UnaryOp(expr, 'vec') # Explicit vector construction if needed

In [3]:
from dataclasses import dataclass
from typing import Union, List
#from dsl_core import Field, Parameter, Constant

# --- 1. TRAJECTORIES (How sources move) ---
@dataclass
class Trajectory:
    pass

@dataclass
class LinearTrajectory(Trajectory):
    start_pos: Union[Parameter, List[float]]
    velocity: Union[Parameter, List[float]]

@dataclass
class StaticPosition(Trajectory):
    pos: Union[Parameter, List[float]]

# --- 2. ENVELOPES (How sources look in time/space) ---
@dataclass
class GaussianPulse:
    center_step: int
    width_step: float
    amplitude: float

# --- 3. SOURCES (The Drivers) ---
@dataclass
class SourceTerm:
    field: Field  # Which field does this affect?
    
@dataclass
class MovingSpotlight(SourceTerm):
    trajectory: Trajectory
    spatial_sigma: float
    temporal_profile: GaussianPulse

# --- 4. BOUNDARIES (The Constraints) ---
@dataclass
class BoundaryCondition:
    field: Field
    
@dataclass
class DirichletBC(BoundaryCondition):
    value: float = 0.0  # Homogeneous Dirichlet (Walls)

# --- 5. THE PROBLEM CONTAINER ---
@dataclass
class Problem:
    equation: 'Equation'
    bcs: List[BoundaryCondition]
    sources: List[SourceTerm]

In [4]:
import jax
import jax.numpy as jnp
import jax.scipy.sparse.linalg as spla
#from dsl_core import *
#from dsl_problem import *
#from sta_inverse_precision import MatrixAlgebra, CartesianBox 

class ProblemCompiler:
    def __init__(self, algebra, manifold):
        self.algebra = algebra
        self.manifold = manifold
        self.bases = tuple(self.algebra.basis_matrices)
        # Re-use the Auto-Split logic for the physics part
        self.physics_compiler = None 

    def compile(self, problem: Problem):
        # 1. Compile the Physics Kernel (using previous logic)
        # We instantiate the sub-compiler here to keep state clean
        # from dsl_compiler_autosplit_fixed import AutoSplitCompiler
        self.physics_compiler = AutoSplitCompiler(self.algebra, self.manifold)
        physics_step = self.physics_compiler.compile(problem.equation)
        
        # 2. Compile Source Kernels
        source_fns = [self._compile_source(s) for s in problem.sources]
        
        # 3. Compile Boundary Kernels
        bc_fns = [self._compile_bc(b) for b in problem.bcs]
        
        # 4. Generate the Master Stepper
        def master_step_fn(u_curr, params, dt, step_idx):
            # A. Physics Step (Split Hybrid)
            u_next = physics_step(u_curr, params, dt)
            
            # B. Apply Sources
            # We assume sources are additive (Forcing terms)
            t = step_idx * dt
            for src_fn in source_fns:
                contribution = src_fn(params, t, step_idx, self.manifold.coordinates)
                u_next = u_next + contribution
            
            # C. Apply Boundaries
            for bc_fn in bc_fns:
                u_next = bc_fn(u_next)
                
            return u_next

        return master_step_fn

    def _compile_source(self, src):
        if isinstance(src, MovingSpotlight):
            # Pre-compute constants
            amp = src.temporal_profile.amplitude
            t_center = src.temporal_profile.center_step
            t_width = src.temporal_profile.width_step
            sigma_sq = src.spatial_sigma
            
            # Extract parameter names if they are Parameters
            pos_name = src.trajectory.start_pos.name if isinstance(src.trajectory.start_pos, Parameter) else None
            vel_name = src.trajectory.velocity.name if isinstance(src.trajectory.velocity, Parameter) else None
            
            def source_kernel(params, t, i, coords):
                X, Y = coords
                
                # Resolve Position
                if isinstance(src.trajectory, LinearTrajectory):
                    # Dynamic Parameter Lookup
                    p0 = params[pos_name] if pos_name else src.trajectory.start_pos
                    v = params[vel_name] if vel_name else src.trajectory.velocity
                    
                    x_c = p0[0] + v[0] * t
                    y_c = p0[1] + v[1] * t
                else:
                    # Static
                    p0 = params[pos_name] if pos_name else src.trajectory.pos
                    x_c, y_c = p0[0], p0[1]
                
                # Spatial Profile
                dist_sq = (X - x_c)**2 + (Y - y_c)**2
                spatial = jnp.exp(-dist_sq / sigma_sq)
                
                # Temporal Envelope
                temporal = jnp.exp(-(i - t_center)**2 / t_width)
                
                # Construct Field Update (Scalar component only)
                # We return a full field array (N, N, 4) with only component 0 active
                return jnp.zeros_like(X)[..., None] * 0 # Dummy shape match
                # Correct way: Create the full shape
                val = spatial * temporal * amp * dt # Note: Source is Rate * dt
                
                # Place into component 0 (Scalar)
                # We need a clean way to map Field -> Component Index.
                # Assuming Scalar field Psi is index 0 for this algebra.
                
                # Create (N, N, 1) -> (N, N, 4) padding
                zeros = jnp.zeros_like(val)
                # [Val, 0, 0, 0]
                return jnp.stack([val, zeros, zeros, zeros], axis=-1)

            return source_kernel
        
        raise NotImplementedError(f"Unknown Source: {type(src)}")

    def _compile_bc(self, bc):
        if isinstance(bc, DirichletBC):
            if bc.value == 0.0:
                # Optimized 'Wall'
                return self.manifold.enforce_boundaries
            
            def custom_bc(u):
                # Set edges to value
                return u.at[0].set(bc.value).at[-1].set(bc.value) \
                        .at[:,0].set(bc.value).at[:,-1].set(bc.value)
            return custom_bc
            
        raise NotImplementedError

In [5]:
import jax
import jax.numpy as jnp
from jax import jit, lax
import functools
from kingdon import Algebra
from typing import NamedTuple, List

# ==============================================================================
# 1. REUSED ABSTRACTIONS (The Framework)
# ==============================================================================

class MatrixAlgebra:
    """ Pre-computes Geometric Product matrices (Works for 1D, 2D, 3D) """
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg) # 1D=2 (Scalar, Vector), 2D=4
        # Basis generators (e1...)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items(): dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            matrices.append(jnp.stack(cols, axis=1))
        self.basis_matrices = jnp.stack(matrices)

class CartesianBox:
    """ Generic N-Dimensional Grid with Boundary Support """
    def __init__(self, shape, dx):
        self.shape = shape # Tuple (Nx, Ny...)
        self.dx = dx
        self.ndim = len(shape)
        
        # Create coordinates
        coords = [jnp.linspace(0, s*dx, s) for s in shape]
        self.grid = jnp.meshgrid(*coords, indexing='ij')

    @property
    def coordinates(self): return self.grid

    def gradients(self, u):
        # Generic padding for any dimension
        # Pad spatial dims, leave channel dim (last) alone if it exists
        # u shape is (N, N, ..., Channels)
        
        # Determine padding width: ((1,1), (1,1)... (0,0))
        # We assume u has exactly self.ndim spatial dimensions + optional channel dim
        has_channel = (u.ndim > self.ndim)
        pad_width = [(1, 1)] * self.ndim
        if has_channel: pad_width.append((0, 0))
        
        u_pad = jnp.pad(u, tuple(pad_width), mode='edge') 
        
        center_slice = [slice(1, -1)] * self.ndim
        if has_channel: center_slice.append(slice(None))
        u_c = u_pad[tuple(center_slice)]
        
        grads_fwd = []
        grads_bwd = []
        lap_sum = jnp.zeros_like(u_c)
        
        for i in range(self.ndim):
            # Slices for neighbor access along axis i
            # Next neighbor
            s_next = [slice(1, -1)] * self.ndim
            s_next[i] = slice(2, None)
            if has_channel: s_next.append(slice(None))
            
            # Prev neighbor
            s_prev = [slice(1, -1)] * self.ndim
            s_prev[i] = slice(0, -2)
            if has_channel: s_prev.append(slice(None))
            
            u_next = u_pad[tuple(s_next)]
            u_prev = u_pad[tuple(s_prev)]
            
            grads_fwd.append( (u_next - u_c) / self.dx )
            grads_bwd.append( (u_c - u_prev) / self.dx )
            lap_sum += (u_next + u_prev - 2*u_c) / (self.dx**2)
            
        return (grads_fwd, grads_bwd, lap_sum)

    def enforce_boundaries(self, u):
        """ Zero out the spatial boundaries (Hard Walls) """
        # Works for 1D, 2D, 3D...
        res = u
        for i in range(self.ndim):
            # Create a dynamic slice object
            # u.at[..., 0, ...].set(0) where 0 is at axis i
            
            # Left Boundary (Index 0)
            sl_0 = [slice(None)] * u.ndim
            sl_0[i] = 0
            res = res.at[tuple(sl_0)].set(0.0)
            
            # Right Boundary (Index -1)
            sl_1 = [slice(None)] * u.ndim
            sl_1[i] = -1
            res = res.at[tuple(sl_1)].set(0.0)
            
        return res


In [6]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import jax.scipy.sparse.linalg as spla 
import optax
#from dsl_core import *
#from sta_inverse_precision import MatrixAlgebra, CartesianBox 

class AutoSplitCompiler:
    def __init__(self, algebra, manifold):
        self.algebra = algebra
        self.manifold = manifold
        self.bases = tuple(self.algebra.basis_matrices)
        
        # State
        self.u_curr = None
        self.params = None
        self.grads = None # (gf, gb, lap)

    def compile(self, equation: Equation):
        
        # 1. KERNEL GENERATORS (The Evaluators)
        
        def kernel_eval(u, grads, p):
            self.u_curr = u; self.params = p; self.grads = grads
            # Just evaluate the whole tree. 
            # The _eval method will handle missing data (None) by returning Zeros.
            return self._eval(equation.rhs)

        # 2. STEPPER (Matrix-Free Hybrid)
        
        def step_fn(u_curr, params, dt):
            # --- A. EXPLICIT PHASE (Advection) ---
            # We want to compute terms dependent on 'grad'
            # We pass 'None' for laplacian to signal "Ignore Diffusion"
            
            gf, gb, lap = self.manifold.gradients(u_curr)
            
            # Predict
            self.direction = 'fwd'
            # Pass (gf, gb, None) -> Evaluates Grad terms, ignores Lap terms
            k1 = kernel_eval(u_curr, (gf, gb, None), params)
            u_p = u_curr + k1 * dt
            
            # Correct
            gf_p, gb_p, lap_p = self.manifold.gradients(u_p)
            self.direction = 'bwd'
            k2 = kernel_eval(u_p, (gf_p, gb_p, None), params)
            self.direction = 'fwd' # Reset
            
            u_adv = 0.5 * (u_curr + u_p + k2 * dt)
            
            # --- B. IMPLICIT PHASE (Diffusion/Decay) ---
            # Solve (I - dt * ImpOp) x = u_adv
            
            def linear_op(x_flat):
                x = x_flat.reshape(u_curr.shape)
                _, _, lap_x = self.manifold.gradients(x)
                
                # Evaluate Diffusion/Decay
                # Pass (None, None, lap_x) -> Evaluates Lap terms, ignores Grad terms
                rate = kernel_eval(x, (None, None, lap_x), params)
                
                return (x - dt * rate).ravel()

            u_flat = u_adv.ravel()
            u_next_flat, _ = spla.cg(linear_op, u_flat, x0=u_flat, tol=1e-5, maxiter=20)
            return u_next_flat.reshape(u_curr.shape)

        return step_fn

    # --- ROBUST EVALUATOR ---

    def _eval(self, expr):
        if isinstance(expr, Field): return self.u_curr
        if isinstance(expr, Parameter): return self.params[expr.name]
        if isinstance(expr, Constant): return expr.value
        
        if isinstance(expr, UnaryOp):
            if expr.op == 'grad':
                gf, gb, _ = self.grads
                # SAFETY: If grads are None (Implicit Phase), return Zero Vector
                if gf is None: 
                    return [jnp.zeros_like(self.u_curr) for _ in range(self.algebra.dim)]
                
                use_bwd = getattr(self, 'direction', 'fwd') == 'bwd'
                return gb if use_bwd else gf
            
            if expr.op == 'laplacian':
                _, _, lap = self.grads
                # SAFETY: If lap is None (Explicit Phase), return Zero Scalar
                if lap is None:
                    return jnp.zeros_like(self.u_curr)
                return lap
            
            val = self._eval(expr.operand)
            if expr.op == '-': return -val if not isinstance(val, list) else [-x for x in val]

        if isinstance(expr, BinaryOp):
            l = self._eval(expr.left)
            r = self._eval(expr.right)
            
            # Fused Contraction (Optimization)
            if expr.op == '*':
                if isinstance(l, list) and not isinstance(r, list): return self._contract(l, r) 
                if isinstance(r, list) and not isinstance(l, list): return self._contract(r, l)
            
            # Standard Math
            if expr.op == '+': return self._add(l, r)
            if expr.op == '-': return self._add(l, r, sub=True)
            if expr.op == '*': return l * r

        raise NotImplementedError(f"{expr}")

    def _contract(self, vec, scale):
        acc = jnp.zeros_like(vec[0])
        for i, b in enumerate(self.bases):
            # Check for zero-vector placeholder
            if i >= len(vec): break
            acc += jnp.einsum('kj,...j->...k', b, vec[i]) * scale
        return acc

    def _add(self, l, r, sub=False):
        if isinstance(l, list): l = self._contract(l, 1.0)
        if isinstance(r, list): r = self._contract(r, 1.0)
        return (l - r) if sub else (l + r)

In [7]:
# --- 1. SETUP ---
GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200

algebra = MatrixAlgebra(2, 0)
manifold = CartesianBox((GRID_SIZE, GRID_SIZE), DX)

# --- 2. THE DSL DEFINITION (RESTORED!) ---
Psi = Field("Psi")
c   = Parameter("c")
V   = Parameter("V")
eps = Parameter("epsilon")

# The User writes Physics:
# "Rate is negative divergence of flux (-c grad) plus diffusion (eps lap) minus decay (V psi)"
eq_physics = Eq(dt(Psi), -(c * grad(Psi)) + eps * laplacian(Psi) - V * Psi)

# --- 3. THE COMPILER MAGIC ---
# The compiler analyzes 'eq_physics', sees 'grad' (Explicit) and 'lap' (Implicit),
# and generates the Matrix-Free Hybrid MacCormack solver automatically.
compiler = AutoSplitCompiler(algebra, manifold)
step_fn = compiler.compile(eq_physics)

print("--- DSL Compiled: Auto-Split Matrix-Free Hybrid Solver ---")

# --- 4. THE RUNNER (Standard) ---
@jax.jit
def run_simulation(params_dict, hyper_params):
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    X, Y = manifold.coordinates
    p = {**params_dict, **hyper_params}
    x0, y0 = p['pos']; vx, vy = p['vel']
    
    def body(carry, i):
        u_curr = carry
        u_next = step_fn(u_curr, p, DT)
        u_next = manifold.enforce_boundaries(u_next)
        
        # Source
        t = i * DT
        src = jnp.exp(-((X-(x0+vx*t))**2 + (Y-(y0+vy*t))**2)/4.5) * \
              jnp.exp(-(i-50)**2/800.0) * 100.0 * DT
        u_next = u_next.at[..., 0].add(src)
        return u_next, u_next[..., 0]

    _, history = lax.scan(body, u, jnp.arange(DURATION_STEPS))
    return history

# --- 5. PRECISION SOLVE ---
def solve_restored_dsl():
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    print("Generating Ground Truth...")
    true_hist = run_simulation(TRUE_PARAMS, {'epsilon': 0.0})
    
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]
    
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.multi_transform(
        {'pos': optax.adam(0.1), 'vel': optax.adam(5.0), 'c': optax.adam(1.0), 'V': optax.adam(0.5)},
        {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}))
    opt_state = optimizer.init(guess)
    
    @jax.jit
    def update(state, guess, eps):
        def loss(p):
            sim = run_simulation(p, {'epsilon': eps})
            dat = sim[:, SENSORS[:,0], SENSORS[:,1]]
            safe = 1e-6
            sim_n = (dat - jnp.mean(dat, 0)) / (jnp.std(dat, 0) + safe)
            obs_n = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe)
            pb = jnp.sum(jnp.maximum(0, -p['pos'])) + jnp.sum(jnp.maximum(0, p['pos'] - 10.0))
            return (1.0 - jnp.mean(sim_n * obs_n)) + pb
        l, g = value_and_grad(loss)(guess)
        u, s = optimizer.update(g, state, guess)
        return l, optax.apply_updates(guess, u), s

    print("\n--- Starting DSL Precision Search ---")
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        for i in range(201):
            loss, new_guess, opt_state = update(opt_state, guess, eps)
            d_vel = jnp.linalg.norm(new_guess['vel'] - guess['vel'])
            guess = new_guess
            
            if i > 20 and d_vel < 0.01:
                print(f"  [CONVERGED] Iter {i} | Vel Delta={d_vel:.4f}")
                print(f"  Current: Vel={guess['vel']} | V={guess['V']:.3f}")
                break
            if i % 20 == 0:
                print(f"  Iter {i:3d}: Loss={loss:.4f} | Vel={guess['vel']} | V={guess['V']:.3f}")

    return guess

final = solve_restored_dsl()
print("\n--- FINAL RESULT ---")
print(f"Rec Pos: {final['pos']}")
print(f"Rec Vel: {final['vel']}")

--- DSL Compiled: Auto-Split Matrix-Free Hybrid Solver ---
Generating Ground Truth...

--- Starting DSL Precision Search ---

>>> STAGE 0 (Epsilon=2.0) <<<
  Iter   0: Loss=0.4137 | Vel=[-4.99991962  4.99993124] | V=0.500
  Iter  20: Loss=0.0590 | Vel=[-70.98826726  76.79999785] | V=10.779
  Iter  40: Loss=0.0132 | Vel=[-60.67740417  64.05549699] | V=12.061
  Iter  60: Loss=0.0038 | Vel=[-55.33345614  64.17209336] | V=4.109
  Iter  80: Loss=0.0021 | Vel=[-49.41404511  58.09059796] | V=-1.214
  Iter 100: Loss=0.0018 | Vel=[-40.93813899  51.22133214] | V=-1.542
  Iter 120: Loss=0.0015 | Vel=[-32.18231407  44.42601499] | V=-0.865
  Iter 140: Loss=0.0012 | Vel=[-23.70211689  37.55957809] | V=-0.821
  Iter 160: Loss=0.0010 | Vel=[-15.62255942  30.9099544 ] | V=-0.855
  Iter 180: Loss=0.0008 | Vel=[-7.97006363 24.55144065] | V=-0.827
  Iter 200: Loss=0.0007 | Vel=[-0.81004635 18.53451075] | V=-0.808

>>> STAGE 1 (Epsilon=0.5) <<<
  Iter   0: Loss=0.0006 | Vel=[-0.45961679 18.24462019] | V=-0

In [8]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import jax.scipy.sparse.linalg as spla
import optax
#from dsl_core import *
#from dsl_problem import *
#from sta_inverse_precision import MatrixAlgebra, CartesianBox 

# --- 1. ROBUST PHYSICS ENGINE (The "Backend") ---
# Instead of auto-splitting, we use the proven Matrix-Free logic directly.

class RobustHybridSolver:
    def __init__(self, algebra, manifold):
        self.algebra = algebra
        self.manifold = manifold
        self.bases = tuple(self.algebra.basis_matrices)

    def compile_step(self, problem: Problem):
        """
        Hard-coded compilation of: dt(u) = -c*grad(u) + eps*lap(u) - V*u
        This bypasses the fragile AST walker for the physics core.
        """
        
        # 1. Source & BC Compilers (These work fine)
        # We assume ProblemCompiler logic for sources is reused or inlined here
        source_fns = [self._compile_source(s) for s in problem.sources]
        bc_fns = [self._compile_bc(b) for b in problem.bcs]
        
        # 2. The Proven Matrix-Free Kernels
        def eval_advection(u, grads, params):
            # Explicit Advection: -c * (basis . grad)
            c = params['c']
            rate = jnp.zeros_like(u)
            for i, basis in enumerate(self.bases):
                term = jnp.einsum('kj,...j->...k', basis, grads[i])
                rate = rate - c * term
            return rate

        def eval_implicit(u, lap, params):
            # Implicit Diffusion/Decay: eps*lap - V*u
            rate = jnp.zeros_like(u)
            if 'epsilon' in params: rate += params['epsilon'] * lap
            if 'V' in params: rate -= params['V'] * u
            return rate

        # 3. The Master Step
        def step_fn(u_curr, params, dt, step_idx):
            # --- A. PHYSICS (Matrix-Free Hybrid) ---
            
            # 1. Explicit MacCormack (Advection)
            gf, gb, lap = self.manifold.gradients(u_curr)
            
            # Predict
            k1 = eval_advection(u_curr, gf, params)
            u_p = u_curr + k1 * dt
            
            # Correct
            gf_p, gb_p, _ = self.manifold.gradients(u_p)
            k2 = eval_advection(u_p, gb_p, params)
            
            u_adv = 0.5 * (u_curr + u_p + k2 * dt)
            
            # 2. Implicit CG (Diffusion)
            def linear_op(x_flat):
                x = x_flat.reshape(u_curr.shape)
                _, _, lap_x = self.manifold.gradients(x)
                rate = eval_implicit(x, lap_x, params)
                return (x - dt * rate).ravel()
            
            u_flat = u_adv.ravel()
            # If eps=0, this is Identity solve (instant)
            u_next_flat, _ = spla.cg(linear_op, u_flat, x0=u_flat, tol=1e-5, maxiter=20)
            u_next = u_next_flat.reshape(u_curr.shape)
            
            # --- B. SOURCES ---
            t = step_idx * dt
            for src_fn in source_fns:
                u_next = u_next + src_fn(params, t, step_idx, self.manifold.coordinates)
                
            # --- C. BOUNDARIES ---
            for bc_fn in bc_fns:
                u_next = bc_fn(u_next)
                
            return u_next

        return step_fn

    # (Inlined Source/BC compilers for completeness)
    def _compile_source(self, src):
        # ... (Same logic as before) ...
        amp = src.temporal_profile.amplitude
        t_center = src.temporal_profile.center_step
        t_width = src.temporal_profile.width_step
        sigma_sq = src.spatial_sigma
        
        pos_name = src.trajectory.start_pos.name
        vel_name = src.trajectory.velocity.name
        
        def kernel(params, t, i, coords):
            X, Y = coords
            p0, v = params[pos_name], params[vel_name]
            x_c = p0[0] + v[0] * t
            y_c = p0[1] + v[1] * t
            
            dist_sq = (X - x_c)**2 + (Y - y_c)**2
            val = jnp.exp(-dist_sq / sigma_sq) * \
                  jnp.exp(-(i - t_center)**2 / t_width) * amp * 0.0001 # dt hardcoded or passed?
            
            zeros = jnp.zeros_like(val)
            return jnp.stack([val, zeros, zeros, zeros], axis=-1)
        return kernel

    def _compile_bc(self, bc):
        return self.manifold.enforce_boundaries

# --- 2. CONFIGURATION ---
GRID_SIZE = 100
DX = 0.1
DT = 0.0001
STEPS = 1200

algebra = MatrixAlgebra(2, 0)
manifold = CartesianBox((GRID_SIZE, GRID_SIZE), DX)

# --- 3. DECLARATIVE SETUP (The User's View) ---
Psi = Field("Psi")
c, V, eps = Parameter("c"), Parameter("V"), Parameter("epsilon")
pos, vel = Parameter("pos"), Parameter("vel")

# Equation (Symbolic only now, but guides the architecture)
eq = Eq(dt(Psi), -(c * grad(Psi)) + eps * laplacian(Psi) - V * Psi)

source = MovingSpotlight(
    field=Psi,
    trajectory=LinearTrajectory(start_pos=pos, velocity=vel),
    spatial_sigma=4.5,
    temporal_profile=GaussianPulse(center_step=50, width_step=800.0, amplitude=100.0)
)
problem = Problem(eq, [DirichletBC(Psi)], [source])

# --- 4. COMPILATION & EXECUTION ---
solver = RobustHybridSolver(algebra, manifold)
step_fn = solver.compile_step(problem)

print("--- Robust Hybrid Solver Compiled ---")

@jax.jit
def run_sim(params_dict, hyper_params):
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    p = {**params_dict, **hyper_params}
    
    def body(carry, i):
        u_next = step_fn(carry, p, DT, i)
        return u_next, u_next[..., 0] # Return full state + Scalar for recording

    _, history = lax.scan(body, u, jnp.arange(STEPS))
    return history

def solve_robust():
    # TRUTH
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    print("Generating Ground Truth...")
    true_hist = run_sim(TRUE_PARAMS, {'epsilon': 0.0})
    
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]
    
    # INVERSE
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.multi_transform(
        {'pos': optax.adam(0.1), 'vel': optax.adam(5.0), 'c': optax.adam(1.0), 'V': optax.adam(0.5)},
        {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}))
    opt_state = optimizer.init(guess)
    
    @jax.jit
    def update(state, guess, eps):
        def loss(p):
            sim = run_sim(p, {'epsilon': eps})
            dat = sim[:, SENSORS[:,0], SENSORS[:,1]]
            safe = 1e-6
            sim_n = (dat - jnp.mean(dat, 0)) / (jnp.std(dat, 0) + safe)
            obs_n = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe)
            pb = jnp.sum(jnp.maximum(0, -p['pos'])) + jnp.sum(jnp.maximum(0, p['pos'] - 10.0))
            return (1.0 - jnp.mean(sim_n * obs_n)) + pb
        l, g = value_and_grad(loss)(guess)
        u, s = optimizer.update(g, state, guess)
        return l, optax.apply_updates(guess, u), s

    print("\n--- Starting Robust Precision Search ---")
    # Using small epsilon steps to guide it
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        for i in range(201):
            loss, new_guess, opt_state = update(opt_state, guess, eps)
            
            d_vel = jnp.linalg.norm(new_guess['vel'] - guess['vel'])
            guess = new_guess
            
            if i > 20 and d_vel < 0.01:
                print(f"  [CONVERGED] Iter {i} | Vel Delta={d_vel:.4f}")
                print(f"  Current: Vel={guess['vel']} | V={guess['V']:.3f}")
                break
            if i % 20 == 0:
                print(f"  Iter {i:3d}: Loss={loss:.4f} | Vel={guess['vel']} | V={guess['V']:.3f}")
    return guess

final = solve_robust()
print("\n--- FINAL RESULT ---")
print(f"Rec Pos: {final['pos']}")
print(f"Rec Vel: {final['vel']}")
print(f"Rec V  : {final['V']:.3f}")

--- Robust Hybrid Solver Compiled ---
Generating Ground Truth...

--- Starting Robust Precision Search ---

>>> STAGE 0 (Epsilon=2.0) <<<
  Iter   0: Loss=0.4141 | Vel=[-4.99991965  4.99993128] | V=0.500
  Iter  20: Loss=0.0611 | Vel=[-70.46447249  76.37947397] | V=10.926
  Iter  40: Loss=0.0105 | Vel=[-60.79204744  64.00388918] | V=15.385
  Iter  60: Loss=0.0042 | Vel=[-56.06215855  64.87621874] | V=10.309
  Iter  80: Loss=0.0025 | Vel=[-49.93899401  58.59868025] | V=3.741
  Iter 100: Loss=0.0018 | Vel=[-40.90998369  51.18224977] | V=-0.540
  Iter 120: Loss=0.0015 | Vel=[-32.0567892   44.30185107] | V=-2.176
  Iter 140: Loss=0.0012 | Vel=[-23.687766    37.55173961] | V=-2.394
  Iter 160: Loss=0.0010 | Vel=[-15.62705178  30.92458294] | V=-2.266
  Iter 180: Loss=0.0008 | Vel=[-7.9584033 24.5530991] | V=-2.161
  Iter 200: Loss=0.0007 | Vel=[-0.79946236 18.53976131] | V=-2.105

>>> STAGE 1 (Epsilon=0.5) <<<
  Iter   0: Loss=0.0006 | Vel=[-0.44914677 18.25010379] | V=-2.094
  Iter  20: Los