In [None]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from matplotlib import colormaps
from matplotlib.collections import LineCollection


from src.sdes import sde_kunita, sde_utils
import src.sdes.time as time

In [None]:
def x0(num_landmarks):
    x0_1 = jnp.linspace(0, 1, num_landmarks)
    x0_2 = jnp.zeros(num_landmarks)
    x0 = jnp.stack([x0_1, x0_2], axis=1).flatten()
    return x0

def sample_circle(num_landmarks: int, radius=1., centre=jnp.asarray([0, 0])) -> jnp.ndarray:
    theta = jnp.linspace(0, 2 * jnp.pi, num_landmarks, endpoint=False)
    x = jnp.cos(theta)
    y = jnp.sin(theta)
    return (radius * jnp.stack([x, y], axis=1) + centre).flatten()

Setup for both forward and reverse data generation

In [None]:
lots_of_landmarks = 200
less_landmarks = 5

x0_less = sample_circle(less_landmarks, 0.5)
x0_lots = sample_circle(lots_of_landmarks, 0.5)

num_trajs = 2
T = 1.
N = 100

keys = jax.random.split(jax.random.PRNGKey(0), num_trajs)

Forward data generation

In [None]:
forward_lots = sde_kunita.data_forward(x0_lots, T, N)
_, fw_trajs_lots, _ = forward_lots(keys)

forward_less = sde_kunita.data_forward(x0_less, T, N)
_, fw_trajs_less, _ = forward_less(keys)


In [None]:
fw_trajs_lots = fw_trajs_lots.reshape(num_trajs, N, -1, 2)
fw_trajs_less = fw_trajs_less.reshape(num_trajs, N, -1, 2)

for landmark in jnp.arange(0,less_landmarks, 1):
    plt.plot(fw_trajs_less[0, :, landmark, 0], fw_trajs_less[0, :, landmark, 1])
plt.show()

for landmark in jnp.arange(0, lots_of_landmarks, 1):
    plt.plot(fw_trajs_lots[0, :, landmark, 0], fw_trajs_lots[0, :, landmark, 1])
plt.show()


Reverse data generation (without correction computation)

In [None]:
vector_fields = sde_kunita.vector_fields_reverse()
backward_lots = sde_utils.data_forward(x0_lots, T, N, vector_fields, bm_shape=(2 * 5 ** 2,))
_, bw_trajs_lots, _ = backward_lots(keys)


backward_less = sde_utils.data_forward(x0_less, T, N, vector_fields, bm_shape=(2 * 5 ** 2,))
_, bw_trajs_less, _ = backward_less(keys)

In [None]:
bw_trajs_lots = bw_trajs_lots.reshape(num_trajs, N, -1, 2)
bw_trajs_less = bw_trajs_less.reshape(num_trajs, N, -1, 2)

fig, ax = plt.subplots()

for landmark in jnp.arange(0, less_landmarks, 1):
    x = bw_trajs_less[0, :, landmark, 0]
    y = bw_trajs_less[0, :, landmark, 1]
    points = jnp.array([x, y]).T.reshape(-1, 1, 2)
    segments = jnp.concatenate([points[:-1], points[1:]], axis=1)
    norm = plt.Normalize(0, N)
    lc = LineCollection(segments, cmap='viridis')
    lc.set_array(jnp.arange(0, N, 1))
    lc.set_linewidth(2)
    line = ax.add_collection(lc)
c = 0.1
ax.set_xlim([bw_trajs_less[0, :, :, 0].min() - c, bw_trajs_less[0, :, :, 0].max() + c])
ax.set_ylim([bw_trajs_less[0, :, :, 1].min() - c, bw_trajs_less[0, :, :, 1].max() + c])
    # plt.plot(bw_trajs_less[0, :, landmark, 0], bw_trajs_less[0, :, landmark, 1])
plt.show()

fig, ax = plt.subplots()
for landmark in jnp.arange(0, lots_of_landmarks, 10):
    x = bw_trajs_lots[0, :30, landmark, 0]
    y = bw_trajs_lots[0, :30, landmark, 1]
    points = jnp.array([x, y]).T.reshape(-1, 1, 2)
    segments = jnp.concatenate([points[:-1], points[1:]], axis=1)
    lc = LineCollection(segments, cmap='viridis')
    lc.set_array(jnp.arange(0, N, 1))
    lc.set_linewidth(2)
    line = plt.gca().add_collection(lc)
    # plt.plot(bw_trajs_lots[0, :, landmark, 0], bw_trajs_lots[0, :, landmark, 1])
c = 0.5
ax.set_xlim([bw_trajs_less[0, :, :, 0].min() - c, bw_trajs_less[0, :, :, 0].max() + c])
ax.set_ylim([bw_trajs_less[0, :, :, 1].min() - c, bw_trajs_less[0, :, :, 1].max() + c])
plt.show()

Reverse data generation (with correction computation)

In [None]:
backward_lots = sde_kunita.data_reverse(x0_lots, T, N)
_, bw_trajs_lots, _ = backward_lots(keys)

backward_less = sde_kunita.data_reverse(x0_less, T, N)
_, bw_trajs_less, _ = backward_less(keys)

forward_lots = sde_kunita.data_forward(x0_lots, T, N)
_, fw_trajs_lots, _ = forward_lots(keys)

In [None]:
bw_trajs_lots = bw_trajs_lots.reshape(num_trajs, N, -1, 2)
bw_trajs_less = bw_trajs_less.reshape(num_trajs, N, -1, 2)

fw_trajs_lots = fw_trajs_lots.reshape(num_trajs, N, -1, 2)

for landmark in jnp.arange(0, less_landmarks, 1):
    plt.plot(bw_trajs_less[0, :, landmark, 0], bw_trajs_less[0, :, landmark, 1])
plt.show()

for landmark in jnp.arange(0, lots_of_landmarks, 40):
    plt.scatter(bw_trajs_lots[0, 0, landmark, 0], bw_trajs_lots[0, 0, landmark, 1], c='r')
    plt.plot(bw_trajs_lots[0, :, landmark, 0], bw_trajs_lots[0, :, landmark, 1])
    plt.scatter(fw_trajs_lots[0, 0, landmark, 0], fw_trajs_lots[0, 0, landmark, 1], c='g')
    plt.plot(fw_trajs_lots[0, :, landmark, 0], fw_trajs_lots[0, :, landmark, 1])
plt.show()