In this Jupyter notebook, we will train a machine learned FV solver to solve the 1D euler equations at reduced resolution. Our objective is to first study how and whether an ML model can learn to solve these equations, then to study whether global stabilization ensures entropy increase.

In [None]:
# setup paths
import sys
basedir = '/Users/nickm/thesis/icml2023paper/1d_euler'
readwritedir = '/Users/nickm/thesis/icml2023paper/1d_euler'

sys.path.append('{}/core'.format(basedir))
sys.path.append('{}/simulate'.format(basedir))
sys.path.append('{}/ml'.format(basedir))

In [None]:
# import external packages
import jax
import jax.numpy as jnp
import numpy as onp
from jax import config, vmap
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# import internal packages
from flux import Flux
from initialconditions import get_a0, f_init_sum_of_amplitudes, shock_tube_problem_1
from simparams import CoreParams, SimulationParams
from simulations import EulerFVSim
from trajectory import get_trajectory_fn, get_inner_fn
from helper import get_rho, get_u, get_p, get_c
from trainingutils import save_training_data
from mlparams import TrainingParams, StencilParams
from model import LearnedStencil # todo
from trainingutils import (get_loss_fn, get_batch_fn, get_idx_gen, train_model, 
                           compute_losses_no_model, init_params, save_training_params, 
                           load_training_params)

In [None]:
# helper functions
def plot_a(a, core_params, mins = [0.0 - 2e-2] * 3, maxs= [1.0 + 5e-2] * 3, title = ""):
    x = jnp.linspace(0.0, core_params.Lx, a.shape[1])
    
    fig, axs = plt.subplots(1, 3, figsize=(11, 3))
    axs[0].plot(x, get_rho(a, core_params))
    axs[0].set_ylabel(r'$\rho$')
    axs[0].set_ylim([mins[0], maxs[0]])
    
    axs[1].plot(x, get_u(a, core_params))
    axs[1].set_ylabel(r'$u$')
    axs[1].set_ylim([mins[1], maxs[1]])
    
    axs[2].plot(x, get_p(a, core_params))
    axs[2].set_ylabel(r'$p$')
    axs[2].set_ylim([mins[2], maxs[2]])
    
    fig.suptitle(title)
    fig.tight_layout()

def plot_ac(a, core_params, mins = [0.0 - 2e-2] * 3, maxs= [1.0 + 5e-2] * 3):
    x = jnp.linspace(0.0, core_params.Lx, a.shape[1])
    
    fig, axs = plt.subplots(1, 4, figsize=(11, 3))
    axs[0].plot(x, get_rho(a, core_params))
    axs[0].set_title(r'$\rho$')
    axs[0].set_ylim([mins[0], maxs[0]])
    
    axs[1].plot(x, get_u(a, core_params))
    axs[1].set_ylabel(r'$u$')
    axs[1].set_ylim([mins[1], maxs[1]])
    
    axs[2].plot(x, get_p(a, core_params))
    axs[2].set_ylabel(r'$p$')
    axs[2].set_ylim([mins[2], maxs[2]])
    
    axs[3].plot(x, get_c(a, core_params))
    axs[3].set_ylabel(r'$c$')
    axs[3].set_ylim([mins[1], maxs[1]])

def plot_trajectory(trajectory, core_params, mins = [0.0 - 2e-2] * 3, maxs= [1.0 + 5e-2] * 3):
    nx = trajectory.shape[2]
    xs = jnp.arange(nx) * core_params.Lx / nx
    xs = xs.T.reshape(-1)
    coords = {
        'x': xs,
        'time': t_inner * jnp.arange(outer_steps)
    }
    rhos = trajectory[:, 0, :]
    g = xarray.DataArray(rhos, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5)
    plt.ylim(mins[0], maxs[0])
    g.axes[0][0].set_ylabel(r'$\rho$', fontsize=18)

    us = trajectory[:, 1, :] / trajectory[:, 0, :]
    g = xarray.DataArray(us, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5)
    plt.ylim(mins[1], maxs[1])
    g.axes[0][0].set_ylabel(r'$u$', fontsize=18)

    ps = (core_params.gamma - 1) * (trajectory[:, 2, :] - 0.5 * trajectory[:, 1, :]**2 / trajectory[:, 0, :])
    g = xarray.DataArray(ps, dims=["time", "x"], coords=coords).plot(
        col='time', col_wrap=5)
    plt.ylim(mins[2], maxs[2])
    g.axes[0][0].set_ylabel(r'$p$', fontsize=18)

