<a href="https://colab.research.google.com/github/erikwirdemark/deep-mash/blob/main/deep_mash.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GTZAN Stems DataLoader
==========================================================
Organized into separate sections for easy navigation:
1. Audio Processing Utilities
2. Dataset Discovery
3. Dataset Class
4. DataLoader Creation


In [None]:
import os
import numpy as np
import librosa
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import soundfile as sf
from typing import Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

## 1. Audio Processing Utilities

In [None]:
class AudioProcessor:
    """Handles all audio loading, mixing, and spectrogram computation."""

    def __init__(
        self,
        sample_rate: int = 22050,
        n_fft: int = 2048,
        hop_length: int = 512,
        n_mels: int = 128
    ):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels

    def load_audio(self, file_path: Path) -> np.ndarray:
        """Load audio file at target sample rate."""
        audio, sr = librosa.load(file_path, sr=self.sample_rate, mono=True)
        return audio

    def mix_stems(self, stem_paths: List[Path]) -> np.ndarray:
        """Mix multiple stems into single audio."""
        stems = [self.load_audio(path) for path in stem_paths]

        # Ensure same length
        min_len = min(len(stem) for stem in stems)
        stems = [stem[:min_len] for stem in stems]

        # Mix with equal weights
        mixed = sum(stems)

        # Normalize to prevent clipping
        max_val = np.abs(mixed).max()
        if max_val > 0:
            mixed = mixed / max_val * 0.9

        return mixed

    def extract_segment(
        self,
        audio: np.ndarray,
        duration: float,
        offset: Optional[float] = None
    ) -> np.ndarray:
        """Extract fixed-duration segment from audio."""
        target_length = int(duration * self.sample_rate)

        if len(audio) < target_length:
            # Pad if too short
            audio = np.pad(audio, (0, target_length - len(audio)), mode='constant')
        elif len(audio) > target_length:
            # Extract segment
            if offset is not None:
                start = int(offset * self.sample_rate)
            else:
                max_start = len(audio) - target_length
                start = np.random.randint(0, max(1, max_start))
            audio = audio[start:start + target_length]

        return audio

    def compute_melspectrogram(self, audio: np.ndarray) -> np.ndarray:
        """Compute normalized log mel spectrogram."""
        # Compute mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=audio,
            sr=self.sample_rate,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            fmax=8000
        )

        # Convert to log scale (dB)
        log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        # Normalize to [0, 1]
        log_mel_spec = (log_mel_spec + 80) / 80
        log_mel_spec = np.clip(log_mel_spec, 0, 1)

        return log_mel_spec


class AudioAugmentor:
    """Handles audio augmentation for training."""

    def __init__(self, sample_rate: int = 22050):
        self.sample_rate = sample_rate

    def augment(self, audio: np.ndarray) -> np.ndarray:
        """Apply random augmentations."""
        # Random time stretch (±10%)
        if np.random.rand() > 0.5:
            rate = np.random.uniform(0.9, 1.1)
            audio = librosa.effects.time_stretch(audio, rate=rate)

        # Random pitch shift (±2 semitones)
        if np.random.rand() > 0.5:
            n_steps = np.random.uniform(-2, 2)
            audio = librosa.effects.pitch_shift(
                audio, sr=self.sample_rate, n_steps=n_steps
            )

        # Random gain (±3 dB)
        if np.random.rand() > 0.5:
            gain_db = np.random.uniform(-3, 3)
            audio = audio * (10 ** (gain_db / 20))

        return audio

## 2. Dataset Discovery

