In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

In [None]:
from L96_emulator.run import setup
from L96_emulator.run_DA import setup_4DVar
from L96_emulator.likelihood import ObsOp_identity, ObsOp_subsampleGaussian, ObsOp_rotsampleGaussian
from L96_emulator.data_assimilation import GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

clrs, lgnd = ['w', 'b', 'c', 'g', 'y', 'r', 'm', 'k'], []

def get_analysis_rmses_4DVar_exp(exp_ids, ifplot=False):

    rmses_total = np.zeros(len(exp_ids))
    win_lens = np.zeros(len(exp_ids))

    if ifplot:
        plt.figure(figsize=(16,6))
        plt.subplot(1,3,1)
        for clr in clrs:
            plt.plot(-100, -1, 'o-', color=clr, linewidth=2.5)    

    for eid, exp_id in enumerate(exp_ids):

        exp_names = os.listdir('experiments_DA/')   
        conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
        fn = save_dir + 'out.npy'

        out = np.load(res_dir + fn, allow_pickle=True)[()]

        J = args['J']
        n_steps = args['n_steps']
        T_win = args['T_win'] 
        T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
        dt = args['dt']

        data = out['out']
        y, m = out['y'], out['m']
        x_sols = out['x_sols']
        losses, times = out['losses'], out['times']

        assert T_win == out['T_win']

        mses = np.zeros(((data.shape[0] - T_win) // T_shift + 1, data.shape[1]))
        for i in range(len(mses)):
            mse = np.nanmean((x_sols[i:i+1] - data)**2, axis=(-2, -1))
            mses[i] = mse[i *T_shift]

        if ifplot:

            xx = np.arange(0, data.shape[0] - T_win, T_shift)
            plt.subplot(1,3,1)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,2)
            plt.plot(xx, np.nanmean(mses, axis=1), 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,3)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.axis([0, len(data)-1, 0, 2])
            print(np.nanmean(mses[1:]))
            lgnd.append('window length='+str(T_win))

        rmses_total[eid] = np.sqrt(np.nanmean(mses))
        win_lens[eid] = T_win

    if ifplot: 

        plt.subplot(1,3,1)
        plt.title('individial trials')
        plt.ylabel('initial state MSE')
        plt.subplot(1,3,2)
        plt.title('averages over trials')
        plt.legend(lgnd[:3])
        plt.xlabel('time t')
        plt.subplot(1,3,3)
        plt.title('inidividual trials, zoom-in on small MSEs')
        plt.show()
        
    return win_lens, rmses_total


def get_pred_rmses_4DVar_exp(exp_id):

    exp_names = os.listdir('experiments_DA/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
    args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
    args.pop('conf_exp')

    #assert args['T_win'] == 64 # we want 4d integration window here

    K,J = args['K'], args['J']
    T_win = args['T_win']

    model_pars = {
        'exp_id' : args['model_exp_id'],
        'model_forwarder' : 'rk4_default',
        'K_net' : args['K'],
        'J_net' : args['J'],
        'dt_net' : args['dt']
    }

    model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

    obs_operator = args['obs_operator']
    obs_pars = {}
    if obs_operator=='ObsOp_subsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_subsampleGaussian
        obs_pars['obs_operator_args'] = {'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']}
    elif obs_operator=='ObsOp_identity':
        obs_pars['obs_operator'] = ObsOp_identity
        obs_pars['obs_operator_args'] = {}
    elif obs_operator=='ObsOp_rotsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
        obs_pars['obs_operator_args'] = {'frq' : args['obs_operator_frq'], 
                                         'sigma2' : args['obs_operator_sig2']}
    else:
        raise NotImplementedError()
    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

    prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                              scale=1.*torch.ones((1,J+1,K)))

    # ### define generative model for observed data
    gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

    forecast_win = int(120/1.5) # 5d forecast
    eval_every = int(6/1.5) # every 6h


    save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
    fn = save_dir + 'out.npy'

    out = np.load(res_dir + fn, allow_pickle=True)[()]

    J = args['J']
    n_steps = args['n_steps']
    T_win = args['T_win'] 
    T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
    dt = args['dt']

    data = out['out']
    y, m = out['y'], out['m']
    x_sols = out['x_sols']
    losses, times = out['losses'], out['times']

    assert T_win == out['T_win']

    mses = np.zeros(((data.shape[0] - forecast_win - T_win) // T_shift + 1, forecast_win//eval_every+1, y.shape[1]))
    for i in range(len(mses)):
        forecasts = gen._forward(x=as_tensor(x_sols[i]), T_obs=T_win + np.arange(0,forecast_win+1,eval_every))
        n = i * T_shift + T_win
        for j in range(mses.shape[1]): # loop over integration windows
            forecast = forecasts[j].detach().cpu().numpy()
            y_obs = data[n+j*eval_every] # sortL96intoChannels(y[n+j*eval_every],J=J)
            mses[i,j] = np.nanmean((forecast - y_obs)**2, axis=(-2, -1))

    pred_lens = 1.5/24 * np.arange(0, forecast_win+1, eval_every)


    return pred_lens, np.sqrt(mses)

In [None]:

exp_ids_minimalNet  = ['74', '70', '66', '62', '58', '54', '50', '46', '78']
exp_ids_bilinNet    = ['75', '71', '67', '63', '59', '55', '51', '47', '79']
exp_ids_analyticNet = ['76', '72', '68', '64', '60', '56', '52', '48', '80']
exp_ids_deepNet     = ['77', '73', '69', '65', '61', '57', '53', '49', '81']


win_lens_minimalNet, rmses_analysis_minimalNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_minimalNet)
win_lens_bilinNet, rmses_analysis_bilinNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_bilinNet)
win_lens_analyticNet, rmses_analysis_analyticNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_analyticNet)
win_lens_deepNet, rmses_analysis_deepNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_deepNet)


