In [2]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import optax

# ==============================================================================
# 1. SETUP
# ==============================================================================
GRID_SIZE = 100
DX = 0.1
DT = 0.0001
C_SOUND = 343.0
DURATION_STEPS = 1200 # Increased to 0.12s to help velocity tracking

IDX_P = 0
IDX_VX = 1
IDX_VY = 2

# ==============================================================================
# 2. STABLE MACCORMACK KERNEL
# ==============================================================================

@jit
def get_gradients_fwd(u):
    u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
    u_center = u_pad[1:-1, 1:-1]
    
    # Forward Diff
    d_dy = (u_pad[2:, 1:-1] - u_center) / DX
    d_dx = (u_pad[1:-1, 2:] - u_center) / DX
    
    lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
           u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (DX**2)
    return d_dx, d_dy, lap

@jit
def get_gradients_bwd(u):
    u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
    u_center = u_pad[1:-1, 1:-1]
    
    # Backward Diff
    d_dy = (u_center - u_pad[0:-2, 1:-1]) / DX
    d_dx = (u_center - u_pad[1:-1, 0:-2]) / DX
        
    lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
           u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (DX**2)
    return d_dx, d_dy, lap

@jit
def physics_update(u, d_dx, d_dy, lap, epsilon):
    delta = jnp.zeros_like(u)
    div_v = d_dx[..., IDX_VX] + d_dy[..., IDX_VY]
    
    rate_p = -C_SOUND * div_v + epsilon * lap[..., IDX_P]
    rate_vx = -C_SOUND * d_dx[..., IDX_P] + epsilon * lap[..., IDX_VX]
    rate_vy = -C_SOUND * d_dy[..., IDX_P] + epsilon * lap[..., IDX_VY]
    
    delta = delta.at[..., IDX_P].set(rate_p)
    delta = delta.at[..., IDX_VX].set(rate_vx)
    delta = delta.at[..., IDX_VY].set(rate_vy)
    return delta

@jit
def full_step_maccormack(u, epsilon):
    d_dx_f, d_dy_f, lap = get_gradients_fwd(u)
    rate_pred = physics_update(u, d_dx_f, d_dy_f, lap, epsilon)
    u_pred = u + rate_pred * DT
    
    d_dx_b, d_dy_b, lap_b = get_gradients_bwd(u_pred)
    rate_corr = physics_update(u_pred, d_dx_b, d_dy_b, lap_b, epsilon)
    
    u_next = 0.5 * (u + u_pred + rate_corr * DT)
    
    # Hard Wall Boundaries
    u_next = u_next.at[0, :, :].set(0)
    u_next = u_next.at[-1, :, :].set(0)
    u_next = u_next.at[:, 0, :].set(0)
    u_next = u_next.at[:, -1, :].set(0)
    
    return u_next

# ==============================================================================
# 3. SIMULATION LOOP (With Split Params)
# ==============================================================================

@jit
def run_simulation(params_dict, epsilon):
    # Unpack Dictionary
    pos = params_dict['pos']
    vel = params_dict['vel']
    
    x0, y0 = pos[0], pos[1]
    vx, vy = vel[0], vel[1]
    
    xs = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    ys = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    X, Y = jnp.meshgrid(xs, ys, indexing='xy')
    
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, 3))
    
    # Blob Size
    blob_width = 0.5 + epsilon * 2.0
    
    def body_fn(carry, step_idx):
        u_curr = carry
        current_time = step_idx * DT
        
        # Physics
        u_next = full_step_maccormack(u_curr, epsilon)
        
        # Moving Source
        pos_x = x0 + vx * current_time
        pos_y = y0 + vy * current_time
        
        # Injection
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        spatial = jnp.exp(-dist_sq / (2 * blob_width**2))
        
        # Temporal "Ping" (Gaussian in time)
        # Peak at step 50, width 20
        amplitude = jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0
        
        source = spatial * amplitude * DT
        u_next = u_next.at[..., IDX_P].add(source)
        
        return u_next, u_next[..., IDX_P]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 4. INVERSE SOLVER (Differential Learning Rates)
# ==============================================================================

# True State: Moves [3, 7] -> [40, -20]
TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0])}

SENSORS_RC = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])

print(f"--- Generating Ground Truth ---")
true_history = run_simulation(TRUE_PARAMS, epsilon=0.0)
observed_data = true_history[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]

def loss_fn(est_params_dict, epsilon):
    sim_hist = run_simulation(est_params_dict, epsilon)
    sim_data = sim_hist[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]
    
    # Normalize (Correlation)
    safe_eps = 1e-6
    sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
    obs_norm = (observed_data - jnp.mean(observed_data, 0)) / (jnp.std(observed_data, 0) + safe_eps)
    
    corr = jnp.mean(sim_norm * obs_norm)
    
    # Boundary Penalty (Position Only)
    pos = est_params_dict['pos']
    bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
    
    return (1.0 - corr) + bounds

def solve():
    # Initial Guess: Static at Center
    guess = {
        'pos': jnp.array([5.0, 5.0]),
        'vel': jnp.array([0.0, 0.0])
    }
    
    # DIFFERENTIAL LEARNING RATES
    # Velocity gets a 50x multiplier to match Position gradients
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {
                'pos': optax.adam(0.1),
                'vel': optax.adam(5.0)  # The Fix: 5.0 LR for Velocity
            },
            param_labels=lambda p: 'vel' if 'vel' in str(p) else 'pos' # Simple mapping won't work on dict keys directly in newer optax
            # We map manually below
        )
    )
    
    # Manual Mapping for optax.multi_transform on Dicts
    # We define the transform structure to match the params dict structure
    partition_optimizers = {
        'pos': optax.adam(0.1),
        'vel': optax.adam(5.0)
    }
    
    # Correct way to apply multi_transform to a PyTree (Dict)
    # We assign labels matching the keys
    param_spec = {'pos': 'pos', 'vel': 'vel'}
    
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(partition_optimizers, param_spec)
    )

    opt_state = optimizer.init(guess)
    
    epsilon_schedule = [2.0, 0.5, 0.0]
    
    print(f"\n--- Starting Search ---")
    
    for stage, eps in enumerate(epsilon_schedule):
        print(f"\n>>> ENTERING STAGE {stage} (Slack Epsilon = {eps}) <<<")
        for i in range(100): 
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            
            if i % 10 == 0:
                p = guess['pos']
                v = guess['vel']
                print(f"   Iter {i}: Loss {loss:.5f}")
                print(f"      Est: Pos=[{p[0]:.2f}, {p[1]:.2f}] Vel=[{v[0]:.2f}, {v[1]:.2f}]")
                
    print(f"\n--- FINAL RESULT ---")
    p_final = guess['pos']
    v_final = guess['vel']
    
    print(f"Estimated: Pos={p_final} Vel={v_final}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")
    
    pos_err = jnp.linalg.norm(p_final - TRUE_PARAMS['pos'])
    vel_err = jnp.linalg.norm(v_final - TRUE_PARAMS['vel'])
    
    print(f"Pos Error: {pos_err:.4f} m")
    print(f"Vel Error: {vel_err:.4f} m/s")

if __name__ == "__main__":
    solve()

--- Generating Ground Truth ---

--- Starting Search ---

>>> ENTERING STAGE 0 (Slack Epsilon = 2.0) <<<
   Iter 0: Loss 0.70650
      Est: Pos=[4.90, 5.10] Vel=[-5.00, 5.00]
   Iter 10: Loss 0.65549
      Est: Pos=[3.91, 6.09] Vel=[-54.23, 54.30]
   Iter 20: Loss 0.62788
      Est: Pos=[3.04, 6.97] Vel=[-97.33, 98.10]
   Iter 30: Loss 0.62065
      Est: Pos=[2.44, 7.61] Vel=[-125.68, 128.66]
   Iter 40: Loss 0.62066
      Est: Pos=[2.19, 7.93] Vel=[-136.09, 142.57]
   Iter 50: Loss 0.62081
      Est: Pos=[2.18, 7.99] Vel=[-133.45, 143.25]
   Iter 60: Loss 0.62048
      Est: Pos=[2.28, 7.92] Vel=[-125.47, 137.19]
   Iter 70: Loss 0.62031
      Est: Pos=[2.36, 7.84] Vel=[-117.99, 130.03]
   Iter 80: Loss 0.62029
      Est: Pos=[2.38, 7.80] Vel=[-113.27, 124.75]
   Iter 90: Loss 0.62027
      Est: Pos=[2.35, 7.81] Vel=[-110.62, 121.62]

>>> ENTERING STAGE 1 (Slack Epsilon = 0.5) <<<
   Iter 0: Loss 0.66657
      Est: Pos=[2.41, 7.75] Vel=[-104.04, 115.20]
   Iter 10: Loss 0.30874
      E

In [3]:
import jax
import jax.numpy as jnp
from jax import jit
from typing import NamedTuple, Callable

# ==============================================================================
# 1. THE MODEL ABSTRACTION (The "Local" Physics)
# ==============================================================================

class PhysParams(NamedTuple):
    c: float
    epsilon: float

def acoustic_constitutive_law(u_state, grads, laplacian, params: PhysParams):
    """
    PURE PHYSICS. No grids, no loops, no padding.
    
    Inputs:
      u_state:   [P, Vx, Vy] at a specific point (or vectorized)
      grads:     ([dP/dx, dVx/dx, dVy/dx], [dP/dy, dVx/dy, dVy/dy])
      laplacian: [Lap_P, Lap_Vx, Lap_Vy]
      params:    Physical constants (Speed of sound, Slack)
      
    Returns:
      d_state/dt: [Rate_P, Rate_Vx, Rate_Vy]
    """
    # Unpack State (for readability)
    # 0: Pressure, 1: Vx, 2: Vy
    
    # Unpack Gradients
    # d_dx[0] is dP/dx, d_dx[1] is dVx/dx, etc.
    d_dx, d_dy = grads
    
    # --- THE PHYSICAL EQUATIONS ---
    # 1. Mass Conservation: dP/dt = -c * Div(V) + diffusion
    div_v = d_dx[1] + d_dy[2] # d(Vx)/dx + d(Vy)/dy
    rate_p = -params.c * div_v + params.epsilon * laplacian[0]
    
    # 2. Momentum Conservation X: dVx/dt = -c * Grad(P)_x + diffusion
    grad_p_x = d_dx[0]
    rate_vx = -params.c * grad_p_x + params.epsilon * laplacian[1]
    
    # 3. Momentum Conservation Y: dVy/dt = -c * Grad(P)_y + diffusion
    grad_p_y = d_dy[0]
    rate_vy = -params.c * grad_p_y + params.epsilon * laplacian[2]
    
    # Pack Result
    return jnp.stack([rate_p, rate_vx, rate_vy], axis=-1)

# ==============================================================================
# 2. THE TOPOLOGY ENGINE (The Grid)
# ==============================================================================

