In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device, as_tensor

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### pick a (trained) emulator

In [None]:
import torch 
import numpy as np
from L96_emulator.eval import load_model_from_exp_conf
from L96_emulator.networks import named_network
from L96_emulator.run import setup

exp_id = 77 # 77: deepNet trained on 9.600 datapoints, best of 3 random network intitializations

exp_names = os.listdir('experiments/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

args = setup(conf_exp=f'experiments/{conf_exp}.yml')
args.pop('conf_exp')

K,J = args['K'], args['J']
args['model_forwarder'] = 'rk4_default'
if args['padding_mode'] == 'valid':
    print('switching from local training to global evaluation')
    args['padding_mode'] = 'circular'
model, model_forwarder, training_outputs = load_model_from_exp_conf(res_dir, args)

if not training_outputs is None:
    training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

    fig = plt.figure(figsize=(8,8))
    seq_length = args['seq_length']
    plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
    plt.title('training')
    plt.ylabel('validation error')
    plt.legend()
    fig.patch.set_facecolor('xkcd:white')
    plt.show()

### pick the reference simulator

In [None]:
exp_id_gt = 35 # 35: bilinear network with analytical weights

conf_exp_gt = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id_gt))[0][0]][:-4]
args_gt = setup(conf_exp=f'experiments/{conf_exp_gt}.yml')
args_gt.pop('conf_exp')
assert args_gt['model_forwarder'] == 'rk4_default' and args_gt['dt'] == args['dt']
model_gt, model_forwarder_gt, _ = load_model_from_exp_conf(res_dir, args_gt)

# reference simulation (also to get initial point for rollouts)

In [None]:
from L96_emulator.run import sel_dataset_class
from L96_emulator.util import predictor_corrector, rk4_default, get_data

np.random.seed(42)

dt = args['dt']
T, spinup = 5, 10

N_trials = 1 
F, h, b, c = 8., 1., 10., 10.

out, _ = get_data(K=K, J=J, T=T+spinup, dt=dt, N_trials=N_trials, F=F, h=h, b=b, c=c, 
                  resimulate=True, solver=rk4_default,
                  save_sim=False, data_dir=data_dir)
out = out.reshape(1, *out.shape) if len(out.shape)==2 else out


DatasetClass = sel_dataset_class(prediction_task='state', N_trials=N_trials, local=False)
dg_train = DatasetClass(data=out, J=J, offset=1, normalize=False, 
                   start=int(spinup/dt), 
                   end=int(np.floor((T+spinup)/dt)))


## multi-step integration (rollout) for different solvers: simulator and emulator

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.networks import Model_forwarder_rk4default
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels


T_dur = int(T/dt)
t = int(spinup/dt)


model_forwarder = Model_forwarder_rk4default(model=model, dt=dt)
model_simulate = get_rollout_fun(dg_train, model_forwarder, prediction_task='state')
out_model = model_simulate(y0=dg_train[t].copy(), dy0=None, n_steps=T_dur)
out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)


### input-output Jacobians of next-state predictions

In [None]:
from L96_emulator.networks import Model_forwarder_predictorCorrector, Model_forwarder_rk4default
from L96_emulator.util import calc_jakobian_onelevelL96_tendencies, calc_jakobian_rk4, get_jacobian_torch
from L96sim.L96_base import f1
import torch 

dX_dt = np.empty(K*(J+1), dtype=dtype_np)
def fun(t, x):
    return f1(x, F, dX_dt, K)

def model_np(inputs):
    return fun(0., inputs).copy()

for p in model.parameters():
    p.requires_grad = False
        
model_forwarder = Model_forwarder_rk4default(model=model, dt=dt)
        
inputs = out[0][t] 
inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
inputs_torch.requires_grad = True
            

J_model = get_jacobian_torch(model_forwarder, inputs=inputs_torch, n=K)

J_np = calc_jakobian_rk4(inputs, calc_f=model_np, 
             calc_J_f=calc_jakobian_onelevelL96_tendencies, dt=dt, n=K)


# some partial derivaties evaluated along example direction in state space

