In [None]:
import sys
sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/core')
sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/simulate')
sys.path.append('/Users/nickm/thesis/icml2023paper/1d_advection/ml')

basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'
readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'

In [None]:
import jax
import jax.numpy as jnp
from jax import config, vmap
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
from initialconditions import get_a0, get_initial_condition_fn
from simparams import CoreParams, CoreParamsDG, SimulationParams
from legendre import generate_legendre
from simulations import AdvectionFVSim, AdvectionDGSim
from trajectory import get_trajectory_fn, get_inner_fn
from trainingutils import save_training_data
from mlparams import TrainingParams, StencilParams
from model import LearnedStencil
from trainingutils import get_loss_fn, get_batch_fn, get_idx_gen, train_model, compute_losses_no_model, init_params

In [None]:
def plot_fv(a, core_params): 
    nx = a.shape[0]
    spatial_coord = jnp.arange(nx) * core_params.Lx / nx # same for x and y
    coords = {'x': spatial_coord}
    xarray.DataArray(a, dims=['x'], coords=coords).plot(aspect=3, size=1.5)

def plot_fv_trajectory(trajectory, core_params, t_inner):
    outer_steps, nx = trajectory.shape[0:2]
    spatial_coord = jnp.arange(nx) * core_params.Lx / nx # same for x and y
    coords = {
      'x': spatial_coord,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5)
    
def plot_dg(a, core_params, color='blue'):
    def evalf(x, a, j, dx, leg_poly):
        x_j = dx * (0.5 + j)
        xi = (x - x_j) / (0.5 * dx)
        vmap_polyval = vmap(jnp.polyval, (0, None), -1)
        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
        return jnp.sum(poly_eval * a, axis=-1)

    NPLOT = [1,2,5,7][core_params.order]
    nx = a.shape[0]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)

    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(core_params.order+1))
    a_plot = a_plot.T.reshape(-1)
    xs = xs.T.reshape(-1)
    coords = {('x'): xs}
    data = xarray.DataArray(a_plot, coords=coords)
    data.plot(color=color)

def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):
    NPLOT = [1,2,5,7][core_params.order]
    nx = trajectory.shape[1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(core_params.order+1)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plot = get_trajectory_plot_repr(trajectory)

    outer_steps = trajectory.shape[0]
    
    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plot, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5)
    

def get_core_params(order, flux='upwind'):
    Lx = 1.0
    if order == 0:
        return CoreParams(Lx, flux)
    else:
        return CoreParamsDG(Lx, flux, order)

def get_sim_params(name = "test", cfl_safety=0.3, rk='ssp_rk3', gs=False):
    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk, gs)

def get_training_params(n_data, unique_id="test", batch_size=4, learning_rate=1e-3, num_epochs = 10):
    return TrainingParams(n_data, num_epochs, unique_id, batch_size, learning_rate)

def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):
    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)

def get_model(core_params, stencil_params):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]
    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)

In [None]:
#### test IC for FV
nx = 100
key = jax.random.PRNGKey(15)
core_params = get_core_params(0)
f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key)
a0 = get_a0(f_init, core_params, nx)

plot_fv(a0, core_params)

In [None]:
##### test simulate FV

nx = 50
key = jax.random.PRNGKey(2)
f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key)
a0 = get_a0(f_init, core_params, nx)

t_inner = 0.1
outer_steps = 11
core_params = get_core_params(0, flux='muscl')

sim_params = get_sim_params()
sim = AdvectionFVSim(core_params, sim_params)
inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)

trajectory = trajectory_fn(a0)
plot_fv_trajectory(trajectory, core_params, t_inner)

In [None]:
### generate data FV (1 timestep at a time)

nx_exact = 128
nxs = 8, 16, 32, 64
n_runs = 5
t_inner_train = 0.02
outer_steps_train = int(1.0/t_inner_train)
kwargs_sim = {'name' : "test", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3', 'gs' : False}

core_params = get_core_params(0, flux='muscl')
sim_params = get_sim_params(**kwargs_sim)
sim = AdvectionFVSim(core_params, sim_params)
kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}
init_fn = lambda key: get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs_init)


