In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import numpy as np

from L96_emulator.run import setup, sel_dataset_class
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
from L96_emulator.eval import get_rollout_fun, plot_rollout, solve_from_init, load_model_from_exp_conf

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

exps = range(1,11) # experiments (.yml) are number 1 - N in experiments/ folder !

fig_loss = plt.figure(figsize=(8,8))
for i in exps:

    # load setup
    i_ = "{:02d}".format(i)
    
    conf_exp = f'{i_}_resnet_1x1convs_predictState'
    args = setup(conf_exp=f'experiments/{conf_exp}.yml')
    args.pop('conf_exp')
    print('\n experiment ' + args['exp_id'])
    K, J, T, dt = args['K'], args['J'], args['T'], args['dt']
    spin_up_time = args['spin_up_time']

    fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
    out = np.load(data_dir + fn_data + '.npy')
    print('data.shape', out.shape)


    # compute comparison solution

    F, h, b, c = 10, 1, 10, 10
    T_burnin, T_comparison = int(spin_up_time/dt), 5000
    out2 = solve_from_init(K, J, 
                           T_burnin=T_burnin, T_=T_comparison, dt=dt, 
                           F=F, h=h, b=b, c=c, 
                           data=out, dilation=2, norm_mean=0., norm_std=1.)


    # load trained model

    model, model_forward, training_outputs = load_model_from_exp_conf(res_dir, args)

    if not training_outputs is None:
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

        seq_length = args['seq_length']
        plt.figure(fig_loss.number)
        plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
        plt.title('training')
        plt.ylabel('validation error')
        plt.legend()


    # compute example rollout

    DatasetClass = sel_dataset_class(prediction_task=args['prediction_task'])
    dg_train = DatasetClass(data=out, J=J, offset=args['lead_time'], normalize=True, 
                       start=int(spin_up_time/dt), 
                       end=int(np.floor(out.shape[0]*args['train_frac'])))

    model_simulate = get_rollout_fun(dg_train, model_forward, args['prediction_task'])

    T_start, T_dur = int(spin_up_time/dt), 200
    out_model = model_simulate(y0=dg_train[T_start].copy(), 
                               dy0=dg_train[T_start]-dg_train[T_start-dg_train.offset],
                               T=T_dur)
    out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

    fig = plt.figure(figsize=(16,9))
    fig = plot_rollout(out, out_model, out_comparison=out2, T_start=T_start, T_dur=T_dur, K=K, fig=fig)    
    fig.suptitle(args['exp_id'])
    fig.show()

plt.figure(fig_loss.number)
fig_loss.patch.set_facecolor('xkcd:white')
plt.show()
