In [87]:
from functools import partial
from itertools import chain

import librosa.effects
import numpy as np
import pandas as pd
import soundfile as sf

import os

from ipywidgets import Audio
from matplotlib import pyplot as plt
from tqdm.contrib.concurrent import process_map
from tqdm.notebook import tqdm

import librosa

In [88]:
INFERENCE_ROOT = "/home/kwatchar3/Documents/data/moisesdb/"
STEM_SETUP = "vdb"
GROUND_TRUTH_ROOT = "/home/kwatchar3/Documents/data/moisesdb"

In [89]:
variants = ["vdb-d-nopre", "vdb-d-prefz", "vdb-d-pre", "vdb-d-pre-aug"]


In [90]:
gt_files = os.listdir(os.path.join(GROUND_TRUTH_ROOT, "npy2"))

In [91]:
def snr(gt, est):
    return 10 * np.log10(np.sum(np.square(gt)) / np.sum(np.square(gt - est)))

In [92]:
allowed_stems = {
    "vocals": [
        "lead_female_singer",
        "lead_male_singer",
    ],
    "drums": [
        "drums",
    ],
    "bass": [
        "bass_guitar",
        # "bass_synthesizer",
        # "bass_synth"
    ],
}


In [93]:
def get_results_for_song(inputs):

    song_name, inference_mode, variant = inputs

    stems = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, variant, "audio", song_name))
    stems = [s.replace(".wav", "") for s in stems]

    results = []

    outputs = []

    for coarse_stem in allowed_stems:

        coarse_pred = []
        coarse_true = []

        has_gt = False
        for stem in allowed_stems[coarse_stem]:
            stem_has_gt = False
            npy_path = os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, f"{stem}.npy")
            if os.path.exists(npy_path):
                audio = np.load(npy_path, mmap_mode="r")
                coarse_true.append(audio)
                has_gt = True
                stem_has_gt = True
            else:
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                # print(f"Ground truth not found for {song_name}/{stem}. Using zeros.")
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                pass



            audio_path = os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, variant, "audio", song_name, f"{stem}.wav")
            if os.path.exists(audio_path):
                audio, sr = sf.read(audio_path)
                audio = audio.T
                coarse_pred.append(audio)
            else:
                if stem_has_gt:
                    print("******************************************************")
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                    print(f"Prediction not found for {song_name}/{stem}. Using zeros.")
                    print("******************************************************")
                # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

        if not has_gt:
            # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            # print(f"No ground truth found for {song_name} - {coarse_stem}. Skipping.")
            # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            continue

        coarse_true_path = os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, f"{coarse_stem}.npy")
        if os.path.exists(coarse_true_path):
            coarse_true_ = np.load(coarse_true_path, mmap_mode="r")
        else:
            coarse_true_ = sum(coarse_true)

        coarse_pred = sum(coarse_pred)

        snr_full = snr(coarse_true_, coarse_pred)

        results.append({
            "song": song_name,
            "stem": coarse_stem,
            "snr": snr_full,
            "variant": variant,
            "inference_mode": inference_mode,
        })

        if coarse_stem == "vocals":
            coarse_true_ = sum(coarse_true)
            snr_full = snr(coarse_true_, coarse_pred)

            results.append({
                "song": song_name,
                "stem": "lead_vocals",
                "snr": snr_full,
                "variant": variant,
                "inference_mode": inference_mode,
            })

        outputs.append(coarse_pred)

    output_all = sum(outputs)
    mixture = np.load(os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, "mixture.npy"), mmap_mode="r")
    residual_pred = mixture - output_all
    residual_true = np.load(os.path.join(GROUND_TRUTH_ROOT, "npy2", song_name, "vdbo_others.npy"), mmap_mode="r")

    snr_residual = snr(residual_true, residual_pred)

    results.append({
        "song": song_name,
        "stem": "residual",
        "snr": snr_residual,
        "variant": variant,
        "inference_mode": inference_mode,
    })

    return results


In [94]:

df = []

