In [1]:
import os
import torch
import torchaudio
import sys

notebook_dir = os.getcwd()
sys.path.append(os.path.join(notebook_dir, "../src_torch"))

import soundfile as sf

from separation.FastMNMF2 import FastMNMF2
from Base import MultiSTFT


## FastMNMF2

In [11]:
audio_src_dir = "classic"
audio_src_name = "classic_all_1_4chann.wav"
audio_src_path = os.path.join(notebook_dir, "../", audio_src_dir, audio_src_name)

audio_save_dir = os.path.join(notebook_dir, "..", "result")
if not os.path.exists(audio_save_dir):
    os.makedirs(audio_save_dir)

n_source = 3
n_basis = 32
device = "cuda:1" if torch.cuda.is_available() else "cpu"
init_SCM = "circular"
n_bit = 32
algo = "IP"
n_iter_init = 200
g_eps = 5e-5

n_mic = 4
n_fft = 2048
n_iter = 1000

# load audio
wav, sample_rate = torchaudio.load(audio_src_path, channels_first=False)
wav /= torch.abs(wav).max() * 1.2
M = min(len(wav), n_mic)
spec_FTM = MultiSTFT(wav[:, :M], n_fft=n_fft)

separater = FastMNMF2(
    n_source=n_source,
    n_basis=n_basis,
    device=device,
    init_SCM=init_SCM,
    n_bit=n_bit,
    algo=algo,
    n_iter_init=n_iter_init,
    g_eps=g_eps,
)

separater.file_id = audio_src_path.split("/")[-1].split(".")[0]
separater.load_spectrogram(spec_FTM, sample_rate)
separater.solve(
    n_iter=n_iter,
    save_dir=audio_save_dir,
    save_likelihood=False,
    save_param=False,
    save_wav=True,
    interval_save=5,
)
torch.cuda.empty_cache()

Update FastMNMF2_IP-M=4-S=3-F=1025-K=32-init=circular-g=5e-05-bit=32-intv_norm=10-ID=classic_all_1_4chann  1000 times ...


100%|██████████| 1000/1000 [03:55<00:00,  4.25it/s]


## GAUSSIAN MNMF (Sawada)

Tests

In [31]:
import numpy as np
import soundfile as sf
import scipy.signal as ss
import librosa
import IPython.display as ipd
from ssspy.bss.mnmf import GaussMNMF as GaussMNMFBase
from tqdm import tqdm  # Ensure progress bar works

# -------------------------------
# ✅ Fix: Ensure GaussMNMF works without issues
# -------------------------------
class GaussMNMF(GaussMNMFBase):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.progress_bar = None

    def __call__(self, *args, n_iter: int = 100, **kwargs) -> np.ndarray:
        self.n_iter = n_iter
        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)
        super().update_once()
        self.progress_bar.update(1)

# -------------------------------
# ✅ Fix: Ensure STFT Parameters Are Correct
# -------------------------------
n_sources = 5
max_duration = 140
n_fft = 2048
hop_length = n_fft // 2  # Ensure noverlap < nperseg
win_length = 2048

# -------------------------------
# 1️⃣ LOAD MULTICHANNEL MIXTURE RECORDING
# -------------------------------
mixture_file = "../data_PAM/real/Fly_me_v2/MIX_v2.wav"
mixture_audio, sr = sf.read(mixture_file)
mixture_audio = mixture_audio.T  # Convert to (Channels, Samples)

print(f"✅ Loaded mixture: {mixture_file}, Shape: {mixture_audio.shape}, Sample Rate: {sr}")

# -------------------------------
# 2️⃣ LOAD INDIVIDUAL INSTRUMENT RECORDINGS (MICS D’APPOINT)
# -------------------------------
instrument_files = {
    "basse": "../data_PAM/real/Fly_me_v2/Basse_v2.wav",
    "percus": "../data_PAM/real/Fly_me_v2/Percus_v2.wav",
    "piano": "../data_PAM/real/Fly_me_v2/Piano_v2.wav",
    "voix": "../data_PAM/real/Fly_me_v2/Voix_v2.wav",
    "saxo": "../data_PAM/real/Fly_me_v2/Saxo_v2.wav"
}

individual_tracks = {}

for instrument, file in instrument_files.items():
    audio, sr_instr = sf.read(file)

    # ✅ Fix: Ensure shape is (Channels, Samples)
    if audio.ndim == 1:  # Convert mono to (1, Samples)
        audio = np.expand_dims(audio, axis=0)
    elif audio.shape[1] == 2:  # Convert (Samples, 2) → (2, Samples)
        audio = audio.T  

    individual_tracks[instrument] = audio
    print(f"✅ Loaded {instrument}: {file}, Shape: {audio.shape}")

