In [None]:
import os
import re
import sys
import glob
import pickle
import numpy as np
from scipy.fft import fft, fftfreq
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FixedLocator, NullLocator, FixedFormatter
import seaborn as sns

fontsize = 8
lw = 0.75

matplotlib.rc('font', **{'family': 'Times New Roman', 'size': fontsize})
matplotlib.rc('axes', **{'linewidth': 0.75, 'labelsize': fontsize})
matplotlib.rc('xtick', **{'labelsize': fontsize})
matplotlib.rc('ytick', **{'labelsize': fontsize})
matplotlib.rc('xtick.major', **{'width': lw, 'size':3})
matplotlib.rc('ytick.major', **{'width': lw, 'size':3})
matplotlib.rc('ytick.minor', **{'width': lw, 'size':1.5})

%matplotlib inline

if '..' not in sys.path:
    sys.path.append('..')
from dlml.data import load_data_files, load_data_areas

In [None]:
def make_axes(rows, cols, x_offset, y_offset, x_space, y_space, squeeze=True):
    w = (1 - np.sum(x_offset) - x_space * (cols - 1)) / cols
    h = (1 - np.sum(y_offset) - y_space * (rows - 1)) / rows
    
    ax = [[plt.axes([x_offset[0] + (w + x_space) * j,
                     y_offset[0] + (h + y_space) * i,
                     w, h]) for j in range(cols)] for i in range(rows-1, -1, -1)]
    
    for row in ax:
        for a in row:
            for side in 'right','top':
                a.spines[side].set_visible(False)

    if squeeze:
        if rows == 1 and cols == 1:
            return ax[0][0]
        if rows == 1:
            return ax[0]
        if cols == 1:
            return [a[0] for a in ax]
        
    return ax

In [None]:
def plot_correlations(R, p, R_ctrl, p_ctrl, edges, idx, ax, sort_F=1.0, vmin=None, vmax=None):
    if p is not None:
        R = R.copy()
        R[p > 0.05] = 0
    if p_ctrl is not None:
        R_ctrl = R_ctrl.copy()
        R_ctrl[p_ctrl > 0.05] = 0
    R_mean = [R[jdx].mean(axis=0) for jdx in idx]
    R_ctrl_mean = [R_ctrl[jdx].mean(axis=0) for jdx in idx]
    rows, cols = ax.shape
    if rows != len(idx):
        raise Exception('Number of rows of ax does not match len(idx)')
    edge = np.abs(edges - sort_F).argmin()
    for i in range(rows):
        kdx = np.argsort(R_mean[i][edge,:])
        R_mean[i] = R_mean[i][:,kdx]
        kdx = np.argsort(R_ctrl_mean[i][edge,:])
        R_ctrl_mean[i] = R_ctrl_mean[i][:,kdx]

    make_symmetric = False
    if vmin is None:
        vmin = min([r.min() for r in R_mean])
        make_symmetric = True
    if vmax is None:
        vmax = max([r.max() for r in R_mean])
        if make_symmetric:
            if vmax > np.abs(vmin):
                vmin = -vmax
            else:
                vmax = -vmin
    print(f'Color bar bounds: ({vmin:.2f},{vmax:.2f}).')
    ticks = np.linspace(vmin, vmax, 7)
    ticklabels = [f'{tick:.2f}' for tick in ticks]

    cmap = plt.get_cmap('bwr')
    y = edges[:-1] + np.diff(edges) / 2
    for i in range(rows):
        for j,R in enumerate((R_mean[i], R_ctrl_mean[i])):
            x = np.arange(R.shape[-1])
            im = ax[i][j].pcolormesh(x, y, R, vmin=vmin, vmax=vmax, shading='auto', cmap=cmap)
            for side in 'right','top':
                ax[i][j].spines[side].set_visible(False)
            ax[i][j].set_xticks(np.linspace(0, x[-1], 3, dtype=np.int32))
        cbar = plt.colorbar(im, fraction=0.1, shrink=1, aspect=20, label='Correlation',
                            orientation='vertical', ax=ax[i][-1], ticks=ticks)
        cbar.ax.set_yticklabels(ticklabels, fontsize=fontsize-1)

    for i in range(rows):
        for j in range(cols):
            ax[i][j].set_ylim(edges[[0,-2]])
            ax[i][j].set_yscale('log')
            if j in (0,2):
                ax[i][j].set_yticklabels([])
            if i < rows-1:
                ax[i][j].set_xticklabels([])
            if j > 0:
                ax[-1][j].set_xlabel('Filter #')

    return vmin, vmax

