In [1]:
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax import jacfwd, jacrev, jit, lax, vmap, grad
import torch
import torch.nn as nn

In [2]:
class MLP2(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64, num_layers=1, batch_norm=False):
        super(MLP2, self).__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.act = nn.Tanh()
        if batch_norm: self.bn = nn.BatchNorm1d(hidden_dim)
        else: self.bn = nn.Identity()

    def forward(self, x):
        out = self.act(self.bn(self.fc_in(x)))
        for layer in self.layers: out = self.act(self.bn(layer(out)))
        out = self.fc_out(out)
        return out

def load_model_from_torch(d):
    z_net = MLP2(d+1, d)
    z_net.load_state_dict(torch.load(f'./trained_bsde_model_lq/d{d}_z_net_1.pth'))
    params_pt = {n: p.detach().numpy() for n, p in z_net.state_dict().items()}
    params_jax = {
        "w1": jnp.array(params_pt["fc_in.weight"]), "b1": jnp.array(params_pt["fc_in.bias"]),
        "w2": jnp.array(params_pt["layers.0.weight"]), "b2": jnp.array(params_pt["layers.0.bias"]),
        "w3": jnp.array(params_pt["fc_out.weight"]), "b3": jnp.array(params_pt["fc_out.bias"]),
    }   # Convert to JAX params
    return params_jax

In [3]:
def z_net_jax(params, x):
    h1 = jnp.tanh(jnp.dot(x, params["w1"].T) + params["b1"])
    h2 = jnp.tanh(jnp.dot(h1, params["w2"].T) + params["b2"])
    out = jnp.dot(h2, params["w3"].T) + params["b3"]
    return out

def div_z(params, tx):
    return jnp.trace(jacfwd(lambda x: z_net_jax(params, x))(tx)[:, 1:])  # Calculates divergence: trace of Jacobian w.r.t x (indices 1:)

@jit
def get_dH_dx_and_div(params, t, x, dw_t):
    """Computes dH/dx and divergence efficiently."""
    tx = jnp.concatenate([t[None], x])
    
    J_z = jacfwd(lambda y: z_net_jax(params, y))(tx)          # Calculate Z and Jacobian of Z w.r.t tx
    z_val = z_net_jax(params, tx)
    
    jx_z = J_z[:, 1:]                                         # Extract Jacobian w.r.t x (d_z/d_x)
    div_val = jnp.trace(jx_z)                                 # Divergence is trace of d_z/d_x
    
    term1 = dw_t @ jx_z
    grad_div = jacrev(lambda y: div_z(params, y))(tx)[1:]     # Gradient of divergence: This is a second derivative.
    
    dh_dx = -term1 + (jnp.sqrt(2.0)/2.0) * grad_div
    return dh_dx, z_val, div_val

def b_tilde(x, p, dw_t):
    return -2*p + jnp.sqrt(2) * dw_t 

def bm_kl_basis(t_grid, n_terms):
    """Precompute cosine basis for KL expansion."""
    js_plus = jnp.arange(0, n_terms) + 0.5
    # Shape: (N+1, n_terms)
    basis = jnp.sqrt(2) * jnp.cos(jnp.pi * jnp.outer(t_grid, js_plus)) 
    return basis

In [4]:
def get_w_dot_at_t(t, xi):
    """
    Computes the value of the noise time-derivative at an arbitrary time t.
    This allows RK4 to query noise at t + dt/2.
    """
    n_terms = xi.shape[0]
    # j_plus = j + 0.5
    j_plus = jnp.arange(n_terms) + 0.5
    # basis = sqrt(2) * cos(pi * t * (j+0.5))
    basis = jnp.sqrt(2) * jnp.cos(jnp.pi * t * j_plus)
    return jnp.dot(basis, xi)

def get_system_derivs(t, state, xi, params):
    """
    Returns [dx/dt, dp/dt, dCost/dt] given current state (x, p, cost).
    """
    x, p, _ = state  # We don't need the current 'accumulated cost' to calculate derivatives
    
    dw_t = get_w_dot_at_t(t, xi)
    tx = jnp.concatenate([t[None], x])
    
    # Forward Jacobian of Z w.r.t tx
    J_z = jax.jacfwd(lambda y: z_net_jax(params, y))(tx)
    z_val = z_net_jax(params, tx)
    jx_z = J_z[:, 1:] 
    div_val = jnp.trace(jx_z)
    
    grad_div = jax.jacrev(lambda y: div_z(params, y))(tx)[1:]    # Gradient of divergence (using jacrev)
    
    term1 = dw_t @ jx_z
    dh_dx = -term1 + (jnp.sqrt(2.0)/2.0) * grad_div
    
    dx_dt = -2*p + jnp.sqrt(2) * dw_t     # dx/dt = b_tilde
    dp_dt = -dh_dx
   
    h_part1 = -jnp.sum(p**2) + jnp.sqrt(2) * jnp.dot(p, dw_t)
    h_part2 = -jnp.dot(z_val, dw_t) + (jnp.sqrt(2)/2.0) * div_val
    H_val = h_part1 + h_part2
    
    dcost_dt = H_val - jnp.dot(x, dh_dx)
    
    return dx_dt, dp_dt, dcost_dt

