### Compute metrics based on the reverb part of audios

In [2]:
################## IMPORT LIBRARIES ##################
import sys
import importlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import soundfile as sf
from IPython.display import Audio, display, HTML
import torch
from os.path import join as pjoin
import random

In [None]:
################## IMPORT MY MODULES ##################
sys.path.append('../src')

import helpers as hlp
import evaluation
import dataset as ds
import trainer
import models
import loss_mel, loss_stft, loss_waveform, loss_embedd

importlib.reload(evaluation)
importlib.reload(hlp)
importlib.reload(ds)
importlib.reload(trainer)
importlib.reload(models)
importlib.reload(loss_mel)
importlib.reload(loss_stft)
importlib.reload(loss_waveform)
importlib.reload(loss_embedd)


#### 1. Retrieving early and late part of the ground truth signals

In [None]:
# RETRIEVING EARLY AND LATE PART OF THE GROUND TRUTH SIGNAL

config=hlp.load_config(pjoin("/home/ubuntu/joanna/reverb-match-cond-u-net/config/basic.yaml"))
# instantiate a test data set 
config["split"]="test"
dataset=ds.DatasetReverbTransfer(config)

# Get "extended" data point i.e. all ground truth signals decomposed into early and late 
# and the corresponding room impulse responses
hlp.init_random_seeds(0)
INDX=46
sigs, rirs = dataset.get_item_test(INDX,truncate_rirs=True)

print(f"{sigs.keys()=}")
print(f"{rirs.keys()=}")

#### 1.1. Impulse responses 
- cut the intitial silence defined by 20dB threshold
- scale so that the max peak = 1 
- We divide full rir into early part and late part: 
- *rir<sub>full</sub> = rir<sub>early</sub> + rir<sub>late</sub>* 
- *rir<sub>early</sub>*  -> everything apart from the first peak (so direct sound)
- *rir<sub>late</sub>*   ->  rest of the rir 

Below, we have 3 RIRs: 
- r1 (rir of the content sound)
- r2 (rir of the style and target sound)
- r2b (cloned rir, same room as r2)

In [None]:
def plot_rirs(early, late, together,suptitle):
    plt.figure(figsize=(10,2))
    plt.rcParams.update({'font.size': 8})
    plt.subplot(1,3,1); plt.plot(early.T, color="red"); plt.xlim([0,4000]) ; plt.ylim([-0.1,1]); plt.title("direct sound")
    plt.subplot(1,3,2); plt.plot(late.T, color="blue") ; plt.xlim([0,4000]) ; plt.ylim([-0.1,1]) ;plt.title("all reflections")
    plt.subplot(1,3,3); plt.plot(together.T); plt.xlim([0,4000]) ; plt.ylim([-0.1,1]) ;plt.title("full rir")
    plt.suptitle(suptitle);plt.tight_layout();plt.show() 

plot_rirs(rirs["rirContent_early"].numpy(),rirs["rirContent_late"].numpy(),rirs["rirContent"].numpy(), "RIR of the content sound (r1)")
plot_rirs(rirs["rirTarget_early"].numpy(),rirs["rirTarget_late"].numpy(),rirs["rirTarget"].numpy(), "RIR of the target sound (r2)")
plot_rirs(rirs["rirTargetClone_early"].numpy(),rirs["rirTargetClone_late"].numpy(),rirs["rirTargetClone"].numpy(), "RIR of the cloned target sound (r2b)")

#### 1.2. Audio signals

- We divide full reverberant signal into early part and late part
- full signal: *s<sub>full</sub>  = s<sub>early</sub>  + s<sub>late</sub>* 
- We scale the output so that max peak = 1 (scaling factor sc) -> the network expects a waveform in the range [-1,1]
- *sc * s<sub>full</sub>   = sc * s<sub>early</sub> + sc * s<sub>late</sub>* 


Below we have the following signals:
- s1r1 (sContent)
- s1r2 (sTarget)
- s1r2b (sTargetClone)

In [None]:
def plot_sigs(early, late, together, suptitle):
    plt.figure(figsize=(10,2))
    plt.subplot(1,3,1); plt.plot(early.T, color="red"); plt.ylim(-1,1); plt.title("direct sound")
    plt.subplot(1,3,2); plt.plot(late.T, color="blue"); plt.ylim(-1,1); plt.title("all reflections")
    plt.subplot(1,3,3); plt.plot(together.T); plt.ylim(-1,1); plt.title("full signal")
    plt.suptitle(suptitle);plt.tight_layout();plt.show() 