In [None]:
X, y, Xf = {}, {}, {}
group_index, n_mom_groups = {}, {}

In [None]:
# data_file = 'traces_hist_spectra.npz'
data_file = 'traces_hist_spectra_comp_grid.npz'
force = False
if not os.path.isfile(data_file) or force:
    if data_file == 'traces_hist_spectra.npz':
        set_name = 'test'
    else:
        set_name = 'training'
    data_dir = '../data/IEEE39/converted_from_PowerFactory/all_stoch_loads/var_H_area_1'
    data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))
    var_names = ['Vd_bus3']
    generators_areas_map = [['G02', 'G03'], ['G04', 'G05', 'G06', 'G07'], ['G08', 'G09', 'G10'], ['G01']]
    generators_Pnom = {'G01': 10e9, 'G02': 700e6, 'G03': 800e6, 'G04': 800e6, 'G05': 300e6,
                       'G06': 800e6, 'G07': 700e6, 'G08': 700e6, 'G09': 1000e6, 'G10': 1000e6}
    area_measure = 'momentum'
    ret = load_data_areas({set_name: data_files}, var_names,
                          generators_areas_map[:1],
                          generators_Pnom,
                          area_measure,
                          trial_dur=60,
                          max_block_size=1000,
                          use_tf=False,
                          add_omega_ref=True,
                          use_fft=False)
    t = ret[0]
    X_raw = ret[1][set_name]
    y[set_name] = ret[2][set_name]
    group_index[set_name] = [np.where(y[set_name] == mom)[0] for mom in np.unique(y[set_name])]
    n_mom_groups[set_name] = len(group_index[set_name])
    X_mean, X_std = X_raw.mean(axis=(1,2)), X_raw.std(axis=(1,2))
    X[set_name] = (X_raw - X_mean) / X_std
    X[set_name] = X[set_name].squeeze()
    y[set_name] = y[set_name].squeeze()

    dt = np.diff(t[:2])[0]
    N_samples = t.size
    Xf[set_name] = fft(X[set_name])
    Xf[set_name] = 2.0 / N_samples * np.abs(Xf[set_name][:, :N_samples//2])
    F = fftfreq(N_samples, dt)[:N_samples//2]


    data_dir = '../data/IEEE39/converted_from_PowerFactory/all_stoch_loads/low_momentum_test_var_G1_G2'
    data_files = sorted(glob.glob(data_dir + os.path.sep + '*.h5'))
    idx = list(map(lambda f: len(re.findall('_[0-9]*.h5', f)) > 0, data_files))
    data_files = [data_files[j] for j,i in enumerate(idx) if i]
    N = len(np.unique([int(os.path.splitext(re.findall('_[0-9]*.h5', f)[0])[0][1:]) for f in data_files]))
    M = len(data_files) // N
    group_index['var_G2_G3'] = np.reshape(np.arange(len(data_files)), (M,N))

    ret = load_data_files(data_files,
                          var_names,
                          generators_areas_map[:1],
                          generators_Pnom,
                          'momentum')

    X_raw = ret[1][:, :, :-1]
    X['var_G2_G3'] = np.zeros(X_raw.shape)
    X['var_G2_G3'] = (X_raw - X_mean) / X_std
    y['var_G2_G3'] = ret[2]
    X['var_G2_G3'] = X['var_G2_G3'].squeeze()
    y['var_G2_G3'] = y['var_G2_G3'].squeeze()
    Xf['var_G2_G3'] = fft(X['var_G2_G3'])
    Xf['var_G2_G3'] = 2.0 / N_samples * np.abs(Xf['var_G2_G3'][:, :N_samples//2])
    n_mom_groups['var_G2_G3'] = len(group_index['var_G2_G3'])
    
    np.savez_compressed(data_file, t=t, X=X, y=y, Xf=Xf, F=F,
                        group_index=group_index, n_mom_groups=n_mom_groups)
else:
    if data_file == 'traces_hist_spectra.npz':
        set_name = 'test'
    else:
        set_name = 'training'
    data = np.load(data_file, allow_pickle=True)
    t = data['t']
    X = data['X'].item()
    y = data['y'].item()
    F = data['F']
    Xf = data['Xf'].item()
    group_index = data['group_index'].item()
    n_mom_groups = data['n_mom_groups'].item()

## Figure 2

In [None]:
fig,ax = plt.subplot_mosaic(
    '''
    AAAAAAAA
    BBBBBCCC
    DDDDEEEE
    ''',
    figsize=(3.5,4)
)
cmap = plt.get_cmap('Set1')
N_div_cmap = 10
div_cmap = plt.get_cmap('bwr', N_div_cmap)

###########################################
momentum = lambda H, S, fn: 2 * H@S / fn * 1e-3

N = 11, 11
h_G1_0, h_G2_0 = 4.33, 4.47
h_G1 = h_G1_0 + np.linspace(-1, 1, N[0])
h_G2 = h_G2_0 + np.linspace(-1, 1, N[1])
H_G1, H_G2 = np.meshgrid(*[h_G1, h_G2])

M = np.zeros(N)
S = np.array([700, 800])
fn = 60
for i in range(N[0]):
    for j in range(N[1]):
        H = np.array([H_G1[i,j], H_G2[i,j]])
        M[i,j] = momentum(H, S, fn)
        
gray_cmap = plt.get_cmap('gray')
cont = ax['A'].contourf(H_G1, H_G2, M, levels=100, cmap=gray_cmap)
cbar = plt.colorbar(cont, ax=ax['A'])
white = [1,1,1]
magenta = [1,0,1]
green = [0,1,0]
yellow = [1,1,0]
blue = [0,.5,1]
red = [1,.333,.333]
orange = [1, .5, 0]
gray = [.6, .6, .6]
ax['A'].scatter(H_G1[::2, ::2], H_G2[::2, ::2], s=5, c='w', edgecolors='k', lw=0.5, marker='o')
for i in range(2):
    for j in range(2):
        ax['A'].plot(H_G1[i,j], H_G2[i,j], 's', markersize=4, lw=1,
                    color=div_cmap(i*2+j), markerfacecolor='none')
        ax['A'].plot(H_G1[-1-i,-1-j], H_G2[-1-i,-1-j], 's', markersize=4, lw=1,
                    color=div_cmap(N_div_cmap-1-i*2-j), markerfacecolor='none')
for i in range(0, N[0], 2):
    ax['A'].plot(H_G1[i,i], H_G2[i,i], 'o', color=cmap(i//2), markerfacecolor='none', markersize=3, lw=1)
ax['A'].plot(H_G1[:2,:2].mean(), H_G2[:2,:2].mean(), 'x', color=magenta, markersize=4, markeredgewidth=1)
ax['A'].plot(H_G1[-2:,-2:].mean(), H_G2[-2:,-2:].mean(), 'x', color=magenta, markersize=4, markeredgewidth=1)
ax['A'].set_xlabel(r'$H_{G_1}$ [s]')
ax['A'].set_ylabel(r'$H_{G_2}$ [s]')
ax['A'].set_xlim([h_G1_0 - 1.1, h_G1_0 + 1.1])
ax['A'].set_ylim([h_G2_0 - 1.1, h_G2_0 + 1.1])
ax['A'].set_xticks([h_G1_0 - 1, h_G1_0, h_G1_0 + 1])
ax['A'].set_yticks([h_G2_0 - 1, h_G2_0, h_G2_0 + 1])
cbar.set_label(r'Momentum [GW$\cdot$s]')
cbar.set_ticks(np.r_[0.17 : 0.28 : 0.02])
###########################################

tend = 10
jdx, = np.where(t < tend)
for i,idx in enumerate(group_index[set_name]):
    n,edges = np.histogram(X[set_name][idx,:], bins=50, density=True)
    ax['B'].plot(t[jdx], X[set_name][idx[0]+1, jdx], lw=1, color=cmap(i))
    ax['C'].plot(n, edges[1:], lw=1, color=cmap(i))
    ax['D'].plot(F, 20 * np.log10(Xf[set_name][idx,:].mean(axis=0)), lw=1,
               color=cmap(i), label=r'{:.3f} GW$\cdot$s'.format(y[set_name][idx[0]+1]))

for i,idx in enumerate(group_index['var_G2_G3']):
    if i < 4:
        ax['E'].plot(F, 20 * np.log10(Xf['var_G2_G3'][idx,:].mean(axis=0)), lw=1,
                   color=div_cmap(i))
    else:
        ax['E'].plot(F, 20 * np.log10(Xf['var_G2_G3'][idx,:].mean(axis=0)), lw=1,
                   color=div_cmap(N_div_cmap - (i - 3)))

for key in 'BC':
    ax[key].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])
for key in 'DE':
    ax[key].grid(which='major', axis='x', lw=0.5, ls=':', color=[.6,.6,.6])
for a in ax.values():
    for side in 'right','top':
        a.spines[side].set_visible(False)
for key in 'BC':
    ax[key].set_ylim([-3.5, 3.5])
    ax[key].set_yticks(np.r_[-3 : 4 : 1.5])
for key in 'DE':
    ax[key].set_ylim([-55, -10])
    ax[key].set_yticks(np.r_[-50 : -9 : 10])
    ax[key].set_xscale('log')
    ax[key].set_xlabel('Frequency [Hz]')
    ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
    ax[key].set_xlim(ticks[[0,-1]] + np.array([0,1]))
    ax[key].xaxis.set_major_locator(FixedLocator(ticks))
    ax[key].xaxis.set_minor_locator(NullLocator())
    ax[key].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))
ax['C'].set_yticklabels([])
ax['E'].set_yticklabels([])

ax['B'].set_xlabel('Time [s]')
ax['B'].set_ylabel('Norm. V')
ax['C'].set_xlabel('Distr.')
ax['D'].set_ylabel('Power [dB]')

ticks = np.r_[0 : 10.5 : 2]
ax['B'].set_xlim(ticks[[0,-1]])
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

ticks = np.r_[0 : 0.51 : 0.25]
ax['C'].set_xlim(ticks[[0,-1]])
ax['C'].xaxis.set_major_locator(FixedLocator(ticks))
ax['C'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

trans = mtransforms.ScaledTranslation(-0.45, -0.05, fig.dpi_scale_trans)
for label in 'ABD':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')
trans = mtransforms.ScaledTranslation(-0.15, -0.05, fig.dpi_scale_trans)
for label in 'CE':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0)
plt.savefig('traces_hist_spectra_comp_grid.pdf')

## Figure 3

In [None]:
experiment_ID = '98475b819ecb4d569646d7e1467d7c9c'
# experiment_ID = 'ed79ae2784274401a9dba5f5ccee98d8'
experiments_path = '../experiments/neural_network/'
history = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'history.pkl'), 'rb'))
test_results = pickle.load(open(os.path.join(experiments_path, experiment_ID, 'test_results.pkl'), 'rb'))
N_bands = 20
if experiment_ID == '98475b819ecb4d569646d7e1467d7c9c':
    N_trials = 4000
