# Slakh Audio Source Separation

This notebook demonstrates working with the Slakh dataset:
1. Define utility classes for working with tracks
2. Select a random track and verify mix correlation
3. Find and play the track with the fewest stems
4. **Oracle mask computation (performance upper bounds)**
5. **Evaluate oracle performance**

## 1. Imports and Utility Classes

In [52]:
import soundfile as sf
import sounddevice as sd
import yaml
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
import random
import matplotlib.pyplot as plt
from scipy import signal

In [53]:
@dataclass
class StemInfo:
    """Information about a single instrument stem."""
    stem_id: str
    inst_class: str
    midi_program_name: str
    program_num: int
    is_drum: bool
    integrated_loudness: Optional[float]
    plugin_name: str
    audio_rendered: bool
    midi_saved: bool

In [54]:
class SlakhTrack:
    """
    A simple tool to work with Slakh dataset tracks.
    
    Usage:
        track = SlakhTrack('/path/to/track/directory')
        info = track.get_track_info()
        drums = track.get_stems_by_class('Drums')
    """
    
    def __init__(self, track_dir: Union[str, Path]):
        self.track_dir = Path(track_dir)
        self.metadata_path = self.track_dir / 'metadata.yaml'
        
        if not self.metadata_path.exists():
            raise FileNotFoundError(f"metadata.yaml not found in {track_dir}")
        
        with open(self.metadata_path, 'r') as f:
            self.metadata = yaml.safe_load(f)
    
    def get_track_info(self) -> Dict:
        """Get general track information."""
        return {
            'uuid': self.metadata.get('UUID'),
            'normalized': self.metadata.get('normalized'),
            'overall_gain': self.metadata.get('overall_gain'),
            'normalization_factor': self.metadata.get('normalization_factor'),
            'target_peak': self.metadata.get('target_peak'),
            'num_stems': len(self.metadata.get('stems', {}))
        }
    
    def get_stem_info(self, stem_id: str) -> Optional[StemInfo]:
        stems = self.metadata.get('stems', {})
        if stem_id not in stems:
            return None
        
        stem_data = stems[stem_id]
        return StemInfo(
            stem_id=stem_id,
            inst_class=stem_data.get('inst_class'),
            midi_program_name=stem_data.get('midi_program_name'),
            program_num=stem_data.get('program_num'),
            is_drum=stem_data.get('is_drum', False),
            integrated_loudness=stem_data.get('integrated_loudness'),
            plugin_name=stem_data.get('plugin_name'),
            audio_rendered=stem_data.get('audio_rendered', False),
            midi_saved=stem_data.get('midi_saved', False)
        )
    
    def get_all_stems(self) -> List[StemInfo]:
        stems = self.metadata.get('stems', {})
        result = []
        for stem_id in stems.keys():
            stem_info = self.get_stem_info(stem_id)
            if stem_info is not None:
                result.append(stem_info)
        return result
    
    def get_stems_by_class(self, inst_class: str) -> List[StemInfo]:
        return [stem for stem in self.get_all_stems() if stem.inst_class == inst_class]
    
    def get_drum_stems(self) -> List[StemInfo]:
        return [stem for stem in self.get_all_stems() if stem.is_drum]
    
    def get_stem_paths(self, stem_id: str) -> Dict[str, Path]:
        return {
            'audio': self.track_dir / 'stems' / f'{stem_id}.wav',
            'midi': self.track_dir / 'MIDI' / f'{stem_id}.mid'
        }
    
    def get_mix_path(self) -> Path:
        return self.track_dir / 'mix.wav'
    
    def list_available_instruments(self) -> List[str]:
        return list(set(stem.inst_class for stem in self.get_all_stems()))
    
    def __repr__(self) -> str:
        info = self.get_track_info()
        return f"SlakhTrack(uuid='{info['uuid']}', stems={info['num_stems']})"

## 2. Random Track Selection and Mix Correlation Verification

Select a random track and verify that the correlation between the reconstructed mix (sum of stems) matches the official mix.

In [55]:
# Get all available tracks
dataset_root = Path("./data/babyslakh_16k")
all_tracks = sorted([d for d in dataset_root.iterdir() if d.is_dir() and d.name.startswith('Track')])

print(f"Found {len(all_tracks)} tracks in the dataset")

# Select a random track
random_track_dir = random.choice(all_tracks)
print(f"\nRandomly selected: {random_track_dir.name}")