In [None]:
class StemDiscovery:
    """Discovers and validates GTZAN stem files."""

    @staticmethod
    def find_tracks(data_dir: Path, original_tracks_dir: Optional[Path] = None) -> List[Dict]:
        """
        Find all valid tracks with required stems and original mixed audio.
        Expected structure:
          - data_dir/genre/track_name/*.wav (stems)
          - original_tracks_dir/genre/track_name.wav (original mix)

        Args:
            data_dir: Path to GTZAN stems directory
            original_tracks_dir: Path to original (unseparated) tracks

        Returns:
            List of track dictionaries with stem paths and original track path
        """
        tracks = []
        genre_dirs = [d for d in data_dir.iterdir() if d.is_dir()]

        for genre_dir in genre_dirs:
            track_dirs = [d for d in genre_dir.iterdir() if d.is_dir()]

            for track_dir in track_dirs:
                stems = StemDiscovery._find_stems_in_directory(track_dir)

                # Find original mixed track
                original_path = None
                if original_tracks_dir is not None:
                    original_path = StemDiscovery._find_original_track(
                        original_tracks_dir, genre_dir.name, track_dir.name
                    )

                # Only add track if vocal stem is present and original exists
                if stems['vocals'] is not None and original_path is not None:
                    tracks.append({
                        'track_name': track_dir.name,
                        'genre': genre_dir.name,
                        'stems': stems,
                        'original': original_path
                    })

        return tracks

    @staticmethod
    def _find_stems_in_directory(track_dir: Path) -> Dict[str, Optional[Path]]:
        """Find stem files in a track directory."""
        stems = {
            'vocals': None,
            'drums': None,
            'bass': None,
            'other': None
        }

        for stem_file in track_dir.glob('*.wav'):
            stem_name = stem_file.stem.lower()
            if 'vocal' in stem_name:
                stems['vocals'] = stem_file
            elif 'drum' in stem_name:
                stems['drums'] = stem_file
            elif 'bass' in stem_name:
                stems['bass'] = stem_file
            elif 'other' in stem_name or 'accomp' in stem_name:
                stems['other'] = stem_file

        return stems

    @staticmethod
    def _find_original_track(
        original_dir: Path,
        genre: str,
        track_name: str
    ) -> Optional[Path]:
        """Find the original (unseparated) track."""
        # Try different possible locations
        possible_paths = [
            original_dir / genre / f"{track_name}.wav",
            original_dir / genre / f"{track_name}.mp3",
            original_dir / f"{track_name}.wav",
            original_dir / f"{track_name}.mp3",
        ]

        for path in possible_paths:
            if path.exists():
                return path

        return None

## 3. DATASET CLASS