In [None]:
"""
for p in model.parameters():
    p.requires_grad = False

def get_partials_torch(model, inputs, i,j):
    inputs.grad = None
    L = model(inputs).flatten()[i] # f_i(x)
    L.backward()
    dfdx = inputs.grad.detach().cpu().numpy().flatten()[j] #df_i/dx_j ? 
    return dfdx


inputs_base = out[0,t]
inputs_end = out[0, -1]

nbh = 2
offsets  = np.concatenate(((np.arange(-nbh,0)), np.arange(nbh)+1))

locations = 0 * np.ones(len(offsets), dtype=np.int)
clrs = ['b', 'r', 'g', 'k', 'c', 'orange']

xx = np.linspace(-0.5, 1.5, 51)
dfdx = np.zeros((len(offsets), len(xx)))
dfdx_gt = np.zeros((len(offsets), len(xx)))

for i in range(len(offsets)):

    j = locations[i] + offsets[i]
                     

    for n in range(len(xx)):
        inputs = (1-xx[n]) * inputs_base  + xx[n] * inputs_end
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
        inputs_torch.requires_grad = True
                     
        dfdx[i, n] = get_partials_torch(model_forwarder, inputs_torch, 
                                     i=locations[i], 
                                     j=locations[i] + offsets[i])
        dfdx_gt[i, n] = get_partials_torch(model_forwarder_gt, inputs_torch,
                                        i=locations[i], 
                                        j=locations[i] + offsets[i])
                     
    plt.plot(xx, dfdx[i], color=clrs[i])
    plt.plot(xx, dfdx_gt[i], '--', color=clrs[i])
"""

# sensitivity analysis

In [None]:
for p in model.parameters():
    p.requires_grad = False

def get_partials_torch(model, inputs, i,T):
    inputs.grad = None
    out = inputs
    for t in range(T):
        out = model(out)
    L = out.flatten()[i] # f_i(x)
    L.backward()
    dfdx = inputs.grad.detach().cpu().numpy().flatten() #df_i/dx_j ? 
    return dfdx

T_start = 20
dT = T_start
k = 19

dfdx = np.zeros((dT+1,K))
dfdx_gt = np.zeros((dT+1,K))

for i in range(len(offsets)):
                     

    for s in range(dT+1):
        inputs = out_model[T_start-s].copy()
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
        inputs_torch.requires_grad = True
                     
        dfdx[s] = get_partials_torch(model_forwarder, inputs_torch, 
                                     i=k, 
                                     T=dT-s)

        inputs = out[0,t+T_start-s].copy()
        inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=J))
        inputs_torch.requires_grad = True

        dfdx_gt[s] = get_partials_torch(model_forwarder_gt, inputs_torch,
                                        i=k, 
                                        T=dT-s)
                     
plt.figure(figsize=(16,5))
plt.subplot(1,3,1)
plt.imshow(dfdx_gt.T, aspect='auto')
plt.colorbar()
plt.subplot(1,3,2)
plt.imshow(dfdx.T, aspect='auto')
plt.colorbar()
plt.subplot(1,3,3)
plt.imshow(dfdx_gt.T-dfdx.T, aspect='auto', cmap='bwr')
plt.colorbar()


In [None]:
plt.plot(dfdx_gt[-1])
plt.plot(dfdx_gt[-2])
plt.plot(dfdx_gt[-5])


# compose figure

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

fontsize = 18

sol_np = sortL96fromChannels(dg_train.data[t:t+T_dur+1]).T
sol_model = out_model.T
sols = [sol_np, sol_model, sol_model - sol_np]  
cmaps = ['viridis', 'viridis', 'bwr']
labels = ['model', 'emulator', 'difference']

vmin, vmax = np.min(np.stack((sol_np, sol_model))), np.max(np.stack((sol_np, sol_model)))
clims = [[vmin, vmax], [vmin,vmax], [-np.max(np.abs(sol_np-sol_model)), np.max(np.abs(sol_np-sol_model))]]

fig = plt.figure(figsize=(16,8))
############
# rollouts #
############
for i in range(len(sols)):
    plt.subplot(len(sols),2,2*i+1)
    plt.imshow(sols[i], aspect='auto', cmap=cmaps[i], vmin=clims[i][0], vmax=clims[i][1])
    plt.colorbar()
    plt.yticks([],fontsize=fontsize)
    if i < len(sols) - 1:
        plt.xticks([], fontsize=fontsize)
    else:
        plt.xticks(np.arange(0, sols[i].shape[1],50),
                  dt * np.arange(0, sols[i].shape[1],50),
                  fontsize=fontsize)
        plt.xlabel('time [au]', fontsize=fontsize)
    plt.ylabel(labels[i], fontsize=fontsize)
    plt.plot(dT, k, 'r+', markersize=10, linewidth=3)


#############
# Jacobians #
#############

jacobians = [J_np, J_model, J_model-J_np]
vmin, vmax = np.min(np.stack((J_np, J_model))), np.max(np.stack((J_np, J_model)))
clims = [[vmin, vmax], [vmin,vmax], [-np.max(np.abs(J_np-J_model)), np.max(np.abs(J_np-J_model))]]
cmaps = ['viridis', 'viridis', 'bwr']
labels = ['model', 'emulator', 'difference']