@jit
def calculate_stencil(u, dx):
    """
    Calculates Forward and Backward spatial derivatives using Hard Walls.
    Returns a NamedTuple-like structure or just raw tuples of gradients.
    """
    # Pad for Hard Walls: ((BeforeY, AfterY), (BeforeX, AfterX), (Ch))
    u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
    u_center = u_pad[1:-1, 1:-1]
    
    # --- Forward Gradients (Predictor) ---
    # d/dy (Rows, Axis 0)
    dy_fwd = (u_pad[2:, 1:-1] - u_center) / dx
    # d/dx (Cols, Axis 1)
    dx_fwd = (u_pad[1:-1, 2:] - u_center) / dx
    
    # --- Backward Gradients (Corrector) ---
    dy_bwd = (u_center - u_pad[0:-2, 1:-1]) / dx
    dx_bwd = (u_center - u_pad[1:-1, 0:-2]) / dx
    
    # --- Laplacian (Central) ---
    lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
           u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (dx**2)
           
    return (dx_fwd, dy_fwd), (dx_bwd, dy_bwd), lap

# ==============================================================================
# 3. THE INTEGRATOR (The Solver)
# ==============================================================================

@jit
def maccormack_integrator(u, dt, dx, params: PhysParams, physics_fn: Callable):
    """
    Generic Predictor-Corrector Integrator.
    It doesn't know it's solving acoustics. It just calls 'physics_fn'.
    """
    # 1. Get Geometry
    grads_fwd, grads_bwd, lap = calculate_stencil(u, dx)
    
    # 2. Predictor Step (Using Forward Gradients)
    # physics_fn(u, (d_dx, d_dy), lap, params)
    k1 = physics_fn(u, grads_fwd, lap, params)
    u_pred = u + k1 * dt
    
    # 3. Corrector Step (Using Backward Gradients on Predicted State)
    # We must re-calculate gradients on the PREDICTED state
    # Note: We only need Backward gradients here, but our helper gets both.
    # Optimization: In pure JIT, unused outputs are pruned.
    _, grads_bwd_pred, lap_pred = calculate_stencil(u_pred, dx)
    
    k2 = physics_fn(u_pred, grads_bwd_pred, lap_pred, params)
    
    # 4. Update
    u_next = 0.5 * (u + u_pred + k2 * dt)
    
    # 5. Boundary Enforcement (Hard Walls)
    u_next = u_next.at[0, :, :].set(0)
    u_next = u_next.at[-1, :, :].set(0)
    u_next = u_next.at[:, 0, :].set(0)
    u_next = u_next.at[:, -1, :].set(0)
    
    return u_next

In [4]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
from typing import NamedTuple, Callable, Tuple
import optax
import functools # Import needed for partial/wraps if manual, but static_argnames is easier

# ==============================================================================
# 1. THE PHYSICS MODEL
# ==============================================================================
IDX_P, IDX_VX, IDX_VY = 0, 1, 2

class PhysParams(NamedTuple):
    c: float
    epsilon: float

@jit
def acoustic_constitutive_law(u, grads: Tuple, laplacian, params: PhysParams):
    d_dx, d_dy = grads
    
    # 1. Mass: dP/dt = -c * Div(V)
    div_v = d_dx[..., IDX_VX] + d_dy[..., IDX_VY]
    rate_p = -params.c * div_v + params.epsilon * laplacian[..., IDX_P]
    
    # 2. Momentum X: d(Vx)/dt = -c * dP/dx
    grad_p_x = d_dx[..., IDX_P]
    rate_vx = -params.c * grad_p_x + params.epsilon * laplacian[..., IDX_VX]
    
    # 3. Momentum Y: d(Vy)/dt = -c * dP/dy
    grad_p_y = d_dy[..., IDX_P]
    rate_vy = -params.c * grad_p_y + params.epsilon * laplacian[..., IDX_VY]
    
    return jnp.stack([rate_p, rate_vx, rate_vy], axis=-1)

# ==============================================================================
# 2. THE TOPOLOGY ENGINE
# ==============================================================================

@jit
def calculate_stencil(u, dx):
    """ Calculates Forward/Backward gradients + Laplacian with Hard Wall padding """
    u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
    u_center = u_pad[1:-1, 1:-1]
    
    # Forward Gradients
    dx_fwd = (u_pad[1:-1, 2:] - u_center) / dx
    dy_fwd = (u_pad[2:, 1:-1] - u_center) / dx
    
    # Backward Gradients
    dx_bwd = (u_center - u_pad[1:-1, 0:-2]) / dx
    dy_bwd = (u_center - u_pad[0:-2, 1:-1]) / dx
    
    # Laplacian
    lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
           u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (dx**2)
           
    return (dx_fwd, dy_fwd), (dx_bwd, dy_bwd), lap

# FIX: Added static_argnames=['model_fn']
@functools.partial(jit, static_argnames=['model_fn'])
def maccormack_step(u, dt, dx, params: PhysParams, model_fn: Callable):
    """
    Generic Predictor-Corrector.
    model_fn is marked STATIC so JAX doesn't try to trace it as an array.
    """
    # 1. Geometry
    grads_fwd, grads_bwd, lap = calculate_stencil(u, dx)
    
    # 2. Predictor
    k1 = model_fn(u, grads_fwd, lap, params)
    u_pred = u + k1 * dt
    
    # 3. Corrector (re-calc gradients on prediction)
    _, grads_bwd_pred, lap_pred = calculate_stencil(u_pred, dx)
    k2 = model_fn(u_pred, grads_bwd_pred, lap_pred, params)
    
    # 4. Update
    u_next = 0.5 * (u + u_pred + k2 * dt)
    
    # 5. Boundaries (Hard Walls)
    u_next = u_next.at[0, :, :].set(0)
    u_next = u_next.at[-1, :, :].set(0)
    u_next = u_next.at[:, 0, :].set(0)
    u_next = u_next.at[:, -1, :].set(0)
    
    return u_next

# ==============================================================================
# 3. SIMULATION RUNNER
# ==============================================================================

GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200 

@jit
def run_simulation(params_dict, epsilon):
    x0, y0 = params_dict['pos'][0], params_dict['pos'][1]
    vx, vy = params_dict['vel'][0], params_dict['vel'][1]
    
    phys_params = PhysParams(c=343.0, epsilon=epsilon)
    
    xs = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    ys = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    X, Y = jnp.meshgrid(xs, ys, indexing='xy')
    
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, 3))
    
    blob_width = 0.5 + epsilon * 2.0
    
    def body_fn(carry, step_idx):
        u_curr = carry
        
        # A. INTEGRATE (Pass the function!)
        u_next = maccormack_step(u_curr, DT, DX, phys_params, acoustic_constitutive_law)
        
        # B. FORCING
        current_time = step_idx * DT
        pos_x = x0 + vx * current_time
        pos_y = y0 + vy * current_time
        
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        spatial = jnp.exp(-dist_sq / (2 * blob_width**2))
        amplitude = jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0
        
        source = spatial * amplitude * DT
        u_next = u_next.at[..., IDX_P].add(source)
        
        return u_next, u_next[..., IDX_P]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 4. INVERSE SOLVER
# ==============================================================================

TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0])}
SENSORS_RC = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])

print(f"--- Generating Ground Truth ---")
true_history = run_simulation(TRUE_PARAMS, epsilon=0.0)
observed_data = true_history[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]

def loss_fn(est_params_dict, epsilon):
    sim_hist = run_simulation(est_params_dict, epsilon)
    sim_data = sim_hist[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]
    
    safe_eps = 1e-6
    sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
    obs_norm = (observed_data - jnp.mean(observed_data, 0)) / (jnp.std(observed_data, 0) + safe_eps)
    corr = jnp.mean(sim_norm * obs_norm)
    
    pos = est_params_dict['pos']
    bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
    return (1.0 - corr) + bounds

def solve():
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0])}
    
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0)},
            {'pos': 'pos', 'vel': 'vel'}
        )
    )
    opt_state = optimizer.init(guess)
    
    epsilon_schedule = [2.0, 0.5, 0.0]
    
    print(f"\n--- Starting Search ---")
    for stage, eps in enumerate(epsilon_schedule):
        print(f"\n>>> ENTERING STAGE {stage} (Slack Epsilon = {eps}) <<<")
        for i in range(100): 
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            
            if i % 10 == 0:
                p = guess['pos']
                v = guess['vel']
                print(f"   Iter {i}: Loss {loss:.5f} | Pos=[{p[0]:.2f}, {p[1]:.2f}] Vel=[{v[0]:.2f}, {v[1]:.2f}]")

    print(f"\n--- FINAL RESULT ---")
    print(f"Estimated: Pos={guess['pos']} Vel={guess['vel']}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")

if __name__ == "__main__":
    solve()

--- Generating Ground Truth ---

--- Starting Search ---

>>> ENTERING STAGE 0 (Slack Epsilon = 2.0) <<<
   Iter 0: Loss 0.70650 | Pos=[4.90, 5.10] Vel=[-5.00, 5.00]
   Iter 10: Loss 0.65549 | Pos=[3.91, 6.09] Vel=[-54.23, 54.30]
   Iter 20: Loss 0.62788 | Pos=[3.04, 6.97] Vel=[-97.33, 98.10]
   Iter 30: Loss 0.62065 | Pos=[2.44, 7.61] Vel=[-125.68, 128.66]
   Iter 40: Loss 0.62066 | Pos=[2.19, 7.93] Vel=[-136.09, 142.57]
   Iter 50: Loss 0.62081 | Pos=[2.18, 7.99] Vel=[-133.45, 143.25]
   Iter 60: Loss 0.62048 | Pos=[2.28, 7.92] Vel=[-125.47, 137.19]
   Iter 70: Loss 0.62031 | Pos=[2.36, 7.84] Vel=[-117.99, 130.03]
   Iter 80: Loss 0.62029 | Pos=[2.38, 7.80] Vel=[-113.27, 124.75]
   Iter 90: Loss 0.62027 | Pos=[2.35, 7.81] Vel=[-110.62, 121.62]

>>> ENTERING STAGE 1 (Slack Epsilon = 0.5) <<<
   Iter 0: Loss 0.66657 | Pos=[2.41, 7.75] Vel=[-104.04, 115.20]
   Iter 10: Loss 0.30874 | Pos=[3.89, 6.21] Vel=[-27.42, 35.90]
   Iter 20: Loss 0.25850 | Pos=[3.34, 6.67] Vel=[-50.58, 54.04]
   

In [5]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
from typing import NamedTuple, Callable, List
import optax
import functools
from kingdon import Algebra

# ==============================================================================
# 1. THE MATRIX-ALGEBRA BRIDGE (Robust & Fast)
# ==============================================================================

