# 1.1 Build & evaluate an Emulator

Given a dynamical system $F: x_t \rightarrow x_{t_1} := F(x_t)$, we want to learn an emulator $G \approx F$ as a neural network from $F$-sampled trajectories.

Evaluate the quality of $G$ using:
- L2 or other error metric on 1-step updates
- L2 or other metric on longer rollouts
- Compare Lyapunov spectrum to original model
- Compare PSD etc. to original model
For each of these metrics, there are baselines one can compare to from the literature.


In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device, as_tensor

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### pick a (trained) emulator

In [None]:
from L96_emulator.run import setup

# experiments to use: 
# dt=0.05  : exp_id = 26 for minimal, exp_id=27 for bilinear net
# dt=0.0125: exp_id = 28 for minimal, exp_id=29 for bilinear net


exp_id = 30

exp_names = os.listdir('experiments/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

args = setup(conf_exp=f'experiments/{conf_exp}.yml')
args.pop('conf_exp')

K,J = args['K'], args['J']

In [None]:
assert args['dt_net'] == args['dt']
if J > 0:
    F, h, b, c = 10., 1., 10., 10.
else:
    F, h, b, c = 8., 1., 10., 10.


### load and instantiate the emulator model and an 'optimal' comparison available for L96

In [None]:
import torch 
import numpy as np
from L96_emulator.eval import load_model_from_exp_conf
from L96_emulator.networks import named_network


args['model_forwarder'] = 'rk4_default'
if args['padding_mode'] == 'valid':
    print('switching from local training to global evaluation')
    args['padding_mode'] = 'circular'
model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)

if not training_outputs is None:
    training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

    fig = plt.figure(figsize=(8,8))
    seq_length = args['seq_length']
    plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
    plt.title('training')
    plt.ylabel('validation error')
    plt.legend()
    fig.patch.set_facecolor('xkcd:white')
    plt.show()

# upper bound: model re-implementation of L96 in torch (conv1d->pointwise_square->conv1d)
model_ubo, model_forwarder_ubo =named_network(
        model_name='MinimalConvNetL96',
        n_input_channels=J+1,
        n_output_channels=J+1,
        seq_length=1,
        **{'filters': [0],
           'kernel_sizes': [4],
           'init_net': 'analytical',
           'K_net': K,
           'J_net': J,
           'dt_net': args['dt'],
           'l96_F' : F, 
           'l96_h' : h, 
           'l96_b' : b, 
           'l96_c' : c, 
           'model_forwarder': 'rk4_default',
           'padding_mode' : 'circular'}
    )

# evaluate errors

In [None]:
from L96_emulator.eval import sortL96fromChannels, sortL96intoChannels
dX_dt = np.empty(K*(J+1), dtype=dtype_np)

### settings for different solvers

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default

dts = {Model_forwarder_predictorCorrector : args['dt']/10,
       Model_forwarder_rk4default : args['dt']}

### load some data to get sensible test system state

In [None]:
from L96_emulator.run import sel_dataset_class
from L96_emulator.util import predictor_corrector, rk4_default, get_data

spin_up_time, train_frac = args['spin_up_time'], args['train_frac']
normalize_data = bool(args['normalize_data'])
T, N_trials, dt = args['T'], args['N_trials'], args['dt']

out, _ = get_data(K=K, J=J, T=T, dt=dt, N_trials=N_trials, F=F, h=h, b=b, c=c, 
                  resimulate=True, solver=rk4_default,
                  save_sim=False, data_dir=data_dir)


prediction_task = 'state'
lead_time = 1
DatasetClass = sel_dataset_class(prediction_task=prediction_task, N_trials=N_trials, local=False)
dg_train = DatasetClass(data=out, J=J, offset=lead_time, normalize=normalize_data, 
                   start=int(spin_up_time/dt), 
                   end=int(np.floor(T/dt*train_frac)))


### set up ground-truth simulator code

