In [None]:
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

from music_nextsim_tuning import read_and_scale

In [None]:
max_date = datetime(2007, 4, 12)
max_std = 1.5

rdir = './music_matrix/cfg01_m20'
odir = '../tuning_paper_figures'
idir = './music_matrix/cfg01_m20'

param_names, inp_rgps, inp_ftrs = read_and_scale(idir, max_date)

rgps_avg = np.median(inp_rgps, axis=0)
rgps_std = np.std(inp_rgps, axis=0)

plt.figure(figsize=(10,2))
arg_idx = np.argsort(rgps_avg)
rgps_avg = rgps_avg[arg_idx]
rgps_std = rgps_std[arg_idx]
columns = inp_rgps.columns[arg_idx]
gpi = (rgps_avg > -max_std) * (rgps_avg < max_std)

plt.errorbar(columns, rgps_avg, rgps_std, capsize=2, fmt='none', color='k')
plt.bar(columns, rgps_avg, label='Excluded descriptors')
plt.bar(columns[gpi], rgps_avg[gpi], color='green', label='Included descriptors')
plt.hlines([-max_std, 0, max_std], 0, len(columns), color='k', alpha=0.5)
plt.xticks(rotation = 90) # Rotates X-Axis Ticks
plt.ylabel('$\hat\mu_R$ and $\hat\sigma_R$')
ofile = f'{odir}/fig00_filter_features_mean_std_{idir.split("/")[-1]}.png'
plt.legend()
plt.savefig(ofile, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
rmse_file = '../tuning_paper_figures/filter_autoencoder.npz'
with np.load(rmse_file, allow_pickle=True) as f:
    rmse_n1 = f['rmse_n1']
    rmse_r1 = f['rmse_r1']
    good_columns1 = f['good_columns1']

In [None]:
sort_idx = np.argsort(rmse_r1)
plt.figure(figsize=(10,2))
plt.bar(good_columns1[sort_idx], rmse_r1[sort_idx], alpha=0.7, label='Excluded descriptors')
plt.bar(good_columns1[sort_idx][:36], rmse_r1[sort_idx][:36], alpha=0.7, label='Included descriptors', color='green')
plt.bar(good_columns1[sort_idx], rmse_n1[sort_idx], alpha=0.7, label='neXtSIM')
plt.hlines([0.95], 0, len(sort_idx), color='k', alpha=0.5)
plt.ylabel('Autoencoder RMSE')
plt.xticks(rotation = 90) # Rotates X-Axis Ticks
ofile = f'{odir}/fig00_filter_features_autoencoder_{idir.split("/")[-1]}.png'
plt.legend()
plt.savefig(ofile, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.show()
print(ofile)
