# Experimenting with implicit solvers

- mostly Forward- vs Backwards-Euler

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device, as_tensor

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### pick a (trained) emulator

In [None]:
from L96_emulator.run import setup

# experiments to use: 
# dt=0.05  : exp_id = 26 for minimal, exp_id=27 for bilinear net
# dt=0.0125: exp_id = 28 for minimal, exp_id=29 for bilinear net

exp_id = 92
exp_id = 70 # best Bilinear Network

exp_names = os.listdir('experiments/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

args = setup(conf_exp=f'experiments/{conf_exp}.yml')
args.pop('conf_exp')


K,J = args['K'], args['J']

In [None]:
assert args['dt_net'] == args['dt']
if J > 0:
    F, h, b, c = 10., 1., 10., 10.
else:
    F, h, b, c = 8., 1., 10., 10.


### load and instantiate the emulator model and an 'optimal' comparison available for L96

In [None]:
import torch 
import numpy as np
from L96_emulator.eval import load_model_from_exp_conf
from L96_emulator.networks import named_network


args['model_forwarder'] = 'rk4_default'
if args['padding_mode'] == 'valid':
    print('switching from local training to global evaluation')
    args['padding_mode'] = 'circular'
model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)

if not training_outputs is None:
    training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

    fig = plt.figure(figsize=(8,8))
    seq_length = args['seq_length']
    plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
    plt.title('training')
    plt.ylabel('validation error')
    plt.legend()
    fig.patch.set_facecolor('xkcd:white')
    plt.show()

if args['model_name'] in ['MinimalConvNetL96','BilinearConvNetL96']:
    model_ubo = args['model_name']
else:
    model_ubo = 'MinimalConvNetL96'
# upper bound: model re-implementation of L96 in torch (conv1d->pointwise_square->conv1d)
model_ubo, model_forwarder_ubo =named_network(
        model_name=model_ubo,
        n_input_channels=J+1,
        n_output_channels=J+1,
        seq_length=1,
        **{'filters': [0],
           'kernel_sizes': [4],
           'init_net': 'analytical',
           'K_net': K,
           'J_net': J,
           'dt_net': args['dt'],
           'F_net' : F, 
           'h_net' : h, 
           'b_net' : b, 
           'c_net' : c, 
           'model_forwarder': 'rk4_default',
           'padding_mode' : 'circular'}
    )

# evaluate errors

In [None]:
from L96_emulator.eval import sortL96fromChannels, sortL96intoChannels
dX_dt = np.empty(K*(J+1), dtype=dtype_np)

### settings for different solvers

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default, Model_forwarder_forwardEuler

dts = {Model_forwarder_predictorCorrector : args['dt']/10,
       Model_forwarder_forwardEuler : args['dt'],
       Model_forwarder_rk4default : args['dt']}

### load some data to get sensible test system state

In [None]:
from L96_emulator.run import sel_dataset_class
from L96_emulator.util import predictor_corrector, rk4_default, get_data

spin_up_time, train_frac = args['spin_up_time'], args['train_frac']
normalize_data = bool(args['normalize_data'])
T, N_trials, dt = args['T'], args['N_trials'], args['dt']

out, _ = get_data(K=K, J=J, T=T, dt=dt, N_trials=N_trials, F=F, h=h, b=b, c=c, 
                  resimulate=True, solver=rk4_default,
                  save_sim=False, data_dir=data_dir)
out = out.reshape(1, *out.shape) if len(out.shape)==2 else out


prediction_task = 'state'
lead_time = 1
DatasetClass = sel_dataset_class(prediction_task=prediction_task, N_trials=N_trials, local=False)
dg_train = DatasetClass(data=out, J=J, offset=lead_time, normalize=normalize_data, 
                   start=int(spin_up_time/dt), 
                   end=int(np.floor(T/dt*train_frac)))


### set up ground-truth simulator code

In [None]:
from L96sim.L96_base import f1, f2, pf2

if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

## implicit solver

In [None]:
def backwards_euler(model, x, x_next, dt):

    return x + dt * model.forward(x_next)

def loss_fun(model, x, x_next, dt):
    
    return torch.mean( (x_next - backwards_euler(model, x, x_next, dt))**2 )

