# 1.3 Solve a fully-observed inverse problem

Given $x_T$, estimate $x_0$ by matching $G^{(T)}(x_0)$ to $x_T$. Use autodiff on $G$ to calculate gradients of an error metric w.r.t. $x_0$. Compare the resulting rollout to the original `true' simulation.

Compare three approaches:
- $argmin_{x_0} || x_T - G^{(T)}(x_i) ||$, i.e. $T$ steps in one go
- $argmin_{x_i} || x_{i+T_i} - G^{(T_i)}(x_i) ||$, i.e. $T_i$ steps at a time, with $\sum_i T_i = T$. In the extreme case of $T_i=1$, this becomes very similar to implicit numerical methods. Can invertible neural networks help beyond providing better initializations for $x_i$ ? 
- solving backwards: more of the extreme case of $\forall i: T_i=1$, however: Only for some forward numerical solvers can we just reverse time [1] and expect to return to initial conditions. Leap-frog works, but e.g. forward-Euler time-reversed is backward-Euler. 

Generally, how do these approaches differ around \& beyond the horizon of predictability? Which solutions do they pick, and how easy is it to get uncertainties from them?

[1] https://scicomp.stackexchange.com/questions/32736/forward-and-backward-integration-cause-of-errors?noredirect=1&lq=1

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

# static case (setup_DA, run_DA)

In [None]:
from L96_emulator.run_DA import setup_DA
exp_id = '05'
exp_names = os.listdir('experiments_DA/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
args = setup_DA(conf_exp=f'experiments_DA/{conf_exp}.yml')
args.pop('conf_exp')

In [None]:
from L96_emulator.run import setup

save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
fn = save_dir + 'res.npy'

res = np.load(res_dir + fn, allow_pickle=True)[()]

targets=res['targets']
initial_states=res['initial_states']

J = res['J']
n_steps = res['n_steps']
n_chunks = res['n_chunks']
n_chunks_recursive = res['n_chunks_recursive']
T_rollout = res['T_rollout']
dt = res['dt']
back_solve_dt_fac=res['back_solve_dt_fac']
n_starts = res['n_starts']

optimiziation_schemes = res['optimiziation_schemes']

In [None]:
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
from L96_emulator.data_assimilation import get_init
from L96_emulator.data_assimilation import get_model
from L96_emulator.likelihood import GenModel, ObsOp_identity, ObsOp_subsampleGaussian
import torch 

# get model
model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : res['model_forwarder'],
    'K_net' : res['K'],
    'J_net' : res['J'],
    'dt_net' : res['dt']
}
model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

# ### instantiate observation operator
if res['obs_operator']=='ObsOp_subsampleGaussian':
    obs_operator = ObsOp_subsampleGaussian
elif res['obs_operator']=='ObsOp_identity':
    obs_operator = ObsOp_identity

model_observer = obs_operator(**res['obs_operator_args'])

# ### define prior over initial states
prior = torch.distributions.normal.Normal(loc=torch.zeros((1,res['J']+1,res['K'])), 
                                          scale=1.*torch.ones((1,res['J']+1,res['K'])))

# ### define generative model for observed data
gen = GenModel(model_forwarder, model_observer, prior, T=T_rollout, x_init=None)

#"""
appr_sel = {
    'LBFGS_chunks' : False,
    'LBFGS_full_chunks' : False,
    'backsolve' : True, 
    'LBFGS_full_backsolve' : True,
    'LBFGS_full_persistence' : False, 
    'LBFGS_recurse_chunks' : False
}
#"""
appr_sel = optimiziation_schemes

appr_names = {
    #'LBFGS_chunks' : 'optim over single chunk (current chunk error)',
    'LBFGS_chunks' : r'implicit backward solver t=1$\rightarrow$t=0',
    'LBFGS_full_backsolve' : 'optim. initialized from forward solve in reverse', 
    'LBFGS_full_chunks' : 'full optim, init from chunks',
    'LBFGS_full_persistence' : 'optim. initialized from persistence',
    'LBFGS_recurse_chunks' : r'recursive optimization through rollout t-k$\rightarrow$t',
    'backsolve' : r'explicit forward solve in reverse t=1$\rightarrow$t=0'
}

plt_styles = {
    'LBFGS_chunks' : 'r--',
    'LBFGS_full_backsolve' : 'b-', 
    'LBFGS_full_chunks' : 'm-',
    'LBFGS_full_persistence' : 'g-',
    'LBFGS_recurse_chunks' : 'k-',
    'backsolve' : 'c--'        
}


# compute state differences

state_diff, pred_mses = {},{}
if 'LBFGS_chunks' in appr_sel.keys() and appr_sel['LBFGS_chunks']: 
    state_diff['LBFGS_chunks'] = [((sortL96fromChannels(res['targets'][0]) - res['x_sols_LBFGS_chunks'][0])**2).mean(axis=1).reshape(1,-1)]
    state_diff['LBFGS_chunks'] += [(np.diff(res['x_sols_LBFGS_chunks'],axis=0)**2).mean(axis=2)]
    state_diff['LBFGS_chunks'] = np.concatenate(state_diff['LBFGS_chunks'])

    x_sols = sortL96intoChannels(res['x_sols_LBFGS_chunks'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['LBFGS_chunks'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

if 'LBFGS_full_chunks' in appr_sel.keys() and appr_sel['LBFGS_full_chunks']: 
    x_init = res['x_sols_LBFGS_chunks'][res['recursions_per_chunks']-1::res['recursions_per_chunks']]
    state_diff['LBFGS_full_chunks'] = ((x_init-res['x_sols_LBFGS_full_chunks'])**2).mean(axis=-1)

    x_sols = sortL96intoChannels(res['x_sols_LBFGS_full_chunks'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['LBFGS_full_chunks'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

if 'backsolve' in appr_sel.keys() and appr_sel['backsolve']: 
    state_diff['backsolve'] = [((sortL96fromChannels(res['targets'][0]) - res['x_sols_backsolve'][0])**2).mean(axis=1).reshape(1,-1)]
    state_diff['backsolve'] += [(np.diff(res['x_sols_backsolve'],axis=0)**2).mean(axis=2)]
    state_diff['backsolve'] = np.concatenate(state_diff['backsolve'])

    x_sols = sortL96intoChannels(res['x_sols_backsolve'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['backsolve'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

if 'LBFGS_full_backsolve' in appr_sel.keys() and appr_sel['LBFGS_full_backsolve']: 
    x_init = res['x_sols_backsolve'][res['recursions_per_chunks']-1::res['recursions_per_chunks']]
    state_diff['LBFGS_full_backsolve'] = ((x_init-res['x_sols_LBFGS_full_backsolve'])**2).mean(axis=-1)

    x_sols = sortL96intoChannels(res['x_sols_LBFGS_full_backsolve'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['LBFGS_full_backsolve'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

if 'LBFGS_full_persistence' in appr_sel.keys() and appr_sel['LBFGS_full_persistence']: 
    x_init = sortL96fromChannels(res['targets'])
    state_diff['LBFGS_full_persistence'] = ((x_init-res['x_sols_LBFGS_full_persistence'])**2).mean(axis=-1)

    x_sols = sortL96intoChannels(res['x_sols_LBFGS_full_persistence'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['LBFGS_full_persistence'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

if 'LBFGS_recurse_chunks' in appr_sel.keys() and appr_sel['LBFGS_recurse_chunks']: 
    state_diff['LBFGS_recurse_chunks'] = [((sortL96fromChannels(res['targets'][0]) - res['x_sols_LBFGS_recurse_chunks'][0])**2).mean(axis=1).reshape(1,-1)]
    state_diff['LBFGS_recurse_chunks'] += [(np.diff(res['x_sols_LBFGS_recurse_chunks'],axis=0)**2).mean(axis=2)]
    state_diff['LBFGS_recurse_chunks'] = np.concatenate(state_diff['LBFGS_recurse_chunks'])

    x_sols = sortL96intoChannels(res['x_sols_LBFGS_recurse_chunks'], J=res['J'])
    x_pred = np.zeros_like(x_sols)
    for n in range(x_sols.shape[1]):
        gen.set_state(x_sols[:,n])
        x_pred[:,n] = gen.forward(T_obs=[res['T_pred']-1])[0].detach().cpu().numpy()

    pred_mses['LBFGS_recurse_chunks'] = ((res['test_state'] - x_pred)**2).mean(axis=(-2,-1))

state_MSE_peristence = ((res['initial_states'] - res['targets'])**2).mean(axis=(-2,-1))

x_init = get_init(sortL96intoChannels(res['targets_obs'][0],J=J), res['loss_mask'][0], method='interpolate')
state_MSE_peristence_obs = ((x_init - res['initial_states'])**2).mean(axis=(-2,-1))

all_appr_names, all_losses, all_times, all_mses, all_stdfs, all_plt_styles = [], [], [], [], [], []
all_pred_mses = []
for scheme_str in list(appr_sel):
    if optimiziation_schemes[scheme_str] and appr_sel[scheme_str]:

        all_appr_names.append(appr_names[scheme_str])
        all_losses.append(res['loss_vals_'+scheme_str])
        all_times.append(res['time_vals_'+scheme_str])
        all_mses.append(res['state_mses_'+scheme_str])
        all_stdfs.append(state_diff[scheme_str])
        all_plt_styles.append(plt_styles[scheme_str])
        all_pred_mses.append(pred_mses[scheme_str])


## plot and compare results

In [None]:
plot_avg = True
if plot_avg:
    all_losses = [np.nanmean(l, axis=1).reshape(-1,1) for l in all_losses]
    all_times = [np.nanmean(l, axis=1).reshape(-1,1) for l in all_times]
    all_mses = [np.nanmean(l, axis=1).reshape(-1,1) for l in all_mses]

plt.figure(figsize=(16,20))
plt.subplot(4,1,1)
for i,loss in enumerate(all_losses):
    neg_log_loss_offset = np.nanmin(loss) # (res['K']*(res['J']+1) * np.log(2.*np.pi))/2.
    xx = np.linspace(0, T_rollout, loss.shape[0])
    plt.semilogy(xx, loss[:,0] - neg_log_loss_offset, all_plt_styles[i], label=all_appr_names[i])        
if not plot_avg:
    for i,loss in enumerate(all_losses):
        neg_log_loss_offset = np.nanmin(loss) # (res['K']*(res['J']+1) * np.log(2.*np.pi))/2.
        xx = np.linspace(0, T_rollout, loss.shape[0])
        plt.semilogy(xx, loss[:,1:] - neg_log_loss_offset, all_plt_styles[i])        

plt.legend()
plt.xlabel('rollout step')
plt.ylabel('neg. log-likelihood (up to constant)')
plt.title('loss during optimization for different initialization methods')


from L96_emulator.util import sortL96fromChannels, sortL96intoChannels

plt.subplot(4,1,2)
for i,mse in enumerate(all_mses):
    xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
    plt.semilogy(xx, mse[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,mse in enumerate(all_mses):
        xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
        plt.semilogy(xx, mse[:,1:], all_plt_styles[i], marker='.')        

xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, state_MSE_peristence[:,0], ':', color='orange', marker='x', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, state_MSE_peristence[:,1:], ':', color='orange', marker='x')        
plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('initial state MSE')
plt.title('initial state error for different initialization methods')

plt.subplot(4,1,3)
for i,stdf in enumerate(all_stdfs):
    xx = np.linspace(0, T_rollout, stdf.shape[0]+1)[1:]
    plt.semilogy(xx, stdf[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,stdf in enumerate(all_stdfs):
        xx = np.linspace(0, T_rollout, stdf.shape[0]+1)[1:]
        plt.semilogy(xx, stdf[:,1:], all_plt_styles[i], marker='.')        
xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, state_MSE_peristence[:,0], ':', color='orange', marker='x', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, state_MSE_peristence[:,1:], ':', color='orange', marker='x')        
plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('mean-squared distance to initialization')
plt.title('difference to initial state estimate initialization')

plt.subplot(4,1,4)
for i, tms in enumerate(all_times):
    tms[np.where(tms > 1e6)[0]] = np.nan
    xx = np.linspace(1, T_rollout, tms.shape[0])
    plt.semilogy(xx, tms[:,0], all_plt_styles[i], label=all_appr_names[i])        
if not plot_avg:
    for i, tms in enumerate(all_times):
        tms[np.where(tms > 1e6)[0]] = np.nan
        xx = np.linspace(1, T_rollout, tms.shape[0])
        plt.semilogy(xx, tms[:,1:], all_plt_styles[i])        

plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('computation time [s]')
plt.title('full computation time for different initialization methods')

plt.suptitle('exp_id : ' + conf_exp)

plt.show()

In [None]:
res['state_mses_LBFGS_full_backsolve'].mean(axis=1)

# backsolve figure

In [None]:
"""
plt.figure(figsize=(6,4))
for i,mse in enumerate(all_mses):
    mse = np.flipud(mse)
    xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
    plt.semilogy(xx, mse[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,mse in enumerate(all_mses):
        mse = np.flipud(mse)
        xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
        plt.semilogy(xx, mse[:,1:], all_plt_styles[i], marker='.')        

xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, np.flipud(state_MSE_peristence[:,:1]), ':', color='orange', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, np.flipud(state_MSE_peristence[:,1:]), ':', color='orange')        
plt.legend()
plt.xlabel('time step T-k')
plt.ylabel('state MSE')
#plt.title('initial state error for different initialization methods')
plt.xticks([1, 10, 20, 30, 40])
#plt.savefig()
plt.show()
"""

# static 4D-Var figure

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
for i,mse in enumerate(all_mses):
    xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
    plt.semilogy(xx, mse[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,mse in enumerate(all_mses):
        xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
        plt.semilogy(xx, mse[:,1:], all_plt_styles[i], marker='.')        

xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, state_MSE_peristence[:,0], ':', color='orange', marker='x', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, state_MSE_peristence[:,1:], ':', color='orange', marker='x')        
#plt.legend()
plt.xlabel('time step t')
plt.ylabel('initial state MSE')
plt.xticks([1, 10, 20, 30, 40])


plt.subplot(1,2,2)
for i,mse in enumerate(all_pred_mses):
    xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
    plt.semilogy(xx, mse[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,mse in enumerate(all_pred_mses):
        xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
        plt.semilogy(xx, mse[:,1:], all_plt_styles[i], marker='.')        

xx = [1, T_rollout]
state_MSE_peristence_pred = ((res['initial_states'][:1] - res['test_state'])**2).mean(axis=(-2,-1))
plt.semilogy(xx, np.ones((2,1))*state_MSE_peristence_pred[:,0], ':', color='orange', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, np.ones((2,1))*state_MSE_peristence_pred[:,1:], ':', color='orange')        
plt.legend()
plt.xlabel('time step t')
T_pred = res['T_pred']
plt.ylabel(f'state prediction error MSE T={T_pred}')
plt.xticks([1, 10, 20, 30, 40])
plt.show()

In [None]:
res['state_mses_LBFGS_full_backsolve'][9,:].T


# check log-likelihoods and likelihood surface

In [None]:
from L96_emulator.data_assimilation import ObsOp_identity, ObsOp_subsampleGaussian, GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

K,J = res['K'], res['J']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : 'rk4_default',
    'K_net' : res['K'],
    'J_net' : res['J'],
    'dt_net' : res['dt_net']
}

model, model_forwarder, args = get_model(model_pars, res_dir=res_dir, exp_dir='')

ObsOp = ObsOp_subsampleGaussian if res['obs_operator']=='ObsOp_subsampleGaussian' else ObsOp_identity

# ### instantiate observation operator
model_observer = ObsOp(**res['obs_operator_args'])

prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                          scale=1.*torch.ones((1,J+1,K)))


In [None]:
if 'x_sols_LBFGS_full_persistence' in res.keys():
    x_sols = res['x_sols_LBFGS_full_persistence'] 
else:
    x_sols = res['x_sols_LBFGS_full_backsolve'] 
    print('persistence-initialized not found, using backsolve-initialized results')
    
n = 4
for j_chunks in range(x_sols.shape[0]):

    T_test = (j_chunks+1) * res['recursions_per_chunks']
    j = T_test-1

    T_obs = [(j+1)*(T_rollout//n_chunks)-1 for j in range(j_chunks+1)]

    gen = GenModel(model_forwarder, model_observer, prior, T=T_test, x_init = None)

    nx = 100
    xx = np.arange(-0.1,1.1,1./nx)

    x = sortL96intoChannels(x_sols[j_chunks][n:n+1],J=J)
    xs = [res['initial_states'][j][n:n+1] + a * ( x- res['initial_states'][j][n:n+1]) for a in xx]
    lls = [gen.log_prob(
                 y=as_tensor(sortL96intoChannels(res['targets_obs'][:len(T_obs),n:n+1],J=J)), 
                 x=as_tensor(x), 
                 m=as_tensor(res['loss_mask'][:len(T_obs),n:n+1]),
                 T_obs=T_obs).detach().cpu().numpy() for x in xs]

    plt.figure(figsize=(12, 16))
    plt.subplot(3,1,3)
    plt.plot(0., lls[np.argmin((xx-0.)**2)], 'gx', markersize=10., markeredgewidth=3, label='true initial state')
    plt.plot(1., lls[np.argmin((xx-1.)**2)], 'mo', markersize=10., markeredgewidth=3, label='est. initial state')
    plt.legend()
    plt.ylabel('log-likelihood')
    plt.xlabel('position on line from true to est. init state')
    plt.plot(xx, np.array(lls))
    plt.title(f'loss-surface for T_rollout={T_test} and estimated initial states (LBFGS, full optim, init with persistence)')

    m = np.where(sortL96fromChannels(res['loss_mask'][len(T_obs)-1, n:n+1]))[1]

    plt.subplot(3,1,2)
    x = gen.forward(as_tensor(sortL96intoChannels(x_sols[j_chunks][n:n+1],J=J)),
                    T_obs=T_obs)[-1]
    plt.plot(np.arange(K*(J+1))+1, sortL96fromChannels(x).flatten(), 'm', label='from est. initial state')
    plt.plot(m+1, sortL96fromChannels(x).T[m], 'o', color='m', label='observed value')

    x = gen.forward(as_tensor(res['initial_states'][j][n:n+1]), 
                    T_obs=T_obs)[-1]
    plt.plot(np.arange(K*(J+1))+1, sortL96fromChannels(x).flatten(), 'g', label='from true initial state')
    plt.plot(m+1, sortL96fromChannels(x).T[m], 'x', color='g', label='observed value')

    plt.plot(m+1, sortL96fromChannels(res['targets_obs'][len(T_obs)-1, n:n+1]).flatten()[m], 'kx', 
             label='observed state')

    plt.xlabel('position k')
    plt.ylabel('value X_k')
    plt.legend()
    plt.title('rollout from retrieved initial state')


    plt.subplot(3,1,1)
    x = x_sols[j_chunks][n:n+1].flatten()
    plt.plot(np.arange(K*(J+1))+1,x, 'm', label='est. initial state')
    plt.plot(m+1, x[m], 'o', color='m', label='observed value')

    x = sortL96fromChannels(res['initial_states'][j][n:n+1]).flatten()
    plt.plot(np.arange(K*(J+1))+1, x, 'g', label='true initial state')
    plt.plot(m+1, x[m], 'x', color='g', label='observed value')

    plt.xlabel('position k')
    plt.ylabel('value X_k')
    plt.legend()
    plt.title('initial state - true and estimate')
    plt.show()

### check recurisvely computed results

In [None]:
x_sols = res['x_sols_LBFGS_chunks'] 

n = 1
for j_chunks in range(x_sols.shape[0]):

    T_test = (j_chunks+1)
    j = T_test-1

    T_obs = [j_chunks]

    print('j_chunks, T_obs', j_chunks, T_obs)
    
    gen = GenModel(model_forwarder, model_observer, prior, T=T_test, x_init = None)

    nx = 100
    xx = np.arange(-0.1,1.1,1./nx)

    x = sortL96intoChannels(x_sols[j_chunks][n:n+1],J=J)
    xs = [res['initial_states'][j][n:n+1] + a * ( x- res['initial_states'][j][n:n+1]) for a in xx]
    lls = [gen.log_prob(
                 y=as_tensor(sortL96intoChannels(res['targets_obs'][:len(T_obs),n:n+1],J=J)), 
                 x=as_tensor(x), 
                 m=as_tensor(res['loss_mask'][:len(T_obs),n:n+1]),
                 T_obs=T_obs).detach().cpu().numpy() for x in xs]

    plt.figure(figsize=(12, 16))
    plt.subplot(3,1,3)
    plt.plot(0., lls[np.argmin((xx-0.)**2)], 'gx', markersize=10., markeredgewidth=3, label='true initial state')
    plt.plot(1., lls[np.argmin((xx-1.)**2)], 'mo', markersize=10., markeredgewidth=3, label='est. initial state')
    plt.legend()
    plt.ylabel('log-likelihood')
    plt.xlabel('position on line from true to est. init state')
    plt.plot(xx, np.array(lls))
    plt.title(f'loss-surface for T_rollout={T_test} and estimated initial states (LBFGS, full optim, init with persistence)')

    m = np.where(sortL96fromChannels(res['loss_mask'][len(T_obs)-1, n:n+1]))[1]

    plt.subplot(3,1,2)
    x = gen.forward(as_tensor(sortL96intoChannels(x_sols[j_chunks][n:n+1],J=J)),
                    T_obs=T_obs)[-1]
    plt.plot(np.arange(K*(J+1))+1, sortL96fromChannels(x).flatten(), 'm', label='from est. initial state')
    plt.plot(m+1, sortL96fromChannels(x).T[m], 'o', color='m', label='observed value')

    x = gen.forward(as_tensor(res['initial_states'][j][n:n+1]), 
                    T_obs=[n_chunks_recursive-1])[-1]
    plt.plot(np.arange(K*(J+1))+1, sortL96fromChannels(x).flatten(), 'g', label='from true initial state')
    plt.plot(m+1, sortL96fromChannels(x).T[m], 'x', color='g', label='observed value')

    plt.plot(m+1, sortL96fromChannels(res['targets_obs'][len(T_obs)-1, n:n+1]).flatten()[m], 'kx', 
             label='observed state')

    plt.xlabel('position k')
    plt.ylabel('value X_k')
    plt.legend()
    plt.title('rollout from retrieved initial state')


    plt.subplot(3,1,1)
    x = x_sols[j_chunks][n:n+1].flatten()
    plt.plot(np.arange(K*(J+1))+1,x, 'm', label='est. initial state')
    plt.plot(m+1, x[m], 'o', color='m', label='observed value')

    x = sortL96fromChannels(res['initial_states'][j][n:n+1]).flatten()
    plt.plot(np.arange(K*(J+1))+1, x, 'g', label='true initial state')
    plt.plot(m+1, x[m], 'x', color='g', label='observed value')

    plt.xlabel('position k')
    plt.ylabel('value X_k')
    plt.legend()
    plt.title('initial state - true and estimate')
    plt.show()

# 4D-Var

In [None]:
from L96_emulator.run import setup
from L96_emulator.run_DA import setup_4DVar

#exp_id = '10'
#exp_ids = ['14', '15', '16', '17', '18', '19', '20', '21']
#exp_ids = ['22', '23', '24', '25', '26', '27', '28', '29']
exp_ids = ['30', '31', '32', '33',  '35', '36', '37']

plt.figure(figsize=(16,6))
clrs, lgnd = ['w', 'b', 'c', 'g', 'y', 'r', 'm', 'k'], []
plt.subplot(1,3,1)
for clr in clrs:
    plt.plot(-100, -1, 'o-', color=clr, linewidth=2.5)
    
rmses_total = np.zeros(len(exp_ids))
win_lens = np.zeros(len(exp_ids))

for eid, exp_id in enumerate(exp_ids):

    exp_names = os.listdir('experiments_DA/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
    args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
    args.pop('conf_exp')

    save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
    fn = save_dir + 'out.npy'

    out = np.load(res_dir + fn, allow_pickle=True)[()]

    J = args['J']
    n_steps = args['n_steps']
    T_win = args['T_win'] 
    T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
    dt = args['dt']

    data = out['out']
    y, m = out['y'], out['m']
    x_sols = out['x_sols']
    losses, times = out['losses'], out['times']

    assert T_win == out['T_win']

    mses = np.zeros(((data.shape[0] - T_win) // T_shift + 1, data.shape[1]))
    for i in range(len(mses)):
        mse = np.nanmean((x_sols[i:i+1] - data)**2, axis=(-2, -1))
        mses[i] = mse[i *T_shift]

    xx = np.arange(0, data.shape[0] - T_win, T_shift)
    
    plt.subplot(1,3,1)
    plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
    plt.xlim(0, len(data))
    plt.subplot(1,3,2)
    plt.plot(xx, np.nanmean(mses, axis=1), 'o-', color=clrs[eid], linewidth=2.5)
    plt.xlim(0, len(data))
    plt.subplot(1,3,3)
    plt.plot(xx, mses, 'o-', color=clrs[eid], linewidth=2.5)
    plt.axis([0, len(data)-1, 0, 2])
    print(np.nanmean(mses[1:]))
    lgnd.append('window length='+str(T_win))
    
    rmses_total[eid] = np.sqrt(np.nanmean(mses))
    win_lens[eid] = T_win

plt.subplot(1,3,1)
plt.title('individial trials')
plt.ylabel('initial state MSE')
plt.subplot(1,3,2)
plt.title('averages over trials')
plt.legend(lgnd[:3])
plt.xlabel('time t')
plt.subplot(1,3,3)
plt.title('inidividual trials, zoom-in on small MSEs')
plt.show()

plt.figure(figsize=(8,4))
plt.plot(win_lens*1.5/24, rmses_total, '-', color='k', linewidth=2.5)
plt.xlabel('integration window length [d]')
plt.ylabel('RMSE')
plt.show()

In [None]:
from L96_emulator.likelihood import ObsOp_identity, ObsOp_subsampleGaussian, ObsOp_rotsampleGaussian
from L96_emulator.data_assimilation import GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

K,J = args['K'], args['J']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : 'rk4_default',
    'K_net' : args['K'],
    'J_net' : args['J'],
    'dt_net' : args['dt']
}

model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

obs_operator = args['obs_operator']
obs_pars = {}
if obs_operator=='ObsOp_subsampleGaussian':
    obs_pars['obs_operator'] = ObsOp_subsampleGaussian
    obs_pars['obs_operator_args'] = {'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']}
elif obs_operator=='ObsOp_identity':
    obs_pars['obs_operator'] = ObsOp_identity
    obs_pars['obs_operator_args'] = {}
elif obs_operator=='ObsOp_rotsampleGaussian':
    obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
    obs_pars['obs_operator_args'] = {'frq' : args['obs_operator_frq'], 
                                     'sigma2' : args['obs_operator_sig2']}
else:
    raise NotImplementedError()
model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                          scale=1.*torch.ones((1,J+1,K)))

# ### define generative model for observed data
gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

forecast_win = int(120/1.5) # 5d forecast
eval_every = int(6/1.5) # every 6h
exp_id = '21'
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
args = setup_4DVar(conf_exp=f'experiments_DA/{conf_exp}.yml')
args.pop('conf_exp')

#assert args['T_win'] == 64 # we want 4d integration window here

save_dir = 'results/data_assimilation/' + args['exp_id'] + '/'
fn = save_dir + 'out.npy'

out = np.load(res_dir + fn, allow_pickle=True)[()]

J = args['J']
n_steps = args['n_steps']
T_win = args['T_win'] 
T_shift = args['T_shift'] if args['T_shift'] >= 0 else T_win
dt = args['dt']

data = out['out']
y, m = out['y'], out['m']
x_sols = out['x_sols']
losses, times = out['losses'], out['times']

assert T_win == out['T_win']

mses = np.zeros(((data.shape[0] - forecast_win - T_win) // T_shift + 1, forecast_win//eval_every+1, y.shape[1]))
for i in range(len(mses)):
    forecasts = gen._forward(x=as_tensor(x_sols[i]), T_obs=T_win + np.arange(0,forecast_win+1,eval_every))
    n = i * T_shift + T_win
    for j in range(mses.shape[1]): # loop over integration windows
        forecast = forecasts[j].detach().cpu().numpy()
        y_obs = data[n+j*eval_every] # sortL96intoChannels(y[n+j*eval_every],J=J)
        mses[i,j] = np.nanmean((forecast - y_obs)**2, axis=(-2, -1))
xx = 1.5/24 * np.arange(0, forecast_win+1, eval_every)

plt.figure(figsize=(6,5))
plt.plot(xx,np.nanmean(np.sqrt(mses),axis=(0,-1)), '-', color='k', linewidth=2.5)
plt.xlabel('forecast time [d]')
plt.ylabel('RMSE')
plt.yticks([0.5, 1.0, 1.5])
plt.xticks(0.5*np.arange(10.1))
plt.show()



In [None]:
from L96_emulator.data_assimilation import ObsOp_identity, ObsOp_subsampleGaussian, GenModel, get_model, as_tensor
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

K,J = args['K'], args['J']

model_pars = {
    'exp_id' : args['model_exp_id'],
    'model_forwarder' : 'rk4_default',
    'K_net' : args['K'],
    'J_net' : args['J'],
    'dt_net' : args['dt']
}

model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

ObsOp = ObsOp_subsampleGaussian if args['obs_operator']=='ObsOp_subsampleGaussian' else ObsOp_identity

# ### instantiate observation operator
model_observer = ObsOp(**{'r' : args['obs_operator_r'], 'sigma2' : args['obs_operator_sig2']})

prior = torch.distributions.normal.Normal(loc=torch.zeros((1,J+1,K)), 
                                          scale=1.*torch.ones((1,J+1,K)))

# ### define generative model for observed data
gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None)

In [None]:
i,j = 16, 9
T = 40


idx_off = (i)*T_win-1 if i > 0 else 0

plt.figure(figsize=(12,6))
x_init = x_sols[i,j] #
traj = sortL96fromChannels(torch.cat(gen._forward(as_tensor(x_init), T_obs=np.arange(T))).detach().cpu().numpy())
plt.subplot(1,2,1)
plt.imshow(traj, aspect='auto')
plt.colorbar()
plt.xlabel('location k')
plt.ylabel('rollout time step t')
plt.subplot(1,2,2)
plt.plot(traj)
plt.ylabel('state value x_k')
plt.xlabel('rollout time step t')
plt.suptitle('emulator rollout from estimated initial state')
plt.show()


rollout_mses = ((traj-sortL96fromChannels(data[np.arange(T)+idx_off, j]))**2).mean(axis=1)

rollout_mses_masked = (m[np.arange(T)+idx_off, j, 0]*(traj-y[np.arange(T)+idx_off, j])**2)
rollout_mses_masked = rollout_mses_masked.sum(axis=(-1)) / m[np.arange(T)+idx_off, j].sum(axis=(1,2))

traj = sortL96fromChannels(data[np.arange(T)+idx_off, j])
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(traj, aspect='auto')
plt.xlabel('location k')
plt.ylabel('rollout time step t')
plt.colorbar()
plt.subplot(1,2,2)
plt.plot(traj)
plt.ylabel('state value x_k')
plt.xlabel('rollout time step t')
plt.suptitle('simulator rollout from true initial state')
plt.show()

plt.figure(figsize=(12,6))
x_init = data[idx_off, j] #
traj = sortL96fromChannels(torch.cat(gen._forward(as_tensor(x_init), T_obs=np.arange(T))).detach().cpu().numpy())
plt.subplot(1,2,1)
plt.imshow(traj, aspect='auto')
plt.colorbar()
plt.xlabel('location k')
plt.ylabel('rollout time step t')
plt.subplot(1,2,2)
plt.plot(traj)
plt.ylabel('state value x_k')
plt.xlabel('rollout time step t')
plt.suptitle('emulator rollout from estimated initial state')
plt.show()

plt.figure(figsize=(12,6))
x_init = x_sols[i-1,j] #
x_init = sortL96fromChannels(torch.cat(gen._forward(as_tensor(x_init), T_obs=[T-1])).detach().cpu().numpy())
traj = sortL96fromChannels(torch.cat(gen._forward(as_tensor(x_init), T_obs=np.arange(T))).detach().cpu().numpy())
plt.subplot(1,2,1)
plt.imshow(traj, aspect='auto')
plt.colorbar()
plt.xlabel('location k')
plt.ylabel('rollout time step t')
plt.subplot(1,2,2)
plt.plot(traj)
plt.ylabel('state value x_k')
plt.xlabel('rollout time step t')
plt.suptitle('emulator rollout from previous estimated initial state (aka background)')
plt.show()


In [None]:
plt.plot(rollout_mses, label='MSE to ground-truth state')
plt.plot(rollout_mses_masked, label='MSE to observation (masked & noisy)')
plt.xlabel('rollout time step t')
plt.legend()
plt.ylabel('MSE')
plt.show()

# share notebook results via html file

In [None]:
!jupyter nbconvert --output-dir='/gpfs/home/nonnenma/projects/lab_coord/mdml_wiki/marcel/emulators' --to html data_assimilation.ipynb