class MatrixAlgebra:
    """
    Pre-computes the Geometric Product matrices for the basis vectors.
    Allows performing 'e_i * Field' using pure linear algebra (matmul).
    """
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg) # Should be 2^(p+q+r)
        
        # 1. Identify Basis Vectors (e1, e2...) to iterate over in the physics loop
        # We need the 1-vectors (generators).
        # We access them safely via the blades dict.
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        
        # 2. Generate Multiplication Matrices
        # For each generator e_k, we compute a matrix M_k.
        
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name] # The generator (e.g., e1)
            
            # Build matrix column-by-column
            # Column j corresponds to the result of (e_k * basis_blade_j)
            cols = []
            for i in range(self.dim):
                # Robustly create the i-th canonical basis blade (1, e1, e2, e12...)
                # Index i maps to binary representation of blades.
                b_i = self.alg.multivector({i: 1})
                
                # Geometric Product: e_k * b_i
                res = e_k * b_i
                
                # Extract dense coefficients into a JAX array column
                dense_col = jnp.zeros(self.dim)
                
                # Kingdon multivectors are sparse dictionaries (key: value)
                # We iterate the keys (indices) in the result
                for bin_key, val in res.items():
                    # bin_key is the integer index of the blade (e.g., 3 for e12)
                    # We map this directly to the dense array index
                    dense_col = dense_col.at[bin_key].set(val)
                        
                cols.append(dense_col)
            
            # Stack columns to form (dim, dim) matrix for generator e_k
            M = jnp.stack(cols, axis=1)
            matrices.append(M)
            
        # Store as a JAX array stack: Shape (Spatial_Dim, Algebra_Dim, Algebra_Dim)
        self.basis_matrices = jnp.stack(matrices)

    def __hash__(self):
        return hash((self.p, self.q, self.r))

    def __eq__(self, other):
        return (self.p, self.q, self.r) == (other.p, other.q, other.r)

# ==============================================================================
# 2. THE PHYSICS KERNEL (Matrix-Based)
# ==============================================================================

IDX_P, IDX_VX, IDX_VY = 0, 1, 2

class PhysParams(NamedTuple):
    c: float
    epsilon: float

@functools.partial(jit, static_argnames=['algebra'])
def geometric_constitutive_law(u_coeffs, grads_list: List, lap_coeffs, params: PhysParams, algebra: MatrixAlgebra):
    """
    GENERIC STA PHYSICS (Matrix Implementation).
    Equation: d(Psi)/dt = -c * sum( M_i @ d_i(Psi) )
    """
    
    # 1. Initialize Rate (same shape as u)
    rate_coeffs = jnp.zeros_like(u_coeffs)
    
    # 2. Geometric Derivative Loop
    # We iterate over the spatial dimensions
    for i, grad_coeffs in enumerate(grads_list):
        if i >= len(algebra.basis_matrices): break
        
        # Get the pre-computed multiplication matrix for e_i
        # Shape: (Alg_Dim, Alg_Dim)
        M_i = algebra.basis_matrices[i]
        
        # Apply Matrix to Field: basis * gradient
        # grad_coeffs shape: (Grid, Grid, Alg_Dim)
        # We want to contract the last dimension
        # term = jnp.dot(grad_coeffs, M_i.T) 
        # But let's be explicit with einsum for clarity:
        # "...j" is input field component
        # "kj" is matrix (row k, col j) -> This is M @ v standard? 
        # No, M @ v means M_kj * v_j -> result_k.
        term = jnp.einsum('kj,...j->...k', M_i, grad_coeffs)
        
        rate_coeffs = rate_coeffs - params.c * term
        
    # 3. Geometric Slack (Scalar multiplication, no matrix needed)
    rate_coeffs = rate_coeffs + params.epsilon * lap_coeffs
    
    return rate_coeffs

# ==============================================================================
# 3. THE TOPOLOGY ENGINE (Grid)
# ==============================================================================

@functools.partial(jit, static_argnames=['ndim'])
def get_gradients_general(u, dx, ndim):
    """ N-dim Finite Difference """
    pad_width = tuple([(1, 1)] * ndim + [(0, 0)])
    u_pad = jnp.pad(u, pad_width, mode='constant')
    
    center_slice = tuple([slice(1, -1)] * ndim + [slice(None)])
    u_center = u_pad[center_slice]
    
    raw_grads_fwd = []
    raw_grads_bwd = []
    lap_sum = jnp.zeros_like(u_center)
    
    for axis in range(ndim):
        slice_fwd = [slice(1, -1)] * ndim; slice_fwd[axis] = slice(2, None); slice_fwd.append(slice(None))
        slice_bwd = [slice(1, -1)] * ndim; slice_bwd[axis] = slice(0, -2); slice_bwd.append(slice(None))
        
        u_fwd = u_pad[tuple(slice_fwd)]
        u_bwd = u_pad[tuple(slice_bwd)]
        
        d_fwd = (u_fwd - u_center) / dx
        d_bwd = (u_center - u_bwd) / dx
        
        raw_grads_fwd.append(d_fwd)
        raw_grads_bwd.append(d_bwd)
        lap_sum += (u_fwd + u_bwd - 2*u_center) / (dx**2)
    
    # Reorder [d/dY, d/dX] -> [d/dX, d/dY]
    if ndim == 2:
        grads_fwd = [raw_grads_fwd[1], raw_grads_fwd[0]]
        grads_bwd = [raw_grads_bwd[1], raw_grads_bwd[0]]
    else:
        grads_fwd = raw_grads_fwd
        grads_bwd = raw_grads_bwd
        
    return grads_fwd, grads_bwd, lap_sum

@functools.partial(jit, static_argnames=['model_fn', 'algebra', 'ndim'])
def maccormack_step(u, dt, dx, params: PhysParams, model_fn: Callable, algebra: MatrixAlgebra, ndim: int):
    # Pass 'algebra' (the wrapper) down
    grads_fwd, grads_bwd, lap = get_gradients_general(u, dx, ndim)
    
    k1 = model_fn(u, grads_fwd, lap, params, algebra)
    u_pred = u + k1 * dt
    
    _, grads_bwd_pred, lap_pred = get_gradients_general(u_pred, dx, ndim)
    k2 = model_fn(u_pred, grads_bwd_pred, lap_pred, params, algebra)
    
    u_next = 0.5 * (u + u_pred + k2 * dt)
    
    for axis in range(ndim):
        s0 = [slice(None)] * (ndim + 1); s0[axis] = 0; u_next = u_next.at[tuple(s0)].set(0)
        s1 = [slice(None)] * (ndim + 1); s1[axis] = -1; u_next = u_next.at[tuple(s1)].set(0)
    
    return u_next

# ==============================================================================
# 4. SIMULATION RUNNER
# ==============================================================================

GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200 

@functools.partial(jit, static_argnames=['algebra'])
def run_simulation(params_dict, epsilon, algebra: MatrixAlgebra):
    xs = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    ys = jnp.linspace(0, GRID_SIZE*DX, GRID_SIZE)
    X, Y = jnp.meshgrid(xs, ys, indexing='xy')
    
    # Use algebra.dim (e.g., 4 for 2D, 8 for 3D)
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    
    phys_params = PhysParams(c=343.0, epsilon=epsilon)
    x0, y0 = params_dict['pos'][0], params_dict['pos'][1]
    vx, vy = params_dict['vel'][0], params_dict['vel'][1]
    blob_width = 0.5 + epsilon * 2.0
    
    def body_fn(carry, step_idx):
        u_curr = carry
        
        u_next = maccormack_step(
            u_curr, DT, DX, phys_params, 
            geometric_constitutive_law, 
            algebra, # Passing matrix wrapper
            ndim=2
        )
        
        current_time = step_idx * DT
        pos_x = x0 + vx * current_time
        pos_y = y0 + vy * current_time
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        source = jnp.exp(-dist_sq / (2 * blob_width**2)) * jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0 * DT
        
        u_next = u_next.at[..., IDX_P].add(source)
        return u_next, u_next[..., IDX_P]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 5. EXECUTION
# ==============================================================================

def solve():
    # 1. Define World (Matrix Algebra Wrapper)
    world_2d = MatrixAlgebra(2, 0)
    
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0])}
    SENSORS_RC = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])

    print(f"--- Generating Ground Truth (Matrix Algebra) ---")
    true_history = run_simulation(TRUE_PARAMS, epsilon=0.0, algebra=world_2d)
    observed_data = true_history[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]

    def loss_fn(est_params, epsilon):
        sim_hist = run_simulation(est_params, epsilon, algebra=world_2d)
        sim_data = sim_hist[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]
        
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (observed_data - jnp.mean(observed_data, 0)) / (jnp.std(observed_data, 0) + safe_eps)
        
        corr = jnp.mean(sim_norm * obs_norm)
        pos = est_params['pos']
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - corr) + bounds

    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0])}
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0)},
            {'pos': 'pos', 'vel': 'vel'}
        )
    )
    opt_state = optimizer.init(guess)
    
    print(f"\n--- Starting Search ---")
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> ENTERING STAGE {stage} (Slack Epsilon = {eps}) <<<")
        for i in range(100): 
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            if i % 10 == 0:
                p, v = guess['pos'], guess['vel']
                print(f"   Iter {i}: Loss {loss:.5f} | Pos=[{p[0]:.2f}, {p[1]:.2f}] Vel=[{v[0]:.2f}, {v[1]:.2f}]")

    print(f"\n--- FINAL RESULT ---")
    print(f"Estimated: Pos={guess['pos']} Vel={guess['vel']}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")

if __name__ == "__main__":
    solve()

--- Generating Ground Truth (Matrix Algebra) ---

--- Starting Search ---

>>> ENTERING STAGE 0 (Slack Epsilon = 2.0) <<<
   Iter 0: Loss 0.74853 | Pos=[4.90, 5.10] Vel=[-5.00, 5.00]
   Iter 10: Loss 0.70204 | Pos=[3.91, 6.09] Vel=[-54.20, 54.27]
   Iter 20: Loss 0.67788 | Pos=[3.05, 6.96] Vel=[-96.74, 97.62]
   Iter 30: Loss 0.67248 | Pos=[2.48, 7.58] Vel=[-123.16, 126.72]
   Iter 40: Loss 0.67281 | Pos=[2.27, 7.87] Vel=[-130.56, 138.31]
   Iter 50: Loss 0.67279 | Pos=[2.31, 7.90] Vel=[-125.14, 136.69]
   Iter 60: Loss 0.67242 | Pos=[2.42, 7.82] Vel=[-115.51, 129.00]
   Iter 70: Loss 0.67230 | Pos=[2.50, 7.74] Vel=[-107.61, 121.06]
   Iter 80: Loss 0.67228 | Pos=[2.50, 7.72] Vel=[-103.03, 115.49]
   Iter 90: Loss 0.67225 | Pos=[2.45, 7.74] Vel=[-100.28, 112.00]

