In this Jupyter notebook, we will train a machine learned FV solver and machine learned DG solver to solve the 1D advection equation at reduced resolution. Our objective is to study the frequency of instability, and to demonstrate that global stabilization eliminates this instability.

In [None]:
# setup paths
import sys
basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'
readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'

sys.path.append('{}/core'.format(basedir))
sys.path.append('{}/simulate'.format(basedir))
sys.path.append('{}/ml'.format(basedir))

In [None]:
# import external packages
import jax
import jax.numpy as jnp
import numpy as onp
from jax import config, vmap
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# import internal packages
from flux import Flux
from initialconditions import get_a0, get_initial_condition_fn, get_a
from simparams import CoreParams, CoreParamsDG, SimulationParams
from legendre import generate_legendre
from simulations import AdvectionFVSim, AdvectionDGSim
from trajectory import get_trajectory_fn, get_inner_fn
from trainingutils import save_training_data
from mlparams import TrainingParams, StencilParams
from model import LearnedStencil
from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, 
                           compute_losses_no_model, init_params, save_training_params, load_training_params)

In [None]:
# helper functions

def plot_fv(a, core_params, color="blue"):
    plot_dg(a[...,None], core_params, color=color)
    
def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):
    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)
    
def plot_dg(a, core_params, color='blue'):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    def evalf(x, a, j, dx, leg_poly):
        x_j = dx * (0.5 + j)
        xi = (x - x_j) / (0.5 * dx)
        vmap_polyval = vmap(jnp.polyval, (0, None), -1)
        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
        return jnp.sum(poly_eval * a, axis=-1)

    NPLOT = [2,2,5,7][p-1]
    nx = a.shape[0]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)

    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))
    a_plot = a_plot.T.reshape(-1)
    xs = xs.T.reshape(-1)
    coords = {('x'): xs}
    data = xarray.DataArray(a_plot, coords=coords)
    data.plot(color=color)

def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    NPLOT = [2,2,5,7][p-1]
    nx = trajectory.shape[1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plot = get_trajectory_plot_repr(trajectory)

    outer_steps = trajectory.shape[0]
    
    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plot, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5, color=color)
    
def plot_multiple_fv_trajectories(trajectories, core_params, t_inner):
    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner)

def plot_multiple_dg_trajectories(trajectories, core_params, t_inner):
    outer_steps = trajectories[0].shape[0]
    nx = trajectories[0].shape[1]
    
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    NPLOT = [2,2,5,7][p-1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plots = []
    for trajectory in trajectories:  
        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))
        
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plots, dims=["stack", "time", "x"], coords=coords).plot.line(
        col='time', hue="stack", col_wrap=5)
    

def get_core_params(order, flux='upwind'):
    Lx = 1.0
    if order == 0:
        return CoreParams(Lx, flux)
    else:
        return CoreParamsDG(Lx, flux, order)

def get_sim_params(name = "test", cfl_safety=0.3, rk='ssp_rk3'):
    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)

def get_training_params(n_data, train_id="test", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):
    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)

def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):
    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)


def l2_norm_trajectory(trajectory):
    return (jnp.mean(trajectory**2, axis=1))
    
def get_model(core_params, stencil_params):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]
    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)

### Finite Volume

##### Training Loop

First, we will generate the data.

In [None]:
# training hyperparameters
init_description = 'sum_sin'
kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}
kwargs_sim = {'name' : "diff_lrs", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}
kwargs_train_FV = {'train_id': "diff_lrs", 'batch_size' : 32, 'optimizer': 'adam', 'num_epochs' : 100}
kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}
n_runs = 100
t_inner_train = 0.02
outer_steps_train = int(1.0/t_inner_train)
fv_flux_baseline = 'muscl' # learning a correction to the MUSCL scheme
nx_exact = 256
nxs = [8, 16, 32, 64]
learning_rate_list = [1e-2, 1e-2, 1e-4, 1e-5]
assert len(nxs) == len(learning_rate_list)
key = jax.random.PRNGKey(12)

