In [None]:
import numpy as np

# from tqdm.notebook import tqdm
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
from pprint import pprint
import scipy.ndimage
from IPython.display import display, Audio
from estimator import ActivationLearner
import scipy.signal
import logging
from unmixdb import UnmixDB

plt.style.use("dark_background")
logging.basicConfig(level=logging.INFO)
logging.getLogger("activation_learner").setLevel(logging.DEBUG)
from pathlib import Path
import scipy.sparse
import skimage.transform
import multiprocessing

In [None]:
# Load audios

FS = 22050
input_paths = ["linear-mix-1.wav", "linear-mix-2.wav", "linear-mix.wav"]
# input_paths = ["linear-mix.wav"]
# input_paths = ["original.wav", "original.wav"]
# input_paths = ["original.wav", "boucled.wav"]
# input_paths = ["amen.wav", "high.wav", "nuttah.wav"]
# input_paths = ["nuttah.wav", "nuttah-timestretch.wav"]


## load trackidnet
# mix_id = 131733
# track_ids = [637380, 276567, 219825, 284725, 15102, 23080, 9579, 598444, 24212, 249590]
# TRACKIDNET_PATH = Path("/data5/anasynth_nonbp/andre/trackidnet/")
# input_paths = [TRACKIDNET_PATH / f"tracks/{i}" for i in track_ids] + [
#     TRACKIDNET_PATH / f"mixes/{mix_id}"
# ]


def load_audio(path):
    return librosa.load(path, sr=FS)[0]


with multiprocessing.Pool() as pool:
    inputs = list(
        tqdm(
            pool.imap(load_audio, input_paths),
            desc="loading audios",
            total=len(input_paths),
        )
    )

# ## load unmixdb
# unmixdb = UnmixDB("/data2/anasynth_nonbp/schwarz/abc-dj/data/unmixdb-zenodo")
# print(unmixdb.timestretches)
# print(unmixdb.fxes)
# mixes = dict(
#     filter(
#         lambda i: i[1].timestretch == "none" and i[1].fx == "none",
#         unmixdb.mixes.items(),
#     )
# )
# mix = list(mixes.values())[5]
# pprint(mix)
# inputs = [unmixdb.refsongs[track["name"]].audio(sr=FS) for track in mix.tracks]
# inputs.append(np.tile(mix.audio(sr=FS), 1))

In [None]:
import itertools


def carve(H: np.ndarray, split_idx, threshold):
    sum = H.sum(axis=0)
    for left, right in zip(split_idx, split_idx[1:]):
        vol = H[left:right, :].sum(axis=0) / sum
        H[left:right, vol < threshold] = 0
    return H


def carve_naive(H: np.ndarray, threshold):
    H[H < threshold] = 0
    return H


def imshow_highlight_zero(X: np.ndarray, **kwargs):
    X_ = X.copy()
    X_[X_ == 0] = np.nan
    plt.imshow(X_, **kwargs)


WIN_SIZES = [4,  2 ,  1, 0.5]

previous_H = None
previous_split_idx = None
for win_size in WIN_SIZES:
    model = ActivationLearner(
        inputs,
        fs=FS,
        n_mels=256,
        win_size=win_size,
        overlap=0.75,
        stft_win_func=np.hanning,
    )

    if previous_H is not None:
        previous_H = carve(previous_H.toarray(), previous_split_idx, 1e-2)
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.title("H carved")
        imshow_highlight_zero(previous_H, cmap="turbo", origin="lower", aspect="auto")
        previous_H = skimage.transform.resize(
            previous_H, model.nmf.H.shape, order=1, preserve_range=True
        )
        plt.subplot(1, 2, 2)
        plt.title("H resized")
        imshow_highlight_zero(previous_H, cmap="turbo", origin="lower", aspect="auto")
        plt.show()
        previous_H = scipy.sparse.bsr_array(previous_H)
        model.nmf.H = previous_H

    losses = []
    for i in (pbar := tqdm(itertools.count())):
        loss = model.iterate()
        if i == 0:
            dloss = np.inf
        else:
            dloss = abs(losses[-1] - loss)
        losses.append(loss)
        pbar.set_description(f"{loss=:.2e} {dloss=:.2e}")
        if dloss < 1e-4:
            break
    plt.figure(figsize=(8, 1))
    plt.plot(losses)
    plt.xlabel("iter")
    plt.title("distance")
    plt.yscale("log")
    plt.show()
    model.plot()

    previous_H = model.nmf.H
    previous_split_idx = model.split_idx

In [None]:
plt.figure(figsize=(15, 15))
plt.imshow(model.nmf.H.toarray(), cmap="turbo", origin="lower")