In [1]:
import os
import torch
import torchaudio
import sys

notebook_dir = os.getcwd()
sys.path.append(os.path.join(notebook_dir, "../src_torch"))

import soundfile as sf

from separation.FastMNMF2 import FastMNMF2
from Base import MultiSTFT


## FastMNMF2

In [11]:
audio_src_dir = "classic"
audio_src_name = "classic_all_1_4chann.wav"
audio_src_path = os.path.join(notebook_dir, "../", audio_src_dir, audio_src_name)

audio_save_dir = os.path.join(notebook_dir, "..", "result")
if not os.path.exists(audio_save_dir):
    os.makedirs(audio_save_dir)

n_source = 3
n_basis = 32
device = "cuda:1" if torch.cuda.is_available() else "cpu"
init_SCM = "circular"
n_bit = 32
algo = "IP"
n_iter_init = 200
g_eps = 5e-5

n_mic = 4
n_fft = 2048
n_iter = 1000

# load audio
wav, sample_rate = torchaudio.load(audio_src_path, channels_first=False)
wav /= torch.abs(wav).max() * 1.2
M = min(len(wav), n_mic)
spec_FTM = MultiSTFT(wav[:, :M], n_fft=n_fft)

separater = FastMNMF2(
    n_source=n_source,
    n_basis=n_basis,
    device=device,
    init_SCM=init_SCM,
    n_bit=n_bit,
    algo=algo,
    n_iter_init=n_iter_init,
    g_eps=g_eps,
)

separater.file_id = audio_src_path.split("/")[-1].split(".")[0]
separater.load_spectrogram(spec_FTM, sample_rate)
separater.solve(
    n_iter=n_iter,
    save_dir=audio_save_dir,
    save_likelihood=False,
    save_param=False,
    save_wav=True,
    interval_save=5,
)
torch.cuda.empty_cache()

Update FastMNMF2_IP-M=4-S=3-F=1025-K=32-init=circular-g=5e-05-bit=32-intv_norm=10-ID=classic_all_1_4chann  1000 times ...


100%|██████████| 1000/1000 [03:55<00:00,  4.25it/s]


## GAUSSIAN MNMF
#### Sawada

In [None]:
import soundfile as sf
import numpy as np
import librosa

# Define your dataset paths
wav_file = "your_multichannel_recording.wav"  # Replace with your main mix
instrument_files = {
    "violin": "violin_isolated.wav",
    "cello": "cello_isolated.wav",
    "piano": "piano_isolated.wav",
}  # For classical, if available

# Load multichannel recording (assumes (N_samples, N_channels))
multi_channel_audio, sr = sf.read(wav_file)  # Shape (N_samples, Channels)

# Transpose to (Channels, Samples) for MNMF
multi_channel_audio = multi_channel_audio.T

print(f"Loaded {wav_file} with shape: {multi_channel_audio.shape}, Sample Rate: {sr}")