# setup
core_params = get_core_params(0, flux=fv_flux_baseline)
sim_params = get_sim_params(**kwargs_sim)
n_data = n_runs * outer_steps_train
training_params_list = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]
stencil_params = get_stencil_params(**kwargs_stencil)
sim = AdvectionFVSim(core_params, sim_params)
init_fn = lambda key: get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
model = get_model(core_params, stencil_params)

In [None]:
# save training data
save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)

Next, we initialize the model parameters.

In [None]:
key = jax.random.PRNGKey(42)
i_params = init_params(key, model)

Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay.

In [None]:
for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

Next, we load and plot the losses for each nx to check that the simulation trained properly.

In [None]:
for i, nx in enumerate(nxs):
    losses, _ = load_training_params(nx, sim_params, training_params_list[i], model)
    plt.plot(losses, label=nx)
plt.ylim([0,1])
plt.legend()
plt.show()

Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(18)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 10
    # with params
    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)
    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])
    
    """
    # with global stabilization
    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization=True, epsilon_gs=0.0, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    trajectory_model_gs = trajectory_fn_model_gs(a0)
    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])
    """

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    #plot_fv_trajectory(trajectory, core_params, t_inner, color = 'red')
    

    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)
    
    plt.show()
    plt.plot(l2_norm_trajectory(trajectory))
    plt.plot(l2_norm_trajectory(trajectory_model))
    #plt.plot(l2_norm_trajectory(trajectory_model_gs))
    plt.show()

We see from above that the baseline (red) has a large amount of numerical diffusion for small number of gridpoints, while is more accurate for more gridpoints. We also see that the machine learned model learns to accurately evolve the solution for nx > 8. So far, so good. Let's now look at a different initial condition.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(20)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 5
    
    # with params
    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    
    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)
    plt.show()

Oh no! We can see that for nx=8 and nx=32, the solution goes unstable between 1.0 < t < 2.0 and 0.0 < t < 1.0 respectively.

This is not good. Even though the machine learned PDE solver gives accurate solution for certain initial conditions, the solution blows up unexpectedly for certain initial conditions. Let's next ask: how frequently does an instability arise?

In [None]:
N = 25

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10) # new key, same initial key for each nx
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    sim_model = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))
    
    

So we see that a large percentage of the simulations go unstable. We could use various tips and tricks to decrease the number of simulations that go unstable, such as to increase the size of the training set or the duration of training or use a different loss function. But we are interested in something else entirely: eliminating the ability of the solution to go unstable. What happens when we set global stabilization to True?

In [None]:
N = 25

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    G = lambda f, u: (jnp.roll(u, -1) - u)
    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = True, epsilon_gs = 0.0, G=G, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))

The NaNs are eliminated for nx=8 and nx=16, but not for nx=32. Why not? The reason is that if the ML solver predicts a large flux value, this modifies the CFL condition. To eliminate these NaNs, we'll need to reduce the cfl number.

In [None]:
kwargs_sim_low_cfl = {'name' : "larger", 'cfl_safety' : 0.1, 'rk' : 'ssp_rk3'}
sim_params_low_cfl = get_sim_params(**kwargs_sim_low_cfl)