>>> ENTERING STAGE 1 (Slack Epsilon = 0.5) <<<
   Iter 0: Loss 0.62616 | Pos=[2.50, 7.68] Vel=[-93.08, 104.82]
   Iter 10: Loss 0.33514 | Pos=[3.87, 6.21] Vel=[-22.23, 28.81]
   Iter 20: Loss 0.27620 | Pos=[3.22, 6.77] Vel=[-4

In [6]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
from typing import NamedTuple, Callable, List, Tuple
import optax
import functools
from kingdon import Algebra

# ==============================================================================
# 1. THE MANIFOLD ABSTRACTION (Topology)
# ==============================================================================

class Manifold:
    """ Base class for Grid Topology """
    def __hash__(self): raise NotImplementedError
    def __eq__(self, other): raise NotImplementedError
    
    def gradients(self, u): 
        """ Returns (fwd_grads_list, bwd_grads_list, laplacian) """
        raise NotImplementedError

    def enforce_boundaries(self, u):
        """ Applies boundary conditions (walls, wrapping, etc.) """
        return u # Default: do nothing
        
    @property
    def coordinates(self):
        """ Returns coordinate meshgrids [X, Y, ...] """
        raise NotImplementedError

class CartesianBox(Manifold):
    """ Standard 2D Grid with Hard Wall Boundaries """
    def __init__(self, size, dx):
        self.size = size
        self.dx = dx
        # Cache coordinates
        xs = jnp.linspace(0, size*dx, size)
        ys = jnp.linspace(0, size*dx, size)
        self.X, self.Y = jnp.meshgrid(xs, ys, indexing='xy')
        
    def __hash__(self): return hash((self.size, self.dx, 'box'))
    def __eq__(self, o): return (self.size, self.dx) == (o.size, o.dx) and isinstance(o, CartesianBox)
    
    @property
    def coordinates(self): return (self.X, self.Y)

    def gradients(self, u):
        # Hard Wall Padding: ((Y_pre, Y_post), (X_pre, X_post), (Ch))
        u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
        u_center = u_pad[1:-1, 1:-1]
        dx = self.dx
        
        # d/dx (Axis 1) and d/dy (Axis 0)
        # Forward
        dx_fwd = (u_pad[1:-1, 2:] - u_center) / dx
        dy_fwd = (u_pad[2:, 1:-1] - u_center) / dx
        # Backward
        dx_bwd = (u_center - u_pad[1:-1, 0:-2]) / dx
        dy_bwd = (u_center - u_pad[0:-2, 1:-1]) / dx
        # Laplacian
        lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
               u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (dx**2)
               
        # Return matched lists: [d_x, d_y]
        return ([dx_fwd, dy_fwd], [dx_bwd, dy_bwd], lap)

    def enforce_boundaries(self, u):
        # Zero out edges
        u = u.at[0, :, :].set(0)
        u = u.at[-1, :, :].set(0)
        u = u.at[:, 0, :].set(0)
        u = u.at[:, -1, :].set(0)
        return u

class CartesianTorus(Manifold):
    """ 2D Grid with Periodic Boundaries (Pac-Man) """
    def __init__(self, size, dx):
        self.size = size
        self.dx = dx
        xs = jnp.linspace(0, size*dx, size)
        ys = jnp.linspace(0, size*dx, size)
        self.X, self.Y = jnp.meshgrid(xs, ys, indexing='xy')

    def __hash__(self): return hash((self.size, self.dx, 'torus'))
    def __eq__(self, o): return (self.size, self.dx) == (o.size, o.dx) and isinstance(o, CartesianTorus)

    @property
    def coordinates(self): return (self.X, self.Y)

    def gradients(self, u):
        # Periodic Wrapping via jnp.roll (No padding needed!)
        dx = self.dx
        
        # Forward
        dx_fwd = (jnp.roll(u, -1, axis=1) - u) / dx
        dy_fwd = (jnp.roll(u, -1, axis=0) - u) / dx
        # Backward
        dx_bwd = (u - jnp.roll(u, 1, axis=1)) / dx
        dy_bwd = (u - jnp.roll(u, 1, axis=0)) / dx
        # Laplacian
        lap = (jnp.roll(u, -1, axis=1) + jnp.roll(u, 1, axis=1) + \
               jnp.roll(u, -1, axis=0) + jnp.roll(u, 1, axis=0) - 4*u) / (dx**2)
               
        return ([dx_fwd, dy_fwd], [dx_bwd, dy_bwd], lap)
        
    def enforce_boundaries(self, u):
        return u # No edges to enforce!

# ==============================================================================
# 2. MATRIX ALGEBRA (Unchanged)
# ==============================================================================
class MatrixAlgebra:
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items():
                    dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            M = jnp.stack(cols, axis=1)
            matrices.append(M)
        self.basis_matrices = jnp.stack(matrices)

    def __hash__(self): return hash((self.p, self.q, self.r))
    def __eq__(self, o): return (self.p, self.q, self.r) == (o.p, o.q, o.r)

# ==============================================================================
# 3. PHYSICS KERNEL (Unchanged)
# ==============================================================================
IDX_P = 0
class PhysParams(NamedTuple):
    c: float
    epsilon: float

@functools.partial(jit, static_argnames=['algebra'])
def geometric_constitutive_law(u_coeffs, grads_list: List, lap_coeffs, params: PhysParams, algebra: MatrixAlgebra):
    rate_coeffs = jnp.zeros_like(u_coeffs)
    for i, grad_coeffs in enumerate(grads_list):
        if i >= len(algebra.basis_matrices): break
        M_i = algebra.basis_matrices[i]
        term = jnp.einsum('kj,...j->...k', M_i, grad_coeffs)
        rate_coeffs = rate_coeffs - params.c * term
    rate_coeffs = rate_coeffs + params.epsilon * lap_coeffs
    return rate_coeffs

# ==============================================================================
# 4. UPDATED INTEGRATOR (Uses Manifold)
# ==============================================================================

# Now 'manifold' replaces 'dx' and 'ndim'
@functools.partial(jit, static_argnames=['model_fn', 'algebra', 'manifold'])
def maccormack_step(u, dt, params: PhysParams, model_fn: Callable, algebra: MatrixAlgebra, manifold: Manifold):
    
    # 1. Geometry (Delegated to Manifold)
    grads_fwd, grads_bwd, lap = manifold.gradients(u)
    
    # 2. Predictor
    k1 = model_fn(u, grads_fwd, lap, params, algebra)
    u_pred = u + k1 * dt
    
    # 3. Corrector (Recalculate on predicted state)
    _, grads_bwd_pred, lap_pred = manifold.gradients(u_pred)
    k2 = model_fn(u_pred, grads_bwd_pred, lap_pred, params, algebra)
    
    # 4. Average
    u_next = 0.5 * (u + u_pred + k2 * dt)
    
    # 5. Boundaries (Delegated to Manifold)
    u_next = manifold.enforce_boundaries(u_next)
    
    return u_next

# ==============================================================================
# 5. SIMULATION RUNNER (Abstracted)
# ==============================================================================

GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200 

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def run_simulation(params_dict, epsilon, algebra: MatrixAlgebra, manifold: Manifold):
    # Retrieve Coordinates from Manifold
    X, Y = manifold.coordinates
    
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    
    phys_params = PhysParams(c=343.0, epsilon=epsilon)
    x0, y0 = params_dict['pos'][0], params_dict['pos'][1]
    vx, vy = params_dict['vel'][0], params_dict['vel'][1]
    blob_width = 0.5 + epsilon * 2.0
    
    def body_fn(carry, step_idx):
        u_curr = carry
        
        # Call step with Manifold
        u_next = maccormack_step(
            u_curr, DT, phys_params, 
            geometric_constitutive_law, 
            algebra, 
            manifold # <--- The abstraction!
        )
        
        current_time = step_idx * DT
        pos_x = x0 + vx * current_time
        pos_y = y0 + vy * current_time
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        source = jnp.exp(-dist_sq / (2 * blob_width**2)) * jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0 * DT
        
        u_next = u_next.at[..., IDX_P].add(source)
        return u_next, u_next[..., IDX_P]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 6. EXECUTION
# ==============================================================================

def solve():
    # 1. Define Objects
    algebra = MatrixAlgebra(2, 0)
    
    # WE CHOOSE THE MANIFOLD HERE
    # Change to CartesianTorus(GRID_SIZE, DX) to see wrapping effects!
    manifold = CartesianBox(GRID_SIZE, DX) 
    
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0])}
    SENSORS_RC = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])

    print(f"--- Generating Ground Truth ({manifold.__class__.__name__}) ---")
    true_history = run_simulation(TRUE_PARAMS, epsilon=0.0, algebra=algebra, manifold=manifold)
    observed_data = true_history[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]

    def loss_fn(est_params, epsilon):
        sim_hist = run_simulation(est_params, epsilon, algebra=algebra, manifold=manifold)
        sim_data = sim_hist[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]
        
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (observed_data - jnp.mean(observed_data, 0)) / (jnp.std(observed_data, 0) + safe_eps)
        
        corr = jnp.mean(sim_norm * obs_norm)
        pos = est_params['pos']
        # Boundary penalty should probably check Manifold bounds, 
        # but simplistic global bounds are fine for now.
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - corr) + bounds

    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0])}
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0)},
            {'pos': 'pos', 'vel': 'vel'}
        )
    )
    opt_state = optimizer.init(guess)
    
    print(f"\n--- Starting Search ---")
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> ENTERING STAGE {stage} (Slack Epsilon = {eps}) <<<")
        for i in range(100): 
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            if i % 10 == 0:
                p, v = guess['pos'], guess['vel']
                print(f"   Iter {i}: Loss {loss:.5f} | Pos=[{p[0]:.2f}, {p[1]:.2f}] Vel=[{v[0]:.2f}, {v[1]:.2f}]")

    print(f"\n--- FINAL RESULT ---")
    print(f"Estimated: Pos={guess['pos']} Vel={guess['vel']}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")

if __name__ == "__main__":
    solve()

--- Generating Ground Truth (CartesianBox) ---

--- Starting Search ---

>>> ENTERING STAGE 0 (Slack Epsilon = 2.0) <<<
   Iter 0: Loss 0.74853 | Pos=[4.90, 5.10] Vel=[-5.00, 5.00]
   Iter 10: Loss 0.70204 | Pos=[3.91, 6.09] Vel=[-54.20, 54.27]
   Iter 20: Loss 0.67788 | Pos=[3.05, 6.96] Vel=[-96.74, 97.62]
   Iter 30: Loss 0.67248 | Pos=[2.48, 7.58] Vel=[-123.16, 126.72]
   Iter 40: Loss 0.67281 | Pos=[2.27, 7.87] Vel=[-130.56, 138.31]
   Iter 50: Loss 0.67279 | Pos=[2.31, 7.90] Vel=[-125.14, 136.69]
   Iter 60: Loss 0.67242 | Pos=[2.42, 7.82] Vel=[-115.51, 129.00]
   Iter 70: Loss 0.67230 | Pos=[2.50, 7.74] Vel=[-107.61, 121.06]
   Iter 80: Loss 0.67228 | Pos=[2.50, 7.72] Vel=[-103.03, 115.49]
   Iter 90: Loss 0.67225 | Pos=[2.45, 7.74] Vel=[-100.28, 112.00]