def get_core_params(Lx = 1.0, gamma = 5/3, bc = 'periodic', fluxstr = 'laxfriedrichs'):
    return CoreParams(Lx, gamma, bc, fluxstr)

def get_sim_params(name = "test", cfl_safety=0.3, rk='ssp_rk3'):
    return SimulationParams(name, basedir, readwritedir, cfl_safety, rk)
    
def get_training_params(n_data, train_id="test", batch_size=4, learning_rate=1e-3, num_epochs = 10, optimizer='adam'):
    return TrainingParams(n_data, num_epochs, train_id, batch_size, learning_rate, optimizer)

def get_stencil_params(kernel_size = 3, kernel_out = 4, stencil_width=4, depth = 3, width = 16):
    return StencilParams(kernel_size, kernel_out, stencil_width, depth, width)

def l2_norm_trajectory(trajectory):
    return (jnp.mean(trajectory**2, axis=1))

def get_model(core_params, stencil_params):
    features = [stencil_params.width for _ in range(stencil_params.depth - 1)]
    return LearnedStencil(features, stencil_params.kernel_size, stencil_params.kernel_out, stencil_params.stencil_width)

### Finite Volume

##### Training Loop

First, we will generate the data.

In [None]:
# training hyperparameters
Lx = 1.0
gamma = 1.4
kwargs_init = {'min_num_modes': 1, 'max_num_modes': 4, 'min_k': 0, 'max_k': 3, 'amplitude_max': 1.0, 'background_rho' : 1.0, 'min_rho' : 0.5, 'background_p' : 1.0, 'min_p' : 0.1}
kwargs_core = {'Lx': Lx, 'gamma': gamma, 'bc': 'periodic', 'fluxstr': 'musclcharacteristic'}
kwargs_sim = {'name' : "euler_test", 'cfl_safety' : 0.3, 'rk' : 'ssp_rk3'}
kwargs_train_FV = {'train_id': "euler_train_test", 'batch_size' : 8, 'optimizer': 'adam', 'num_epochs' : 3}
kwargs_stencil = {'kernel_size' : 3, 'kernel_out' : 4, 'stencil_width' : 4, 'depth' : 3, 'width' : 16}
n_runs = 5
t_inner_train = 0.01
Tf = 1.0
outer_steps_train = int(Tf/t_inner_train)
fv_flux_baseline = 'musclcharacteristic'
nx_exact = 400
nxs = [32, 64]
learning_rate_list = [0.0, 0.0]#[1e-6, 1e-6]
assert len(nxs) == len(learning_rate_list)


key = jax.random.PRNGKey(12)

# setup
core_params = get_core_params(**kwargs_core)
sim_params = get_sim_params(**kwargs_sim)
n_data = n_runs * outer_steps_train
training_params_list = [get_training_params(n_data, **kwargs_train_FV, learning_rate = lr) for lr in learning_rate_list]
stencil_params = get_stencil_params(**kwargs_stencil)
sim = EulerFVSim(core_params, sim_params)
model = get_model(core_params, stencil_params)

### Test Initial Conditions

In [None]:
nx = 100
key = jax.random.PRNGKey(30)
f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
a0 = get_a0(f_init, core_params, nx)
t_inner = 0.25
outer_steps = 5

inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
trajectory = trajectory_fn(a0)

In [None]:
maxs = [4.0, 2.0, 3.0]
mins = [-0.05, -2.0, -0.05]
#plot_a(a0, core_params, maxs=maxs, mins=mins)
plot_trajectory(trajectory, core_params, maxs=maxs, mins=mins)

### Save Training Data

In [None]:
init_fn = lambda key: f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
save_training_data(key, init_fn, core_params, sim_params, sim, t_inner_train, outer_steps_train, n_runs, nx_exact, nxs)

### Train

Next, we initialize the model parameters.

In [None]:
key = jax.random.PRNGKey(42)
i_params = init_params(key, model)

Next, we run a training loop for each value of nx. The learning rate undergoes a prespecified decay.

In [None]:
for i, nx in enumerate(nxs):
    print(nx)
    training_params = training_params_list[i]
    idx_fn = lambda key: get_idx_gen(key, training_params)
    batch_fn = get_batch_fn(core_params, sim_params, training_params, nx)
    loss_fn = get_loss_fn(model, core_params)
    losses, params = train_model(model, i_params, training_params, key, idx_fn, batch_fn, loss_fn)
    save_training_params(nx, sim_params, training_params, params, losses)

Next, we load and plot the losses for each nx to check that the simulation trained properly.

