In [None]:
import pytorch_nmf
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Penalties vizualization

In [None]:
import activation_learner, plot
import librosa
import itertools

FS = 22050
# input_paths = ["linear-mix-1.wav", "linear-mix-2.wav", "linear-mix.wav"]
input_paths = ["original.wav", "boucled.wav"]
inputs = [librosa.load(path, sr=FS)[0] for path in input_paths]

OVERLAP_FACTOR = 4
HOP_SIZE = 0.5
BETA = 0
NMELS = 64
# stop conditions
ITER_MAX = 1000
# logging
DIVERGENCE = pytorch_nmf.ItakuraSaito()
PENALTIES = [
    # pytorch_nmf.L1(),
    # pytorch_nmf.L2(),
    # pytorch_nmf.SmoothGain(),
    # pytorch_nmf.SmoothDiago(),
    # pytorch_nmf.WarpReguA(),
    # pytorch_nmf.ColumnClusterPromotion(),
    pytorch_nmf.Lineness(),
]
LAMBDAS = [0, 1e3, 1e4, 1e5, 1e6]
for penalty in PENALTIES:
    fig, axs = plt.subplots(4, len(LAMBDAS), sharey="row")
    fig.set_size_inches(20, 10)
    fig.suptitle(f"{penalty.__class__.__name__}")

    for p, lambda_ in enumerate(LAMBDAS):
        learner = activation_learner.ActivationLearner(
            inputs,
            fs=FS,
            n_mels=NMELS,
            win_size=HOP_SIZE * OVERLAP_FACTOR,
            hop_size=HOP_SIZE,
            divergence=DIVERGENCE,
            penalties=[(penalty, lambda_)],
        )
        losses = []
        for i in tqdm(range(ITER_MAX)):
            try:
                learner.iterate()
            except:
                print(f"Stopped at iter {i} due to error")
                raise
            if i % 10 == 0:
                loss, loss_comp = learner.loss()
                losses.append(loss_comp)

        print(f"Stopped at NMF iteration={i} loss={loss}")
        plot.plot_H(learner.nmf.H.detach().numpy(), learner.split_idx, ax=axs[0, p])
        fig.colorbar(axs[0, p].images[0], ax=axs[0, p])
        axs[0, p].set_title(f"$\\lambda$ = {lambda_:.2e}")
        axs[1, p].imshow(
            penalty.grad_pos(learner.nmf.H).detach().numpy(),
            origin="lower",
            cmap="turbo",
            aspect="auto",
        )
        fig.colorbar(axs[1, p].images[0], ax=axs[1, p])
        axs[1, p].set_title("grad pos")
        axs[2, p].imshow(
            penalty.grad_neg(learner.nmf.H).detach().numpy(),
            origin="lower",
            cmap="turbo",
            aspect="auto",
        )
        fig.colorbar(axs[2, p].images[0], ax=axs[2, p])
        axs[2, p].set_title("grad neg")
        plot.plot_loss_history(losses, ax=axs[3, p])
    plt.tight_layout()
    plt.show()