In [None]:
import torchnmf
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
from mixes.synthetic import SyntheticDB

In [None]:
def _transform_melspec(input, fs, n_mels, stft_win_func, win_len, hop_len):
    spec = librosa.stft(
        input,
        n_fft=win_len,
        hop_length=hop_len,
        win_length=win_len,
        center=True,
        window=stft_win_func,
    )
    mel_f = librosa.filters.mel(sr=fs, n_fft=win_len, n_mels=n_mels)
    melspec = mel_f.dot(abs(spec) ** 2)

    return melspec


FS = 22050
HSIZE = 0.01
WSIZE = 8 * HSIZE
NMELS = 256

db = SyntheticDB()
mix = db.get_mix("linear-mix")
inputs = mix.as_activation_learner_input()
specs = [
    _transform_melspec(i, FS, NMELS, np.hanning, int(WSIZE * FS), int(HSIZE * FS))
    for i in inputs
]
W = torch.Tensor(np.concatenate(specs[:-1], axis=1))
# W /= W.sum(axis=1, keepdims=True)
V = torch.Tensor(specs[-1])
# X /= X.sum(axis=1, keepdims=True)
H = abs(torch.randn((W.shape[1], V.shape[1])))
Z = torch.ones(H.shape[0])

In [None]:
print(V.shape, W.shape, H.shape, Z.shape, (W @ H).shape)

In [None]:
net = torchnmf.plca.PLCA(W=W, H=H, Z=Z, trainable_W=False)
net.fit(V)