In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

In [None]:
from L96_emulator.run import setup
from L96_emulator.run_DA import setup_4DVar
from L96_emulator.likelihood import ObsOp_identity, ObsOp_subsampleGaussian, ObsOp_rotsampleGaussian
from L96_emulator.data_assimilation import GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

clrs, lgnd = ['w', 'b', 'c', 'g', 'y', 'r', 'm', 'k'], []

def get_analysis_rmses_4DVar_exp(exp_ids, ifplot=False):

    rmses_total = np.zeros(len(exp_ids))
    win_lens = np.zeros(len(exp_ids))

    if ifplot:
        plt.figure(figsize=(16,6))
        plt.subplot(1,3,1)
        for clr in clrs:
            plt.plot(-100, -1, 'o-', color=clr, linewidth=2.5)    

    for eid, exp_id in enumerate(exp_ids):

        exp_names = os.listdir('experiments_DA/')   
        conf_exp = exp_names[np.where(np.array([name.split('_')[0] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
        fn = save_dir + 'out.npy'

        out = np.load(res_dir + fn, allow_pickle=True)[()]

        J = args['J']
        n_steps = args['n_steps']
        T_win = args['T_win'] 
        T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
        dt = args['dt']

        data = out['out']
        y, m = out['y'], out['m']
        x_sols = out['x_sols']
        losses, times = out['losses'], out['times']

        assert T_win == out['T_win']

        mses = np.zeros(((data.shape[0] - T_win) // T_shift + 1, data.shape[1]))
        for i in range(len(mses)):
            mse = np.nanmean((x_sols[i:i+1] - data)**2, axis=(-2, -1))
            mses[i] = mse[i *T_shift]

        if ifplot:

            xx = np.arange(0, data.shape[0] - T_win, T_shift)
            plt.subplot(1,3,1)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,2)
            plt.plot(xx, np.nanmean(mses, axis=1), 'o-', color=clrs[eid], linewidth=2.5)
            plt.xlim(0, len(data))
            plt.subplot(1,3,3)
            plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
            plt.axis([0, len(data)-1, 0, 2])
            print(np.nanmean(mses[1:]))
            lgnd.append('window length='+str(T_win))

        rmses_total[eid] = np.sqrt(np.nanmean(mses))
        win_lens[eid] = T_win

    if ifplot: 

        plt.subplot(1,3,1)
        plt.title('individial trials')
        plt.ylabel('initial state MSE')
        plt.subplot(1,3,2)
        plt.title('averages over trials')
        plt.legend(lgnd[:3])
        plt.xlabel('time t')
        plt.subplot(1,3,3)
        plt.title('inidividual trials, zoom-in on small MSEs')
        plt.show()
        
    return win_lens, rmses_total


def get_pred_rmses_4DVar_exp(exp_id, forecast_len=120):

    exp_names = os.listdir('experiments_DA/')   
    conf_exp = exp_names[np.where(np.array([name.split('_')[0] for name in exp_names])==str(exp_id))[0][0]][:-4]
    args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
    args.pop('conf_exp')

    #assert args['T_win'] == 64 # we want 4d integration window here

    K,J = args['K'], args['J']
    T_win = args['T_win']

    model_pars = {
        'exp_id' : args['model_exp_id'],
        'model_forwarder' : 'rk4_default',
        'K_net' : args['K'],
        'J_net' : args['J'],
        'dt_net' : args['dt']
    }

    model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

    obs_operator = args['obs_operator']
    obs_pars = {}
    if obs_operator=='ObsOp_subsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_subsampleGaussian
        obs_pars['obs_operator_args'] = {'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']}
    elif obs_operator=='ObsOp_identity':
        obs_pars['obs_operator'] = ObsOp_identity
        obs_pars['obs_operator_args'] = {}
    elif obs_operator=='ObsOp_rotsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
        obs_pars['obs_operator_args'] = {'frq' : args['obs_operator_frq'], 
                                         'sigma2' : args['obs_operator_sig2']}
    else:
        raise NotImplementedError()
    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

    prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                              scale=1.*torch.ones((1,J+1,K)))

    # ### define generative model for observed data
    gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

    forecast_win = int(forecast_len/1.5) # 5d forecast
    eval_every = int(6/1.5) # every 6h


    save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
    fn = save_dir + 'out.npy'

    out = np.load(res_dir + fn, allow_pickle=True)[()]

    J = args['J']
    n_steps = args['n_steps']
    T_win = args['T_win'] 
    T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
    dt = args['dt']

    data = out['out']
    y, m = out['y'], out['m']
    x_sols = out['x_sols']
    print('percent of NaN sols', str(np.mean(np.isnan(x_sols))))
    losses, times = out['losses'], out['times']

    assert T_win == out['T_win']

    mses = np.zeros(((data.shape[0] - forecast_win - T_win) // T_shift + 1, forecast_win//eval_every+1, y.shape[1]))
    for i in range(len(mses)):
        forecasts = gen._forward(x=as_tensor(x_sols[i]), T_obs=T_win + np.arange(0,forecast_win+1,eval_every))
        n = i * T_shift + T_win
        for j in range(mses.shape[1]): # loop over integration windows
            forecast = forecasts[j].detach().cpu().numpy()
            if np.any(np.isnan(forecast)):
                print('warning - had NaN in forecasts!')
            y_obs = data[n+j*eval_every]
            mses[i,j] = np.nanmean((forecast - y_obs)**2, axis=(-2, -1))

    pred_lens = 1.5/24 * np.arange(0, forecast_win+1, eval_every)


    return pred_lens, np.sqrt(mses)

In [None]:

#exp_ids_minimalNet  = ['74', '70', '66', '62', '58', '54', '50', '46', '78']
#exp_ids_bilinNet    = ['75', '71', '67', '63', '59', '55', '51', '47', '79']
exp_ids_analyticNet = ['76', '72', '68', '64', '60', '56', '52', '48', '80']
#exp_ids_deepNet     = ['77', '73', '69', '65', '61', '57', '53', '49', '81']
exp_ids_deepNet     = ['92', '91', '90', '89', '88', '87', '86', '84', '85']


win_lens_analyticNet, rmses_analysis_analyticNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_analyticNet)
#win_lens_minimalNet, rmses_analysis_minimalNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_minimalNet)
#win_lens_bilinNet, rmses_analysis_bilinNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_bilinNet)
win_lens_deepNet, rmses_analysis_deepNet = get_analysis_rmses_4DVar_exp(exp_ids=exp_ids_deepNet)


#pred_lens_minimalNet, rmses_pred_minimalNet = get_pred_rmses_4DVar_exp(exp_id='46')
#pred_lens_bilinNet, rmses_pred_bilinNet = get_pred_rmses_4DVar_exp(exp_id='47')
pred_lens_analyticNet, rmses_pred_analyticNet = get_pred_rmses_4DVar_exp(exp_id='48')
pred_lens_deepNet, rmses_pred_deepNet = get_pred_rmses_4DVar_exp(exp_id='84')


In [None]:
args.keys()

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


exp_id = '89'
exp_names = os.listdir('experiments_DA/')
conf_exp = exp_names[np.where(np.array([name.split('_')[0] for name in exp_names])==str(exp_id))[0][0]][:-4]
args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
args.pop('conf_exp')

#assert args['T_win'] == 64 # we want 4d integration window here

K,J = args['K'], args['J']
T_win = args['T_win']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : 'rk4_default',
    'K_net' : args['K'],
    'J_net' : args['J'],
    'dt_net' : args['dt']
}
model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')
obs_pars = {'obs_operator' : ObsOp_rotsampleGaussian,
            'obs_operator_args' : {'frq' : args['obs_operator_frq'], 
                             'sigma2' : args['obs_operator_sig2']}}
model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])
prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                          scale=1.*torch.ones((1,J+1,K)))
gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
fn = save_dir + 'out.npy'
out = np.load(res_dir + fn, allow_pickle=True)[()]

