In [1]:
# Cell 0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import soundfile as sf
import json
from pathlib import Path
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

  import pynvml  # type: ignore[import]


PyTorch version: 2.10.0.dev20251107+cu128
CUDA available: True


In [2]:
# Cell 1
def stft(signal, n_fft=512, hop_length=256, window='hann'):
    """Simple STFT using torch (no librosa dependency)"""
    win = torch.hann_window(n_fft, device=signal.device)
    signal = F.pad(signal, (n_fft//2, n_fft//2))
    frames = signal.unfold(1, n_fft, hop_length).transpose(0, 1)  # (T, n_fft)
    frames = frames * win
    spec = torch.fft.rfft(frames, dim=-1)  # (T, n_fft//2+1)
    return spec

def istft(spec, hop_length=256, length=None):
    """Inverse STFT"""
    win = torch.hann_window(spec.shape[-2]*2-1, device=spec.device)
    frames = torch.fft.irfft(spec, dim=-1)
    frames = frames * win
    signal = F.fold(
        frames.transpose(0, 1).unsqueeze(1),
        output_size=(length + hop_length, 1),
        kernel_size=(spec.shape[-2]*2-1, 1),
        stride=(hop_length, 1)
    ).squeeze()
    return signal[..., :length] if length else signal

# Example
mix, sr = sf.read("example_mix.wav")  # Replace with real data
mix = torch.from_numpy(mix).float().T  # (M, T)
spec = stft(mix[0], n_fft=512, hop_length=256)  # (F, T)
print("STFT shape (single channel):", spec.shape)

LibsndfileError: Error opening 'example_mix.wav': System error.

In [3]:
# Cell 2
class NBSSDataset(Dataset):
    def __init__(self, json_file, max_len_sec=8.0, sr=16000, n_fft=512, hop=256):
        with open(json_file) as f:
            self.meta = json.load(f)
        self.sr = sr
        self.n_fft = n_fft
        self.hop = hop
        self.max_frames = int(max_len_sec * sr / hop)
        self.ref_mic = 0  # reference mic

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        item = self.meta[idx]
        mix_path = item['mixture']
        s1_path = item['sources'][0]
        s2_path = item['sources'][1]

        # Load waveforms
        mix, _ = sf.read(mix_path)  # (T, M)
        s1, _ = sf.read(s1_path)    # (T,)
        s2, _ = sf.read(s2_path)    # (T,)

        T = mix.shape[0]
        if T > self.max_frames * self.hop:
            start = np.random.randint(0, T - self.max_frames * self.hop)
            end = start + self.max_frames * self.hop
            mix = mix[start:end]
            s1 = s1[start:end]
            s2 = s2[start:end]
        elif T < self.max_frames * self.hop:
            pad = self.max_frames * self.hop - T
            mix = np.pad(mix, ((0, pad), (0, 0)))
            s1 = np.pad(s1, (0, pad))
            s2 = np.pad(s2, (0, pad))

        mix = torch.from_numpy(mix).float()  # (T, M)
        s1 = torch.from_numpy(s1).float()
        s2 = torch.from_numpy(s2).float()

        # STFT
        X = torch.stft(mix, n_fft=self.n_fft, hop_length=self.hop,
                       window=torch.hann_window(self.n_fft), return_complex=True)
        # X: (M, F, T)
        X = X.permute(1, 2, 0)  # (F, T, M)

        Y1 = torch.stft(s1.unsqueeze(1), n_fft=self.n_fft, hop_length=self.hop,
                        window=torch.hann_window(self.n_fft), return_complex=True)[:, self.ref_mic]
        Y2 = torch.stft(s2.unsqueeze(1), n_fft=self.n_fft, hop_length=self.hop,
                        window=torch.hann_window(self.n_fft), return_complex=True)[:, self.ref_mic]
        Y = torch.stack([Y1, Y2], dim=0)  # (2, F, T)

        # Normalize per frequency using reference mic magnitude
        X_ref = X[:, :, self.ref_mic].abs().mean(dim=1, keepdim=True)  # (F, 1)
        X_norm = X / (X_ref + 1e-8)

        # Convert to real-valued
        X_real = torch.cat([X_norm.real, X_norm.imag], dim=-1)  # (F, T, 2M)
        Y_real = torch.cat([Y.real, Y.imag], dim=-1)            # (2, F, T)

        return X_real, Y_real, X_ref.squeeze(-1)  # for denormalization

In [4]:
# Cell 3
class NarrowBandSeparator(nn.Module):
    def __init__(self, input_dim=16, hidden1=256, hidden2=128):  # 2M=16 for M=8
        super().__init__()
        self.bilstm1 = nn.LSTM(input_dim, hidden1, batch_first=True, bidirectional=True)
        self.bilstm2 = nn.LSTM(hidden1*2, hidden2, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden2*2, 4)  # 2 sources × (real + imag)

    def forward(self, x):
        # x: (B, T, 2M)
        x, _ = self.bilstm1(x)
        x, _ = self.bilstm2(x)
        x = self.fc(x)  # (B, T, 4)
        return x.view(-1, 2, 2, x.size(1))  # (B, 2, 2, T) → real/imag

In [5]:
# Cell 4
def complex_from_real_imag(real_imag):
    # real_imag: (B, 2, 2, T) → (B, 2, T) complex
    return torch.complex(real_imag[:, :, 0], real_imag[:, :, 1])

def si_sdr_loss(est, ref):
    # est, ref: (B, T)
    ref = ref - ref.mean(dim=1, keepdim=True)
    est = est - est.mean(dim=1, keepdim=True)
    alpha = (est * ref).sum(dim=1, keepdim=True) / (ref.pow(2).sum(dim=1, keepdim=True) + 1e-8)
    s_target = alpha * ref
    e_noise = est - s_target
    return -10 * torch.log10(
        s_target.pow(2).sum(dim=1) / (e_noise.pow(2).sum(dim=1) + 1e-8) + 1e-8
    ).mean()

def full_band_pit_loss(pred_complex, target_complex, X_ref):
    """
    pred_complex: (B, N, F, T)
    target_complex: (B, N, F, T)
    X_ref: (B, F) — normalization factors
    """
    B, N, F, T = pred_complex.shape

    # Denormalize
    pred = pred_complex * X_ref.view(B, 1, F, 1)
    target = target_complex

    # Reconstruct waveforms
    pred_wav = torch.zeros(B, N, T * 256, device=pred.device)
    target_wav = torch.zeros(B, N, T * 256, device=target.device)

    for b in range(B):
        for n in range(N):
            pred_wav[b,n] = istft(pred[b,n], hop_length=256, length=T*256)
            target_wav[b,n] = istft(target[b,n], hop_length=256, length=T*256)

    # PIT over utterances
    loss_perm1 = si_sdr_loss(pred_wav[:,0], target_wav[:,0]) + si_sdr_loss(pred_wav[:,1], target_wav[:,1])
    loss_perm2 = si_sdr_loss(pred_wav[:,0], target_wav[:,1]) + si_sdr_loss(pred_wav[:,1], target_wav[:,0])
    return torch.min(loss_perm1, loss_perm2)

In [6]:
# Cell 5
class NBSSTrainer:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = optim.Adam(model.parameters(), lr=1e-3)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=5)

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        for X_real, Y_real, X_ref in self.train_loader:
            X_real = X_real.to(self.device)  # (B, F, T, 2M)
            Y_real = Y_real.to(self.device)  # (B, 2, F, T)
            X_ref = X_ref.to(self.device)    # (B, F)

            B, F, T, _ = X_real.shape
            loss = 0
            for f in range(F):
                x_f = X_real[:, f]  # (B, T, 2M)
                y_f = Y_real[:, :, f]  # (B, 2, T)

                pred_f = self.model(x_f)  # (B, 2, 2, T)
                pred_complex = complex_from_real_imag(pred_f)  # (B, 2, T)
                target_complex = torch.complex(y_f[:, 0], y_f[:, 1])  # (B, 2, T)

                # Reshape for fPIT
                pred_complex = pred_complex.unsqueeze(2)  # (B, 2, 1, T)
                target_complex = target_complex.unsqueeze(2)

                if f == 0:
                    pred_all = pred_complex
                    target_all = target_complex
                else:
                    pred_all = torch.cat([pred_all, pred_complex], dim=2)
                    target_all = torch.cat([target_all, target_complex], dim=2)

            # Full-band PIT
            batch_loss = full_band_pit_loss(pred_all, target_all, X_ref)
            loss += batch_loss

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
            self.optimizer.step()
            total_loss += loss.item()

        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for X_real, Y_real, X_ref in self.val_loader:
                X_real = X_real.to(self.device)
                Y_real = Y_real.to(self.device)
                X_ref = X_ref.to(self.device)

                B, F, T, _ = X_real.shape
                pred_all = None
                target_all = None
                for f in range(F):
                    x_f = X_real[:, f]
                    y_f = Y_real[:, :, f]
                    pred_f = self.model(x_f)
                    pred_c = complex_from_real_imag(pred_f)
                    targ_c = torch.complex(y_f[:, 0], y_f[:, 1])
                    pred_c = pred_c.unsqueeze(2)
                    targ_c = targ_c.unsqueeze(2)
                    pred_all = pred_c if pred_all is None else torch.cat([pred_all, pred_c], dim=2)
                    target_all = targ_c if target_all is None else torch.cat([target_all, targ_c], dim=2)

                loss = full_band_pit_loss(pred_all, target_all, X_ref)
                total_loss += loss.item()

        return total_loss / len(self.val_loader)

In [None]:
# Cell 6
# Dummy data for demo (replace with real WSJ0-2mix spatialized)
train_meta = [{"mixture": "mix1.wav", "sources": ["s1.wav", "s2.wav"]}] * 100
val_meta = train_meta[:20]

train_set = NBSSDataset(train_meta, max_len_sec=4.0)
val_set = NBSSDataset(val_meta, max_len_sec=4.0)

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4)

model = NarrowBandSeparator(input_dim=16)  # 8 mics → 2*8=16
trainer = NBSSTrainer(model, train_loader, val_loader, device='cpu')

for epoch in range(5):
    train_loss = trainer.train_epoch()
    val_loss = trainer.validate()
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")