def audio_3_sigs(early,late, together):
    au1=Audio(early,rate=48000)
    au2=Audio(late,rate=48000)
    au3=Audio(together,rate=48000)
    display(HTML(f"""
    <div style="display: flex; space-between;">
        <div>{au1._repr_html_()}</div>
        <div>{au2._repr_html_()}</div>
        <div>{au3._repr_html_()}</div>
    </div>
    """))

plot_sigs(sigs["sContent_early"].numpy(),sigs["sContent_late"].numpy(),sigs["sContent"].numpy(), "Content sound (s1r1)")
audio_3_sigs(sigs["sContent_early"].numpy(),sigs["sContent_late"].numpy(),sigs["sContent"].numpy())

plot_sigs(sigs["sTarget_early"].numpy(),sigs["sTarget_late"].numpy(),sigs["sTarget"].numpy(), "Target sound (s1r2)")
audio_3_sigs(sigs["sTarget_early"].numpy(),sigs["sTarget_late"].numpy(),sigs["sTarget"].numpy())

plot_sigs(sigs["sTargetClone_early"].numpy(),sigs["sTargetClone_late"].numpy(),sigs["sTargetClone"].numpy(), "Target clone sound (s1r2b)")
audio_3_sigs(sigs["sTargetClone_early"].numpy(),sigs["sTargetClone_late"].numpy(),sigs["sTargetClone"].numpy())

#### 1.3. Subtracting early part to estimate the late part (non-processed signals)

- Knowing early part of the target signal (*s1r2<sub>early</sub>*), we can subtract it from a full reverberant signal to get the estimate of the late part.
- For example, if we have an estimate of target, we can calculate:

$$\widehat{s1r2}_{late}= \widehat{s1r2} - s1r2_{early}$$

- ...and we get the estimate of the late part of the target.


- Below we do this for sTargetClone and sContent (we estimate the late part): 

$$\widehat{s1r1}_{late}= s1r1 - s1r2_{early}$$
$$\widehat{s1r2b}_{late}= s1r2b - s1r2_{early}$$

- ...and we compare the estimate of the late part to the actual late part (we know the actual part because we have ground truth knowledge for these signals)

In [None]:
def plot_gt_est(gt, est, suptitle, colorchoice):
    plt.figure(figsize=(10,2))
    plt.subplot(1,2,1); plt.plot(gt.T, color=colorchoice); plt.ylim(-1,1); plt.title("ground truth")
    plt.subplot(1,2,2); plt.plot(est.T, color=colorchoice); plt.ylim(-1,1); plt.title("estimate")
    plt.suptitle(suptitle);plt.tight_layout();plt.show() 

def audio_2_sigs(s1,s2):
    au1=Audio(s1,rate=48000)
    au2=Audio(s2,rate=48000)
    display(HTML(f"""
    <div style="display: flex; space-between;">
        <div>{au1._repr_html_()}</div>
        <div>{au2._repr_html_()}</div>
    </div>
    """))

# estimate of the late part of sTargetClone
sTargetClone_late_gt=sigs["sTargetClone_late"].numpy()
sTargetClone_late_estim=sigs["sTargetClone"].numpy()- sigs["sTarget_early"].numpy()
plot_gt_est(sTargetClone_late_gt, sTargetClone_late_estim, "Estimating late part of sTargetClone","blue")
audio_2_sigs(sTargetClone_late_gt,sTargetClone_late_estim)


# estimate of the late part of sContent
sContent_late_gt=sigs["sContent_late"].numpy()
sContent_late_estim=sigs["sContent"].numpy()- sigs["sTarget_early"].numpy()
plot_gt_est(sContent_late_gt, sContent_late_estim, "Estimating late part of sContent","blue")
audio_2_sigs(sContent_late_gt,sContent_late_estim)


- This works quite well. The estimate of the late part is perceptually very close to the ground truth of the late part.

#### 1.4. Subtracting early part to estimate the late part (processed signals)

- so far we have only looked at signals that are not processed by any network
- now we will generate the estimated of the target using different models
- next, we will estimate the late part according to the formula: 

$\widehat{s1r2}_{late}= \widehat{s1r2} - s1r2_{early}$