In [None]:
from L96sim.L96_base import f1, f2, pf2

if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

## direct error on tendencies (rhs of diff.eq.)

In [None]:
T_start = np.arange(int(T/dt)) # grab initial states for rollout from long-running simulations
i_trial = np.random.choice(N_trials, size=T_start.shape)
idx_show = np.arange(0,len(T_start)-1, len(T_start)//3)

In [None]:
MSEs = np.zeros(len(T_start))
MSEs_ubo = np.zeros(len(T_start))

print('\n')
print('MSEs are on differential equation (tendencies) !')
print('\n')

for i in range(len(T_start)): # diff.eq. implementaion in numpy cannot necessarily handle parallel solving
    inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
    inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

    out_np = fun(0., inputs)
    out_model = model.forward(inputs_torch).detach().cpu().numpy()
    out_ubo = model_ubo.forward(inputs_torch).detach().cpu().numpy()
    
    MSEs[i] = ((out_np - sortL96fromChannels(out_model))**2).mean()
    MSEs_ubo[i] = ((out_np - sortL96fromChannels(out_ubo))**2).mean()

    
print('MSEs              ', MSEs[idx_show])
print('MSEs - upper bound', MSEs_ubo[idx_show])

print('\n')
print('MSEs (* dt)              ', MSEs[idx_show]*dt**2)
print('MSEs (* dt) - upper bound', MSEs_ubo[idx_show]*dt**2)


plt.plot(np.sort(MSEs), label='learned')
plt.plot(np.sort(MSEs_ubo), label='analytic')
plt.title('comparison of MSEs (sorted), learned and analyical')
plt.legend()
plt.show()

## error on resolvent (1-step integration error) for different solvers

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
import torch 

print('\n')
print('MSEs are on system state !')
print('\n')

MSEs = np.zeros(len(T_start))
MSEs_ubo = np.zeros(len(T_start))

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J)

for MFWD in [Model_forwarder_predictorCorrector, Model_forwarder_rk4default]:
    model_forwarder_np = MFWD(Torch_solver(fun), 
                              dt=dts[MFWD])
    model_forwarder = MFWD(model=model, dt=dts[MFWD])
    model_forwarder_ubo = MFWD(model=model_ubo, dt=dts[MFWD])

    MSEs = np.zeros(len(T_start))
    for i in range(len(T_start)):
        inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

        out_np = model_forwarder_np(inputs_torch)
        out_model = model_forwarder(inputs_torch)
        out_ubo = model_forwarder_ubo(inputs_torch)

        MSEs[i] = ((out_np - out_model)**2).mean().detach().cpu().numpy()
        MSEs_ubo[i] = ((out_np - out_ubo)**2).mean().detach().cpu().numpy()

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')

    print('MSEs              ', MSEs)
    print('MSEs - upper bound', MSEs_ubo)

    plt.figure(figsize=(8,5))
    plt.plot(np.sort(MSEs), label='learned')
    plt.plot(np.sort(MSEs_ubo), label='analytic')
    plt.title('comparison of MSEs (sorted), learned and analyical')
    plt.legend()
    plt.show()

## multi-step integration error (rollout error) for different solvers

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout

print('\n')
print('MSEs are on system state !')
print('\n')


MTU = 10 # rollout time in time units, should be rough estimate of first Lyapunov exponent