elif experiment_ID == 'ed79ae2784274401a9dba5f5ccee98d8':
    N_trials = 1000
else:
    raise Exception(f'Unknown number of trials for experiment {experiment_ID[:6]}')
correlations_file = f'correlations_{experiment_ID[:6]}_{N_bands}-bands_64-filters_' + \
    f'36-neurons_{N_trials}-trials_8-butter_Vd_bus3_pool_1_3.npz'
correlations = np.load(os.path.join(experiments_path, experiment_ID, correlations_file))
for key in correlations.files:
    exec(f'{key} = correlations["{key}"]')
corr_idx = idx
corr_edges = edges

In [None]:
key = 'var_G2_G3'

coeff = 2
fig,ax = plt.subplot_mosaic(
    '''
    AAABBCCCDDD
    EEEFFFFGGGG
    HHHIIIIJJJJ
    ''',
    figsize=(3.5 * 1.75, 5.01)
)

############# Panel A #############
cmap = plt.get_cmap('tab10')
green, magenta = cmap(2), cmap(6)
jdx, = np.where(t < tend)
ax['A'].plot(t[jdx], X[key][:5, jdx].T, color=green, lw=0.5)
ax['A'].plot(t[jdx], X[key][-10:-5, jdx].T, color=magenta, lw=0.5)

