### 3. Which metrics correlate with perception? (multiple audio samples)

In [1]:
################## IMPORT LIBRARIES ##################
import sys
import importlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import soundfile as sf
from IPython.display import Audio, display, HTML
import torch
from os.path import join as pjoin


In [None]:
################## IMPORT MY MODULES ##################
sys.path.append('../src')

import helpers as hlp
import evaluation
import dataset as ds
import trainer
import models
import loss_mel, loss_stft, loss_waveform, loss_embedd

importlib.reload(evaluation)
importlib.reload(hlp)
importlib.reload(ds)
importlib.reload(trainer)
importlib.reload(models)
importlib.reload(loss_mel)
importlib.reload(loss_stft)
importlib.reload(loss_waveform)
importlib.reload(loss_embedd)


In [None]:
################## LOAD CSV WITH EVALUATION RESULTS ##################
df=pd.read_csv("/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/100924_compare_percept_5000testset.csv")
display(df.head(-20))

In [4]:
################## GIVE MORE CONCISE TAGS FOR EACH CATEGORY ##################

def impove_categories_tags(df): 
       # add a column to store a shorter tag identifying each category
       df['short_label'] = df['label'].apply(lambda x: x.split('_', 1)[1] if "_" in x else x)
       df['short_label'] = df['short_label'].apply(lambda x: x.replace("checkpoint","ch"))
       df['tag'] = df['short_label']+ ' -> ' + df['compared']
       df['tag'] = df['tag'].apply(lambda x: x.replace("target","tar"))
       df['tag'] = df['tag'].apply(lambda x: x.replace("prediction","pred"))
       df=df.sort_values("compared")
       df=df.drop(columns=['short_label'])
       # create a custom order of the files so that the plots have similar order as before
       custom_order=["oracle -> tar:anecho", "oracle -> tar:tarclone" , "oracle -> tar:content", "oracle -> tar:style", "anecho+fins -> pred:tar", "dfnet+fins -> pred:tar", "wpe+fins -> pred:tar",
              "c_wunet_stft+wave_0.8_0.2_chbest -> pred:tar", "c_wunet_logmel+wave_0.8_0.2_chbest -> pred:tar", "c_wunet_logmel_1_chbest -> pred:tar", "c_wunet_stft_1_chbest -> pred:tar",
              "c_wunet_stft_1_ch50 -> pred:tar", "c_wunet_stft_1_ch10 -> pred:tar", "c_wunet_stft_1_ch0 -> pred:tar"]
       df['tag'] = pd.Categorical(df['tag'], categories=custom_order, ordered=True)
       df=df.sort_values("tag")
       return df

df=impove_categories_tags(df)

In [None]:
################## ADD COLUMN TO KNOW IF THE SAMPLE IS REV2DRY OR DRY2REV ##################

# divide into re-reverbaration, de-reverberation 
config=hlp.load_config(pjoin("/home/ubuntu/joanna/reverb-match-cond-u-net/config/basic.yaml"))

def get_reverb_ind(config, df, split):
    config["p_noise"]=0
    config["split"]=split
    dataset=ds.DatasetReverbTransfer(config)
    indices_dry2rev=dataset.get_idx_with_rt60diff(-3,-0.3)
    indices_rev2dry=dataset.get_idx_with_rt60diff(0.3,3)
    indices_smalldiff=dataset.get_idx_with_rt60diff(-0.3,0.3)
    df.loc[df["idx"].isin(indices_dry2rev), "rev_delta"] = "dry2rev"
    df.loc[df["idx"].isin(indices_rev2dry), "rev_delta"] = "rev2dry"
    df.loc[df["idx"].isin(indices_smalldiff), "rev_delta"] = "smalldiff"
    return df 

df=get_reverb_ind(config, df, "test")


display(df.head(10))


In [None]:
################### FOR EACH SAMPLE INDEX, CHECK IF FOR THAT SAMPLE THE PERCEPTUAL EXPECTATIONS ARE CONFIRMED ##################

