In [10]:
from functools import partial

import jax.numpy as jnp
import jax.random as jr
import jax
import matplotlib.pyplot as plt

from src.sdes import sde_kunita, sde_utils
import src.sdes.time as time
from src import plotting

In [11]:
def x0(num_landmarks):
    x0_1 = jnp.linspace(0, 1, num_landmarks)
    x0_2 = jnp.zeros(num_landmarks)
    x0 = jnp.stack([x0_1, x0_2], axis=1).flatten()
    return x0

def sample_circle(num_landmarks: int, radius=1., centre=jnp.asarray([0, 0])) -> jnp.ndarray:
    theta = jnp.linspace(0, 2 * jnp.pi, num_landmarks, endpoint=False)
    x = jnp.cos(theta)
    y = jnp.sin(theta)
    return (radius * jnp.stack([x, y], axis=1) + centre).flatten()

lots_of_landmarks = 1000

x0_10 = sample_circle(5)
x0_lots = sample_circle(lots_of_landmarks)

num_trajs = 2

T = 1.
N = 100

keys = jax.random.split(jax.random.PRNGKey(0), num_trajs)

data_fn_lots = sde_kunita.data_forward(x0_lots, T, N)
_, trajs_lots, _ = data_fn_lots(keys)

data_fn_less = sde_kunita.data_forward(x0_10, T, N)
_, trajs_10, _ = data_fn_less(keys)



In [15]:
from src.data_boundary_pts import sample_circle

_, diffusion = sde_kunita.vector_fields()
x = sample_circle(14, radius=1.0)

print(x)
print(x.shape)

diffusion_ = diffusion(0.0, x)
print(diffusion_)
# print(diffusion_@diffusion_.T)


[ 1.0000000e+00  0.0000000e+00  9.0096885e-01  4.3388376e-01
  6.2348974e-01  7.8183150e-01  2.2252086e-01  9.7492790e-01
 -2.2252107e-01  9.7492790e-01 -6.2348998e-01  7.8183138e-01
 -9.0096891e-01  4.3388361e-01 -1.0000000e+00 -3.2584137e-07
 -9.0096873e-01 -4.3388399e-01 -6.2348962e-01 -7.8183162e-01
 -2.2252055e-01 -9.7492802e-01  2.2252150e-01 -9.7492778e-01
  6.2349004e-01 -7.8183132e-01  9.0096909e-01 -4.3388331e-01]
(28,)
[[1.5034391e-04 0.0000000e+00 1.8315640e-03 ... 0.0000000e+00
  8.2085002e-03 0.0000000e+00]
 [0.0000000e+00 1.5034391e-04 0.0000000e+00 ... 1.3533528e-02
  0.0000000e+00 8.2085002e-03]
 [7.6955017e-05 0.0000000e+00 8.4911071e-04 ... 0.0000000e+00
  1.6036628e-02 0.0000000e+00]
 ...
 [0.0000000e+00 1.5247837e-03 0.0000000e+00 ... 1.9444861e-03
  0.0000000e+00 8.0935948e-04]
 [4.3648490e-04 0.0000000e+00 4.8161163e-03 ... 0.0000000e+00
  2.8273556e-03 0.0000000e+00]
 [0.0000000e+00 4.3648490e-04 0.0000000e+00 ... 5.1467880e-03
  0.0000000e+00 2.8273556e-03]]


In [None]:
trajs_lots = trajs_lots.reshape(num_trajs, N, lots_of_landmarks, 2)
trajs_less = trajs_10.reshape(num_trajs, N, 5, 2)

In [None]:
for landmark in jnp.arange(0, 5, 1):
    plt.plot(trajs_less[0, :, landmark, 0], trajs_less[0, :, landmark, 1])
    plt.scatter(trajs_less[0, 0, landmark, 0], trajs_less[0, 0, landmark, 1])
plt.show()

for landmark in jnp.arange(0, lots_of_landmarks, 100):
    plt.plot(trajs_lots[0, :, landmark, 0], trajs_lots[0, :, landmark, 1])
plt.show()


In [None]:
def x0(num_landmarks):
    x0_1 = jnp.linspace(0, 1, num_landmarks)
    x0_2 = jnp.zeros(num_landmarks)
    x0 = jnp.stack([x0_1, x0_2], axis=1).flatten()
    return x0

lots_of_landmarks = 100
x0_10 = sample_circle(10)
# x0_lots = sample_circle(lots_of_landmarks)

time_grid = time.grid(0, T, N)

data_fn_less = sde_kunita.data_reverse(x0_10, T, N)

ts, reverse, corr = data_fn_less(keys)

reverse = reverse.reshape(num_trajs, N, 10, 2)

for landmark in jnp.arange(0, 10, 1):
    plt.plot(reverse[0, :, landmark, 0], reverse[0, :, landmark, 1])
    plt.scatter(reverse[0, 0, landmark, 0], reverse[0, 0, landmark, 1])
    
plt.show()

for corr_ in corr:
    plt.scatter(1.0, corr_)

plt.show()

In [None]:
def x0(num_landmarks):
    x0_1 = jnp.linspace(0, 1, num_landmarks)
    x0_2 = jnp.zeros(num_landmarks)
    x0 = jnp.stack([x0_1, x0_2], axis=1).flatten()
    return x0

lots_of_landmarks = 100
x0_10 = sample_circle(10)
x0_lots = sample_circle(lots_of_landmarks)

time_grid = time.grid(0, T, N)
rev_drift, rev_diffusion = sde_kunita.vector_fields_reverse()

rev_trajs_10 = jax.vmap(sde_utils.solution, (0, None, None, None, None, None))(keys, time_grid, x0_10, rev_drift, rev_diffusion, (2*sde_kunita.GRID_SIZE**2,))
rev_trajs_lots = jax.vmap(sde_utils.solution, (0, None, None, None, None, None))(keys, time_grid, x0_lots, rev_drift, rev_diffusion, (2*sde_kunita.GRID_SIZE**2,))

# , bm_shape=(2*sde_kunita.noise_dim(),)

In [None]:
rev_trajs_10 = rev_trajs_10.reshape(num_trajs, N, 10, 2)
rev_trajs_lots = rev_trajs_lots.reshape(num_trajs, N, lots_of_landmarks, 2)

In [None]:
for landmark in jnp.arange(0, 10, 1):
    plt.plot(rev_trajs_10[0, :, landmark, 0], rev_trajs_10[0, :, landmark, 1])
    plt.scatter(rev_trajs_10[0, 0, landmark, 0], rev_trajs_10[0, 0, landmark, 1], c='r')
    plt.scatter(rev_trajs_10[0, -1, landmark, 0], rev_trajs_10[0, -1, landmark, 1], c='g')
    
# plt.xlim((0, 1))
# plt.ylim((-1,0))
plt.show()

for landmark in jnp.arange(0, lots_of_landmarks, 1):
    plt.plot(rev_trajs_lots[0, :, landmark, 0], rev_trajs_lots[0, :, landmark, 1])
    plt.scatter(rev_trajs_lots[0, 0, landmark, 0], rev_trajs_lots[0, 0, landmark, 1], c='r')
    plt.scatter(rev_trajs_lots[0, -1, landmark, 0], rev_trajs_lots[0, -1, landmark, 1], c='g')
    
# plt.xlim((0, 1))
# plt.ylim((-1,0))
plt.show()


In [None]:
grid_range=(-1, 2)
grid_size=5
def grid():
    grid = jnp.linspace(*grid_range, grid_size)
    grid = jnp.stack(jnp.meshgrid(grid, grid, indexing="xy"), axis=-1)
    return grid

print(grid())