In [None]:
key = jax.random.PRNGKey(12)
save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs, **kwargs)

In [None]:
### train FV model (1 timestep at a time)

kwargs_train = {'unique_id': "test", 'batch_size' : 4, 'learning_rate' : 1e-4, 'num_epochs' : 10}
kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}

key = jax.random.PRNGKey(42)
n_data = n_runs * outer_steps_train
core_params = get_core_params(0, flux='muscl')
sim_params = get_sim_params()
sim = AdvectionFVSim(core_params, sim_params)

training_params = get_training_params(n_data, **kwargs_train)
stencil_params = get_stencil_params(**kwargs_stencil)
model = get_model(core_params, stencil_params)

In [None]:
nx = 32
idx_fn = lambda key: get_idx_gen(key, training_params)
batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)
loss_fn = get_loss_fn(model, core_params)
params = init_params(key, model)

In [None]:
losses, params = train_model(model, params, training_params, key, idx_fn, batch_fn, loss_fn)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)

In [None]:
### test success of FV model

# First, simulate sin wave

nx = 50
f_init = get_initial_condition_fn(core_params, 'sin')
a0 = get_a0(f_init, core_params, nx)

t_inner = 1.0
outer_steps = 5
core_params = get_core_params(0, flux='muscl')
sim_params = get_sim_params()
sim = AdvectionFVSim(core_params, sim_params)
inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)

trajectory = trajectory_fn(a0)
plot_fv_trajectory(trajectory, core_params, t_inner)


In [None]:
############### NOW DISCONTINUOUS GALERKIN SIMS ################

In [None]:
### test IC for DG
nx = 5
order = 1
key = jax.random.PRNGKey(15)
core_params = get_core_params(order)
f_init = get_initial_condition_fn(core_params, 'sum_sin', key=key)
a0 = get_a0(f_init, core_params, nx)

print(a0.shape)

plot_dg(a0, core_params, color='blue')

In [None]:
##### test simulate DG

nx = 50
order = 1
key = jax.random.PRNGKey(15)
core_params = get_core_params(order)

f_init = get_initial_condition_fn(core_params, 'sin', key=key)
a0 = get_a0(f_init, core_params, nx)


t_inner = 0.1
outer_steps = 11
core_params = get_core_params(order, flux='centered')
sim_params = get_sim_params(cfl_safety=0.3)
sim = AdvectionDGSim(core_params, sim_params)
inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)

trajectory = trajectory_fn(a0)
plot_dg_trajectory(trajectory, core_params, t_inner)

In [None]:
### generate data DG (1 timestep at a time)
nx_exact = 32
nxs = 4, 6, 8, 16
order_exact = 2
n_runs = 5
t_inner_train = 0.02
outer_steps_train = int(1.0/t_inner_train)

core_params = get_core_params(order_exact, flux='upwind')
sim_params = get_sim_params()
sim = AdvectionDGSim(core_params, sim_params)
kwargs = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}
init_fn = lambda key: get_initial_condition_fn(core_params, 'sum_sin', key=key, **kwargs)

In [None]:
key = jax.random.PRNGKey(12)
save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs, **kwargs)

In [None]:
### train DG model (1 timestep at a time)

n_runs = 5
t_inner_train = 0.02
outer_steps_train = int(1.0/t_inner_train)
key = jax.random.PRNGKey(42)
order_exact = 2
n_data = n_runs * outer_steps_train
core_params = get_core_params(order_exact, flux='upwind')
sim_params = get_sim_params()
sim = AdvectionFVSim(core_params, sim_params)

training_params = get_training_params(n_data, learning_rate=1e-5)
stencil_params = get_stencil_params()
model = get_model(core_params, stencil_params)

In [None]:
nx = 8
idx_fn = lambda key: get_idx_gen(key, training_params)
batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)
loss_fn = get_loss_fn(model, core_params)
params = init_params(key, model)

In [None]:
losses, params = train_model(model, params, training_params, key, idx_fn, batch_fn, loss_fn)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)

In [None]:
### test success of DG model