In [None]:
import os
# import re
import sys
import glob
import pickle
from scipy.fft import fft, fftfreq
from scipy.signal import butter, filtfilt

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
from tensorflow import keras

if '..' not in sys.path:
    sys.path.append('..')
from dlml.utils import collect_experiments
from dlml.data import read_area_values, load_data_areas, load_data_slide
from dlml.nn import predict

#### Find the best experiment given a set of tags

In [None]:
# training on frequency data, 2 output values
# experiment_ID = '9ea493c789b542bf979c51a6031f4044'
# training on frequency data, 4 output values
# experiment_ID = 'f6d9a03f1cfe450288e9cb86da94235f'
# training on time data, 2 output values
experiment_ID = '034a1edb0797475b985f0e1335dab383'
# training on time data, 4 output values
# experiment_ID = 'b346a89d384c4db2ba4058a2c83c4f12'

#### Load the model

In [None]:
experiments_path = '../experiments/neural_network/'
network_parameters = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'parameters.pkl'), 'rb'))
checkpoint_path = experiments_path + experiment_ID + '/checkpoints/'
checkpoint_files = glob.glob(checkpoint_path + '*.h5')
try:
    epochs = [int(os.path.split(file)[-1].split('.')[1].split('-')[0]) for file in checkpoint_files]
    best_checkpoint = checkpoint_files[epochs.index(np.argmin(val_loss) + 1)]
except:
    best_checkpoint = checkpoint_files[-1]
model = keras.models.load_model(best_checkpoint)
x_train_mean = network_parameters['x_train_mean']
x_train_std  = network_parameters['x_train_std']
x_train_min = network_parameters['x_train_min']
x_train_max = network_parameters['x_train_max']
var_names = network_parameters['var_names']
print(f'Loaded network from {best_checkpoint}.')
print(f'Variable names: {var_names}')

#### Plot the model topology

In [None]:
model.summary()

#### Load the data set

In [None]:
set_name = 'test'
use_fft = network_parameters['use_fft'] if 'use_fft' in network_parameters else False
if use_fft:
    raise Exception('This notebook must be used on a network that uses time-domain inputs')
data_dirs = []
for area_ID,data_dir in zip(network_parameters['area_IDs'], network_parameters['data_dirs']):
    data_dirs.append(data_dir.format(area_ID))
data_dir = os.path.join('..', data_dirs[0])
data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))[:3]

ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                        network_parameters['generators_areas_map'][:1],
                        network_parameters['generators_Pnom'],
                        network_parameters['area_measure'],
                        trial_dur=network_parameters['trial_duration'],
                        max_block_size=100,
                        use_tf=False, add_omega_ref=True,
                        use_fft=True)
F = ret[0]
Xf = [ret[1][set_name][i] for i in range(len(var_names))]

ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                        network_parameters['generators_areas_map'][:1],
                        network_parameters['generators_Pnom'],
                        network_parameters['area_measure'],
                        trial_dur=network_parameters['trial_duration'],
                        max_block_size=100,
                        use_tf=False, add_omega_ref=True,
                        use_fft=use_fft)
y = ret[2][set_name]
X = [[(ret[1][set_name][i] - m) / s
      for i,(m,s) in enumerate(zip(x_train_mean, x_train_std))]]
t = ret[0]
dt = np.diff(t[:2])[0]
idx = [np.where(y == mom)[0] for mom in np.unique(y)]
n_mom_values = len(idx)

# load the same file(s) filtered in different bands
bands = [[0.1, 0.4],
         [0.4, 0.9],
         [1, 2],
         [3, 10],
         [0.1, 2],
         [0.05, 10]]
