In [None]:
import pytorch_nmf
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def _transform_melspec(input, fs, n_mels, stft_win_func, win_len, hop_len):
    spec = librosa.stft(
        input,
        n_fft=win_len,
        hop_length=hop_len,
        win_length=win_len,
        center=True,
        window=stft_win_func,
    )
    mel_f = librosa.filters.mel(sr=fs, n_fft=win_len, n_mels=n_mels)
    melspec = mel_f.dot(abs(spec) ** 2)

    return melspec


FS = 22050
HSIZE = 0.01
WSIZE = 8 * HSIZE
NMELS = 256
input_paths = ["linear-mix-1.wav", "linear-mix-2.wav", "linear-mix.wav"]
inputs = [librosa.load(path, sr=FS)[0] for path in input_paths]
specs = [_transform_melspec(i, FS, NMELS, np.hanning, int(WSIZE*FS), int(HSIZE*FS)) for i in inputs]
W = torch.Tensor(np.concatenate(specs[:-1], axis=1))
# W /= W.sum(axis=1, keepdims=True)
V = torch.Tensor(specs[-1])
# X /= X.sum(axis=1, keepdims=True)
H = abs(torch.randn((W.shape[1], V.shape[1])))

In [None]:
print(V.shape, W.shape, H.shape, (W @ H).shape)

In [None]:
from tqdm import tqdm


DIVERGENCE = pytorch_nmf.ItakuraSaito()
PENALTIES = [
    (pytorch_nmf.L1(), 1)
]
model = pytorch_nmf.NMF(V.T, W, H.T, DIVERGENCE, PENALTIES, trainable_W=False)
for i in tqdm(range(1000)):
    model.iterate()
plt.imshow(model.Ht.T.detach(), aspect="auto", origin="lower")