# Emulators

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### load / simulate some toy data
- note we're simulating with the numba simulator, so that one will potentially have an edge over the pytorch emulator in predicting it's own output

In [None]:
from L96sim.L96_base import f1, f2, pf2
from L96_emulator.util import predictor_corrector, rk4_default
from L96_emulator.run import sel_dataset_class

try: 
    K, J, T, dt = args['K'], args['J'], args['T'], args['dt']
    spin_up_time, train_frac = args['spin_up_time'], args['train_frac']
    normalize_data = bool(args['normalize_data'])
except:
    K, J, T, dt = 36, 10, 605, 0.01
    spin_up_time, train_frac = 5., 0.8
    normalize_data = False

F, h, b, c = 10, 1, 10, 10

fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

resimulate, save_sim = True, True
if resimulate:
    print('simulating data')
    X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,50)
    dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
    times = np.linspace(0, T, int(np.floor(T/dt)+1))
    
    out = rk4_default(fun=fun, y0=X_init.copy(), times=times)

    # filename for data storage
    if save_sim: 
        np.save(data_dir + fn_data, out.astype(dtype=dtype_np))
else:
    print('loading data')
    out = np.load(data_dir + fn_data + '.npy')

plt.figure(figsize=(8,4))
plt.imshow(out.T, aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.show()

prediction_task = 'state'
lead_time = 1
DatasetClass = sel_dataset_class(prediction_task=prediction_task)
dg_train = DatasetClass(data=out, J=J, offset=lead_time, normalize=normalize_data, 
                   start=int(spin_up_time/dt), 
                   end=int(np.floor(out.shape[0]*train_frac)))

### finite differences
- scipy.optimize.approx_fprime

In [None]:
from scipy.optimize import approx_fprime
from L96_emulator.util import predictor_corrector, rk4_default
import time


n_trials = 20 # number of parallel solves
n_starts = 5*int(spin_up_time/dt) * np.arange(1,n_trials+1) # initial conditions, taken from long simulation
n_rollout = 50 # rollout steps

from L96sim.L96_base import pf2 # parallelized numba simulator

dX_dt = np.empty((K*(J+1), n_trials), dtype=dtype_np)
if J > 0:
    def fun(t, x):
        return pf2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)
             
times = dt * np.arange(0, n_rollout+1) # time points for numba solver, will only use np.diff(times)

def pred(x):
    return rk4_default(fun=fun, y0=x.copy().reshape(-1,n_trials), times=times)[-1,:]

def loss_np(x, target): 
    return ((pred(x)-target)**2).mean()

target = out[n_starts+n_rollout].T.copy() # predict after rollout
x = out[n_starts].T.copy()                # initialization estimate

# potentially corrupt initialization estimate (for non-zero gradients...)
eps = np.ones((x.shape))
eps[:K,:] *= np.mean(x[:K,:])
eps[K:,:] *= np.mean(x[K:,:])
x += np.random.normal(size=x.shape) * ( 0.1 * eps ) # corrupting with 0.1 std noise

def loss_sp(x):
    return loss_np(x, target=target)

# finite-difference gradients
t = time.time()
fprime_sp = approx_fprime(xk=x.reshape(-1), f=loss_sp, epsilon=1e-5).reshape(-1, len(n_starts))
print(f'took {time.time()-t}s')

plt.plot(fprime_sp)
plt.show()

# emulator gradients
- loss.backward()

In [None]:
import torch 
import numpy as np
from L96_emulator.networks import named_network
from L96_emulator.eval import load_model_from_exp_conf, named_network
from L96_emulator.run import setup

exp_id = 20

if exp_id is None: 
    # loading 'perfect' (up to machine-precision-level quirks) L96 model in pytorch
    model, model_forwarder = named_network(
        model_name='MinimalConvNetL96',
        n_input_channels=J+1,
        n_output_channels=J+1,
        seq_length=1,
        **{'filters': [0],
           'kernel_sizes': [4],
           'init_net': 'analytical',
           'K_net': 36,
           'J_net': 10,
           'dt_net': 0.01,
           'model_forwarder': 'rk4_default'}
    )

else:
    # loading trained L96 model in pytorch
    exp_names = os.listdir('experiments/')   
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
    print('conf_exp', conf_exp)
    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')

    args['model_forwarder'] = 'rk4_default'  # update numerical integration method to RK4
    args['dt_net'] = dt                      # with current step size

    model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)

    if not training_outputs is None:
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

        fig = plt.figure(figsize=(8,8))
        seq_length = args['seq_length']
        plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
        plt.title('training')
        plt.ylabel('validation error')
        plt.legend()
        fig.patch.set_facecolor('xkcd:white')
        plt.show()

In [None]:
import torch 
import numpy as np
from L96_emulator.eval import sortL96fromChannels, sortL96intoChannels, Rollout


def loss_torch(x, target):
    roller_outer.X = x    
    x_pred = roller_outer.forward()
    return ((x_pred - target)**2).mean()

# rollout operator
roller_outer = Rollout(model_forwarder, prediction_task='state', K=K, J=J, 
                       N=n_trials, T=n_rollout)

x_torch = torch.as_tensor(sortL96intoChannels(np.atleast_2d(x.T),J=J), 
                          dtype=dtype, 
                          device=device)
x_torch = torch.nn.Parameter(x_torch)
target_torch = torch.as_tensor(sortL96intoChannels(np.atleast_2d(target.T),J=J),
                               dtype=dtype,
                               device=device)

def loss_(x):
    return loss_torch(x, target=target_torch)

# torch loss
t = time.time()
loss = loss_(x_torch)
loss.backward()
grad_torch = sortL96fromChannels(roller_outer.X.grad.detach().cpu().numpy()).reshape(n_trials, -1).T
print(f'took {time.time()-t}s')

plt.plot(grad_torch)
plt.show()

## some plotting

In [None]:
plt.figure(figsize=(16,24))
for i in range(n_trials):
    plt.subplot(np.ceil(n_trials/2), 2, i+1)
    plt.plot(grad_torch[:,i], label='torch')
    plt.plot(fprime_sp[:,i], '--', label='scipy fd')
plt.subplot(np.ceil(n_trials/2), 2, 1)
plt.legend()
plt.suptitle(f'rollout through {n_rollout} steps')
plt.show()