pred_lens_minimalNet, rmses_pred_minimalNet = get_pred_rmses_4DVar_exp(exp_id='46')
pred_lens_bilinNet, rmses_pred_bilinNet = get_pred_rmses_4DVar_exp(exp_id='47')
pred_lens_analyticNet, rmses_pred_analyticNet = get_pred_rmses_4DVar_exp(exp_id='48')
pred_lens_deepNet, rmses_pred_deepNet = get_pred_rmses_4DVar_exp(exp_id='49')

# shorter time-series (T=25)

#exp_ids_minimalNet = ['14', '15', '16', '17', '18', '19', '20', '21']
#exp_ids_bilinNet = ['22', '23', '24', '25', '26', '27', '28', '29']
#exp_ids_analyticNet = ['30', '31', '32', '33', '34', '35', '36', '37']
#exp_ids_deepNet = ['38', '39', '40', '41', '42', '43', '44', '45']

#pred_lens_minimalNet, rmses_pred_minimalNet = get_pred_rmses_4DVar_exp(exp_id='21')
#pred_lens_bilinNet, rmses_pred_bilinNet = get_pred_rmses_4DVar_exp(exp_id='29')
#pred_lens_analyticNet, rmses_pred_analyticNet = get_pred_rmses_4DVar_exp(exp_id='37')
#pred_lens_deepNet, rmses_pred_deepNet = get_pred_rmses_4DVar_exp(exp_id='45')


In [None]:
fontsize=14

plt.figure(figsize=(12,5))
ax = plt.subplot(1,2,1)
plt.plot(win_lens_analyticNet*1.5/24, rmses_analysis_analyticNet, 
         '-', color='r', linewidth=2.5, label='analytic')
plt.plot(win_lens_minimalNet*1.5/24, rmses_analysis_minimalNet, 
         '-', color='b', linewidth=2.5, label='sqNet')
plt.plot(win_lens_bilinNet*1.5/24, rmses_analysis_bilinNet, 
         '-', color='orange', linewidth=2.5, label='bilinNet')
plt.plot(win_lens_deepNet*1.5/24, rmses_analysis_deepNet, 
         '-', color='g', linewidth=2.5, label='deepNet')
plt.xlabel('integration window length [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xticks(0.5*np.arange(1, 9.1), fontsize=fontsize)
plt.legend(fontsize=fontsize, frameon=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax = plt.subplot(1,2,2)
rmses_preds = [rmses_pred_analyticNet, rmses_pred_minimalNet, rmses_pred_bilinNet, rmses_pred_deepNet]
pred_lens = [pred_lens_analyticNet, pred_lens_minimalNet, pred_lens_bilinNet, pred_lens_deepNet]
clrs = ['r', 'b', 'orange', 'g']
labels = ['analytic', 'sqNet', 'bilinNet', 'deepNet']

for rmses_pred, pred_len, clr, lbl in zip(rmses_preds, pred_lens, clrs, labels): 
    nanmean, nanstd = np.nanmean(rmses_pred, axis=(0,-1)), np.nanstd(rmses_pred, axis=(0,-1))
    #nanstd /= np.sqrt(rmses_pred.shape[0] * rmses_pred.shape[-1]) 
    plt.plot(pred_len, nanmean, '-', color=clr, linewidth=1.5, label=lbl)
    plt.plot(pred_len, nanmean-nanstd, '--', color=clr, linewidth=0.5)
    plt.plot(pred_len, nanmean+nanstd, '--', color=clr, linewidth=0.5)

plt.xlabel('forecast time [d]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.5, 1.0, 1.5, 2.0, 2.5], fontsize=fontsize)
plt.xticks(np.arange(5.1), fontsize=fontsize)
plt.legend(fontsize=fontsize, frameon=False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig(res_dir + 'figs/4DVar.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
plt.show()