In this Jupyter notebook, we will train a machine learned FV solver to solve the 1D advection equation at reduced resolution. Our objective is to train some ML models and plot their accuracy relative to other solvers.

In [1]:
# setup paths
import sys
basedir = '/Users/nickm/thesis/InvariantPreservingMLSolvers/1d_advection'
readwritedir = '/Users/nickm/thesis/InvariantPreservingMLSolvers/1d_advection'

sys.path.append('{}/core'.format(basedir))
sys.path.append('{}/simulate'.format(basedir))
sys.path.append('{}/ml'.format(basedir))

In [2]:
# import external packages
import jax
import jax.numpy as jnp
import numpy as onp
from jax import config, vmap
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
# import internal packages
from flux import Flux
from initialconditions import get_a0, get_initial_condition_fn, get_a
from simparams import CoreParams, SimulationParams
from helper import generate_legendre
from simulations import AdvectionFVSim
from trajectory import get_trajectory_fn, get_inner_fn
from trainingutils import save_training_data
from mlparams import TrainingParams, ModelParams
from model import LearnedFluxOutput
from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, 
                           compute_losses_no_model, init_params, save_training_params, load_training_params)

In [4]:
# helper functions

def plot_fv(a, core_params, color="blue"):
    plot_dg(a[...,None], core_params, color=color)
    
def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):
    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)
    
def plot_dg(a, core_params, color='blue'):
    p = a.shape[-1]
    def evalf(x, a, j, dx, leg_poly):
        x_j = dx * (0.5 + j)
        xi = (x - x_j) / (0.5 * dx)
        vmap_polyval = vmap(jnp.polyval, (0, None), -1)
        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
        return jnp.sum(poly_eval * a, axis=-1)

    NPLOT = [2,2,5,7][p-1]
    nx = a.shape[0]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)

    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))
    a_plot = a_plot.T.reshape(-1)
    xs = xs.T.reshape(-1)
    coords = {('x'): xs}
    data = xarray.DataArray(a_plot, coords=coords)
    data.plot(color=color)

def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):
    p = trajectory.shape[-1]
    NPLOT = [2,2,5,7][p-1]
    nx = trajectory.shape[1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plot = get_trajectory_plot_repr(trajectory)

    outer_steps = trajectory.shape[0]
    
    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plot, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5, color=color)
    
def plot_multiple_fv_trajectories(trajectories, core_params, t_inner, ylim = [-1.5,1.5]):
    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner, ylim = ylim)

def plot_multiple_dg_trajectories(trajectories, core_params, t_inner, ylim = [-1.5,1.5]):
    outer_steps = trajectories[0].shape[0]
    nx = trajectories[0].shape[1]
    p = trajectories[0].shape[-1]
    NPLOT = [2,2,5,7][p-1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plots = []
    for trajectory in trajectories:  
        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))
        
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plots, dims=["stack", "time", "x"], coords=coords).plot.line(
        col='time', hue="stack", col_wrap=5, ylim=ylim)
    

def get_core_params(flux='upwind'):
    Lx = 1.0
    return CoreParams(Lx, flux)

def get_sim_params(name = "test", cfl_safety=0.3, rk='ssp_rk3'):
    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)

def get_training_params(n_data, train_id="test", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):
    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)

def get_model_params(kernel_size = 3, kernel_out = 4,  depth = 3, width = 16):
    return ModelParams(kernel_size, kernel_out, depth, width)


def l2_norm_trajectory(trajectory):
    return (jnp.mean(trajectory**2, axis=1))
    
def get_model(core_params, model_params):
    features = [model_params.width for _ in range(model_params.depth - 1)]
    return LearnedFluxOutput(features, model_params.kernel_size, model_params.kernel_out)

### Finite Volume

##### Training Loop

First, we will generate the data.

In [5]:
# training hyperparameters
init_description = 'sum_sin'
kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 1, 'max_k': 4, 'amplitude_max': 1.0}
kwargs_sim = {'name' : "paper_test", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}
kwargs_model = {'kernel_size' : 3, 'kernel_out' : 4, 'depth' : 3, 'width' : 16}
n_runs = 100
t_inner_train = 0.02
BS = 32
NE = 100 # num epochs
outer_steps_train = int(1.0/t_inner_train)
nx_exact = 512
nxs = [8, 16, 32, 64]
learning_rate_list = [1e-3, 1e-3, 1e-4, 1e-4]
assert len(nxs) == len(learning_rate_list)
key = jax.random.PRNGKey(12)