for i in idx_show:
    print(f'integrating for starting point {i+1} / {len(T_start)}')
    for MFWD in [Model_forwarder_predictorCorrector, Model_forwarder_rk4default]:

        print(f'solver {MFWD}, dt = {dts[MFWD]}')
        T_dur = int(MTU/dts[MFWD])

        model_forwarder_np = MFWD(Torch_solver(fun), 
                                  dt=dts[MFWD])
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        model_forwarder_ubo = MFWD(model=model_ubo, dt=dts[MFWD])

        model_simulate = get_rollout_fun(dg_train, model_forwarder, prediction_task)
        ubo_simulate = get_rollout_fun(dg_train, model_forwarder_ubo, prediction_task)
        np_simulate = get_rollout_fun(dg_train, model_forwarder_np, prediction_task)

        out_np = np_simulate(y0=dg_train[T_start[i]].copy(), 
                             dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                             n_steps=T_dur)
        out_np = sortL96fromChannels(out_np * dg_train.std + dg_train.mean)
        out_model = model_simulate(y0=dg_train[T_start[i]].copy(), 
                                   dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                                   n_steps=T_dur)
        out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

        out_ubo = ubo_simulate(y0=dg_train[T_start[i]].copy(), 
                                   dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                                   n_steps=T_dur)
        out_ubo = sortL96fromChannels(out_ubo * dg_train.std + dg_train.mean)

        fig = plot_rollout(out_np, out_model, out_comparison=out_ubo, n_start=0, n_steps=T_dur, K=K)
        plt.subplot(1,2,2)
        plt.legend(['trained model', '(only slow vars)', 'upper-bound model', '(only slow vars)'])
        plt.suptitle('integration scheme: ' + str(MFWD))
        plt.show()


# check long-term stability
- simulate long trajectories using different emulators/simulators
- do so in small segments, checking repeatedly for divergence and keeping memory footprint limited
- also accumulate statistics along trajectories

In [None]:
i = -1 # pick one starting point for a long simulation
n_rep = 100
T_dur = 10 if J>0 else 100


bins_K = np.linspace(out[0,int(spin_up_time/dt):, :K].min(),
                     out[0,int(spin_up_time/dt):, :K].max(),
                     50)

if J > 0:
    bins_J = np.linspace(out[0,int(spin_up_time/dt):, K:].min(),
                         out[0,int(spin_up_time/dt):, K:].max(),
                         50)
else: 
    bins_J = np.array([])

def calc_state_pdf(out, T_start=0, T_end=-1, bins_K=None, bins_J=None, n_bins= 100):
    # assuming out.shape = (T, K*(J+1))
    out_K = out[T_start:T_end][:K]
    bins_K = np.linspace(out_K.min(), out_K.max(), n_bins) if bins_K is None else bins_K
    pdf_K, bins_K = np.histogram(out_K, bins=bins_K, density=True)

    if J > 0:
        out_J = out[T_start:T_end][K:]
        bins_J = np.linspace(out_J.min(), out_J.max(), n_bins) if bins_J is None else bins_J
        pdf_J, bins_J = np.histogram(out_J, bins=bins_J, density=True)
    else: 
        pdf_J, bins_J = None, np.array([])
    return pdf_K, pdf_J, bins_K, bins_J

def iter_solve_and_stats(out, simulate_fun, T_dur, n_rep, bins_K, bins_J):

    pdf_Ks, pdf_Js = [], []
    for n in range(n_rep):
        print(f'- {n+1} / {n_rep}')
        out = simulate_fun(y0=out[-1:].copy(),
                          dy0=None,
                          n_steps=T_dur)
        out = out * dg_train.std + dg_train.mean
        assert not np.any(np.isnan(out))        
        pdf_K_n, pdf_J_n, _, _ = calc_state_pdf(sortL96fromChannels(out), bins_K=bins_K, bins_J=bins_J)
        pdf_Ks.append(pdf_K_n)
        if J > 0:
            pdf_Js.append(pdf_J_n)
    return out, pdf_Ks, pdf_Js
            
print('simulating from simulator')
out_np, pdf_K_np, pdf_J_np = iter_solve_and_stats(dg_train[T_start[i]].copy().reshape(1,J+1,K), 
                                                  np_simulate, 
                                                  int(T_dur/dt), 
                                                  n_rep, 
                                                  bins_K, 
                                                  bins_J)

print('simulating from emulator')
out_model, pdf_K_model, pdf_J_model = iter_solve_and_stats(dg_train[T_start[i]].copy().reshape(1,J+1,K), 
                                                  model_simulate, 
                                                  int(T_dur/dt), 
                                                  n_rep, 
                                                  bins_K, 
                                                  bins_J)

