# Emulators

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim
from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

### load specific experiment

In [None]:
from L96_emulator.run import setup

exp_id = 20

exp_names = os.listdir('experiments/')   
conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]

args = setup(conf_exp=f'experiments/{conf_exp}.yml')
args.pop('conf_exp')

### load / simulate data

In [None]:
from L96sim.L96_base import f1, f2, J1, J1_init, f1_juliadef, f2_juliadef
from L96_emulator.util import predictor_corrector

try: 
    K, J, T, dt = args['K'], args['J'], args['T'], args['dt']
    spin_up_time = args['spin_up_time']
except:
    F, h, b, c = 10, 1, 10, 10
    K, J, T, dt = 36, 10, 605, 0.001
    spin_up_time = 5.

fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'

resimulate = False
if resimulate:
    print('simulating data')
    X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0).astype(dtype=dtype_np) / np.maximum(J,10)
    dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
    times = np.linspace(0, T, int(np.floor(T/dt))+1)

    if J > 0:
        def fun(t, x):
            return f2(x, F, h, b, c, dX_dt, K, J)
    else:
        def fun(t, x):
            return f1(x, F, dX_dt, K)

    out = predictor_corrector(fun=fun, y0=X_init.copy(), times=times, alpha=0.5)

    # filename for data storage
    np.save(data_dir + fn_data, out.astype(dtype=dtype))
else:
    print('loading data')
    out = np.load(data_dir + fn_data + '.npy')