- and we will compare by listening:  
- ${s1r2}$ (target ground truth)   
- $\widehat{s1r2}$ (target estimate)  
- $s1r2_{late}$ (target late part ground truth)  
- $\widehat{s1r2}_{late}$  (target late part estimate)  

In [None]:
# INIT EVALUATION OBJECT (CONTAINS FUNCTION THAT GETS PREDICTIONS FROM DIFFERENT MODELS & TO COMPUTE METRICS)

myeval = evaluation.Evaluator(config)

In [None]:
# COMPUTE AND SAVE TO WAV ALL VERSIONS OF A DATAPOINT (GROUND TRUTH, BASELINES, OUR MODEL PREDICTIONS)

# which data point to use
wavsavedir="../sounds/report_151024/"
resave=True

# make a list of checkpoints to compare (in addition to ground truth sounds and baselines)
checkpoint_paths=[
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/10-06-2024--15-02_c_wunet_stft+wave_0.8_0.2/checkpointbest.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/20-05-2024--22-48_c_wunet_logmel+wave_0.8_0.2/checkpointbest.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/29-05-2024--05-47_c_wunet_logmel_1/checkpointbest.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/18-06-2024--18-37_c_wunet_stft_1/checkpointbest.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/18-06-2024--18-37_c_wunet_stft_1/checkpoint50.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/18-06-2024--18-37_c_wunet_stft_1/checkpoint10.pt",
                  "/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/18-06-2024--18-37_c_wunet_stft_1/checkpoint0.pt"
                  ]

# init dictionaries to contain signals and corresponding file names
sigs={}
filenames={}

# get ground truth signals for this sample
sigs_gt,filenames_gt=myeval.save_audios_sample_ext("groundtruth",INDX,wavsavedir,savefiles=resave)
sigs.update(sigs_gt)
filenames.update(filenames_gt)

# get baseline predictions for this sample
sigs_bl,filenames_bl=myeval.save_audios_sample_ext("baselines",INDX,wavsavedir,savefiles=resave)
sigs.update(sigs_bl)
filenames.update(filenames_bl)

# get checkpoint predictions for this sample
for i, checkpoint_path in enumerate(checkpoint_paths):
    sigs_chckpt,filenames_chckpt=myeval.save_audios_sample_ext(checkpoint_path,INDX,wavsavedir,savefiles=resave)
    sigs.update(sigs_chckpt)
    filenames.update(filenames_chckpt)

In [None]:
# DISPLAY THE OUTPUT OF THE ABOVE FUNCTION (SIGNALS AND CORRESPONDING FILE PATHS)

for key in list(sigs.keys()):
    print(key + " --->  " + filenames[key]) 

In [None]:
# PLOT AND PLAY BACK TARGET AND ESTIMATES

def audio_4_sigs(s1,s2,s3,s4):
    au1=Audio(s1,rate=48000)
    au2=Audio(s2,rate=48000)
    au3=Audio(s3,rate=48000)
    au4=Audio(s4,rate=48000)
    display(HTML(f"""
    <div style="display: flex; space-between;">
        <div>{au1._repr_html_()}</div>
        <div>{au2._repr_html_()}</div>
        <div>{au3._repr_html_()}</div>
        <div>{au4._repr_html_()}</div>
    </div>
    """))

audio1=sigs["sTarget"].numpy()
audio2=sigs["sTarget_late"].numpy()

for key in list(sigs.keys()):
    if ("_late" in key) and ("sPred" in key):
        key_full=key.split("_late")[0]
        audio3=sigs[key_full].numpy()
        audio4=sigs[key].numpy()
        plot_gt_est(audio1, audio3, "Comparing:  sTarget to " + key_full + " (FULL SIG)", "steelblue")
        # hlp.plot_2_spectrograms(audio1, audio3,  48000)
        audio_2_sigs(audio1,audio3)
        plot_gt_est(audio2, audio4, "Comparing:  sTarget to "  + key_full + " (LATE PART)", "blue")
        # hlp.plot_2_spectrograms(audio2, audio4,  48000)
        audio_2_sigs(audio2,audio4)
        print("----------------------------------------------------------------------------------------------------------------")


- For natural signal this way of estimating the late reverb worked quite ok
- For some processed signals (especially logmel-trained) we get something a bit similar to the late part
- But for stft-trained signals, subtracting early part gives a signal with a very different dynamic range than ground truth late reverb

#### 1.5. Using metrics on the late part

- same metrics as the one that I applied to the full signal:

