In [None]:
import numpy as np

# from tqdm.notebook import tqdm
from tqdm import tqdm
import matplotlib.pyplot as plt
import librosa
from pprint import pprint
import scipy.ndimage
from IPython.display import display, Audio
from nmf.activation_learner import ActivationLearner
import scipy.signal
import logging
from unmixdb import UnmixDB
from tensorboardX import SummaryWriter
import param_estimator

plt.style.use("dark_background")
logging.basicConfig(level=logging.INFO)
logging.getLogger("activation_learner").setLevel(logging.DEBUG)
from pathlib import Path
import scipy.sparse
import multiprocessing
import cv2

In [None]:
# Load audios

FS = 22050
mix = None
# input_paths = ["linear-mix-1.wav", "linear-mix-2.wav", "linear-mix.wav"]
# input_paths = ["original.wav", "original.wav"]
# input_paths = ["original.wav", "boucled.wav"]
# input_paths = ["amen.wav", "high.wav", "nuttah.wav"]
# input_paths = ["nuttah.wav", "nuttah-timestretch.wav"]


## load trackidnet
# mix_id = 131733
# track_ids = [637380, 276567, 219825, 284725, 15102, 23080, 9579, 598444, 24212, 249590]
# TRACKIDNET_PATH = Path("/data5/anasynth_nonbp/andre/trackidnet/")
# input_paths = [TRACKIDNET_PATH / f"tracks/{i}" for i in track_ids] + [
#     TRACKIDNET_PATH / f"mixes/{mix_id}"
# ]

# load unmixdb
unmixdb = UnmixDB("/data2/anasynth_nonbp/schwarz/abc-dj/data/unmixdb-zenodo")
print(unmixdb.timestretches)
print(unmixdb.fxes)
mixes = dict(
    filter(
        lambda i: i[1].timestretch == "stretch" and i[1].fx == "none",
        unmixdb.mixes.items(),
    )
)
mix = list(mixes.values())[1]
pprint(mix)
input_paths = [unmixdb.refsongs[track["name"]].audio_path for track in mix.tracks] + [
    mix.audio_path
]


def load_audio(path):
    return librosa.load(path, sr=FS)[0]


with multiprocessing.Pool() as pool:
    inputs = list(
        tqdm(
            pool.imap(load_audio, input_paths),
            desc="loading audios",
            total=len(input_paths),
        )
    )
fade = np.linspace(0, 1, FS)
inputs[2][: len(fade)] *= fade

# for i in inputs:
# display(Audio(i, rate=FS))

In [None]:
import itertools


def carve(H: np.ndarray, split_idx, threshold):
    sum = H.sum(axis=0)
    for left, right in zip(split_idx, split_idx[1:]):
        vol = H[left:right, :].sum(axis=0) / sum
        H[left:right, vol < threshold] = 0
    return H


def carve_naive(H: np.ndarray, threshold):
    H[H < threshold] = 0
    return H


def imshow_highlight_zero(X: np.ndarray, **kwargs):
    X_ = X.copy()
    X_[X_ == 0] = np.nan
    plt.imshow(X_, **kwargs)


HOP_SIZES = [5, 1]
OVERLAP_FACTOR = 4
previous_H = None
previous_split_idx = None
for hop_size in HOP_SIZES:
    win_size = OVERLAP_FACTOR * hop_size
    logging.info(f"Starting round with {hop_size=}s, {win_size=}s")

    model = ActivationLearner(
        inputs,
        fs=FS,
        n_mels=256,
        beta=0,
        win_size=win_size,
        hop_size=hop_size,
        lambda_variance=0,
    )

    if previous_H is not None:
        previous_H = carve(previous_H.toarray(), previous_split_idx, 1e-2)
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.title("H carved")
        imshow_highlight_zero(previous_H, cmap="turbo", origin="lower", aspect="auto")
        previous_H = cv2.resize(
            previous_H,
            dsize=(model.nmf.H.shape[1], model.nmf.H.shape[0]),
            interpolation=cv2.INTER_AREA,
        )
        plt.subplot(1, 2, 2)
        plt.title("H resized")
        imshow_highlight_zero(previous_H, cmap="turbo", origin="lower", aspect="auto")
        plt.show()
        previous_H = scipy.sparse.bsr_array(previous_H)
        model.nmf.H = previous_H

    losses = []
    for i in (pbar := tqdm(itertools.count())):
        loss = model.iterate()
        if i == 0:
            dloss = np.inf
        else:
            dloss = abs(np.sum(losses[-1]) - np.sum(loss))
        losses.append(loss)

        pbar.set_description(f"loss{np.sum(loss):.2e} {dloss=:.2e}")
        if dloss < 1e-8:
            break
    losses = np.array(losses)
    fig, ax1 = plt.subplots(figsize=(8, 3))
    ax1.set_xlabel("iter")
    ax1.set_ylabel("losses")
    for i, loss_component in enumerate(losses.T):
        ax1.plot(loss_component, label=f"loss {i}")
    ax1.plot(np.sum(loss_component, axis=0), label="total")
    ax1.set_yscale("log")
    ax1.legend()
    fig.tight_layout()
    plt.show()
    model.plot()
    est_volumes = model.est_volumes()
    est_positions = model.est_timeremap()
    if mix is not None:
        T = np.arange(0, int(mix.duration / hop_size * FS), int(hop_size * FS)) / FS
        real_volumes = mix.get_track_volumes(T)
        real_positions = mix.get_track_positions(T)
        param_estimator.plot_vol_pos(
            est_volumes, est_positions, hop_size, real_volumes, real_positions
        )
        print(
            f"volume rel error={param_estimator.rel_error(est_volumes, real_volumes):.2e}"
        )
        print(
            f"position rel error={param_estimator.rel_error(est_positions, real_positions):.2e}"
        )
    else:
        param_estimator.plot_vol_pos(est_volumes, est_positions, hop_size, None, None)

    previous_H = model.nmf.H
    previous_split_idx = model.split_idx

In [None]:
plt.figure(figsize=(16, 8))
plt.imshow(model.nmf.H.toarray(), cmap="turbo", aspect="auto", origin="lower")
plt.colorbar()

In [None]:
import pickle

with open("model.pickle", "wb") as f:
    pickle.dump(model, f)

In [None]:
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt

# Classic straight-line Hough transform
# Set a precision of 0.5 degree.
tested_angles = np.linspace(-np.pi / 2, np.pi / 2, 360, endpoint=False)
h, theta, d = skimage.transform.hough_line(
    np.fmax(model.nmf.H.toarray() - 1e-4, 0), theta=tested_angles
)
plt.imshow(
    h,
    extent=(np.rad2deg(theta[0]), np.rad2deg(theta[-1]), d[-1], d[0]),
    aspect="auto",
    cmap="turbo",
)
plt.xlabel("Theta (degrees)")
plt.ylabel("Distance (pixels)")
plt.colorbar()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

n = 3  # Example value, you can set this to any integer

flat_h = h.flatten()
sorted_indices = np.argsort(flat_h)[-n:]
highest_points = np.column_stack(np.unravel_index(sorted_indices, h.shape))

plt.imshow(h, cmap="turbo", aspect="auto")
plt.colorbar()
plt.scatter(highest_points[:, 1], highest_points[:, 0], c="red")
plt.title(f"Top {n} highest points")
plt.show()