In [None]:
import os
import re
import sys
import glob
import pickle
import shelve
import numpy as np
from scipy.fft import fft, fftfreq
from scipy.signal import find_peaks
from scipy.optimize import curve_fit
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from matplotlib.colors import LogNorm, FuncNorm
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FixedLocator, NullLocator, FixedFormatter
from matplotlib.patches import Polygon
import seaborn as sns

fontsize = 8
lw = 0.75

matplotlib.rc('font', **{'family': 'Times New Roman', 'size': fontsize})
matplotlib.rc('axes', **{'linewidth': 0.75, 'labelsize': fontsize})
matplotlib.rc('xtick', **{'labelsize': fontsize})
matplotlib.rc('ytick', **{'labelsize': fontsize})
matplotlib.rc('xtick.major', **{'width': lw, 'size':3})
matplotlib.rc('ytick.major', **{'width': lw, 'size':3})
matplotlib.rc('ytick.minor', **{'width': lw, 'size':1.5})

%matplotlib inline

if '..' not in sys.path:
    sys.path.append('..')
from dlml.data import load_data_files, load_data_areas

In [None]:
def make_axes(rows, cols, x_offset, y_offset, x_space, y_space, squeeze=True):
    w = (1 - np.sum(x_offset) - x_space * (cols - 1)) / cols
    h = (1 - np.sum(y_offset) - y_space * (rows - 1)) / rows
    
    ax = [[plt.axes([x_offset[0] + (w + x_space) * j,
                     y_offset[0] + (h + y_space) * i,
                     w, h]) for j in range(cols)] for i in range(rows-1, -1, -1)]
    
    for row in ax:
        for a in row:
            for side in 'right','top':
                a.spines[side].set_visible(False)

    if squeeze:
        if rows == 1 and cols == 1:
            return ax[0][0]
        if rows == 1:
            return ax[0]
        if cols == 1:
            return [a[0] for a in ax]
        
    return ax

In [None]:
def plot_correlations(R, p, R_ctrl, p_ctrl, edges, idx, ax, sort_freq=1.0,
                      vmin=None, vmax=None, legend_bbox=[0.4, -0.05]):
    if p is not None:
        R = R.copy()
        R[p > 0.05] = np.nan
    if R_ctrl is not None and p_ctrl is not None:
        R_ctrl = R_ctrl.copy()
        R_ctrl[p_ctrl > 0.05] = np.nan
    rows, cols = ax.shape
    if rows != len(idx):
        raise Exception('Number of rows of ax does not match len(idx)')
    R_mean = [np.nanmean(R[jdx], axis=0) for jdx in idx]
    R_abs_mean = [np.mean(np.abs(r), axis=1) for r in R_mean]
    if R_ctrl is not None:
        R_ctrl_mean = [np.nanmean(R_ctrl[jdx], axis=0) for jdx in idx]
        R_ctrl_abs_mean = [np.mean(np.abs(r), axis=1) for r in R_ctrl_mean]
    else:
        R_ctrl_mean = [None for _ in range(rows)]
        R_ctrl_abs_mean = None
    if np.isscalar(sort_freq):
        sort_freq += np.zeros(rows)
    edge = np.array([np.abs(edges - freq).argmin() for freq in sort_freq])
    for i in range(rows):
        kdx = np.argsort(R_mean[i][edge[i],:])
        R_mean[i] = R_mean[i][:,kdx]
        if R_ctrl is not None:
            R_ctrl_mean[i] = R_ctrl_mean[i][:,kdx]

    make_symmetric = False
    if vmin is None:
        vmin = min([r.min() for r in R_mean])
        make_symmetric = True
    if vmax is None:
        vmax = max([r.max() for r in R_mean])
        if make_symmetric:
            if vmax > np.abs(vmin):
                vmin = -vmax
            else:
                vmax = -vmin
    print(f'Color bar bounds: ({vmin:.2f},{vmax:.2f}).')
    ticks = np.linspace(vmin, vmax, 7)
    ticklabels = [f'{tick:.2f}' for tick in ticks]

    cmap = plt.get_cmap('bwr')
    y = edges[:-1] + np.diff(edges) / 2
    for i in range(rows):
        for j,R in enumerate((R_mean[i], R_ctrl_mean[i])):
            if R is not None:
                x = np.arange(R.shape[-1])
                im = ax[i][j].pcolormesh(x, y, R, vmin=vmin, vmax=vmax, shading='auto', cmap=cmap)
                for side in 'right','top':
                    ax[i][j].spines[side].set_visible(False)
                ax[i][j].set_xticks(np.linspace(0, x[-1], 3, dtype=np.int32))
        if cols > 1:
            cbar = plt.colorbar(im, fraction=0.1, shrink=1, aspect=20, label='Correlation',
                                orientation='vertical', ax=ax[i][1], ticks=ticks)
            cbar.ax.set_yticklabels(ticklabels, fontsize=fontsize-1)
            if R_ctrl_abs_mean is not None:
                ax[i][-1].plot(R_abs_mean[i], y, 'r', lw=1, label='Tr.')
                ax[i][-1].plot(R_ctrl_abs_mean[i], y, 'g--', lw=1, label='Untr.')
                ax[i][-1].plot(R_abs_mean[i] - R_ctrl_abs_mean[i], y, 'k', lw=1, label='Diff.')
                ax[i][-1].legend(loc='lower left', bbox_to_anchor=legend_bbox,
                                 frameon=False, fontsize=fontsize-1)

    for i in range(rows):
        for j in range(cols):
            ax[i][j].set_ylim(edges[[0,-2]])
            ax[i][j].set_yscale('log')

    return vmin, vmax

In [None]:
def plot_layer_output_hist(Y, group_index, N_bins, cols=8, w=2, h=1.5, cmap=None, ax=None, labels=None):
    N_trials, N_samples, N_filters = Y.shape
    N_groups = len(group_index)
    N = np.zeros((N_filters, N_groups, N_bins))
    edges = np.zeros((N_filters, N_groups, N_bins+1))
    for i in range(N_filters):
        for j,jdx in enumerate(group_index):
            N[i,j,:],edges[i,j,:] = np.histogram(Y[jdx, :, i], N_bins)

    if cmap is None:
        cmap = plt.get_cmap('tab10', N_groups)
    if ax is None:
        rows = N_filters // cols
        fig,ax = plt.subplots(rows, cols, figsize=(w*cols, h*rows))
    else:
        fig = None
        N_filters = ax.size
    ax = ax.flatten()
    for i in range(N_filters):
        for j in range(N_groups):
            de = np.diff(edges[i, j, :])[0]
            col = np.max([[0,0,0], cmap(j)[:3] - 1/3 * np.ones(3)], axis=0)
            ax[i].bar(edges[i, j, :-1], N[i, j, :], width=de*0.8, align='edge',
                     facecolor=cmap(j), edgecolor=col, linewidth=0.5, alpha=0.85)
        xlim = [edges[i, :, 2:-3].min(), edges[i, j, 2:-3].max()]
        ylim = ax[i].get_ylim()
