In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import sys
sys.path.append("..")

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from omegaconf import OmegaConf

from neuralbridge.run_scripts.read_config import read_config
from neuralbridge.configs.neural_bridge_config import get_neural_bridge_landmark_config
from neuralbridge.stochastic_processes.examples import SDEFactory
from neuralbridge.solvers.sde import WienerProcess, Euler
from neuralbridge.models import neurb, pCN
from neuralbridge.utils.sample_path import SamplePath
from neuralbridge.utils.plotting import plot_landmark_sample_path
from neuralbridge.utils.t_grid import TimeGrid

In [3]:
X_brownian_unc = SDEFactory(
    config=OmegaConf.create({
        "sde": {
            "name": "landmark",
            "n_landmarks": 50,
            "X_dim": 50 * 2,
            "W_dim": 50 * 2,
            "T": 1.0,
            "dt": 0.01,
            "t_scheme": "linear",
            "params_X_unc": {
                "k_alpha": 0.3,
                "k_sigma": 0.2,
                "n_landmarks": 50,
                "m_landmarks": 2
            }
        }
    })
).get_original_sde()

X_lagrangian_unc = SDEFactory(
    config=OmegaConf.create({
        "sde": {
            "name": "landmark",
            "n_landmarks": 50,
            "X_dim": 50 * 2,
            "W_dim": 50 * 2,
            "T": 1.0,
            "dt": 0.01,
            "t_scheme": "linear",
            "params_X_unc": {
                "k_alpha": 0.3,
                "k_sigma": 0.2,
                "n_landmarks": 50,
                "m_landmarks": 2
            }
        }
    })
).get_original_sde()

In [18]:
config = get_neural_bridge_landmark_config()
print(config)

{'sde': {'name': 'landmark', 'n_landmarks': 50, 'T': 1.0, 'dt': 0.01, 'X_dim': 100, 'W_dim': 100, 't_scheme': 'linear', 'params_X_unc': {'k_alpha': 0.3, 'k_sigma': 0.5, 'n_landmarks': 50, 'm_landmarks': 2}, 'params_X_aux': {'k_alpha': 0.3, 'k_sigma': 0.5, 'n_landmarks': 50, 'm_landmarks': 2}, 'u': Array([ 0.00000000e+00,  5.00000000e-01,  1.25333234e-01,  4.96057351e-01,
        2.48689887e-01,  4.84291581e-01,  3.68124553e-01,  4.64888243e-01,
        4.81753674e-01,  4.38153340e-01,  5.87785252e-01,  4.04508497e-01,
        6.84547106e-01,  3.64484314e-01,  7.70513243e-01,  3.18711995e-01,
        8.44327926e-01,  2.67913397e-01,  9.04827052e-01,  2.12889646e-01,
        9.51056516e-01,  1.54508497e-01,  9.82287251e-01,  9.36906573e-02,
        9.98026728e-01,  3.13952598e-02,  9.98026728e-01, -3.13952598e-02,
        9.82287251e-01, -9.36906573e-02,  9.51056516e-01, -1.54508497e-01,
        9.04827052e-01, -2.12889646e-01,  8.44327926e-01, -2.67913397e-01,
        7.70513243e-01, -3

In [19]:
X_brownian_unc.initialize_g(u)

In [20]:
neural_bridge_model = neurb.NeuralBridge(config)
_ = neural_bridge_model.train(mode="pretrained", load_relative_dir="../assets/ckpts/neurb")
X_neu = neural_bridge_model.build_neural_bridge()

INFO:root:Checkpoint loaded from /Users/vbd402/Documents/Projects/neuralbridge/assets/ckpts/neurb/landmark_benchmark


Loading pretrained model from the last epoch


In [21]:
X_gui = neural_bridge_model.X_gui

In [22]:
tGrid = TimeGrid(
    T=config["sde"]["T"],
    dt=config["sde"]["dt"],
    t_scheme=config["sde"]["t_scheme"]
)
W = WienerProcess(
    config["sde"]["W_dim"]
)

In [24]:
X_neu_solver = Euler(X_neu, W, tGrid)
X_neu_path = X_neu_solver.solve(x0=config["sde"]["u"], batch_size=1)
%timeit X_neu_path = X_neu_solver.solve(x0=config["sde"]["u"], batch_size=1)

9.81 ms ± 14.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
X_gui_solver = Euler(X_gui, W, tGrid)
X_gui_path = X_gui_solver.solve(x0=config["sde"]["u"], batch_size=1)
%timeit X_gui_path = X_gui_solver.solve(x0=config["sde"]["u"], batch_size=1)

6.85 ms ± 135 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