# -------------------------------
# 3️⃣ COMPUTE STFT FOR MIXTURE & SEPARATE RECORDINGS
# -------------------------------
print(f"✅ STFT parameters: n_fft={n_fft}, hop_length={hop_length}")

# Compute STFT for mixture
X_mixture = ss.stft(mixture_audio, nperseg=n_fft, noverlap=hop_length - 1)[2]  # STFT returns (Freq, Time, Channels)

# ✅ Fix: Ensure X_mixture is (Channels, Time, Freq)
X_mixture = np.moveaxis(X_mixture, 2, 0)  # Convert (Freq, Time, Channels) → (Channels, Time, Freq)
print(f"✅ X_mixture shape: {X_mixture.shape}")  # Should be (2, Time, Freq)

# Compute STFT for each instrument recording
X_individual = {
    instr: np.moveaxis(
        ss.stft(individual_tracks[instr], nperseg=n_fft, noverlap=hop_length - 1)[2],
        2, 0  # Convert (Freq, Time, Channels) → (Channels, Time, Freq)
    )
    for instr in individual_tracks.keys()
}

print("✅ Computed STFTs.")

# -------------------------------
# 4️⃣ APPLY GAUSSIAN MNMF WITH SUPERVISION
# -------------------------------
n_sources = len(instrument_files)  # Number of expected sources
n_basis = 30  # Number of basis components for MNMF

# Initialize MNMF
nmf = GaussMNMF(n_basis=n_basis, n_sources=n_sources)

# ✅ Fix: Use correct `fit()` and `transform()` functions
nmf.fit(X_mixture, reference=X_individual)
Y = nmf.transform(X_mixture)

print("✅ Gaussian MNMF separation completed.")

# -------------------------------
# 5️⃣ RECONSTRUCT AUDIO, SAVE & LISTEN
# -------------------------------
sources = np.array([
    ss.istft(Y[src], nperseg=n_fft, noverlap=hop_length - 1)[1]
    for src in range(Y.shape[0])
])

# Save separated sources and play them
for i, (instrument, _) in enumerate(instrument_files.items()):
    output_file = f"separated_{instrument}.wav"
    sf.write(output_file, sources[i], sr)
    print(f"✅ Saved: {output_file}")
    ipd.display(ipd.Audio(sources[i], rate=sr))

print("✅ Separation complete!")


✅ Loaded mixture: ../data_PAM/real/Fly_me_v2/MIX_v2.wav, Shape: (2, 5659324), Sample Rate: 44100
✅ Loaded basse: ../data_PAM/real/Fly_me_v2/Basse_v2.wav, Shape: (2, 5659324)
✅ Loaded percus: ../data_PAM/real/Fly_me_v2/Percus_v2.wav, Shape: (2, 5659324)
✅ Loaded piano: ../data_PAM/real/Fly_me_v2/Piano_v2.wav, Shape: (2, 5659324)
✅ Loaded voix: ../data_PAM/real/Fly_me_v2/Voix_v2.wav, Shape: (2, 5659324)
✅ Loaded saxo: ../data_PAM/real/Fly_me_v2/Saxo_v2.wav, Shape: (2, 5659324)
✅ STFT parameters: n_fft=2048, hop_length=1024
✅ X_mixture shape: (5523, 2, 1025)
✅ Computed STFTs.


AttributeError: 'GaussMNMF' object has no attribute 'fit'

In [32]:
import numpy as np
import soundfile as sf
import scipy.signal as ss
import librosa
import IPython.display as ipd
from ssspy.bss.mnmf import GaussMNMF as GaussMNMFBase
from tqdm import tqdm  # Ensure progress bar works

# -------------------------------
# ✅ Fix: Ensure GaussMNMF works without issues
# -------------------------------
class GaussMNMF(GaussMNMFBase):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.progress_bar = None

    def __call__(self, *args, n_iter: int = 100, **kwargs) -> np.ndarray:
        self.n_iter = n_iter
        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)
        super().update_once()
        self.progress_bar.update(1)

# -------------------------------
# ✅ Fix: Ensure STFT Parameters Are Correct
# -------------------------------
n_sources = 5
max_duration = 140
n_fft = 2048
hop_length = n_fft // 2  # Ensure noverlap < nperseg
win_length = 2048

# -------------------------------
# 1️⃣ LOAD MULTICHANNEL MIXTURE RECORDING
# -------------------------------
mixture_file = "../data_PAM/real/Fly_me_v2/MIX_v2.wav"
mixture_audio, sr = sf.read(mixture_file)
mixture_audio = mixture_audio.T  # Convert to (Channels, Samples)