#         ax[i].set_xlim(xlim)
#         ax[i].set_xticks(xlim)
        if labels is not None:
            ax[i].text(xlim[0] - 0.1 * np.diff(xlim), ylim[1],
                       labels[i], fontsize=fontsize-1, verticalalignment='top',
                       horizontalalignment='left')
        ax[i].set_xticklabels([])
        ax[i].set_yticks(ax[i].get_ylim())
        ax[i].set_yticklabels([])
        for side in 'right','top':
            ax[i].spines[side].set_visible(False)
    if fig is not None:
        fig.tight_layout()
    return fig,ax

In [None]:
X, y, Xf = {}, {}, {}
group_index, n_mom_groups = {}, {}

In [None]:
data_file = 'traces_hist_spectra_comp_grid.npz'
force = False
set_name = 'training'
if not os.path.isfile(data_file) or force:
    data_dir = '../data/IEEE39/converted_from_PowerFactory/all_stoch_loads/' + \
        'var_H_area_1_comp_grid/coarse_H_comp11_0.1/diagonal'
    data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))
    var_names = ['Vd_bus3']
    generators_areas_map = [['G02', 'G03', 'Comp11'],
                            ['G04', 'G05', 'G06', 'G07', 'Comp21'],
                            ['G08', 'G09', 'G10', 'Comp31'],
                            ['G01']]
    generators_Pnom = {'G01': 10e9, 'G02': 700e6, 'G03': 800e6, 'G04': 800e6, 'G05': 300e6,
                       'G06': 800e6, 'G07': 700e6, 'G08': 700e6, 'G09': 1000e6, 'G10': 1000e6,
                       'Comp11': 100e6, 'Comp21': 100e6, 'Comp31': 100e6}
    area_measure = 'momentum'
    ret = load_data_areas({set_name: data_files}, var_names,
                          generators_areas_map[:1],
                          generators_Pnom,
                          area_measure,
                          trial_dur=60,
                          max_block_size=1000,
                          use_tf=False,
                          add_omega_ref=True,
                          use_fft=False)
    t = ret[0]
    X_raw = ret[1][set_name]
    y[set_name] = ret[2][set_name]
    group_index[set_name] = [np.where(y[set_name] == mom)[0] for mom in np.unique(y[set_name])]
    n_mom_groups[set_name] = len(group_index[set_name])
    X_mean, X_std = X_raw.mean(axis=(1,2)), X_raw.std(axis=(1,2))
    X[set_name] = (X_raw - X_mean) / X_std
    X[set_name] = X[set_name].squeeze()
    y[set_name] = y[set_name].squeeze()

    dt = np.diff(t[:2])[0]
    N_samples = t.size
    Xf[set_name] = fft(X[set_name])
    Xf[set_name] = 2.0 / N_samples * np.abs(Xf[set_name][:, :N_samples//2])
    F = fftfreq(N_samples, dt)[:N_samples//2]

    data_dir = '../data/IEEE39/converted_from_PowerFactory/all_stoch_loads/' + \
        'var_H_area_1_comp_grid/subset_8_H_comp11_0.1'
    data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))
    var_names = ['Vd_bus3']
    ret = load_data_areas({set_name: data_files}, var_names,
                          generators_areas_map[:1],
                          generators_Pnom,
                          area_measure,
                          trial_dur=60,
                          max_block_size=1000,
                          use_tf=False,
                          add_omega_ref=True,
                          use_fft=False)    
    X_raw = ret[1][set_name]
    y['var_G2_G3'] = ret[2][set_name]
    group_index['var_G2_G3'] = [np.where(y['var_G2_G3'] == mom)[0] for mom in np.unique(y['var_G2_G3'])]
    n_mom_groups['var_G2_G3'] = len(group_index['var_G2_G3'])
    X['var_G2_G3'] = (X_raw - X_mean) / X_std
    X['var_G2_G3'] = X['var_G2_G3'].squeeze()
    y['var_G2_G3'] = y['var_G2_G3'].squeeze()
    Xf['var_G2_G3'] = fft(X['var_G2_G3'])
    Xf['var_G2_G3'] = 2.0 / N_samples * np.abs(Xf['var_G2_G3'][:, :N_samples//2])

    np.savez_compressed(data_file, t=t, X=X, y=y, Xf=Xf, F=F,
                        group_index=group_index, n_mom_groups=n_mom_groups)
else:
    data = np.load(data_file, allow_pickle=True)
    t = data['t']
    X = data['X'].item()
    y = data['y'].item()
    F = data['F']
    Xf = data['Xf'].item()
    group_index = data['group_index'].item()
    n_mom_groups = data['n_mom_groups'].item()

## Figure 2

In [None]:
fig,ax = plt.subplot_mosaic(
    '''
    AAAAAAAA
    BBBBBCCC
    DDDDEEEE
    ''',
    figsize=(3.5,4)
)
cmap = plt.get_cmap('Set1')
N_div_cmap = 10
div_cmap = plt.get_cmap('bwr', N_div_cmap)

###########################################
momentum = lambda H, S, fn: 2 * H@S / fn * 1e-3

N = 11, 11
h_G2_0, h_G3_0 = 4.33, 4.47
h_G2 = h_G2_0 + np.linspace(-1, 1, N[0])
h_G3 = h_G3_0 + np.linspace(-1, 1, N[1])
H_G2, H_G3 = np.meshgrid(*[h_G2, h_G3])

M = np.zeros(N)
S = np.array([700, 800])
fn = 60
for i in range(N[0]):
    for j in range(N[1]):
        H = np.array([H_G2[i,j], H_G3[i,j]])
        M[i,j] = momentum(H, S, fn)
        
gray_cmap = plt.get_cmap('gray')
cont = ax['A'].contourf(H_G2, H_G3, M, levels=100, cmap=gray_cmap, zorder=-1)
cbar = plt.colorbar(cont, ax=ax['A'])
white = [1,1,1]
magenta = [1,0,1]
green = [0,1,0]
yellow = [1,1,0]
blue = [0,.5,1]
red = [1,.333,.333]
orange = [1, .5, 0]
gray = [.6, .6, .6]
zord = 0
for i in range(2):
    for j in range(2):
        ax['A'].plot(H_G2[i,j], H_G3[i,j], 's', markersize=4, lw=1, color=div_cmap(i*2+j), zorder=zord)
        ax['A'].plot(H_G2[-1-i,-1-j], H_G3[-1-i,-1-j], 's', markersize=4, lw=1,
                     color=div_cmap(N_div_cmap-1-i*2-j), zorder=zord)
        zord += 2
for i in range(0, N[0], 2):
    ax['A'].plot(H_G2[i,i], H_G3[i,i], 'o', color=cmap(i//2), markersize=3, lw=1, zorder=zord)
    zord += 1
ax['A'].plot(H_G2[:2,:2].mean(), H_G3[:2,:2].mean(), 'x', color=magenta, markersize=4,
             markeredgewidth=1, zorder=zord)
zord += 1
ax['A'].plot(H_G2[-2:,-2:].mean(), H_G3[-2:,-2:].mean(), 'x', color=magenta, markersize=4,
             markeredgewidth=1, zorder=zord)
zord += 1
ax['A'].scatter(H_G2[::2, ::2], H_G3[::2, ::2], s=5, c='w', edgecolors='k', lw=0.5, marker='o', zorder=zord)
ax['A'].set_xlabel(r'$H_{G_2}$ [s]')
ax['A'].set_ylabel(r'$H_{G_3}$ [s]')
ax['A'].set_xlim([h_G2_0 - 1.1, h_G2_0 + 1.1])
ax['A'].set_ylim([h_G3_0 - 1.1, h_G3_0 + 1.1])
ax['A'].set_xticks([h_G2_0 - 1, h_G2_0, h_G2_0 + 1])
ax['A'].set_yticks([h_G3_0 - 1, h_G3_0, h_G3_0 + 1])
cbar.set_label(r'Momentum [GW$\cdot$s$^2$]')
cbar.set_ticks(np.r_[0.17 : 0.28 : 0.02])
###########################################

tend = 10
jdx, = np.where(t < tend)
for i,idx in enumerate(group_index[set_name]):
    n,edges = np.histogram(X[set_name][idx,:], bins=50, density=True)
    ax['B'].plot(t[jdx], X[set_name][idx[0]+1, jdx], lw=1, color=cmap(i))
    ax['C'].plot(n, edges[1:], lw=1, color=cmap(i))
    ax['D'].plot(F, 20 * np.log10(Xf[set_name][idx,:].mean(axis=0)), lw=1,
               color=cmap(i), label=r'{:.3f} GW$\cdot$s$^2$'.format(y[set_name][idx[0]+1]))

for i,idx in enumerate(group_index['var_G2_G3']):
    if i < 4:
        ax['E'].plot(F, 20 * np.log10(Xf['var_G2_G3'][idx,:].mean(axis=0)), lw=1,
                   color=div_cmap(i))
    else:
        ax['E'].plot(F, 20 * np.log10(Xf['var_G2_G3'][idx,:].mean(axis=0)), lw=1,
                   color=div_cmap(N_div_cmap - (i - 3)))

for key in 'BC':
    ax[key].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
for key in 'DE':
    ax[key].grid(which='major', axis='x', lw=0.5, ls=':', color=[.6,.6,.6])
for a in ax.values():
    for side in 'right','top':
        a.spines[side].set_visible(False)
for key in 'BC':
    ax[key].set_ylim([-3.5, 3.5])
    ax[key].set_yticks(np.r_[-3 : 4 : 1.5])
for key in 'DE':
    ax[key].set_ylim([-55, -10])
    ax[key].set_yticks(np.r_[-50 : -9 : 10])
    ax[key].set_xscale('log')
    ax[key].set_xlabel('Frequency [Hz]')
    ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
    ax[key].set_xlim(ticks[[0,-1]] + np.array([0,1]))
    ax[key].xaxis.set_major_locator(FixedLocator(ticks))
    ax[key].xaxis.set_minor_locator(NullLocator())
    ax[key].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))
ax['C'].set_yticklabels([])
ax['E'].set_yticklabels([])

ax['B'].set_xlabel('Time [s]')
ax['B'].set_ylabel('Norm. V')
ax['C'].set_xlabel('Distr.')
ax['D'].set_ylabel('Power [dB]')

ticks = np.r_[0 : 10.5 : 2]
ax['B'].set_xlim(ticks[[0,-1]])
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

ticks = np.r_[0 : 0.51 : 0.25]
ax['C'].set_xlim(ticks[[0,-1]])
ax['C'].xaxis.set_major_locator(FixedLocator(ticks))
ax['C'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

trans = mtransforms.ScaledTranslation(-0.45, -0.05, fig.dpi_scale_trans)
for label in 'ABD':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.15, -0.05, fig.dpi_scale_trans)
for label in 'CE':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0)
plt.savefig('traces_hist_spectra_comp_grid.pdf')

## Figure 3

In [None]:
experiment_ID = '474d2016e33b441889ce8b17531487cb' # replaces '98475b819ecb4d569646d7e1467d7c9c'
# experiment_ID = 'd0e4cb94211c4190828fd8cd856cdd94' # replaces 'ed79ae2784274401a9dba5f5ccee98d8'
experiments_path = '../experiments/neural_network/'
history = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'history.pkl'), 'rb'))
test_results = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'test_results.pkl'), 'rb'))
N_bands = 40
filter_order = 6
if experiment_ID == '474d2016e33b441889ce8b17531487cb': # replaces '98475b819ecb4d569646d7e1467d7c9c'
    N_trials = 4000
