In [2]:
# RSNG for Linear advection with a Gaussian bump
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from flax import linen as nn
import optax

In [11]:
# Set up the linear 
def advection_eq(u,x,t,c):
    du_dt= -c*jax.grad(u,x)
    return du_dt

x = jnp.linspace(-2.0,2.0, 101)
t = jnp.linspace(0, 3, 50)
c = 1.0 #constat speed

In [None]:
class PDEApproximator(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(32)(x)
        x = nn.relu(x)
        x = nn.Dense(32)(x)
        x = nn.relu(x)
        x = nn.Dense(32)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Initialize the neural network
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
input_shape = (1, 101)  # Assuming input vectors are 16-dimensional
model = PDEApproximator()
params = model.init(init_rng, jnp.ones(input_shape))['params']

In [None]:
# Copied from JB's code

# High Speed Settings

n_x = 1000 # number of sample points in sapce
sub_sample = 100 # number of paramters to randomly sample
dt = 1e-3 # time step for rk4 integrator


# High Accuracy Settings 
# n_x = 10_000  # number of sample points in sapce
# sub_sample = 800  # number of paramters to randomly sample
# dt = 1e-3  # time step for rk4 integrator

In [11]:
# Similarly define forward euler integrator
def odeint_euler(fn, y0, t, key):
    def euler(carry, t):
        y, t_prev, key = carry
        h = t - t_prev
        key, subkey = jax.random.split(key)
        y = y + h * fn(t_prev, y, subkey)
        return (y, t, key), y

    (yf, _, _), y = jax.lax.scan(euler, (y0, jnp.array(t[0]), key), t)
    return y