In [None]:
import os
import sys
import glob
import pickle
from scipy.fft import fft, fftfreq
from scipy.signal import butter, filtfilt
from sklearn.metrics import r2_score
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
from tensorflow import keras

if '..' not in sys.path:
    sys.path.append('..')
from dlml.utils import collect_experiments
from dlml.data import load_data_files, load_data_areas

In [None]:
def compute_score(X, y, dt, model, band=None, order=8, btype='bandstop'):
    N_trials, N_samples = X.shape
    if band is None:
        Xfilt = X
    else:
        b,a = butter(order//2, band, btype, fs=1/dt)
        Xfilt = filtfilt(b, a, X)
    y_pred = model.predict(Xfilt)
    return r2_score(y, y_pred), y_pred.squeeze(), Xfilt

#### Find the best experiment given a set of tags

In [None]:
# training on frequency data, 2 output values
# experiment_ID = '9ea493c789b542bf979c51a6031f4044'
# training on frequency data, 4 output values
# experiment_ID = 'f6d9a03f1cfe450288e9cb86da94235f'
# training on time data, 2 output values
# experiment_ID = '034a1edb0797475b985f0e1335dab383'
# training on time data, 4 output values
# experiment_ID = 'b346a89d384c4db2ba4058a2c83c4f12'
# training on time data, 2 output values, 8 input values
experiment_ID = '98475b819ecb4d569646d7e1467d7c9c'

#### Load the model

In [None]:
experiments_path = '../experiments/neural_network/'
network_parameters = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'parameters.pkl'), 'rb'))
checkpoint_path = experiments_path + experiment_ID + '/checkpoints/'
checkpoint_files = glob.glob(checkpoint_path + '*.h5')
try:
    epochs = [int(os.path.split(file)[-1].split('.')[1].split('-')[0]) for file in checkpoint_files]
    best_checkpoint = checkpoint_files[epochs.index(np.argmin(val_loss) + 1)]
except:
    best_checkpoint = checkpoint_files[-1]
model = keras.models.load_model(best_checkpoint)
x_train_mean = network_parameters['x_train_mean']
x_train_std  = network_parameters['x_train_std']
x_train_min = network_parameters['x_train_min']
x_train_max = network_parameters['x_train_max']
var_names = network_parameters['var_names']
print(f'Loaded network from {best_checkpoint}.')
print(f'Variable names: {var_names}')

#### Plot the model topology

In [None]:
model.summary()

#### Load the data set

In [None]:
use_fft = network_parameters['use_fft'] if 'use_fft' in network_parameters else False
if use_fft:
    raise Exception('This notebook must be used on a network that uses time-domain inputs')


if False:
    base_folder = network_parameters['data_dirs'][0]
    while '{}' in base_folder:
        base_folder,_ = os.path.split(base_folder)
    data_files = []
    group_index = []
    set_name = 'test'
    for i,prefix in enumerate(('low','high')):
        folder = os.path.join('..', base_folder, prefix + '_momentum_' + set_name)
        files = sorted(glob.glob(folder + os.path.sep + '*.h5'))
        group_index.append(np.arange(len(files)) + len(data_files))
        data_files += files
    ret = load_data_files(data_files,
                          network_parameters['var_names'],
                          network_parameters['generators_areas_map'][:1],
                          network_parameters['generators_Pnom'],
                          'momentum')
    t = ret[0][:-1]
    X_raw = ret[1][:, :, :-1]
    X = np.zeros(X_raw.shape)
    for i,(m,s) in enumerate(zip(x_train_mean, x_train_std)):
        X[i,:,:] = (X_raw[i,:,:] - m) / s
    y = ret[2]
    X_raw = X_raw.squeeze()
else:
    set_name = 'test'
    base_folder = '../' + network_parameters['data_dirs'][0].format(network_parameters['area_IDs'][0])
    data_files = sorted(glob.glob(base_folder + f'/*{set_name}_set.h5'))
    ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                          network_parameters['generators_areas_map'][:1],
                          network_parameters['generators_Pnom'],
                          network_parameters['area_measure'],
                          trial_dur=network_parameters['trial_duration'],
                          max_block_size=1000,
                          use_tf=False, add_omega_ref=True,
                          use_fft=False)
    t = ret[0]
    X = [(ret[1][set_name][i] - m) / s for i,(m,s) in enumerate(zip(x_train_mean, x_train_std))]
    X = X[0]
    y = ret[2][set_name]
    group_index = [np.where(y < y.mean())[0], np.where(y > y.mean())[0]]

