# DO NOT CHANGE, RAN TO COMPLETION AS PART OF SUBMISSION


In this Jupyter notebook, we will train a machine learned DG solver to solve the 1D advection equation at reduced resolution. Our objective is to study the frequency of instability, and to demonstrate that global stabilization eliminates this instability.

In [None]:
# setup paths
import sys
basedir = '/Users/nickm/thesis/icml2023paper/1d_advection'
readwritedir = '/Users/nickm/thesis/icml2023paper/1d_advection'

sys.path.append('{}/core'.format(basedir))
sys.path.append('{}/simulate'.format(basedir))
sys.path.append('{}/ml'.format(basedir))

In [None]:
# import external packages
import jax
import jax.numpy as jnp
import numpy as onp
from jax import config, vmap
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# import internal packages
from flux import Flux
from initialconditions import get_a0, get_initial_condition_fn, get_a
from simparams import CoreParams, CoreParamsDG, SimulationParams
from legendre import generate_legendre
from simulations import AdvectionDGSim
from trajectory import get_trajectory_fn, get_inner_fn
from trainingutils import save_training_data
from mlparams import TrainingParams, StencilParams
from model import LearnedStencil
from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, 
                           compute_losses_no_model, init_params, save_training_params, load_training_params)
from helper import convert_DG_representation

In [None]:
# helper functions

def plot_fv(a, core_params, color="blue"):
    plot_dg(a[...,None], core_params, color=color)
    
def plot_fv_trajectory(trajectory, core_params, t_inner, color='blue'):
    plot_dg_trajectory(trajectory[...,None], core_params, t_inner, color=color)
    
def plot_dg(a, core_params, color='blue'):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    def evalf(x, a, j, dx, leg_poly):
        x_j = dx * (0.5 + j)
        xi = (x - x_j) / (0.5 * dx)
        vmap_polyval = vmap(jnp.polyval, (0, None), -1)
        poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
        return jnp.sum(poly_eval * a, axis=-1)

    NPLOT = [2,2,5,7][p-1]
    nx = a.shape[0]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)

    a_plot = vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p))
    a_plot = a_plot.T.reshape(-1)
    xs = xs.T.reshape(-1)
    coords = {('x'): xs}
    data = xarray.DataArray(a_plot, coords=coords)
    data.plot(color=color)

def plot_dg_trajectory(trajectory, core_params, t_inner, color='blue'):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    NPLOT = [2,2,5,7][p-1]
    nx = trajectory.shape[1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plot = get_trajectory_plot_repr(trajectory)

    outer_steps = trajectory.shape[0]
    
    trajectory_plot = trajectory_plot.reshape(outer_steps, -1)
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plot, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5, color=color)
    
def plot_multiple_fv_trajectories(trajectories, core_params, t_inner):
    plot_multiple_dg_trajectories([trajectory[..., None] for trajectory in trajectories], core_params, t_inner)

def plot_multiple_dg_trajectories(trajectories, core_params, t_inner):
    outer_steps = trajectories[0].shape[0]
    nx = trajectories[0].shape[1]
    
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    NPLOT = [2,2,5,7][p-1]
    dx = core_params.Lx / nx
    xjs = jnp.arange(nx) * core_params.Lx / nx
    xs = xjs[None, :] + jnp.linspace(0.0, dx, NPLOT)[:, None]
    
    def get_plot_repr(a):
        def evalf(x, a, j, dx, leg_poly):
            x_j = dx * (0.5 + j)
            xi = (x - x_j) / (0.5 * dx)
            vmap_polyval = vmap(jnp.polyval, (0, None), -1)
            poly_eval = vmap_polyval(leg_poly, xi)  # nx, p array
            return jnp.sum(poly_eval * a, axis=-1)

        vmap_eval = vmap(evalf, (1, 0, 0, None, None), 1)
        return vmap_eval(xs, a, jnp.arange(nx), dx, generate_legendre(p)).T

    get_trajectory_plot_repr = vmap(get_plot_repr)
    trajectory_plots = []
    for trajectory in trajectories:  
        trajectory_plots.append(get_trajectory_plot_repr(trajectory).reshape(outer_steps, -1))
        
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    xarray.DataArray(trajectory_plots, dims=["stack", "time", "x"], coords=coords).plot.line(
        col='time', hue="stack", col_wrap=5)
    

def get_core_params(order, flux='upwind'):
    Lx = 1.0
    if order == 0:
        return CoreParams(Lx, flux)
    else:
        return CoreParamsDG(Lx, flux, order)

def get_sim_params(name = "test", cfl_safety=0.3, rk='ssp_rk3'):
    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)

def get_training_params(n_data, train_id="test", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='sgd'):
    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)

def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):
    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)


def l2_norm_trajectory_fv(trajectory):
    return (jnp.mean(trajectory**2, axis=1))

def l2_norm_trajectory_dg(trajectory, p):
    twokplusone = 2 * jnp.arange(0, p) + 1
    return (jnp.mean(jnp.sum(trajectory**2 / twokplusone[None, :], axis=-1), axis=1))
    
