In [27]:
import os
from functools import partial
from itertools import chain

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.contrib.concurrent import process_map


In [28]:
INFERENCE_ROOT = "/home/kwatchar3/Documents/data/moisesdb/"
STEM_SETUP = "everything"
GROUND_TRUTH_ROOT = "/home/kwatchar3/Documents/data/moisesdb"

In [29]:
variants = [
    "ev-d-pre",
    "ev-d-prefz",
    "ev-d-pre-aug",
    # "ev-d-pre-bal",
    "ev-d-prefz-bal",
    # "ev-d-pre-aug-bal",
]



In [30]:
gt_files = os.listdir(os.path.join(GROUND_TRUTH_ROOT, "npy2"))

In [31]:
def snr(gt, est):
    return 10 * np.log10(np.sum(np.square(gt)) / np.sum(np.square(gt - est)))

In [45]:
allowed_stems = COARSE_TO_FINE = {
    "vocals": [
        "lead_female_singer",
        "lead_male_singer",
        # "human_choir",
        "background_vocals",
        # "other_vocals",
    ],
    "drums": [
        "drums"
    ],
    "bass": [
        "bass_guitar",
        "bass_synthesizer",
        # "contrabass_double_bass",
        # "tuba",
        # "bassoon",
    ],
    "guitar": [
        "clean_electric_guitar",
        "distorted_electric_guitar",
        # "lap_steel_guitar_or_slide_guitar",
        "acoustic_guitar",
    ],
    "piano": ["grand_piano", "electric_piano"],
    "other_keys": ["organ_electric_organ", "synth_pad", "synth_lead",
                   # "other_sounds"
                   ],
    "bowed_strings": [
        # "violin",
        # "viola",
        # "cello",
        # "violin_section",
        # "viola_section",
        # "cello_section",
        "string_section",
        "other_strings",
    ],
    "wind": ["brass",
             # "flutes",
             "reeds",
             "other_wind"],
    "other_plucked": ["other_plucked"],
    "percussion": [
        # "atonal_percussion",
        "pitched_percussion"
    ],
    "other": ["fx",
              # "click_track"
              ],
}


In [33]:
def get_results_for_song(inputs):

    song_name, inference_mode, variant = inputs

    stems = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, variant, "audio", song_name))
    stems = [s.replace(".wav", "") for s in stems]

    results = []

    for coarse_stem in allowed_stems:

        coarse_pred = []
        coarse_true = []

        has_gt = False
        for stem in allowed_stems[coarse_stem]:
            stem_has_gt = False
            npy_path = os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, f"{stem}.npy")
            if os.path.exists(npy_path):
                audio = np.load(npy_path, mmap_mode="r")
                coarse_true.append(audio)
                has_gt = True
                stem_has_gt = True
            else:
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                # print(f"Ground truth not found for {song_name}/{stem}. Using zeros.")
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                pass

            audio_path = os.path.join(
                INFERENCE_ROOT,
                inference_mode,
                STEM_SETUP,
                variant,
                "audio",
                song_name,
                f"{stem}.wav"
                )
            if os.path.exists(audio_path):
                audio, sr = sf.read(audio_path)
                audio = audio.T
                coarse_pred.append(audio)
            else:
                if stem_has_gt:
                    print("******************************************************")
                    # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    print(f"Prediction not found for {song_name}/{stem}. Using zeros.")
                    print("******************************************************")
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

        if not has_gt:
            # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            # print(f"No ground truth found for {song_name} - {coarse_stem}. Skipping.")
            # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            continue

        coarse_true = sum(coarse_true)
        coarse_pred = sum(coarse_pred)

        snr_full = snr(coarse_true, coarse_pred)

        results.append(
            {
                "song": song_name,
                "stem": coarse_stem,
                "snr": snr_full,
                "variant": variant,
                "inference_mode": inference_mode,
            }
        )

    return results


In [34]:

df = []

for inference_mode in ["inference-d", "inference-o"]:

    for v in variants:
        print(f"Processing {v}...")

        test_files = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, v, "audio"))

        # for song in tqdm(test_files):
        #     results = get_results_for_song(song, inference_mode, v)
        #     df.extend(results)

        inputs = [(song, inference_mode, v) for song in test_files]

        results = process_map(get_results_for_song, inputs, max_workers=16)
        results = list(chain(*results))

        df.extend(results)

df = pd.DataFrame(df)

Processing ev-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-prefz-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************
Processing ev-d-prefz-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_wind. Using zeros.
******************************************************
******************************************************
Prediction not found for 5a6df5c7-a58a-479e-bdfb-c5946c221933/other_plucked. Using zeros.
******************************************************
******************************************************
Prediction not found for c70471f9-9c4a-41c9-b8f8-20ac38847a8e/other_strings. Using zeros.
******************************************************