print(f"✅ Loaded mixture: {mixture_file}, Shape: {mixture_audio.shape}, Sample Rate: {sr}")

# -------------------------------
# 2️⃣ LOAD INDIVIDUAL INSTRUMENT RECORDINGS (MICS D’APPOINT)
# -------------------------------
instrument_files = {
    "basse": "../data_PAM/real/Fly_me_v2/Basse_v2.wav",
    "percus": "../data_PAM/real/Fly_me_v2/Percus_v2.wav",
    "piano": "../data_PAM/real/Fly_me_v2/Piano_v2.wav",
    "voix": "../data_PAM/real/Fly_me_v2/Voix_v2.wav",
    "saxo": "../data_PAM/real/Fly_me_v2/Saxo_v2.wav"
}

individual_tracks = {}

for instrument, file in instrument_files.items():
    audio, sr_instr = sf.read(file)

    # ✅ Fix: Ensure shape is (Channels, Samples)
    if audio.ndim == 1:  # Convert mono to (1, Samples)
        audio = np.expand_dims(audio, axis=0)
    elif audio.shape[1] == 2:  # Convert (Samples, 2) → (2, Samples)
        audio = audio.T  

    individual_tracks[instrument] = audio
    print(f"✅ Loaded {instrument}: {file}, Shape: {audio.shape}")

# -------------------------------
# 3️⃣ COMPUTE STFT FOR MIXTURE & SEPARATE RECORDINGS
# -------------------------------
print(f"✅ STFT parameters: n_fft={n_fft}, hop_length={hop_length}")

# Compute STFT for mixture
X_mixture = ss.stft(mixture_audio, nperseg=n_fft, noverlap=hop_length - 1)[2]  # STFT returns (Freq, Time, Channels)

# ✅ Fix: Ensure X_mixture is (Channels, Time, Freq)
X_mixture = np.moveaxis(X_mixture, 2, 0)  # Convert (Freq, Time, Channels) → (Channels, Time, Freq)
print(f"✅ X_mixture shape: {X_mixture.shape}")  # Should be (2, Time, Freq)

# Compute STFT for each instrument recording
X_individual = {
    instr: np.moveaxis(
        ss.stft(individual_tracks[instr], nperseg=n_fft, noverlap=hop_length - 1)[2],
        2, 0  # Convert (Freq, Time, Channels) → (Channels, Time, Freq)
    )
    for instr in individual_tracks.keys()
}

print("✅ Computed STFTs.")

# -------------------------------
# 4️⃣ APPLY GAUSSIAN MNMF WITH SUPERVISION
# -------------------------------
n_sources = len(instrument_files)  # Number of expected sources
n_basis = 30  # Number of basis components for MNMF

# Initialize MNMF
nmf = GaussMNMF(n_basis=n_basis, n_sources=n_sources)

# ✅ Fix: Use __call__() instead of fit()
Y = nmf(X_mixture, reference=X_individual)

print("✅ Gaussian MNMF separation completed.")

# -------------------------------
# 5️⃣ RECONSTRUCT AUDIO, SAVE & LISTEN
# -------------------------------
sources = np.array([
    ss.istft(Y[src], nperseg=n_fft, noverlap=hop_length - 1)[1]
    for src in range(Y.shape[0])
])

# Save separated sources and play them
for i, (instrument, _) in enumerate(instrument_files.items()):
    output_file = f"separated_{instrument}.wav"
    sf.write(output_file, sources[i], sr)
    print(f"✅ Saved: {output_file}")
    ipd.display(ipd.Audio(sources[i], rate=sr))

print("✅ Separation complete!")


✅ Loaded mixture: ../data_PAM/real/Fly_me_v2/MIX_v2.wav, Shape: (2, 5659324), Sample Rate: 44100
✅ Loaded basse: ../data_PAM/real/Fly_me_v2/Basse_v2.wav, Shape: (2, 5659324)
✅ Loaded percus: ../data_PAM/real/Fly_me_v2/Percus_v2.wav, Shape: (2, 5659324)
✅ Loaded piano: ../data_PAM/real/Fly_me_v2/Piano_v2.wav, Shape: (2, 5659324)
✅ Loaded voix: ../data_PAM/real/Fly_me_v2/Voix_v2.wav, Shape: (2, 5659324)
✅ Loaded saxo: ../data_PAM/real/Fly_me_v2/Saxo_v2.wav, Shape: (2, 5659324)
✅ STFT parameters: n_fft=2048, hop_length=1024
✅ X_mixture shape: (5523, 2, 1025)
✅ Computed STFTs.


: 