# Load the track using our utility class
track = SlakhTrack(random_track_dir)
print(f"Track info: {track}")

# Display instruments in this track
instruments = [stem.inst_class for stem in track.get_all_stems()]
print(f"Instruments: {instruments}")

Found 20 tracks in the dataset

Randomly selected: Track00018
Track info: SlakhTrack(uuid='3cccac3e7cd68df8f01a1a61e6ef3172', stems=18)
Instruments: ['Synth Lead', 'Synth Lead', 'Organ', 'Strings (continued)', 'Organ', 'Strings (continued)', 'Strings (continued)', 'Piano', 'Piano', 'Piano', 'Synth Pad', 'Guitar', 'Sound Effects', 'Sound Effects', 'Guitar', 'Bass', 'Bass', 'Drums']


## 3. Find Track with Fewest Stems

In [56]:
# Count stems in each track
track_stem_counts = []
for track_dir in all_tracks:
    try:
        t = SlakhTrack(track_dir)
        num_stems = t.get_track_info()['num_stems']
        track_stem_counts.append((track_dir, num_stems))
    except Exception as e:
        print(f"Warning: Could not load {track_dir.name}: {e}")

# Sort by number of stems (ascending)
track_stem_counts.sort(key=lambda x: x[1])

print("Tracks sorted by number of stems:")
for track_dir, num_stems in track_stem_counts[:10]:  # Show first 10
    print(f"  {track_dir.name}: {num_stems} stems")

# Select the track with fewest stems
minimal_track_dir, minimal_stem_count = track_stem_counts[1]
print(f"\n{'='*50}")
print(f"Track with fewest stems: {minimal_track_dir.name}")
print(f"Number of stems: {minimal_stem_count}")
print(f"{'='*50}")

Tracks sorted by number of stems:
  Track00008: 7 stems
  Track00010: 7 stems
  Track00015: 7 stems
  Track00003: 9 stems
  Track00004: 9 stems
  Track00011: 9 stems
  Track00013: 9 stems
  Track00007: 10 stems
  Track00020: 10 stems
  Track00001: 11 stems

Track with fewest stems: Track00010
Number of stems: 7


In [57]:
# Load and display information about this track
minimal_track = SlakhTrack(minimal_track_dir)
print(f"\nTrack information:")
print(minimal_track)

print(f"\nInstruments in this track:")
for stem in minimal_track.get_all_stems():
    print(f"  {stem.stem_id}: {stem.inst_class} - {stem.midi_program_name}")


Track information:
SlakhTrack(uuid='763348b79ed855fa294a7c980392178b', stems=7)

Instruments in this track:
  S00: Bass - Electric Bass (finger)
  S01: Guitar - Acoustic Guitar (steel)
  S02: Guitar - Acoustic Guitar (steel)
  S03: Synth Pad - Pad 7 (halo)
  S04: Synth Pad - Pad 7 (halo)
  S05: Piano - Honky-tonk Piano
  S06: Drums - Drums


## 4. Oracle Mask Implementation (Performance Upper Bounds)

Oracle masks use ground-truth stem information to establish theoretical performance ceilings.

### 4.1 STFT Configuration

In [58]:
# STFT parameters (optimized for 16kHz audio)
WINDOW_SIZE = 1024  # 64ms @ 16kHz
HOP_LENGTH = 256    # 75% overlap
FFT_SIZE = 1024     # 513 frequency bins up to 8kHz
WINDOW = 'hann'     # Hann window for smooth reconstruction

# Load minimal track mix for testing
mix_audio, sr = sf.read(minimal_track.get_mix_path())
trim_duration = 10  # seconds
trim_samples = int(sr * trim_duration)
mix_audio = mix_audio[:trim_samples]

print("STFT Configuration:")
print(f"  Window size: {WINDOW_SIZE} samples ({WINDOW_SIZE/sr*1000:.1f} ms)")
print(f"  Hop length: {HOP_LENGTH} samples ({HOP_LENGTH/sr*1000:.1f} ms)")
print(f"  FFT size: {FFT_SIZE}")
print(f"  Frequency bins: {FFT_SIZE//2 + 1} (0 to {sr/2:.0f} Hz)")
print(f"  Overlap: {(1 - HOP_LENGTH/WINDOW_SIZE)*100:.0f}%")
print(f"\nMix: {len(mix_audio) / sr:.2f} seconds @ {sr} Hz")

