# Emulators

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### load / simulate data

In [None]:
from L96sim.L96_base import f1, f2, J1, J1_init, f1_juliadef, f2_juliadef
from L96_emulator.util import predictor_corrector
from L96_emulator.run import sel_dataset_class

F, h, b, c = 10, 1, 10, 10
K, J, T, dt = 36, 10, 605, 0.001
spin_up_time, train_frac = 5., 0.8

fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

resimulate, save_sim = False, False
if resimulate:
    print('simulating data')
    X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,10)
    dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
    times = np.linspace(0, T, np.floor(T/dt)+1)

    out = predictor_corrector(fun=fun, y0=X_init.copy(), times=times, alpha=0.5)

    # filename for data storage
    if save_sim: 
        np.save(data_dir + fn_data, out.astype(dtype=dtype_np))
else:
    print('loading data')
    out = np.load(data_dir + fn_data + '.npy')

plt.figure(figsize=(8,4))
plt.imshow(out.T, aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.show()


prediction_task = 'state'
lead_time = 1
DatasetClass = sel_dataset_class(prediction_task=prediction_task)
dg_train = DatasetClass(data=out, J=J, offset=lead_time, normalize=False, 
                   start=int(spin_up_time/dt), 
                   end=int(np.floor(out.shape[0]*train_frac)))

### load model

In [None]:
from L96_emulator.networks import MinimalNetL96
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch 

model = MinimalNetL96(K,J,F,b,c,h,skip_conn=True,loc=1e3)
std_out = torch.as_tensor(dg_train.std, device=device, dtype=dtype)
mean_out = torch.as_tensor(dg_train.mean, device=device, dtype=dtype)

def model_forward(x):
    alpha = 0.5
    ndim = x.ndim

    x = sortL96fromChannels(x * std_out + mean_out) if ndim == 3 else x

    f0 = model.forward(x)
    f1 = model.forward(x + dt*f0)

    x = x + dt * (alpha*f0 + (1-alpha)*f1)
    x = (sortL96intoChannels(x, J=J) - mean_out) / std_out

    return  sortL96fromChannels(x) if ndim == 2 else x

### example rollout

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.eval import solve_from_init

model_simulate = get_rollout_fun(dg_train, model_forward, prediction_task)

T_start, T_dur = 100*int(spin_up_time/dt), 10000
out_model = model_simulate(y0=dg_train[T_start].copy(), 
                           dy0=dg_train[T_start]-dg_train[T_start-dg_train.offset],
                           T=T_dur)
out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

solver_comparison = True 
if solver_comparison:
    try: 
        print(F, h, b, c)
    except: 
        F, h, b, c = 10, 1, 10, 10

    out2 = solve_from_init(K, J, 
                           T_burnin=T_start, T_=T_dur, dt=dt, 
                           F=F, h=h, b=b, c=c, 
                           data=out, dilation=2, norm_mean=0., norm_std=1.)

fig = plot_rollout(out, out_model, out_comparison=out2, T_start=T_start, T_dur=T_dur, K=None)


# Solving a fully-observed inverse problem

In [None]:
"""
from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T, N = 10, len(T_start)

roller_outer = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N)
x_init = roller_outer.X.detach().numpy().copy()

target = torch.as_tensor(out[T_start+T], dtype=dtype, device=device)

n_steps, lr, weight_decay = 1000, 5e-2, 0.
roller_outer.train()
optimizer = torch.optim.Adam(roller_outer.parameters(), lr=lr, weight_decay=weight_decay)
loss_vals = np.zeros(n_steps)
for i in range(n_steps):
        optimizer.zero_grad()
        loss = ((roller_outer.forward(T=T) - target)**2).mean()
        loss.backward()
        optimizer.step()
        loss_vals[i] = loss.detach().numpy()

plt.figure(figsize=(8,2))
plt.semilogy(loss_vals, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()
"""

In [None]:

from L96_emulator.eval import Rollout

T_start = np.array([5000, 10000, 150000])
T, N = 100, len(T_start)

roller_outer = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N)
x_init = roller_outer.X.detach().cpu().numpy().copy()

target = torch.as_tensor(out[T_start+T], dtype=dtype, device=device)

n_steps, lr, weight_decay = 100, 1e-2, 0.
roller_outer.train()