def get_model(core_params, stencil_params):
    if core_params.order is None:
        p = 1
    else:
        p = core_params.order + 1
    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]
    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width, p)

### Discontinuous Galerkin

##### Training Loop

First, we will generate the data.

In [None]:
# training hyperparameters
init_description = 'sum_sin'
kwargs_init = {'min_num_modes': 1, 'max_num_modes': 6, 'min_k': 0, 'max_k': 3, 'amplitude_max': 1.0}
kwargs_sim = {'name' : "dg_paper_data", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}
kwargs_train_DG = {'train_id': "dg_paper_train", 'batch_size' : 8, 'optimizer': 'adam', 'num_epochs' : 50}
kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}
n_runs = 100
t_inner_train = 0.02
outer_steps_train = int(1.0/t_inner_train)
dg_flux_baseline = 'upwind'
nx_exact = 128
nxs = [8, 16, 32]
learning_rate_list = [1e-3, 1e-3, 1e-3]
assert len(nxs) == len(learning_rate_list)
key = jax.random.PRNGKey(12)

p = 1

# setup
core_params = get_core_params(p, flux=dg_flux_baseline)
sim_params = get_sim_params(**kwargs_sim)
n_data = n_runs * outer_steps_train
training_params_list = [get_training_params(n_data, **kwargs_train_DG, learning_rate = lr) for lr in learning_rate_list]
stencil_params = get_stencil_params(**kwargs_stencil)
sim = AdvectionDGSim(core_params, sim_params)
init_fn = lambda key: get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
model = get_model(core_params, stencil_params)

In [None]:
# save training data
save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)

Next, we initialize the model parameters.

In [None]:
key = jax.random.PRNGKey(42)
i_params = init_params(key, model)

Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay.

In [None]:
for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

Next, we load and plot the losses for each nx to check that the simulation trained properly.

In [None]:
for i, nx in enumerate(nxs):
    losses, _ = load_training_params(nx, sim_params, training_params_list[i], model)
    plt.plot(losses, label=nx)
    print(losses)
plt.ylim([0,5])
plt.legend()
plt.show()

Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(19)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 10
    # with params
    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)
    #plot_dg_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])
    
    
    # with global stabilization
    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization=True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    trajectory_model_gs = trajectory_fn_model_gs(a0)
    

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    

    plot_multiple_dg_trajectories([trajectory, trajectory_model], core_params, t_inner)
    
    plt.show()
    plt.plot(l2_norm_trajectory_dg(trajectory, p))
    plt.plot(l2_norm_trajectory_dg(trajectory_model, p))
    plt.plot(l2_norm_trajectory_dg(trajectory_model_gs, p))
    plt.show()

We see from above that the baseline (red) has a large amount of numerical diffusion for small number of gridpoints, while is more accurate for more gridpoints. We also see that the machine learned model learns to accurately evolve the solution for nx > 8. So far, so good. Let's now look at a different initial condition.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(10)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 10
    
    # with params
    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    
    # with gs
    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization=True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    trajectory_model_gs = trajectory_fn_model_gs(a0)
    
    plot_multiple_dg_trajectories([trajectory, trajectory_model, trajectory_model_gs], core_params, t_inner)
    plt.show()
    
    plt.plot(l2_norm_trajectory_dg(trajectory, p))
    plt.plot(l2_norm_trajectory_dg(trajectory_model, p))
    plt.plot(l2_norm_trajectory_dg(trajectory_model_gs, p))
    #plt.plot(l2_norm_trajectory(trajectory_model_gs))
    plt.show()

In [None]:
N = 100

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10) # new key, same initial key for each nx
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    sim_model = AdvectionDGSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))
    
    

In [None]:
N = 100

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))

Are the NaNs eliminated by global stabilization? (They should be.)

### Demonstrate that Global Stabilization Doesn't Degrade Accuracy

We want to compare 3 different numerical algorithms for solving the 1D advection equation. We compare: (a) Upwind (b) Machine Learned (ML) (c) Machine learned with global stabilization.

In [None]:
N = 100

mse_upwind = []
mse_ml = []
mse_mlgs = []

def normalized_mse_fv(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])


def normalized_mse_dg(traj, traj_exact, p):
    twokplusone = 2 * jnp.arange(0, p+1) + 1
    l2_normalization = jnp.mean(jnp.sum(traj_exact**2 / twokplusone[None, None, :], axis=-1), axis=-1)
    l2 = jnp.sum((traj - traj_exact)**2 / twokplusone[None, None, :], axis=-1)
    return jnp.mean(l2 / l2_normalization[:, None])