>>> ENTERING STAGE 1 (Slack Epsilon = 0.5) <<<
   Iter 0: Loss 0.62616 | Pos=[2.50, 7.68] Vel=[-93.08, 104.82]
   Iter 10: Loss 0.33514 | Pos=[3.87, 6.21] Vel=[-22.23, 28.81]
   Iter 20: Loss 0.27620 | Pos=[3.22, 6.77] Vel=[-48.

In [7]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
from typing import NamedTuple, List, Callable
import optax
import functools
from kingdon import Algebra

# ==============================================================================
# 1. CORE MATH: MATRIX ALGEBRA & MANIFOLD
# ==============================================================================
# (Same robust implementations as before)

class MatrixAlgebra:
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items():
                    dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            M = jnp.stack(cols, axis=1)
            matrices.append(M)
        self.basis_matrices = jnp.stack(matrices)
    def __hash__(self): return hash((self.p, self.q, self.r))
    def __eq__(self, o): return (self.p, self.q, self.r) == (o.p, o.q, o.r)

class CartesianBox:
    def __init__(self, size, dx):
        self.size = size
        self.dx = dx
        xs = jnp.linspace(0, size*dx, size)
        ys = jnp.linspace(0, size*dx, size)
        self.X, self.Y = jnp.meshgrid(xs, ys, indexing='xy')
        
    def __hash__(self): return hash((self.size, self.dx))
    def __eq__(self, o): return (self.size, self.dx) == (o.size, o.dx)

    @property
    def coordinates(self): return (self.X, self.Y)

    def gradients(self, u):
        u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
        u_center = u_pad[1:-1, 1:-1]
        dx = self.dx
        
        # Forward
        dx_fwd = (u_pad[1:-1, 2:] - u_center) / dx
        dy_fwd = (u_pad[2:, 1:-1] - u_center) / dx
        # Backward
        dx_bwd = (u_center - u_pad[1:-1, 0:-2]) / dx
        dy_bwd = (u_center - u_pad[0:-2, 1:-1]) / dx
        # Laplacian
        lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + \
               u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_center) / (dx**2)
               
        return ([dx_fwd, dy_fwd], [dx_bwd, dy_bwd], lap)

    def enforce_boundaries(self, u):
        u = u.at[0, :, :].set(0); u = u.at[-1, :, :].set(0)
        u = u.at[:, 0, :].set(0); u = u.at[:, -1, :].set(0)
        return u

# ==============================================================================
# 2. THE OPERATOR ABSTRACTION (Lifting the Terms)
# ==============================================================================

class Operator:
    """ Base class for a differential physics term """
    def __hash__(self): raise NotImplementedError
    def __eq__(self, o): raise NotImplementedError
    
    def apply(self, u, grads, lap, algebra):
        """ Returns the rate of change d(Psi)/dt contributed by this term """
        raise NotImplementedError

class GeometricAdvection(Operator):
    """
    First Order: dPsi/dt = -c * (Geometric_Gradient * Psi)
    Unifies: Acoustics, Maxwell, Dirac
    """
    def __init__(self, c):
        self.c = c
        
    def __hash__(self): return hash(('advection', self.c))
    def __eq__(self, o): return isinstance(o, GeometricAdvection) and self.c == o.c
    
    def apply(self, u, grads, lap, algebra):
        rate = jnp.zeros_like(u)
        # Sum over dimensions: -c * (e_i * d_i_u)
        for i, grad_field in enumerate(grads):
            if i >= len(algebra.basis_matrices): break
            M_i = algebra.basis_matrices[i]
            # Matrix-Vector multiply (M_i @ grad)
            term = jnp.einsum('kj,...j->...k', M_i, grad_field)
            rate = rate - self.c * term
        return rate

class GeometricDiffusion(Operator):
    """
    Second Order: dPsi/dt = epsilon * Laplacian(Psi)
    Unifies: Heat, Viscosity, SchrÃ¶dinger Kinetic Term
    """
    def __init__(self, epsilon):
        self.epsilon = epsilon
        
    def __hash__(self): return hash(('diffusion', self.epsilon))
    def __eq__(self, o): return isinstance(o, GeometricDiffusion) and self.epsilon == o.epsilon

    def apply(self, u, grads, lap, algebra):
        # Scalar multiplication of the Laplacian field
        return self.epsilon * lap

class LinearPotential(Operator):
    """
    Zero Order: dPsi/dt = -V * Psi
    Unifies: Mass, Damping, Potential Wells
    """
    def __init__(self, potential_val):
        self.V = potential_val
        
    def __hash__(self): return hash(('potential', self.V))
    def __eq__(self, o): return isinstance(o, LinearPotential) and self.V == o.V

    def apply(self, u, grads, lap, algebra):
        return -self.V * u

class CompositePDE:
    """ The Hamiltonian: A sum of operators """
    def __init__(self, operators: List[Operator]):
        self.operators = tuple(operators) # Tuple is hashable
        
    def __hash__(self): return hash(self.operators)
    def __eq__(self, o): return self.operators == o.operators
    
    def __call__(self, u, grads, lap, algebra):
        total_rate = jnp.zeros_like(u)
        for op in self.operators:
            total_rate = total_rate + op.apply(u, grads, lap, algebra)
        return total_rate

# ==============================================================================
# 3. THE UNIVERSAL INTEGRATOR
# ==============================================================================

@functools.partial(jit, static_argnames=['pde', 'algebra', 'manifold'])
def maccormack_step(u, dt, pde: CompositePDE, algebra: MatrixAlgebra, manifold: CartesianBox):
    # 1. Geometry
    grads_fwd, grads_bwd, lap = manifold.gradients(u)
    
    # 2. Predictor (Using PDE object)
    k1 = pde(u, grads_fwd, lap, algebra)
    u_pred = u + k1 * dt
    
    # 3. Corrector
    _, grads_bwd_pred, lap_pred = manifold.gradients(u_pred)
    k2 = pde(u_pred, grads_bwd_pred, lap_pred, algebra)
    
    # 4. Update
    u_next = 0.5 * (u + u_pred + k2 * dt)
    u_next = manifold.enforce_boundaries(u_next)
    return u_next

# ==============================================================================
# 4. SIMULATION RUNNER
# ==============================================================================

GRID_SIZE = 100
DX = 0.1
DT = 0.0001
DURATION_STEPS = 1200 
IDX_P = 0

@functools.partial(jit, static_argnames=['algebra', 'manifold', 'pde'])
def run_simulation(params_dict, pde: CompositePDE, algebra: MatrixAlgebra, manifold: CartesianBox):
    X, Y = manifold.coordinates
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    
    x0, y0 = params_dict['pos'][0], params_dict['pos'][1]
    vx, vy = params_dict['vel'][0], params_dict['vel'][1]
    
    # Retrieve 'epsilon' from the PDE to size the source blob?
    # For now, we hardcode blob width or inspect the diffusion term if needed.
    # Let's assume generic blob.
    blob_width = 1.5 
    
    def body_fn(carry, step_idx):
        u_curr = carry
        
        # Call Integrator with PDE
        u_next = maccormack_step(u_curr, DT, pde, algebra, manifold)
        
        current_time = step_idx * DT
        pos_x = x0 + vx * current_time
        pos_y = y0 + vy * current_time
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        source = jnp.exp(-dist_sq / (2 * blob_width**2)) * jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0 * DT
        
        u_next = u_next.at[..., IDX_P].add(source)
        return u_next, u_next[..., IDX_P]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 5. EXECUTION
# ==============================================================================

def solve():
    # 1. Define World
    algebra = MatrixAlgebra(2, 0)
    manifold = CartesianBox(GRID_SIZE, DX)
    
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0])}
    SENSORS_RC = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])

    # 2. Define Physics as LEGO BLOCKS (Target Physics)
    # Target: Pure Acoustics (Epsilon=0)
    target_physics = CompositePDE([
        GeometricAdvection(c=343.0),
        GeometricDiffusion(epsilon=0.0) 
    ])

    print(f"--- Generating Ground Truth (Universal PDE Engine) ---")
    true_history = run_simulation(TRUE_PARAMS, target_physics, algebra, manifold)
    observed_data = true_history[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]

    # 3. Define Loss Function
    # Note: We must construct a NEW PDE object for each epsilon stage!
    def loss_fn(est_params, current_epsilon):
        # Dynamic Physics Construction!
        current_physics = CompositePDE([
            GeometricAdvection(c=343.0),
            GeometricDiffusion(epsilon=current_epsilon)
        ])
        
        sim_hist = run_simulation(est_params, current_physics, algebra, manifold)
        sim_data = sim_hist[:, SENSORS_RC[:,0], SENSORS_RC[:,1]]
        
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (observed_data - jnp.mean(observed_data, 0)) / (jnp.std(observed_data, 0) + safe_eps)
        corr = jnp.mean(sim_norm * obs_norm)
        
        pos = est_params['pos']
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - corr) + bounds

    # 4. Optimization Loop
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0])}
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0)},
            {'pos': 'pos', 'vel': 'vel'}
        )
    )
    opt_state = optimizer.init(guess)
    
    print(f"\n--- Starting Search ---")
    # Annealing Schedule: Reduce epsilon (Diffusion term) over time
    for stage, eps in enumerate([2.0, 0.5, 0.0]):
        print(f"\n>>> ENTERING STAGE {stage} (Physics: Advection + Diffusion={eps}) <<<")
        for i in range(100): 
            # Note: We pass 'eps' to loss_fn, which constructs the PDE internally
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            if i % 10 == 0:
                p, v = guess['pos'], guess['vel']
                print(f"   Iter {i}: Loss {loss:.5f} | Pos=[{p[0]:.2f}, {p[1]:.2f}] Vel=[{v[0]:.2f}, {v[1]:.2f}]")

    print(f"\n--- FINAL RESULT ---")
    print(f"Estimated: Pos={guess['pos']} Vel={guess['vel']}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")

if __name__ == "__main__":
    solve()

--- Generating Ground Truth (Universal PDE Engine) ---

--- Starting Search ---

>>> ENTERING STAGE 0 (Physics: Advection + Diffusion=2.0) <<<
   Iter 0: Loss 0.47631 | Pos=[4.90, 5.10] Vel=[-5.00, 5.00]
   Iter 10: Loss 0.08123 | Pos=[3.90, 6.10] Vel=[-55.16, 55.17]
   Iter 20: Loss 0.02477 | Pos=[3.24, 6.81] Vel=[-86.87, 89.53]
   Iter 30: Loss 0.01545 | Pos=[3.41, 6.73] Vel=[-77.00, 84.42]
   Iter 40: Loss 0.00603 | Pos=[3.65, 6.46] Vel=[-62.64, 69.30]
   Iter 50: Loss 0.00391 | Pos=[3.53, 6.50] Vel=[-65.38, 68.82]
   Iter 60: Loss 0.00352 | Pos=[3.45, 6.61] Vel=[-66.34, 71.44]
   Iter 70: Loss 0.00268 | Pos=[3.50, 6.59] Vel=[-60.04, 67.43]
   Iter 80: Loss 0.00260 | Pos=[3.48, 6.58] Vel=[-57.09, 63.80]
   Iter 90: Loss 0.00241 | Pos=[3.44, 6.61] Vel=[-55.17, 62.11]