In [6]:
##### Setup for Generating Training Data
core_params_muscl = get_core_params(flux='muscl')
sim_params = get_sim_params(**kwargs_sim)
n_data = n_runs * outer_steps_train
sim = AdvectionFVSim(core_params_muscl, sim_params)
init_fn = lambda key: get_initial_condition_fn(core_params_muscl, init_description, key=key, **kwargs_init)

In [7]:
# save training data
save_training_data(key, init_fn, core_params_muscl, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [8]:
##### Setup for training models

model_params = get_model_params(**kwargs_model)
model = get_model(core_params_muscl, model_params)
key = jax.random.PRNGKey(42)
i_params = init_params(key, model)

core_params_learned = get_core_params(flux='learned')
kwargs_train_FV = {'train_id': "flux_predicting", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}
training_params_list_learned = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]

core_params_limiter = get_core_params(flux='learnedlimiter')
kwargs_train_FV = {'train_id': "flux_limited", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}
training_params_list_limited = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]

core_params_combo = get_core_params(flux='combination_learned')
kwargs_train_FV = {'train_id': "combo_learned", 'batch_size' : BS, 'optimizer': 'adam', 'num_epochs' : NE}
training_params_list_combo = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]


Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay.

In [None]:
#### First, train original ML Model

for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list_learned[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params_learned, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params_learned)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

8
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85


In [None]:
#### Second, train flux-limited model

for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list_limited[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params_limiter, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params_limiter)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

In [None]:
#### Third, train combination model

for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list_combo[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params_combo, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params_combo)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

Next, we load and plot the losses for each nx to check that the simulation trained properly.

In [None]:
for i, nx in enumerate(nxs):
    losses, _ = load_training_params(nx, sim_params, training_params_list_learned[i], model)
    plt.plot(losses, label="learned {}".format(nx))
    
    #losses, _ = load_training_params(nx, sim_params, training_params_list_limited[i], model)
    #plt.plot(losses, label="limited {}".format(nx))
    
    #losses, _ = load_training_params(nx, sim_params, training_params_list_combo[i], model)
    #plt.plot(losses, label="combination {}".format(nx))
    
plt.ylim([0,100])
plt.legend()
plt.show()

Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model.

In [None]:
key = jax.random.PRNGKey(19)
sim_params = get_sim_params(**kwargs_sim)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    
    f_init = get_initial_condition_fn(core_params_muscl, init_description, key=key, **kwargs_init)
    a0 = get_a0(f_init, core_params_muscl, nx)
    t_inner = 1.0
    outer_steps = 5
    
    
    ########
    # Exact trajectory
    ########
    
    trajectory_exact = onp.zeros((outer_steps, nx))
    for k in range(outer_steps):
        t = k * t_inner
        trajectory_exact[k] = get_a(f_init, t, core_params_muscl, nx)
    trajectory_exact = jnp.asarray(trajectory_exact)
    
    ########
    # Flux 1: Centered
    ########
    core_params = get_core_params(flux='centered')
    sim = AdvectionFVSim(core_params, sim_params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_centered = trajectory_fn(a0)
    
    ########
    # Flux 2: Upwind
    ########
    core_params = get_core_params(flux='upwind')
    sim = AdvectionFVSim(core_params, sim_params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_upwind = trajectory_fn(a0)
    
    ########
    # Flux 3: MUSCL
    ########
    core_params = get_core_params(flux='muscl')
    sim = AdvectionFVSim(core_params, sim_params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_muscl = trajectory_fn(a0)
    
    ########
    # Flux 4: Learned
    ########
    core_params = get_core_params(flux='learned')
    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)
    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_learned = trajectory_fn(a0)
    
    """
    
    ########
    # Flux 5: Learned Limiter
    ########
    core_params = get_core_params(flux='learnedlimiter')
    _, params = load_training_params(nx, sim_params, training_params_list_limited[i], model)
    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_limiter = trajectory_fn(a0)
    
    ########
    # Flux 6: Upwind + Centered
    ########
    core_params = get_core_params(flux='combination_learned')
    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)
    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_combination = trajectory_fn(a0)
    """
    
    ########
    # Flux 7: Invariant-Preserving Learned
    ########
    core_params = get_core_params(flux='learned')
    _, params = load_training_params(nx, sim_params, training_params_list_learned[i], model)
    sim = AdvectionFVSim(core_params, sim_params, model=model, params=params, global_stabilization=True)
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory_invariant_learned = trajectory_fn(a0)
    
    
    
    trajectories = [trajectory_exact, trajectory_learned, trajectory_invariant_learned]  #, trajectory_upwind]#, trajectory_centered, trajectory_muscl]#, trajectory_limiter, trajectory_combination, trajectory_invariant_learned]
    plot_multiple_fv_trajectories(trajectories, core_params, t_inner)
    

    
    plt.show()
    
    for trajectory in trajectories:
        plt.plot(l2_norm_trajectory(trajectory))
    plt.show()

### Demonstrate that Global Stabilization Doesn't Degrade Accuracy

We want to compare four different numerical algorithms for solving the 1D advection equation. We compare: (a) MUSCL (b) Machine Learned (ML) (c) Machine learned with global stabilization and (d) Machine learned with MC limiter. 

Make sure to use "diff_lrs" for the params.

In [None]:
N = 100

mses = onp.zeros((len(nxs), 7))

def normalized_mse(traj, traj_exact):
    assert len(traj_exact.shape == 2)
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])