# ---- RK4 ----
def rk4_step(state, t, dt, xi, params):
    """
    Performs one RK4 step for the coupled system (x, p, cost).
    """
    # k1
    k1_x, k1_p, k1_c = get_system_derivs(t, state, xi, params)
    
    # k2 (at t + dt/2)
    state_k2 = (state[0] + 0.5*dt*k1_x, state[1] + 0.5*dt*k1_p, state[2] + 0.5*dt*k1_c)
    k2_x, k2_p, k2_c = get_system_derivs(t + 0.5*dt, state_k2, xi, params)
    
    # k3 (at t + dt/2)
    state_k3 = (state[0] + 0.5*dt*k2_x, state[1] + 0.5*dt*k2_p, state[2] + 0.5*dt*k2_c)
    k3_x, k3_p, k3_c = get_system_derivs(t + 0.5*dt, state_k3, xi, params)
    
    # k4 (at t + dt)
    state_k4 = (state[0] + dt*k3_x, state[1] + dt*k3_p, state[2] + dt*k3_c)
    k4_x, k4_p, k4_c = get_system_derivs(t + dt, state_k4, xi, params)
    
    # Combine
    x_next = state[0] + (dt/6.0)*(k1_x + 2*k2_x + 2*k3_x + k4_x)
    p_next = state[1] + (dt/6.0)*(k1_p + 2*k2_p + 2*k3_p + k4_p)
    c_next = state[2] + (dt/6.0)*(k1_c + 2*k2_c + 2*k3_c + k4_c)
    
    return (x_next, p_next, c_next)

def loss_fn_single_rk4(p0, xi, params, t_grid, dt):
    """
    Computes objective using RK4 solver.
    """
    d = p0.shape[0]
    x0 = jnp.zeros(d)
    
    init_state = (x0, p0, 0.0)
    
    times = t_grid[:-1] # 0 to T-dt
    
    def scan_fn(state, t):
        next_state = rk4_step(state, t, dt, xi, params)
        return next_state, None

    final_state, _ = lax.scan(scan_fn, init_state, times)
    
    x_T, p_T, total_integral = final_state
    
    # Terminal Cost (Same as original)
    param_g = jnp.array([1.0, 25.0/4.0])
    g_star = 0.5 * (jnp.dot(p_T * param_g, p_T) - 1)
    
    hopf_val = jnp.dot(x0, p0) - g_star + total_integral
    return -hopf_val

In [5]:
# ---- Gradient Ascent ---- 
@jit
def optimize_single_sample_rk4(xi, params, t_grid, dt):
    d = xi.shape[1]
    p0 = jnp.zeros(d)
    learning_rate = 0.1
    iterations = 100
    
    val_and_grad_fn = jax.value_and_grad(loss_fn_single_rk4, argnums=0)
    
    def train_step(state, _):
        p, opt_state = state
        loss, grads = val_and_grad_fn(p, xi, params, t_grid, dt)
        p_new = p - learning_rate * grads
        return (p_new, opt_state), loss

    (p_final, _), losses = lax.scan(train_step, (p0, 0), None, length=iterations)
    return -losses[-1]

In [6]:
def run_batched_optimization():
    # Settings
    Nt = 200
    T = 1.0
    dt = T / Nt
    n_terms = 32
    n_samples = 500
    d = 2             # param_g needs to be changed accordingly. 

    params_jax_d = load_model_from_torch(d)
    
    t_grid = jnp.linspace(0, T, Nt + 1)
    basis = bm_kl_basis(t_grid, n_terms) # Precompute basis
    
    # Generate all random noise at once
    key = jax.random.PRNGKey(42)
    xi_batch = jax.random.normal(key, (n_samples, n_terms, d))
    
    print(f"Compiling and running for {n_samples} samples...")
    start_time = time.time()
    
    # Vectorize the optimizer over the batch of xi
    batch_optimizer = vmap(optimize_single_sample_rk4, in_axes=(0, None, None, None))
    
    final_values = batch_optimizer(xi_batch, params_jax_d, t_grid, dt)
    
    # Measure time accurately
    final_values.block_until_ready()
    end_time = time.time()
    
    print(f"Total Time: {end_time - start_time:.4f}s")
    
    # ---- Statistics ----
    lb = jnp.mean(final_values)
    se = jnp.std(final_values) / jnp.sqrt(n_samples)
    
    print("-" * 30)
    print(f"Dual lower bound: {lb:.4f} +/- {1.96 * se:.4f}")
    print(f"95% lower confidence limit: {lb - 1.96 * se:.4f}")
    print("-" * 30)

run_batched_optimization()

Compiling and running for 500 samples...
Total Time: 317.9576s
------------------------------
Dual lower bound: 1.1889 +/- 0.0016
95% lower confidence limit: 1.1873
------------------------------
