In [None]:
import modular_nmf
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Penalties vizualization

In [None]:
import activation_learner, plot
import librosa
import itertools

FS = 22050
input_paths = ["linear-mix-1.wav", "linear-mix-2.wav", "linear-mix.wav"]
inputs = [librosa.load(path, sr=FS)[0] for path in input_paths]

OVERLAP_FACTOR = 4
HOP_SIZE = 1
BETA = 0
NMELS = 256
# stop conditions
ITER_MAX = 500
# logging
DIVERGENCE = modular_nmf.BetaDivergence(0)
PENALTIES = [
    modular_nmf.L1(),
    modular_nmf.L2(),
    modular_nmf.SmoothGain(),
    modular_nmf.SmoothDiago(),
    modular_nmf.VirtanenTemporalContinuity(),
]
LAMBDAS = [1, 10, 100, 1e4,1e6]
for penalty in PENALTIES:
    fig, axs = plt.subplots(2, len(LAMBDAS))
    fig.set_size_inches(20, 10)
    fig.suptitle(f"{penalty.__class__.__name__}")

    for p, lambda_ in enumerate(LAMBDAS):
        learner = activation_learner.ActivationLearner(
            inputs,
            fs=FS,
            n_mels=NMELS,
            win_size=HOP_SIZE * OVERLAP_FACTOR,
            hop_size=HOP_SIZE,
            divergence=DIVERGENCE,
            penalties=[(penalty, lambda_)],
            postprocessors=[],
        )
        losses = []
        for i in itertools.count():
            try:
                loss, loss_comp = learner.iterate(0)
            except:
                print(f"Stopped at iter {i} due to error")
                raise
            losses.append(loss_comp)

            if i > ITER_MAX:
                print(f"Stopped at NMF iteration={i} loss={loss}")
                break
        plot.plot_H(learner.H, learner.split_idx, ax=axs[0, p])
        axs[0, p].set_title(f"{lambda_:.2e}")
        plot.plot_loss_history(losses, ax=axs[1, p])
    plt.show()