In [116]:
import librosa.effects
import numpy as np
import pandas as pd
import soundfile as sf

import os

from ipywidgets import Audio
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import librosa

In [117]:
from norbert import wiener

In [118]:
GROUND_TRUTH_ROOT = "/home/kwatchar3/Documents/data/moisesdb"

splits = pd.read_csv("/home/kwatchar3/projects/query-bandit/reproducibility/splits.csv")

gt_files = splits[splits["split"] == 5]["song_id"].values

In [119]:
stems = ["lead_female_singer",
         "lead_male_singer",
         "drums", "bass_guitar",
         "acoustic_guitar", "clean_electric_guitar", "distorted_electric_guitar",
         "grand_piano", "electric_piano"]

In [120]:
def snr(gt, est):
    return 10 * np.log10(np.sum(np.square(gt)) / np.sum(np.square(gt - est)))

In [121]:
def compute_irm(song_id):
    mixture = np.load(os.path.join(GROUND_TRUTH_ROOT, "npy2", song_id, "mixture.npy"))

    X = librosa.stft(mixture, n_fft=2048, hop_length=512)

    res = []

    sources = {}

    for stem in stems:
        stem_name = stem
        stem_path = os.path.join(GROUND_TRUTH_ROOT, "npy2", song_id, f"{stem}.npy")
        if not os.path.exists(stem_path):
            continue

        stem = np.load(stem_path)
        _, n_frames = stem.shape

        S = librosa.stft(stem, n_fft=2048, hop_length=512, pad_mode="constant")

        noises = X - S

        irm = (np.abs(S) + 1e-8) / (np.abs(S) + np.abs(noises) + 1e-8)
        Y = irm * X

        y = librosa.istft(Y, n_fft=2048, hop_length=512, length=n_frames)

        snr_ = snr(stem, y)

        res.append({
            "song_id": song_id,
            "stem": stem_name,
            "snr": snr_,
            "method": "oracle_irm"
        })

        ibm = np.abs(S) > np.abs(noises)

        Y = ibm * X

        y = librosa.istft(Y, n_fft=2048, hop_length=512, length=n_frames)

        snr_ = snr(stem, y)

        res.append({
            "song_id": song_id,
            "stem": stem_name,
            "snr": snr_,
            "method": "oracle_ibm"
        })



    return res

In [122]:
from tqdm.contrib.concurrent import process_map

df = []

for song_id in tqdm(gt_files):

    res = compute_irm(song_id)
    df.extend(res)

df = pd.DataFrame(df)


  0%|          | 0/48 [00:00<?, ?it/s]

In [123]:
df.to_csv("oracle_irm_ibm.csv", index=False)

In [124]:
stem_dtype = pd.CategoricalDtype(categories=stems, ordered=True)

In [125]:
df

Unnamed: 0,song_id,stem,snr,method
0,0358fd1e-244a-4422-9a42-29b5d68f6e4b,lead_male_singer,8.228316,oracle_irm
1,0358fd1e-244a-4422-9a42-29b5d68f6e4b,lead_male_singer,8.718332,oracle_ibm
2,0358fd1e-244a-4422-9a42-29b5d68f6e4b,drums,9.494342,oracle_irm
3,0358fd1e-244a-4422-9a42-29b5d68f6e4b,drums,10.163924,oracle_ibm
4,0358fd1e-244a-4422-9a42-29b5d68f6e4b,acoustic_guitar,6.922390,oracle_irm
...,...,...,...,...
437,fc3c9e48-e2ac-4088-af65-68404baa7f12,bass_guitar,6.883285,oracle_ibm
438,fc3c9e48-e2ac-4088-af65-68404baa7f12,clean_electric_guitar,3.328224,oracle_irm
439,fc3c9e48-e2ac-4088-af65-68404baa7f12,clean_electric_guitar,3.678336,oracle_ibm
440,fc3c9e48-e2ac-4088-af65-68404baa7f12,grand_piano,5.070795,oracle_irm


