In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

# Emulator evaluation results

In [None]:
import torch 
from L96_emulator.run import setup, sel_dataset_class
from L96_emulator.eval import sortL96fromChannels, sortL96intoChannels, load_model_from_exp_conf
from L96_emulator.networks import named_network, Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import predictor_corrector, rk4_default, get_data, as_tensor
from L96sim.L96_base import f1, f2, pf2

# experiments to use: 

# for one-level L96, K=40, F=8

# reference (analytical) emulators:
# dt=0.05  : exp_id=34 for minimal, exp_id=35 for bilinear net
# dt=0.0125: exp_id=37 for minimal, exp_id=36 for bilinear net
# full domain training: 
# dt=0.05  : exp_id=26 for minimal, exp_id=27 for bilinear net
# dt=0.0125: exp_id=28 for minimal, exp_id=29 for bilinear net
# local training:
# K_local = 10, batch-size = 32
# dt=0.05  : exp_id=30 for minimal, exp_id=31 for bilinear net
# dt=0.0125: exp_id=32 for minimal, exp_id=33 for bilinear net
# K_local = 1, batch-size = 32
# dt=0.05  : exp_id=38 for minimal, exp_id=39 for bilinear net
# K_local = 1, batch-size = 1
# dt=0.05  : exp_id=40 for minimal, exp_id=41 for bilinear net

# for one-level L96, K=36, F=10

# full domain training: 
# dt=0.01  : exp_id=42 for minimal, exp_id=43 for bilinear net


#exp_ids = [77,  90,  71,  81, 65]
exp_ids = [77, 95,  97, 90, 35]

exp_id_model_sorted = [np.arange(len(exp_ids))]
model_names = ['emulator training']


# initial one-level L96, K=40, F=8 with various degrees of locality: 
#exp_ids = [34, 26, 30, 38, 40, 35, 27, 31, 39, 41]
#exp_id_model_sorted = [np.arange(0,5), np.arange(5,10)]
#model_names = ['quadratic nonlinearity network', 'bilinear layer network']

# 4x4 one-level L96, K=40, F=8 with various degrees of locality and batch-size
#exp_ids = np.concatenate((np.arange(49, 61), np.arange(61, 65)))
#exp_id_model_sorted=[np.arange(4), np.arange(4,8), np.arange(8,12), np.arange(12,16), ]
#model_names=['batch-size 1', 'batch-size 4', 'batch-size 16', 'batch-size 64']

#exp_ids = [42, 43, 44]
#exp_id_model_sorted = [np.arange(len(exp_ids))]
#model_names=['emulator training']


all_lgnd = []
all_models, all_model_forwarders, all_training_outputs = [], [], []

fig = plt.figure(figsize=(8,8))
for exp_id in exp_ids:

    exp_names = os.listdir('experiments/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')
    args['model_forwarder'] = 'rk4_default'

    K,J = args['K'], args['J']
    assert args['dt_net'] == args['dt']

    if J > 0:
        F, h, b, c = 10., 1., 10., 10.
    else:
        h, b, c = 1., 10., 10.
        F = 10. if K==36 else 8.

    exp_str = 'N=' + str(args['N_trials'] * int(args['T']/args['dt'] - args['spin_up_time']/args['dt']))
        
    all_lgnd.append(exp_str)

    if args['padding_mode'] == 'valid':
        print('switching from local training to global evaluation')
        args['padding_mode'] = 'circular'
    model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)
    all_models.append(model)
    all_model_forwarders.append(model_forwarder)
    all_training_outputs.append(training_outputs)

    if not training_outputs is None:
        seq_length = args['seq_length']
        plt.semilogy(training_outputs['validation_loss'], label=all_lgnd[-1])
        print('final loss', np.min(training_outputs['validation_loss']))

all_lgnd = np.array(all_lgnd)
plt.title('training')
plt.ylabel('validation error')
plt.legend()
fig.patch.set_facecolor('xkcd:white')
plt.show()

dX_dt = np.empty(K*(J+1), dtype=dtype_np)

train_frac = args['train_frac']
normalize_data = bool(args['normalize_data'])
dt = args['dt']

N_trials = 1000
spin_up_time = 50
T = (1000)*dt + spin_up_time