plt.figure(figsize=(8,4))
plt.imshow(out.T, aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.show()

In [None]:
out.dtype

### optional: create short comparison solution with different step size for numerical solver

In [None]:
from L96_emulator.eval import solve_from_init

solver_comparison = True 

if solver_comparison:
    try: 
        print(F, h, b, c)
    except: 
        F, h, b, c = 10, 1, 10, 10

    T_burnin, T_comparison = int(spin_up_time/dt), 5000
    out2 = solve_from_init(K, J, 
                           T_burnin=T_burnin, T_=T_comparison, dt=dt, 
                           F=F, h=h, b=b, c=c, 
                           data=out, dilation=2, norm_mean=0., norm_std=1.)

# Learn local emulator

In [None]:
#%run -i 'main_train.py -c experiments/template.yml'

# Evaluate model fit

In [None]:
"""
import torch 
import numpy as np
from L96_emulator.eval import load_model_from_exp_conf

model, model_forward, training_outputs = load_model_from_exp_conf(res_dir, args)

if not training_outputs is None:
    training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

    fig = plt.figure(figsize=(8,8))
    seq_length = args['seq_length']
    plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
    plt.title('training')
    plt.ylabel('validation error')
    plt.legend()
    fig.patch.set_facecolor('xkcd:white')
    plt.show()

model.layers_ks1, model.layers3x3
"""

In [None]:
from L96_emulator.networks import AnalyticModel_twoLevel
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
import torch

model = AnalyticModel_twoLevel(K=K, J=J, F=F, b=b, c=c, h=h, loc=1e3)

### example rollout

In [None]:
from L96_emulator.run import sel_dataset_class
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels, init_torch_device
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.eval import solve_from_init

DatasetClass = sel_dataset_class(prediction_task=args['prediction_task'])
dg_train = DatasetClass(data=out, J=J, offset=args['lead_time'], normalize=True, 
                   start=int(args['spin_up_time']/args['dt']), 
                   end=int(np.floor(out.shape[0]*args['train_frac'])))


def model_forward(x):
    alpha = 0.5
    x = sortL96fromChannels(x.detach().cpu().numpy() * dg_train.std + dg_train.mean) 
    f0 = model.forward(x.T).T
    f1 = model.forward((x + dt*f0).T).T
    #f0 = fun(0., x.T).T
    #f1 = fun(0., (x + dt*f0).T ).T    
    out = (sortL96intoChannels(x + dt * (alpha*f0 + (1-alpha)*f1), J=J) - dg_train.mean) / dg_train.std 
    
    return torch.as_tensor(out, device=device, dtype=dtype)

model_simulate = get_rollout_fun(dg_train, model_forward, args['prediction_task'])

T_start, T_dur = 100*int(spin_up_time/dt), 200
out_model = model_simulate(y0=dg_train[T_start].copy(), 
                           dy0=dg_train[T_start]-dg_train[T_start-dg_train.offset],
                           T=T_dur)
out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

solver_comparison = True 
if solver_comparison:
    try: 
        print(F, h, b, c)
    except: 
        F, h, b, c = 10, 1, 10, 10

    out2 = solve_from_init(K, J, 
                           T_burnin=T_start, T_=T_dur, dt=dt, 
                           F=F, h=h, b=b, c=c, 
                           data=out, dilation=2, norm_mean=0., norm_std=1.)

fig = plot_rollout(out, out_model, out_comparison=out2, T_start=T_start, T_dur=T_dur, K=K)


In [None]:
from L96_emulator.train import calc_val_loss, loss_function
from L96_emulator.util import init_torch_device

device = init_torch_device()

batch_size, train_frac, validation_frac = 32, 0.8, 0.1
dg_val   = DatasetClass(data=out, J=J, offset=1, normalize=True, 
                   start=int(np.ceil(out.shape[0]*train_frac)), 
                   end=int(np.floor(out.shape[0]*(train_frac+validation_frac))))

validation_loader = torch.utils.data.DataLoader(
    dg_val, batch_size=batch_size, drop_last=False, num_workers=0 
)

loss_fun = loss_function(loss_fun='mse', extra_args={})


model = AnalyticModel_twoLevel(K=K, J=J, F=F, b=b, c=c, h=h, loc=1e0, scale=1e6)
def model_forward(x):
    alpha = 0.5
    x = sortL96fromChannels(x.detach().cpu().numpy() * dg_train.std + dg_train.mean) 
    f0 = model.forward(x.T).T
    f1 = model.forward((x + dt*f0).T).T
    #f0 = fun(0., x.T).T
    #f1 = fun(0., (x + dt*f0).T ).T    
    out = (sortL96intoChannels(x + dt * (alpha*f0 + (1-alpha)*f1), J=J) - dg_train.mean) / dg_train.std 
    
    return torch.as_tensor(out, device=device, dtype=dtype)

calc_val_loss(validation_loader, model_forward, device, loss_fun)

In [None]:
from L96_emulator.networks import AnalyticModel_oneLevel, AnalyticModel_twoLevel
         
J = 10
F, h, b, c = 10, 1, 10, 10
K, T, dt = 36, 605, 0.001
spin_up_time = 5.


print('simulating data')
X_init = F * (0.5 + np.random.randn(K*(J+1)) * 1.0) / np.maximum(J,10)
dX_dt = np.empty(X_init.size, dtype=X_init.dtype)
times = np.linspace(0, T, int(np.floor(T/dt))+1)

if J > 0:
    def fun(t, x):
        return f2(x, F, h, b, c, dX_dt, K, J)
else:
    def fun(t, x):
        return f1(x, F, dX_dt, K)

out = predictor_corrector(fun=fun, y0=X_init.copy(), times=times, alpha=0.5)    
    

for skip_conn in [False, True]:
    if J > 0:
        def get_model(loc=1.):
            return AnalyticModel_twoLevel(K, J, F=F, b=b, c=c, h=h, loc=loc, skip_conn=skip_conn)
    else:
        def get_model(loc=1.):
            return AnalyticModel_oneLevel(K=K, F=F, loc=loc, skip_conn=skip_conn)

    t = 500000
    model = get_model()

    plt.figure(figsize=(12,8))
    plt.subplot(1,2,1)
    dX_dt = np.empty(K*(J+1))
    plt.plot(fun(t, out[t]) - model.forward(out[t]).flatten())
    plt.title('difference between network and rhs of diff.eq.')
    plt.xlabel('location k')
    plt.subplot(1,2,2)
    locs = 10.**np.arange(-5, 10)    
    errs = np.zeros(len(locs))
    for i,loc in enumerate(locs):
        model = get_model(loc=loc)

        errs[i] = np.mean((fun(t, out[t]) - model.forward(out[t]))**2)
    plt.loglog(locs, errs)
    plt.title('identity through nonlinearity: location for shifting data')
    plt.xlabel('loc parameter')
    plt.ylabel('MSE between network and rhs of diff.eq.')
    plt.show()
    print('one-step error:', np.min(errs))

    model = get_model()

    y1 = out[t] + dt*model.forward(out[t])
    plt.figure(figsize=(12,8))
    plt.subplot(1,2,1)
    plt.plot(dt/2*(model.forward(y1) + model.forward(out[t])), label='est')
    plt.plot(out[t+1]-out[t], '--', label='finite-diff')
    plt.legend()
    plt.title(f'est. vs numerical temporal difference, example step t={t}')
    plt.xlabel('location k')
    plt.ylabel('state X_k')
    plt.subplot(1,2,2)
    plt.plot(out[t+1]-out[t] - dt/2*(model.forward(y1) + model.forward(out[t])))
    plt.title(f'error in temporal difference, example step t={t}')
    plt.xlabel('location k')
    plt.ylabel('state difference X_k - \hat{X_k}')
    plt.show()
    print('example MSE :', np.mean( (out[t] + dt/2.*(model.forward(y1) + model.forward(out[t])) - out[t+1])**2 ))


    MSE, t_range = 0., np.arange(5000, out.shape[0]-1, 100)
    for t in t_range:
        MSE += np.mean( (out[t] + dt/2.*(model.forward(y1) + model.forward(out[t])) - out[t+1])**2 )
    MSE /= len(t_range)
    print('subsampled MSE :', MSE)

In [None]:
# 
data_std = np.linalg.norm(out[np.arange(T_dur+1)+T_start], axis=1)
def prediction_horizon(data, est, threshold=0.1, data_std=1.):

    assert data.ndim == 2
    assert est.ndim == 2
    L2 = np.sqrt(np.sum( (data - est)**2, axis=1 ))

    return np.argmax(L2/data_std>threshold)

def prediction_error(data, est, dt, T_horizon=None, data_std=1.):

    T_horizon = int(np.ceil(0.5/dt)) if T_horizon is None else T_horizon
    assert T_horizon == T_horizon//1, 'has to be integer'
    T_horizon = int(T_horizon)
    
    assert data.ndim == 2
    assert est.ndim == 2
    L2 = np.sqrt(np.sum( (data - est)**2, axis=1 ))

    return np.sum(L2[:T_horizon]/data_std) / (T_horizon*dt)

pred_horizons = (prediction_horizon(out[np.arange(T_dur+1)+T_start], out_model, threshold=0.1, data_std=data_std), 
                 prediction_horizon(out[np.arange(T_dur+1)+T_start], out2[:T_dur+1], threshold=0.1, data_std=data_std))
pred_errors = (prediction_error(out[np.arange(T_dur+1)+T_start], out_model, dt=args['dt'], data_std=data_std), 
              prediction_error(out[np.arange(T_dur+1)+T_start], out2[:T_dur+1], dt=args['dt'], data_std=data_std))

data, est = out[np.arange(T_dur+1)+T_start], out_model
L2 = np.sqrt(np.sum( (data - est)**2, axis=1 ))
plt.plot(L2/data_std)

pred_horizons, pred_errors

In [None]:
def adams_bashfort(model_simulate, dg_train, T_start, T_dur):

    y = dg_train[T_start].copy()
    dy_prev = model_simulate(y0=y, dy0=None, T=1)[-1] - y    

    out = np.empty((T_dur+1, *y.shape[1:]))
    out[0] = y.copy()
    y += dy_prev
    out[1] = y.copy()

    for t in range(2,T_dur+1):
        dy_new = model_simulate(y0=y, dy0=None, T=1)[-1] - y
        y += 0.5 * (3 * dy_new - 1. * dy_prev)
        dy_prev = dy_new
        out[t] = y

    return out

out_ab = adams_bashfort(model_simulate, dg_train, T_start, T_dur=T_dur)
out_ab = sortL96fromChannels(out_ab * dg_train.std + dg_train.mean)

fig = plot_rollout(out, out_model, out_comparison=out_ab, T_start=T_start, T_dur=T_dur)
plt.subplot(1,2,2)
plt.legend(['direct model reconstruction', 'Adams-Bashfort explicit step'])

# debug corner - might not execture anymore

### older fits - comparison of validation errors

In [None]:
import matplotlib.pyplot as plt
import os

conf_exps = [
            f'0{i}_resnet_1x1convs_predictState' for i in range(1,6) # initially just gave the network fits version numbers V0 - V9
          ]

def find_weights(fn):
    return fn[-3:] == '.pt'


fig = plt.figure(figsize=(16,9))
for i, conf_exp in enumerate(conf_exps):
    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')
    exp_id, dim = args['exp_id'], args['J']+1

    save_dir = res_dir + 'models/' + exp_id + '/'


    #plt.subplot(1,2,1)
    try:
        training_outputs = np.load(save_dir + '_training_outputs' + '.npy', allow_pickle=True)[()]
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']
        plt.semilogy(validation_loss, label=exp_id + f' ({dim}D)')
    except:
        plt.semilogy(1., label=exp_id + f' ({dim}D)')            
    plt.title('training')

plt.ylabel('validation error')
plt.legend()
fig.patch.set_facecolor('xkcd:white')
plt.show()

"""
plt.figure(figsize=(16,4))
cellText = np.hstack((np.array(conf_exps).reshape(-1,1), np.around(RMSEall,2)))
collabel=('experiment', f'RMSE {lead_time}h, z 500', f'RMSE {lead_time}h, t 850')
plt.axis('tight')
plt.axis('off')
plt.table(cellText=cellText,colLabels=collabel,loc='center')
plt.show()
"""


### more plotting

In [None]:
T_burnin = 10000
out_model = model_simulate(y0=dg_train[T_burnin].copy(), dy0 = None, T=T_)#.reshape(-1, K*(J+1))


plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
for i in range(J+1):
    plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(2,) ))[:,i], 
             'b--')
plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(2,) ))[:,0], 'k', linewidth=2,
         label='slow variables')
plt.semilogy(-1, 1, 'b--', label=f'fast variables (J={str(J)})')
plt.axis([0,20,0.0000001, 0.5])
plt.legend()
plt.xlabel('iterations')
plt.ylabel('MSE')
plt.title('error over iterations, per variable type')

plt.subplot(1,2,2)
for i in range(J+1):
    plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(2,) ))[:,i], 
             'b--')
plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(2,) ))[:,0], 'k', linewidth=2,
         label='slow variables')
plt.semilogy(-1, 1, 'b--', label=f'fast variables (J={str(J)})')
plt.axis([0,1300,0.0000001, 1000])
plt.legend()
plt.xlabel('iterations')
plt.ylabel('MSE')
plt.title('error over iterations, per variable type')

plt.show()

In [None]:
T_burnin = 10000
out_model = model_simulate(y0=dg_train[T_burnin].copy(), T=T_)#.reshape(-1, K*(J+1))


plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
for i in range(K):
    plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(1,) ))[:,i], 
             'b--')
plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(1,) ))[:,0], 'k', linewidth=2,
         label='slow variables')
plt.semilogy(-1, 1, 'b--', label=f'fast variables (J={str(J)})')
plt.axis([0,20,0.0000001, 0.5])
plt.legend()
plt.xlabel('iterations')
plt.ylabel('MSE')
plt.title('error over iterations, per variable type')