In [133]:
df = pd.read_csv("oracle_irm_ibm.csv")
df = df.groupby(["stem", "method"])["snr"].describe()[["25%", "50%", "75%"]].reset_index()
#
df["stem"] = df["stem"].astype(stem_dtype)
#
df = df.sort_values("stem")
#
# df

df

Unnamed: 0,stem,method,25%,50%,75%
15,lead_female_singer,oracle_irm,9.959732,11.296755,11.913341
14,lead_female_singer,oracle_ibm,10.630318,12.033295,12.443081
17,lead_male_singer,oracle_irm,7.992356,9.261748,10.070114
16,lead_male_singer,oracle_ibm,8.597003,9.9281,10.773254
9,drums,oracle_irm,6.958705,8.49988,9.528434
8,drums,oracle_ibm,7.732758,9.383624,10.280637
3,bass_guitar,oracle_irm,6.049543,7.402663,9.265378
2,bass_guitar,oracle_ibm,6.648597,7.91229,9.910038
1,acoustic_guitar,oracle_irm,3.035848,4.191159,5.747699
0,acoustic_guitar,oracle_ibm,3.447459,4.3001,6.205191


In [134]:

df = df.melt(id_vars=["method", "stem"], var_name="percentile", value_name="snr")

df = df.pivot_table(
    index="method",
    columns=["stem", "percentile"],
    values="snr",).sort_index(axis=1)

df

stem,lead_female_singer,lead_female_singer,lead_female_singer,lead_male_singer,lead_male_singer,lead_male_singer,drums,drums,drums,bass_guitar,...,clean_electric_guitar,distorted_electric_guitar,distorted_electric_guitar,distorted_electric_guitar,grand_piano,grand_piano,grand_piano,electric_piano,electric_piano,electric_piano
percentile,25%,50%,75%,25%,50%,75%,25%,50%,75%,25%,...,75%,25%,50%,75%,25%,50%,75%,25%,50%,75%
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
oracle_ibm,10.630318,12.033295,12.443081,8.597003,9.9281,10.773254,7.732758,9.383624,10.280637,6.648597,...,7.275913,4.973454,6.127327,7.521223,3.403751,4.347475,6.052148,4.09523,5.200898,5.520819
oracle_irm,9.959732,11.296755,11.913341,7.992356,9.261748,10.070114,6.958705,8.49988,9.528434,6.049543,...,6.925114,4.584637,5.843509,6.915269,3.09131,4.369795,5.786943,3.918085,4.877171,5.181009


In [135]:
print(df.to_latex(float_format="%.1f"))

\begin{tabular}{lrrrrrrrrrrrrrrrrrrrrrrrrrrr}
\toprule
stem & \multicolumn{3}{r}{lead_female_singer} & \multicolumn{3}{r}{lead_male_singer} & \multicolumn{3}{r}{drums} & \multicolumn{3}{r}{bass_guitar} & \multicolumn{3}{r}{acoustic_guitar} & \multicolumn{3}{r}{clean_electric_guitar} & \multicolumn{3}{r}{distorted_electric_guitar} & \multicolumn{3}{r}{grand_piano} & \multicolumn{3}{r}{electric_piano} \\
percentile & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% & 25% & 50% & 75% \\
method &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
\midrule
oracle_ibm & 10.6 & 12.0 & 12.4 & 8.6 & 9.9 & 10.8 & 7.7 & 9.4 & 10.3 & 6.6 & 7.9 & 9.9 & 3.4 & 4.3 & 6.2 & 3.9 & 5.8 & 7.3 & 5.0 & 6.1 & 7.5 & 3.4 & 4.3 & 6.1 & 4.1 & 5.2 & 5.5 \\
oracle_irm & 10.0 & 11.3 & 11.9 & 8.0 & 9.3 & 10.1 & 7.0 & 8.5 & 9.5 & 6.0 & 7.4 & 9.3 & 3.0 & 4.2 & 5.7 & 3.5 & 5.6 & 6.9 & 4.6 & 5.8 & 