for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    
    t_inner = 0.1
    outer_steps = 10
    
    ########
    # Generate Exact Data
    ########
    
    
    exact_trajectory = onp.zeros((trajectory_muscl.shape[0],nx))
    for n in range(outer_steps):
        t = n * t_inner
        exact_trajectory[n] = get_a(f_init, t, core_params, nx)
    
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    
    # ML with MC Limiter
    core_params_mc = get_core_params(0, flux=Flux.LEARNEDLIMITER)
    sim_model_mc = AdvectionFVSim(core_params_mc, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model_mc = get_inner_fn(sim_model_mc.step_fn, sim_model_mc.dt_fn, t_inner)
    trajectory_fn_model_mc = get_trajectory_fn(inner_fn_model_mc, outer_steps)
   
    
    for n in range(N):
        
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        trajectory_model_mc = trajectory_fn_model_mc(a0)
        
        
        # Exact trajectory

        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N
        mse_mlmc_nx += normalized_mse(trajectory_model_mc, exact_trajectory) / N
        
        key, _ = jax.random.split(key)
        
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)
    mse_mlmc.append(mse_mlmc_nx)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
print(mse_muscl)
print(mse_ml)
print(mse_mlgs)
print(mse_mlmc)
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

mses = [mse_ml, mse_mlgs, mse_muscl, mse_mlmc]
labels = ["ML", "ML (Stabilized)", "MUSCL",    "ML MC\nLimiter"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

axs.set_xticks([64, 32, 16, 8])
axs.set_xticklabels(["N=64", "N=32", "N=16", "N=8"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.3, 0.95, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.655,0.45) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


#plt.savefig('mse_vs_nx.png')
#plt.savefig('mse_vs_nx.eps')
plt.show()

### Demonstrate that Global Stabilization Improves Accuracy over Time

For nx = 16, plot the accuracy of global stabilization versus ML MC limiter vs ML on the y-axis, with time on the x-axis.

In [None]:
nx = 16
N = 25

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

def normalized_mse(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])

mse_muscl = []
mse_ml = []
mse_mlgs = []
mse_mlmc = []

_, params = load_training_params(nx, sim_params, training_params_list[0], model)


for T in Ts:
    
    key = jax.random.PRNGKey(20)
    
    t_inner = 0.1
    outer_steps = int(T / t_inner)
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = AdvectionFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)

    # ML with MC Limiter
    core_params_mc = get_core_params(0, flux=Flux.LEARNEDLIMITER)
    sim_model_mc = AdvectionFVSim(core_params_mc, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model_mc = get_inner_fn(sim_model_mc.step_fn, sim_model_mc.dt_fn, t_inner)
    trajectory_fn_model_mc = get_trajectory_fn(inner_fn_model_mc, outer_steps)

    
    mse_muscl_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    mse_mlmc_nx = 0.0
    
    for n in range(N):
    
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        trajectory_model_mc = trajectory_fn_model_mc(a0)
        
        # Exact trajectory
        exact_trajectory = onp.zeros((outer_steps,nx))
        for n in range(outer_steps):
            t = n * t_inner
            exact_trajectory[n] = get_a(f_init, t, core_params, nx)
        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N
        mse_mlmc_nx += normalized_mse(trajectory_model_mc, exact_trajectory) / N
    
        key, _ = jax.random.split(key)
    
    
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)
    mse_mlmc.append(mse_mlmc_nx)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

mses = [mse_ml, mse_mlgs, mse_muscl, mse_mlmc]
labels = ["ML", "ML (Stabilized)", "MUSCL",    "ML MC\nLimiter"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(Ts, [jnp.nan_to_num(error, nan=1e7) for error in mse], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    
axs.set_xticks(Ts)
axs.set_xticklabels(["t=1", "2", "5", "10", "20", "50", "t=100"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.15, 0.8, '$N=16$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.6,0.1) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


#plt.savefig('mse_vs_time.png')
#plt.savefig('mse_vs_time.eps')
plt.show()