elif experiment_ID == 'd0e4cb94211c4190828fd8cd856cdd94': # replaces 'ed79ae2784274401a9dba5f5ccee98d8'
    N_trials = 1000
else:
    raise Exception(f'Unknown number of trials for experiment {experiment_ID[:6]}')
correlations_file = f'correlations_{experiment_ID[:6]}_{N_bands}-bands_64-filters_' + \
    f'36-neurons_{N_trials}-trials_{filter_order}-butter_Vd_bus3_pool_1_3.npz'
correlations = np.load(os.path.join(experiments_path, experiment_ID, correlations_file))
for key in correlations.files:
    exec(f'{key} = correlations["{key}"]')
corr_idx = idx
corr_edges = edges

In [None]:
key = 'var_G2_G3'

fig = plt.figure(figsize=(3.5 * 2, 4))

offset = 0.07, 0.1
border = 0.01, 0.02
space = ((0.05, 0.1, 0.1), (0.075, 0.075, 0.13)), 0.17
rows = 2
h = (1 - offset[1] - border[1] - space[1] * (rows-1)) / rows
w_rel = [
    [3, 2, 3, 2],
    [2, 3, 3, 2]
]

w_rel_sum = np.sum(w_rel, axis=1)
cols = list(map(len, w_rel))
w_total = [1 - offset[0] - border[0] - np.sum(sp) for sp in space[0]]
w = []
for i in range(rows):
    w.append([])
    for j in range(cols[i]):
        w[-1].append(w_total[i] * w_rel[i][j] / w_rel_sum[i])

