In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

# Emulator evaluation results

In [None]:
import torch 
from L96_emulator.run import setup, sel_dataset_class
from L96_emulator.eval import sortL96fromChannels, sortL96intoChannels, load_model_from_exp_conf
from L96_emulator.networks import named_network, Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import predictor_corrector, rk4_default, get_data, as_tensor
from L96sim.L96_base import f1, f2, pf2

# experiments to use: 

# for one-level L96, K=40, F=8

# reference (analytical) emulators:
# dt=0.05  : exp_id=34 for minimal, exp_id=35 for bilinear net
# dt=0.0125: exp_id=37 for minimal, exp_id=36 for bilinear net
# full domain training: 
# dt=0.05  : exp_id=26 for minimal, exp_id=27 for bilinear net
# dt=0.0125: exp_id=28 for minimal, exp_id=29 for bilinear net
# local training:
# K_local = 10, batch-size = 32
# dt=0.05  : exp_id=30 for minimal, exp_id=31 for bilinear net
# dt=0.0125: exp_id=32 for minimal, exp_id=33 for bilinear net
# K_local = 1, batch-size = 32
# dt=0.05  : exp_id=38 for minimal, exp_id=39 for bilinear net
# K_local = 1, batch-size = 1
# dt=0.05  : exp_id=40 for minimal, exp_id=41 for bilinear net

# for one-level L96, K=36, F=10

# full domain training: 
# dt=0.01  : exp_id=42 for minimal, exp_id=43 for bilinear net


exp_ids = [34, 26, 37, 27, 65]
exp_id_model_sorted = [np.arange(len(exp_ids))]
model_names = ['emulator training']


# initial one-level L96, K=40, F=8 with various degrees of locality: 
#exp_ids = [34, 26, 30, 38, 40, 35, 27, 31, 39, 41]
#exp_id_model_sorted = [np.arange(0,5), np.arange(5,10)]
#model_names = ['quadratic nonlinearity network', 'bilinear layer network']

# 4x4 one-level L96, K=40, F=8 with various degrees of locality and batch-size
#exp_ids = np.concatenate((np.arange(49, 61), np.arange(61, 65)))
#exp_id_model_sorted=[np.arange(4), np.arange(4,8), np.arange(8,12), np.arange(12,16), ]
#model_names=['batch-size 1', 'batch-size 4', 'batch-size 16', 'batch-size 64']

#exp_ids = [42, 43, 44]
#exp_id_model_sorted = [np.arange(len(exp_ids))]
#model_names=['emulator training']


all_lgnd = []
all_models, all_model_forwarders, all_training_outputs = [], [], []

fig = plt.figure(figsize=(8,8))
for exp_id in exp_ids:

    exp_names = os.listdir('experiments/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')
    args['model_forwarder'] = 'rk4_default'

    K,J = args['K'], args['J']
    assert args['dt_net'] == args['dt']

    if J > 0:
        F, h, b, c = 10., 1., 10., 10.
    else:
        h, b, c = 1., 10., 10.
        F = 10. if K==36 else 8.

    #exp_str = 'blnNet' if args['model_name']=='BilinearConvNetL96' else 'sqrNet'
    if args['init_net']=='analytical':
        exp_str = 'analytic'
    else:
        exp_str = 'localK'+str(args['K_local']) if args['loss_fun']=='local_mse' else 'fullDomain'
        exp_str += '_bs'+str(args['batch_size'])
        
    all_lgnd.append(exp_str)

    if args['padding_mode'] == 'valid':
        print('switching from local training to global evaluation')
        args['padding_mode'] = 'circular'
    model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)
    all_models.append(model)
    all_model_forwarders.append(model_forwarder)
    all_training_outputs.append(training_outputs)

    if not training_outputs is None:
        seq_length = args['seq_length']
        plt.semilogy(training_outputs['validation_loss'], label=all_lgnd[-1])

all_lgnd = np.array(all_lgnd)
plt.title('training')
plt.ylabel('validation error')
plt.legend()
fig.patch.set_facecolor('xkcd:white')
plt.show()

dX_dt = np.empty(K*(J+1), dtype=dtype_np)
dts = {Model_forwarder_predictorCorrector : args['dt']/10,
       Model_forwarder_rk4default : args['dt']}

spin_up_time, train_frac = args['spin_up_time'], args['train_frac']
normalize_data = bool(args['normalize_data'])
T, N_trials, dt = args['T'], args['N_trials'], args['dt']

T = (1000)*dt + spin_up_time

out, _ = get_data(K=K, J=J, T=T, dt=dt, N_trials=N_trials, F=F, h=h, b=b, c=c, 
                  resimulate=True, solver=rk4_default,
                  save_sim=False, data_dir=data_dir)