def implicit_forwarder(x_next, f_loss, n_steps=1):
    
    loss_vals = np.zeros(n_steps)
    optimizer = torch.optim.LBFGS(params=[x_next],
                                  lr=0.1,
                                  max_iter=1000,
                                  max_eval=None,
                                  tolerance_grad=1e-12,
                                  tolerance_change=1e-15,
                                  history_size=100,
                                  line_search_fn='strong_wolfe')

    for i_n in range(n_steps):

        def closure():
            loss = f_loss(x_next)
            if torch.is_grad_enabled():
                optimizer.zero_grad()
            if loss.requires_grad:
                loss.backward()
            return loss

        optimizer.step(closure)
        with torch.no_grad():
            loss = f_loss(x_next)        
        loss_vals[i_n] = loss.detach().cpu().numpy()
        #time_vals[i_n] = time.time() - time_vals[i_+i_n,n]
        
    print('final loss:',  loss_vals[-1])

    return x_next

## error on resolvent (1-step integration error) for different solvers

In [None]:
T_start = np.arange(int(spin_up_time/dt), int(spin_up_time/dt)+10000, 100)
i_trial = np.random.choice(N_trials, size=T_start.shape)

idx_show = np.arange(0,len(T_start)-1, len(T_start)//3)

In [None]:
import torch 

print('\n')
print('MSEs are on system state !')
print('\n')

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return as_tensor(sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))

#for MFWD in [Model_forwarder_predictorCorrector, Model_forwarder_rk4default]:
for MFWD in [Model_forwarder_forwardEuler]:
    model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=dts[MFWD])
    model_forwarder = MFWD(model=model, dt=dts[MFWD])
    model_forwarder_ubo = Model_forwarder_rk4default(model=model, dt=dts[MFWD])

    MSEs = np.zeros(len(T_start))
    MSEs_ubo = np.zeros(len(T_start))
    MSEs_imp = np.zeros(len(T_start))
    for i in range(len(T_start)):
        inputs = out[i_trial[i], T_start[i]]
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))

        out_np = model_forwarder_np(inputs_torch)
        out_model = model_forwarder(inputs_torch)
        out_ubo = model_forwarder_ubo(inputs_torch)
        
        def f_loss(x_next):
            return loss_fun(model=model, x=inputs_torch, x_next=x_next, dt=dts[MFWD])
        out_imp = implicit_forwarder(x_next= torch.nn.Parameter(1.*inputs_torch), f_loss=f_loss, n_steps=1)
        #out_imp = out_imp.detach().cpu().numpy()

        MSEs[i] = ((out_np - out_model)**2).mean().detach().cpu().numpy()
        MSEs_imp[i] = ((out_np - sortL96fromChannels(out_imp))**2).mean().detach().cpu().numpy()
        MSEs_ubo[i] = ((out_np - out_ubo)**2).mean().detach().cpu().numpy()

    print('\n')
    print(f'solver {MFWD}, dt = {dts[MFWD]}')
    print('\n')

    print('MSEs                  ', MSEs[idx_show])
    print('MSEs - implicit method', MSEs_imp[idx_show])
    print('MSEs - upper bound    ', MSEs_ubo[idx_show])

    plt.figure(figsize=(8,5))
    plt.plot(np.sort(MSEs), label='learned')
    plt.plot(np.sort(MSEs_imp), label='learned - implicit')
    plt.plot(np.sort(MSEs_ubo), label='analytic')
    plt.title('comparison of MSEs (sorted), learned and analyical')
    plt.legend()
    plt.show()

## multi-step integration error (rollout error) for different solvers

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout

print('\n')
print('MSEs are on system state !')
print('\n')


MTU = 5 # rollout time in time units, should be rough estimate of first Lyapunov exponent

