In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

# Emulator evaluation results

# 4D-Var results

In [None]:
from L96_emulator.run import setup
from L96_emulator.run_DA import setup_4DVar
from L96_emulator.likelihood import ObsOp_identity, ObsOp_subsampleGaussian, ObsOp_rotsampleGaussian
from L96_emulator.data_assimilation import GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

clrs, lgnd = ['w', 'b', 'c', 'g', 'y', 'r', 'm', 'k'], []

def get_analysis_rmses_4DVar_exp(exp_ids, ifplot=False):

    rmses_total = np.zeros(len(exp_ids))
    win_lens = np.zeros(len(exp_ids))

    if ifplot:
        plt.figure(figsize=(16,6))
        plt.subplot(1,3,1)
        for clr in clrs:
            plt.plot(-100, -1, 'o-', color=clr, linewidth=2.5)    

    for eid, exp_id in enumerate(exp_ids):

        exp_names = os.listdir('experiments_DA/')   
        conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
        fn = save_dir + 'out.npy'

        out = np.load(res_dir + fn, allow_pickle=True)[()]

        J = args['J']
        n_steps = args['n_steps']
        T_win = args['T_win'] 
        T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
        dt = args['dt']

        data = out['out']
        y, m = out['y'], out['m']
        x_sols = out['x_sols']
        losses, times = out['losses'], out['times']

        assert T_win == out['T_win']

        mses = np.zeros(((data.shape[0] - T_win) // T_shift + 1, data.shape[1]))
        for i in range(len(mses)):
            mse = np.nanmean((x_sols[i:i+1] - data)**2, axis=(-2, -1))
            mses[i] = mse[i *T_shift]

        if ifplot:

            xx = np.arange(0, data.shape[0] - T_win, T_shift)
            plt.subplot(1,3,1)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,2)
            plt.plot(xx, np.nanmean(mses, axis=1), 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,3)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.axis([0, len(data)-1, 0, 2])
            print(np.nanmean(mses[1:]))
            lgnd.append('window length='+str(T_win))

        rmses_total[eid] = np.sqrt(np.nanmean(mses))
        win_lens[eid] = T_win

    if ifplot: 

        plt.subplot(1,3,1)
        plt.title('individial trials')
        plt.ylabel('initial state MSE')
        plt.subplot(1,3,2)
        plt.title('averages over trials')
        plt.legend(lgnd[:3])
        plt.xlabel('time t')
        plt.subplot(1,3,3)
        plt.title('inidividual trials, zoom-in on small MSEs')
        plt.show()
        
    return win_lens, rmses_total


def get_pred_rmses_4DVar_exp(exp_id):

    exp_names = os.listdir('experiments_DA/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
    args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
    args.pop('conf_exp')

    assert args['T_win'] == 64 # we want 4d integration window here

    K,J = args['K'], args['J']
    T_win = args['T_win']

    model_pars = {
        'exp_id' : args['model_exp_id'],
        'model_forwarder' : 'rk4_default',
        'K_net' : args['K'],
        'J_net' : args['J'],
        'dt_net' : args['dt']
    }

    model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

    obs_operator = args['obs_operator']
    obs_pars = {}
    if obs_operator=='ObsOp_subsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_subsampleGaussian
        obs_pars['obs_operator_args'] = {'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']}
    elif obs_operator=='ObsOp_identity':
        obs_pars['obs_operator'] = ObsOp_identity
        obs_pars['obs_operator_args'] = {}
    elif obs_operator=='ObsOp_rotsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
        obs_pars['obs_operator_args'] = {'frq' : args['obs_operator_frq'], 
                                         'sigma2' : args['obs_operator_sig2']}
    else:
        raise NotImplementedError()
    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

    prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                              scale=1.*torch.ones((1,J+1,K)))

    # ### define generative model for observed data
    gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

    forecast_win = int(120/1.5) # 5d forecast
    eval_every = int(6/1.5) # every 6h


    save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
    fn = save_dir + 'out.npy'

    out = np.load(res_dir + fn, allow_pickle=True)[()]

    J = args['J']
    n_steps = args['n_steps']
    T_win = args['T_win'] 
    T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
    dt = args['dt']

    data = out['out']
    y, m = out['y'], out['m']
    x_sols = out['x_sols']
    losses, times = out['losses'], out['times']

    assert T_win == out['T_win']

    mses = np.zeros(((data.shape[0] - forecast_win - T_win) // T_shift + 1, forecast_win//eval_every+1, y.shape[1]))
    for i in range(len(mses)):
        forecasts = gen._forward(x=as_tensor(x_sols[i]), T_obs=T_win + np.arange(0,forecast_win+1,eval_every))
        n = i * T_shift + T_win
        for j in range(mses.shape[1]): # loop over integration windows
            forecast = forecasts[j].detach().cpu().numpy()
            y_obs = data[n+j*eval_every] # sortL96intoChannels(y[n+j*eval_every],J=J)
            mses[i,j] = np.nanmean((forecast - y_obs)**2, axis=(-2, -1))

    pred_lens = 1.5/24 * np.arange(0, forecast_win+1, eval_every)


    return pred_lens, np.sqrt(mses)

In [None]:
exp_ids_minimalNet = ['14', '15', '16', '17', '18', '19', '20', '21']
exp_ids_bilinNet = ['22', '23', '24', '25', '26', '27', '28', '29']
exp_ids_analyticNet = ['30', '31', '32', '33', '34', '35', '36', '37']

win_lens_minimalNet, rmses_analysis_minimalNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_minimalNet)
win_lens_bilinNet, rmses_analysis_bilinNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_bilinNet)
win_lens_analyticNet, rmses_analysis_analyticNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_analyticNet)

