In [17]:
import dill
import os
import numpy as np
import jax
import jax.numpy as jnp

# Load true funcitons
ex_name = 'ex4-3D_spd'
file_path = os.path.join("true_functions", f"{ex_name}.pkl")

with open(file_path, "rb") as f:
    loaded_functions = dill.load(f)

true_drift = loaded_functions["drift"]
true_diffusion = loaded_functions["diffusion"]

diff_type = "symmetric"

In [18]:
# training data parameters
xlim = np.array([[-1, 1], [-1, 1], [-1, 1]])

n_trajectories = 10000
trajectory_time = 0.01
step_size = 0.01
grid_resolution = int(trajectory_time/step_size) + 1 # +1 to account for staring point


In [19]:
def simulate_euler_maruyama(key, drift_fn, diffusion_fn, x0, step_size, grid_resolution):
    """
    Simulate trajectories using Euler-Maruyama.
    - x0: (N, D) array of initial positions
    - Returns: x: (N, D, T)
    """
    n_trajectories, D = x0.shape
    T = grid_resolution

    # Generate dW: shape (T-1, N, D)
    key, subkey = jax.random.split(key)
    dW = jax.random.normal(subkey, shape=(T-1, n_trajectories, D)) * jnp.sqrt(step_size)

    def euler_step(x_t, dW_t):
        drift = drift_fn(x_t)                 # (N, D)
        diffusion = diffusion_fn(x_t).reshape(x_t.shape[0], D, D)         # (N, D, D)
        diffusion_term = jnp.einsum('nij,nj->ni', diffusion, dW_t)  # (N, D)
        #diffusion_term = jnp.matmul(diffusion, dW_t[..., None]).squeeze(-1)
        x_next = x_t + drift * step_size + diffusion_term
        return x_next, x_next

    # Run lax.scan
    _, xs = jax.lax.scan(euler_step, x0, dW)
    
    # Add initial condition to beginning
    x_full = jnp.concatenate([x0[:, :, None], xs.transpose(1, 2, 0)], axis=-1)  # (N, D, T)
    return x_full

In [20]:
key = jax.random.PRNGKey(0)
x0 = jax.random.uniform(key, shape=(n_trajectories, xlim.shape[0]), minval=xlim[:,0], maxval=xlim[:,1])

x = simulate_euler_maruyama(key,
                            drift_fn=true_drift,
                            diffusion_fn=true_diffusion,
                            x0=x0,
                            step_size=step_size,
                            grid_resolution=grid_resolution)


In [21]:
x_data = x[:, :, :-1].reshape(-1, xlim.shape[0])
y_data = x[:, :, 1:].reshape(-1, xlim.shape[0])


f = y_data - (x_data + step_size * true_drift(x_data))
print(jnp.mean((f ** 2), axis=0))
print(true_diffusion(x_data[0,:]))

[2.3624081e-05 6.1316387e-05 1.7342680e-04]
[[ 0.02045836 -0.03502162 -0.02678421]
 [-0.03502162  0.06355898  0.02981971]
 [-0.02678421  0.02981971  0.12453806]]


In [22]:

np.savez(f"training_data/{ex_name}_time{trajectory_time}_SS{step_size}_ntraj{n_trajectories}.npz", diff_type=diff_type, step_size=step_size, trajectories=np.array(x))
