In [147]:
from functools import partial
from itertools import chain

import librosa.effects
import numpy as np
import pandas as pd
import soundfile as sf

import os

from ipywidgets import Audio
from matplotlib import pyplot as plt
from tqdm.contrib.concurrent import process_map
from tqdm.notebook import tqdm

import librosa

In [41]:
INFERENCE_ROOT = "/home/kwatchar3/Documents/data/moisesdb/"
STEM_SETUP = "vdbgp"
GROUND_TRUTH_ROOT = "/home/kwatchar3/Documents/data/moisesdb"

In [42]:
variants = [
    "vdbgp-d-pre",
    "vdbgp-d-prefz",
    "vdbgp-d-pre-aug",
    "vdbgp-d-pre-bal",
    "vdbgp-d-prefz-bal",
    "vdbgp-d-pre-aug-bal",
]



In [43]:
gt_files = os.listdir(os.path.join(GROUND_TRUTH_ROOT, "npy2"))

In [44]:
def snr(gt, est):
    return 10 * np.log10(np.sum(np.square(gt)) / np.sum(np.square(gt - est)))

In [45]:
allowed_stems = [
    "lead_female_singer",
    "lead_male_singer",
    "drums",
    "bass_guitar",
    "acoustic_guitar",
    "clean_electric_guitar",
    "distorted_electric_guitar",
    "grand_piano",
    "electric_piano",
]

In [46]:
def get_results_for_song(inputs):

    song_name, inference_mode, variant = inputs

    stems = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, variant, "audio", song_name))
    stems = [s.replace(".wav", "") for s in stems]

    results = []

    for stem in stems:
        if stem not in allowed_stems:
            continue

        audio_est, sr = sf.read(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, variant, "audio", song_name, f"{stem}.wav"))
        audio_est = audio_est.T

        npy_path = os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, f"{stem}.npy")
        if os.path.exists(npy_path):
            audio = np.load(npy_path, mmap_mode="r")
        else:
            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print(f"Ground truth not found for {song_name}/{stem}. Using zeros.")
            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            audio = np.zeros_like(audio_est)

        snr_full = snr(audio, audio_est)

        results.append({
            "song": song_name,
            "stem": stem,
            "snr": snr_full,
            "variant": variant,
            "inference_mode": inference_mode,
        })

    return results



In [59]:

df = []

for inference_mode in ["inference-d", "inference-o"]:

    for v in variants:
        print(f"Processing {v}...")

        test_files = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, v, "audio"))

        # for song in tqdm(test_files):
        #     results = get_results_for_song(song, inference_mode, v)
        #     df.extend(results)

        inputs = [(song, inference_mode, v) for song in test_files]

        results = process_map(get_results_for_song, inputs, max_workers=16)
        results = list(chain(*results))

        df.extend(results)


df = pd.DataFrame(df)

Processing vdbgp-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-prefz-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-aug-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-prefz-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdbgp-d-pre-aug-bal...


  0%|          | 0/48 [00:00<?, ?it/s]

In [60]:
df

Unnamed: 0,song,stem,snr,variant,inference_mode
0,704f1de9-1d02-4c2b-af05-107a7700a51d,bass_guitar,9.033754,vdbgp-d-pre,inference-d
1,704f1de9-1d02-4c2b-af05-107a7700a51d,drums,12.501521,vdbgp-d-pre,inference-d
2,704f1de9-1d02-4c2b-af05-107a7700a51d,acoustic_guitar,1.717476,vdbgp-d-pre,inference-d
3,704f1de9-1d02-4c2b-af05-107a7700a51d,lead_male_singer,7.361787,vdbgp-d-pre,inference-d
4,8a6c9c1f-4865-404f-a805-1949de36a33c,lead_female_singer,13.685509,vdbgp-d-pre,inference-d
...,...,...,...,...,...
2647,a56d9450-3a26-485c-8ac3-24b6b54e2c1d,acoustic_guitar,10.415203,vdbgp-d-pre-aug-bal,inference-o
2648,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,distorted_electric_guitar,4.498278,vdbgp-d-pre-aug-bal,inference-o
2649,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,bass_guitar,8.909531,vdbgp-d-pre-aug-bal,inference-o
2650,1f98fe4d-26c7-460f-9f68-33964bc4d8d3,drums,10.670884,vdbgp-d-pre-aug-bal,inference-o


In [61]:
df["snr"] = df["snr"].replace(-np.inf, np.nan)

In [62]:
df.to_csv(os.path.join(INFERENCE_ROOT, "bandit_vdbgp.csv"), index=False)

In [194]:
df = pd.read_csv(os.path.join(INFERENCE_ROOT, "bandit_vdbgp.csv"))

stem_dtype = pd.CategoricalDtype(categories=allowed_stems, ordered=True)
df["stem"] = df["stem"].astype(stem_dtype)

bool_dtype = pd.CategoricalDtype(categories=["Y", "N"], ordered=True)
ibool_dtype = pd.CategoricalDtype(categories=["N", "Y"], ordered=True)

df["is_frozen"] = df["variant"].str.contains("prefz").apply(lambda x: "Y" if x else "N").astype(bool_dtype)
df["is_balanced"] = df["variant"].str.contains("bal").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["is_augmented"] = df["variant"].str.contains("aug").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["query_same"] = df["inference_mode"].str.contains("-o").apply(lambda x: "same" if x else "diff.")

In [195]:
def q25(x):
    return x.quantile(0.25)