- Type 1: similarity measured using symmetric metric : **M = m(a,b) = m(b,a)**
    - '1L_multi-stft-mag'
    - '1L_stft'
    - '1L_stft-mag'
    - '1L_multi-wave'
    - '1L_wave'
    - '1L_logmel'
    - '1L_multi-mel'
    - '1S_sisdr'
    - '1L_emb_euc'
- Type 2: similarity measured using non-symmetric metric: **M = (m(a,b)+m(b,a))/2**
    - '2L_lsd' 
    - '2L_mcd' 
    - '2S_fwsnr'
    - '2L_multi-stft'
    - '2L_stft'
    - '2S_pesq'
    - '2S_stoi'
- Type 3: similarity measured as distance in intrusive metric: **M = abs(m(a,ref) - m(b,ref))**
    - '3D_pesq'
    - '3D_stoi'
    - '3D_sisdr'
    - '3D_mos_nidiff'
    - '3D_pesq_nidiff'
    - '3D_stoi_nidiff'
    - '3D_sisdr_nidiff'

- for late parts, I scaled all the signals before they enter the metrics

In [None]:
################## LOAD CSV WITH EVALUATION RESULTS ##################
df=pd.read_csv("/home/ubuntu/Data/RESULTS-reverb-match-cond-u-net/runs-exp-20-05-2024/151024_compare_percept_100testset_revpart_scaled.csv")
display(df.head(5))
print(len(df))

In [5]:
################## GIVE MORE CONCISE TAGS FOR EACH CATEGORY ##################

def impove_categories_tags(df): 
       # add a column to store a shorter tag identifying each category
       df['short_label'] = df['label'].apply(lambda x: x.split('_', 1)[1] if "_" in x else x)
       df['short_label'] = df['short_label'].apply(lambda x: x.replace("checkpoint","ch"))
       df['tag'] = df['short_label']+ ' -> ' + df['compared']
       df['tag'] = df['tag'].apply(lambda x: x.replace("target","tar"))
       df['tag'] = df['tag'].apply(lambda x: x.replace("prediction","pred"))
       df=df.sort_values("compared")
       df=df.drop(columns=['short_label'])
       # create a custom order of the files so that the plots have similar order as before
       custom_order=["oracle -> tar:tarclone" , "oracle -> r(tar):r(tarclone)", 
                     "oracle -> tar:content", "oracle -> r(tar):r(content)",  
                     "anecho+fins -> pred:tar", "anecho+fins -> r(pred):r(tar)", 
                     "dfnet+fins -> pred:tar",  "dfnet+fins -> r(pred):r(tar)",
                     "wpe+fins -> pred:tar", "wpe+fins -> r(pred):r(tar)",
                     "c_wunet_stft+wave_0.8_0.2_chbest -> pred:tar", "c_wunet_stft+wave_0.8_0.2_chbest -> r(pred):r(tar)", 
                     "c_wunet_logmel+wave_0.8_0.2_chbest -> pred:tar", "c_wunet_logmel+wave_0.8_0.2_chbest -> r(pred):r(tar)", 
                     "c_wunet_logmel_1_chbest -> pred:tar", "c_wunet_logmel_1_chbest -> r(pred):r(tar)", 
                     "c_wunet_stft_1_chbest -> pred:tar","c_wunet_stft_1_chbest -> r(pred):r(tar)",
                     "c_wunet_stft_1_ch50 -> pred:tar", "c_wunet_stft_1_ch50 -> r(pred):r(tar)", 
                     "c_wunet_stft_1_ch10 -> pred:tar", "c_wunet_stft_1_ch10 -> r(pred):r(tar)", 
                     "c_wunet_stft_1_ch0 -> pred:tar", "c_wunet_stft_1_ch0 -> r(pred):r(tar)"]
       df['tag'] = pd.Categorical(df['tag'], categories=custom_order, ordered=True)
       df=df.sort_values("tag")
       return df

df=impove_categories_tags(df)


In [None]:
################## ADD COLUMN TO KNOW IF THE SAMPLE IS REV2DRY OR DRY2REV ##################

# divide into re-reverbaration, de-reverberation 
config=hlp.load_config(pjoin("/home/ubuntu/joanna/reverb-match-cond-u-net/config/basic.yaml"))

