# Parametrization learning with emulators

- first attempt, for now assuming a 'perfect' emulator (how to get that starting out from a suboptimal parametrization is an open research question)

- heavily leaning on the setup of Rasp (2020), where Rasp generated training data to refine a parametrization 'within the loop' of a simulator. To do that, Rasp nudges a high-res model (two-level L96) to the parametrized low-res model (one-level L96), using the high-res model outputs as training targets for the paramtrization. The nudging is annoying an allegedly fiddly, and the usfulness of the training data depends on the validty of the nudging.
- we here use a non-nudged free high-res simulation to generate training data for the paramtetrization, which we optimize directly with gradients through the differentiable emulator

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim
from L96_emulator.util import dtype, dtype_np, device, as_tensor

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

# setup from Rasp (2020)
https://arxiv.org/abs/1907.01351

In [None]:
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
from L96_emulator.networks import named_network

K, J, J_net = 36, 10, 0
dt = 0.001

model_name = 'MinimalConvNetL96'
model_kwargs = {
        'K_net' : K, 
        'J_net' : J_net, 
        'init_net' : 'analytical', 
        'dt_net' : dt, 

        'l96_F' : 10., 
        'l96_h' : 1., 
        'l96_b' : 10., 
        'l96_c' : 10., 
        'model_forwarder' : 'rk4_default',
        'padding_mode' : 'circular'
}

model, model_forwarder = named_network(model_name, 
                                       n_input_channels=J_net+1, 
                                       n_output_channels=J_net+1, 
                                       seq_length=1, **model_kwargs)

# define simple linear parametrization
- grab two linear parametrizations from Rasp's paper (cf. figure 3): 
    - the 'ideal' one ($a=-0.31$, $b=-0.2$) trained on the real two-level L96 model (as above) and 
    - the 'bad' one ($a=-3/4$, $b=-0.4$) trained on different two-level L96 parameters $F=7, h=2, b=c=5$ 

In [None]:
from L96_emulator.parametrization import Parametrization_lin, Parametrization_nn

param_lin_good = Parametrization_lin(a=as_tensor(np.array([-0.31])), b=as_tensor(np.array([-0.2])))
param_lin_bad = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4])))

# define ground-truth and parameterized models in Pytorch

In [None]:
from L96_emulator.parametrization import Parametrized_twoLevel_L96
from L96_emulator.networks import Model_forwarder_rk4default
from L96sim.L96_base import f1, f2, pf2
import torch

model_parametrized_bad = Parametrized_twoLevel_L96(emulator=model, 
                                               parametrization=param_lin_bad)
model_forwarder_parametrized_bad = Model_forwarder_rk4default(model=model_parametrized_bad, dt=dt)

model_parametrized_good = Parametrized_twoLevel_L96(emulator=model, 
                                               parametrization=param_lin_good)
model_forwarder_parametrized_good = Model_forwarder_rk4default(model=model_parametrized_good, dt=dt)


# ground-truth two-level L96 model (based on Numba implementation):

dX_dt = np.empty(K*(J+1), dtype=dtype_np)
if J > 0:
    def fun(t, x):
        return f2(x, model_kwargs['l96_F'], model_kwargs['l96_h'], model_kwargs['l96_b'], model_kwargs['l96_c'], dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, model_kwargs['l96_F'], dX_dt, K)

class Torch_solver(torch.nn.Module):
    # numerical solver (from numpy/numba/Julia)
    def __init__(self, fun):
        self.fun = fun
    def forward(self, x):
        x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
        return sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J)

model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=dt)


# create some training data from the true two-level L96

In [None]:

X_init = model_kwargs['l96_F'] * (0.5 + np.random.randn(1,K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,50)

def model_simulate(y0, dy0, n_steps):
    x = np.empty((n_steps+1, *y0.shape[1:]), dtype=dtype_np)
    x[0] = y0.copy()
    xx = as_tensor(x[0]).reshape(1,1,-1)
    for i in range(1,n_steps+1):
        xx = model_forwarder_np(xx.reshape(1,J+1,-1))
        x[i] = xx.detach().cpu().numpy().copy()
    return x

T_dur = 10000
data_full = model_simulate(y0=sortL96intoChannels(X_init,J=J), dy0=None, n_steps=T_dur)

# two-level simulates for fast and slow variables, we only take the slow ones for training !
data = data_full[:,0,:] 

# create another parametrized model to optimize
- initialized with the 'bad' paramterization from above
- we'll optimize the parametrization directly throught he (for now analytically perfect...) emulator

In [None]:
param_lin_train = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4])))

model_parametrized_train = Parametrized_twoLevel_L96(emulator=model, 
                                                     parametrization=param_lin_train)