n_mom_groups = len(group_index)
X = X.squeeze()
y = y.squeeze()
dt = np.diff(t[:2])[0]
N_samples = t.size
Xf = fft(X)
Xf = 2.0 / N_samples * np.abs(Xf[:, :N_samples//2])
F = fftfreq(N_samples, dt)[:N_samples//2]

#### Plot the spectra of the inputs

In [None]:
cmap = plt.get_cmap('Accent')
fig,ax = plt.subplots(1, 2, figsize=(10, 4))

ylim = [1e-2, 0]
Xfm = np.zeros((len(group_index), F.size))
Xfci = np.zeros((len(group_index), F.size))
for i,idx in enumerate(group_index):
    mean = X[idx].mean(axis=0)
    stddev = X[idx].std(axis=0)
    ci = 1.96 * stddev / np.sqrt(idx.size)
    ax[0].fill_between(t, mean + ci, mean - ci, color=cmap(i))
    Xfm[i] = Xf[idx].mean(axis=0)
    Xfci[i] = 1.96 * Xf[idx].std(axis=0) / np.sqrt(idx.size)
    m = np.max((Xfm[i,:] + Xfci[i,:])[F > 0.1]) * 1.1
    if m > ylim[1]:
        ylim = [1e-2, m]
    ax[1].fill_between(F, 20 * np.log10(Xfm[i] + Xfci[i]), 20 * np.log10(Xfm[i] - Xfci[i]), color=cmap(i))
for a in ax:
    for side in 'right','top':
        a.spines[side].set_visible(False)
    a.grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
ax[1].set_xscale('log')
ax[0].set_xlabel('Time [min]')
ax[0].set_ylabel(f'Normalized {var_names[0]}')
ax[1].set_xlabel('Frequency [Hz]')
ax[1].set_ylabel('FFT')
ax[1].set_xlim([1e-2, 2.5])
fig.tight_layout()

In [None]:
bands = [[0.02, 0.5], [0.5, 1], [1, 1.5], [1.5,3], [3,8], [8,19.9]]
N_bands = len(bands)
N_trials, N_samples = X.shape
scores = np.zeros(N_bands+1)
X_filt = np.zeros((N_bands+1, N_trials, N_samples))
y_pred = np.zeros((N_bands+1, N_trials))
scores[0], y_pred[0, :], X_filt[0, :, :] = compute_score(X, y, dt, model)
for i in tqdm(range(N_bands)):
    scores[i+1], y_pred[i+1, :], X_filt[i+1, :, :] = compute_score(X, y, dt, model, bands[i])
bands[-1][1] = 20

In [None]:
np.savez_compressed(os.path.join(experiments_path, experiment_ID, 'stopband_momentum_estimation.npz'),
                   F=F, Xf=Xf, y=y, y_pred=y_pred, scores=scores, bands=bands, group_index=group_index)

In [None]:
fig,ax = plt.subplots(2, 1, figsize=(7, 7))

cmap2 = plt.get_cmap('tab10')
y_m = [y[idx].mean() for idx in group_index]
y_s = [y[idx].std() for idx in group_index]
y_pred_m = np.array([[pred[jdx].mean() for jdx in group_index] for pred in y_pred])
y_pred_s = np.array([[pred[jdx].std() for jdx in group_index] for pred in y_pred])
ax[0].plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
add_band = False
last_band = N_bands + 1 if add_band else N_bands + 1
for i in range(last_band):
    if i == 0:
        lbl = 'Broadband'
    else:
        lbl = f'[{bands[i-1][0]}-{bands[i-1][1]:g}] Hz'
    ax[0].plot(y_m, y_pred_m[i], color=cmap2(i), lw=2, label=lbl)
    for j in range(len(group_index)):
        ax[0].plot(y_m[j] + np.zeros(2),
                   y_pred_m[i,j] + y_pred_s[i,j] * np.array([-1,1]),
                   color=cmap2(i), lw=2)
        ax[0].plot(y_m[j] + y_s[j] * np.array([-1,1]),
                   y_pred_m[i,j] + np.zeros(2),
                   color=cmap2(i), lw=2)
        ax[0].plot(y_m[j], y_pred_m[i,j], 'o', color=cmap2(i),
                   markerfacecolor='w', markersize=6.5, markeredgewidth=2)
for side in 'right','top':
    ax[0].spines[side].set_visible(False)
ax[0].legend(loc='upper left', bbox_to_anchor=[0.8,0.575], frameon=False)
ax[0].set_xlabel(r'Exact momentum [GW$\cdot$s$^2$]')
ax[0].set_ylabel(r'Estimated momentum [GW$\cdot$s$^2$]')
ax[0].grid(which='major', axis='both', lw=0.5, ls=':', color=[.6,.6,.6])
ax[0].set_xlim([0.149, 0.31])
ax[0].set_ylim([0.13, 0.34])
ax[0].set_xticks(np.r_[0.15 : 0.31 : 0.03])
ax[0].set_yticks(np.r_[0.15 : 0.34 : 0.03])

axr = ax[1].twinx()
for i,(m,ci) in enumerate(zip(Xfm, Xfci)):
    ax[1].fill_between(F, 20*np.log10(m + ci), 20*np.log10(m - ci), color=cmap(i),
                       label=r'M = {:.3f} GW$\cdot$s$^2$'.format(y[group_index[i]].mean()))
ax[1].legend(loc='lower left', frameon=False, fontsize=8)
axr.plot(ax[1].get_xlim(), scores[0] + np.zeros(2), '--', color=cmap2(0), lw=3)
for i,band in enumerate(bands):
    if i >= last_band-1:
        break
    axr.axvline(band[0], color=[.6,.6,.6], ls=':', lw=0.5)
    axr.plot(band, scores[i+1] + np.zeros(2), color=cmap2(i+1), lw=2)
ax[1].set_xlabel('Frequency [Hz]')
ax[1].set_ylabel('Power [dB]')
axr.set_ylim((-1.1, 1.1))
axr.set_ylabel(r'R$^2$ score')
ax[1].set_xscale('log')
fig.tight_layout()
if add_band:
    fig.savefig(f'stopband_{experiment_ID}_{last_band}.pdf')
else:
    fig.savefig(f'stopband_{experiment_ID[:6]}.pdf')