plt.subplot(1,2,2)
for i in range(K):
    plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(1,) ))[:,i], 
             'b--')
plt.semilogy(np.arange(1, T_+1), (np.mean( (out_model[1:] - dg_train[np.arange(1,T_+1)+ T_burnin])**2, axis=(1,) ))[:,0], 'k', linewidth=2,
         label='slow variables')
plt.semilogy(-1, 1, 'b--', label=f'fast variables (J={str(J)})')
plt.axis([0,1300,0.0000001, 1000])
plt.legend()
plt.xlabel('iterations')
plt.ylabel('MSE')
plt.title('error over iterations, per variable type')

plt.show()

In [None]:
t = 10000
plt.plot((dg_train.mean_in + dg_train.std_in * dg_train[t][0,:,:]).reshape(K*(J+1)) - out[t])


In [None]:
t = 5000
plt.figure(figsize=(12,8))
plt.plot(dg_train[t+0].flatten(), label='t=0')
plt.plot(dg_train[t+1].flatten(), label='t=1')
plt.plot(model_simulate(y0=dg_train[t+0].copy(), T=1)[-1,:,:].flatten(), 'k--')
plt.show()

In [None]:
t = 10000
plt.figure(figsize=(12,8))
plt.plot(dg_train[t+1].flatten() - dg_train[t+0].flatten(), label='sim')
plt.plot(model_simulate(y0=dg_train[t+0].copy(), T=1)[-1,:,:].flatten()  - dg_train[t+0].flatten(), label='model')
plt.show()