labels = ['ABCD','EFGH']
ax = {}
for i in range(rows):
    for j in range(cols[i]):
        ax[labels[i][j]] = fig.add_axes([offset[0] + np.sum(space[0][i][:j]) + np.sum(w[i][:j]),
                                         1 - border[1] - h * (i+1) - space[1] * i,
                                         w[i][j],
                                         h])

# ############# Panel A #############
cmap = plt.get_cmap('tab10')
green, magenta = cmap(2), cmap(6)
jdx, = np.where(t < tend)
ax['A'].plot(t[jdx], X[key][:5, jdx].T, color=green, lw=0.5)
ax['A'].plot(t[jdx], X[key][-10:-5, jdx].T, color=magenta, lw=0.5)

# ############# Panel B #############
idx = np.concatenate(group_index[key][:4])
n,edges = np.histogram(X[key][idx,:], bins=50, density=True)
ax['B'].plot(n, edges[1:], lw=1, color=green)
m = Xf[key][idx,:].mean(axis=0)
ci = 1.96 * Xf[key][idx,:].std(axis=0) / np.sqrt(idx.size)
ax['E'].plot(20 * np.log10(m), F, color=green, lw=1)
idx = np.concatenate(group_index[key][4:])
n,edges = np.histogram(X[key][idx,:], bins=50, density=True)
ax['B'].plot(n, edges[1:], lw=1, color=magenta)
m = Xf[key][idx,:].mean(axis=0)
ci = 1.96 * Xf[key][idx,:].std(axis=0) / np.sqrt(idx.size)
ax['E'].plot(20 * np.log10(m), F, color=magenta, lw=1)
ax['B'].set_xlim([0, 0.5])
ticks = np.r_[0 : 0.51 : 0.25]
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))
ax['B'].set_yticklabels([])

# ############# Panel C #############
ax['C'].plot(history['loss'], 'k', lw=1, label='Training')
ax['C'].plot(history['val_loss'], 'r', lw=1, label='Validation')
ax['C'].legend(loc='upper right', frameon=False, fontsize=fontsize)

# ############# Panel D #############
target_values = np.unique(exact_momentum)
df = pd.DataFrame(data={'exact': exact_momentum, 'pred': np.concatenate(pred_momentum)})
df_ctrl = pd.DataFrame(data={'exact': exact_momentum, 'pred': np.concatenate(pred_momentum_ctrl)})
# sns.violinplot(x='exact', y='pred', data=df_ctrl, cut=0, inner='quartile',
#                palette='gray', ax=ax['D'], linewidth=0.5)
sns.violinplot(x='exact', y='pred', data=df, cut=0, inner='quartile',
               palette=[green, magenta], ax=ax['D'], linewidth=0.5)
ax['D'].xaxis.set_major_locator(FixedLocator([0, 1]))
ax['D'].xaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in target_values]))
ax['D'].yaxis.set_major_locator(FixedLocator(target_values))
ax['D'].yaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in target_values]))

# ############# Panels F, G ########
plot_correlations(R, p, R_ctrl, p_ctrl, corr_edges,
                  [np.concatenate(corr_idx)], sort_freq=[1.1],
                  ax=np.array([[ax['F'], ax['G'], ax['H']]]))
ax['H'].plot(np.zeros(2), ax['H'].get_ylim(), ':', lw=1, color=[.6,.6,.6])
ax['F'].set_title('Trained network', fontsize=fontsize+1)
ax['G'].set_title('Untrained network', fontsize=fontsize+1)

for label in 'AB':
    ax[label].set_ylim([-3.5, 3.5])
    ax[label].set_yticks(np.r_[-3 : 3.1])

for label in 'ABE':
    ax[label].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])

for a in ax.values():
    for side in 'right','top':
        a.spines[side].set_visible(False)

for label in 'EFG':
    ax[label].set_yscale('log')
    ax[label].set_ylim([0.05, 20])

# ax['E'].invert_xaxis()
# ax['E'].yaxis.tick_right()
# ax['E'].spines['right'].set_visible(True)
# ax['E'].spines['left'].set_visible(False)
ax['E'].set_xlim([-10, -55])
ax['E'].set_xticks(np.r_[-50 : -9 : 10])

for label in 'EFGH':
    ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
    ax[label].set_ylim(ticks[[0,-1]])
    ax[label].yaxis.set_major_locator(FixedLocator(ticks))
    ax[label].yaxis.set_minor_locator(NullLocator())
    if label == 'F' or True:
        ax[label].yaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

ax['A'].set_xlabel('Time [s]')
ax['A'].set_ylabel('Norm. V')
ax['B'].set_xlabel('Distribution')
ax['C'].set_xlabel('Epoch')
ax['C'].set_ylabel('Loss')
ax['D'].set_xlabel(r'Exact M [GW$\cdot$s$^2$]')
ax['D'].set_ylabel(r'Predicted M [GW$\cdot$s$^2$]')
ax['E'].set_xlabel('Power [dB]')
ax['E'].set_ylabel('Frequency [Hz]')
ax['F'].set_xlabel('Filter #')
ax['G'].set_xlabel('Filter #')
ax['H'].set_xlabel('Correlation')
ax['F'].set_xticklabels([1, 32, 64])
ax['G'].set_xticklabels([1, 32, 64])

trans = mtransforms.ScaledTranslation(-0.2, -0.05, fig.dpi_scale_trans)
ax['B'].text(0.0, 1.0, 'B', transform=ax['B'].transAxes + trans, fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.55, -0.05, fig.dpi_scale_trans)
ax['C'].text(0.0, 1.0, 'C', transform=ax['C'].transAxes + trans, fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.4, -0.05, fig.dpi_scale_trans)
for label in 'ADEFGH':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

plt.savefig(f'low_high_momentum_{experiment_ID[:6]}.pdf')

### Figure 4

In [None]:
data = np.load(os.path.join(experiments_path, experiment_ID, 'stopband_momentum_estimation.npz'))
for key in data.files:
    exec(f'{key} = data["{key}"]')
N_bands = len(bands)
Xfm = np.zeros((len(group_index), F.size))
Xfci = np.zeros((len(group_index), F.size))
for i,idx in enumerate(group_index):
    Xfm[i] = Xf[idx].mean(axis=0)
    Xfci[i] = 1.96 * Xf[idx].std(axis=0) / np.sqrt(idx.size)

In [None]:
fig,ax = plt.subplot_mosaic(
    '''
    AAAA.
    BBBB.
    ''',
    figsize=(3.5,4)
)