def check_requirements(df,idx):

    # get the names and number of metrics used
    df_metrics_only = df.loc[:, ~df.columns.isin(["label","idx","compared","short_label","tag",'rev_delta','dataset'])]
    N_metrics=len(df_metrics_only.columns)

    # get data for one index (for each index several audios: anecho, content, target, processed, etc.)
    df_idx=df[df["idx"]==idx]

    rev_delta=df_idx['rev_delta'].iloc[0]

    # create dictionary to fill for this index
    conditions_dict={"metric": list(df_metrics_only.columns),
                     "idx" : [idx] * N_metrics,
                     "rev_delta" : [rev_delta] * N_metrics,
                     "sTargetClone > sContent" : [False] * N_metrics,
                     "sTargetClone most similar" : [False] * N_metrics,
                     "sContent < sPred_anecho_fins & our best models" : [False] * N_metrics,
                     "sPred_anecho_fins > all our models" : [False] * N_metrics,
                     "sPred_anecho_fins > our logmel model" : [False] * N_metrics,
                     "sPred_anecho_fins > our half-trained models" : [False] * N_metrics,
                     "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins" :[False] * N_metrics,
                     "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest" : [False] * N_metrics,
                     "earlier checkpoints < later checkpoints" : [False] * N_metrics}
    
    def note_metric_in_conddict(conditions_dict, key):
        for j,metric in enumerate(conditions_dict["metric"]):
            if metric==col: 
                conditions_dict[key][j]=True
            


    # filtering
    df_idx=df_idx[df_idx["tag"]!= "tar:anecho"]
    df_idx=df_idx[df_idx["tag"]!= "tar:style"]

    row_content = df_idx[df_idx['tag'].str.contains("content")]
    row_targetclone = df_idx[df_idx['tag'].str.contains("tarclone")]

    row_anechofins = df_idx[df_idx['tag'].str.contains("anecho\\+fins")]
    row_dfnetfins = df_idx[df_idx['tag'].str.contains("dfnet\\+fins")]
    row_wpefins = df_idx[df_idx['tag'].str.contains("wpe\\+fins")]

    rows_fins_notanecho=pd.concat([row_dfnetfins,row_wpefins])

    row_ourlogmelwave = df_idx[df_idx['tag'].str.contains("logmel\\+wave_0.8_0.2_chbest")]
    row_ourstftwave = df_idx[df_idx['tag'].str.contains("stft\\+wave_0.8_0.2_chbest")]
    row_ourlogmel = df_idx[df_idx['tag'].str.contains("logmel_1_chbest")]
    row_ourstft = df_idx[df_idx['tag'].str.contains("stft_1_chbest")]

    rows_ourbests=pd.concat([row_ourlogmelwave,row_ourstftwave, row_ourlogmel,row_ourstft])

    row_ourstft0 = df_idx[df_idx['tag'].str.contains("ch0")]
    row_ourstft10 = df_idx[df_idx['tag'].str.contains("ch10")]
    row_ourstft50 = df_idx[df_idx['tag'].str.contains("ch50")]

    rows_ourhalftrained=pd.concat([row_ourstft0,row_ourstft10])
    rows_ourmodels = df_idx[df_idx['tag'].str.contains("_ch")]
    rows_bestpreds = pd.concat([rows_ourbests,row_anechofins])

    
    # sTargetClone is more similar to sTarget than sContent (change in reverb)
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_content[col].item() >= row_targetclone[col].item()):  
                note_metric_in_conddict(conditions_dict, "sTargetClone > sContent")
        elif "S_" in col:
            if (row_content[col].item() <= row_targetclone[col].item()): 
                note_metric_in_conddict(conditions_dict, "sTargetClone > sContent")

    # sTargetClone is most similar to sTarget from all the signals
    for col in df_idx.columns:
        sorted_values = df_idx.sort_values(by=col).reset_index(drop=True)
        # loss/difference metrics
        if "L_" in col  or "D_" in col: 
            if sorted_values.loc[sorted_values.index[0], 'tag']=='oracle -> tar:tarclone':  # Check second lowest
                note_metric_in_conddict(conditions_dict, "sTargetClone most similar")
        # similarity metrics
        elif "S_" in col: 
            if sorted_values.loc[sorted_values.index[-1], 'tag']=='oracle -> tar:tarclone':  # Check second lowest
                note_metric_in_conddict(conditions_dict, "sTargetClone most similar")


    # sContent less similar to target than sPred_anecho_fins and our well-trained models
    # i.e. transformation helps
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_bestpreds[col] < row_content[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sContent < sPred_anecho_fins & our best models")
        elif "S_" in col:
            if (rows_bestpreds[col] > row_content[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sContent < sPred_anecho_fins & our best models")


    # sPred_anecho_fins is better than all of our models
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourmodels[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > all our models")
        elif "S_" in col:
            if (rows_ourmodels[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > all our models")

    # sPred_anecho_fins is at least better than logmel model 
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_ourlogmelwave[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our logmel model")
        elif "S_" in col:
            if (row_ourlogmelwave[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our logmel model")

    # sPred_anecho_fins is at least better than our half-trained models 
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourhalftrained[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our half-trained models")
        elif "S_" in col:
            if (rows_ourhalftrained[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our half-trained models")

    # sPred_anecho_fins is better than sPred_dfnet_fins and sPred_wpe_fins
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_fins_notanecho[col] > row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins")
        elif "S_" in col:
            if (rows_fins_notanecho[col] < row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins")


    # stft+wave_0.8_0.2_checkpointbest better than logmel+wave_0.8_0.2_checkpointbest
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_ourlogmel[col].item() > row_ourstft[col].item()): 
                note_metric_in_conddict(conditions_dict, "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest")
        elif "S_" in col:
            if (row_ourlogmel[col].item() < row_ourstft[col].item()):  
                note_metric_in_conddict(conditions_dict, "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest")

    # checkpoint0 <  checkpoint10 < checkpoint50
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_ourstft0[col].item() > row_ourstft10[col].item() and 
            row_ourstft10[col].item() > row_ourstft50[col].item()): 
                note_metric_in_conddict(conditions_dict, "earlier checkpoints < later checkpoints")
        elif "S_" in col:
            if (row_ourstft0[col].item() < row_ourstft10[col].item() and 
            row_ourstft10[col].item() < row_ourstft50[col].item()): 
                note_metric_in_conddict(conditions_dict, "earlier checkpoints < later checkpoints")

    return conditions_dict
    # end of function "check_requirements(df,idx)"


# get unique indices
unique_idx=df.idx.unique()

# initialize empty dictionary
keys=["metric",
      "idx",
      "rev_delta",
      "sTargetClone > sContent",
      "sTargetClone most similar",
      "sContent < sPred_anecho_fins & our best models" ,
       "sPred_anecho_fins > all our models" ,
       "sPred_anecho_fins > our logmel model" ,
       "sPred_anecho_fins > our half-trained models"  ,
        "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins" ,
        "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest" ,
       "earlier checkpoints < later checkpoints" ]
combined_dict =  {key: [] for key in keys}

i=0
for idx in unique_idx:
    i+=1
    tmp_dict=check_requirements(df,idx)
    combined_dict = {key: combined_dict[key] + tmp_dict[key] for key in keys}

# go from dictionary to df
df_table=pd.DataFrame(combined_dict)
display(df_table.head(10))

In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 1 -> ALLDATA

needed_columns = df_table.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 2 -> REV2DRY

df_table_rev2dry=df_table[df_table["rev_delta"]=="rev2dry"]

needed_columns = df_table_rev2dry.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 2 -> DRY2REV

df_table_dry2rev=df_table[df_table["rev_delta"]=="dry2rev"]

needed_columns = df_table_dry2rev.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 2 -> SMALLDIFF

df_table_smalldiff=df_table[df_table["rev_delta"]=="smalldiff"]

needed_columns = df_table_smalldiff.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

### Conclusions: 

- Results for multiple audio samples are similar to the results for 1 audio sample
- Most of the metrics "see" the difference in reverb (even when the difference in rt60 is small)
- Metrics get confused when dealing with artificially processed sounds (for example, if the model was trained with the stft loss, the stft_loss(processed, target) is smaller than stft_loss(targetclone, target). This does not reflect the perception.
- When delta in the rt60 is small, there is not much benefit in processing the data (see column "sContent < sPred_anecho_fins & our best models"), but when the rt60 delta is big it gets better. 
- When delta in the rt60 is small, there less benefit in training longer (see difference between early and late checkpoints)
- It looks like for a large rt60 delta, the perceptual effect that the baseline sPred_anecho_fins is better than our models is sometimes confirmed by some metrics (see 3 middle columns). This perceptual effect is confirmed more often than when rt60 delta is small...

In [None]:
# PLOT AVERAGE METRICS 

# create df with average metric across data points
numeric_df = df.select_dtypes(include='number')
df_g=df.groupby("tag")[numeric_df.columns].mean().reset_index()

# plot average metrics
plt.figure(figsize=(20,60))
N_metrics=len(df_g.columns)
for i, column in enumerate(df_g.columns):
    if column not in ["label", "idx", "compared","short_label","tag"]:    
        N_rows=int(np.ceil(N_metrics/3))
        metriccolor="red" if "D_" in column or "L_" in column else "green"
        plt.subplot(N_rows,3,i+1)
        bars=plt.bar(df_g["tag"],df_g[column], color=metriccolor)
        plt.xticks(rotation=60, ha='right', fontsize="9")
        plt.title("Metric mean: " + column)
        for j,bar in enumerate(bars):
            val4clone=bars[j].get_height()#df_idx[column][df_idx["label"]=="sTarget : sAnecho"][0]
            plt.text(j,val4clone,str("%.2f" % val4clone), horizontalalignment='center',verticalalignment='bottom', fontsize="9")
plt.tight_layout()
plt.show()

