In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import numpy as np
import os

from L96_emulator.run import setup, sel_dataset_class
from L96_emulator.util import sortL96fromChannels, sortL96intoChannels
from L96_emulator.eval import get_rollout_fun, plot_rollout, load_model_from_exp_conf

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

exps = range(1,20) # experiments (.yml) are number 1 - N in experiments/ folder !
exp_names = os.listdir('experiments/')   
Js = (10,) # filter experiments by this list of values for number of fast variables 

fig_loss = plt.figure(figsize=(8,8))
for i in exps:

    # load setup
    i_ = "{:02d}".format(i)
    
    conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==i_)[0][0]]
    args = setup(conf_exp=f'experiments/{conf_exp}')
    args.pop('conf_exp')

    if not args['J'] in Js:
        continue

    print('\n experiment ' + args['exp_id'])
    K, J, T, dt = args['K'], args['J'], args['T'], args['dt']
    spin_up_time = args['spin_up_time']

    fn_data = f'out_K{K}_J{J}_T{T}_dt0_{str(dt)[2:]}'
    out = np.load(data_dir + fn_data + '.npy')
    print('data.shape', out.shape)
    print('data.dtype', out.dtype)

    # load trained model

    model, model_forward, training_outputs = load_model_from_exp_conf(res_dir, args)

    if not training_outputs is None:
        training_loss, validation_loss = training_outputs['training_loss'], training_outputs['validation_loss']

        seq_length = args['seq_length']
        plt.figure(fig_loss.number)
        plt.semilogy(validation_loss, label=conf_exp+ f' ({seq_length * (J+1)}-dim)')
        plt.title('training')
        plt.ylabel('validation error')
        plt.legend()


    # compute example rollout

    DatasetClass = sel_dataset_class(prediction_task=args['prediction_task'])
    dg_train = DatasetClass(data=out, J=J, offset=args['lead_time'], normalize=bool(args['normalize_data']),
                            start=int(spin_up_time/dt), end=int(np.floor(out.shape[0]*args['train_frac'])))

    model_simulate = get_rollout_fun(dg_train, model_forward, args['prediction_task'])

    T_start, T_dur = int(spin_up_time/dt), 200
    out_model = model_simulate(y0=dg_train[T_start].copy(), 
                               dy0=dg_train[T_start]-dg_train[T_start-dg_train.offset],
                               T=T_dur)
    out_model = sortL96fromChannels(out_model * dg_train.std + dg_train.mean)

    fig = plt.figure(figsize=(16,9))
    fig = plot_rollout(out, out_model, out_comparison=None, T_start=T_start, T_dur=T_dur, K=K, fig=fig)    
    fig.suptitle(args['exp_id'])
    fig.show()

plt.figure(fig_loss.number)
fig_loss.patch.set_facecolor('xkcd:white')
plt.show()
