In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import L96sim

from L96_emulator.util import dtype, dtype_np, device

res_dir = '/gpfs/work/nonnenma/results/emulators/L96/'
data_dir = '/gpfs/work/nonnenma/data/emulators/L96/'

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.parametrization import Parametrized_twoLevel_L96, Parametrization_lin, Parametrization_nn
from L96_emulator.networks import Model_forwarder_rk4default
from L96_emulator.run_parametrization import setup_parametrization
from L96_emulator.data_assimilation import get_model
from L96_emulator.util import as_tensor, dtype_np, sortL96intoChannels, sortL96fromChannels
from L96sim.L96_base import f1, f2, pf2
import numpy as np
import torch
import os

def get_eval_parametrization(exp_ids, data=None, 
                             n_start_rollout=10, T_dur=5000, 
                             T_data=1000, N_trials=1000,
                             ifplot=False):

    T_start = None 
    i_trial = None     

    for exp_id in exp_ids:
        exp_names = os.listdir('experiments_parametrization/')   
        conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_parametrization(conf_exp=f'experiments_parametrization/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/parametrization/' + args['exp_id'] + '/'
        out = np.load(res_dir + save_dir + 'out.npy', allow_pickle=True)[()]

        X_init = out['X_init']

        model_pars = {
            'exp_id' : args['model_exp_id'],
            'model_forwarder' : args['model_forwarder'],
            'K_net' : args['K'],
            'J_net' : 0,
            'dt_net' : args['dt']
        }

        ##########################
        #       models           #
        ##########################

        # trained parametrized model
        model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

        if args['parametrization'] == 'linear':
            param_train = Parametrization_lin(a=as_tensor(out['param_train_state_dict']['a']), 
                                              b=as_tensor(out['param_train_state_dict']['b']))
        elif args['parametrization'] == 'nn':
            param_train = Parametrization_nn(n_hiddens=args['n_hiddens'], n_in=1,n_out=1)
            for key, value in out['param_train_state_dict'].items():
                out['param_train_state_dict'][key] = as_tensor(value)
            param_train.load_state_dict(out['param_train_state_dict'])
        model_parametrized_train = Parametrized_twoLevel_L96(emulator=model, parametrization=param_train)
        model_forwarder_parametrized_train = Model_forwarder_rk4default(model=model_parametrized_train, dt=args['dt'])

        # initial and reference parametrized models
        param_ref = Parametrization_lin(a=as_tensor(np.array([-0.31])), b=as_tensor(np.array([-0.2])))
        param_init = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4])))
        model_parametrized_init = Parametrized_twoLevel_L96(emulator=model, parametrization=param_init)
        model_forwarder_parametrized_init = Model_forwarder_rk4default(model=model_parametrized_init, dt=args['dt'])
        model_parametrized_ref = Parametrized_twoLevel_L96(emulator=model, parametrization=param_ref)
        model_forwarder_parametrized_ref = Model_forwarder_rk4default(model=model_parametrized_ref, dt=args['dt'])

        # ground-truth high-res model
        dX_dt_oneLevel = np.empty(args['K'], dtype=dtype_np)
        dX_dt_twoLevel = np.empty(args['K']*(args['J']+1), dtype=dtype_np)
        def fun_oneLevel(t, x):
            return f1(x, args['l96_F'], dX_dt_oneLevel, args['K']).copy()
        def fun_twoLevel(t, x):
            return f2(x, args['l96_F'], args['l96_h'], args['l96_b'], args['l96_c'], dX_dt_twoLevel, args['K'], args['J']).copy()

        class Torch_solver(torch.nn.Module):
            # numerical solver (from numpy/numba/Julia)
            def __init__(self, fun):
                self.fun = fun
            def forward(self, x):
                J = x.shape[-2] - 1
                x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
                return as_tensor(sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))


        model_forwarder_np_twoLevel = Model_forwarder_rk4default(Torch_solver(fun_twoLevel), dt=args['dt'])
        model_forwarder_np_oneLevel = Model_forwarder_rk4default(Torch_solver(fun_oneLevel), dt=args['dt'])

        ##########################
        #       test data        #
        ##########################

        from L96_emulator.util import rk4_default, get_data

        spin_up_time = 5.
        T = T_data*args['dt'] + spin_up_time

        try:
            assert data.shape == (N_trials, int(T/args['dt'])+1, args['K']*(args['J']+1))
            print('found data with matching specs (shape)')
        except:
            print('generating test data')
            data, _ = get_data(K=args['K'], J=args['J'], T=T, dt=args['dt'], N_trials=N_trials, 
                              F=args['l96_F'], h=args['l96_h'], b=args['l96_b'], c=args['l96_c'], 
                              resimulate=True, solver=rk4_default,
                              save_sim=False, data_dir=data_dir)

        ##########################
        #       rollouts         #
        ##########################

        model_forwarders = [model_forwarder_np_oneLevel,
                            model_forwarder_parametrized_init, 
                            model_forwarder_parametrized_train,
                            model_forwarder_parametrized_ref,
                            model_forwarder_np_twoLevel]


        Js = [0, 0, 0, 0, args['J']]
        panel_titles=['one-level L96', 
                      'initial param.', 
                      'learned param.', 
                      'reference param.', 
                      'two-level L96']
        n_start = n_start_rollout
        if T_start is None:
            T_start = np.linspace(int(spin_up_time/args['dt']), int(T/args['dt']), n_start).astype(np.int)
            i_trial = np.random.choice(N_trials, size=T_start.shape, replace=False)
        print('T_start, i_tria', (T_start, i_trial))

        sols = np.nan * np.ones((len(model_forwarders), n_start, T_dur+1, args['K']))
        
        for i_model in range(len(model_forwarders[:-1])): 

            model_forwarder_i, J_i = model_forwarders[i_model], Js[i_model]

            def model_simulate(y0, dy0, n_steps):
                x = np.empty((n_steps+1, *y0.shape[1:]))
                x[0] = y0.copy()
                xx = as_tensor(x[0])
                for i in range(1,n_steps+1):
                    xx = model_forwarder_i(xx.reshape(-1,J_i+1,args['K']))
                    x[i] = xx.detach().cpu().numpy().copy()
                return x
            
            print('forwarding model ' + panel_titles[i_model])
            X_init = []
            for i in range(n_start):
                X_init.append(data[i_trial[i], T_start[i]] if N_trials > 1 else data[T_start[i]])
                X_init[-1] = sortL96intoChannels(X_init[-1][:args['K']*(J_i+1)],J=J_i)
            X_init = np.stack(X_init)
            X_init = X_init.reshape(1, *X_init.shape)
            with torch.no_grad():
                sol = model_simulate(y0=X_init, dy0=None, n_steps=T_dur)
            sols[i_model,:,:,:] = sol[:,:,0,:].transpose(1,0,2)

        # not parallelising Numba model over initial values for rollouts
        def model_simulate(y0, dy0, n_steps):
            x = np.empty((n_steps+1, *y0.shape[1:]))
            x[0] = y0.copy()
            xx = as_tensor(x[0]).reshape(1,1,-1)
            for i in range(1,n_steps+1):
                xx = model_forwarder_np_twoLevel(xx.reshape(-1,args['J']+1,args['K']))
                x[i] = xx.detach().cpu().numpy().copy()
            return x

        print('forwarding model ' + panel_titles[-1])
        X_init = []
        for i in range(n_start):
            X_init = data[i_trial[i], T_start[i]] if N_trials > 1 else data[T_start[i]]
            X_init = sortL96intoChannels(X_init,J=args['J'])
            X_init = X_init.reshape(1, *X_init.shape)

            with torch.no_grad():
                sol = model_simulate(y0=X_init, dy0=None, n_steps=T_dur)
            sols[-1,i,:,:] = sol[:,0,:]

        ##########################
        #  one-step predictions  #
        ##########################


        model_forwarders_eval = [model_forwarder_np_oneLevel,
                                 model_forwarder_parametrized_init, 
                                 model_forwarder_parametrized_train,
                                 model_forwarder_parametrized_ref]


        MFWDs = [Model_forwarder_rk4default]
        dts = {Model_forwarder_rk4default : args['dt']}
        RMSEs_states = np.zeros((len(model_forwarders_eval), len(T_start)))

        print('\n')
        print('MSEs are on system state !')
        print('\n')


        for m_i, model_forwarder in enumerate(model_forwarders_eval):

            for i in range(len(T_start)):
                inputs = data[i_trial[i], T_start[i]] if N_trials > 1 else data[T_start[i]]
                inputs_torch = as_tensor(sortL96intoChannels(np.atleast_2d(inputs.copy()),J=args['J']))

                out_np = model_forwarder_np_twoLevel(inputs_torch)[:,:1,:]
                out_model = model_forwarder(inputs_torch[:,:1,:])

                RMSEs_states[m_i, i] = np.sqrt(((out_np - out_model)**2).mean().detach().cpu().numpy())


        ##########################
        #         plot           #
        ##########################

        clrs = ['blue', 'orange', 'green', 'red']
        if ifplot:
            fontsize=14
            plt.figure(figsize=(16,9))

            model_forwarders_plot = [0,1,2,4]
            for i,i_model in enumerate(model_forwarders_plot): 

                plt.subplot(2,len(model_forwarders_plot),len(model_forwarders_plot)-i) #plot right to left
                plt.imshow(sols[i_model,0].T, aspect='auto')
                plt.colorbar()
                plt.title(panel_titles[i_model], fontsize=fontsize)

                if i == 0:
                    plt.ylabel('location k', fontsize=fontsize)
                #if i == 2:
                plt.xlabel('time [au]', fontsize=fontsize)

                plt.yticks([], fontsize=fontsize)
                plt.xticks([0, T_dur/2, T_dur], [0, args['dt']*T_dur/2, args['dt']*T_dur], fontsize=fontsize)


            plt.subplot(2,3,6)
            xx = np.linspace(-7.5, 15, 20)
            parametrizations, labels = [param_train], ['trained']
            clrs_plot = ['g', 'r', 'orange']
            if args['parametrization'] == 'linear':
                parametrizations += [param_ref, param_init]
                labels += ['ref.', 'init.']
            for i, parametrization in enumerate(parametrizations):

                plt.plot(xx, 
                         parametrization(as_tensor(xx.reshape(1,1,-1))).detach().cpu().numpy().flatten(),
                         color=clrs_plot[i],
                         linewidth=2.5,
                         label=labels[i])
                plt.xlabel(r'$x_k$', fontsize=fontsize)
                plt.ylabel(r'$\mathcal{B}(x_k)$', fontsize=fontsize)
                plt.legend(fontsize=fontsize)
                plt.axis([-7.5, 12.5, -6, 5])
                plt.xticks([-5, 0, 5, 10])
                plt.yticks([-6, -4, -2, 0, 2, 4])
                plt.grid(True)

                
            ax = plt.subplot(2,3,4)
            model_labels = ['one-level', 
                          'init.', 
                          'trained', 
                          'ref.']

            RMSEs = RMSEs_states
            for i in range(len(model_labels)):
                plt.semilogy(i*np.ones(2)+np.array([-0.5,0.5]), 
                         RMSEs.mean(axis=1)[i]*np.ones(2),
                         color=clrs[i], linewidth=1.5)
                plt.semilogy(i*np.ones(2), 
                         RMSEs.mean(axis=1)[i]+RMSEs.std(axis=1)[i]*np.array([-1,1]),
                         color='k', linewidth=1.5)
            #plt.title(r'state prediction (RK4, $\Delta=0.001$)', y=1.05, fontsize=fontsize)
            plt.yticks(fontsize=fontsize)
            plt.xticks(np.arange(len(model_labels)), model_labels, fontsize=fontsize)
            plt.ylabel('RMSE', fontsize=fontsize)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)


            rmses = np.zeros((len(model_forwarders), n_start, T_dur+1))
            ax = plt.subplot(2,3,5)
            for i_model in range(len(model_forwarders)-1):
                for i in range(n_start):
                    rmses[i_model,i,:] = np.sqrt(np.mean((sols[i_model,i] - sols[-1,i])**2, axis=1))
                #plt.semilogy(rmses[i_model,0].T, clrs[i_model], label=panel_titles[i_model])
                #plt.semilogy(rmses[i_model,1:].T, clrs[i_model])
                plt.semilogy(rmses[i_model].mean(axis=0), clrs[i_model], label=panel_titles[i_model])
                
            plt.legend(fontsize=fontsize, frameon=False)
            plt.ylabel('RMSE', fontsize=fontsize)
            plt.xlabel('time [au]', fontsize=fontsize)
            plt.xticks([0, T_dur/2, T_dur], [0, args['dt']*T_dur/2, args['dt']*T_dur], fontsize=fontsize)
            #plt.yticks([0, 3, 6], fontsize=fontsize)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

            #plt.savefig(res_dir + 'figs/param.pdf', bbox_inches='tight', pad_inches=0, frameon=False)
            plt.show()

    return data