In [None]:
for i, nx in enumerate(nxs):
    losses, _ = load_training_params(nx, sim_params, training_params_list[i], model)
    plt.plot(losses, label=nx)
    print(losses)
#plt.ylim([0,1])
plt.legend()
plt.show()

Next, we plot the accuracy of the trained model on a few simple test cases to qualitatively evaluate the success of the training. We will eventually quantify the accuracy of the trained model.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(18)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 10
    # with params
    sim_model = EulerFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)
    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])
    
    """
    # with global stabilization
    sim_model_gs = EulerFVSim(core_params, sim_params, global_stabilization=True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    trajectory_model_gs = trajectory_fn_model_gs(a0)
    #plot_fv_trajectory(trajectory_model, core_params, t_inner, color = plot_colors[i])
    """

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    #plot_fv_trajectory(trajectory, core_params, t_inner, color = 'red')
    

    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)
    
    plt.show()
    plt.plot(l2_norm_trajectory(trajectory))
    plt.plot(l2_norm_trajectory(trajectory_model))
    #plt.plot(l2_norm_trajectory(trajectory_model_gs))
    plt.show()

We see from above that the baseline (red) has a large amount of numerical diffusion for small number of gridpoints, while is more accurate for more gridpoints. We also see that the machine learned model learns to accurately evolve the solution for nx > 8. So far, so good. Let's now look at a different initial condition.

In [None]:
# pick a key that gives something nice
key = jax.random.PRNGKey(20)

for i, nx in enumerate(nxs):
    print("nx is {}".format(nx))
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    
    f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
    a0 = get_a0(f_init, core_params, nx)
    t_inner = 1.0
    outer_steps = 5
    
    # with params
    sim_model = EulerFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    trajectory_model = trajectory_fn_model(a0)

    # without params
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    trajectory = trajectory_fn(a0)
    
    plot_multiple_fv_trajectories([trajectory, trajectory_model], core_params, t_inner)
    plt.show()

Oh no! We can see that for nx=8 and nx=32, the solution goes unstable between 1.0 < t < 2.0 and 0.0 < t < 1.0 respectively.

This is not good. Even though the machine learned PDE solver gives accurate solution for certain initial conditions, the solution blows up unexpectedly for certain initial conditions. Let's next ask: how frequently does an instability arise?

In [None]:
N = 25

for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10) # new key, same initial key for each nx
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 10.0
    outer_steps = 10
    sim_model = EulerFVSim(core_params, sim_params, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)
    
    inner_fn = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn = get_trajectory_fn(inner_fn, outer_steps)
    
    num_nan = 0
    
    for n in range(N):
        f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        trajectory_model = trajectory_fn_model(a0)
        num_nan += jnp.isnan(trajectory_model[-1]).any()
        key, _ = jax.random.split(key)
        
    print("nx is {}, num_nan is {} out of {}".format(nx, num_nan, N))
    
    

### Demonstrate that Global Stabilization Doesn't Degrade Accuracy

We want to compare three different numerical algorithms for solving the 1D Euler equations. We compare: (a) MUSCL w/ Characteristic Reconstruction (b) Machine Learned (ML) (c) Machine learned with global stabilization.

In [None]:
N = 100

mse_muscl = []
mse_ml = []
mse_mlgs = []

def normalized_mse(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])


for i, nx in enumerate(nxs):
    
    key = jax.random.PRNGKey(10)
    
    _, params = load_training_params(nx, sim_params, training_params_list[i], model)
    t_inner = 0.1
    outer_steps = 10
    
    mse_muscl_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = EulerFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = EulerFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
   
    
    for n in range(N):
        
        f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        a0_exact = get_a0(f_init, core_params, nx_exact)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        
        
        # Exact trajectory
        exact_trajectory = trajectory_fn_muscl(a0_exact)
        raise Exception # TODO
        exact_trajectory_ds = ...
        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory) / N
        
        key, _ = jax.random.split(key)
        
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
print(mse_muscl)
print(mse_ml)
print(mse_mlgs)
print(mse_mlmc)
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