>>> ENTERING STAGE 1 (Physics: Advection + Diffusion=0.5) <<<
   Iter 0: Loss 0.00190 | Pos=[3.44, 6.62] Vel=[-51.05, 59.28]
   Iter 10: Loss 0.00179 | Pos=[3.43, 6.63] Vel=[-47.65, 56.27]
   Iter 20: Loss 0.00167 | Pos

In [8]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import optax
import functools
from kingdon import Algebra

# ... [Include MatrixAlgebra, CartesianBox, get_gradients from previous step] ...

# ==============================================================================
# 1. THE UNIVERSAL OPERATORS (With Learnable Parameters)
# ==============================================================================

# We modify the PDE call to accept a parameter dictionary directly
# instead of having fixed classes.

@functools.partial(jit, static_argnames=['algebra'])
def universal_pde_kernel(u, grads, lap, params_dict, algebra):
    """
    The Hamiltonian: H(u) = -c(Advection) + epsilon(Diffusion) - V(Potential)
    All coefficients (c, epsilon, V) are now passed dynamically.
    """
    rate = jnp.zeros_like(u)
    
    # 1. Geometric Advection (Wave Speed c)
    # Term: -c * sum(e_i * d_i u)
    c = params_dict['c']
    for i, grad_field in enumerate(grads):
        if i >= len(algebra.basis_matrices): break
        M_i = algebra.basis_matrices[i]
        term = jnp.einsum('kj,...j->...k', M_i, grad_field)
        rate = rate - c * term

    # 2. Geometric Diffusion (Viscosity epsilon)
    # Term: +epsilon * Laplacian(u)
    # Note: We usually fix epsilon for annealing, but we COULD learn it.
    # Here we treat it as a hyperparameter passed in params_dict.
    eps = params_dict['epsilon']
    rate = rate + eps * lap
    
    # 3. Linear Potential (Damping/Mass V)
    # Term: -V * u
    V = params_dict['V']
    rate = rate - V * u
    
    return rate

# ==============================================================================
# 2. THE SIMULATION (Accepts Physics Params)
# ==============================================================================

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def run_simulation(learnable_params, hyper_params, algebra, manifold):
    """
    learnable_params: {pos, vel, c, V} -> The things we want to find.
    hyper_params: {epsilon} -> The solver controls.
    """
    X, Y = manifold.coordinates
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    
    # Unpack Logic
    x0, y0 = learnable_params['pos']
    vx, vy = learnable_params['vel']
    
    # Combine all params for the kernel
    # We detach epsilon (it's a control variable, not a physical truth we are searching for yet)
    pde_params = {
        'c': learnable_params['c'], 
        'V': learnable_params['V'],
        'epsilon': hyper_params['epsilon']
    }
    
    blob_width = 1.5
    
    def body_fn(carry, step_idx):
        u_curr = carry
        
        # 1. Geometry
        grads_fwd, grads_bwd, lap = manifold.gradients(u_curr)
        
        # 2. Predictor
        k1 = universal_pde_kernel(u_curr, grads_fwd, lap, pde_params, algebra)
        u_pred = u_curr + k1 * DT
        
        # 3. Corrector
        _, grads_bwd_pred, lap_pred = manifold.gradients(u_pred)
        k2 = universal_pde_kernel(u_pred, grads_bwd_pred, lap_pred, pde_params, algebra)
        
        u_next = 0.5 * (u_curr + u_pred + k2 * DT)
        u_next = manifold.enforce_boundaries(u_next)
        
        # 4. Source
        t = step_idx * DT
        pos_x = x0 + vx * t
        pos_y = y0 + vy * t
        dist_sq = (X - pos_x)**2 + (Y - pos_y)**2
        source = jnp.exp(-dist_sq / (2 * blob_width**2)) * jnp.exp(-(step_idx - 50)**2 / (2 * 20.0**2)) * 100.0 * DT
        
        u_next = u_next.at[..., 0].add(source)
        return u_next, u_next[..., 0]

    final_u, history_p = lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))
    return history_p

# ==============================================================================
# 3. EXECUTION: LEARNING PHYSICS FROM SCRATCH
# ==============================================================================

def solve_universal():
    algebra = MatrixAlgebra(2, 0)
    manifold = CartesianBox(GRID_SIZE, DX)
    
    # --- GROUND TRUTH ---
    # We simulate a world with specific physics:
    # Speed of Sound = 343 m/s
    # Damping (V) = 0.5 (Air resistance?)
    TRUE_PARAMS = {
        'pos': jnp.array([3.0, 7.0]), 
        'vel': jnp.array([40.0, -20.0]),
        'c': 343.0,
        'V': 0.5
    }
    
    print("--- Generating Ground Truth ---")
    # Generate data with epsilon=0 (Real Physics)
    true_hist = run_simulation(TRUE_PARAMS, {'epsilon': 0.0}, algebra, manifold)
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]

    # --- THE BLIND SOLVER ---
    # We guess EVERYTHING.
    # "I don't know where it is, how fast it's moving, or what the speed of sound is."
    guess = {
        'pos': jnp.array([5.0, 5.0]), 
        'vel': jnp.array([0.0, 0.0]),
        'c': 300.0,  # Wrong guess (Standard air is 343)
        'V': 0.0     # Wrong guess (Assume vacuum)
    }

    # Complex Optimizer Strategy
    # We need different learning rates for different physical units!
    # c (~300) needs large updates? No, c is sensitive. Small updates relative to mag.
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {
                'pos': optax.adam(0.1), 
                'vel': optax.adam(5.0),
                'c':   optax.adam(1.0), # Learn speed of sound
                'V':   optax.adam(0.1)  # Learn damping
            },
            {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}
        )
    )
    opt_state = optimizer.init(guess)
    
    def loss_fn(params, eps):
        sim_hist = run_simulation(params, {'epsilon': eps}, algebra, manifold)
        sim_data = sim_hist[:, SENSORS[:,0], SENSORS[:,1]]
        
        # Normalize (Correlation Loss is safer for 'c' estimation than MSE)
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe_eps)
        return 1.0 - jnp.mean(sim_norm * obs_norm)

    print("\n--- Discovering Physics ---")
    # Anneal Epsilon to help find 'c' and 'pos' simultaneously
    for eps in [2.0, 1.0, 0.0]:
        print(f">>> Epsilon {eps}")
        for i in range(50):
            l, g = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(g, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            
            if i % 10 == 0:
                print(f"Iter {i}: L={l:.4f} | c={guess['c']:.1f} | V={guess['V']:.2f} | Pos={guess['pos']}")

    print("\n--- FINAL DISCOVERY ---")
    print(f"True: c={TRUE_PARAMS['c']}, V={TRUE_PARAMS['V']}")
    print(f"Est : c={guess['c']:.1f}, V={guess['V']:.2f}")

if __name__ == "__main__":
    solve_universal()

--- Generating Ground Truth ---

--- Discovering Physics ---
>>> Epsilon 2.0
Iter 0: L=0.4140 | c=299.0 | V=0.10 | Pos=[4.9000006 5.0999994]
Iter 10: L=0.1076 | c=308.0 | V=1.12 | Pos=[3.9292595 6.0769677]
Iter 20: L=0.0642 | c=318.8 | V=2.20 | Pos=[3.5451179 6.5610857]
Iter 30: L=0.0189 | c=329.3 | V=3.13 | Pos=[3.7362127 6.400748 ]
Iter 40: L=0.0087 | c=336.2 | V=3.61 | Pos=[3.596058 6.427496]
>>> Epsilon 1.0
Iter 0: L=0.0040 | c=340.4 | V=3.73 | Pos=[3.4598937 6.5909076]
Iter 10: L=0.0026 | c=343.2 | V=3.66 | Pos=[3.5011408 6.5931168]
Iter 20: L=0.0024 | c=344.1 | V=3.45 | Pos=[3.4383194 6.6044083]
Iter 30: L=0.0022 | c=344.1 | V=3.16 | Pos=[3.4145033 6.6372914]
Iter 40: L=0.0020 | c=343.8 | V=2.86 | Pos=[3.4138277 6.6435943]
>>> Epsilon 0.0
Iter 0: L=0.0016 | c=343.5 | V=2.55 | Pos=[3.382498  6.6572013]
Iter 10: L=0.0014 | c=343.2 | V=2.30 | Pos=[3.3764687 6.671409 ]
Iter 20: L=0.0013 | c=343.1 | V=2.09 | Pos=[3.352902  6.6860533]
Iter 30: L=0.0012 | c=343.0 | V=1.92 | Pos=[3.33520

In [9]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import optax
import functools
from kingdon import Algebra
from typing import NamedTuple, List, Dict

# ... [Keep MatrixAlgebra, CartesianBox, Operators, maccormack_step from previous] ...
# (Assuming previous definitions of MatrixAlgebra, CartesianBox, Operators are present)
# For brevity, I will re-include the minimal necessary parts or assume they are imported.
# Below is the FULL updated "Control Logic" section.

# ... [Paste previous MatrixAlgebra / CartesianBox / Universal PDE Kernel classes here] ...
# ... [For a standalone run, copy classes from the previous 'universal' script] ...
# ... [I will include them for completeness so this block runs standalone] ...

# --- RE-DEFINING ESSENTIALS FOR STANDALONE EXECUTION ---
class MatrixAlgebra:
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items(): dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            matrices.append(jnp.stack(cols, axis=1))
        self.basis_matrices = jnp.stack(matrices)
    def __hash__(self): return hash((self.p, self.q, self.r))
    def __eq__(self, o): return (self.p, self.q, self.r) == (o.p, o.q, o.r)

class CartesianBox:
    def __init__(self, size, dx):
        self.size, self.dx = size, dx
        xs = jnp.linspace(0, size*dx, size)
        self.X, self.Y = jnp.meshgrid(xs, xs, indexing='xy')
    def __hash__(self): return hash((self.size, self.dx))
    def __eq__(self, o): return (self.size, self.dx) == (o.size, o.dx)
    @property
    def coordinates(self): return (self.X, self.Y)
    def gradients(self, u):
        u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
        u_c = u_pad[1:-1, 1:-1]; dx = self.dx
        dx_f = (u_pad[1:-1, 2:] - u_c)/dx; dy_f = (u_pad[2:, 1:-1] - u_c)/dx
        dx_b = (u_c - u_pad[1:-1, 0:-2])/dx; dy_b = (u_c - u_pad[0:-2, 1:-1])/dx
        lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_c)/(dx**2)
        return ([dx_f, dy_f], [dx_b, dy_b], lap)
    def enforce_boundaries(self, u):
        u = u.at[0].set(0); u = u.at[-1].set(0); u = u.at[:, 0].set(0); u = u.at[:, -1].set(0)
        return u

GRID_SIZE, DX, DT, DURATION_STEPS, IDX_P = 100, 0.1, 0.0001, 1200, 0

@functools.partial(jit, static_argnames=['algebra'])
def universal_pde_kernel(u, grads, lap, params_dict, algebra):
    rate = jnp.zeros_like(u)
    # Advection
    for i, grad in enumerate(grads):
        if i >= len(algebra.basis_matrices): break
        rate = rate - params_dict['c'] * jnp.einsum('kj,...j->...k', algebra.basis_matrices[i], grad)
    # Diffusion + Potential
    rate = rate + params_dict['epsilon'] * lap - params_dict['V'] * u
    return rate

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def maccormack_step(u, dt, params, algebra, manifold):
    g_f, g_b, lap = manifold.gradients(u)
    k1 = universal_pde_kernel(u, g_f, lap, params, algebra)
    u_p = u + k1 * dt
    _, g_b_p, lap_p = manifold.gradients(u_p)
    k2 = universal_pde_kernel(u_p, g_b_p, lap_p, params, algebra)
    return manifold.enforce_boundaries(0.5 * (u + u_p + k2 * dt))

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def run_simulation(learn_params, hyper_params, algebra, manifold):
    X, Y = manifold.coordinates
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    x0, y0 = learn_params['pos']
    vx, vy = learn_params['vel']
    pde_params = {'c': learn_params['c'], 'V': learn_params['V'], 'epsilon': hyper_params['epsilon']}
    
    def body_fn(carry, step_idx):
        u_next = maccormack_step(carry, DT, pde_params, algebra, manifold)
        t = step_idx * DT
        src = jnp.exp(-((X-(x0+vx*t))**2 + (Y-(y0+vy*t))**2)/4.5) * jnp.exp(-(step_idx-50)**2/800.0) * 100.0 * DT
        u_next = u_next.at[..., 0].add(src)
        return u_next, u_next[..., 0]
    return lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))[1]