STFT Configuration:
  Window size: 1024 samples (64.0 ms)
  Hop length: 256 samples (16.0 ms)
  FFT size: 1024
  Frequency bins: 513 (0 to 8000 Hz)
  Overlap: 75%

Mix: 10.00 seconds @ 16000 Hz


### 4.2 STFT/ISTFT Functions

In [59]:
def compute_stft(audio, sr, window_size=WINDOW_SIZE, hop_length=HOP_LENGTH, fft_size=FFT_SIZE):
    """Compute STFT of audio signal."""
    f, t, Zxx = signal.stft(
        audio, fs=sr, window=WINDOW,
        nperseg=window_size,
        noverlap=window_size - hop_length,
        nfft=fft_size
    )
    return f, t, Zxx

def compute_istft(Zxx, sr, window_size=WINDOW_SIZE, hop_length=HOP_LENGTH, fft_size=FFT_SIZE):
    """Compute inverse STFT to reconstruct audio."""
    _, audio = signal.istft(
        Zxx, fs=sr, window=WINDOW,
        nperseg=window_size,
        noverlap=window_size - hop_length,
        nfft=fft_size
    )
    return audio

# Test on mix
f, t, mix_stft = compute_stft(mix_audio, sr)
print(f"✓ STFT computed: {mix_stft.shape} (freq x time)")
print(f"  Frequency range: {f[0]:.0f} - {f[-1]:.0f} Hz")
print(f"  Time frames: {len(t)}")
print(f"  Duration: {t[-1]:.2f} seconds")

✓ STFT computed: (513, 626) (freq x time)
  Frequency range: 0 - 8000 Hz
  Time frames: 626
  Duration: 10.00 seconds


### 4.3 Oracle Mask Functions

In [65]:
def ideal_binary_mask(stem_stfts):
    """
    Ideal Binary Mask (IBM): Winner-take-all per T-F bin.
    Assigns 1 to stem with largest magnitude, 0 to others.
    """
    magnitudes = np.array([np.abs(stft) for stft in stem_stfts])
    max_indices = np.argmax(magnitudes, axis=0)
    
    masks = []
    for i in range(len(stem_stfts)):
        mask = (max_indices == i).astype(float)
        masks.append(mask)
    
    return masks

def ideal_ratio_mask(stem_stfts, epsilon=1e-10):
    """
    Ideal Ratio Mask (IRM): Soft mask based on magnitude ratios.
    IRM_i = |S_i| / sum_j(|S_j|)
    """
    magnitudes = np.array([np.abs(stft) for stft in stem_stfts])
    total_magnitude = np.sum(magnitudes, axis=0) + epsilon
    
    masks = []
    for mag in magnitudes:
        mask = mag / total_magnitude
        masks.append(mask)
    
    return masks

def wiener_mask(stem_stfts, power=2, epsilon=1e-10):
    """
    Wiener-style mask (generalized power mask).
    Mask_i = |S_i|^p / sum_j(|S_j|^p)
    p=1: Same as IRM, p=2: Traditional Wiener filter
    """
    magnitudes = np.array([np.abs(stft) for stft in stem_stfts])
    powers = magnitudes ** power
    total_power = np.sum(powers, axis=0) + epsilon
    
    masks = []
    for mag_power in powers:
        mask = mag_power / total_power
        masks.append(mask)
    
    return masks

print("Oracle mask functions defined:\n\
  - Ideal Binary Mask (IBM)\n\
  - Ideal Ratio Mask (IRM)\n\
  - Wiener Mask (p=1,2)")

Oracle mask functions defined:
  - Ideal Binary Mask (IBM)
  - Ideal Ratio Mask (IRM)
  - Wiener Mask (p=1,2)


### 4.4 Evaluation Metrics

In [67]:
def si_sdr(reference, estimate, epsilon=1e-10):
    """
    Scale-Invariant Signal-to-Distortion Ratio (SI-SDR).
    More robust than SDR to gain differences.
    """
    # Ensure same length
    min_len = min(len(reference), len(estimate))
    reference = reference[:min_len]
    estimate = estimate[:min_len]
    
    # Scale-invariant projection
    alpha = np.dot(estimate, reference) / (np.dot(reference, reference) + epsilon)
    target = alpha * reference
    
    # Distortion
    residual = estimate - target
    
    # SI-SDR in dB
    si_sdr_value = 10 * np.log10(
        (np.sum(target**2) + epsilon) / (np.sum(residual**2) + epsilon)
    )
    
    return si_sdr_value

