In [90]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from src.plot_utils import dafx_from_name
from sklearn.feature_selection import mutual_info_regression

In [91]:
sns.set(style="white")

In [92]:
DATA_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/data/param_extraction"
FIG_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/figures/param_mmi"
OUT_DATA_DIR = "/home/kieran/Level5ProjectAudioVAE/src/evaluation/data/param_mmi"

DAFX = ["overdrive", "combo", "delay", "ambience", "dynamics", "ringmod"]

DAFX_TO_MDA_NAME = {
    "delay": "mda Delay",
    "combo": "mda Combo",
    "overdrive": "mda Overdrive",
    "ambience": "mda Ambience",
    "dynamics": "mda Dynamics",
    "ringmod": "mda RingMod",
}

In [93]:
dataframes = {}

param_sizes = {}

for fx in DAFX:
    emb = np.load(f"{DATA_DIR}/{fx}_data.npy")
    params = np.load(f"{DATA_DIR}/{fx}_settings.npy")

    idx_param_map = dafx_from_name(DAFX_TO_MDA_NAME[fx]).idx_to_param_map

    param_sizes[fx] = len(idx_param_map)

    mi_matrix = np.array([mutual_info_regression(emb, params[:, i]) for i in range(params.shape[1])])
    df = pd.DataFrame(mi_matrix, index=[idx_param_map[i] for i in range(params.shape[1])])

    dataframes[fx] = df

    df.to_csv(f"{OUT_DATA_DIR}/{fx}.csv")

In [94]:
sorted_fx = sorted(param_sizes, key=param_sizes.get)

In [208]:
sorted_fx

['overdrive', 'ringmod', 'ambience', 'combo', 'delay', 'dynamics']

In [295]:
subframes = []
vmax = 0

for fx in sorted_fx:
    df = dataframes[fx]
    sub_df = df.max(axis='columns').sort_values(ascending=False).reset_index()

    vm = sub_df[0].max()

    vmax = max(vm, vmax)

    # mean_data = ['mean', sub_df[0].mean()]
    # sub_df = pd.concat([sub_df, pd.DataFrame([mean_data], columns=sub_df.columns)], ignore_index=True)

    fx_name = DAFX_TO_MDA_NAME[fx].split()[-1]

    y = [(fx_name, 'Param'), (fx_name, 'MMI')]
    cols = pd.MultiIndex.from_tuples(y)
    sub_df.columns = cols

    subframes.append(sub_df)

In [296]:
subframes[-1]

Unnamed: 0_level_0,Dynamics,Dynamics
Unnamed: 0_level_1,Param,MMI
0,gate_thr_db,0.272074
1,release_ms,0.120145
2,mix,0.084259
3,output_db,0.074751
4,gate_rel_ms,0.06604
5,thresh_db,0.030173
6,limiter_db,0.027245
7,ratio,0.024533
8,gate_att_s,0.020577
9,attack_s,0.017268


In [297]:
full_df = pd.concat(subframes, axis=1)

In [298]:
slice = [x for x in full_df.columns if x[1] == 'MMI']

In [299]:
slice

[('Overdrive', 'MMI'),
 ('RingMod', 'MMI'),
 ('Ambience', 'MMI'),
 ('Combo', 'MMI'),
 ('Delay', 'MMI'),
 ('Dynamics', 'MMI')]

In [300]:
res = full_df.mean(numeric_only=True)

In [301]:
mean_data = []
for i in res:
    mean_data.append('mean')
    mean_data.append(i)

In [302]:
full_df = pd.concat([full_df, pd.DataFrame([mean_data], columns=full_df.columns)], ignore_index=True)

In [303]:
full_df.to_csv(f"{OUT_DATA_DIR}_full.csv")

In [328]:
blue_cm = sns.light_palette("blue", as_cmap=True)

styler = full_df.style\
    .background_gradient(cmap=blue_cm, subset=slice, vmin=0, vmax=vmax)\
    .set_properties(**{'text-align': 'center',
                       'font-family': "Calibri",})\
    .format(na_rep='', precision=2)\
    .hide(axis='index')

# Define the CSS style rule for the table headers
header_css = [{'selector': '.col_heading',
               'props': [('font-family', 'Calibri')]}]

# Define the CSS style rule for cell padding
padding_css = [{'selector': 'td',
                'props': [('padding-left', '15px'), ('padding-right', '15px')]}]

# Combine the CSS style rules
css = padding_css  + header_css

# Apply the CSS style rule to the table
html = styler.set_table_styles(css).hide(axis='index').to_html(border=True)

In [329]:
# html = styler.to_html(index=False)

In [330]:
with open(f'{OUT_DATA_DIR}/table.html', 'w') as f:
    f.write(html)

In [331]:
latex = styler.set_table_styles(css).hide(axis='index').to_latex()

In [332]:
with open(f'{OUT_DATA_DIR}/table.tex', 'w') as f:
    f.write(latex)

In [333]:
md = styler.set_table_styles(css).hide(axis='index').to_excel(f"{OUT_DATA_DIR}/full.xlsx")

In [334]:
len("lrrrrrrrrrrrr")

13

In [335]:
print("c"*13)

ccccccccccccc