pred_lens_minimalNet, rmses_pred__minimalNet = get_pred_rmses_4DVar_exp(exp_id='21')
pred_lens_bilinNet, rmses_pred__bilinNet = get_pred_rmses_4DVar_exp(exp_id='29')
pred_lens_analyticNet, rmses_pred_analyticNet = get_pred_rmses_4DVar_exp(exp_id='37')

In [None]:
fontsize=14

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(win_lens_analyticNet*1.5/24, rmses_analysis_analyticNet, 
         ':', color='k', linewidth=2.5, label='analyticNet')
plt.plot(win_lens_minimalNet*1.5/24, rmses_analysis_minimalNet, 
         '-', color='k', linewidth=2.5, label='minimalNet')
plt.plot(win_lens_bilinNet*1.5/24, rmses_analysis_bilinNet, 
         '--', color='k', linewidth=2.5, label='bilinNet')
plt.xlabel('integration window length [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xticks(0.5*np.arange(1, 8.1), fontsize=fontsize)
plt.legend(fontsize=fontsize)


plt.subplot(1,2,2)
plt.plot(pred_lens_analyticNet,np.nanmean(rmses_pred_analyticNet,axis=(0,-1)), 
         ':', color='k', linewidth=2.5, label='analyticNet')
plt.plot(pred_lens_minimalNet,np.nanmean(rmses_pred__minimalNet,axis=(0,-1)), 
         '-', color='k', linewidth=2.5, label='minimalNet')
plt.plot(pred_lens_bilinNet,np.nanmean(rmses_pred__bilinNet,axis=(0,-1)), 
         '--', color='k', linewidth=2.5, label='bilinNet')
plt.xlabel('forecast time [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.5, 1.0, 1.5], fontsize=fontsize)
plt.xticks(0.5*np.arange(10.1), fontsize=fontsize)
plt.legend(fontsize=fontsize)
plt.savefig(res_dir + 'figs/4DVar.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
plt.show()

# Parametrization results

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.parametrization import Parametrized_twoLevel_L96, Parametrization_lin
from L96_emulator.networks import Model_forwarder_rk4default
from L96_emulator.run_parametrization import setup_parametrization
from L96_emulator.data_assimilation import get_model
from L96_emulator.util import as_tensor, dtype_np, sortL96intoChannels, sortL96fromChannels
from L96sim.L96_base import f1, f2, pf2
import numpy as np
import torch
import os

exp_id = '01'

exp_names = os.listdir('experiments_parametrization/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
args = setup_parametrization(conf_exp=f'experiments_parametrization/{conf_exp}.yml')
args.pop('conf_exp')

save_dir = 'results/parametrization/' + args['exp_id'] + '/'
out = np.load(res_dir + save_dir + 'out.npy', allow_pickle=True)[()]

T_dur = 5000
X_init = out['X_init']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : args['model_forwarder'],
    'K_net' : args['K'],
    'J_net' : 0,
    'dt_net' : args['dt']
}
# trained parametrized model
model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')
param_train = Parametrization_lin(a=as_tensor(out['param_train_state_dict']['a']), 
                                  b=as_tensor(out['param_train_state_dict']['b']))
model_parametrized_train = Parametrized_twoLevel_L96(emulator=model, parametrization=param_train)
model_forwarder_parametrized_train = Model_forwarder_rk4default(model=model_parametrized_train, dt=args['dt'])

# initial and reference parametrized models
param_ref = Parametrization_lin(a=as_tensor(np.array([-0.31])), b=as_tensor(np.array([-0.2])))
param_init = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4])))
model_parametrized_init = Parametrized_twoLevel_L96(emulator=model, parametrization=param_init)
model_forwarder_parametrized_init = Model_forwarder_rk4default(model=model_parametrized_init, dt=args['dt'])
model_parametrized_ref = Parametrized_twoLevel_L96(emulator=model, parametrization=param_ref)
model_forwarder_parametrized_ref = Model_forwarder_rk4default(model=model_parametrized_ref, dt=args['dt'])