def evaluate_separation(reference_stems, estimated_stems, stem_names):
    """Evaluate separation quality using SI-SDR."""
    results = {}
    
    for i, name in enumerate(stem_names):
        sdr = si_sdr(reference_stems[i], estimated_stems[i])
        results[name] = sdr
    
    results['mean'] = np.mean(list(results.values()))
    return results

print("Evaluation functions defined:\n\
  - SI-SDR metric\n\
  - Batch evaluation")

Evaluation functions defined:
  - SI-SDR metric
  - Batch evaluation


### 4.5 Load Stems and Compute STFTs

In [68]:
# Load all stems from minimal track
print(f"Loading stems from {minimal_track_dir.name}...\n")

stem_names = []
stems_audio = []
stems_stft = []

for stem_info in minimal_track.get_all_stems():
    stem_path = minimal_track.get_stem_paths(stem_info.stem_id)['audio']
    
    # Skip macOS metadata files
    if stem_path.name.startswith('._'):
        continue
    
    try:
        audio, _ = sf.read(stem_path)
        audio = audio[:trim_samples]  # Trim to 10 seconds
        
        # Compute STFT
        f, t, Zxx = compute_stft(audio, sr)
        
        stem_name = f"{stem_info.stem_id}_{stem_info.inst_class}"
        stem_names.append(stem_name)
        stems_audio.append(audio)
        stems_stft.append(Zxx)
        
        print(f"  ✓ {stem_name}")
        
    except Exception as e:
        print(f"  ✗ {stem_path.name}: {e}")

print(f"\nLoaded {len(stem_names)} stems successfully")

Loading stems from Track00010...

  ✓ S00_Bass
  ✓ S01_Guitar
  ✓ S02_Guitar
  ✓ S03_Synth Pad
  ✓ S04_Synth Pad
  ✓ S05_Piano
  ✓ S06_Drums

Loaded 7 stems successfully


### 4.6 Verify Stems Sum to Mix

In [71]:
# Verify stems sum to mix
reconstructed_mix = np.sum(stems_audio, axis=0)
correlation = np.corrcoef(mix_audio.flatten(), reconstructed_mix.flatten())[0, 1]

print(f"Verification: Correlation = {correlation:.12f}")
if correlation > 0.99:
    print("✓ Stems sum to mix correctly!")
else:
    print("⚠ WARNING: Low correlation detected!")

Verification: Correlation = 0.999999845496
✓ Stems sum to mix correctly!


## 5. Compute and Evaluate Oracle Masks

### 5.1 Compute All Oracle Masks

In [73]:
# Compute all oracle mask types
print("Computing oracle masks...\n")

# IBM
ibm_masks = ideal_binary_mask(stems_stft)
print("Ideal Binary Masks")

# IRM
irm_masks = ideal_ratio_mask(stems_stft)
print("Ideal Ratio Masks")

# Wiener p=2
wiener2_masks = wiener_mask(stems_stft, power=2)
print("Wiener Masks (p=2)")

# Wiener p=1
wiener1_masks = wiener_mask(stems_stft, power=1)
print("Wiener Masks (p=1)")

Computing oracle masks...

Ideal Binary Masks
Ideal Ratio Masks
Wiener Masks (p=2)
Wiener Masks (p=1)


### 5.2 Apply Masks and Reconstruct

In [75]:
def reconstruct_with_masks(mix_stft, masks, sr):
    """Apply masks to mix STFT and reconstruct audio."""
    reconstructed = []
    for mask in masks:
        masked_stft = mix_stft * mask
        audio = compute_istft(masked_stft, sr)
        reconstructed.append(audio)
    return reconstructed

# Reconstruct with each mask type
print("Reconstructing audio with masks...\n")

ibm_reconstructed = reconstruct_with_masks(mix_stft, ibm_masks, sr)
print("IBM reconstruction")

irm_reconstructed = reconstruct_with_masks(mix_stft, irm_masks, sr)
print("IRM reconstruction")

wiener2_reconstructed = reconstruct_with_masks(mix_stft, wiener2_masks, sr)
print("Wiener (p=2) reconstruction")

wiener1_reconstructed = reconstruct_with_masks(mix_stft, wiener1_masks, sr)
print("Wiener (p=1) reconstruction")

Reconstructing audio with masks...

IBM reconstruction
IRM reconstruction
Wiener (p=2) reconstruction
Wiener (p=1) reconstruction


### 5.3 Evaluate Oracle Performance