In [None]:
class GTZANStemsDataset(Dataset):
    """
    Dataset for vocal-instrumental matching from GTZAN Stems.

    Returns:
        - Vocal spectrogram (isolated vocal stem)
        - Full track spectrogram (original mixed track - ground truth)
        - Track index (for positive pair identification)
    """

    def __init__(
        self,
        stems_dir: str,
        original_tracks_dir: str,
        sample_rate: int = 22050,
        n_fft: int = 2048,
        hop_length: int = 512,
        n_mels: int = 128,
        duration: float = 10.0,
        segment_offset: Optional[float] = None,
        augment: bool = True,
        cache_spectrograms: bool = False
    ):
        """
        Args:
            stems_dir: Path to GTZAN stems directory
            original_tracks_dir: Path to original (unseparated) tracks directory
            sample_rate: Target sample rate
            n_fft: FFT window size
            hop_length: Hop length for STFT
            n_mels: Number of mel bands
            duration: Duration of audio segments in seconds
            segment_offset: Fixed offset for reproducibility (None = random)
            augment: Whether to apply data augmentation
            cache_spectrograms: Whether to cache computed spectrograms
        """
        self.stems_dir = Path(stems_dir)
        self.original_tracks_dir = Path(original_tracks_dir)
        self.duration = duration
        self.segment_offset = segment_offset
        self.augment = augment
        self.cache_spectrograms = cache_spectrograms

        # Initialize processors
        self.audio_processor = AudioProcessor(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels
        )
        self.augmentor = AudioAugmentor(sample_rate=sample_rate) if augment else None

        # Discover tracks
        self.tracks = StemDiscovery.find_tracks(
            self.stems_dir,
            self.original_tracks_dir
        )
        print(f"Found {len(self.tracks)} tracks with vocal stems and original audio")

        # Cache for spectrograms
        self.cache = {} if cache_spectrograms else None

    def __len__(self) -> int:
        return len(self.tracks)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Returns:
            vocal_spec: (1, n_mels, time_steps) - isolated vocal stem
            full_track_spec: (1, n_mels, time_steps) - original mixed track (ground truth)
            track_idx: int for identifying positive pairs
        """
        # Check cache
        cache_key = f"{idx}_{self.segment_offset}"
        if self.cache_spectrograms and cache_key in self.cache:
            return self.cache[cache_key]

        track = self.tracks[idx]

        # Process vocal (isolated stem)
        vocal = self._process_vocal(track)

        # Process original full track (ground truth)
        full_track = self._process_full_track(track)

        # Compute spectrograms
        vocal_spec = self.audio_processor.compute_melspectrogram(vocal)
        full_track_spec = self.audio_processor.compute_melspectrogram(full_track)

        # Convert to tensors
        vocal_spec = torch.FloatTensor(vocal_spec).unsqueeze(0)
        full_track_spec = torch.FloatTensor(full_track_spec).unsqueeze(0)

        result = (vocal_spec, full_track_spec, idx)

        # Cache if enabled
        if self.cache_spectrograms:
            self.cache[cache_key] = result

        return result

    def _process_vocal(self, track: Dict) -> np.ndarray:
        """Load and process vocal stem."""
        vocal = self.audio_processor.load_audio(track['stems']['vocals'])
        vocal = self.audio_processor.extract_segment(
            vocal, self.duration, self.segment_offset
        )
        if self.augment:
            vocal = self.augmentor.augment(vocal)
        return vocal

    def _process_full_track(self, track: Dict) -> np.ndarray:
        """Load and process original full mixed track (ground truth)."""
        full_track = self.audio_processor.load_audio(track['original'])
        full_track = self.audio_processor.extract_segment(
            full_track, self.duration, self.segment_offset
        )
        if self.augment:
            full_track = self.augmentor.augment(full_track)
        return full_track

## 4. DATALOADER CREATION

In [None]:
def collate_contrastive_batch(batch):
    """
    Collate function for contrastive learning.

    Returns:
        vocal_specs: (batch_size, 1, n_mels, time_steps) - vocal stems
        full_track_specs: (batch_size, 1, n_mels, time_steps) - original tracks (ground truth)
        labels: (batch_size,) track indices for positive pairs
    """
    vocal_specs = torch.stack([item[0] for item in batch])
    full_track_specs = torch.stack([item[1] for item in batch])
    labels = torch.LongTensor([item[2] for item in batch])

    return vocal_specs, full_track_specs, labels


def create_dataloaders(
    stems_dir: str,
    original_tracks_dir: str,
    batch_size: int = 32,
    train_split: float = 0.8,
    val_split: float = 0.1,
    num_workers: int = 4,
    seed: int = 42,
    **dataset_kwargs
):
    """
    Create train, validation, and test dataloaders.

    Args:
        stems_dir: Path to GTZAN stems directory
        original_tracks_dir: Path to original (unseparated) tracks directory
        batch_size: Batch size for dataloaders
        train_split: Proportion of data for training
        val_split: Proportion of data for validation
        num_workers: Number of worker processes
        seed: Random seed for reproducibility
        **dataset_kwargs: Additional arguments for GTZANStemsDataset

    Returns:
        train_loader, val_loader, test_loader
    """
    # Create full dataset
    dataset = GTZANStemsDataset(stems_dir, original_tracks_dir, **dataset_kwargs)

    # Split dataset
    total_size = len(dataset)
    train_size = int(train_split * total_size)
    val_size = int(val_split * total_size)
    test_size = total_size - train_size - val_size

    torch.manual_seed(seed)
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_contrastive_batch,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_contrastive_batch,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_contrastive_batch,
        pin_memory=True
    )

    print(f"\nDataset splits:")
    print(f"  Train: {len(train_dataset)} samples")
    print(f"  Val: {len(val_dataset)} samples")
    print(f"  Test: {len(test_dataset)} samples")

    return train_loader, val_loader, test_loader


In [None]:
if __name__ == "__main__":
    # Create dataloaders
    stems_dir = "/path/to/gtzan-stems"  # Directory with separated stems
    original_tracks_dir = "/path/to/gtzan-original"  # Directory with original mixed tracks

    train_loader, val_loader, test_loader = create_dataloaders(
        stems_dir=stems_dir,
        original_tracks_dir=original_tracks_dir,
        batch_size=16,
        sample_rate=22050,
        duration=10.0,
        augment=True,
        cache_spectrograms=False
    )

    # Test loading a batch
    print("\nTesting batch loading...")
    vocal_batch, full_track_batch, labels = next(iter(train_loader))
    print(f"Vocal batch shape: {vocal_batch.shape}")
    print(f"Full track batch shape: {full_track_batch.shape}")
    print(f"Labels shape: {labels.shape}")