In [1]:
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax import jacfwd, jacrev, vmap, jit
from jax import lax
import torch
import torch.nn as nn

In [2]:
# LOAD MODEL 
class MLP2(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64, num_layers=1, batch_norm=False):
        super(MLP2, self).__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.act = nn.Tanh()
        if batch_norm: self.bn = nn.BatchNorm1d(hidden_dim)
        else: self.bn = nn.Identity()

    def forward(self, x):
        out = self.act(self.bn(self.fc_in(x)))
        for layer in self.layers: out = self.act(self.bn(layer(out)))
        out = self.fc_out(out)
        return out

d = 2
z_net = MLP2(d+1, d)
z_net.load_state_dict(torch.load(f'./trained_bsde_model_lq/d{d}_z_net_1.pth'))
params_pt = {n: p.detach().numpy() for n, p in z_net.state_dict().items()}
# --------------------------------------------------------------------------

params_jax = {
    "w1": jnp.array(params_pt["fc_in.weight"]), "b1": jnp.array(params_pt["fc_in.bias"]),
    "w2": jnp.array(params_pt["layers.0.weight"]), "b2": jnp.array(params_pt["layers.0.bias"]),
    "w3": jnp.array(params_pt["fc_out.weight"]), "b3": jnp.array(params_pt["fc_out.bias"]),
}

In [3]:
def z_net_jax(params, x):
    # Pass params explicitly to allow JAX to trace gradients w.r.t inputs safely
    h1 = jnp.tanh(jnp.dot(x, params["w1"].T) + params["b1"])
    h2 = jnp.tanh(jnp.dot(h1, params["w2"].T) + params["b2"])
    out = jnp.dot(h2, params["w3"].T) + params["b3"]
    return out

def jac_z(params, tx):
    # Differentiate z_net w.r.t tx, return spatial part
    return jacfwd(lambda x: z_net_jax(params, x))(tx)[:, 1:]

def div_z(params, tx):
    return jnp.trace(jac_z(params, tx))

def dH_dx(params, t, x, p, dw_t):  
    tx = jnp.concatenate([t[None], x])
    jx_z = jac_z(params, tx)
    term1 = dw_t @ jx_z
    term2 = jacrev(lambda y: div_z(params, y))(tx)[1:]
    return -term1 + (jnp.sqrt(2.0)/2.0) * term2

def b_tilde(x, p, dw_t):  
    return -2*p + jnp.sqrt(2) * dw_t

def u_star(p):
    return -p

# KL Expansion in JAX
def bm_kl_jax(xi, t_grid):
    # xi: (n_terms, d), t_grid: (N+1,)
    n_terms = xi.shape[0]
    js_plus = jnp.arange(0, n_terms) + 0.5
    basis = jnp.sqrt(2) * jnp.cos(jnp.pi * jnp.outer(t_grid, js_plus))   # Shape: (N+1, n_terms)
    return basis @ xi       # return \dot{w}_n(t), Shape: (N+1, d)

In [4]:
# Settings
T = 1.0
N = 200
dt = T / N
t_grid = jnp.linspace(0, T, N + 1)
iter_max = 50
alpha = 0.18
x0 = jnp.zeros(d)