print('simulating from reference emulator')
out_ubo, pdf_K_ubo, pdf_J_ubo = iter_solve_and_stats(dg_train[T_start[i]].copy().reshape(1,J+1,K), 
                                                  ubo_simulate, 
                                                  int(T_dur/dt), 
                                                  n_rep, 
                                                  bins_K, 
                                                  bins_J)

# quick visual inspection
plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
plt.imshow(sortL96fromChannels(out_np).T, aspect='auto')
plt.title('simulator')
plt.colorbar()
plt.subplot(1,3,2)
plt.imshow(sortL96fromChannels(out_model).T, aspect='auto')
plt.title('learned emulator')
plt.colorbar()
plt.subplot(1,3,3)
plt.imshow(sortL96fromChannels(out_ubo).T, aspect='auto')
plt.title('upper-bound emulator')
plt.colorbar()
plt.show()

# compare distribution of state values (PSD)
- cf. figure 9 from Chattopadhyay, Hassanzadeh, Subramanian (2019)

In [None]:

xx = np.repeat(bins_K[:-1],2)
xx[::2] -= np.diff(bins_K).mean()

Q = 4 # divide long simulation into 4 quarters, track distributions invididually
o = n_rep//Q

if J > 0:
    plt.figure(figsize=(16,12))
    plt.subplot(1,2,1)
else:
    plt.figure(figsize=(12,12))


for q in range(Q):
    if J > 0:
        plt.subplot(2,4,2*q+1)
    else:
        plt.subplot(2,2,q+1)
    # histogram of all data is average of (normalized!) histograms
    pdf_K_np_q = np.stack(pdf_K_np[q*o:(q+1)*o]).mean(axis=0)
    pdf_K_model_q = np.stack(pdf_K_model[q*o:(q+1)*o]).mean(axis=0)
    pdf_K_ubo_q = np.stack(pdf_K_ubo[q*o:(q+1)*o]).mean(axis=0)

    plt.semilogy(xx, np.repeat(pdf_K_np_q, 2), color='g', label='simulator')
    plt.semilogy(xx, np.repeat(pdf_K_ubo_q, 2), color='k', label='upper-bound model')
    plt.semilogy(xx, np.repeat(pdf_K_model_q, 2), color='b', label='model')
    if J > 0:
        plt.xlabel('value of slow variables')
    else:
        plt.xlabel('state value')        
    plt.ylabel('relative frequency')
    plt.legend()

    if J > 0:
        plt.subplot(2,4,2*(q+1))
        xx = np.repeat(bins_J[:-1],2)
        xx[::2] -= np.diff(bins_J).mean()
        pdf_J_np_q = np.stack(pdf_J_np[q*o:(q+1)*o]).mean(axis=0)
        pdf_J_model_q = np.stack(pdf_J_model[q*o:(q+1)*o]).mean(axis=0)
        pdf_J_ubo_q = np.stack(pdf_J_ubo[q*o:(q+1)*o]).mean(axis=0)
        plt.semilogy(xx, np.repeat(pdf_J_np_q, 2), color='g')
        plt.semilogy(xx, np.repeat(pdf_J_ubo_q, 2), color='k')
        plt.semilogy(xx, np.repeat(pdf_J_model_q, 2), color='b')
        plt.xlabel('value of fast variables')
        plt.ylabel('relative frequency')
    plt.title(f'Quarter {q+1} / {Q}')
    plt.suptitle('distribution of state values for long simulation')
plt.show()

# tbd: compare Lyapunov spectrum

In [None]:
# yeah, hf fiddling the pytorch model into Julia...

# share notebook results via html file

In [None]:
!jupyter nbconvert --output-dir='/gpfs/home/nonnenma/projects/lab_coord/mdml_wiki/marcel/emulators' --to html eval_emulator.ipynb