def get_reverb_ind(config, df, split):
    config["p_noise"]=0
    config["split"]=split
    dataset=ds.DatasetReverbTransfer(config)
    indices_dry2rev=dataset.get_idx_with_rt60diff(-3,-0.3)
    indices_rev2dry=dataset.get_idx_with_rt60diff(0.3,3)
    indices_smalldiff=dataset.get_idx_with_rt60diff(-0.3,0.3)
    df.loc[df["idx"].isin(indices_dry2rev), "rev_delta"] = "dry2rev"
    df.loc[df["idx"].isin(indices_rev2dry), "rev_delta"] = "rev2dry"
    df.loc[df["idx"].isin(indices_smalldiff), "rev_delta"] = "smalldiff"
    return df 

df=get_reverb_ind(config, df, "test")

display(df.head(5))

In [None]:
# filter results into full signal and late reverb
df_late=df[df["tag"].str.contains("r\(")]
df_full=df[~df["tag"].str.contains("r\(")]
print(len(df_late))
print(len(df_full))


In [15]:
################### FOR EACH SAMPLE INDEX, CHECK IF FOR THAT SAMPLE THE PERCEPTUAL EXPECTATIONS ARE CONFIRMED ##################

def check_requirements(df,idx):

    # get the names and number of metrics used
    df_metrics_only = df.loc[:, ~df.columns.isin(["label","idx","compared","short_label","tag",'rev_delta','dataset'])]
    N_metrics=len(df_metrics_only.columns)

    # get data for one index (for each index several audios: anecho, content, target, processed, etc.)
    df_idx=df[df["idx"]==idx]

    rev_delta=df_idx['rev_delta'].iloc[0]

    # create dictionary to fill for this index
    conditions_dict={"metric": list(df_metrics_only.columns),
                     "idx" : [idx] * N_metrics,
                     "rev_delta" : [rev_delta] * N_metrics,
                     "sTargetClone > sContent" : [False] * N_metrics,
                     "sTargetClone most similar" : [False] * N_metrics,
                     "sContent < sPred_anecho_fins & our best models" : [False] * N_metrics,
                     "sPred_anecho_fins > all our models" : [False] * N_metrics,
                     "sPred_anecho_fins > our logmel models" : [False] * N_metrics,
                     "sPred_anecho_fins > our stft models" : [False] * N_metrics,
                     "sPred_anecho_fins > our half-trained models" : [False] * N_metrics,
                     "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins" :[False] * N_metrics,
                     "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest" : [False] * N_metrics,
                     "earlier checkpoints < later checkpoints" : [False] * N_metrics}
    
    def note_metric_in_conddict(conditions_dict, key):
        for j,metric in enumerate(conditions_dict["metric"]):
            if metric==col: 
                conditions_dict[key][j]=True

    # filtering
    df_idx=df_idx[df_idx["tag"]!= "tar:anecho"]
    df_idx=df_idx[df_idx["tag"]!= "tar:style"]

    row_content = df_idx[df_idx['tag'].str.contains("content")]
    row_targetclone = df_idx[df_idx['tag'].str.contains("tarclone")]

    row_anechofins = df_idx[df_idx['tag'].str.contains("anecho\\+fins")]
    row_dfnetfins = df_idx[df_idx['tag'].str.contains("dfnet\\+fins")]
    row_wpefins = df_idx[df_idx['tag'].str.contains("wpe\\+fins")]

    rows_fins_notanecho=pd.concat([row_dfnetfins,row_wpefins])

    row_ourlogmelwave = df_idx[df_idx['tag'].str.contains("logmel\\+wave_0.8_0.2_chbest")]
    row_ourstftwave = df_idx[df_idx['tag'].str.contains("stft\\+wave_0.8_0.2_chbest")]
    row_ourlogmel = df_idx[df_idx['tag'].str.contains("logmel_1_chbest")]
    row_ourstft = df_idx[df_idx['tag'].str.contains("stft_1_chbest")]

    rows_ourbests=pd.concat([row_ourlogmelwave,row_ourstftwave, row_ourlogmel,row_ourstft])

    row_ourstft0 = df_idx[df_idx['tag'].str.contains("ch0")]
    row_ourstft10 = df_idx[df_idx['tag'].str.contains("ch10")]
    row_ourstft50 = df_idx[df_idx['tag'].str.contains("ch50")]

    rows_ourhalftrained=pd.concat([row_ourstft0,row_ourstft10])
    rows_ourmodels = df_idx[df_idx['tag'].str.contains("_ch")]
    rows_ourlogmels=df_idx[df_idx['tag'].str.contains("logmel") & df_idx['tag'].str.contains("best")]
    rows_ourstfts=df_idx[df_idx['tag'].str.contains("stft") & df_idx['tag'].str.contains("best")]
    rows_bestpreds = pd.concat([rows_ourbests,row_anechofins])

    
    # sTargetClone is more similar to sTarget than sContent (change in reverb)
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_content[col].item() >= row_targetclone[col].item()):  
                note_metric_in_conddict(conditions_dict, "sTargetClone > sContent")
        elif "S_" in col:
            if (row_content[col].item() <= row_targetclone[col].item()): 
                note_metric_in_conddict(conditions_dict, "sTargetClone > sContent")

    # sTargetClone is most similar to sTarget from all the signals
    for col in df_idx.columns:
        sorted_values = df_idx.sort_values(by=col).reset_index(drop=True)
        # loss/difference metrics
        if "L_" in col  or "D_" in col: 
            if "tarclone" in sorted_values.loc[sorted_values.index[0], 'tag']:  # Check second lowest
                note_metric_in_conddict(conditions_dict, "sTargetClone most similar")
        # similarity metrics
        elif "S_" in col: 
            if "tarclone" in sorted_values.loc[sorted_values.index[-1], 'tag']:  # Check second lowest
                note_metric_in_conddict(conditions_dict, "sTargetClone most similar")


    # sContent less similar to target than sPred_anecho_fins and our well-trained models
    # i.e. transformation helps
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_bestpreds[col] < row_content[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sContent < sPred_anecho_fins & our best models")
        elif "S_" in col:
            if (rows_bestpreds[col] > row_content[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sContent < sPred_anecho_fins & our best models")


    # sPred_anecho_fins is better than all of our models
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourmodels[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > all our models")
        elif "S_" in col:
            if (rows_ourmodels[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > all our models")

    # sPred_anecho_fins is at least better than logmel model 
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourlogmels[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our logmel models")
        elif "S_" in col:
            if (rows_ourlogmels[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our logmel models")

    # sPred_anecho_fins is at least better than stft models
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourstfts[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our stft models")
        elif "S_" in col:
            if (rows_ourstfts[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our stft models")

    # sPred_anecho_fins is at least better than our half-trained models 
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_ourhalftrained[col] >= row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our half-trained models")
        elif "S_" in col:
            if (rows_ourhalftrained[col] <= row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > our half-trained models")

    # sPred_anecho_fins is better than sPred_dfnet_fins and sPred_wpe_fins
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (rows_fins_notanecho[col] > row_anechofins[col].item()).all(): 
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins")
        elif "S_" in col:
            if (rows_fins_notanecho[col] < row_anechofins[col].item()).all():  
                note_metric_in_conddict(conditions_dict, "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins")


    # stft+wave_0.8_0.2_checkpointbest better than logmel+wave_0.8_0.2_checkpointbest
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_ourlogmel[col].item() > row_ourstft[col].item()): 
                note_metric_in_conddict(conditions_dict, "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest")
        elif "S_" in col:
            if (row_ourlogmel[col].item() < row_ourstft[col].item()):  
                note_metric_in_conddict(conditions_dict, "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest")

    # checkpoint0 <  checkpoint10 < checkpoint50
    for col in df_idx.columns: 
        if "L_" in col  or "D_" in col:
            if (row_ourstft0[col].item() > row_ourstft10[col].item() and 
            row_ourstft10[col].item() > row_ourstft50[col].item()): 
                note_metric_in_conddict(conditions_dict, "earlier checkpoints < later checkpoints")
        elif "S_" in col:
            if (row_ourstft0[col].item() < row_ourstft10[col].item() and 
            row_ourstft10[col].item() < row_ourstft50[col].item()): 
                note_metric_in_conddict(conditions_dict, "earlier checkpoints < later checkpoints")

    return conditions_dict
    # end of function "check_requirements(df,idx)"

def compute_metrics_table(df):
    # get unique indices
    unique_idx=df.idx.unique()
    # initialize empty dictionary
    keys=["metric",
        "idx",
        "rev_delta",
        "sTargetClone > sContent",
        "sTargetClone most similar",
        "sContent < sPred_anecho_fins & our best models" ,
        "sPred_anecho_fins > all our models" ,
        "sPred_anecho_fins > our logmel models" ,
        "sPred_anecho_fins > our stft models" ,
        "sPred_anecho_fins > our half-trained models"  ,
        "sPred_anecho_fins > sPred_dfnet_fins & sPred_wpe_fins" ,
        "stft+wave_0.8_0.2_checkpointbest > logmel+wave_0.8_0.2_checkpointbest" ,
        "earlier checkpoints < later checkpoints" ]
    combined_dict =  {key: [] for key in keys}
    i=0
    for idx in unique_idx:
        i+=1
        tmp_dict=check_requirements(df,idx)
        combined_dict = {key: combined_dict[key] + tmp_dict[key] for key in keys}
    # go from dictionary to df
    df_table=pd.DataFrame(combined_dict)
    return df_table





df_table_full=compute_metrics_table(df_full)
df_table_late=compute_metrics_table(df_late)




In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 1 -> FULL SIGNAL (SAME AS BEFORE)

needed_columns = df_table_full.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

In [None]:
####################  CREATE A TABLE INDICATING WHICH PERCEPTUAL OBSERVATIONS ARE CONFIRMED BY EACH METRIC ################### 
# CASE 1 -> LATE

needed_columns = df_table_late.drop(columns=["idx","rev_delta"],axis=1)
df_table_g=needed_columns.groupby("metric").agg(lambda x: int((x.sum() / len(x))*100)).reset_index()
df_table_g = df_table_g.style.background_gradient(cmap='viridis', vmin=0, vmax=100,)
display(df_table_g)

- in general, the metrics "see" that the late reverb of targetclone is closer to the target than the late reverb of the content sound
- but (surprisingly) this difference is less clear than when using late part vs when using full signals 
- also less metrics show that reverb transfer helps (that the signal after transformation becomes more similar to the target)
- even though we use only the reverb part for analysis, the baseline is still not better than our networks
- but at least some metrics consider baseline better than our half-trained models
- if we cut the direct sound from the signals, there is no difference between early and late checkpoints (which could mean that during training the direct sound is improved but the reverb not much)
- so it doesnt look like it gives us a better information to evaluate the signals on the late part. 

In [None]:
# PLOT AVERAGE METRICS 

# create df with average metric across data points
numeric_df_late = df_late.select_dtypes(include='number')
numeric_df_full = df_full.select_dtypes(include='number')

df_g_late=df_late.groupby("tag")[numeric_df_late.columns].mean().reset_index()
df_g_full=df_full.groupby("tag")[numeric_df_full.columns].mean().reset_index()

# for some reason df_late.tag.unique() gives different list then df_late.groupby(["tag"]).groups.keys()
# so the workaround is to remove the empty rows of df_g 
df_g_late=df_g_late.dropna()
df_g_full=df_g_full.dropna()

# plot average metrics
plt.figure(figsize=(10,150))
plt.rcParams.update({'font.size': 8})
N_metrics=len(df_g_late.columns)
i=0
for column in df_g_late.columns:
    if column not in ["label", "idx", "compared","short_label","tag"]: 
        N_rows=int(np.ceil(N_metrics))
        metriccolor="red" if "D_" in column or "L_" in column else "green"
        # plot metrics on full signal
        plt.subplot(N_rows,2,i+1)
        bars=plt.bar(df_g_full["tag"],df_g_full[column], color=metriccolor)
        plt.xticks(rotation=60, ha='right', fontsize="8")
        plt.title("Metric mean: " + column + " (FULL SIG)")
        for j,bar in enumerate(bars):
            val4clone=bars[j].get_height()#df_idx[column][df_idx["label"]=="sTarget : sAnecho"][0]
            plt.text(j,val4clone,str("%.2f" % val4clone), horizontalalignment='center',verticalalignment='bottom', fontsize="7")
        # plot metrics on late part of the signal
        plt.subplot(N_rows,2,i+2)
        bars=plt.bar(df_g_late["tag"],df_g_late[column], color=metriccolor)
        plt.xticks(rotation=60, ha='right', fontsize="8")
        plt.title("Metric mean: " + column + " (LATE PART)")
        for j,bar in enumerate(bars):
            val4clone=bars[j].get_height()#df_idx[column][df_idx["label"]=="sTarget : sAnecho"][0]
            plt.text(j,val4clone,str("%.2f" % val4clone), horizontalalignment='center',verticalalignment='bottom', fontsize="7")

        i=i+2

plt.tight_layout()
plt.show()