@jit
def solve_single_path_p_iteration(xi, params):
    # 1. Setup
    dw = bm_kl_jax(xi, t_grid) 
    p_init = jnp.zeros((N + 1, d))
    init_val = (0, p_init, 1.0)
    tol = 1e-6 

    def cond_fun(val):
        i, _, diff = val
        return (i < iter_max) & (diff > tol)

    def step_body(val):
        iter_count, p_prev, _ = val
        
        # x_{i+1} = x_i + integral(b_tilde)
        def forward_body(x_curr, i):
            # Inputs at t_i, t_{i+1}, and midpoint
            p_i = p_prev[i]
            p_next = p_prev[i+1]
            p_mid = 0.5 * (p_i + p_next)
            
            dw_i = dw[i]
            dw_next = dw[i+1]
            dw_mid = 0.5 * (dw_i + dw_next)
            
            # k1 (at t_i)
            k1 = b_tilde(x_curr, p_i, dw_i)
            
            # k2 (at t_i + dt/2)
            x_k2 = x_curr + 0.5 * dt * k1
            k2 = b_tilde(x_k2, p_mid, dw_mid)
            
            # k3 (at t_i + dt/2)
            x_k3 = x_curr + 0.5 * dt * k2
            k3 = b_tilde(x_k3, p_mid, dw_mid)
            
            # k4 (at t_{i+1})
            x_k4 = x_curr + dt * k3
            k4 = b_tilde(x_k4, p_next, dw_next)
            
            # Update
            x_next_step = x_curr + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
            return x_next_step, x_curr

        _, x_traj = lax.scan(forward_body, x0, jnp.arange(N))
        x_last = forward_body(x_traj[-1], N-1)[0]
        x_full = jnp.concatenate([x_traj, x_last[None, :]], axis=0)

        # p_i = p_{i+1} + integral(dH_dx)
        p_coef = jnp.array([1.0, 4.0/25.0])
        p_T = p_coef[:d] * x_full[-1]
        
        def backward_body(p_next_val, i):
            t_next_time = t_grid[i+1]
            t_curr_time = t_grid[i]
            t_mid_time = 0.5 * (t_next_time + t_curr_time)
            
            x_next_val = x_full[i+1]
            x_curr_val = x_full[i]
            x_mid_val = 0.5 * (x_next_val + x_curr_val)
            
            dw_next_val = dw[i+1]
            dw_curr_val = dw[i]
            dw_mid_val = 0.5 * (dw_next_val + dw_curr_val)
            
            # k1 (at t_{i+1})
            k1 = dH_dx(params, t_next_time, x_next_val, p_next_val, dw_next_val)
            
            # k2 (at t_{i+1} - dt/2 -> midpoint)
            p_k2 = p_next_val + 0.5 * dt * k1
            k2 = dH_dx(params, t_mid_time, x_mid_val, p_k2, dw_mid_val)
            
            # k3 (at midpoint)
            p_k3 = p_next_val + 0.5 * dt * k2
            k3 = dH_dx(params, t_mid_time, x_mid_val, p_k3, dw_mid_val)
            
            # k4 (at t_i)
            p_k4 = p_next_val + dt * k3
            k4 = dH_dx(params, t_curr_time, x_curr_val, p_k4, dw_curr_val)
            p_curr_val = p_next_val + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
            
            return p_curr_val, p_curr_val

        rev_indices = jnp.arange(N)[::-1]
        _, p_rev_traj = lax.scan(backward_body, p_T, rev_indices)
        
        p_new_calculated = jnp.concatenate([p_rev_traj[::-1], p_T[None, :]], axis=0)
        p_updated = alpha * p_new_calculated + (1 - alpha) * p_prev
        diff_norm = jnp.mean((p_updated - p_prev)**2)
        
        return (iter_count + 1, p_updated, diff_norm)

    final_iter, p_opt, final_diff = lax.while_loop(cond_fun, step_body, init_val)
    
    def final_forward(x_curr, i):
        p_i = p_opt[i]
        p_next = p_opt[i+1]
        p_mid = 0.5 * (p_i + p_next)
        
        dw_i = dw[i]
        dw_next = dw[i+1]
        dw_mid = 0.5 * (dw_i + dw_next)
        
        k1 = b_tilde(x_curr, p_i, dw_i)
        k2 = b_tilde(x_curr + 0.5*dt*k1, p_mid, dw_mid)
        k3 = b_tilde(x_curr + 0.5*dt*k2, p_mid, dw_mid)
        k4 = b_tilde(x_curr + dt*k3, p_next, dw_next)
        
        x_next_step = x_curr + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
        return x_next_step, x_curr
        
    x_traj_final_scan_out, x_traj_final = lax.scan(final_forward, x0, jnp.arange(N))
    x_res = jnp.concatenate([x_traj_final, x_traj_final_scan_out[None, :]], axis=0)
    
    return x_res, p_opt, final_diff, final_iter

In [5]:
@jit
def compute_cost_single(x, p, xi, params):
    dw = bm_kl_jax(xi, t_grid)
    
    def step_cost(carry, inputs):
        i = inputs
        t_k = t_grid[i]
        x_k = x[i]
        p_k = p[i]
        dw_k = dw[i]
        
        u_k = -p_k  
        r_val = jnp.sum(u_k**2)
        tx = jnp.concatenate([t_k[None], x_k])
        
        z_val = z_net_jax(params, tx)
        term1 = jnp.dot(dw_k, z_val)
        J_jac = jac_z(params, tx)
        term2 = (jnp.sqrt(2.0)/2.0) * jnp.trace(J_jac)
        
        running = r_val - term1 + term2
        return carry + running * dt, None

    running_total, _ = lax.scan(step_cost, 0.0, jnp.arange(N))
    term_cost = 0.5 * jnp.sum(jnp.array([1.0, 4.0/25.0]) * x[-1]**2) + 0.5
    return running_total + term_cost

In [6]:
batch_solver = jit(vmap(solve_single_path_p_iteration, in_axes=(0, None)))
batch_cost = jit(vmap(compute_cost_single, in_axes=(0, 0, 0, None)))

np.random.seed(42)
n_samples = 500
n_terms = 32
xi_batch = np.random.randn(n_samples, n_terms, d)

start_time = time.time()
print("Compiling and Running...")
x_batch, p_batch, diff_batch, iter_batch = batch_solver(jnp.array(xi_batch), params_jax)
J_scores = batch_cost(x_batch, p_batch, jnp.array(xi_batch), params_jax)
J_scores.block_until_ready()
elapsed = time.time() - start_time

print(f"Elapsed: {elapsed:.4f}s")
lb = jnp.mean(J_scores)
se = jnp.std(J_scores) / jnp.sqrt(n_samples)
print("-" * 30)
print(f"Dual lower bound: {lb:.4f} +/- {1.96 * se:.4f}")
print(f"95% lower confidence limit: {lb - 1.96 * se:.4f}")
print("-" * 30)
print(f"Convergence Stats:")
print(f"  Avg Iterations: {jnp.mean(iter_batch):.1f} / {iter_max}")
print(f"  Max Iterations: {jnp.max(iter_batch)}")
print(f"  Avg Final Diff: {jnp.mean(diff_batch):.2e} (Target < 1e-6)")
print("-" * 30)

Compiling and Running...
Elapsed: 20.5807s
------------------------------
Dual lower bound: 1.1858 +/- 0.0018
95% lower confidence limit: 1.1841
------------------------------
Convergence Stats:
  Avg Iterations: 19.7 / 50
  Max Iterations: 25
  Avg Final Diff: 8.17e-07 (Target < 1e-6)
------------------------------