In [None]:
for t in [0, 100, 1000, 10000]:
    plt.plot(model_simulate(y0=dg_train[t+0].copy(), T=1)[-1,:,:].flatten()  - dg_train[t+1].flatten(), label='model')
    plt.show()


In [None]:

plt.semilogy(np.std(np.diff(dg_train[np.arange(T_+1)+ T_burnin].reshape(-1,(J+1)*K), axis=0), axis=0))
plt.xlabel('variable ID (slow: first K=36)')
plt.ylabel('std')
plt.title('variability of 1-step temporal differences')
plt.axis([0, 397, 0.001, 0.1])
plt.show()

In [None]:
vmin, vmax

In [None]:
plt.figure(figsize=(16,12))
T_burnin = 10000
T_ = 10000
plt.imshow(dg_train[np.arange(T_+1)+ T_burnin].reshape(-1,(J+1)*K).T, aspect='auto', vmin=vmin, vmax=vmax)
plt.xlabel('time')
plt.ylabel('location')
plt.title('numerical simulation')
plt.colorbar()
plt.show()

In [None]:
dg_train[np.arange(T_+1)+ T_burnin][:,:,:].shape

In [None]:
plt.figure(figsize=(16,12))
T_burnin = 10000
T_ = 10000
plt.imshow(dg_train[np.arange(T_+1)+ T_burnin][:,:,0].reshape(-1,J+1).T, aspect='auto', vmin=vmin, vmax=vmax)
plt.xlabel('time')
plt.ylabel('location')
plt.title('numerical simulation')
plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize=(8,9))
plt.subplot(2,1,1)
plt.imshow(dg_train[np.arange(100)+T_burnin].reshape(100,-1).T - dg_train[T_burnin].reshape(-1,1), aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.title('numerical simulation, differences to yo')
plt.colorbar()
plt.subplot(2,1,2)
plt.imshow(out_model[:100].reshape(100,-1).T - dg_train[T_burnin].reshape(-1,1), aspect='auto')
plt.xlabel('time')
plt.ylabel('location')
plt.title('model-reconstructed simulation, differences to yo')
plt.colorbar()

plt.show()

In [None]:
from L96_emulator.dataset import DatasetRelPred
dg_train = DatasetRelPred(data=out, J=J, offset=temporal_offset, normalize=True, 
                          start=T_burnin, 
                          end=int(np.floor(out.shape[0]*0.8)))
dg_val   = DatasetRelPred(data=out, J=J, offset=temporal_offset, normalize=True, 
                          start=int(np.ceil(out.shape[0]*0.8)), 
                          end=int(np.floor(out.shape[0]*0.9)))

In [None]:
ct = 0
s = np.zeros((474000, 11, 36))
for batch in dg_train:
    X,Y = batch
    #print(X.shape, Y.shape)
    
    s[ct] = Y.copy()
    ct += 1
    if ct > 1000:
        pass #break
m = np.mean(s, axis=(0,2))
s = np.std(s, axis=(0,2))
print(m, s)