if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)


In [None]:
for i in [2,0,4]:
    print('\n\n\n')
    print(all_models[i])

In [None]:
T_start = np.arange(int(spin_up_time/dt), int(T/dt)) # grab initial states for rollout from long-running simulations
i_trial = np.random.choice(N_trials, size=T_start.shape)
idx_show = np.arange(0,len(T_start)-1, len(T_start)//3)

In [None]:
RMSEs_tendencies = np.zeros((len(exp_ids), len(T_start)))

print('\n')
print('MSEs are on differential equation (tendencies) !')
print('\n')

for m_i, model in enumerate(all_models):
    for i in range(len(T_start)): # diff.eq. implementaion in numpy cannot necessarily handle parallel solving
        inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

        out_np = fun(0., inputs)
        out_model = model.forward(inputs_torch).detach().cpu().numpy()

        RMSEs_tendencies[m_i,i] = np.sqrt(((out_np - sortL96fromChannels(out_model))**2).mean())


plt.figure(figsize=(16,12))
for i in range(len(exp_id_model_sorted)):
    plt.subplot(len(exp_id_model_sorted),2,1+2*i)
    plt.semilogy(np.sort(RMSEs_tendencies[exp_id_model_sorted[i]],axis=1).T)
    plt.ylabel('RMSE')
    plt.xlabel('test data point (sorted)')
    plt.title(model_names[i])
    plt.legend(all_lgnd[exp_id_model_sorted[i]])
    plt.subplot(len(exp_id_model_sorted),2,2+2*i)
    plt.ylabel('RMSE')
    plt.boxplot(RMSEs_tendencies[exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
    plt.title(model_names[i])
plt.show()

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
import torch 

print('\n')
print('MSEs are on system state !')
print('\n')

MFWDs = [Model_forwarder_predictorCorrector, Model_forwarder_rk4default]
RMSEs_states = np.zeros((len(MFWDs), len(exp_ids), len(T_start)))

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return as_tensor(sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))

    
for mf_i, MFWD in enumerate(MFWDs):

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')

    model_forwarder_np = MFWD(Torch_solver(fun), dt=dts[MFWD])
    
    for m_i, model in enumerate(all_models):
        
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        
        for i in range(len(T_start)):
            inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
            inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

            out_np = model_forwarder_np(inputs_torch)
            out_model = model_forwarder(inputs_torch)

            RMSEs_states[mf_i, m_i, i] = np.sqrt(((out_np - out_model)**2).mean().detach().cpu().numpy())

    plt.figure(figsize=(16,12))
    for i in range(len(exp_id_model_sorted)):
        plt.subplot(len(exp_id_model_sorted),2,1+2*i)
        plt.semilogy(np.sort(RMSEs_states[mf_i][exp_id_model_sorted[i]],axis=1).T)
        plt.ylabel('RMSE')
        plt.xlabel('test data point (sorted)')
        plt.title(model_names[i])
        plt.legend(all_lgnd[exp_id_model_sorted[i]])
        plt.subplot(len(exp_id_model_sorted),2,2+2*i)
        plt.ylabel('RMSE')
        plt.boxplot(RMSEs_states[mf_i][exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
        plt.title(model_names[i])
    plt.show()

In [None]:
fontsize=14
plt.figure(figsize=(16,4))
ax = plt.subplot(1,3,1)
plt.text(0.5, 0.5, 'sketch', fontsize=fontsize)
ax.axis('off')
plt.subplot(1,3,2)
plt.imshow(RMSEs_states[-1].mean(axis=1).reshape(4,4), vmax=3e-6, cmap='cividis')
plt.xlabel('local region size', fontsize=fontsize)
plt.xticks(range(4), ['640', '160', '40', '10'], fontsize=fontsize)
plt.ylabel('batch-size', fontsize=fontsize)
plt.yticks(range(4), ['1', '4', '16', '64'], fontsize=fontsize)
plt.colorbar()
plt.title(r'RMSE on predicted $x_{t+\Delta}$')
for i in range(4):
    plt.plot(np.array([-.5, .5])+i, np.array([-0.5, -0.5])+i, 'k--', linewidth=2)
    plt.plot(np.array([.5, .5])+i, np.array([-.5, .5])+i, 'k--', linewidth=2)
plt.plot([0.5, 3.5], [-0.5, -0.5], 'k--', linewidth=2)
plt.plot([3.5, 3.5], [-0.5, 2.5], 'k--', linewidth=2)

ax = plt.subplot(1,3,3)
batch_sizes = [1, 1, 1, 1, 4, 4, 4, 4, 16, 16, 16, 16, 64, 64, 64, 64]
region_sizes = [640, 160, 40, 10, 640, 160, 40, 10, 640, 160, 40, 10, 640, 160, 40, 10]
plt.semilogx([640, 640], [0.1e-6, 2e-6], 'k--', linewidth=2.)
for i in range(len(RMSEs_states[-1])):
    plt.semilogx(batch_sizes[i]*region_sizes[i], RMSEs_states[-1][i].mean(), 'b.', markersize=8.0)
plt.xlabel('locations per minibatch', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(np.array([0.5, 1., 1.5, 2.])*1e-6, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.title(r'RMSE on predicted $x_{t+\Delta}$', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)

plt.savefig(res_dir + 'figs/emulator_local.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import calc_jakobian_onelevelL96_tendencies, calc_jakobian_rk4, get_jacobian_torch
import torch 

def model_np(inputs):
    return fun(0., inputs).copy()

print('\n')
print('MSEs are on system state !')
print('\n')

MFWDs = [Model_forwarder_rk4default]
L2_jakobians = np.zeros((len(MFWDs), len(exp_ids), len(T_start)))

  
for mf_i, MFWD in enumerate(MFWDs):

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')
    
    for m_i, model in enumerate(all_models):
        
        for p in model.parameters():
            p.requires_grad = False
        
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        print('\n')
        print(f'model forwarder for {model}')
        print('\n')
        
        for i in range(len(T_start)):
            inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
            inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
            inputs_torch.requires_grad = True
            
            #J_np = calc_jakobian_onelevelL96_tendencies(inputs, n=K)
            J_np = calc_jakobian_rk4(inputs, calc_f=model_np, 
                         calc_J_f=calc_jakobian_onelevelL96_tendencies, dt=dt, n=K)

            J_torch = get_jacobian_torch(model_forwarder, inputs=inputs_torch, n=K)

            L2_jakobians[mf_i, m_i, i] = np.sqrt(((J_np - J_torch)**2).sum())

    plt.figure(figsize=(16,12))
    for i in range(len(exp_id_model_sorted)):
        plt.subplot(2,2,1+2*i)
        plt.semilogy(np.sort(L2_jakobians[mf_i][exp_id_model_sorted[i]],axis=1).T)
        plt.ylabel('RMSE')
        plt.xlabel('test data point (sorted)')
        plt.title(model_names[i])
        plt.legend(all_lgnd[exp_id_model_sorted[i]])
        plt.subplot(2,2,2+2*i)
        plt.ylabel('RMSE')
        plt.boxplot(L2_jakobians[mf_i][exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
        plt.title(model_names[i])
    plt.show()

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import calc_jakobian_onelevelL96_tendencies, calc_jakobian_rk4, get_jacobian_torch
import torch 

def model_np(inputs):
    return fun(0., inputs).copy()

print('\n')
print('MSEs are on system state !')
print('\n')

MFWDs = [Model_forwarder_rk4default]
L2_jakobians_tendencies = np.zeros((len(MFWDs), len(exp_ids), len(T_start)))

  
for mf_i, MFWD in enumerate(MFWDs):

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')
    
    for m_i, model in enumerate(all_models):
        
        for p in model.parameters():
            p.requires_grad = False
        
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        print('\n')
        print(f'model forwarder for {model}')
        print('\n')
        
        for i in range(len(T_start)):
            inputs = out[i_trial[i], T_start[i]] if N_trials > 1 else out[T_start[i]]
            inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
            inputs_torch.requires_grad = True
            
            J_np = calc_jakobian_onelevelL96_tendencies(inputs, n=K)
            #J_np = calc_jakobian_rk4(inputs, calc_f=model_np, 
            #             calc_J_f=calc_jakobian_onelevelL96_tendencies, dt=dt, n=K)

            J_torch = get_jacobian_torch(model, inputs=inputs_torch, n=K)

            L2_jakobians_tendencies[mf_i, m_i, i] = np.sqrt(((J_np - J_torch)**2).sum())

    plt.figure(figsize=(16,12))
    for i in range(len(exp_id_model_sorted)):
        plt.subplot(2,2,1+2*i)
        plt.semilogy(np.sort(L2_jakobians_tendencies[mf_i][exp_id_model_sorted[i]],axis=1).T)
        plt.ylabel('RMSE')
        plt.xlabel('test data point (sorted)')
        plt.title(model_names[i])
        plt.legend(all_lgnd[exp_id_model_sorted[i]])
        plt.subplot(2,2,2+2*i)
        plt.ylabel('RMSE')
        plt.boxplot(L2_jakobians_tendencies[mf_i][exp_id_model_sorted[i]].T, labels=all_lgnd[exp_id_model_sorted[i]])
        plt.title(model_names[i])
    plt.show()

### actual figure

In [None]:
n_panels = 4
plt.figure(figsize=((16,4)))

model_labels = ['analytic', 'trained', 'analytic',  'trained', 'trained']

plt.subplot(1,n_panels,1)
plt.ylabel('RMSE')
plt.boxplot(RMSEs_tendencies.T, labels=model_labels)
plt.gcf().text(0.15, 0.02, 'sqrNet')
plt.gcf().text(0.23, 0.02, 'bilinNet')
plt.title('tendencies', y=1.05)
plt.yticks(np.array([0, 2, 4, 6])*1e-6)

plt.subplot(1,n_panels,2)
plt.ylabel('RMSE')
plt.boxplot(RMSEs_states[1].T, labels=model_labels)
plt.gcf().text(0.35, 0.02, 'sqrNet')
plt.gcf().text(0.43, 0.02, 'bilinNet')
plt.title(r'state prediction (RK4, $\Delta=0.05$)', y=1.05)
plt.yticks(np.array([0, 1, 2, 3])*1e-7)

plt.subplot(1,n_panels,3)
plt.ylabel('RMSE')
plt.boxplot(RMSEs_states[0].T, labels=model_labels)
plt.gcf().text(0.56, 0.02, 'sqrNet')
plt.gcf().text(0.64, 0.02, 'bilinNet')
plt.title(r'state prediction (pred-corr, $\Delta=0.005$)', y=1.05)
plt.yticks(np.array([0, 0.5, 1, 1.5])*1e-7)

plt.subplot(1,n_panels,4)
plt.ylabel('Frobenius Norm')
plt.boxplot(L2_jakobians[0].T, labels=model_labels)
plt.gcf().text(0.75, 0.02, 'sqrNet')
plt.gcf().text(0.85, 0.02, 'bilinNet')
plt.title(r'Jakobian (RK4, $\Delta=0.05$)', y=1.05)
plt.yticks(np.array([2, 3, 4])*1e-7)
plt.savefig(res_dir + 'figs/emulator_eval.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()

In [None]:
n_panels = 4
fontsize = 14

plt.figure(figsize=((10,8)))

model_labels = ['analytic', 'trained', 'analytic',  'trained', 'trained']
clrs = ['cyan', 'blue', 'red', 'orange', 'green']

ax = plt.subplot(n_panels//2,n_panels//2,1)
RMSEs = RMSEs_tendencies
plt.ylabel('RMSE', fontsize=fontsize)
plt.bar(np.arange(len(model_labels)), RMSEs.mean(axis=1), color=clrs)
for i in range(len(model_labels)):
    plt.plot(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             'k', linewidth=1.5)
#plt.gcf().text(0.15, 0.02, 'sqrNet', fontsize=fontsize)
#plt.gcf().text(0.23, 0.02, 'bilinNet', fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), [], fontsize=fontsize)
plt.title('tendencies', y=1.05, fontsize=fontsize)
plt.yticks(np.array([0, 2, 4])*1e-6, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = plt.subplot(n_panels//2,n_panels//2,2)
RMSEs = RMSEs_states[1]
plt.ylabel('RMSE', fontsize=fontsize)
plt.bar(np.arange(len(model_labels)), RMSEs.mean(axis=1), color=clrs)
for i in range(len(model_labels)):
    plt.plot(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             'k', linewidth=1.5)
#plt.gcf().text(0.35, 0.02, 'sqrNet', fontsize=fontsize)
#plt.gcf().text(0.43, 0.02, 'bilinNet', fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), [], fontsize=fontsize)
plt.title(r'state prediction (RK4, $\Delta=0.05$)', y=1.05, fontsize=fontsize)
plt.yticks(np.array([0, 1, 2])*1e-7, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

"""
ax = plt.subplot(1,n_panels,3)
RMSEs = RMSEs_states[0]
plt.ylabel('RMSE')
plt.bar(np.arange(len(model_labels)), RMSEs.mean(axis=1), color=clrs)
for i in range(len(model_labels)):
    plt.plot(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             'k', linewidth=1.5)
plt.gcf().text(0.56, 0.02, 'sqrNet')
plt.gcf().text(0.64, 0.02, 'bilinNet')
plt.xticks(np.arange(len(model_labels)), model_labels)
plt.title(r'state prediction (pred-corr, $\Delta=0.005$)', y=1.05)
plt.yticks(np.array([0, 0.5, 1, 1.5])*1e-7)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
"""

ax = plt.subplot(n_panels//2,n_panels//2,3)
RMSEs = L2_jakobians_tendencies[0]
plt.ylabel('Frobenius Norm', fontsize=fontsize)
plt.bar(np.arange(len(model_labels)), RMSEs.mean(axis=1), color=clrs)
for i in range(len(model_labels)):
    plt.plot(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             'k', linewidth=1.5)
plt.gcf().text(0.18, 0.05, 'sqrNet', fontsize=fontsize)
plt.gcf().text(0.35, 0.05, 'bilinNet', fontsize=fontsize)
plt.title(r'Jakobian (tendencies)',  fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), model_labels, fontsize=fontsize)
plt.yticks(np.array([1, 2, 3])*1e-6, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)


#"""
ax = plt.subplot(n_panels//2,n_panels//2,4)
RMSEs = L2_jakobians[0]
plt.ylabel('Frobenius Norm', fontsize=fontsize)
plt.bar(np.arange(len(model_labels)), RMSEs.mean(axis=1), color=clrs)
for i in range(len(model_labels)):
    plt.plot(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             'k', linewidth=1.5)
plt.gcf().text(0.60, 0.05, 'sqrNet', fontsize=fontsize)
plt.gcf().text(0.77, 0.05, 'bilinNet', fontsize=fontsize)
plt.title(r'Jakobian (RK4, $\Delta=0.05$)', fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), model_labels, fontsize=fontsize)
plt.yticks(np.array([1, 2, 3])*1e-7, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
#"""


#plt.savefig(res_dir + 'figs/emulator_eval.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()

In [None]:
n_panels = 4
fontsize = 14

plt.figure(figsize=((12,8)))

model_labels = ['analytic', 'trained', 'analytic',  'trained', 'trained']
clrs = ['cyan', 'blue', 'red', 'orange', 'green']

ax = plt.subplot(n_panels//2,n_panels//2,1)
RMSEs = RMSEs_tendencies
for i in range(len(model_labels)):
    plt.semilogy(i*np.ones(2)+np.array([-0.5,0.5]), 
             RMSEs.mean(axis=1)[i]*np.ones(2),
             color=clrs[i], linewidth=1.5)
    plt.semilogy(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             color='k', linewidth=1.5)
plt.xticks(np.arange(len(model_labels)), [], fontsize=fontsize)
plt.title('tendencies', y=1.05, fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
#plt.yticks(np.array([0, 2, 4])*1e-6, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = plt.subplot(n_panels//2,n_panels//2,2)
RMSEs = RMSEs_states[1]
for i in range(len(model_labels)):
    plt.semilogy(i*np.ones(2)+np.array([-0.5,0.5]), 
             RMSEs.mean(axis=1)[i]*np.ones(2),
             color=clrs[i], linewidth=1.5)
    plt.semilogy(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             color='k', linewidth=1.5)
plt.xticks(np.arange(len(model_labels)), [], fontsize=fontsize)
plt.title(r'state prediction (RK4, $\Delta=0.05$)', y=1.05, fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
#plt.yticks(np.array([0, 1, 2])*1e-7, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = plt.subplot(n_panels//2,n_panels//2,3)
RMSEs = L2_jakobians_tendencies[0]
for i in range(len(model_labels)):
    plt.semilogy(i*np.ones(2)+np.array([-0.5,0.5]), 
             RMSEs.mean(axis=1)[i]*np.ones(2),
             color=clrs[i], linewidth=1.5)
    plt.semilogy(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             color='k', linewidth=1.5)
plt.title(r'Jakobian (tendencies)',  fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), model_labels, fontsize=fontsize)
#plt.yticks(np.array([1, 2, 3])*1e-6, fontsize=fontsize)
plt.ylabel('Frobenius Norm', fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)


#"""
ax = plt.subplot(n_panels//2,n_panels//2,4)
RMSEs = L2_jakobians[0]
for i in range(len(model_labels)):
    plt.semilogy(i*np.ones(2)+np.array([-0.5,0.5]), 
             RMSEs.mean(axis=1)[i]*np.ones(2),
             color=clrs[i], linewidth=1.5)
    plt.semilogy(i*np.ones(2), 
             RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
             color='k', linewidth=1.5)
plt.title(r'Jakobian (RK4, $\Delta=0.05$)', fontsize=fontsize)
plt.xticks(np.arange(len(model_labels)), model_labels, fontsize=fontsize)
plt.ylabel('Frobenius Norm', fontsize=fontsize)
#plt.yticks(np.array([1, 2, 3])*1e-7, fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
#"""

plt.gcf().text(0.18, 0.05, 'sqrNet', fontsize=fontsize)
plt.gcf().text(0.30, 0.05, 'bilinNet', fontsize=fontsize)
plt.gcf().text(0.40, 0.05, 'deepNet', fontsize=fontsize)
plt.gcf().text(0.60, 0.05, 'sqrNet', fontsize=fontsize)
plt.gcf().text(0.72, 0.05, 'bilinNet', fontsize=fontsize)
plt.gcf().text(0.82, 0.05, 'deepNet', fontsize=fontsize)


plt.savefig(res_dir + 'figs/emulator_eval.pdf', bbox_inches='tight', pad_inches=0, frameon=False)

plt.show()

# 4D-Var results

In [None]:
from L96_emulator.run import setup
from L96_emulator.run_DA import setup_4DVar
from L96_emulator.likelihood import ObsOp_identity, ObsOp_subsampleGaussian, ObsOp_rotsampleGaussian
from L96_emulator.data_assimilation import GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

clrs, lgnd = ['w', 'b', 'c', 'g', 'y', 'r', 'm', 'k'], []

def get_analysis_rmses_4DVar_exp(exp_ids, ifplot=False):

    rmses_total = np.zeros(len(exp_ids))
    win_lens = np.zeros(len(exp_ids))

    if ifplot:
        plt.figure(figsize=(16,6))
        plt.subplot(1,3,1)
        for clr in clrs:
            plt.plot(-100, -1, 'o-', color=clr, linewidth=2.5)    

    for eid, exp_id in enumerate(exp_ids):

        exp_names = os.listdir('experiments_DA/')   
        conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
        fn = save_dir + 'out.npy'

        out = np.load(res_dir + fn, allow_pickle=True)[()]

        J = args['J']
        n_steps = args['n_steps']
        T_win = args['T_win'] 
        T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
        dt = args['dt']

        data = out['out']
        y, m = out['y'], out['m']
        x_sols = out['x_sols']
        losses, times = out['losses'], out['times']

        assert T_win == out['T_win']

        mses = np.zeros(((data.shape[0] - T_win) // T_shift + 1, data.shape[1]))
        for i in range(len(mses)):
            mse = np.nanmean((x_sols[i:i+1] - data)**2, axis=(-2, -1))
            mses[i] = mse[i *T_shift]

        if ifplot:

            xx = np.arange(0, data.shape[0] - T_win, T_shift)
            plt.subplot(1,3,1)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,2)
            plt.plot(xx, np.nanmean(mses, axis=1), 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,3)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.axis([0, len(data)-1, 0, 2])
            print(np.nanmean(mses[1:]))
            lgnd.append('window length='+str(T_win))

        rmses_total[eid] = np.sqrt(np.nanmean(mses))
        win_lens[eid] = T_win

    if ifplot: 

        plt.subplot(1,3,1)
        plt.title('individial trials')
        plt.ylabel('initial state MSE')
        plt.subplot(1,3,2)
        plt.title('averages over trials')
        plt.legend(lgnd[:3])
        plt.xlabel('time t')
        plt.subplot(1,3,3)
        plt.title('inidividual trials, zoom-in on small MSEs')
        plt.show()
        
    return win_lens, rmses_total


def get_pred_rmses_4DVar_exp(exp_id):

    exp_names = os.listdir('experiments_DA/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
    args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
    args.pop('conf_exp')

    assert args['T_win'] == 64 # we want 4d integration window here

    K,J = args['K'], args['J']
    T_win = args['T_win']

    model_pars = {
        'exp_id' : args['model_exp_id'],
        'model_forwarder' : 'rk4_default',
        'K_net' : args['K'],
        'J_net' : args['J'],
        'dt_net' : args['dt']
    }

    model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

    obs_operator = args['obs_operator']
    obs_pars = {}
    if obs_operator=='ObsOp_subsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_subsampleGaussian
        obs_pars['obs_operator_args'] = {'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']}
    elif obs_operator=='ObsOp_identity':
        obs_pars['obs_operator'] = ObsOp_identity
        obs_pars['obs_operator_args'] = {}
    elif obs_operator=='ObsOp_rotsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
        obs_pars['obs_operator_args'] = {'frq' : args['obs_operator_frq'], 
                                         'sigma2' : args['obs_operator_sig2']}
    else:
        raise NotImplementedError()
    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

    prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                              scale=1.*torch.ones((1,J+1,K)))

    # ### define generative model for observed data
    gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

    forecast_win = int(120/1.5) # 5d forecast
    eval_every = int(6/1.5) # every 6h


    save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
    fn = save_dir + 'out.npy'

    out = np.load(res_dir + fn, allow_pickle=True)[()]

    J = args['J']
    n_steps = args['n_steps']
    T_win = args['T_win'] 
    T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
    dt = args['dt']

    data = out['out']
    y, m = out['y'], out['m']
    x_sols = out['x_sols']
    losses, times = out['losses'], out['times']

    assert T_win == out['T_win']

    mses = np.zeros(((data.shape[0] - forecast_win - T_win) // T_shift + 1, forecast_win//eval_every+1, y.shape[1]))
    for i in range(len(mses)):
        forecasts = gen._forward(x=as_tensor(x_sols[i]), T_obs=T_win + np.arange(0,forecast_win+1,eval_every))
        n = i * T_shift + T_win
        for j in range(mses.shape[1]): # loop over integration windows
            forecast = forecasts[j].detach().cpu().numpy()
            y_obs = data[n+j*eval_every] # sortL96intoChannels(y[n+j*eval_every],J=J)
            mses[i,j] = np.nanmean((forecast - y_obs)**2, axis=(-2, -1))

    pred_lens = 1.5/24 * np.arange(0, forecast_win+1, eval_every)


    return pred_lens, np.sqrt(mses)

In [None]:
exp_ids_minimalNet = ['14', '15', '16', '17', '18', '19', '20', '21']
exp_ids_bilinNet = ['22', '23', '24', '25', '26', '27', '28', '29']
exp_ids_analyticNet = ['30', '31', '32', '33', '34', '35', '36', '37']
exp_ids_deepNet = ['38', '39', '40', '41', '42', '43', '44', '45']

win_lens_minimalNet, rmses_analysis_minimalNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_minimalNet)
win_lens_bilinNet, rmses_analysis_bilinNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_bilinNet)
win_lens_analyticNet, rmses_analysis_analyticNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_analyticNet)
win_lens_deepNet, rmses_analysis_deepNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_deepNet)

pred_lens_minimalNet, rmses_pred__minimalNet = get_pred_rmses_4DVar_exp(exp_id='21')
pred_lens_bilinNet, rmses_pred__bilinNet = get_pred_rmses_4DVar_exp(exp_id='29')
pred_lens_analyticNet, rmses_pred_analyticNet = get_pred_rmses_4DVar_exp(exp_id='37')
pred_lens_deepNet, rmses_pred_deepNet = get_pred_rmses_4DVar_exp(exp_id='45')

In [None]:
fontsize=14

plt.figure(figsize=(12,5))
ax = plt.subplot(1,2,1)
plt.plot(win_lens_analyticNet*1.5/24, rmses_analysis_analyticNet, 
         '-', color='r', linewidth=2.5, label='analytic')
plt.plot(win_lens_minimalNet*1.5/24, rmses_analysis_minimalNet, 
         '-', color='b', linewidth=2.5, label='sqrtNet')
plt.plot(win_lens_bilinNet*1.5/24, rmses_analysis_bilinNet, 
         '-', color='orange', linewidth=2.5, label='bilinNet')
plt.plot(win_lens_deepNet*1.5/24, rmses_analysis_deepNet, 
         '-', color='g', linewidth=2.5, label='deepNet')
plt.xlabel('integration window length [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xticks(0.5*np.arange(1, 8.1), fontsize=fontsize)
plt.legend(fontsize=fontsize, frameon=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = plt.subplot(1,2,2)
plt.plot(pred_lens_analyticNet,np.nanmean(rmses_pred_analyticNet,axis=(0,-1)), 
         '-', color='r', linewidth=2.5, label='analytic')
plt.plot(pred_lens_minimalNet,np.nanmean(rmses_pred__minimalNet,axis=(0,-1)), 
         '-', color='b', linewidth=2.5, label='sqrtlNet')
plt.plot(pred_lens_bilinNet,np.nanmean(rmses_pred__bilinNet,axis=(0,-1)), 
         '-', color='orange', linewidth=2.5, label='bilinNet')
plt.plot(pred_lens_bilinNet,np.nanmean(rmses_pred__bilinNet,axis=(0,-1)), 
         '-', color='g', linewidth=2.5, label='deepNet')
plt.xlabel('forecast time [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.5, 1.0, 1.5], fontsize=fontsize)
plt.xticks(np.arange(5.1), fontsize=fontsize)
plt.legend(fontsize=fontsize, frameon=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig(res_dir + 'figs/4DVar.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
plt.show()

# Parametrization results

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.parametrization import Parametrized_twoLevel_L96, Parametrization_lin
from L96_emulator.networks import Model_forwarder_rk4default
from L96_emulator.run_parametrization import setup_parametrization
from L96_emulator.data_assimilation import get_model
from L96_emulator.util import as_tensor, dtype_np, sortL96intoChannels, sortL96fromChannels
from L96sim.L96_base import f1, f2, pf2
import numpy as np
import torch
import os

exp_id = '01'

exp_names = os.listdir('experiments_parametrization/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
args = setup_parametrization(conf_exp=f'experiments_parametrization/{conf_exp}.yml')
args.pop('conf_exp')

save_dir = 'results/parametrization/' + args['exp_id'] + '/'
out = np.load(res_dir + save_dir + 'out.npy', allow_pickle=True)[()]

T_dur = 5000
X_init = out['X_init']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : args['model_forwarder'],
    'K_net' : args['K'],
    'J_net' : 0,
    'dt_net' : args['dt']
}
# trained parametrized model
model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')
param_train = Parametrization_lin(a=as_tensor(out['param_train_state_dict']['a']), 
                                  b=as_tensor(out['param_train_state_dict']['b']))
model_parametrized_train = Parametrized_twoLevel_L96(emulator=model, parametrization=param_train)
model_forwarder_parametrized_train = Model_forwarder_rk4default(model=model_parametrized_train, dt=args['dt'])

# initial and reference parametrized models
param_ref = Parametrization_lin(a=as_tensor(np.array([-0.31])), b=as_tensor(np.array([-0.2])))
param_init = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4])))
model_parametrized_init = Parametrized_twoLevel_L96(emulator=model, parametrization=param_init)
model_forwarder_parametrized_init = Model_forwarder_rk4default(model=model_parametrized_init, dt=args['dt'])
model_parametrized_ref = Parametrized_twoLevel_L96(emulator=model, parametrization=param_ref)
model_forwarder_parametrized_ref = Model_forwarder_rk4default(model=model_parametrized_ref, dt=args['dt'])

# ground-truth high-res model
dX_dt = np.empty(args['K']*(args['J']+1), dtype=dtype_np)
def fun(t, x):
    return f2(x, args['l96_F'], args['l96_h'], args['l96_b'], args['l96_c'], dX_dt, args['K'], args['J'])
class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=args['J'])
model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=args['dt'])

model_forwarders = [Model_forwarder_rk4default(model, dt=args['dt']),
                    model_forwarder_parametrized_init, 
                    model_forwarder_parametrized_train,
                    model_forwarder_parametrized_ref,
                    model_forwarder_np]
X_inits = [X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init.copy()]
Js = [0, 0, 0, 0, args['J']]
panel_titles=['one-level L96', 
              'initial param.', 
              'learned param.', 
              'reference param.', 
              'two-level L96']
sols = [np.nan for n in range(len(model_forwarders))]
for i_model in range(len(model_forwarders)): 
    
    model_forwarder_i, X_init_i, J_i = model_forwarders[i_model], X_inits[i_model], Js[i_model]

    def model_simulate(y0, dy0, n_steps):
        x = np.empty((n_steps+1, *y0.shape[1:]))
        x[0] = y0.copy()
        xx = as_tensor(x[0]).reshape(1,1,-1)
        for i in range(1,n_steps+1):
            xx = model_forwarder_i(xx.reshape(1,J_i+1,-1))
            x[i] = xx.detach().cpu().numpy().copy()
        return x

    print('forwarding model ' + panel_titles[i_model])
    sols[i_model] = model_simulate(y0=sortL96intoChannels(X_init_i,J=J_i), dy0=None, n_steps=T_dur)


In [None]:
fontsize=14
plt.figure(figsize=(12,8))

rmses = np.zeros((len(model_forwarders), T_dur+1))
for i_model in range(len(model_forwarders)): 
    
    plt.subplot(2,len(model_forwarders),i_model+1)
    plt.imshow(sortL96fromChannels(sols[i_model][:,:1,:]).T, aspect='auto')
    plt.colorbar()
    plt.title(panel_titles[i_model], fontsize=fontsize)
    
    if i_model == 0:
        plt.ylabel('location k', fontsize=fontsize)
    if i_model == 2:
        plt.xlabel('time [steps]', fontsize=fontsize)
        
    rmses[i_model,:] = np.sqrt(np.mean((sols[i_model][:,0,:] - sols[-1][:,0,:])**2, axis=1))
    plt.yticks([], fontsize=fontsize)
    plt.xticks([0, T_dur/2, T_dur], fontsize=fontsize)
    
plt.subplot(2,2,3)
plt.text(0.5, 0.5, 'tbd')
plt.ylabel('parametrization parameters', fontsize=fontsize)
plt.xlabel('dataset size', fontsize=fontsize)
plt.xticks([], fontsize=fontsize)
plt.yticks([], fontsize=fontsize)

    
plt.subplot(2,2,4)
for i_model in range(len(model_forwarders)-1):
    plt.plot(rmses[i_model], label=panel_titles[i_model])
plt.legend(fontsize=fontsize, frameon=False)
plt.ylabel('RMSE', fontsize=fontsize)
plt.xlabel('time [steps]', fontsize=fontsize)
plt.xticks([0, T_dur/2, T_dur], fontsize=fontsize)
plt.yticks([0, 3, 6], fontsize=fontsize)
plt.savefig(res_dir + 'figs/param.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
plt.show()