In [18]:
import jax
import jax.numpy as jnp
from jax import random, vmap
jax.config.update("jax_enable_x64", True)

In [20]:
P = 20
M = 100 # M+1 x grid
N = 100 # N+1 t grid
NUMBER_OF_SENSORS = M+1

x = jnp.linspace(0, P, M+1)
t = jnp.linspace(0, 5, N+1)

In [19]:
def sech(x):
    return 1/jnp.cosh(x) # sech isn't defined in NumPy

def u_soliton(x, t, key = random.PRNGKey(0)):
    c_key, d_key = random.split(key)
    c = random.uniform(c_key, minval=0.5, maxval=1.5)
    d = random.uniform(d_key, minval=0., maxval=20.)
    return 1/2*c*sech(jnp.abs((x-c*t + d) % P - P/2))**2

In [21]:
NUM_SAMPLES = 500

data = vmap(vmap(u_soliton, (None, 0, None)), (None, None, 0))(x,t, random.split(random.PRNGKey(0), NUM_SAMPLES))

In [22]:
train_val_test_split = [0.7, 0.15, 0.15]
train_split_idx = int(NUM_SAMPLES*train_val_test_split[0])
val_split_idx = int(NUM_SAMPLES*(train_val_test_split[0]+train_val_test_split[1]))

train, val, test = jnp.split(data, [train_split_idx, val_split_idx])

In [23]:
a_train = train[:,0]
u_train = train

a_val = val[:,0]
u_val = val

a_test = test[:,0]
u_test = test

In [24]:
jnp.savez("../data/advection.npz", x=x, t=t, a_train=a_train, u_train=u_train, a_val=a_val, u_val=u_val, a_test=a_test, u_test=u_test)