In [None]:
import os
import re
import sys
import glob
import pickle
import tables
from scipy.fft import fft, fftfreq
from scipy.signal import butter, filtfilt
# from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# %matplotlib notebook

import tensorflow as tf
from tensorflow import keras

# import dtw
# import umap
# from sklearn.preprocessing import StandardScaler
# from sklearn.decomposition import PCA

if '..' not in sys.path:
    sys.path.append('..')
from dlml.utils import collect_experiments
from dlml.data import read_area_values, load_data_areas, load_data_slide
from dlml.nn import predict

#### Find the best experiment given a set of tags

In [None]:
# training on frequency data, 2 output values
# experiment_ID = '9ea493c789b542bf979c51a6031f4044'
# training on frequency data, 4 output values
experiment_ID = 'f6d9a03f1cfe450288e9cb86da94235f'
# training on time data, 2 output values
# experiment_ID = '034a1edb0797475b985f0e1335dab383'
# training on time data, 4 output values
# experiment_ID = 'b346a89d384c4db2ba4058a2c83c4f12'

#### Load the model

In [None]:
experiments_path = '../experiments/neural_network/'
network_parameters = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'parameters.pkl'), 'rb'))
checkpoint_path = experiments_path + experiment_ID + '/checkpoints/'
checkpoint_files = glob.glob(checkpoint_path + '*.h5')
try:
    epochs = [int(os.path.split(file)[-1].split('.')[1].split('-')[0]) for file in checkpoint_files]
    best_checkpoint = checkpoint_files[epochs.index(np.argmin(val_loss) + 1)]
except:
    best_checkpoint = checkpoint_files[-1]
model = keras.models.load_model(best_checkpoint)
x_train_mean = network_parameters['x_train_mean']
x_train_std  = network_parameters['x_train_std']
try:
    x_train_min = network_parameters['x_train_min']
    x_train_max = network_parameters['x_train_max']
except:
    pass
var_names = network_parameters['var_names']
print(f'Loaded network from {best_checkpoint}.')
print(f'Variable names: {var_names}')

#### Plot the model topology

In [None]:
keras.utils.plot_model(model, show_shapes=False, dpi=96)

In [None]:
model.summary()

#### Load the data set

In [None]:
set_name = 'test'
use_fft = network_parameters['use_fft'] if 'use_fft' in network_parameters else False
data_dirs = []
for area_ID,data_dir in zip(network_parameters['area_IDs'], network_parameters['data_dirs']):
    data_dirs.append(data_dir.format(area_ID))
data_dir = os.path.join('..', data_dirs[0])
data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))[:3]
ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                        network_parameters['generators_areas_map'][:1],
                        network_parameters['generators_Pnom'],
                        network_parameters['area_measure'],
                        trial_dur=network_parameters['trial_duration'],
                        max_block_size=100,
                        use_tf=False, add_omega_ref=True,
                        use_fft=use_fft)
y = ret[2][set_name]
X = [[(ret[1][set_name][i] - m) / (M - m) for i,(m,M) in enumerate(zip(x_train_min, x_train_max))]]
F = ret[0]

# load the same file(s) filtered in different bands
bands = [[0.1, 0.4], [0.4, 0.9], [0.9, 2], [2, 10]]
btypes =  ['bp', 'bp', 'bp', 'bp']
filter_orders = [6, 10, 10, 10]
for band,btype,order in zip(bands, btypes, filter_orders):
    ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                            network_parameters['generators_areas_map'][:1],
                            network_parameters['generators_Pnom'],
                            network_parameters['area_measure'],
                            trial_dur=network_parameters['trial_duration'],
                            max_block_size=100,
                            use_tf=False, add_omega_ref=True,
                            use_fft=use_fft, btype=btype, Wn=band, filter_order=order)
    X.append([(ret[1][set_name][i] - m) / (M - m) for i,(m,M) in enumerate(zip(x_train_min, x_train_max))])

#### Read the exact values of momentum

In [None]:
exact_momentum = []
for data_file in data_files:
    _,_,v,_ = read_area_values(data_file,
                               network_parameters['generators_areas_map'],
                               network_parameters['generators_Pnom'],
                               'momentum')
    exact_momentum.append(v[0])
exact_momentum = np.array(exact_momentum)

#### Predict the momentum using the model

In [None]:
idx = [np.where(y == mom)[0] for mom in np.unique(y)]
n_mom_values = len(idx)
momentum = [[np.squeeze(model.predict(x[0][jdx])) for jdx in idx] for x in X]
mean_momentum = [[m.mean() for m in mom] for mom in momentum]
stddev_momentum = [[m.std() for m in mom] for mom in momentum]
print('Mean momentum:', mean_momentum)
# print(' Std momentum:', stddev_momentum)

#### Plot the spectra of the inputs

In [None]:
cmap = plt.get_cmap('jet', len(X))
fig,ax = plt.subplots(1, n_mom_values, figsize=(5*n_mom_values,4), sharex=True, sharey=True)
if n_mom_values == 1:
    ax = [ax]
for i in range(n_mom_values):
    for j in range(len(X)):
        mean = X[j][0][idx[i]].mean(axis=0)
        stddev = X[j][0][idx[i]].std(axis=0)
        ci = 1.96 * stddev / np.sqrt(idx[i].size)
        if j == 0:
            col = 'k'
        else:
            col = cmap(j)
        ax[i].fill_between(F, mean + ci, mean - ci, color=col)
    #ax[i].set_xscale('log')
    for side in 'right','top':
        ax[i].spines[side].set_visible(False)
    ax[i].grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
    ax[i].set_xlabel('Frequency [Hz]')
ax[0].set_ylabel('Normalized FFT')
fig.tight_layout()

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(8,5))
ms = 8
ax.plot(exact_momentum, exact_momentum, '--', color=[1,.5,.5], lw=10, label='Exact')
for i,mom in enumerate(mean_momentum):
    if i == 0:
        ax.plot(exact_momentum, mom, 'ko-', lw=2, markersize=ms,
                markerfacecolor='w', markeredgewidth=2, label='Full spectrum')
    else:
        ax.plot(exact_momentum, mom, 'o-', color=cmap(i), lw=2, markersize=ms,
                markerfacecolor='w', markeredgewidth=2)
for side in 'right','top':
    ax.spines[side].set_visible(False)
ax.legend(loc='upper left')
ax.set_xlabel(r'Exact momentum [GW$\cdot$s$^2$]')
ax.set_ylabel(r'Estimated momentum [GW$\cdot$s$^2$]')
ax.grid(which='major', axis='both', lw=0.5, ls=':', color=[.6,.6,.6])
fig.tight_layout()