############# Panel B #############
idx = np.concatenate(group_index[key][:4])
n,edges = np.histogram(X[key][idx,:], bins=50, density=True)
ax['B'].plot(n, edges[1:], lw=1, color=green)
for label in 'EH':
    m = Xf[key][idx,:].mean(axis=0)
    ci = 1.96 * Xf[key][idx,:].std(axis=0) / np.sqrt(idx.size)
    ax[label].plot(20 * np.log10(m), F, color=green, lw=1)
idx = np.concatenate(group_index[key][4:])
n,edges = np.histogram(X[key][idx,:], bins=50, density=True)
ax['B'].plot(n, edges[1:], lw=1, color=magenta)
ax['B'].set_xlim([0, 0.5])
ticks = np.r_[0 : 0.51 : 0.25]
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

############# Panel C #############
ax['C'].plot(history['loss'], 'k', lw=1, label='Tr. loss')
ax['C'].plot(history['val_loss'], 'r', lw=1, label='Val. loss')
ax['C'].legend(loc='upper right', frameon=False, fontsize=fontsize)

############# Panel D #############
target_values = np.unique(test_results['y_test'])
y_test = np.squeeze(test_results['y_test'])
y_pred = np.squeeze(test_results['y_prediction'])
low, = np.where(y_test == target_values[0])
high, = np.where(y_test == target_values[1])
df = pd.DataFrame(data={'test': y_test, 'pred': y_pred})
sns.violinplot(x='test', y='pred', data=df, cut=0, inner='quartile',
               color=[.6,.6,.6], ax=ax['D'], linewidth=1)