for i in range(len(jacobians)):
    ax =  plt.subplot(2,2*len(jacobians),len(jacobians) + i+1)
    plt.imshow(jacobians[i], cmap=cmaps[i], vmin=clims[i][0], vmax=clims[i][1])
    plt.xticks([], fontsize=fontsize)
    plt.yticks([], fontsize=fontsize)
    plt.xlabel(labels[i], fontsize=fontsize)
    box = ax.get_position()
    
    ax.set_position(box)
    
    if i == 0:
        plt.ylabel(r'Jacobians',fontsize=fontsize)
    if i == 1:
        #plt.title(r'state-update Jacobians $\frac{dx_{t+1}}{dx_t}$',fontsize=fontsize)
        axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                           bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
        plt.colorbar(cax=axins)
    if i == 2:
        box = ax.get_position()        
        box.x0 += 0.025
        box.x1 += 0.025
        ax.set_position(box)
        axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                           bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
        plt.colorbar(cax=axins)

        
#########################
# sensitivitiy analysis #
#########################

vmin, vmax = np.min(np.stack((dfdx_gt, dfdx))), np.max(np.stack((dfdx_gt, dfdx)))
clims = [[vmin, vmax], [vmin,vmax], [-np.max(np.abs(dfdx_gt-dfdx)), np.max(np.abs(dfdx_gt-dfdx))]]

sols = [dfdx_gt.T, dfdx.T, dfdx_gt.T-dfdx.T]
for i in range(len(sols)):
    ax = plt.subplot(2,2*len(sols),3*len(sols)+i+1)
    plt.imshow(sols[i], 
               aspect='auto', cmap=cmaps[i], vmin=clims[i][0], vmax=clims[i][1])
    plt.yticks([],fontsize=fontsize)
    plt.xticks([0, sols[i].shape[1]-1],
               [dt * (T_start-dT), dt * (T_start)],
               fontsize=fontsize)
    plt.title(labels[i], fontsize=fontsize)
    plt.plot(dT, k, 'r+', markersize=10, linewidth=3)
    if i == 0:
        plt.ylabel('partial derivatives', fontsize=fontsize)
    if i == 1:
        plt.xlabel('time [au]', fontsize=fontsize)
        axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                           bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
        plt.colorbar(cax=axins)
    if i == 2:
        box = ax.get_position()        
        box.x0 += 0.025
        box.x1 += 0.025
        ax.set_position(box)
        axins = inset_axes(ax, width="5%", height="100%", loc='lower left',
                           bbox_to_anchor=(1.05, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0)
        plt.colorbar(cax=axins)

        
"""
#######################
# partial derivatives #
#######################

ax = plt.subplot(2,2,4)
#plt.title(r'partial derivatives ',fontsize=fontsize)
#plt.ylabel(r'$\frac{\partial{}x_{k,t+1}}{\partial{}x_{k+l,t}}$', fontsize=1.5*fontsize)
plt.ylabel('partial derivatives ',fontsize=fontsize)
plt.xlabel(r'$\alpha$', fontsize=fontsize)

clrs = ['b', 'g', 'purple', 'orange']
labels = [r'$l = -2$', r'$l = -1$', r'$l = +1$', r'$l = +2$']
for i in range(len(dfdx)):
    plt.plot(xx, dfdx[i], color=clrs[i], label=labels[i])
    plt.plot(xx, dfdx_gt[i], '--', color=clrs[i])
plt.plot(xx.min()-1, 0, 'k--', label='emulator')

plt.axis([xx.min(), xx.max(), 1.05*dfdx.min(), 1.05*dfdx.max()])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xticks([0, 1], fontsize=fontsize)
plt.yticks([-0.5, 0, 0.5], fontsize=fontsize)
plt.legend(fontsize=fontsize, bbox_to_anchor=(1.0, 1.0), frameon=False)
plt.plot([0,0], [-0.5, 0.5], 'k:')
plt.plot([1,1], [-0.5, 0.5], 'k:')
box = ax.get_position()        
box.x1 -= 0.25
"""


plt.gcf().text(0.1,   0.9, 'a)', fontsize=fontsize, weight='bold')
plt.gcf().text(0.475, 0.9, 'b)', fontsize=fontsize, weight='bold')
plt.gcf().text(0.475, 0.4, 'c)', fontsize=fontsize, weight='bold')



plt.savefig(res_dir + 'figs/emulator_intro.pdf', bbox_inches='tight', pad_inches=0)

plt.show()


In [None]:
# verifiy that Jacobian has 12 off-diagonal elements per row (8x locations to the left and 4x to the right): 
#plt.imshow(np.log(np.abs(J_model-np.eye(K))))
#plt.colorbar()