# 1.3 Solve a fully-observed inverse problem

Given $x_T$, estimate $x_0$ by matching $G^{(T)}(x_0)$ to $x_T$. Use autodiff on $G$ to calculate gradients of an error metric w.r.t. $x_0$. Compare the resulting rollout to the original `true' simulation.

Compare three approaches:
- $argmin_{x_0} || x_T - G^{(T)}(x_i) ||$, i.e. $T$ steps in one go
- $argmin_{x_i} || x_{i+T_i} - G^{(T_i)}(x_i) ||$, i.e. $T_i$ steps at a time, with $\sum_i T_i = T$. In the extreme case of $T_i=1$, this becomes very similar to implicit numerical methods. Can invertible neural networks help beyond providing better initializations for $x_i$ ? 
- solving backwards: more of the extreme case of $\forall i: T_i=1$, however: Only for some forward numerical solvers can we just reverse time [1] and expect to return to initial conditions. Leap-frog works, but e.g. forward-Euler time-reversed is backward-Euler. 

Generally, how do these approaches differ around \& beyond the horizon of predictability? Which solutions do they pick, and how easy is it to get uncertainties from them?

[1] https://scicomp.stackexchange.com/questions/32736/forward-and-backward-integration-cause-of-errors?noredirect=1&lq=1

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### pick a (trained) emulator

In [None]:
from L96_emulator.run import setup

exp_id = 24
J = 0
if exp_id is None: 
    # loading 'perfect' (up to machine-precision-level quirks) L96 model in pytorch    
    if J > 0:
        conf_exp = '00_analyticalMinimalConvNet'
    else:
        conf_exp = '00_analyticalMinimalConvNet_oneLevel'
else:
    exp_names = os.listdir('experiments/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')
    
print('conf_exp', conf_exp)


model_forwarder_str = 'rk4_default'

optimizer_str = 'LBFGS' #'LBFGS', 'SGD'
obs_operator_str = 'ObsOp_identity' #'ObsOp_subsampleGaussian'

In [None]:
if not obs_operator_str is None:
    fn = f'fullyobs_initstate_tests_exp{exp_id}_{model_forwarder_str}_{optimizer_str}_{obs_operator_str}.npy'
else:
    fn = f'fullyobs_initstate_tests_exp{exp_id}_{model_forwarder_str}_{optimizer_str}.npy'

res = np.load(res_dir + 'results/data_assimilation/'+ fn, allow_pickle=True)[()]

targets=res['targets']
initial_states=res['initial_states']

J = res['J']
n_steps = res['n_steps']
n_chunks = res['n_chunks']
n_chunks_recursive = res['n_chunks_recursive']
T_rollout = res['T_rollout']
dt = res['dt']
back_solve_dt_fac=res['back_solve_dt_fac']
n_starts = res['n_starts']

optimiziation_schemes = res['optimiziation_schemes']

In [None]:
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels

#"""
appr_sel = {'LBFGS_chunks' : True,
            'LBFGS_full_chunks' : True,
            'LBFGS_recurse_chunks' : True,
            'LBFGS_full_persistence' : True,
            'backsolve' : True,
            'LBFGS_full_backsolve' : True,
           } 
#"""
appr_sel = optimiziation_schemes.keys()

appr_names = {
    'LBFGS_chunks' : 'optim over single chunk (current chunk error)',
    'LBFGS_full_backsolve' : 'full optim, init from backsolve', 
    'LBFGS_full_chunks' : 'full optim, init from chunks',
    'LBFGS_full_persistence' : 'full optim, init from persistence',
    'LBFGS_recurse_chunks' : 'full optim, recursive from backsolve',
    'backsolve' : 'forward solve in reverse'
}

plt_styles = {
    'LBFGS_chunks' : 'r--',
    'LBFGS_full_backsolve' : 'b-', 
    'LBFGS_full_chunks' : 'm-',
    'LBFGS_full_persistence' : 'g-',
    'LBFGS_recurse_chunks' : 'k-',
    'backsolve' : 'c--'        
}


# compute state differences

state_diff = {}
state_diff['LBFGS_chunks'] = [((sortL96fromChannels(res['targets']) - res['x_sols_LBFGS_chunks'][0])**2).mean(axis=1).reshape(1,-1)]
state_diff['LBFGS_chunks'] += [(np.diff(res['x_sols_LBFGS_chunks'],axis=0)**2).mean(axis=2)]
state_diff['LBFGS_chunks'] = np.concatenate(state_diff['LBFGS_chunks'])

x_init = res['x_sols_LBFGS_chunks'][res['recursions_per_chunks']-1::res['recursions_per_chunks']]
state_diff['LBFGS_full_chunks'] = ((x_init-res['x_sols_LBFGS_full_chunks'])**2).mean(axis=-1)

state_diff['backsolve'] = [((sortL96fromChannels(res['targets']) - res['x_sols_backsolve'][0])**2).mean(axis=1).reshape(1,-1)]
state_diff['backsolve'] += [(np.diff(res['x_sols_backsolve'],axis=0)**2).mean(axis=2)]
state_diff['backsolve'] = np.concatenate(state_diff['backsolve'])

x_init = res['x_sols_backsolve'][res['recursions_per_chunks']-1::res['recursions_per_chunks']]
state_diff['LBFGS_full_backsolve'] = ((x_init-res['x_sols_LBFGS_full_backsolve'])**2).mean(axis=-1)

x_init = sortL96fromChannels(res['targets'])
state_diff['LBFGS_full_persistence'] = ((x_init-res['x_sols_LBFGS_full_persistence'])**2).mean(axis=-1)

state_diff['LBFGS_recurse_chunks'] = [((sortL96fromChannels(res['targets']) - res['x_sols_LBFGS_recurse_chunks'][0])**2).mean(axis=1).reshape(1,-1)]
state_diff['LBFGS_recurse_chunks'] += [(np.diff(res['x_sols_LBFGS_recurse_chunks'],axis=0)**2).mean(axis=2)]
state_diff['LBFGS_recurse_chunks'] = np.concatenate(state_diff['LBFGS_recurse_chunks'])

state_MSE_peristence = ((res['targets'].reshape(1,-1,res['J']+1, res['K']) - res['initial_states'])**2).mean(axis=(-2,-1))

all_appr_names, all_losses, all_times, all_mses, all_stdfs, all_plt_styles = [], [], [], [], [], []
for scheme_str in list(appr_sel):
    if optimiziation_schemes[scheme_str]:

        all_appr_names.append(appr_names[scheme_str])
        all_losses.append(res['loss_vals_'+scheme_str])
        all_times.append(res['time_vals_'+scheme_str])
        all_mses.append(res['state_mses_'+scheme_str])
        all_stdfs.append(state_diff[scheme_str])
        all_plt_styles.append(plt_styles[scheme_str])


## plot and compare results

In [None]:
plot_avg = True
if plot_avg:
    all_losses = [l.mean(axis=1).reshape(-1,1) for l in all_losses]
    all_times = [l.mean(axis=1).reshape(-1,1) for l in all_times]
    all_mses = [l.mean(axis=1).reshape(-1,1) for l in all_mses]

neg_log_loss_offset = (res['K']*(res['J']+1) * np.log(2.*np.pi))/2.

plt.figure(figsize=(16,20))
plt.subplot(4,1,1)
for i,loss in enumerate(all_losses):
    xx = np.linspace(0, T_rollout, loss.shape[0])
    plt.semilogy(xx, loss[:,0] - neg_log_loss_offset, all_plt_styles[i], label=all_appr_names[i])        
if not plot_avg:
    for i,loss in enumerate(all_losses):
        xx = np.linspace(0, T_rollout, loss.shape[0])
        plt.semilogy(xx, loss[:,1:] - neg_log_loss_offset, all_plt_styles[i])        

plt.legend()
plt.xlabel('rollout step')
plt.ylabel('neg. log-likelihood (up to constant)')
plt.title('loss during optimization for different initialization methods')


from L96_emulator.util import sortL96fromChannels, sortL96intoChannels

plt.subplot(4,1,2)
for i,mse in enumerate(all_mses):
    xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
    plt.semilogy(xx, mse[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,mse in enumerate(all_mses):
        xx = np.linspace(0, T_rollout, mse.shape[0]+1)[1:]
        plt.semilogy(xx, mse[:,1:], all_plt_styles[i], marker='.')        

xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, state_MSE_peristence[:,0], ':', color='orange', marker='x', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, state_MSE_peristence[:,1:], ':', color='orange', marker='x')        
plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('initial state MSE')
plt.title('initial state error for different initialization methods')

plt.subplot(4,1,3)
for i,stdf in enumerate(all_stdfs):
    xx = np.linspace(0, T_rollout, stdf.shape[0]+1)[1:]
    plt.semilogy(xx, stdf[:,0], all_plt_styles[i], marker='.',label=all_appr_names[i])        
if not plot_avg:
    for i,stdf in enumerate(all_stdfs):
        xx = np.linspace(0, T_rollout, stdf.shape[0]+1)[1:]
        plt.semilogy(xx, stdf[:,1:], all_plt_styles[i], marker='.')        
xx = np.linspace(0, T_rollout, state_MSE_peristence.shape[0]+1)[1:]
plt.semilogy(xx, state_MSE_peristence[:,0], ':', color='orange', marker='x', label='persistence')        
if not plot_avg:
    plt.semilogy(xx, state_MSE_peristence[:,1:], ':', color='orange', marker='x')        
plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('mean-squared distance to initialization')
plt.title('difference to initial state estimate initialization')

plt.subplot(4,1,4)
for i, tms in enumerate(all_times):
    tms[np.where(tms > 1e6)[0]] = np.nan
    xx = np.linspace(1, T_rollout, tms.shape[0])
    plt.semilogy(xx, tms[:,0], all_plt_styles[i], label=all_appr_names[i])        
if not plot_avg:
    for i, tms in enumerate(all_times):
        tms[np.where(tms > 1e6)[0]] = np.nan
        xx = np.linspace(1, T_rollout, tms.shape[0])
        plt.semilogy(xx, tms[:,1:], all_plt_styles[i])        

plt.legend()
plt.xlabel('rollout steps')
plt.ylabel('computation time [s]')
plt.title('full computation time for different initialization methods')

plt.suptitle('exp_id : ' + conf_exp)

plt.show()

In [None]:
"""
from L96_emulator.util import predictor_corrector, rk4_default, get_data

out, datagen_dict = get_data(K=res['K'], J=res['J'], T=res['T'], dt=res['dt'], N_trials=1, 
                             F=res['F'], h=res['h'], b=res['b'], c=res['c'], 
                             resimulate=True, solver=rk4_default,
                             save_sim=False, data_dir=data_dir)
"""

# share notebook results via html file

In [None]:
!jupyter nbconvert --output-dir='/gpfs/home/nonnenma/projects/lab_coord/mdml_wiki/marcel/emulators' --to html data_assimilation.ipynb