# ==============================================================================
# 3. CONTROLLED OPTIMIZATION LOGIC
# ==============================================================================

class SolverConfig(NamedTuple):
    target_precision: float # Stop if loss < this
    max_iter_per_stage: int # Budget per epsilon stage
    min_grad_norm: float    # Stop if gradients vanish (convergence)
    epsilon_schedule: List[float]

def solve_with_controls(config: SolverConfig):
    algebra = MatrixAlgebra(2, 0)
    manifold = CartesianBox(GRID_SIZE, DX)
    
    # --- GROUND TRUTH ---
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    print("--- Generating Ground Truth ---")
    true_hist = run_simulation(TRUE_PARAMS, {'epsilon': 0.0}, algebra, manifold)
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]

    # --- SETUP OPTIMIZER ---
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0), 'c': optax.adam(1.0), 'V': optax.adam(0.1)},
            {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}
        )
    )
    opt_state = optimizer.init(guess)
    
    # Loss Function (Correlation)
    def loss_fn(params, eps):
        sim_hist = run_simulation(params, {'epsilon': eps}, algebra, manifold)
        sim_data = sim_hist[:, SENSORS[:,0], SENSORS[:,1]]
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe_eps)
        # Bounds penalty
        pos = params['pos']
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - jnp.mean(sim_norm * obs_norm)) + bounds

    print(f"\n--- Starting Controlled Search ---")
    print(f"Goal: Loss < {config.target_precision} OR Iter > {config.max_iter_per_stage}")
    
    for stage, eps in enumerate(config.epsilon_schedule):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        
        for i in range(config.max_iter_per_stage):
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            
            # --- CONTROL LOGIC ---
            
            # 1. Check Precision (Success)
            # We usually only trust low-loss on the final stage (Real Physics)
            if loss < config.target_precision and eps == 0.0:
                print(f"   [STOP] Target precision reached at Iter {i}: Loss {loss:.6f}")
                break
                
            # 2. Check Gradient Norm (Stagnation)
            # Flatten gradients to compute global norm
            grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grads)))
            if grad_norm < config.min_grad_norm:
                print(f"   [SKIP] Converged (Grad Norm {grad_norm:.6f} < {config.min_grad_norm})")
                break
            
            # Log
            if i % 10 == 0:
                p, c_est = guess['pos'], guess['c']
                print(f"   Iter {i}: L={loss:.4f} | Grad={grad_norm:.4f} | c={c_est:.1f} | Pos={p}")
        else:
            # Python's 'for-else': Runs if loop completed WITHOUT break
            print(f"   [NEXT] Max iterations ({config.max_iter_per_stage}) reached.")

    print("\n--- FINAL DISCOVERY ---")
    print(f"True: c={TRUE_PARAMS['c']}, V={TRUE_PARAMS['V']}")
    print(f"Est : c={guess['c']:.1f}, V={guess['V']:.2f}")
    p_final = guess['pos']
    v_final = guess['vel']
    
    print(f"Estimated: Pos={p_final} Vel={v_final}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")
    

if __name__ == "__main__":
    # Define our constraints
    config = SolverConfig(
        target_precision=0.0005,  # Stop if correlation error < 0.5%
        max_iter_per_stage=100,   # Don't waste time on high-epsilon stages
        min_grad_norm=0.001,     # Stop if optimization flatlines
        epsilon_schedule=[2.0, 0.5, 0.0]
    )
    solve_with_controls(config)

--- Generating Ground Truth ---

--- Starting Controlled Search ---
Goal: Loss < 0.0005 OR Iter > 100

>>> STAGE 0 (Epsilon=2.0) <<<
   Iter 0: L=0.4140 | Grad=0.1996 | c=299.0 | Pos=[4.9000006 5.0999994]
   Iter 10: L=0.1076 | Grad=0.0792 | c=308.0 | Pos=[3.9292595 6.0769677]
   Iter 20: L=0.0642 | Grad=0.0875 | c=318.8 | Pos=[3.5451179 6.5610857]
   Iter 30: L=0.0189 | Grad=0.0244 | c=329.3 | Pos=[3.7362123 6.400748 ]
   Iter 40: L=0.0087 | Grad=0.0310 | c=336.2 | Pos=[3.5960577 6.4274964]
   Iter 50: L=0.0044 | Grad=0.0163 | c=340.4 | Pos=[3.4597843 6.590959 ]
   Iter 60: L=0.0029 | Grad=0.0076 | c=343.2 | Pos=[3.498338 6.594799]
   Iter 70: L=0.0027 | Grad=0.0044 | c=344.2 | Pos=[3.4360435 6.606053 ]
   Iter 80: L=0.0025 | Grad=0.0027 | c=344.3 | Pos=[3.4100184 6.6401525]
   [SKIP] Converged (Grad Norm 0.000849 < 0.001)

>>> STAGE 1 (Epsilon=0.5) <<<
   [SKIP] Converged (Grad Norm 0.000988 < 0.001)

>>> STAGE 2 (Epsilon=0.0) <<<
   [SKIP] Converged (Grad Norm 0.000967 < 0.001)

---

In [None]:
# ... (Previous imports and class definitions remain the same) ...

class SolverConfig(NamedTuple):
    target_precision: float 
    max_iter_per_stage: int
    min_grad_norm: float
    warmup_steps: int      # NEW: Force minimum steps per stage
    epsilon_schedule: List[float]

def solve_with_controls(config: SolverConfig):
    algebra = MatrixAlgebra(2, 0)
    manifold = CartesianBox(GRID_SIZE, DX)
    
    # --- GROUND TRUTH ---
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    true_hist = run_simulation(TRUE_PARAMS, {'epsilon': 0.0}, algebra, manifold)
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]

    # --- SETUP OPTIMIZER ---
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    
    # We maintain the 50x learning rate for Velocity, 
    # and maybe boost V slightly if it lags?
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {
                'pos': optax.adam(0.1), 
                'vel': optax.adam(5.0), 
                'c':   optax.adam(1.0), 
                'V':   optax.adam(0.5)  # Boosted V learning rate slightly
            },
            {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}
        )
    )
    opt_state = optimizer.init(guess)
    
    def loss_fn(params, eps):
        sim_hist = run_simulation(params, {'epsilon': eps}, algebra, manifold)
        sim_data = sim_hist[:, SENSORS[:,0], SENSORS[:,1]]
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe_eps)
        
        pos = params['pos']
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - jnp.mean(sim_norm * obs_norm)) + bounds

    print(f"\n--- Starting Controlled Search ---")
    
    for stage, eps in enumerate(config.epsilon_schedule):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        
        # Reset best loss for this stage
        stage_best_loss = 1e9
        
        for i in range(config.max_iter_per_stage):
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            guess = optax.apply_updates(guess, updates)
            
            grad_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grads)))
            
            # --- UPDATED CONTROL LOGIC ---
            
            # 1. Warmup Check: Ignore convergence signals early in the stage
            if i < config.warmup_steps:
                if i % 10 == 0:
                    print(f"   Iter {i} (Warmup): L={loss:.4f} | Grad={grad_norm:.5f}")
                continue

            # 2. Convergence Check
            if grad_norm < config.min_grad_norm:
                print(f"   [SKIP] Converged at Iter {i} (Grad {grad_norm:.6f} < {config.min_grad_norm})")
                break
                
            # 3. Precision Check (Only in final stage)
            if loss < config.target_precision and eps == 0.0:
                print(f"   [STOP] Target precision reached: Loss {loss:.6f}")
                break
            
            # Log
            if i % 10 == 0:
                p, c_est, v_est = guess['pos'], guess['c'], guess['V']
                print(f"   Iter {i}: L={loss:.4f} | Grad={grad_norm:.5f} | c={c_est:.1f} | V={v_est:.2f}")

    print("\n--- FINAL DISCOVERY ---")
    print(f"True: c={TRUE_PARAMS['c']}, V={TRUE_PARAMS['V']}")
    print(f"Est : c={guess['c']:.1f}, V={guess['V']:.2f}")
    p_final = guess['pos']
    v_final = guess['vel']
    print(f"Estimated: Pos={p_final} Vel={v_final}")
    print(f"True     : Pos={TRUE_PARAMS['pos']} Vel={TRUE_PARAMS['vel']}")

if __name__ == "__main__":
    config = SolverConfig(
        target_precision=0.0005, # Very strict
        max_iter_per_stage=150,  # Give it time
        min_grad_norm=1e-5,      # Much lower threshold (was 1e-3)
        warmup_steps=20,         # Force 20 steps per stage minimum
        epsilon_schedule=[2.0, 0.5, 0.0]
    )
    solve_with_controls(config)


--- Starting Controlled Search ---

>>> STAGE 0 (Epsilon=2.0) <<<
   Iter 0 (Warmup): L=0.4140 | Grad=0.19956
   Iter 10 (Warmup): L=0.1049 | Grad=0.08094
   Iter 20: L=0.0611 | Grad=0.08734 | c=318.8 | V=10.92
   Iter 30: L=0.0186 | Grad=0.02335 | c=329.3 | V=14.92
   Iter 40: L=0.0105 | Grad=0.03067 | c=335.9 | V=15.37
   Iter 50: L=0.0064 | Grad=0.01621 | c=339.7 | V=13.31
   Iter 60: L=0.0042 | Grad=0.00714 | c=342.4 | V=10.27
   Iter 70: L=0.0032 | Grad=0.00462 | c=343.4 | V=6.84
   Iter 80: L=0.0025 | Grad=0.00268 | c=343.6 | V=3.69
   Iter 90: L=0.0021 | Grad=0.00346 | c=343.7 | V=1.20
   Iter 100: L=0.0018 | Grad=0.00093 | c=343.6 | V=-0.58
   Iter 110: L=0.0016 | Grad=0.00138 | c=343.7 | V=-1.66
   Iter 120: L=0.0015 | Grad=0.00105 | c=343.7 | V=-2.21
   Iter 130: L=0.0014 | Grad=0.00092 | c=343.8 | V=-2.40
   Iter 140: L=0.0012 | Grad=0.00099 | c=343.8 | V=-2.42