cmap2 = plt.get_cmap('Set1')
y_m = [y[idx].mean() for idx in group_index]
y_s = [y[idx].std() for idx in group_index]
y_pred_m = np.array([[pred[jdx].mean() for jdx in group_index] for pred in y_pred])
y_pred_s = np.array([[pred[jdx].std() for jdx in group_index] for pred in y_pred])
add_band = False
last_band = N_bands + 1 if add_band else N_bands + 1
for i in range(last_band):
    if i == 0:
        lbl = 'Broadband'
    else:
        lbl = f'[{bands[i-1][0]}-{bands[i-1][1]:g}] Hz'
    ax['A'].plot(y_m, y_pred_m[i], color=cmap2(i), lw=1, label=lbl)
    for j in range(len(group_index)):
        ax['A'].plot(y_m[j] + np.zeros(2),
                   y_pred_m[i,j] + y_pred_s[i,j] * np.array([-1,1]),
                   color=cmap2(i), lw=1)
        ax['A'].plot(y_m[j] + y_s[j] * np.array([-1,1]),
                   y_pred_m[i,j] + np.zeros(2),
                   color=cmap2(i), lw=1)
        ax['A'].plot(y_m[j], y_pred_m[i,j], 'o', color=cmap2(i),
                   markerfacecolor='w', markersize=4, markeredgewidth=1)
for side in 'right','top':
    ax['A'].spines[side].set_visible(False)
ax['A'].legend(loc='center left', bbox_to_anchor=[0.875, 0.5], frameon=False, fontsize=fontsize-1)
ax['A'].set_xlabel(r'Exact M [GW$\cdot$s$^2$]')
ax['A'].set_ylabel(r'Predicted M [GW$\cdot$s$^2$]')
ax['A'].set_xlim(y_m + np.diff(y_m) * np.array([-1/10, 1/5]))
ax['A'].set_ylim(y_m + np.diff(y_m) / 3 * np.array([-1,1]))
ax['A'].xaxis.set_major_locator(FixedLocator(y_m))
ax['A'].xaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in y_m]))
ax['A'].yaxis.set_major_locator(FixedLocator(y_m))
ax['A'].yaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in y_m]))

axr = ax['B'].twinx()
for i,(m,ci,col) in enumerate(zip(Xfm, Xfci, (green,magenta))):
    ax['B'].plot(F, 20*np.log10(m), color=col,
                 label=r'M = {:.2f} GW$\cdot$s$^2$'.format(y[group_index[i]].mean()))
ax['B'].legend(loc='lower left', frameon=False, fontsize=fontsize-1, bbox_to_anchor=[0.0, 0.13])
axr.plot(ax['B'].get_xlim(), scores[0] + np.zeros(2), '--', color=cmap2(0), lw=2)
for i,band in enumerate(bands):
    if i >= last_band-1:
        break
    axr.axvline(band[0], color=[.6,.6,.6], ls=':', lw=0.5)
    axr.plot(band, scores[i+1] + np.zeros(2), color=cmap2(i+1), lw=2)
ax['B'].set_xlabel('Frequency [Hz]')
ax['B'].set_ylabel('Power [dB]')
axr.set_ylabel(r'R$^2$ score')
ax['B'].set_xscale('log')

ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
ax['B'].set_xlim(ticks[[0,-1]] + np.array([0,1]))
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_minor_locator(NullLocator())
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))
axr.set_ylim((-0.25, 1.05))
axr.set_yticks(np.r_[-0.2 : 1.05 : 0.2])

trans = mtransforms.ScaledTranslation(-0.45, -0.05, fig.dpi_scale_trans)
for label in 'AB':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0.2)
if add_band:
    fig.savefig(f'stopband_momentum_estimation_{experiment_ID}_{last_band}.pdf')
else:
    fig.savefig(f'stopband_momentum_estimation_{experiment_ID[:6]}.pdf')

### Figure 5

### Figure 6

In [None]:
# the following experiment IDs replace the previously used ones:
# '98475b819ecb4d569646d7e1467d7c9c' and '302a21340f354ac2949184be98d8e907'
experiment_IDs = '474d2016e33b441889ce8b17531487cb', 'c6f72abb5e364c4cb7770250e135bd73'
experiments_path = '../experiments/neural_network/'

data, correlations = {}, {}
N_bands = 40
filter_order = 6
for experiment_ID,N_trials in zip(experiment_IDs, (4000,4000)):
    key = experiment_ID[:6]
    tmp = np.load(os.path.join(experiments_path, experiment_ID, f'variable_inertia_{key}.npz'),
                  allow_pickle=True)
    data[key] = {}
    for fname in tmp.files:
        try:
            exec(f'data["{key}"]["{fname}"] = tmp["{fname}"].item()')
        except:
            exec(f'data["{key}"]["{fname}"] = tmp["{fname}"]')
    correlations_file = f'correlations_{key}_{N_bands}-bands_64-filters_' + \
        f'36-neurons_{N_trials}-trials_{filter_order}-butter_Vd_bus3_pool_1_3.npz'
    correlations[key] = np.load(os.path.join(experiments_path, experiment_ID, correlations_file))

In [None]:
fig = plt.figure(figsize=(3.5*1.8,4))

cols = 2
offset = np.array([[0.075, 0.1], [0.14, 0.03]])
space = {
    'AB': 0.125,
    'BC': 0.05,
    'LR': 0.1,
    'DE': 0.125,
    'EF': 0.035
}
width = {'A': 0.4}
height = {'A': 0.4}
width['B'] = (width['A'] - space['BC']) / 2
width['C'] = width['B']
height['B'] = 1 - np.sum(offset[:,1]) - height['A'] - space['AB']
height['C'] = height['B']
width['D'] = 1 - np.sum(offset[:,0]) - space['LR'] - width['A'] + 0.1
width['E'], width['F'] = width['D'] - 0.06, width['D']
height['D'] = (1 - np.sum(offset[:,1]) - space['DE'] - space['EF']) / 2.5
height['E'] = (1 - np.sum(offset[:,1]) - space['DE'] - space['EF'] - height['D']) / 2
height['F'] = height['E']

ax = {
    'A': plt.axes([offset[0,0],
                   offset[0,1] + space['AB'] + height['B'],
                   width['A'], height['A']]),
    'B': plt.axes([offset[0,0],
                   offset[0,1],
                   width['B'],
                   height['B']]),
    'C': plt.axes([offset[0,0] + space['BC'] + width['B'],
                   offset[0,1],
                   width['C'], height['C']]),
    'D': plt.axes([offset[0,0] + width['A'] + space['LR'],
                   offset[0,1] + 2*height['E'] + space['EF'] + space['DE'],
                   width['D'], height['D']]),
    'E': plt.axes([offset[0,0] + width['A'] + space['LR'],
                   offset[0,1] + height['F'] + space['EF'],
                   width['E'], height['E']]),
    'F': plt.axes([offset[0,0] + width['A'] + space['LR'],
                   offset[0,1],
                   width['F'], height['F']]),
}

cmap_name = 'tab10'
cmap = plt.get_cmap(cmap_name)

################### Panels A, B and C ###################
key = experiment_IDs[0][:6]
xticks = {'A': np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20]),
          'B': np.array([0.4, 0.5, 0.7, 1, 1.5]),
          'C': np.array([8, 10, 12, 15])}
