In [None]:
import sys
#sys.path.append('/home/mcgreivy/InvariantPreservingMLSolvers/2d_incompressible_euler/ml')
#sys.path.append('/home/mcgreivy/InvariantPreservingMLSolvers/2d_incompressible_euler/baselines')
#sys.path.append('/home/mcgreivy/InvariantPreservingMLSolvers/2d_incompressible_euler/simulate')
sys.path.append('/Users/nickm/thesis/InvariantPreservingMLSolvers/2d_incompressible_euler/ml')
sys.path.append('/Users/nickm/thesis/InvariantPreservingMLSolvers/2d_incompressible_euler/baselines')
sys.path.append('/Users/nickm/thesis/InvariantPreservingMLSolvers/2d_incompressible_euler/simulate')

import jax
import jax.numpy as jnp
import numpy as onp
from jax import config
config.update("jax_enable_x64", True)
import xarray
import seaborn as sns
import matplotlib.pyplot as plt

from initialconditions import init_fn_FNO, init_fn_jax_cfd
from simulations import KolmogorovFiniteVolumeSimulation
from simparams import FiniteVolumeSimulationParams

from helper import convert_FV_representation
from trajectory import get_trajectory_fn, get_inner_fn
from flux import Flux

from model import LearnedFlux2D
from mlparams import ModelParams, TrainingParams
from trainingutils import init_params, save_training_data, save_training_params, load_training_params
from trainingutils import get_loss_fn, get_batch_fn, get_idx_gen, train_model, compute_losses_no_model

#########################
# HYPERPARAMS
#########################


simname = "ten_burnin"
train_id = "ten_burnin"

cfl_safety= 0.3
Lx = Ly = 2 * jnp.pi
viscosity= 0.0 #1/1000
forcing_coeff = 0.0 #1.0
drag = 0.0 #0.1
max_velocity = 7.0
ic_wavenumber = 2

batch_size= 100
learning_rate=1e-4
num_epochs = 100
kernel_size = 5
depth = 3
width = 4

nx_exact = ny_exact = 128
outer_steps = 100
n_runs = 1
t_inner = 0.01
t_burnin = 40.0
nxs = [32, 64]

#basedir = "/home/mcgreivy/InvariantPreservingMLSolvers/2d_incompressible_euler"
#readwritedir = "/scratch/gpfs/mcgreivy/InvariantPreservingMLSolvers/2d_incompressible_euler"
basedir = '/Users/nickm/thesis/InvariantPreservingMLSolvers/2d_incompressible_euler'
readwritedir = '/Users/nickm/thesis/InvariantPreservingMLSolvers/2d_incompressible_euler'

#########################
# END HYPERPARAMS
#########################

plot_dir = '{}/data/plots'.format(readwritedir)



In [None]:
def plot_trajectory_fv(trajectory, sim_params, t_inner):
    nx = trajectory.shape[1]
    spatial_coord = jnp.arange(nx) * Lx / nx # same for x and y
    coords = {
      'x': spatial_coord,
      'y': spatial_coord,
        'time': t_inner * jnp.arange(trajectory.shape[0])
    }
    xarray.DataArray(trajectory, dims=["time", "x", "y"], coords=coords).plot.imshow(
        col='time', col_wrap=5, 
        cmap=sns.cm.icefire, robust=True)

def plot_trajectory_dg(trajectory, old_sim_params):
    nx = trajectory.shape[1]
    new_order = 0
    sim_params = make_dg_sim_params(nx * (old_sim_params.order+1), nx * (old_sim_params.order+1), new_order)
    trajectory_fv = batch_convert_DG_representation(trajectory, old_sim_params.order, sim_params)[...,0]
    plot_trajectory_fv(trajectory_fv, sim_params)
    
def get_sim_params(nx, ny, global_stabilization=False, energy_conserving=False):
    rk='ssp_rk3'
    flux=Flux.VANLEER
    return FiniteVolumeSimulationParams(simname, basedir, readwritedir, nx, ny, Lx, Ly, cfl_safety, rk, flux, global_stabilization, energy_conserving)