>>> STAGE 1 (Epsilon=0.5) <<<
   Iter 0 (Warmup): L=0.0010 | Grad=0.00033
   Iter 10 (Warmup): L=0.0009 | Grad=0.

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, lax, value_and_grad
import optax
import functools
from kingdon import Algebra
from typing import NamedTuple, List, Dict

# ... [MatrixAlgebra, CartesianBox, Operators, etc. same as before] ...
# (Copying minimal boilerplate for runnable context)

class MatrixAlgebra:
    def __init__(self, p, q, r=0):
        self.p, self.q, self.r = p, q, r
        self.alg = Algebra(p, q, r)
        self.dim = len(self.alg)
        self.basis_names = [f'e{i+1}' for i in range(p + q + r)]
        matrices = []
        for name in self.basis_names:
            e_k = self.alg.blades[name]
            cols = []
            for i in range(self.dim):
                b_i = self.alg.multivector({i: 1})
                res = e_k * b_i
                dense_col = jnp.zeros(self.dim)
                for bin_key, val in res.items(): dense_col = dense_col.at[bin_key].set(val)
                cols.append(dense_col)
            matrices.append(jnp.stack(cols, axis=1))
        self.basis_matrices = jnp.stack(matrices)
    def __hash__(self): return hash((self.p, self.q, self.r))
    def __eq__(self, o): return (self.p, self.q, self.r) == (o.p, o.q, o.r)

class CartesianBox:
    def __init__(self, size, dx):
        self.size, self.dx = size, dx
        xs = jnp.linspace(0, size*dx, size)
        self.X, self.Y = jnp.meshgrid(xs, xs, indexing='xy')
    def __hash__(self): return hash((self.size, self.dx))
    def __eq__(self, o): return (self.size, self.dx) == (o.size, o.dx)
    @property
    def coordinates(self): return (self.X, self.Y)
    def gradients(self, u):
        u_pad = jnp.pad(u, ((1, 1), (1, 1), (0, 0)), mode='constant')
        u_c = u_pad[1:-1, 1:-1]; dx = self.dx
        dx_f = (u_pad[1:-1, 2:] - u_c)/dx; dy_f = (u_pad[2:, 1:-1] - u_c)/dx
        dx_b = (u_c - u_pad[1:-1, 0:-2])/dx; dy_b = (u_c - u_pad[0:-2, 1:-1])/dx
        lap = (u_pad[2:, 1:-1] + u_pad[0:-2, 1:-1] + u_pad[1:-1, 2:] + u_pad[1:-1, 0:-2] - 4*u_c)/(dx**2)
        return ([dx_f, dy_f], [dx_b, dy_b], lap)
    def enforce_boundaries(self, u):
        u = u.at[0].set(0); u = u.at[-1].set(0); u = u.at[:, 0].set(0); u = u.at[:, -1].set(0)
        return u

GRID_SIZE, DX, DT, DURATION_STEPS, IDX_P = 100, 0.1, 0.0001, 1200, 0

@functools.partial(jit, static_argnames=['algebra'])
def universal_pde_kernel(u, grads, lap, params_dict, algebra):
    rate = jnp.zeros_like(u)
    for i, grad in enumerate(grads):
        if i >= len(algebra.basis_matrices): break
        rate = rate - params_dict['c'] * jnp.einsum('kj,...j->...k', algebra.basis_matrices[i], grad)
    rate = rate + params_dict['epsilon'] * lap - params_dict['V'] * u
    return rate

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def maccormack_step(u, dt, params, algebra, manifold):
    g_f, g_b, lap = manifold.gradients(u)
    k1 = universal_pde_kernel(u, g_f, lap, params, algebra)
    u_p = u + k1 * dt
    _, g_b_p, lap_p = manifold.gradients(u_p)
    k2 = universal_pde_kernel(u_p, g_b_p, lap_p, params, algebra)
    return manifold.enforce_boundaries(0.5 * (u + u_p + k2 * dt))

@functools.partial(jit, static_argnames=['algebra', 'manifold'])
def run_simulation(learn_params, hyper_params, algebra, manifold):
    X, Y = manifold.coordinates
    u = jnp.zeros((GRID_SIZE, GRID_SIZE, algebra.dim))
    x0, y0 = learn_params['pos']
    vx, vy = learn_params['vel']
    pde_params = {'c': learn_params['c'], 'V': learn_params['V'], 'epsilon': hyper_params['epsilon']}
    
    def body_fn(carry, step_idx):
        u_next = maccormack_step(carry, DT, pde_params, algebra, manifold)
        t = step_idx * DT
        # Source injection
        src = jnp.exp(-((X-(x0+vx*t))**2 + (Y-(y0+vy*t))**2)/4.5) * jnp.exp(-(step_idx-50)**2/800.0) * 100.0 * DT
        u_next = u_next.at[..., 0].add(src)
        return u_next, u_next[..., 0]
    return lax.scan(body_fn, u, jnp.arange(DURATION_STEPS))[1]

# ==============================================================================
# 3. PRECISION CONFIG & LOGIC
# ==============================================================================

class ConvergenceConfig(NamedTuple):
    # Instead of one "tolerance", we have one for each physical type
    tol_pos: float  # Stop if position changes < X meters
    tol_vel: float  # Stop if velocity changes < X m/s
    tol_c: float    # Stop if speed of sound changes < X m/s
    tol_V: float    # Stop if damping changes < X
    
    max_iter: int
    warmup: int
    epsilon_schedule: List[float]

def solve_precision(config: ConvergenceConfig):
    algebra = MatrixAlgebra(2, 0)
    manifold = CartesianBox(GRID_SIZE, DX)
    
    # Ground Truth
    TRUE_PARAMS = {'pos': jnp.array([3.0, 7.0]), 'vel': jnp.array([40.0, -20.0]), 'c': 343.0, 'V': 0.5}
    obs_data = run_simulation(TRUE_PARAMS, {'epsilon': 0.0}, algebra, manifold)[:, 90, 10] # Tiny hack for speed/syntax
    # Re-run properly for full sensors if needed, but keeping it simple for logic demo
    # Let's do it right:
    SENSORS = jnp.array([[90, 10], [90, 90], [10, 10], [10, 90]])
    true_hist = run_simulation(TRUE_PARAMS, {'epsilon': 0.0}, algebra, manifold)
    obs_data = true_hist[:, SENSORS[:,0], SENSORS[:,1]]

    # Optimizer
    guess = {'pos': jnp.array([5.0, 5.0]), 'vel': jnp.array([0.0, 0.0]), 'c': 300.0, 'V': 0.0}
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.multi_transform(
            {'pos': optax.adam(0.1), 'vel': optax.adam(5.0), 'c': optax.adam(1.0), 'V': optax.adam(0.5)},
            {'pos':'pos', 'vel':'vel', 'c':'c', 'V':'V'}
        )
    )
    opt_state = optimizer.init(guess)
    
    def loss_fn(params, eps):
        sim_hist = run_simulation(params, {'epsilon': eps}, algebra, manifold)
        sim_data = sim_hist[:, SENSORS[:,0], SENSORS[:,1]]
        safe_eps = 1e-6
        sim_norm = (sim_data - jnp.mean(sim_data, 0)) / (jnp.std(sim_data, 0) + safe_eps)
        obs_norm = (obs_data - jnp.mean(obs_data, 0)) / (jnp.std(obs_data, 0) + safe_eps)
        
        pos = params['pos']
        bounds = jnp.sum(jnp.maximum(0, -pos)) + jnp.sum(jnp.maximum(0, pos - GRID_SIZE*DX))
        return (1.0 - jnp.mean(sim_norm * obs_norm)) + bounds

    print(f"\n--- Starting Precision Search ---")
    
    for stage, eps in enumerate(config.epsilon_schedule):
        print(f"\n>>> STAGE {stage} (Epsilon={eps}) <<<")
        
        for i in range(config.max_iter):
            loss, grads = value_and_grad(loss_fn)(guess, eps)
            updates, opt_state = optimizer.update(grads, opt_state, guess)
            
            # --- THE MAGIC: Check Update Magnitudes ---
            # We check how much the optimizer *actually moved* the parameters this step
            # This handles learning rates automatically.
            
            # Apply updates to get new guess
            new_guess = optax.apply_updates(guess, updates)
            
            # Calculate shift per parameter
            diffs = {
                'pos': jnp.linalg.norm(new_guess['pos'] - guess['pos']),
                'vel': jnp.linalg.norm(new_guess['vel'] - guess['vel']),
                'c':   jnp.abs(new_guess['c'] - guess['c']),
                'V':   jnp.abs(new_guess['V'] - guess['V'])
            }
            
            guess = new_guess
            
            # --- CONVERGENCE CHECK ---
            is_stable = (
                diffs['pos'] < config.tol_pos and
                diffs['vel'] < config.tol_vel and
                diffs['c']   < config.tol_c and
                diffs['V']   < config.tol_V
            )
            
            # Don't stop during warmup
            if i >= config.warmup and is_stable:
                print(f"   [CONVERGED] All parameters stable at Iter {i}")
                print(f"      Deltas: Pos={diffs['pos']:.5f} Vel={diffs['vel']:.5f} c={diffs['c']:.5f} V={diffs['V']:.5f}")
                break
            
            if i % 10 == 0:
                print(f"   Iter {i}: L={loss:.5f}")
                print(f"      Est: Vel={guess['vel']} (Delta: {diffs['vel']:.4f})")
                print(f"      Est: V={guess['V']:.2f} (Delta: {diffs['V']:.4f})")

    print("\n--- FINAL ---")
    print(f"True: Vel={TRUE_PARAMS['vel']}, V={TRUE_PARAMS['V']}")
    print(f"Est : Vel={guess['vel']}, V={guess['V']:.2f}")

if __name__ == "__main__":
    # Define tolerances based on physics units
    config = ConvergenceConfig(
        tol_pos=0.001,  # 1 mm stability
        tol_vel=0.01,   # 1 cm/s stability (Forces it to grind velocity!)
        tol_c=0.01,     # 1 cm/s stability for sound speed
        tol_V=0.0001,   # Very strict stability for damping
        
        max_iter=200,   # Allow long grind
        warmup=30,      # Don't quit early
        epsilon_schedule=[2.0, 0.5, 0.0]
    )
    solve_precision(config)

In [None]:
jax.config.update("jax_enable_x64", True)