In [None]:
N = 25

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    G = lambda f, u: (jnp.roll(u, -1) - u)
    sim_model = AdvectionFVSim(core_params, sim_params_low_cfl, global_stabilization = True, epsilon_gs = 0.0, G=G, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("With low cfl, nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))

Now that the timestep is sufficiently small, as expected the global stabilization method eliminates NaNs from the final solution. This is a demonstration of our claim that the solution is provably stable (in the time-continuous limit).  

### Demonstrate that Global Stabilization Doesn't Degrade Accuracy

We want to compare four different numerical algorithms for solving the 1D advection equation. We compare: (a) MUSCL (b) Machine Learned (ML) (c) Machine learned with global stabilization and (d) Machine learned with MC limiter. 

Make sure to use "diff_lrs" for the params.

In [None]:
N = 100

mse_muscl = []
mse_ml = []
mse_mlgs = []
mse_mlmc = []

def normalized_mse(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])


for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 0.1
    outer_steps = 10
    
    mse_muscl_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    mse_mlmc_nx = 0.0
    
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    
    # ML with MC Limiter
    core_params_mc = get_core_params(0, flux=Flux.LEARNEDLIMITER)
    sim_model_mc = AdvectionFVSim(core_params_mc, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model_mc = get_inner_fn(sim_model_mc.step_fn, sim_model_mc.dt_fn, t_inner)
    trajectory_fn_model_mc = get_trajectory_fn(inner_fn_model_mc, outer_steps)
   
    
    for n in range(N):
        
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        trajectory_model_mc = trajectory_fn_model_mc(a0)
        
        
        # Exact trajectory
        exact_trajectory = onp.zeros((trajectory_muscl.shape[0],nx))
        for n in range(outer_steps):
            t = n * t_inner
            exact_trajectory[n] = get_a(f_init, t, core_params, nx)
        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N
        mse_mlmc_nx += normalized_mse(trajectory_model_mc, exact_trajectory) / N
        
        key, _ = jax.random.split(key)
        
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)
    mse_mlmc.append(mse_mlmc_nx)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
print(mse_muscl)
print(mse_ml)
print(mse_mlgs)
print(mse_mlmc)
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

mses = [mse_ml, mse_mlgs, mse_muscl, mse_mlmc]
labels = ["ML", "ML (Stabilized)", "MUSCL",    "ML MC\nLimiter"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

axs.set_xticks([64, 32, 16, 8])
axs.set_xticklabels(["N=64", "N=32", "N=16", "N=8"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.3, 0.95, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.655,0.45) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


plt.savefig('mse_vs_nx.png')
plt.savefig('mse_vs_nx.eps')
plt.show()

### Demonstrate that Global Stabilization Improves Accuracy over Time

For nx = 16, plot the accuracy of global stabilization versus ML MC limiter vs ML on the y-axis, with time on the x-axis.

In [None]:
nx = 16
N = 25

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

def normalized_mse(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])

mse_muscl = []
mse_ml = []
mse_mlgs = []
mse_mlmc = []

_, params = load_training_params(nx, sim_params, training_params_list[0], model)


for T in Ts:
    
    key = jax.random.PRNGKey(20)
    
    t_inner = 0.1
    outer_steps = int(T / t_inner)
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)

    # ML with MC Limiter
    core_params_mc = get_core_params(0, flux=Flux.LEARNEDLIMITER)
    sim_model_mc = AdvectionFVSim(core_params_mc, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model_mc = get_inner_fn(sim_model_mc.step_fn, sim_model_mc.dt_fn, t_inner)
    trajectory_fn_model_mc = get_trajectory_fn(inner_fn_model_mc, outer_steps)

    
    mse_muscl_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    mse_mlmc_nx = 0.0
    
    for n in range(N):
    
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        trajectory_model_mc = trajectory_fn_model_mc(a0)
        
        # Exact trajectory
        exact_trajectory = onp.zeros((outer_steps,nx))
        for n in range(outer_steps):
            t = n * t_inner
            exact_trajectory[n] = get_a(f_init, t, core_params, nx)
        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N
        mse_mlmc_nx += normalized_mse(trajectory_model_mc, exact_trajectory) / N
    
        key, _ = jax.random.split(key)
    
    
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)
    mse_mlmc.append(mse_mlmc_nx)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

mses = [mse_ml, mse_mlgs, mse_muscl, mse_mlmc]
labels = ["ML", "ML (Stabilized)", "MUSCL",    "ML MC\nLimiter"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(Ts, [jnp.nan_to_num(error, nan=1e7) for error in mse], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    
axs.set_xticks(Ts)
axs.set_xticklabels(["t=1", "2", "5", "10", "20", "50", "t=100"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.15, 0.8, '$N=16$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.6,0.1) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


plt.savefig('mse_vs_time.png')
plt.savefig('mse_vs_time.eps')
plt.show()