In [76]:
print("=" * 70)
print("ORACLE MASK PERFORMANCE (Upper Bounds)")
print("=" * 70)

# IBM
print("\nIdeal Binary Mask (IBM):")
print("-" * 40)
ibm_results = evaluate_separation(stems_audio, ibm_reconstructed, stem_names)
for name in stem_names:
    print(f"  {name:30s}: {ibm_results[name]:6.2f} dB")
print(f"  {'Mean SI-SDR':30s}: {ibm_results['mean']:6.2f} dB")

# IRM
print("\nIdeal Ratio Mask (IRM):")
print("-" * 40)
irm_results = evaluate_separation(stems_audio, irm_reconstructed, stem_names)
for name in stem_names:
    print(f"  {name:30s}: {irm_results[name]:6.2f} dB")
print(f"  {'Mean SI-SDR':30s}: {irm_results['mean']:6.2f} dB")

# Wiener p=2
print("\nWiener Mask (p=2):")
print("-" * 40)
wiener2_results = evaluate_separation(stems_audio, wiener2_reconstructed, stem_names)
for name in stem_names:
    print(f"  {name:30s}: {wiener2_results[name]:6.2f} dB")
print(f"  {'Mean SI-SDR':30s}: {wiener2_results['mean']:6.2f} dB")

# Wiener p=1
print("\nWiener Mask (p=1):")
print("-" * 40)
wiener1_results = evaluate_separation(stems_audio, wiener1_reconstructed, stem_names)
for name in stem_names:
    print(f"  {name:30s}: {wiener1_results[name]:6.2f} dB")
print(f"  {'Mean SI-SDR':30s}: {wiener1_results['mean']:6.2f} dB")

print("\n" + "=" * 70)

ORACLE MASK PERFORMANCE (Upper Bounds)

Ideal Binary Mask (IBM):
----------------------------------------
  S00_Bass                      :   2.68 dB
  S01_Guitar                    :   4.38 dB
  S02_Guitar                    :   1.57 dB
  S03_Synth Pad                 :   7.39 dB
  S04_Synth Pad                 :   0.00 dB
  S05_Piano                     :   4.23 dB
  S06_Drums                     :   9.05 dB
  Mean SI-SDR                   :   4.18 dB

Ideal Ratio Mask (IRM):
----------------------------------------
  S00_Bass                      :   3.04 dB
  S01_Guitar                    :   5.10 dB
  S02_Guitar                    :   3.02 dB
  S03_Synth Pad                 :   7.00 dB
  S04_Synth Pad                 :   0.00 dB
  S05_Piano                     :   5.05 dB
  S06_Drums                     :   8.03 dB
  Mean SI-SDR                   :   4.46 dB

Wiener Mask (p=2):
----------------------------------------
  S00_Bass                      :   4.07 dB
  S01_Guitar       

### 5.4 Summary Comparison

In [77]:
# Summary comparison
print("\nSUMMARY - Mean SI-SDR by Method:")
print("-" * 40)
print(f"  {'IBM':15s}: {ibm_results['mean']:6.2f} dB")
print(f"  {'IRM':15s}: {irm_results['mean']:6.2f} dB")
print(f"  {'Wiener (p=2)':15s}: {wiener2_results['mean']:6.2f} dB")
print(f"  {'Wiener (p=1)':15s}: {wiener1_results['mean']:6.2f} dB")

print("\nInterpretation:")
print("- These are UPPER BOUNDS (oracle masks use ground truth)")
print("- Any classical method should achieve SI-SDR below these values")
print("- Higher SI-SDR = better separation quality")
print("- Wiener (p=2) typically performs best for oracle masks")


SUMMARY - Mean SI-SDR by Method:
----------------------------------------
  IBM            :   4.18 dB
  IRM            :   4.46 dB
  Wiener (p=2)   :   5.19 dB
  Wiener (p=1)   :   4.46 dB

Interpretation:
- These are UPPER BOUNDS (oracle masks use ground truth)
- Any classical method should achieve SI-SDR below these values
- Higher SI-SDR = better separation quality
- Wiener (p=2) typically performs best for oracle masks


## 6. Next Steps

**Phase 0 Complete:** Oracle masks established performance ceiling

**Next phases:**
1. **HPSS** (Harmonic-Percussive Source Separation via median filtering)
2. **Sub-band filtering** (FIR/IIR filter banks)
3. **NMF** (Non-negative Matrix Factorization)
4. **Full evaluation** on all tracks