def q75(x):
    return x.quantile(0.75)

def q50(x):
    return x.quantile(0.5)

dfagg = df.groupby([
    "is_frozen",
    "is_augmented",
    "is_balanced",
    "query_same",
    "stem"
]).agg({"snr": [q25, q50, q75]})
dfagg.columns = ["q25", "q50", "q75"]
dfagg = dfagg.reset_index()
dfagg = dfagg.reset_index().pivot_table(
    index=["is_frozen", "is_augmented", "is_balanced", "query_same"],
    columns="stem",
    values=["q25", "q50", "q75"]
)


  dfagg = df.groupby([


In [196]:
dfagg = dfagg.swaplevel(axis=1).sort_index(axis=1)

dfagg



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,stem,lead_female_singer,lead_female_singer,lead_female_singer,lead_male_singer,lead_male_singer,lead_male_singer,drums,drums,drums,bass_guitar,...,clean_electric_guitar,distorted_electric_guitar,distorted_electric_guitar,distorted_electric_guitar,grand_piano,grand_piano,grand_piano,electric_piano,electric_piano,electric_piano
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,q25,q50,q75,q25,q50,q75,q25,q50,q75,q25,...,q75,q25,q50,q75,q25,q50,q75,q25,q50,q75
is_frozen,is_augmented,is_balanced,query_same,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2
Y,N,N,diff.,7.198664,9.723279,12.798156,6.314435,7.357547,9.797291,8.157859,9.804378,11.801135,8.015215,...,1.722306,0.604794,2.584025,4.775978,0.095643,0.94394,2.09996,-0.005586,0.215652,0.590281
Y,N,N,same,7.369198,9.691747,12.805657,6.613824,7.426128,9.785307,8.153814,9.805333,11.800478,8.012354,...,1.812555,0.503696,2.538738,4.548766,-0.103905,0.947414,2.103304,0.247375,0.391507,0.559184
Y,N,Y,diff.,6.161379,9.097359,12.302063,6.114836,6.884451,8.748042,7.779492,8.966419,11.010345,7.595034,...,1.463174,0.322619,2.324578,4.031529,-1.498889,0.780164,1.989419,-0.054843,0.401561,0.757468
Y,N,Y,same,6.25017,9.161142,12.315944,5.442377,7.065612,8.892028,7.777564,8.954748,11.022204,7.557049,...,1.567352,0.384457,2.235533,4.021377,-1.435887,0.767682,1.977984,0.31292,0.480738,0.698899
N,N,N,diff.,5.952838,9.630895,12.874142,6.463644,7.947806,9.908751,7.881636,9.339234,11.766748,7.897048,...,2.941832,0.903565,2.742352,5.189143,0.740644,2.365721,3.059876,0.136738,0.631473,0.744
N,N,N,same,6.257276,9.636968,12.874389,6.461676,7.945579,9.902304,7.845629,9.3438,11.762347,7.902327,...,2.423461,0.653814,2.747163,4.717952,0.744294,2.370211,3.079071,-0.778946,0.823196,1.677319
N,Y,N,diff.,5.008355,9.851802,13.350157,6.362659,8.032614,10.098088,8.560939,9.983489,12.399178,8.709943,...,2.095459,0.543903,2.276117,4.393632,0.4791,1.570971,2.857273,2.7e-05,0.000667,0.202467
N,Y,N,same,5.147925,9.868163,13.35022,6.323971,8.017485,10.071638,8.559176,9.985436,12.39927,8.710634,...,1.868485,0.555051,2.258218,4.542683,0.463992,1.423379,2.961855,0.021866,0.10785,0.135602


In [191]:
dfagg_max = dfagg.max(axis=0)

In [192]:

def bold_formatter(x, val):
    if round(x, 1) == round(val, 1):
        return r"\bfseries " + f"{x:.1f}"
    return f"{x:.1f}"

formatters = {
    (c, d): partial(bold_formatter, val=dfagg_max.loc[c, d])
    for c, d in dfagg.columns
}

In [193]:
str_ = dfagg.to_latex(
    formatters=formatters,
    sparsify=True,
    multirow=False,
)

print(str_)

\begin{tabular}{llllrrrrrrrrrrrrrrrrrrrrrrrrrrr}
\toprule
 &  &  & stem & \multicolumn{3}{r}{lead_female_singer} & \multicolumn{3}{r}{lead_male_singer} & \multicolumn{3}{r}{drums} & \multicolumn{3}{r}{bass_guitar} & \multicolumn{3}{r}{acoustic_guitar} & \multicolumn{3}{r}{clean_electric_guitar} & \multicolumn{3}{r}{distorted_electric_guitar} & \multicolumn{3}{r}{grand_piano} & \multicolumn{3}{r}{electric_piano} \\
 &  &  &  & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 & q25 & q50 & q75 \\
is_frozen & is_augmented & is_balanced & query_same &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
\midrule
Y & N & N & diff. & 7.2 & 9.6 & 12.8 & 6.4 & 7.6 & 9.9 & 8.2 & 9.8 & 11.8 & 8.1 & 10.1 & 12.1 & 0.4 & 1.5 & 2.4 & 0.1 & 0.5 & 1.6 & 0.6 & 2.7 & 4.9 & -0.3 & 0.8 & 2.4 & 0.2 & 0.5 & 0.9 \\
 &  &  & same & \bfseries 7.3 & 9.7 & 12.8 & 6.7 & 7.6 & 9.9 & 8.2 & 9.8 &