In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from neuralbridge.stochastic_processes.examples import OUProcess, OUAuxProcess
from neuralbridge.networks.score_net import ScoreNetSmall
from neuralbridge.stochastic_processes.conds import GuidedBridgeProcess, NeuralBridgeProcess
from neuralbridge.solvers.sde import WienerProcess, Euler

In [3]:
dim = 1
T = 1.0
dt = 1. / 200
dtype = jnp.float32

gamma = 1.0
sigma = 1.0

seed = 42
u = jnp.array([0.0], dtype=dtype)
v = jnp.array([0.0], dtype=dtype)

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/Caskroom/miniconda/base/envs/neuralbridge/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


In [4]:
t_scheme = "linear"

ou_proc = OUProcess(gamma, sigma, T, dim, dtype)
ou_aux_proc = OUAuxProcess(gamma, sigma, T, dim, dtype)
ou_guided_proc = GuidedBridgeProcess(
    ou_proc, 
    ou_aux_proc,
    u=u,
    v=v,
    L0=jnp.eye(dim, dtype=dtype),
    Sigma0=jnp.eye(dim, dtype=dtype) * 1e-10,
    ts=jnp.arange(0, T + dt, dt, dtype=dtype),
)
wiener_proc = WienerProcess(T, dt, dim, dtype, t_scheme)

neural_net = ScoreNetSmall(
    out_dim=dim,
    hidden_dims=[10, 20],
    activation="tanh",
    norm="layer"
)

In [5]:
# Compute gradient w.r.t gamma using jax.grad
grad_gamma = jax.grad(lambda g: jnp.mean(Euler(OUProcess(g, sigma, T, dim, dtype), wiener_proc).solve(x0=u, rng_key=jax.random.PRNGKey(42), batch_size=16).xs))(gamma)

print(f"Gradient of mean w.r.t. gamma: {grad_gamma}")


Gradient of mean w.r.t. gamma: -0.09112265706062317


In [15]:
variables = neural_net.init(jax.random.PRNGKey(42), jnp.zeros((1, 1)), jnp.zeros((1, dim))) 
neural_bridge = NeuralBridgeProcess(ou_guided_proc, neural_net)


# Define a function that simulates trajectories and returns mean
def simulate_and_mean(variables):
    solver = Euler(neural_bridge, wiener_proc)
    dWs = wiener_proc.sample_path(jax.random.PRNGKey(42), batch_size=16).dxs
    trajectories = solver.solve_with_variables(
        x0=u,
        dWs=dWs,
        variables=variables,
        batch_size=16,
        training=True,
        mutable=["batch_stats"]
    )
    return jnp.mean(trajectories.xs)

# Compute gradients with respect to parameters only
grad_params = jax.grad(lambda p: simulate_and_mean({'params': p, 'batch_stats': {}}))(variables['params'])

# Print the shape of gradients for each parameter
print("Gradient shapes:")
jax.tree.map(lambda x: print(x.shape), grad_params)

Gradient shapes:
(10,)
(2, 10)
(20,)
(10, 20)
(1,)
(20, 1)
(10,)
(10,)
(20,)
(20,)


{'Dense_0': {'bias': None, 'kernel': None},
 'Dense_1': {'bias': None, 'kernel': None},
 'Dense_2': {'bias': None, 'kernel': None},
 'LayerNorm_0': {'bias': None, 'scale': None},
 'LayerNorm_1': {'bias': None, 'scale': None}}

In [19]:
import einops

In [18]:
solver = Euler(neural_bridge, wiener_proc)

def loss_fn(variables, seed):
    dWs = wiener_proc.sample_path(jax.random.PRNGKey(seed), batch_size=16).dxs
    solver_paths = solver.solve_with_variables(
        x0=u,
        dWs=dWs,
        variables=variables,
        batch_size=16,
        training=True   ,
        mutable=["batch_stats"]
    )
    xs, ts, log_lls = solver_paths.xs, solver_paths.ts, solver_paths.log_ll
    ts = einops.repeat(ts, "t -> b t 1", b=xs.shape[0])
    nus, *_ = neural_net.apply(
        variables,
        xs,
        ts,
        training=True,
        mutable=["batch_stats"]
    )
    loss = jnp.sum(jnp.sum(nus ** 2, axis=-1)) * dt - log_lls
    return jnp.mean(loss)

# Compute gradients with respect to parameters
grad_params_1 = jax.grad(loss_fn)(variables, seed=42)

# Print the shape of gradients for each parameter
print("Gradient shapes:")
jax.tree.map(lambda x: print(x.shape), grad_params_1)



Gradient shapes:
(10,)
(2, 10)
(20,)
(10, 20)
(1,)
(20, 1)
(10,)
(10,)
(20,)
(20,)


{'params': {'Dense_0': {'bias': None, 'kernel': None},
  'Dense_1': {'bias': None, 'kernel': None},
  'Dense_2': {'bias': None, 'kernel': None},
  'LayerNorm_0': {'bias': None, 'scale': None},
  'LayerNorm_1': {'bias': None, 'scale': None}}}

In [26]:
def loss_fn_batch_input(variables, batch):
    xs, ts, log_lls = batch
    ts = einops.repeat(ts, "t -> b t 1", b=xs.shape[0])
    nus, *_ = neural_net.apply(
        variables={
            "params": variables['params'],
            "batch_stats": variables['batch_stats']
        },
        x=xs,
        t=ts,
        training=True,
        mutable=["batch_stats"]
    )
    loss = jnp.sum(jnp.sum(nus ** 2, axis=-1)) * dt - log_lls
    return jnp.mean(loss)

path = solver.solve_with_variables(
    x0=u, 
    dWs=wiener_proc.sample_path(jax.random.PRNGKey(42), batch_size=16).dxs, 
    variables=variables,
    batch_size=16,
    training=True,
    mutable=["batch_stats"]
)
batch = (path.xs, path.ts, path.log_ll)

grad_params = jax.grad(lambda p: loss_fn_batch_input({'params': p, 'batch_stats': {}}, batch))(variables['params'])

# Print the shape of gradients for each parameter
print("Gradient shapes:")
jax.tree.map(lambda x: print(x.shape), grad_params)

Gradient shapes:
(10,)
(2, 10)
(20,)
(10, 20)
(1,)
(20, 1)
(10,)
(10,)
(20,)
(20,)


{'Dense_0': {'bias': None, 'kernel': None},
 'Dense_1': {'bias': None, 'kernel': None},
 'Dense_2': {'bias': None, 'kernel': None},
 'LayerNorm_0': {'bias': None, 'scale': None},
 'LayerNorm_1': {'bias': None, 'scale': None}}