In [35]:
df

Unnamed: 0,song,stem,snr,variant,inference_mode
0,704f1de9-1d02-4c2b-af05-107a7700a51d,vocals,8.348032,ev-d-pre,inference-d
1,704f1de9-1d02-4c2b-af05-107a7700a51d,drums,11.773505,ev-d-pre,inference-d
2,704f1de9-1d02-4c2b-af05-107a7700a51d,bass,8.722601,ev-d-pre,inference-d
3,704f1de9-1d02-4c2b-af05-107a7700a51d,guitar,1.481413,ev-d-pre,inference-d
4,8a6c9c1f-4865-404f-a805-1949de36a33c,vocals,13.111771,ev-d-pre,inference-d
...,...,...,...,...,...
1995,a56d9450-3a26-485c-8ac3-24b6b54e2c1d,guitar,8.346391,ev-d-prefz-bal,inference-o
1996,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,vocals,4.019988,ev-d-prefz-bal,inference-o
1997,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,drums,10.185577,ev-d-prefz-bal,inference-o
1998,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,bass,8.265368,ev-d-prefz-bal,inference-o


In [36]:
df["snr"] = df["snr"].replace(-np.inf, np.nan)

In [37]:
df.to_csv(os.path.join(INFERENCE_ROOT, "bandit_ev_merged.csv"), index=False)

In [78]:
df = pd.read_csv(os.path.join(INFERENCE_ROOT, "bandit_ev_merged.csv"))

In [79]:
stem_dtype = pd.CategoricalDtype(categories=allowed_stems, ordered=True)
df["stem"] = df["stem"].astype(stem_dtype)

bool_dtype = pd.CategoricalDtype(categories=["Y", "N"], ordered=True)
ibool_dtype = pd.CategoricalDtype(categories=["N", "Y"], ordered=True)

df["is_frozen"] = df["variant"].str.contains("prefz").apply(lambda x: "Y" if x else "N").astype(bool_dtype)
df["is_balanced"] = df["variant"].str.contains("bal").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["is_augmented"] = df["variant"].str.contains("aug").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["query_same"] = df["inference_mode"].str.contains("-o").apply(lambda x: "same" if x else "diff.")

In [80]:
def q25(x):
    return x.quantile(0.25)


def q75(x):
    return x.quantile(0.75)


def q50(x):
    return x.quantile(0.5)


dfagg = df.groupby(
    [
        "is_frozen",
        "is_augmented",
        "is_balanced",
        "query_same",
        "stem"
    ]
).describe()["snr"]

# dfagg = dfagg[dfagg.query_same == "diff."]