# ground-truth high-res model
dX_dt = np.empty(args['K']*(args['J']+1), dtype=dtype_np)
def fun(t, x):
    return f2(x, args['l96_F'], args['l96_h'], args['l96_b'], args['l96_c'], dX_dt, args['K'], args['J'])
class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=args['J'])
model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=args['dt'])

model_forwarders = [Model_forwarder_rk4default(model, dt=args['dt']),
                    model_forwarder_parametrized_init, 
                    model_forwarder_parametrized_train,
                    model_forwarder_parametrized_ref,
                    model_forwarder_np]
X_inits = [X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init[:,:args['K']].copy(), 
           X_init.copy()]
Js = [0, 0, 0, 0, args['J']]
panel_titles=['one-level L96', 
              'initial param.', 
              'learned param.', 
              'reference param.', 
              'two-level L96']
sols = [np.nan for n in range(len(model_forwarders))]
for i_model in range(len(model_forwarders)): 
    
    model_forwarder_i, X_init_i, J_i = model_forwarders[i_model], X_inits[i_model], Js[i_model]

    def model_simulate(y0, dy0, n_steps):
        x = np.empty((n_steps+1, *y0.shape[1:]))
        x[0] = y0.copy()
        xx = as_tensor(x[0]).reshape(1,1,-1)
        for i in range(1,n_steps+1):
            xx = model_forwarder_i(xx.reshape(1,J_i+1,-1))
            x[i] = xx.detach().cpu().numpy().copy()
        return x

    print('forwarding model ' + panel_titles[i_model])
    sols[i_model] = model_simulate(y0=sortL96intoChannels(X_init_i,J=J_i), dy0=None, n_steps=T_dur)


In [None]:
fontsize=14
plt.figure(figsize=(12,8))

rmses = np.zeros((len(model_forwarders), T_dur+1))
for i_model in range(len(model_forwarders)): 
    
    plt.subplot(2,len(model_forwarders),i_model+1)
    plt.imshow(sortL96fromChannels(sols[i_model][:,:1,:]).T, aspect='auto')
    plt.colorbar()
    plt.title(panel_titles[i_model], fontsize=fontsize)
    
    if i_model == 0:
        plt.ylabel('location k', fontsize=fontsize)
    if i_model == 2:
        plt.xlabel('time [steps]', fontsize=fontsize)
        
    rmses[i_model,:] = np.sqrt(np.mean((sols[i_model][:,0,:] - sols[-1][:,0,:])**2, axis=1))
    plt.yticks([], fontsize=fontsize)
    plt.xticks([0, T_dur/2, T_dur], fontsize=fontsize)
    
plt.subplot(2,2,3)
plt.text(0.5, 0.5, 'tbd')
plt.ylabel('parametrization parameters', fontsize=fontsize)
plt.xlabel('dataset size', fontsize=fontsize)
plt.xticks([], fontsize=fontsize)
plt.yticks([], fontsize=fontsize)

    
plt.subplot(2,2,4)
for i_model in range(len(model_forwarders)-1):
    plt.plot(rmses[i_model], label=panel_titles[i_model])
plt.legend(fontsize=fontsize, frameon=False)
plt.ylabel('RMSE', fontsize=fontsize)
plt.xlabel('time [steps]', fontsize=fontsize)
plt.xticks([0, T_dur/2, T_dur], fontsize=fontsize)
plt.yticks([0, 3, 6], fontsize=fontsize)
plt.savefig(res_dir + 'figs/param.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
plt.show()