def get_simulation(sim_params, model=None, params=None):
    return KolmogorovFiniteVolumeSimulation(sim_params, viscosity, forcing_coeff, drag, model=model, params=params)

def get_trajectory(sim_params, v0):
    v_init = convert_FV_representation(v0, sim_params)
    sim = get_simulation(sim_params)
    rollout_fn = get_trajectory_fn(sim.step_fn, outer_steps)
    return rollout_fn(v_init)

def get_ml_params():
    return ModelParams(train_id, batch_size, learning_rate, num_epochs, kernel_size, depth, width)

def get_model():
    model_params = get_ml_params()
    return LearnedFlux2D(model_params)

In [None]:
t_inner = 0.1
outer_steps = 11

nx = 32




model = get_model()
key_init = jax.random.PRNGKey(42)
i_params = init_params(key_init, model)



sim_params = get_sim_params(nx, nx)
simulation = get_simulation(sim_params, model=model, params=i_params)
inner_fn = get_inner_fn(simulation.step_fn, simulation.dt_fn, t_inner)
rollout_fn = jax.jit(get_trajectory_fn(inner_fn, outer_steps))

sim_params_gs = get_sim_params(nx, nx, global_stabilization=True)
simulation_gs = get_simulation(sim_params_gs, model=model, params=i_params)
inner_fn_gs = get_inner_fn(simulation_gs.step_fn, simulation_gs.dt_fn, t_inner)
rollout_fn_gs = jax.jit(get_trajectory_fn(inner_fn_gs, outer_steps))

sim_params_ec = get_sim_params(nx, nx, global_stabilization=True, energy_conserving=True)
simulation_ec = get_simulation(sim_params_ec, model=model, params=i_params)
inner_fn_ec = get_inner_fn(simulation_ec.step_fn, simulation_ec.dt_fn, t_inner)
rollout_fn_ec = jax.jit(get_trajectory_fn(inner_fn_ec, outer_steps))

key_data = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key_data)
vorticity0 = init_fn_jax_cfd(subkey, sim_params, 7.0, 2)

In [None]:
trajectory = rollout_fn(vorticity0)
trajectory_gs = rollout_fn_gs(vorticity0)
trajectory_ec = rollout_fn_ec(vorticity0)

In [None]:
plot_trajectory_fv(trajectory, sim_params, t_inner)
plot_trajectory_fv(trajectory_gs, sim_params_gs, t_inner)
plot_trajectory_fv(trajectory_ec, sim_params_ec, t_inner)

In [None]:
def l2_norm(a):
    return jnp.mean(a**2)

traj_l2_norm = jax.vmap(l2_norm)

plt.plot(traj_l2_norm(trajectory), color="blue")
plt.plot(traj_l2_norm(trajectory_gs), color="green")
plt.plot(traj_l2_norm(trajectory_ec), color="red")
plt.ylim([0,30])

In [None]:
from poissonsolver import get_poisson_solve_fn_fv
from jax import jit, vmap

@vmap
def get_energy(H):
    nx, ny, _ = H.shape
    dx = Lx / nx
    dy = Ly / ny
    u_y = -(H[:,:,1] - H[:,:,0]) / dx
    u_x = (H[:,:,3] - H[:,:,0]) / dy
    return jnp.mean(u_x**2 + u_y**2) * Lx * Ly

    

f_poisson = vmap(jit(get_poisson_solve_fn_fv(sim_params)))


H = f_poisson(trajectory)
energy = get_energy(H)

H_gs = f_poisson(trajectory_gs)
energy_gs = get_energy(H_gs)

H_ec = f_poisson(trajectory_ec)
energy_ec = get_energy(H_ec)

plt.plot(energy, color="blue")
plt.plot(energy_gs, color="green")
plt.plot(energy_ec, color="red")
plt.ylim([0,250])