In [None]:
import logging
from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np

USE_GPU = False

if USE_GPU:
    import manage_gpus as gpl

    gpl.get_gpu_lock()
import activation_learner
from pytorch_nmf import BetaDivergence
import param_estimator
import plot
from mixes.unmixdb import UnmixDB
from mixes.synthetic import SyntheticDB, ManualMix
import seaborn as sns

sns.set_theme("paper")

In [None]:
# hyperparams
FS = 22050
HOP_SIZES = [8, 4, 1, 0.5]
OVERLAP = 8
NMELS = 256
SPEC_POWER = 2
DIVERGENCE = BetaDivergence(0)
GAIN_ESTOR = param_estimator.GainEstimator.SUM
WARP_ESTOR = param_estimator.WarpEstimator.ARGMAX
LOW_POWER_THRESHOLD = 1e-2
CARVE_THRESHOLD = 1e-5
CARVE_BLUR_SIZE = 3
CARVE_MIN_DURATION = 1
CARVE_MAX_SLOPE = 2
NOISE_DIM = 0
# stop conditions
DLOSS_MIN = 1e-8
ITER_MAX = 8000
# paths
# RESULTS_DIR = Path("/data5/anasynth_nonbp/andre/reverse-dj-mix/results")
RESULTS_DIR = Path("/home/etiandre/stage/results/")
# UNMIXDB_PATH = Path("/data2/anasynth_nonbp/schwarz/abc-dj/data/unmixdb-zenodo")
DATASET_PATH = Path("/home/etiandre/stage/datasets/unmixdb/")
# DATASET_PATH = Path("/home/etiandre/stage/datasets/dj_mix_ground_truth_extractor_dataset")
#############################
logger = logging.getLogger()
logging.basicConfig()
logger.setLevel(logging.INFO)

In [None]:
dataset = UnmixDB(DATASET_PATH)
# dataset = SyntheticDB()
mix = dataset.get_mix("set281mix3-none-none-77.mp3")
# mix = dataset.mixes[0]
logger.info(mix.name)

inputs = [track.audio for track in mix.tracks] + [mix.audio]

In [None]:
learner, loss_history = activation_learner.multistage(
    inputs,
    FS,
    hops=HOP_SIZES,
    overlap=OVERLAP,
    nmels=NMELS,
    low_power_threshold=LOW_POWER_THRESHOLD,
    spec_power=SPEC_POWER,
    divergence=DIVERGENCE,
    iter_max=ITER_MAX,
    dloss_min=DLOSS_MIN,
    carve_threshold=CARVE_THRESHOLD,
    carve_blur_size=CARVE_BLUR_SIZE,
    carve_min_duration=CARVE_MIN_DURATION,
    carve_max_slope=CARVE_MAX_SLOPE,
    noise_dim=NOISE_DIM,
    doplot=False,
)

In [None]:
# get ground truth
tau = np.arange(0, learner.V.shape[1]) * HOP_SIZES[-1]

real_gain = mix.gain(tau)
real_warp = mix.warp(tau)
# real_gain = None
# real_warp = None

# estimate gain
logger.info(f"Estimating gain with method {GAIN_ESTOR}")
est_gain = GAIN_ESTOR(learner.H, learner.split_idx, SPEC_POWER)

# estimate warp
logger.info(f"Estimating warp with method {WARP_ESTOR}")
est_warp = WARP_ESTOR(learner.H, learner.split_idx, HOP_SIZES[-1])

# gain_err = param_estimator.error(est_gain, real_gain)
# warp_err = param_estimator.error(est_warp, real_warp)

fig, axs = plt.subplots(1, 3, figsize=(16, 3))

im = plot.plot_H(
    learner.H.detach().numpy(),
    split_idx=learner.split_idx,
    ignored_lines=learner.W_ignored_cols,
    ax=axs[0],
)
plt.colorbar(im, ax=axs[0])
axs[0].set_title("Activation matrix")

plot.plot_gain(tau, est_gain, real_gain, ax=axs[1])
axs[1].set_title(f"Gain")

plot.plot_warp(tau, est_warp, real_warp, ax=axs[2])
axs[2].set_title(f"Warp")
plt.savefig("results.svg")
plt.show()

highparams = []
for i in range(3):
    (
        est_track_start,
        est_fadein_start,
        est_fadein_stop,
        est_fadeout_start,
        est_fadeout_stop,
        est_speed,
        fig,
    ) = param_estimator.estimate_highparams(
        tau, est_gain[:, i], est_warp[:, i], doplot=True
    )
    real_track_start = mix.meta[i]["start"]
    real_fadein_start = mix.meta[i]["fadein"][0]
    real_fadein_stop = mix.meta[i]["fadein"][1]
    real_fadeout_start = mix.meta[i]["fadeout"][0]
    real_fadeout_stop = mix.meta[i]["fadeout"][1]
    real_speed = mix.meta[i]["speed"]

    err_track_start = param_estimator.error(est_track_start, real_track_start)
    err_fadein_start = param_estimator.error(est_fadein_start, real_fadein_start)
    err_fadein_stop = param_estimator.error(est_fadein_stop, real_fadein_stop)
    err_fadeout_start = param_estimator.error(est_fadeout_start, real_fadeout_start)
    err_fadeout_stop = param_estimator.error(est_fadeout_stop, real_fadeout_stop)
    err_speed = param_estimator.error(est_speed, real_speed)

    highparams.append(
        {
            "track_start_est": est_track_start,
            "fadein_start_est": est_fadein_start,
            "fadein_stop_est": est_fadein_stop,
            "fadeout_start_est": est_fadeout_start,
            "fadeout_stop_est": est_fadeout_stop,
            "speed_est": est_speed,
            "track_start_real": real_track_start,
            "fadein_start_real": real_fadein_start,
            "fadein_stop_real": real_fadein_stop,
            "fadeout_start_real": real_fadeout_start,
            "fadeout_stop_real": real_fadeout_stop,
            "speed_real": real_speed,
            "track_start_err": err_track_start,
            "fadein_start_err": err_fadein_start,
            "fadein_stop_err": err_fadein_stop,
            "fadeout_start_err": err_fadeout_start,
            "fadeout_stop_err": err_fadeout_stop,
            "speed_err": err_speed,
        }
    )
pprint(highparams)

In [None]:
import scipy.ndimage
import scipy.signal

est_gain_smooth = scipy.ndimage.median_filter(est_gain, size=15, axes=0, mode="nearest")
est_warp_smooth = scipy.ndimage.median_filter(est_warp, size=10, axes=0, mode="nearest")
# est_warp_smooth = scipy.ndimage.gaussian_filter(est_warp_smooth, sigma=5, axes=0, mode="nearest")

plot.plot_gain(tau, est_gain_smooth, est_gain)
plt.show()

# plot.plot_warp(tau, est_warp_smooth, est_warp)
# plt.show()

data = {
    # "H": learner.H.T.tolist(),
    "tau": tau.tolist(),
    "split_idx": [int(i) for i in learner.split_idx],
    "mix": mix.name,
    "tracks": [track.name for track in mix.tracks],
    "gain": np.nan_to_num(real_gain).T.tolist(),
    "warp": np.nan_to_num(real_warp).T.tolist(),
}

import json
import soundfile as sf

sf.write("/home/etiandre/git/viz/static/mix.mp3", mix.audio, samplerate=FS)
for i, track in enumerate(mix.tracks):
    sf.write(f"/home/etiandre/git/viz/static/track_{i}.mp3", track.audio, samplerate=FS)


with open("/home/etiandre/git/viz/static/data.json", "w") as f:
    json.dump(data, f, indent=4)