dfagg = dfagg.reset_index().pivot_table(
    index=["is_frozen", "is_augmented", "is_balanced", "query_same"],
    columns="stem",
)


  dfagg = df.groupby(


In [81]:
dfagg

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,25%,25%,25%,25%,25%,25%,25%,25%,25%,25%,...,std,std,std,std,std,std,std,std,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,stem,vocals,drums,bass,guitar,piano,other_keys,bowed_strings,wind,other_plucked,percussion,...,vocals,drums,bass,guitar,piano,other_keys,bowed_strings,wind,percussion,other
is_frozen,is_augmented,is_balanced,query_same,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
Y,N,N,diff.,6.894812,8.157859,7.565005,1.072846,0.215652,-0.000709,-0.010425,0.000579,0.0,-2.415543,...,3.513647,2.527102,4.214204,2.586086,2.090998,0.380155,0.015852,0.469784,2.274218,5.627593
Y,N,N,same,6.859615,8.153814,8.012354,1.190625,0.328484,0.000201,-2.4e-05,0.003754,0.0,-1.016244,...,3.713512,2.541448,3.991589,2.566749,2.459208,0.172286,0.001835,0.034627,1.276415,4.59219
Y,N,Y,diff.,6.743392,7.779492,7.521674,1.072451,-0.234785,-0.081063,-0.188035,0.035286,0.0,-2.1337,...,2.901677,2.37261,3.526889,2.339386,3.226273,0.625984,0.417858,0.589618,1.966818,1.992873
Y,N,Y,same,6.407937,7.777564,7.557049,1.179914,-0.912198,-0.313551,-0.530249,0.078286,0.0,-0.112063,...,3.228087,2.393993,3.335579,2.296313,3.358125,0.797854,0.454663,0.808579,2.614345,1.62672
N,N,N,diff.,6.859689,7.881636,7.484564,1.446187,0.631473,-0.094925,-0.231063,-0.089766,0.0,-2.455696,...,3.512808,2.46226,4.261725,2.931147,2.753458,2.077147,0.793063,1.584786,2.214774,5.565832
N,N,N,same,6.586905,7.845629,7.878842,1.231723,0.600752,-0.468628,-0.46947,0.124161,0.0,-1.261964,...,3.676778,2.466046,3.997968,2.895431,3.049478,2.145024,0.61509,0.759273,2.689716,4.531642
N,Y,N,diff.,7.047089,8.560939,8.267794,1.235196,0.131989,1.5e-05,3e-06,-0.012499,0.0,-2.513345,...,3.873531,2.516873,4.294079,2.782856,2.943407,1.01411,0.00019,2.111165,2.373379,5.419691
N,Y,N,same,6.269022,8.559176,8.462066,1.350015,0.304006,4e-05,2.1e-05,0.002636,0.0,-1.359048,...,4.076241,2.532753,4.025533,2.719135,3.326416,1.225373,0.000409,1.075197,1.285342,4.217764


In [82]:
dfagg = dfagg.swaplevel(axis=1).sort_index(axis=1)




In [83]:
dfagg_max = dfagg.max(axis=0)

In [84]:

def bold_formatter(x, val):
    if round(x, 1) == round(val, 1):
        return r"\bfseries " + f"{x:.1f}"
    return f"{x:.1f}"


formatters = {
    (c, d): partial(bold_formatter, val=dfagg_max.loc[c, d])
    for c, d in dfagg.columns
}

In [85]:
str_ = dfagg.to_latex(
    formatters=formatters,
    sparsify=True,
    multirow=False,
)

print(str_)

\begin{tabular}{llllrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr}
\toprule
 &  &  & stem & \multicolumn{8}{r}{vocals} & \multicolumn{8}{r}{drums} & \multicolumn{8}{r}{bass} & \multicolumn{8}{r}{guitar} & \multicolumn{8}{r}{piano} & \multicolumn{8}{r}{other_keys} & \multicolumn{8}{r}{bowed_strings} & \multicolumn{8}{r}{wind} & \multicolumn{7}{r}{other_plucked} & \multicolumn{8}{r}{percussion} & \multicolumn{8}{r}{other} \\
 &  &  &  & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75% & count & max & mean & min & 25% & 50% & 75% & count & max & mean & min & std & 25% & 50% & 75%

In [86]:
dfagg.round(1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,stem,vocals,vocals,vocals,vocals,vocals,vocals,vocals,vocals,drums,drums,...,percussion,percussion,other,other,other,other,other,other,other,other
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,25%,50%,75%,count,max,mean,min,std,25%,50%,...,min,std,25%,50%,75%,count,max,mean,min,std
is_frozen,is_augmented,is_balanced,query_same,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
Y,N,N,diff.,6.9,7.8,10.6,48.0,15.2,8.3,0.8,3.5,8.2,9.8,...,-3.2,2.3,-0.0,-0.0,-0.0,6.0,0.0,-2.3,-13.8,5.6
Y,N,N,same,6.9,8.2,10.6,48.0,15.6,8.2,0.4,3.7,8.2,9.8,...,-1.5,1.3,-0.0,-0.0,-0.0,6.0,0.0,-1.9,-11.3,4.6
Y,N,Y,diff.,6.7,7.3,9.9,48.0,14.4,8.0,1.2,2.9,7.8,9.0,...,-2.8,2.0,-0.2,-0.1,-0.0,6.0,0.0,-0.9,-4.9,2.0
Y,N,Y,same,6.4,7.7,10.2,48.0,14.3,7.9,1.0,3.2,7.8,9.0,...,-1.0,2.6,-0.1,0.0,0.2,6.0,0.5,-0.5,-3.8,1.6
N,N,N,diff.,6.9,8.2,10.7,48.0,15.9,8.5,0.4,3.5,7.9,9.3,...,-3.2,2.2,-0.7,-0.0,0.0,6.0,0.0,-2.4,-13.8,5.6
N,N,N,same,6.6,8.3,10.7,48.0,15.9,8.3,0.3,3.7,7.8,9.3,...,-2.2,2.7,-1.0,-0.2,-0.0,6.0,0.0,-2.1,-11.3,4.5
N,Y,N,diff.,7.0,8.2,10.9,48.0,16.4,8.4,0.3,3.9,8.6,10.0,...,-3.4,2.4,-0.0,-0.0,0.0,6.0,0.0,-2.2,-13.3,5.4
N,Y,N,same,6.3,8.1,10.9,48.0,16.6,8.1,0.3,4.1,8.6,10.0,...,-1.8,1.3,-0.0,0.0,0.0,6.0,0.0,-1.7,-10.3,4.2