yticks = {'A': np.r_[-60 : -5 : 10],
          'B': np.r_[-30 : -9 : 5],
          'C': np.r_[-56 : -44 : 3]}
for a in 'ABC':
    n = 0
    for i,(k,v) in enumerate(data[key]['Xf'].items()):
        for j in range(data[key]['n_mom_groups'][k]):
            lbl = r'{:.3f} GW$\cdot$s$^2$'.format(data[key]['ym'][k][j])
            if k == 'var_G2_G3':
                lbl += ' (GEN)'
            elif k == 'var_Comp11':
                lbl += ' (COMP)'
            m = v[data[key]['group_index'][k][j], :].mean(axis=0)
            s = v[data[key]['group_index'][k][j], :].std(axis=0)
            ci = 1.96 * s / np.sqrt(data[key]['group_index'][k][j].size)
            ax[a].fill_between(data[key]['F'],
                               20*np.log10(m + ci),
                               20*np.log10(m - ci),
                               color=cmap(n), facecolor=cmap(n), edgecolor=cmap(n),
                               alpha=0.5, label=lbl)
            n += 1
    ax[a].set_xscale('log')
    ax[a].set_xlabel('Frequency [Hz]')
    for side in 'right','top':
        ax[a].spines[side].set_visible(False)
    ax[a].set_xlim(xticks[a][[0,-1]] + np.array([0,1]))
    ax[a].xaxis.set_major_locator(FixedLocator(xticks[a]))
    ax[a].xaxis.set_minor_locator(NullLocator())
    ax[a].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in xticks[a]]))
    ax[a].yaxis.set_major_locator(FixedLocator(yticks[a]))
    ax[a].yaxis.set_minor_locator(NullLocator())
    ax[a].yaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in yticks[a]]))

ax['A'].legend(loc='lower left', frameon=False, fontsize=6)
ax['A'].set_ylabel('Power [dB]')
ax['B'].set_ylabel('Power [dB]')
ax['A'].set_ylim([-60, -10])
ax['B'].set_xlim([0.4, 1.5])
ax['C'].set_xlim([8, 15])
ax['B'].set_ylim([-31, -9])
ax['C'].set_ylim([-57, -46])

################### Panel D ###################
key = experiment_IDs[0][:6]
ax['D'].plot(data[key]['ym']['test'], data[key]['ym']['test'], 'k--', lw=2, markerfacecolor='w')
m = 'so'
markers = {ID[:6]: m[i] for i,ID in enumerate(experiment_IDs)}
lw, ms = 1, 6
for ID in experiment_IDs:
    n = 0
    key = ID[:6]
    ym = data[key]['ym']
    ym_pred = data[key]['ym_pred']
    ys_pred = data[key]['ys_pred']
    for cond in ym:
        for i in range(len(ym[cond])):
            ax['D'].plot(ym[cond][i] + np.zeros(2),
                         ym_pred[cond][i] + ys_pred[cond][i] * np.array([-1,1]),
                         color=cmap(n), linewidth=lw)
            ax['D'].plot(ym[cond][i], ym_pred[cond][i], markers[key], color=cmap(n), markersize=ms,
                     markerfacecolor='w', markeredgewidth=lw)
            n += 1
ax['D'].plot(1, 1, 'k'+m[0], markersize=ms, markerfacecolor='w',
             markeredgewidth=lw, label='without var. comp. in tr.')
ax['D'].plot(1, 1, 'k'+m[1], markersize=ms, markerfacecolor='w',
             markeredgewidth=lw, label='with var. comp. in tr.')
ax['D'].legend(loc='lower right', frameon=False, fontsize=fontsize-1, bbox_to_anchor=[1, -0.05])
ax['D'].set_xlabel(r'Exact M [GW$\cdot$s$^2$]')
ax['D'].set_ylabel(r'Predicted M [GW$\cdot$s$^2$]')
ym = ym['test']
ticks = [ym[0], 0.197, ym[1]]
ax['D'].set_xlim(ym + np.diff(ym) * np.array([-1/10, 1/10]))
ax['D'].set_ylim(ym + np.diff(ym) * np.array([-1/10, 1/10]))
ax['D'].xaxis.set_major_locator(FixedLocator(ticks))
ax['D'].xaxis.set_major_formatter(FixedFormatter([f'{tick:.3f}' for tick in ticks]))
ax['D'].yaxis.set_major_locator(FixedLocator(ticks))
ax['D'].yaxis.set_major_formatter(FixedFormatter([f'{tick:.3f}' for tick in ticks]))

################### Panels E and F ###################
key = experiment_IDs[1][:6]
plot_correlations(correlations[key]['R'],
                  correlations[key]['p'],
                  None,
                  None,
                  correlations[key]['edges'],
                  [np.concatenate(correlations[key]['idx'])],
                  sort_freq=[1.5],
                  ax=np.array([[ax['E']]]))
plot_correlations(correlations[key]['R'],
                  correlations[key]['p'],
                  None,
                  None,
                  correlations[key]['edges'],
                  [np.concatenate(correlations[key]['idx'])],
                  sort_freq=[10],
                  ax=np.array([[ax['F'], ax['F']]]))

for lbl in 'EF':
    ax[lbl].set_ylim(xticks['A'][[0,-1]] + np.array([0,1]))
    ax[lbl].yaxis.set_major_locator(FixedLocator(xticks['A']))
    ax[lbl].yaxis.set_minor_locator(NullLocator())
    ax[lbl].yaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in xticks['A']]))
    ax[lbl].set_ylabel('Frequency [Hz]')
    ax[lbl].set_xticks([1, 32, 64])
ax['E'].set_xticklabels([])
ax['F'].set_xlabel('Filter #')

for lbl in ax:
    for side in 'right','top':
        ax[lbl].spines[side].set_visible(False)

trans = mtransforms.ScaledTranslation(-0.4, -0.04, fig.dpi_scale_trans)
ax['A'].text(0.0, 1.0, 'A', transform=ax['A'].transAxes + trans,
             fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.45, -0.04, fig.dpi_scale_trans)
for key,lbl in zip('DEF', 'BCD'):
    ax[key].text(0.0, 1.0, lbl, transform=ax[key].transAxes + trans,
                 fontsize=10, va='bottom')

pdf_file = f'variable_inertia_{experiment_IDs[0][:6]}_{experiment_IDs[1][:6]}.pdf'
fig.savefig(pdf_file)

### Figure 7

In [None]:
experiment_ID = 'f64bde90cab54d1ea770bb21f33c3ed1'
# experiment_ID = 'a40658acee3c4e419c0ee34d0c59f4df'
experiments_path = '../experiments/neural_network/'
history = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'history.pkl'), 'rb'))
test_results = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'test_results.pkl'), 'rb'))
print('MAPE on the test set: {:.2f}%.'.format(test_results['mape_prediction'][0]))