ax['D'].xaxis.set_major_locator(FixedLocator([0, 1]))
ax['D'].xaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in target_values]))
ax['D'].yaxis.set_major_locator(FixedLocator(target_values))
ax['D'].yaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in target_values]))

############# Panels F, G, I, J ########
plot_correlations(R, p, R_ctrl, p_ctrl, corr_edges, corr_idx,
                  ax=np.array([[ax['F'], ax['G']], [ax['I'], ax['J']]]))
ax['F'].set_title('Trained network', fontsize=fontsize+1)
ax['G'].set_title('Untrained network', fontsize=fontsize+1)

############# Panels E, H #############
for label in 'EH':
    ax[label].plot(20 * np.log10(Xf[key][idx,:].mean(axis=0)), F, color=magenta, lw=1)

for label in 'AB':
    ax[label].set_ylim([-3.5, 3.5])
    ax[label].set_yticks(np.r_[-3 : 3.1])

for label in 'ABEH':
    ax[label].grid(which='major', axis='y', lw=0.5, ls=':', color=[.6,.6,.6])

for a in ax.values():
    for side in 'right','top':
        a.spines[side].set_visible(False)

for label in 'EFGHIJ':
    ax[label].set_yscale('log')
    ax[label].set_ylim([0.05, 20])

for label in 'EH':
    ax[label].invert_xaxis()
    ax[label].yaxis.tick_right()
    ax[label].spines['right'].set_visible(True)
    ax[label].spines['left'].set_visible(False)
    ax[label].set_xlim([-10, -55])
    ax[label].set_xticks(np.r_[-50 : -9 : 10])
    
for label in 'EGHJ':
    ax[label].set_yticklabels([])
for label in 'EFG':
    ax[label].set_xticklabels([])
for label in 'EFGHIJ':
    ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
    ax[label].set_ylim(ticks[[0,-1]])
    ax[label].yaxis.set_major_locator(FixedLocator(ticks))
    ax[label].yaxis.set_minor_locator(NullLocator())
    if label in 'FI':
        ax[label].yaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))

ax['A'].set_xlabel('Time [s]')
ax['A'].set_ylabel('Norm. V')
ax['B'].set_xlabel('Distr.')
ax['C'].set_xlabel('Epoch')
ax['C'].set_ylabel('Loss')
ax['D'].set_xlabel(r'Exact M [GW$\cdot$s]')
ax['D'].set_ylabel(r'Predicted M [GW$\cdot$s]')
ax['H'].set_xlabel('Power [dB]')
ax['F'].set_ylabel('Frequency [Hz]')
ax['I'].set_ylabel('Frequency [Hz]')
ax['I'].set_xlabel('Filter #')
ax['J'].set_xlabel('Filter #')

trans = mtransforms.ScaledTranslation(-0.45, -0.05, fig.dpi_scale_trans)
for label in 'ABCDEFGHIJ':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0)
plt.savefig('low_high_momentum.pdf')

### Figure 4

In [None]:
data = np.load(os.path.join(experiments_path, experiment_ID, 'stopband_momentum_estimation.npz'))
for key in data.files:
    exec(f'{key} = data["{key}"]')
N_bands = len(bands)
Xfm = np.zeros((len(group_index), F.size))
Xfci = np.zeros((len(group_index), F.size))
for i,idx in enumerate(group_index):
    Xfm[i] = Xf[idx].mean(axis=0)
    Xfci[i] = 1.96 * Xf[idx].std(axis=0) / np.sqrt(idx.size)