for i in idx_show[-1:]:
    print(f'integrating for starting point {i+1} / {len(T_start)}')
    for MFWD in [Model_forwarder_forwardEuler]:

        print(f'solver {MFWD}, dt = {dts[MFWD]}')
        T_dur = int(MTU/dts[MFWD])

        model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), 
                                  dt=dts[MFWD])
        model_forwarder = MFWD(model=model, dt=dts[MFWD])
        model_forwarder_ubo = Model_forwarder_rk4default(model=Torch_solver(fun), dt=dts[MFWD]/10.)  # 1/10-th step size
        
        class Model_forwarder_implicit(torch.nn.Module):
            
            def __init__(self, model, dt):
                super(Model_forwarder_implicit, self).__init__()
                self.model = model
                self.dt = dt
                
            def forward(self, x):
                def f_loss(x_next):
                    return loss_fun(model=self.model, x=x, x_next=x_next, dt=self.dt)
                return implicit_forwarder(x_next= torch.nn.Parameter(1.*x), f_loss=f_loss, n_steps=1)                
        model_forwarder_imp = Model_forwarder_implicit(model=model, dt=dts[MFWD])
        model_forwarder_imp_10 = Model_forwarder_implicit(model=model, dt=dts[MFWD]/10) # 1/10-th step size

        model_simulate = get_rollout_fun(dg_train, model_forwarder_imp_10, prediction_task)
        imp_simulate = get_rollout_fun(dg_train, model_forwarder_imp, prediction_task)
        ubo_simulate = get_rollout_fun(dg_train, model_forwarder_ubo, prediction_task)
        np_simulate = get_rollout_fun(dg_train, model_forwarder_np, prediction_task)

        out_np = np_simulate(y0=dg_train[T_start[i]].copy(), 
                             dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                             n_steps=T_dur)
        out_np = sortL96fromChannels(out_np * dg_train.std + dg_train.mean)

        out_model = model_simulate(y0=dg_train[T_start[i]].copy(), 
                                   dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                                   n_steps=T_dur*10)
        out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

        out_imp = imp_simulate(y0=dg_train[T_start[i]].copy(), 
                               dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                               n_steps=T_dur)
        out_imp = sortL96fromChannels(out_imp * dg_train.std + dg_train.mean)

        out_ubo = ubo_simulate(y0=dg_train[T_start[i]].copy(), 
                               dy0=dg_train[T_start[i]]-dg_train[T_start[i]-dg_train.offset],
                                   n_steps=T_dur*10)
        out_ubo = sortL96fromChannels(out_ubo * dg_train.std + dg_train.mean)

        fig = plot_rollout(out_np, out_imp, out_comparison=out_model, n_start=0, n_steps=T_dur, K=K)
        plt.subplot(1,2,2)
        plt.legend(['trained model', '(only slow vars)', 'upper-bound model', '(only slow vars)'])
        plt.suptitle('integration scheme: ' + str(MFWD))
        plt.show()

In [None]:
plt.figure(figsize=(16,8))

fontsize=14

plt.subplot(2,2,1)
plt.imshow(out_np.T, aspect='auto')
plt.title('RK4 (true simulator)', fontsize=fontsize)
plt.ylabel('location k', fontsize=fontsize)
plt.colorbar()

plt.subplot(2,2,2)
plt.imshow(out_ubo.T, aspect='auto')
plt.title('RK4 (true simulator), 1/10 step size', fontsize=fontsize)
plt.colorbar()

plt.subplot(2,2,3)
plt.imshow(out_imp.T, aspect='auto')
plt.title('backwards-Euler (emulator)', fontsize=fontsize)
plt.xlabel('integration step n', fontsize=fontsize)
plt.ylabel('location k', fontsize=fontsize)
plt.colorbar()

plt.subplot(2,2,4)
plt.imshow(out_model.T, aspect='auto')
plt.title('backwards-Euler (emulator), 1/10 step size', fontsize=fontsize)
plt.xlabel('integration step n', fontsize=fontsize)
plt.colorbar()

#plt.savefig(res_dir + 'figs/implicit.pdf', bbox_inches='tight', pad_inches=0, frameon=False)


plt.show()

In [None]:
fig = plot_rollout(out_np, out_model, out_comparison=out_ubo, n_start=0, n_steps=T_dur)
plt.subplot(1,2,2)
plt.legend(['deepNet', 'sqrNet (analytic)'])
plt.suptitle('integration scheme: ' + str(MFWD))
#plt.savefig(res_dir + 'figs/deepNet_rollout.pdf', bbox_inches='tight', pad_inches=0, frameon=False)