In [None]:
try: 
    data.shape
except: 
    data = None

In [None]:
# nn parametrization for long, intermediate, short data, direct comparison of single- vs multi-step
#exp_ids = ['02', '04', '06', '08', '10', '12']
#exp_ids = ['10', '13', '15',  '12', '14', '16']

exp_ids = ['05']
data = get_eval_parametrization(exp_ids=exp_ids, 
                                T_data=1,
                                N_trials=100,
                                n_start_rollout=100, 
                                T_dur=5000, 
                                data=data, 
                                ifplot=True)

In [None]:
from L96_emulator.eval import get_rollout_fun, plot_rollout
from L96_emulator.parametrization import Parametrized_twoLevel_L96, Parametrization_lin, Parametrization_nn
from L96_emulator.networks import Model_forwarder_rk4default
from L96_emulator.run_parametrization import setup_parametrization
from L96_emulator.data_assimilation import get_model
from L96_emulator.util import as_tensor, dtype_np, sortL96intoChannels, sortL96fromChannels
from L96sim.L96_base import f1, f2, pf2
import numpy as np
import torch
import os


fontsize = 14

def plot_parametrization(exp_ids):


    xx = np.linspace(-7.5, 15, 20)
    plt.figure(figsize=(16,8))

    for exp_id in exp_ids:
        exp_names = os.listdir('experiments_parametrization/')   
        conf_exp = exp_names[np.where(np.array([name[:2] for name in exp_names])==str(exp_id))[0][0]][:-4]
        args = setup_parametrization(conf_exp=f'experiments_parametrization/{conf_exp}.yml')
        args.pop('conf_exp')

        save_dir = 'results/parametrization/' + args['exp_id'] + '/'
        out = np.load(res_dir + save_dir + 'out.npy', allow_pickle=True)[()]

        X_init = out['X_init']

        model_pars = {
            'exp_id' : args['model_exp_id'],
            'model_forwarder' : args['model_forwarder'],
            'K_net' : args['K'],
            'J_net' : 0,
            'dt_net' : args['dt']
        }

        ##########################
        #       models           #
        ##########################

        # trained parametrized model
        model, model_forwarder, _ = get_model(model_pars, res_dir=res_dir, exp_dir='')

        if args['parametrization'] == 'linear':
            param_train = Parametrization_lin(a=as_tensor(out['param_train_state_dict']['a']), 
                                              b=as_tensor(out['param_train_state_dict']['b']))
        elif args['parametrization'] == 'nn':
            param_train = Parametrization_nn(n_hiddens=args['n_hiddens'], n_in=1,n_out=1)
            for key, value in out['param_train_state_dict'].items():
                out['param_train_state_dict'][key] = as_tensor(value)
            param_train.load_state_dict(out['param_train_state_dict'])


        ##########################
        #         plot           #
        ##########################
        
        ln = '--' if len(args['offset']) > 1 else '-'
        mrkr = 'd' if len(args['offset']) > 1 else 'x'
        if args['T'] == 10:
            clr = 'orange'
        elif args['T'] == 1:
            clr = 'blue'
        elif args['T'] == 0.2:
            clr = 'black'
            

        label = args['parametrization'] + '_N' + str(int(1000*args['T']*args['train_frac'])) + '_n' + str(len(args['offset']))
        
        if args['parametrization'] == 'linear' :
            plt.subplot(1,2,1)
        else:
            plt.subplot(1,2,2)
        plt.plot(xx, 
                 param_train(as_tensor(xx.reshape(1,1,-1))).detach().cpu().numpy().flatten(),
                 ln,
                 marker=mrkr,
                 color=clr,
                 linewidth=2.5,
                 label=label)
    for i in range(2):
        plt.subplot(1,2,i+1)    
        plt.xlabel(r'$x_k$', fontsize=fontsize)
        plt.ylabel(r'$\mathcal{B}(x_k)$', fontsize=fontsize)
        plt.legend(fontsize=fontsize)
        plt.axis([-7.5, 12.5, -6, 5])
        plt.xticks([-5, 0, 5, 10])
        plt.yticks([-6, -4, -2, 0, 2, 4])
        plt.grid(True)
    plt.show()


plot_parametrization(exp_ids=np.array(['01','02','03','04','05','06','07','08','09','10','11','12']))