vmap_convert_DG = vmap(convert_DG_representation, (0, None, None, None), 0)

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 0.1
    outer_steps = 10
    
    mse_upwind_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    
    # Upwind
    inner_fn_upwind = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_upwind = get_trajectory_fn(inner_fn_upwind, outer_steps)

    # Model without GS
    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)

    
    for n in range(N):
        
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        a0_exact = get_a0(f_init, core_params, nx_exact)
        
        trajectory_upwind = trajectory_fn_upwind(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        
        trajectory_exact = trajectory_fn_upwind(a0_exact)
        trajectory_exact_ds = vmap_convert_DG(trajectory_exact, p+1, nx, core_params.Lx)

        
        mse_upwind_nx += normalized_mse_dg(trajectory_upwind, trajectory_exact_ds, p) / N
        mse_ml_nx += normalized_mse_dg(trajectory_model, trajectory_exact_ds, p) / N
        gs = normalized_mse_dg(trajectory_model_gs, trajectory_exact_ds, p) / N
        mse_mlgs_nx += gs
        
        key, _ = jax.random.split(key)
        
    mse_upwind.append(mse_upwind_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
print(mse_upwind)
print(mse_ml)
print(mse_mlgs)
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

mses = [mse_ml, mse_mlgs, mse_upwind]
labels = ["DG ML", "DG ML\n(Stabilized)", "DG (Upwind)"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

axs.set_xticks([32, 16, 8])
axs.set_xticklabels(["N=32", "N=16", "N=8"], fontsize=18)
axs.set_yticks([1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1])
axs.set_yticklabels(["$10^{-8}$", "$10^{-7}$", "$10^{-6}$", "$10^{-5}$", "$10^{-4}$", "$10^{-3}$", "$10^{-2}$", "$10^{-1}$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.2, 0.85, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles, loc=(0.63,0.21), prop={'size': 15}, frameon=False)
plt.ylim([3e-9, 1e-1+6e-2])
fig.tight_layout()


plt.savefig('mse_vs_nx_dg.png')
plt.savefig('mse_vs_nx_dg.eps')
plt.show()

### Demonstrate that Global Stabilization Improves Accuracy over Time

For nx = 16, plot the accuracy of global stabilization vs ML on the y-axis, with time on the x-axis.

In [None]:
nx = 8
N = 100

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

def normalized_mse_fv(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])


def normalized_mse_dg(traj, traj_exact, p):
    twokplusone = 2 * jnp.arange(0, p+1) + 1
    l2_normalization = jnp.mean(jnp.sum(traj_exact**2 / twokplusone[None, None, :], axis=-1), axis=-1)
    l2 = jnp.sum((traj - traj_exact)**2 / twokplusone[None, None, :], axis=-1)
    return jnp.mean(l2 / l2_normalization[:, None])

mse_upwind_time = []
mse_ml_time = []
mse_mlgs_time = []

_, params = load_training_params(nx, sim_params, training_params_list[0], model)


for T in Ts:
    
    print(T)
    
    key = jax.random.PRNGKey(10)
    
    t_inner = 0.1
    outer_steps = int(T / t_inner)
    
    # Upwind
    inner_fn_upwind = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_upwind = get_trajectory_fn(inner_fn_upwind, outer_steps)

    # Model without GS
    sim_model = AdvectionDGSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = AdvectionDGSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)

    
    mse_upwind_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    
    for n in range(N):
    
        f_init = get_initial_condition_fn(core_params, init_description, key=key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        a0_exact = get_a0(f_init, core_params, nx_exact)
        
        trajectory_upwind = trajectory_fn_upwind(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        
        # Exact trajectory
        trajectory_exact = trajectory_fn_upwind(a0_exact)
        trajectory_exact_ds = vmap_convert_DG(trajectory_exact, p+1, nx, core_params.Lx)
        
        mse_upwind_nx += normalized_mse_dg(trajectory_upwind, trajectory_exact_ds, p) / N
        mse_ml_nx += normalized_mse_dg(trajectory_model, trajectory_exact_ds, p) / N
        mse_mlgs_nx += normalized_mse_dg(trajectory_model_gs, trajectory_exact_ds, p) / N
    
        key, _ = jax.random.split(key)
    
    
    mse_upwind_time.append(mse_upwind_nx)
    mse_ml_time.append(mse_ml_nx)
    mse_mlgs_time.append(mse_mlgs_nx)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

mses = [mse_ml_time, mse_mlgs_time, mse_upwind_time]
labels = ["DG ML", "DG ML (Stabilized)", "DG Upwind"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(Ts, [jnp.nan_to_num(error, nan=1e7) for error in mse], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    
axs.set_xticks(Ts)
axs.set_xticklabels(["t=1", "2", "5", "10", "20", "50", "t=100"], fontsize=18)
axs.set_yticks([1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-7}$", "$10^{-6}$", "$10^{-5}$", "$10^{-4}$", "$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.15, 0.8, '$N=8$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.52,0.03) , prop={'size': 15}, frameon=False)
plt.ylim([2.0e-6, 1e0-1e-1])
fig.tight_layout()


plt.savefig('mse_vs_time_dg.png')
plt.savefig('mse_vs_time_dg.eps')
plt.show()