In [None]:
vert = False
if vert:
    fig,ax = plt.subplots(2, 1, figsize=(2.5,3))
else:
    fig,ax = plt.subplots(1, 2, figsize=(5,2))

ax[0].plot(history['loss'], 'k', lw=1, label='Training')
ax[0].plot(history['val_loss'], 'r', lw=1, label='Validation')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Loss')
ax[0].legend(loc='upper right', frameon=False)

y_test,y_pred = test_results['y_test'].squeeze(), test_results['y_prediction'].squeeze()
# limits = np.array([y_test.min(), y_test.max()])
# print(limits)
if experiment_ID[:6] == 'a40658':
    limits = [0.17, 0.3]
    ticks = np.r_[0.18 : 0.31 : 0.03]
elif experiment_ID[:6] == 'f64bde':
    limits = [0.17, 0.28]
    ticks = np.r_[0.18 : 0.31 : 0.03]
else:
    raise Exception('set limits and ticks')
ax[1].plot(limits, limits, '--', lw=1, color=[.6,.6,.6])
for y in np.unique(y_test):
    idx = y_test == y
    m = y_pred[idx].mean()
    s = y_pred[idx].std()
    ax[1].plot(y+np.zeros(2), m+s*np.array([-1,1]), 'k', lw=1)
    ax[1].plot(y, m, 'ko', markersize=4, markerfacecolor='w', markeredgewidth=1)
ax[1].set_xlabel(r'Exact M [GW$\cdot$s$^2$]')
ax[1].set_ylabel(r'Predicted M [GW$\cdot$s$^2$]')
ax[1].set_xticks(ticks)
ax[1].set_yticks(ticks)
ax[1].set_xlim(limits)
ax[1].set_ylim(limits)

for a in ax:
    for side in 'right','top':
        a.spines[side].set_visible(False)

fig.tight_layout(pad=0.1)
if vert:
    fig.savefig(f'training_{experiment_ID[:6]}.pdf')
else:
    fig.savefig(f'/Users/daniele/Downloads/training_{experiment_ID[:6]}_horiz.pdf')

### Figure 8

In [None]:
# experiment_IDs = '7d55f784f6b64f2caeb866804bda1a8b', '57dfd307a5a945d8b28e5cf501b41f13'
experiment_IDs = 'a40658acee3c4e419c0ee34d0c59f4df', 'f64bde90cab54d1ea770bb21f33c3ed1'
experiments_path = '../experiments/neural_network/'
dbs = [shelve.open(os.path.join(experiments_path, ID, ID[:6]+'.out')) for ID in experiment_IDs]

In [None]:
vert = False
if vert:
    fig,ax = plt.subplot_mosaic(
        '''
        A
        B
        C
        D
        ''',
        figsize=(3, 4)
    )
else:
    fig,ax = plt.subplot_mosaic(
        '''
        AB
        CD
        ''',
        figsize=(6.5, 2.5)
    )

ylim = [
    [0.17, 0.27],
    [0.17, 0.27],
    [0.20, 0.25],
    [0.20, 0.25]
]
yticks = [
    np.r_[0.18 : 0.27 : 0.04],
    np.r_[0.18 : 0.27 : 0.04],
    np.r_[0.20 : 0.26 : 0.025],
    np.r_[0.20 : 0.26 : 0.025],
]
nextch = lambda ch: chr(ord(ch) + 1)
ithch = lambda n,start='A': chr(ord(start)+n)

col = [.6+np.zeros(3), np.zeros(3)]
col = [[.2,.8,.4], np.zeros(3)]
magenta = [1,0,1]

for i,db in enumerate(dbs):
    for j,expt in enumerate(db['experiments']):
        J = ithch(j)
        time = expt['prediction_time']
        prediction = np.squeeze(expt['prediction'])
        exact = expt['exact']
        mean_prediction = expt['mean_prediction']
        N_blocks = len(expt['H_values'])
        block_dur = np.ceil(expt['data_time'][-1]) / N_blocks
        area_measure = 'M'
        measure_units = r'GW$\cdot$s$^2$'
        for k in range(N_blocks):
            t0,t1 = block_dur*k, block_dur*(k+1)
            idx, = np.where((time >= t0) & (time < t1) & np.logical_not(np.isnan(prediction)))
            n,x = np.histogram(prediction[idx], bins=10, density=True)
            ax[J].plot(np.array([t0, t1])/60, mean_prediction[k] + np.zeros(2), color=col[i], lw=1)
            if i == 1:
                ax[J].plot(np.array([t0, t1])/60, exact[k] + np.zeros(2), '--', color=magenta, lw=1)
        ax[J].plot(time/60, prediction, color=col[i], lw=0.75)
        if i == 0:
            for side in 'right','top':
                ax[J].spines[side].set_visible(False)
            ax[J].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
            ax[J].set_ylim(ylim[j])
            ax[J].set_xticks(np.r_[0 : 181 : 30])
            ax[J].set_yticks(yticks[j])
            ax[J].set_ylabel(f'{area_measure.capitalize()} [{measure_units}]')

ax['D'].set_xlabel('Time [min]')
if not vert:
    ax['C'].set_xlabel('Time [min]')
for i in 'AB':
    ax[i].set_xticklabels([])
if vert:    
    ax['C'].set_xticklabels([])
trans = mtransforms.ScaledTranslation(-0.5, -0.05, fig.dpi_scale_trans)
for lbl in 'ABCD':
    ax[lbl].text(0.0, 1.0, lbl, transform=ax[lbl].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0.2)
if vert:
    pdf_file = f'area_momentum_estimation_grid_{experiment_IDs[0][:6]}_{experiment_IDs[1][:6]}.pdf'
else:
    pdf_file = f'/Users/daniele/Downloads/area_momentum_estimation_grid_{experiment_IDs[0][:6]}_{experiment_IDs[1][:6]}_horiz.pdf'
fig.savefig(pdf_file)

### Figure 8

In [None]:
buses = 3, # 14, 17, 39
N_buses = len(buses)
var_names = [f'Vd_bus{bus}' for bus in buses]
area_measure = 'momentum'
max_block_size = 100
cutoff = 0.1
filename = f'spectra_compensators_buses={"-".join(map(str,buses))}_cutoff={cutoff:.02f}_blocks={max_block_size}'
data = np.load(filename + '.npz')
for key in data.files:
    globals()[key] = data[key]
min_fft = Xfm.min(axis=(1,2))
max_fft = Xfm.max(axis=(1,2))
N_H = len(H_comp)
# Xfm = Xfm.squeeze()
step = Xfm.shape[1] // H_comp.size