out, _ = get_data(K=K, J=J, T=T, dt=dt, N_trials=N_trials, F=F, h=h, b=b, c=c, 
                  resimulate=True, solver=rk4_default,
                  save_sim=False, data_dir=data_dir)

if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)


In [None]:
T_start = np.arange(int(spin_up_time/dt), int(T/dt)) # grab initial states for rollout from long-running simulations
i_trial = np.random.choice(N_trials, size=T_start.shape)
idx_show = np.arange(0,len(T_start)-1, len(T_start)//3)

# state-prediction RMSES

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.networks import Model_forwarder_forwardEuler
import torch 

print('\n')
print('MSEs are on system state !')
print('\n')

MFWDs = [Model_forwarder_rk4default]
dts = {Model_forwarder_predictorCorrector : args['dt']/10,
       Model_forwarder_rk4default : args['dt']}

RMSEs_states = np.zeros((len(MFWDs), len(exp_ids), len(T_start)))

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return as_tensor(sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))
    
for mf_i, MFWD in enumerate(MFWDs):

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')

    model_forwarder_np = MFWD(Torch_solver(fun), dt=dts[MFWD])
    
    for m_i, model in enumerate(all_models):
        
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        
        for i in range(len(T_start)):
            inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
            inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

            out_np = model_forwarder_np(inputs_torch)
            out_model = model_forwarder(inputs_torch)

            RMSEs_states[mf_i, m_i, i] = np.sqrt(((out_np - out_model)**2).mean().detach().cpu().numpy())

    plt.figure(figsize=(16,12))
    for i in range(len(exp_id_model_sorted)):
        plt.subplot(len(exp_id_model_sorted),2,1+2*i)
        plt.semilogy(np.sort(RMSEs_states[mf_i][exp_id_model_sorted[i]],axis=1).T)
        plt.ylabel('RMSE')
        plt.xlabel('test data point (sorted)')
        plt.title(model_names[i])
        plt.legend(all_lgnd[exp_id_model_sorted[i]])
        plt.subplot(len(exp_id_model_sorted),2,2+2*i)
        plt.ylabel('RMSE')
        plt.boxplot(RMSEs_states[mf_i][exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
        plt.title(model_names[i])
    plt.show()

# Jacobian error Frobenius norms

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import calc_jakobian_onelevelL96_tendencies, calc_jakobian_rk4, get_jacobian_torch
import torch 

def model_np(inputs):
    return fun(0., inputs).copy()

print('\n')
print('MSEs are on system state !')
print('\n')

MFWDs = [Model_forwarder_rk4default]
L2_jakobians = np.zeros((len(MFWDs), len(exp_ids), len(T_start)))

  
for mf_i, MFWD in enumerate(MFWDs):

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')
    
    for m_i, model in enumerate(all_models):
        
        for p in model.parameters():
            p.requires_grad = False
        
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        print('\n')
        print(f'model forwarder for {model}')
        print('\n')
        
        for i in range(len(T_start)):
            inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
            inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
            inputs_torch.requires_grad = True
            
            #J_np = calc_jakobian_onelevelL96_tendencies(inputs, n=K)
            J_np = calc_jakobian_rk4(inputs, calc_f=model_np, 
                         calc_J_f=calc_jakobian_onelevelL96_tendencies, dt=dt, n=K)

            J_torch = get_jacobian_torch(model_forwarder, inputs=inputs_torch, n=K)

            L2_jakobians[mf_i, m_i, i] = np.sqrt(((J_np - J_torch)**2).sum())

    plt.figure(figsize=(16,12))
    for i in range(len(exp_id_model_sorted)):
        plt.subplot(2,2,1+2*i)
        plt.semilogy(np.sort(L2_jakobians[mf_i][exp_id_model_sorted[i]],axis=1).T)
        plt.ylabel('RMSE')
        plt.xlabel('test data point (sorted)')
        plt.title(model_names[i])
        plt.legend(all_lgnd[exp_id_model_sorted[i]])
        plt.subplot(2,2,2+2*i)
        plt.ylabel('RMSE')
        plt.boxplot(L2_jakobians[mf_i][exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
        plt.title(model_names[i])
    plt.show()

# rollout RMSEs

In [None]:
T_data=1      # length of base simulation (in [au]!) to get initial state for rollouts
N_trials=1000 # number of base simulations 

n_start=1000 # number of rollouts 
T_dur=100    # length of rollouts (in steps!)
F = 8.  # Lorenz-96 forcing parameter

##########################
#       test data        #
##########################

from L96_emulator.util import rk4_default, get_data

spin_up_time = 5.
T = T_data*args['dt'] + spin_up_time

try:
    assert data.shape == (N_trials, int(T/args['dt'])+1, args['K']*(args['J']+1))
    print('found data with matching specs (shape)')
except:
    print('generating test data')
    data, _ = get_data(K=args['K'], J=args['J'], T=T, dt=args['dt'], N_trials=N_trials, 
                      F=args['F_net'], 
                      resimulate=True, solver=rk4_default,
                      save_sim=False, data_dir=data_dir)

##########################
#       rollouts         #
##########################

from L96sim.L96_base import f1

T_start = np.linspace(int(spin_up_time/args['dt']), int(T/args['dt']), n_start).astype(np.int)
i_trial = np.random.choice(N_trials, size=T_start.shape, replace=False)
print('T_start, i_tria', (T_start, i_trial))

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        J = x.shape[-2] - 1
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return as_tensor(sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))


sols = np.nan * np.ones((len(all_model_forwarders)+1, n_start, T_dur+1, args['K']))

for i_model in range(len(all_model_forwarders)): 

    model_forwarder_i = all_model_forwarders[i_model]

    def model_simulate(y0, dy0, n_steps):
        x = np.empty((n_steps+1, *y0.shape[1:]))
        x[0] = y0.copy()
        xx = as_tensor(x[0])
        for i in range(1,n_steps+1):
            xx = model_forwarder_i(xx.reshape(-1,J+1,args['K']))
            x[i] = xx.detach().cpu().numpy().copy()
        return x

    print('forwarding model ' + str(model_forwarder_i))
    X_init = []
    for i in range(n_start):
        X_init.append(data[i_trial[i], T_start[i]] if N_trials > 1 else data[T_start[i]])
        X_init[-1] = sortL96intoChannels(X_init[-1][:args['K']*(J+1)],J=J)
    X_init = np.stack(X_init)
    X_init = X_init.reshape(1, *X_init.shape)
    with torch.no_grad():
        sol = model_simulate(y0=X_init, dy0=None, n_steps=T_dur)
    sols[i_model,:,:,:] = sol[:,:,0,:].transpose(1,0,2)

# not parallelising Numba model over initial values for rollouts
model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=args['dt'])
def model_simulate(y0, dy0, n_steps):
    x = np.empty((n_steps+1, *y0.shape[1:]))
    x[0] = y0.copy()
    xx = as_tensor(x[0]).reshape(1,1,-1)
    for i in range(1,n_steps+1):
        xx = model_forwarder_np(xx.reshape(-1,args['J']+1,args['K']))
        x[i] = xx.detach().cpu().numpy().copy()
    return x

print('forwarding np model')
X_init = []
for i in range(n_start):
    X_init = data[i_trial[i], T_start[i]] if N_trials > 1 else data[T_start[i]]
    X_init = sortL96intoChannels(X_init,J=args['J'])
    X_init = X_init.reshape(1, *X_init.shape)

    with torch.no_grad():
        sol = model_simulate(y0=X_init, dy0=None, n_steps=T_dur)
    sols[-1,i,:,:] = sol[:,0,:]

### actual figure

In [None]:
fontsize=14

idx_dn = np.arange(len(exp_ids))
xx = np.array([1200, 4000, 8000, 12000, 120000]) * 0.8

plt.figure(figsize=(14,3.7))
RMSES_all = [RMSEs_states[0], L2_jakobians[0]]
titles = [r'state updates $x_{t+\Delta}$', r'Jacobians $J_\mathcal{M}$']
yaxes = ['', 'Frobenius norm of error']
for i in range(len(RMSES_all)):

    RMSEs = RMSES_all[i]

    ax = plt.subplot(1,3,i+2)    
    plt.plot(xx[:len(idx_dn)], RMSEs[idx_dn,:].mean(axis=1), '*', color='black', label='deepNet')
    plt.xticks([], fontsize=fontsize)
    plt.ylabel(yaxes[i], fontsize=fontsize)
    #if np.mod(i,2) == 1:
    #    plt.legend(fontsize=fontsize)
    plt.title(titles[i])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if i == 0:
        #plt.axis([500, 110000000, 0.02, 0.2])
        plt.axis([500, 12500, 0, 1.05*RMSEs[idx_dn,:].mean(axis=1).max()])
        plt.yticks([0, 0.010], fontsize=fontsize)
    else:
        plt.axis([500, 12500, 0.05, 1.05*RMSEs[idx_dn,:].mean(axis=1).max()])
        plt.yticks([0.1, 0.2], fontsize=fontsize)
    plt.xticks([1e3, 1e4],  
               [r'$10^3$', r'$10^4$'], 
               fontsize=fontsize)
    plt.xlabel('training set size N', fontsize=fontsize)


ax = plt.subplot(1,3,1)
clrs = ['purple','blue', 'orange', 'green', 'k']
fontsize= 14

rmses = np.zeros((len(all_model_forwarders), n_start, T_dur+1))
for i_model in range(len(all_model_forwarders))[::-1]:
    for i in range(n_start):
        rmses[i_model,i,:] = np.sqrt(np.mean((sols[i_model,i] - sols[-1,i])**2, axis=1))
    #plt.semilogy(rmses[i_model,0].T, clrs[i_model], label=panel_titles[i_model])
    #plt.semilogy(rmses[i_model,1:].T, clrs[i_model])
    plt.plot(rmses[i_model].mean(axis=0), clrs[i_model], label='N='+str(int(xx[i_model])), linewidth=2.5)

plt.legend(fontsize=fontsize, frameon=False)
plt.ylabel('RMSE', fontsize=fontsize)
plt.xlabel('time [au]', fontsize=fontsize)
plt.xticks([0, T_dur/2, T_dur], [0, args['dt']*T_dur/2, args['dt']*T_dur], fontsize=fontsize)
#plt.yticks([0, 3, 6], fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.title('rollouts', fontsize=fontsize)

#plt.savefig(res_dir + 'figs/emulator_eval.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()


# in principle, copy relevant code bits for local learning below 

In [None]:
"""
# local emulator evaluation fig 
# exp_ids = np.arange(49, 65)
fontsize=14
plt.figure(figsize=(16,4))
ax = plt.subplot(1,3,1)
plt.text(0.5, 0.5, 'sketch', fontsize=fontsize)
ax.axis('off')
plt.subplot(1,3,2)
plt.imshow(RMSEs_states[-1].mean(axis=1).reshape(4,4), vmax=3e-6, cmap='cividis')
plt.xlabel('local region size', fontsize=fontsize)
plt.xticks(range(4), ['640', '160', '40', '10'], fontsize=fontsize)
plt.ylabel('batch-size', fontsize=fontsize)
plt.yticks(range(4), ['1', '4', '16', '64'], fontsize=fontsize)
plt.colorbar()
plt.title(r'RMSE on predicted $x_{t+\Delta}$')
for i in range(4):
    plt.plot(np.array([-.5, .5])+i, np.array([-0.5, -0.5])+i, 'k--', linewidth=2)
    plt.plot(np.array([.5, .5])+i, np.array([-.5, .5])+i, 'k--', linewidth=2)
plt.plot([0.5, 3.5], [-0.5, -0.5], 'k--', linewidth=2)
plt.plot([3.5, 3.5], [-0.5, 2.5], 'k--', linewidth=2)

ax = plt.subplot(1,3,3)
batch_sizes = [1, 1, 1, 1, 4, 4, 4, 4, 16, 16, 16, 16, 64, 64, 64, 64]
region_sizes = [640, 160, 40, 10, 640, 160, 40, 10, 640, 160, 40, 10, 640, 160, 40, 10]
plt.semilogx([640, 640], [0.1e-6, 2e-6], 'k--', linewidth=2.)
for i in range(len(RMSEs_states[-1])):
    plt.semilogx(batch_sizes[i]*region_sizes[i], RMSEs_states[-1][i].mean(), 'b.', markersize=8.0)
plt.xlabel('locations per minibatch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(np.array([0.5, 1., 1.5, 2.])*1e-6, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.title(r'RMSE on predicted $x_{t+\Delta}$', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)

#plt.savefig(res_dir + 'figs/emulator_local.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()
"""