for p in model_parametrized_train.emulator.parameters():
    p.requires_grad = False
    
model_forwarder_parametrized_train = Model_forwarder_rk4default(model=model_parametrized_train, dt=dt)

print('torch.nn.Parameters of parametrization require grad: ')
for p in model_forwarder_parametrized_train.model.param.parameters():
    print(p.requires_grad)
    
print('torch.nn.Parameters of emulator require grad: ')
for p in model_forwarder_parametrized_train.model.emulator.parameters():
    print(p.requires_grad)
    
print('initialized a', model_parametrized_train.param.a)
print('initialized b', model_parametrized_train.param.b)

In [None]:
from L96_emulator.run import sel_dataset_class, loss_function
from L96_emulator.train import train_model

lead_time = 1
normalize_data = False
prediction_task = 'state'
N_trials = 1

batch_size = 32
train_frac = 0.8
validation_frac = 0.1
spin_up_time = 0.

DatasetClass = sel_dataset_class(prediction_task, N_trials, local=False)
test_frac = 1. - (train_frac + validation_frac)
assert test_frac > 0.
spin_up = int(spin_up_time/dt)

dg_train = DatasetClass(data=data, J=J_net, offset=lead_time, normalize=bool(normalize_data), 
                   start=spin_up, 
                   end=int(np.floor(T_dur*train_frac)))
train_loader = torch.utils.data.DataLoader(
    dg_train, batch_size=batch_size, drop_last=True, num_workers=0
)
dg_val   = DatasetClass(data=data, J=J_net, offset=lead_time, normalize=bool(normalize_data), 
                   start=int(np.ceil(T_dur*train_frac)),
                   end=int(np.ceil(T_dur*(train_frac+validation_frac))))
validation_loader = torch.utils.data.DataLoader(
    dg_val, batch_size=batch_size, drop_last=True, num_workers=0
)

loss_fun = loss_function(loss_fun='mse', extra_args={})
training_outputs = train_model(
    model=model_forwarder_parametrized_train,
    train_loader=train_loader, 
    validation_loader=validation_loader, 
    device=device, 
    model_forward=model_forwarder_parametrized_train, 
    loss_fun=loss_fun,
    max_epochs=10
)

# check learned parametrization parameter a, b

- initialized at 'bad' values $a=-0.75$, $b=-0.4$
- should approach the 'ideal values $a = -0.37$, $b=-0.2$

In [None]:
print('learned a', model_parametrized_train.param.a)
print('learned b', model_parametrized_train.param.b)

# check quality of simulation
- starting at the final state of the two-level L96 training data, we visually inspect how the solutions of the slow variables look for different (parameterized) models

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout

X_init = data_full[-1].reshape(1,-1)

plt.figure(figsize=(16,6))

model_forwarders = [Model_forwarder_rk4default(model, dt=dt),
                    model_forwarder_parametrized_bad, 
                    model_forwarder_parametrized_train,
                    model_forwarder_parametrized_good,
                    model_forwarder_np]
X_inits = [X_init[:,:K].copy(), X_init[:,:K].copy(), X_init[:,:K].copy(), X_init[:,:K].copy(), X_init.copy()]
Js = [J_net, J_net, J_net, J_net, J]
panel_titles=['one-level L96', 'bad linear parametrization', 'learned parametrization', 'good linear parametrization', 'full two-level L96']
for i_model in range(len(model_forwarders)): 
    
    model_forwarder_i, X_init_i, J_i = model_forwarders[i_model], X_inits[i_model], Js[i_model]

    def model_simulate(y0, dy0, n_steps):
        x = np.empty((n_steps+1, *y0.shape[1:]))
        x[0] = y0.copy()
        xx = as_tensor(x[0]).reshape(1,1,-1)
        for i in range(1,n_steps+1):
            xx = model_forwarder_i(xx.reshape(1,J_i+1,-1))
            x[i] = xx.detach().cpu().numpy().copy()
        return x

    T_dur = 5000
    out_model = model_simulate(y0=sortL96intoChannels(X_init_i,J=J_i), dy0=None, n_steps=T_dur)

    plt.subplot(1,len(model_forwarders),i_model+1)
    plt.imshow(sortL96fromChannels(out_model[:,:1,:]).T, aspect='auto')
    plt.colorbar()
    plt.title(panel_titles[i_model])
    
    if i_model == 0:
        plt.ylabel('location k')
    if i_model == 2:
        plt.xlabel('time [steps]')
plt.show()

In [None]:
!jupyter nbconvert --output-dir='/gpfs/home/nonnenma/projects/lab_coord/mdml_wiki/marcel/emulators' --to html parametrization.ipynb