#optimizer = torch.optim.Adam(roller_outer.parameters(), lr=lr, weight_decay=weight_decay)
optimizer = torch.optim.LBFGS(params=roller_outer.parameters(), 
                              lr=lr, 
                              max_iter=20, 
                              max_eval=None, 
                              tolerance_grad=1e-07, 
                              tolerance_change=1e-09, 
                              history_size=100, 
                              line_search_fn=None)

loss_vals = np.zeros(n_steps)

for i in range(n_steps):
    def closure():
        loss = ((roller_outer.forward(T=T) - target)**2).mean()
        optimizer.zero_grad()
        loss.backward()

        return loss    
    optimizer.step(closure)
    loss = ((roller_outer.forward(T=T) - target)**2).mean()
    loss_vals[i] = loss.detach().cpu().numpy()

        
plt.figure(figsize=(8,2))
plt.semilogy(loss_vals, label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()

loss_vals

In [None]:
"""
import torch 
import numpy as np
from L96_emulator.eval import Rollout

N, T_rollout = 100, 10 # number of rollout starting points, number of rollout steps
exp_id = 'minimalnet_fullyconn_skipconn_J10' # trained network for which initstate_train.py was run
lead_time = 1 # number of steps predicted

save_dir = res_dir + 'models/' + exp_id + '/'
model_fn = f'{exp_id}_dt{lead_time}.pt'
results_fn = f'_rollout_outputs_K{K}_J{J}_T{T}_N{N}_TR{T_rollout}'
output_fn = f'_rollout_training_outputs_K{K}_J{J}_T{T}_N{N}_TR{T_rollout}'

training_outputs = np.load(save_dir + output_fn + '.npy', allow_pickle=True)[()]
roller_outer = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N)
roller_outer.load_state_dict(torch.load(save_dir + results_fn, map_location=torch.device(device)))

x_init, T_start = training_outputs['x_init'],  training_outputs['T_start']

plt.figure(figsize=(16,4))
plt.semilogy(training_outputs['training_loss'], label='initialization')
plt.title('rollout final state loss across gradient descent steps')
plt.ylabel('MSE)')
plt.xlabel('gradient step')
plt.show()
"""

## plotting

In [None]:
N_max = 3 # chose N_max << N if N is very large and you don't want hundreds of subplots

plt.figure(figsize=(16,2*N_max))
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(x_init[n], 'k', label='init', alpha=0.2)
    plt.plot(roller_outer.X.detach().cpu().numpy()[n,:], color='orange', linewidth=1.5, label='target')
    plt.plot(out[T_start[n]].T, 'b--', linewidth=0.5, label='est.')
    plt.xlabel('state dimension')
    plt.ylabel(f'iniital state at T = {T_start[n]}')
    plt.legend()
plt.suptitle('estimated initial state')
plt.show()

plt.figure(figsize=(16,2*N_max))
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(out[T_start[n]].flatten() - roller_outer.X.detach().cpu().numpy()[n,:], 'orange', 
             label='est. - true initial state')
    plt.plot(out[T_start[n]].flatten() - out[T_start[n]+T_rollout].flatten(), 'k', alpha=0.3, 
             label='future - initial state')
    plt.legend()
    plt.xlabel('state dimension')
    plt.ylabel(f'initial state error at T = {T_start[n]}')
plt.suptitle('error of estimated initial state')
plt.show()

plt.figure(figsize=(16,2*N_max))
roller_outer2 = Rollout(model_forward, prediction_task='state', K=K, J=J, N=N, x_init=out[T_start])
for n in range(N_max):
    plt.subplot(np.ceil(N_max/2),2,n+1)
    plt.plot(out[T_start[n]+T_rollout].flatten() - roller_outer.forward(T=T_rollout).detach().cpu().numpy()[n], 
             'orange', label='true final state - rollout from est. init. state')
    plt.plot(out[T_start[n]+T_rollout].flatten() - roller_outer2.forward(T=T_rollout).detach().cpu().numpy()[n], 
             'k', alpha=0.3, label='true final state - rollout from true init. state')
    plt.legend()
    plt.xlabel('state dimension')
    plt.ylabel(f'final state error at T = {T_start[n]}+{T_rollout}')
plt.suptitle('error of estimated final state (under the learned model)')
plt.show()