N_bands = len(bands)
btypes =  ['bp' for _ in range(N_bands)]
filter_orders = [8 for _ in range(N_bands)]
for band,btype,order in zip(bands, btypes, filter_orders):
    if use_fft:
        kwargs = {'Wn': band, 'btype': btype, 'filter_order': filter_order}
    else:
        kwargs = {}
    ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                            network_parameters['generators_areas_map'][:1],
                            network_parameters['generators_Pnom'],
                            network_parameters['area_measure'],
                            trial_dur=network_parameters['trial_duration'],
                            max_block_size=100,
                            use_tf=False, add_omega_ref=True,
                            use_fft=use_fft, **kwargs)
    if use_fft:
        X.append([(ret[1][set_name][i] - m) / (M - m)
                  for i,(m,M) in enumerate(zip(x_train_min, x_train_max))])
    else:
        b,a = butter(order//2, band, btype, fs=1/dt)
        X.append([filtfilt(b, a, (ret[1][set_name][i] - m) / s)
                  for i,(m,s) in enumerate(zip(x_train_mean, x_train_std))])

#### Read the exact values of momentum

In [None]:
exact_momentum = []
for data_file in data_files:
    _,_,v,_ = read_area_values(data_file,
                               network_parameters['generators_areas_map'],
                               network_parameters['generators_Pnom'],
                               'momentum')
    exact_momentum.append(v[0])
exact_momentum = np.array(exact_momentum)

#### Plot the spectra of the inputs

In [None]:
cmap = plt.get_cmap('tab10', n_mom_values)
fig,ax = plt.subplots(1, 2, figsize=(10, 4))

for j,jdx in enumerate(idx):
    mean = X[0][0][jdx].mean(axis=0)
    stddev = X[0][0][jdx].std(axis=0)
    ci = 1.96 * stddev / np.sqrt(jdx.size)
    ax[0].fill_between(t, mean + ci, mean - ci, color=cmap(j))
    mean = Xf[0][jdx].mean(axis=0)
    stddev = Xf[0][jdx].std(axis=0)
    ci = 1.96 * stddev / np.sqrt(jdx.size)
    ylim = [0, np.max((mean + ci)[F > 0.1]) * 1.1]
    ax[1].fill_between(F, mean + ci, mean - ci, color=cmap(j))

for a in ax:
    for side in 'right','top':
        a.spines[side].set_visible(False)
    a.grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
ax[1].set_xscale('log')
ax[0].set_xlabel('Time [min]')
ax[0].set_ylabel(f'Normalized {var_names[0]}')
ax[1].set_xlabel('Frequency [Hz]')
ax[1].set_ylabel('FFT')
ax[1].set_ylim(ylim)
fig.tight_layout()

#### Plot one input trace filtered in all bands

In [None]:
trace = 100
fig,ax = plt.subplots(1, 1, figsize=(8, 5))
cmap = plt.get_cmap('jet', N_bands+1)
ax.plot(t, X[0][0][trace,:], 'k', lw=1, label='Broadband')
for i in range(N_bands):
    ax.plot(t, X[i+1][0][trace,:], color=cmap(i), lw=1,
            label=f'[{bands[i][0]:g}-{bands[i][1]:g}] Hz')
ax.legend(loc='best')
for side in 'right','top':
    ax.spines[side].set_visible(False)
ax.set_xlabel('Time [s]')
ax.set_ylabel(var_names[0])
ax.grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
fig.tight_layout()

#### Predict the momentum using the model

In [None]:
idx = [np.where(y == mom)[0] for mom in np.unique(y)]
n_mom_values = len(idx)
momentum = [[np.squeeze(model.predict(x[0][jdx])) for jdx in idx] for x in X]
mean_momentum = [[m.mean() for m in mom] for mom in momentum]
stddev_momentum = [[m.std() for m in mom] for mom in momentum]

In [None]:
print('Exact momentum: {:.3f} {} {:.3f}'.format(exact_momentum[0],
                                                    ' ' * 10,
                                                    exact_momentum[1]))
print('     Broadband: {:.3f} +- {:.3f}   {:.3f} +- {:.3f}'.format(
    mean_momentum[0][0], stddev_momentum[0][0],
    mean_momentum[0][1], stddev_momentum[0][1],
))
for m,s,band in zip(mean_momentum[1:], stddev_momentum[1:], bands):
    print('   [{:4.1f},{:4.1f}]: {:.3f} +- {:.3f}   {:.3f} +- {:.3f}'.format(
        band[0], band[1],
        m[0], s[0], m[1], s[1],
    ))

### Plot the predicted values of momentum

In [None]:
fig,ax = plt.subplots(1, 1, figsize=(8,5))
ms = 8
ax.plot(exact_momentum, exact_momentum, '--', color=[1,.5,.5], lw=10, label='Exact')
for i,mom in enumerate(mean_momentum):
    if i == 0:
        ax.plot(exact_momentum, mom, 'ko-', lw=2, markersize=ms,
                markerfacecolor='w', markeredgewidth=2, label='Broadband')
    else:
        ax.plot(exact_momentum, mom, 'o-', color=cmap(i), lw=2, markersize=ms,
                markerfacecolor='w', markeredgewidth=2,
                label=f'[{bands[i-1][0]:g}-{bands[i-1][1]:g}] Hz')
for side in 'right','top':
    ax.spines[side].set_visible(False)
ax.legend(loc='best')
ax.set_xlabel(r'Exact momentum [GW$\cdot$s$^2$]')
ax.set_ylabel(r'Estimated momentum [GW$\cdot$s$^2$]')
ax.grid(which='major', axis='both', lw=0.5, ls=':', color=[.6,.6,.6])
fig.tight_layout()