In [1]:
import numpy as np
import soundfile as sf
import scipy.signal as ss
import librosa
import IPython.display as ipd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from ssspy.bss.mnmf import GaussMNMF as GaussMNMFBase

# -------------------------------
# ✅ Custom GaussMNMF Class (With Progress Bar)
# -------------------------------
class GaussMNMF(GaussMNMFBase):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.progress_bar = None

    def __call__(self, *args, n_iter: int = 200, **kwargs) -> np.ndarray:
        self.n_iter = n_iter
        return super().__call__(*args, n_iter=n_iter, **kwargs)

    def update_once(self) -> None:
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=self.n_iter)
        super().update_once()
        self.progress_bar.update(1)

# -------------------------------
# ✅ Parameters
# -------------------------------
n_sources = 5  # Number of instruments
n_basis = 30   # Number of basis components for MNMF
n_fft = 4096   # Larger FFT for better separation
hop_length = n_fft // 2
window = "hann"
extract_duration = 20  # Extract the last 20 seconds



In [2]:
# -------------------------------
# 1️⃣ Load Mixture Recording (Extract Last 20 sec)
# -------------------------------
mixture_file = "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2.wav"
mixture_audio, sr = sf.read(mixture_file)
mixture_audio = mixture_audio.T  # Convert to (Channels, Samples)

# ✅ Extract last 20 seconds
num_samples = sr * extract_duration
mixture_audio = mixture_audio[:, -num_samples:]  # Take last `extract_duration` seconds

print(f"✅ Loaded mixture: {mixture_file}, Shape: {mixture_audio.shape}, Sample Rate: {sr}")

# -------------------------------
# 2️⃣ Load Individual Instrument Recordings (Extract Last 20 sec)
# -------------------------------
instrument_files = {
    "basse": "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_BASSE.wav",
    "percus": "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_PERCU.wav",
    "piano": "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_PIANO.wav",
    "voix": "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_VOIX.wav",
    "saxo": "../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_SAXO.wav"
}

individual_tracks = {}

for instrument, file in instrument_files.items():
    audio, sr_instr = sf.read(file)

    # ✅ Fix shape: Ensure (Channels, Samples)
    if audio.ndim == 1:  
        audio = np.expand_dims(audio, axis=0)  
    elif audio.shape[1] == 2:  
        audio = audio.T  

    # ✅ Extract last 20 seconds
    audio = audio[:, -num_samples:]
    
    individual_tracks[instrument] = audio
    print(f"✅ Loaded {instrument}: {file}, Shape: {audio.shape}")


✅ Loaded mixture: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2.wav, Shape: (2, 882000), Sample Rate: 44100
✅ Loaded basse: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_BASSE.wav, Shape: (2, 882000)
✅ Loaded percus: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_PERCU.wav, Shape: (2, 882000)
✅ Loaded piano: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_PIANO.wav, Shape: (2, 882000)
✅ Loaded voix: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_VOIX.wav, Shape: (2, 882000)
✅ Loaded saxo: ../data_PAM/real/Fly_me_v2_indiv/Fly_me_v2_SAXO.wav, Shape: (2, 882000)


In [3]:
# -------------------------------
# 3️⃣ Compute STFT (Using Hann Window)
# -------------------------------
print(f"✅ STFT parameters: n_fft={n_fft}, hop_length={hop_length}")

# Compute STFT for mixture
_, _, X_mixture = ss.stft(mixture_audio, window=window, nperseg=n_fft, noverlap=n_fft - hop_length)

# ✅ Fix shape: Swap Freq & Time to match (Channels, Time, Freq)
X_mixture = np.moveaxis(X_mixture, 1, 2)
print(f"✅ X_mixture shape after fix: {X_mixture.shape}")  # Should be (Channels, Time, Freq)

# Compute STFT for each instrument recording
X_individual = {
    instr: np.moveaxis(
        ss.stft(individual_tracks[instr], window=window, nperseg=n_fft, noverlap=n_fft - hop_length)[2],
        1, 2  # ✅ Fix shape to match X_mixture
    )
    for instr in individual_tracks.keys()
}

print("✅ Computed STFTs.")


✅ STFT parameters: n_fft=4096, hop_length=2048
✅ X_mixture shape after fix: (2, 432, 2049)
✅ Computed STFTs.


In [None]:
# -------------------------------
# 4️⃣ Apply Gaussian MNMF with Partitioning
# -------------------------------
try:
    nmf = GaussMNMF(n_basis=n_basis, n_sources=n_sources, partitioning=True, rng=np.random.default_rng(42))
    print("✅ GaussMNMF initialized successfully.")
    
    Y = nmf(X_mixture, reference=X_individual)
    print("✅ Gaussian MNMF separation completed.")
except Exception as e:
    print(f"❌ Error during separation: {e}")
    exit()

✅ GaussMNMF initialized successfully.


  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
# -------------------------------
# 5️⃣ Reconstruct Audio, Save & Listen
# -------------------------------
try:
    _, waveform_est = ss.istft(Y, window=window, nperseg=n_fft, noverlap=n_fft - hop_length)

    for i, (instrument, _) in enumerate(instrument_files.items()):
        output_file = f"separated_{instrument}.wav"
        sf.write(output_file, waveform_est[i], sr)
        print(f"✅ Saved: {output_file}")
        ipd.display(ipd.Audio(waveform_est[i], rate=sr))

    print("✅ Separation complete!")

    # ✅ Plot Loss Curve
    plt.figure()
    plt.plot(nmf.loss[10:])
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.title("GaussMNMF Loss Curve")
    plt.show()
except Exception as e:
    print(f"❌ Error during ISTFT: {e}")
    exit()