n = Xfm.shape[1]
peaks = np.zeros((n,2))
var_idx = 0
for i in range(n):
    x = 20*np.log10(Xfm[var_idx, i, F < 1.5])
    locs,_ = find_peaks(x, height=10, prominence=1, distance=5)
    peaks[i,:] = F[locs[1:3]]

In [None]:
fig = plt.figure(figsize=(2.75, 5))
x_offset = [0.2, 0.02]
y_offset = [0.1, 0.04]
x_space = 0.1
y_space = 0.05

# H_comp_sub = [0.1, 2.5, 5.0]
H_comp_sub = [1.0, 3.0, 6.0]
rows = len(H_comp_sub)
height = [0.15, 0.15]
n = 2.2
w = 1 - np.sum(x_offset)
h = (1 - np.sum(height) - y_space*len(height)*1.2 - np.sum(y_offset) - y_space * (rows - 1)) / rows
print(h)
ax = [plt.axes([x_offset[0], 1 - y_offset[1] - height[0], w, height[0]])]
for i in range(rows-1, -1, -1):
    ax.append(plt.axes([x_offset[0],
                        y_offset[0] + height[1] + y_space*1.5 + (h + y_space) * i,
                        w, h]))
ax.append(plt.axes([x_offset[0], y_offset[0], w, height[1]]))
# ax = [a[0] for a in ax]

xticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
var_idx = 0

reds = plt.get_cmap('Reds', N_H+4)
blues = plt.get_cmap('Blues', N_H+4)
for i in range(N_H-1):
    ax[0].plot(F, 20*np.log10(Xfm[var_idx, i*step, :]), color=reds(i+2), lw=2)
    ax[0].plot(F, 20*np.log10(Xfm[var_idx, step - 1 + i*step, :]), color=blues(i+2), lw=1)
ax[0].plot(F, 20*np.log10(Xfm[var_idx, i*step, :]), color=reds(i+2), lw=2, label='Low momentum')
ax[0].plot(F, 20*np.log10(Xfm[var_idx, step - 1 + i*step, :]), color=blues(i+2), lw=1, label='High momentum')
ax[0].set_xscale('log')
ax[0].set_yticks(np.r_[-20:25:10])
ax[0].set_ylabel('Power [dB]')
ax[0].grid(which='major', axis='both', ls=':', lw=0.5, color=[.6,.6,.6])
ax[0].set_xlim(xticks[[0,-1]] + np.array([0,1]))
ax[0].xaxis.set_major_locator(FixedLocator(xticks))
ax[0].xaxis.set_minor_locator(NullLocator())
ax[0].xaxis.set_major_formatter(FixedFormatter([]))
ax[0].legend(loc='lower left', frameon=False, fontsize=fontsize-1, bbox_to_anchor=[0,-0.05])
ax[0].arrow(17, 7, -10, 7, shape='left', width=0.2, head_width=2, head_length=1, fc='k', ec='k')
ax[0].text(10, 27, 'Increasing H comp', fontsize=fontsize-2, va='top', ha='center', rotation=-15)
_forward = lambda x: 20 * np.log10(x)
_inverse = lambda y: 10 ** (y/20)
cmap = plt.cm.jet
w,h = 2.75,1
for i,H in enumerate(H_comp_sub):
    j = np.where(H_comp == H)[0][0]
    idx = IDX[j]
    norm = FuncNorm((_forward, _inverse), vmin=min_fft[var_idx], vmax=max_fft[var_idx])
    X,Y = np.meshgrid(F, ym[idx])
    ylim = [Y.min(), Y.max()]
    if True:
        ax[i+1].pcolor(X, Y, Xfm[var_idx, idx, :], cmap=cmap, norm=norm)

    col = [[1,1,1], [0,.75,1]]
    for k in range(2):
        jdx = np.argsort(ym[idx])
        xdata = peaks[idx[jdx],k]
        ydata = ym[idx[jdx]]
        func = lambda x,a,b: a*x**b
        popt,pcov = curve_fit(func, ydata, xdata)
        ax[i+1].plot(func(ydata, *popt), ydata, '--', color=col[k], lw=1)

    dx = np.log(F[-1] - cutoff) / 40
    dy = np.diff(ylim)[0] / 10
    xy = np.array([
        [Fpeak[var_idx,j], ylim[0] + dy],
        [np.exp(np.log(Fpeak[var_idx,j]) - dx), ylim[0]],
        [np.exp(np.log(Fpeak[var_idx,j]) + dx), ylim[0]]
    ])
    triangle = Polygon(xy, fc='k', ec='k')
    ax[i+1].add_patch(triangle)

    ax[i+1].set_xlim([cutoff, F[-1]])
    ax[i+1].set_xscale('log')
    ax[i+1].set_xlim(xticks[[0,-1]] + np.array([0,1]))
    ax[i+1].xaxis.set_major_locator(FixedLocator(xticks))
    ax[i+1].xaxis.set_minor_locator(NullLocator())
    if i == rows-1:
        ax[i+1].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in xticks]))
    else:
        ax[i+1].xaxis.set_major_formatter(FixedFormatter([]))

    yticks = np.linspace(ylim[0], ylim[1], 3)
    ax[i+1].text(xticks[-1]-3, ylim[1] - 0.1*np.diff(ylim), 'H = {:.1f} s'.format(H),
               fontsize=fontsize+1, color='w', verticalalignment='top', horizontalalignment='right')
    ax[i+1].set_ylim(ylim)
    ax[i+1].yaxis.set_major_locator(FixedLocator(yticks))
    ax[i+1].yaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in yticks]))
    ax[i+1].set_ylabel(r'M [GW$\cdot$s$^2$]')

ax[-2].set_xlabel('Frequency [Hz]')

ax[-1].plot(H_comp[1:], Fpeak[var_idx,1:], 'ko-', lw=1, markersize=4, markerfacecolor='w')
ax[-1].set_xlim([0.5, 6.5])
ax[-1].set_ylim([3, 16])
ax[-1].set_yticks(np.r_[5 : 20 : 5])
ax[-1].set_xlabel('Compensator H [s]')
ax[-1].set_ylabel('Frequency peak [Hz]')
ax[-1].grid(which='major', axis='y', ls=':', lw=0.5, color=[.6,.6,.6])
for side in 'right','top':
    for a in ax:
        a.spines[side].set_visible(False)

trans = mtransforms.ScaledTranslation(-0.5, 0, fig.dpi_scale_trans)
ax[0].text(0.0, 1.0, 'A', transform=ax[0].transAxes + trans, fontsize=10, va='bottom')
ax[-1].text(0.0, 1.0, 'C', transform=ax[-1].transAxes + trans, fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.5, 0.03, fig.dpi_scale_trans)
ax[1].text(0.0, 1.0, 'B', transform=ax[1].transAxes + trans, fontsize=10, va='bottom')

fig.savefig(f'spectra_compensators_{var_names[var_idx]}.pdf')