In [92]:
import numpy as np
import pandas as pd
import seaborn as sns

from tqdm import tqdm

from src.plot_utils import dafx_from_name
from sklearn.feature_selection import mutual_info_regression
from sklearn.cross_decomposition import CCA

In [93]:
sns.set(style="white")

In [94]:
DATA_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/data/param_extraction"
FIG_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/figures/param_cca_mmi"
OUT_DATA_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/data/param_cca_mmi"

DAFX = ["overdrive", "combo", "delay", "ambience", "dynamics", "ringmod"]

DAFX_TO_MDA_NAME = {
    "delay": "mda Delay",
    "combo": "mda Combo",
    "overdrive": "mda Overdrive",
    "ambience": "mda Ambience",
    "dynamics": "mda Dynamics",
    "ringmod": "mda RingMod",
}

In [95]:
dataframes = {}

param_sizes = {}

for fx in tqdm(DAFX):
    emb = np.load(f"{DATA_DIR}/{fx}_data.npy")
    params = np.load(f"{DATA_DIR}/{fx}_settings.npy")

    idx_param_map = dafx_from_name(DAFX_TO_MDA_NAME[fx]).idx_to_param_map
    param_sizes[fx] = len(idx_param_map)

    cca = CCA(n_components=2)
    cca.fit(emb, params)
    X_c = cca.transform(emb)

    mi_matrix = np.array([mutual_info_regression(X_c, params[:, i]) for i in range(params.shape[1])])
    df = pd.DataFrame(mi_matrix, index=[idx_param_map[i] for i in range(params.shape[1])])

    dataframes[fx] = df

    df.to_csv(f"{OUT_DATA_DIR}/{fx}.csv")

100%|██████████| 6/6 [00:04<00:00,  1.43it/s]


In [96]:
sorted_fx = sorted(param_sizes, key=param_sizes.get)

In [97]:
sorted_fx

['overdrive', 'ringmod', 'ambience', 'combo', 'delay', 'dynamics']

In [98]:
subframes = []
vmax = 0

for fx in sorted_fx:
    df = dataframes[fx]
    sub_df = df.max(axis='columns').sort_values(ascending=False).reset_index()

    vm = sub_df[0].max()

    vmax = max(vm, vmax)

    # mean_data = ['mean', sub_df[0].mean()]
    # sub_df = pd.concat([sub_df, pd.DataFrame([mean_data], columns=sub_df.columns)], ignore_index=True)

    fx_name = DAFX_TO_MDA_NAME[fx].split()[-1]

    y = [(fx_name, 'Param'), (fx_name, 'MMI')]
    cols = pd.MultiIndex.from_tuples(y)
    sub_df.columns = cols

    subframes.append(sub_df)

In [100]:
full_df = pd.concat(subframes, axis=1)

In [101]:
slice = [x for x in full_df.columns if x[1] == 'MMI']

In [102]:
slice

[('Overdrive', 'MMI'),
 ('RingMod', 'MMI'),
 ('Ambience', 'MMI'),
 ('Combo', 'MMI'),
 ('Delay', 'MMI'),
 ('Dynamics', 'MMI')]

In [103]:
res = full_df.mean(numeric_only=True)

In [104]:
mean_data = []
for i in res:
    mean_data.append('mean')
    mean_data.append(i)

In [105]:
full_df = pd.concat([full_df, pd.DataFrame([mean_data], columns=full_df.columns)], ignore_index=True)

In [106]:
full_df.to_csv(f"{OUT_DATA_DIR}_full.csv")

In [107]:
blue_cm = sns.light_palette("blue", as_cmap=True)

styler = full_df.style\
    .background_gradient(cmap=blue_cm, subset=slice, vmin=0, vmax=vmax)\
    .set_properties(**{'text-align': 'center',
                       'font-family': "Calibri",})\
    .format(na_rep='', precision=2)\
    .hide(axis='index')

# Define the CSS style rule for the table headers
header_css = [{'selector': '.col_heading',
               'props': [('font-family', 'Calibri')]}]

# Define the CSS style rule for cell padding
padding_css = [{'selector': 'td',
                'props': [('padding-left', '15px'), ('padding-right', '15px')]}]

# Combine the CSS style rules
css = padding_css  + header_css

# Apply the CSS style rule to the table
html = styler.set_table_styles(css).hide(axis='index').to_html(border=True)

In [108]:
# html = styler.to_html(index=False)

In [109]:
with open(f'{OUT_DATA_DIR}/table.html', 'w') as f:
    f.write(html)

In [110]:
latex = styler.set_table_styles(css).hide(axis='index').to_latex()

In [111]:
with open(f'{OUT_DATA_DIR}/table.tex', 'w') as f:
    f.write(latex)

In [112]:
styler.set_table_styles(css).hide(axis='index').to_excel(f"{OUT_DATA_DIR}/full.xlsx")

In [113]:
display(styler)

Overdrive,Overdrive,RingMod,RingMod,Ambience,Ambience,Combo,Combo,Delay,Delay,Dynamics,Dynamics
Param,MMI,Param,MMI,Param,MMI,Param,MMI,Param,MMI,Param,MMI
muffle,2.34,freq_hz,0.74,mix,0.53,hpf_freq,1.23,fb_mix,0.49,output_db,0.12
drive,0.96,feedback,0.54,size_m,0.28,hpf_reso,0.46,l_delay_ms,0.34,gate_thr_db,0.12
output_db,0.0,fine_hz,0.0,hf_damp,0.08,drive_s_h,0.02,r_delay,0.06,mix,0.1
,,,,output_db,0.0,bias,0.01,feedback,0.03,release_ms,0.08
,,,,,,output_db,0.0,fb_tone_lo_hi,0.01,gate_rel_ms,0.04
,,,,,,,,,,limiter_db,0.03
,,,,,,,,,,ratio,0.02
,,,,,,,,,,thresh_db,0.02
,,,,,,,,,,gate_att_s,0.01
,,,,,,,,,,attack_s,0.0