for inference_mode in ["inference-d", "inference-o"]:

    for v in variants:
        print(f"Processing {v}...")

        test_files = os.listdir(os.path.join(INFERENCE_ROOT, inference_mode, STEM_SETUP, v, "audio"))

        # for song in tqdm(test_files):
        #     results = get_results_for_song(song, inference_mode, v)
        #     df.extend(results)

        inputs = [(song, inference_mode, v) for song in test_files]

        results = process_map(get_results_for_song, inputs, max_workers=16)
        results = list(chain(*results))

        df.extend(results)


df = pd.DataFrame(df)

Processing vdb-d-nopre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-nopre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-prefz...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-pre...


  0%|          | 0/48 [00:00<?, ?it/s]

Processing vdb-d-pre-aug...


  0%|          | 0/48 [00:00<?, ?it/s]

In [95]:
df["snr"] = df["snr"].replace(-np.inf, np.nan)

In [96]:
df.to_csv(os.path.join(INFERENCE_ROOT, "bandit_vdb_merged.csv"), index=False)

In [97]:
df = pd.read_csv(os.path.join(INFERENCE_ROOT, "bandit_vdb_merged.csv"))

stem_dtype = pd.CategoricalDtype(categories=allowed_stems, ordered=True)
df["stem"] = df["stem"].astype(stem_dtype)

bool_dtype = pd.CategoricalDtype(categories=["Y", "N"], ordered=True)
ibool_dtype = pd.CategoricalDtype(categories=["N", "Y"], ordered=True)

df["is_pretrained"] = df["variant"].str.contains("nopre").apply(lambda x: "N" if x else "Y").astype(ibool_dtype)
df["is_frozen"] = df["variant"].str.contains("prefz").apply(lambda x: "Y" if x else "N").astype(bool_dtype)
df["is_balanced"] = df["variant"].str.contains("bal").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["is_augmented"] = df["variant"].str.contains("aug").apply(lambda x: "Y" if x else "N").astype(ibool_dtype)
df["query_same"] = df["inference_mode"].str.contains("-o").apply(lambda x: "same" if x else "diff.")

In [98]:
def q25(x):
    return x.quantile(0.25)

def q75(x):
    return x.quantile(0.75)

def q50(x):
    return x.quantile(0.5)

dfagg = df.groupby([
    "is_pretrained",
    "is_frozen",
    "is_augmented",
    # "is_balanced",
    "query_same",
    "stem"
]).agg({"snr": [q50]})
dfagg.columns = ["q50"]
dfagg = dfagg.reset_index()

dfagg = dfagg[dfagg.query_same == "diff."]

dfagg = dfagg.reset_index().pivot_table(
    index=["is_pretrained", "is_frozen", "is_augmented",],
    columns="stem",
    values=["q50"]
)


  dfagg = df.groupby([


In [99]:
dfagg = dfagg.swaplevel(axis=1).sort_index(axis=1)
dfagg_max = dfagg.max(axis=0)


In [100]:

def bold_formatter(x, val):
    if round(x, 1) == round(val, 1):
        return r"\bfseries " + f"{x:.1f}"
    return f"{x:.1f}"

formatters = {
    (c, d): partial(bold_formatter, val=dfagg_max.loc[c, d])
    for c, d in dfagg.columns
}

In [101]:
str_ = dfagg.to_latex(
    formatters=formatters,
    sparsify=False,
    multirow=False,
)

print(str_)

\begin{tabular}{lllrrr}
\toprule
 &  & stem & vocals & drums & bass \\
 &  &  & q50 & q50 & q50 \\
is_pretrained & is_frozen & is_augmented &  &  &  \\
\midrule
N & N & N & 6.7 & 9.4 & 9.6 \\
Y & Y & N & \bfseries 8.0 & 9.9 & 10.5 \\
Y & N & N & 7.4 & 9.3 & 10.5 \\
Y & N & Y & 7.9 & \bfseries 10.1 & \bfseries 11.2 \\
\bottomrule
\end{tabular}