nc =4
win_length, forecast_length = args['T_win'], 80
n = 50 # pick an integration window
t0 = n*args['T_shift'] - win_length//2
t1 = n*args['T_shift']
t2 = t1 + win_length
t3 = t2 + forecast_length

data = out['y'][t0:t3,nc].T
data[data==0] = np.nan # display missing values as missing
data[:,-t3+t2:] = np.nan # mask out future

recon = np.nan * np.zeros_like(data)
background = torch.stack(gen._forward(as_tensor(out['x_sols'][n-1,nc]), 
                          T_obs=np.arange(win_length//2,win_length))).squeeze().detach().cpu().numpy().T
forecast = torch.stack(gen._forward(as_tensor(out['x_sols'][n,nc]), 
                          T_obs=np.arange(win_length+forecast_length))).squeeze().detach().cpu().numpy().T
recon[:,:win_length//2] = np.nan * background
recon[:,win_length//2:] = forecast

plt.figure(figsize=(16,8))

ax = plt.subplot(3,2,1)
plt.imshow(data, aspect='auto')
plt.yticks([], fontsize=fontsize)
plt.plot([t1-t0-0.5, t1-t0-0.5], [0,K], linewidth=2, color='orange', label='integration window')
plt.plot([t2-t0-0.5, t2-t0-0.5], [0,K], linewidth=2, color='orange')
plt.plot([t1-t0-0.5, t2-t0-0.5], [0,0], linewidth=2, color='orange')
plt.plot([t1-t0-0.5, t2-t0-0.5], [K,K], linewidth=2, color='orange')
plt.plot([t2-t0, t2-t0], [0,K], linewidth=2, color='purple', label='forecast window')
plt.plot([t3-t0, t3-t0], [0,K], linewidth=2, color='purple')
plt.plot([t2-t0, t3-t0], [0,0], linewidth=2, color='purple')
plt.plot([t2-t0, t3-t0], [K,K], linewidth=2, color='purple')
plt.legend(fontsize=fontsize, frameon=False)
plt.xticks([win_length//2, int(0.5/args['dt']), int(1.0/args['dt']), int(1.5/args['dt'])], 
           [r'$x_0$', 0.5, 1.0, 1.5], fontsize=fontsize)
plt.xlabel('time [au]', fontsize=fontsize)
plt.axis([0, t3-t0, 0, K+0.5])
axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                   bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
plt.colorbar(cax=axins)
        

ax = plt.subplot(3,2,3)
plt.plot(sortL96fromChannels(out['out'][t1,nc]), color='blue', label = 'true state', linewidth=1.5)
plt.plot(background[:,-1], '--', color='black', label = 'background')
plt.plot(sortL96fromChannels(out['x_sols'][n,nc]), color='orange', label = 'analysis', linewidth=1.5)
plt.legend(fontsize=fontsize, bbox_to_anchor=(1.0, 1.0), handlelength=0.7, frameon=False)
plt.xlabel(r'position $k$', fontsize=fontsize)
plt.ylabel(r'state $x_0$', fontsize=fontsize)
plt.xticks([10, 20, 30, 40], fontsize=fontsize)
plt.yticks([-5, 0, 5, 10], fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

box = ax.get_position()
box.x0 += 0.1 * (box.x1-box.x0)
box.x1 = box.x0 + 0.8 * (box.x1-box.x0)
box.y0 += 0.15 * (box.y1-box.y0)
box.y1 = box.y0 + 0.8 * (box.y1-box.y0)
ax.set_position(box)

print(np.mean((sortL96fromChannels(out['out'][t1,nc])-background[:,-1])**2))
print(np.mean((sortL96fromChannels(out['out'][t1,nc])-sortL96fromChannels(out['x_sols'][n,nc]))**2))
print(np.mean((sortL96fromChannels(out['out'][t1,nc])-out['x_sols'][n,nc])**2))




ax = plt.subplot(3,2,5)
plt.imshow(recon, aspect='auto')
plt.yticks([], fontsize=fontsize)
plt.plot([t1-t0-0.5, t1-t0-0.5], [0,K], linewidth=2, color='orange', label='integration window')
plt.plot([t2-t0-0.5, t2-t0-0.5], [0,K], linewidth=2, color='orange')
plt.plot([t1-t0-0.5, t2-t0-0.5], [0,0], linewidth=2, color='orange')
plt.plot([t1-t0-0.5, t2-t0-0.5], [K,K], linewidth=2, color='orange')
plt.plot([t2-t0, t2-t0], [0,K], linewidth=2, color='purple', label='forecast window')
plt.plot([t3-t0, t3-t0], [0,K], linewidth=2, color='purple')
plt.plot([t2-t0, t3-t0], [0,0], linewidth=2, color='purple')
plt.plot([t2-t0, t3-t0], [K,K], linewidth=2, color='purple')
plt.axis([0, t3-t0, 0, K])
plt.xticks([win_length//2, int(0.5/args['dt']), int(1.0/args['dt']), int(1.5/args['dt'])], 
           [r'$x_0$', 0.5, 1.0, 1.5], fontsize=fontsize)
plt.xlabel('time [au]', fontsize=fontsize)
axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                   bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
plt.colorbar(cax=axins)
        
    
    
    
    
ax = plt.subplot(2,4,3)
plt.plot(win_lens_analyticNet*1.5/24, rmses_analysis_analyticNet, 
         '-', color='blue', linewidth=2.5, label='true model')
plt.plot(win_lens_deepNet*1.5/24, rmses_analysis_deepNet, 
         '-', color='orange', linewidth=2.5, label='emulator')
plt.xlabel('integration window length [au]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xticks(win_lens_deepNet[::4]*1.5/24, win_lens_deepNet[::4]*args['dt'], fontsize=fontsize)
plt.legend(fontsize=fontsize, frameon=False, handlelength=0.7)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
box = ax.get_position()
box.x0 += 0.3 * (box.x1-box.x0)
ax.set_position(box)



ax = plt.subplot(2,4,7)
rmses_preds = [rmses_pred_analyticNet, rmses_pred_deepNet]
pred_lens = [pred_lens_analyticNet, pred_lens_deepNet]
clrs = ['b', 'orange']
labels = ['true model', 'emulator']

for rmses_pred, pred_len, clr, lbl in zip(rmses_preds, pred_lens, clrs, labels): 
    nanmean, nanstd = np.nanmean(rmses_pred, axis=(0,-1)), np.nanstd(rmses_pred, axis=(0,-1))
    #nanstd /= np.sqrt(rmses_pred.shape[0] * rmses_pred.shape[-1]) 
    plt.plot(pred_len, nanmean, '-', color=clr, linewidth=2.5, label=lbl)
    plt.plot(pred_len, nanmean-nanstd, '--', color=clr, linewidth=1.0)
    plt.plot(pred_len, nanmean+nanstd, '--', color=clr, linewidth=1.0)

plt.xlabel('forecast time [au]', fontsize=fontsize)
plt.ylabel('RMSE', fontsize=fontsize)
plt.yticks([0.5, 1.0, 1.5, 2.0, 2.5], fontsize=fontsize)
#plt.xticks(pred_len[::8], pred_len[::8]*16*args['dt'], fontsize=fontsize)
plt.xticks(np.linspace(0,5,3), np.linspace(0,5,3)*16*args['dt'], fontsize=fontsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.gcf().text(0.1,   0.9, 'a)', fontsize=fontsize, weight='bold')
plt.gcf().text(0.53, 0.9, 'b)', fontsize=fontsize, weight='bold')
box = ax.get_position()
box.x0 += 0.3 * (box.x1-box.x0)
ax.set_position(box)

plt.savefig(res_dir + 'figs/4DVar.pdf', bbox_inches='tight', pad_inches=0, frameon=False)


plt.show()