In [None]:
fig,ax = plt.subplot_mosaic(
    '''
    AAAA.
    BBBB.
    ''',
    figsize=(3.5,4)
)

cmap2 = plt.get_cmap('Set1')
y_m = [y[idx].mean() for idx in group_index]
y_s = [y[idx].std() for idx in group_index]
y_pred_m = np.array([[pred[jdx].mean() for jdx in group_index] for pred in y_pred])
y_pred_s = np.array([[pred[jdx].std() for jdx in group_index] for pred in y_pred])
add_band = False
last_band = N_bands + 1 if add_band else N_bands + 1
for i in range(last_band):
    if i == 0:
        lbl = 'Broadband'
    else:
        lbl = f'[{bands[i-1][0]}-{bands[i-1][1]:g}] Hz'
    ax['A'].plot(y_m, y_pred_m[i], color=cmap2(i), lw=1, label=lbl)
    for j in range(len(group_index)):
        ax['A'].plot(y_m[j] + np.zeros(2),
                   y_pred_m[i,j] + y_pred_s[i,j] * np.array([-1,1]),
                   color=cmap2(i), lw=1)
        ax['A'].plot(y_m[j] + y_s[j] * np.array([-1,1]),
                   y_pred_m[i,j] + np.zeros(2),
                   color=cmap2(i), lw=1)
        ax['A'].plot(y_m[j], y_pred_m[i,j], 'o', color=cmap2(i),
                   markerfacecolor='w', markersize=4, markeredgewidth=1)
for side in 'right','top':
    ax['A'].spines[side].set_visible(False)
ax['A'].legend(loc='center left', bbox_to_anchor=[0.875, 0.5], frameon=False, fontsize=fontsize-1)
ax['A'].set_xlabel(r'Exact M [GW$\cdot$s$^2$]')
ax['A'].set_ylabel(r'Predicted M [GW$\cdot$s$^2$]')
ax['A'].set_xlim(y_m + np.diff(y_m) * np.array([-1/10, 1/5]))
ax['A'].set_ylim(y_m + np.diff(y_m) / 3 * np.array([-1,1]))
ax['A'].xaxis.set_major_locator(FixedLocator(y_m))
ax['A'].xaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in y_m]))
ax['A'].yaxis.set_major_locator(FixedLocator(y_m))
ax['A'].yaxis.set_major_formatter(FixedFormatter([f'{tick:.2f}' for tick in y_m]))

axr = ax['B'].twinx()
for i,(m,ci,col) in enumerate(zip(Xfm, Xfci, (green,magenta))):
    ax['B'].plot(F, 20*np.log10(m), color=col,
                 label=r'M = {:.2f} GW$\cdot$s$^2$'.format(y[group_index[i]].mean()))
ax['B'].legend(loc='lower left', frameon=False, fontsize=fontsize-1)
axr.plot(ax['B'].get_xlim(), scores[0] + np.zeros(2), '--', color=cmap2(0), lw=2)
for i,band in enumerate(bands):
    if i >= last_band-1:
        break
    axr.axvline(band[0], color=[.6,.6,.6], ls=':', lw=0.5)
    axr.plot(band, scores[i+1] + np.zeros(2), color=cmap2(i+1), lw=2)
ax['B'].set_xlabel('Frequency [Hz]')
ax['B'].set_ylabel('Power [dB]')
axr.set_ylim((-1.1, 1.1))
axr.set_ylabel(r'R$^2$ score')
ax['B'].set_xscale('log')

ticks = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20])
ax['B'].set_xlim(ticks[[0,-1]] + np.array([0,1]))
ax['B'].xaxis.set_major_locator(FixedLocator(ticks))
ax['B'].xaxis.set_minor_locator(NullLocator())
ax['B'].xaxis.set_major_formatter(FixedFormatter([f'{tick:g}' for tick in ticks]))
    
trans = mtransforms.ScaledTranslation(-0.45, -0.05, fig.dpi_scale_trans)
for label in 'AB':
    ax[label].text(0.0, 1.0, label, transform=ax[label].transAxes + trans, fontsize=10, va='bottom')

fig.tight_layout(pad=0.2)
if add_band:
    fig.savefig(f'stopband_momentum_estimation_{experiment_ID}_{last_band}.pdf')
else:
    fig.savefig(f'stopband_momentum_estimation_{experiment_ID[:6]}.pdf')