mses = [mse_ml, mse_mlgs, mse_muscl]
labels = ["ML", "ML (Stabilized)", "MUSCL"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(nxs, mse, label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

axs.set_xticks([64, 32, 16, 8])
axs.set_xticklabels(["N=64", "N=32", "N=16", "N=8"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.3, 0.95, '$t=1$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.655,0.45) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


#plt.savefig('mse_vs_nx_euler.png')
#plt.savefig('mse_vs_nx_euler.eps')
plt.show()

### Demonstrate that Global Stabilization Improves Accuracy over Time

For nx = 16, plot the accuracy of global stabilization versus ML MC limiter vs ML on the y-axis, with time on the x-axis.

In [None]:
nx = 16
N = 25

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

def normalized_mse(traj, traj_exact):
    return jnp.mean((traj - traj_exact)**2 / jnp.mean(traj_exact**2, axis=1)[:, None])

mse_muscl = []
mse_ml = []
mse_mlgs = []

_, params = load_training_params(nx, sim_params, training_params_list[0], model)


for T in Ts:
    
    key = jax.random.PRNGKey(20)
    
    t_inner = 0.1
    outer_steps = int(T / t_inner)
    
    # MUSCL
    inner_fn_muscl = get_inner_fn(sim.step_fn, sim.dt_fn, t_inner)
    trajectory_fn_muscl = get_trajectory_fn(inner_fn_muscl, outer_steps)

    # Model without GS
    sim_model = EulerFVSim(core_params, sim_params, global_stabilization = False, model=model, params=params)
    inner_fn_model = get_inner_fn(sim_model.step_fn, sim_model.dt_fn, t_inner)
    trajectory_fn_model = get_trajectory_fn(inner_fn_model, outer_steps)

    # Model with GS
    sim_model_gs = EulerFVSim(core_params, sim_params, global_stabilization = True, model=model, params=params)
    inner_fn_model_gs = get_inner_fn(sim_model_gs.step_fn, sim_model_gs.dt_fn, t_inner)
    trajectory_fn_model_gs = get_trajectory_fn(inner_fn_model_gs, outer_steps)
    
    mse_muscl_nx = 0.0
    mse_ml_nx = 0.0
    mse_mlgs_nx = 0.0
    
    for n in range(N):
    
        f_init = f_init_sum_of_amplitudes(core_params, key, **kwargs_init)
        a0 = get_a0(f_init, core_params, nx)
        a0_exact = get_a0(f_init, core_params, nx_exact)
        
        trajectory_muscl = trajectory_fn_muscl(a0)
        trajectory_model = trajectory_fn_model(a0)
        trajectory_model_gs = trajectory_fn_model_gs(a0)
        
        # Exact trajectory
        exact_trajectory = trajectory_fn_muscl(a0_exact)
        raise Exception # TODO
        exact_trajectory_ds = ...
        
        mse_muscl_nx += normalized_mse(trajectory_muscl, exact_trajectory_ds) / N
        mse_ml_nx += normalized_mse(trajectory_model, exact_trajectory_ds) / N
        mse_mlgs_nx += normalized_mse(trajectory_model_gs, exact_trajectory_ds) / N
    
        key, _ = jax.random.split(key)
    
    
    mse_muscl.append(mse_muscl_nx)
    mse_ml.append(mse_ml_nx)
    mse_mlgs.append(mse_mlgs_nx)

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(7, 3.25))
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
linewidth = 3

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

mses = [mse_ml, mse_mlgs, mse_muscl]
labels = ["ML", "ML (Stabilized)", "MUSCL"]
colors = ["blue", "red", "purple", "green"]
linestyles = ["solid", "dashed", "solid", "solid"]

for k, mse in enumerate(mses):
    plt.loglog(Ts, [jnp.nan_to_num(error, nan=1e7) for error in mse], label = labels[k], color=colors[k], linewidth=linewidth, linestyle=linestyles[k])

Ts = [1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
    
axs.set_xticks(Ts)
axs.set_xticklabels(["t=1", "2", "5", "10", "20", "50", "t=100"], fontsize=18)
axs.set_yticks([1e-3, 1e-2, 1e-1, 1e0])
axs.set_yticklabels(["$10^{-3}$", "$10^{-2}$", "$10^{-1}$", "$10^0$"], fontsize=18)
axs.minorticks_off()
axs.set_ylabel("Normalized MSE", fontsize=18)
axs.text(0.15, 0.8, '$N=16$', transform=axs.transAxes, fontsize=18, verticalalignment='top')


handles = []
for k, mse in enumerate(mses):
    handles.append(
        mlines.Line2D(
            [],
            [],
            color=colors[k],
            linewidth=linewidth,
            label=labels[k],
            linestyle=linestyles[k]
        )
    )
axs.legend(handles=handles,loc=(0.6,0.1) , prop={'size': 15}, frameon=False)
plt.ylim([2.5e-4, 1e0+1e-1])
fig.tight_layout()


#plt.savefig('mse_vs_time_euler.png')
#plt.savefig('mse_vs_time_euler.eps')
plt.show()