# Data exploration: DN-LM and training pipeline

This notebook lets you **listen to and inspect** the data fed to the models:
1. **Raw DN-LM samples** — vocals, noise, and mixture as stored on disk (no chunking, no augmentations).
2. **After the dataset pipeline** — the same data after chunking, optional loudness, and mixing (what the model sees as input and target).

Run from the project root so `dataset` and `utils` import correctly.

## Setup: paths and config

In [1]:
import os
import json
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import Audio, display

# Project root (adjust if running from elsewhere)
PROJECT_ROOT = Path(os.getcwd())
if (PROJECT_ROOT / "dataset.py").exists():
    os.chdir(PROJECT_ROOT)

# Paths — point to your DN-LM split and config
DATA_PATH = PROJECT_ROOT / "datasets" / "DN-LM" / "train"
VALID_PATH = PROJECT_ROOT / "datasets" / "DN-LM" / "valid"
CONFIG_PATH = PROJECT_ROOT / "configs" / "3_FA_RoPE(64).yaml"

# Where to cache dataset metadata (no training results needed)
EXPLORE_CACHE = PROJECT_ROOT / "results" / "explore"
EXPLORE_CACHE.mkdir(parents=True, exist_ok=True)

from utils import load_config

config = load_config("edge_bs_rof", str(CONFIG_PATH))
sample_rate = getattr(config.audio, "sample_rate", None) or getattr(config.audio, "samplerate", 16000)
chunk_size = config.audio.chunk_size
instruments = list(config.training.instruments)
target_instrument = getattr(config.training, "target_instrument", None) or "vocals"

print(f"Sample rate: {sample_rate} Hz")
print(f"Chunk size: {chunk_size} samples ({chunk_size / sample_rate:.2f} s)")
print(f"Instruments: {instruments}")
print(f"Target: {target_instrument}")

Sample rate: 16000 Hz
Chunk size: 131584 samples (8.22 s)
Instruments: ['vocals', 'noise']
Target: vocals


## 1. Raw DN-LM samples (from disk)

Load full-length **vocals**, **noise**, and **mixture** from a few sample folders. No chunking, no augmentations — exactly as created by `create_dataset.py`.

In [2]:
def load_raw_sample(sample_dir, sr=None):
    """Load vocals, noise, mixture from one DN-LM sample folder. Returns (vocals, noise, mixture), sr."""
    sample_dir = Path(sample_dir)
    vocals, sr_v = sf.read(sample_dir / "vocals.wav", dtype="float32")
    noise, sr_n = sf.read(sample_dir / "noise.wav", dtype="float32")
    mixture, sr_m = sf.read(sample_dir / "mixture.wav", dtype="float32")
    sr_actual = sr_v  # all same in DN-LM
    if sr is not None and sr_actual != sr:
        import librosa
        vocals = librosa.resample(vocals, orig_sr=sr_actual, target_sr=sr)
        noise = librosa.resample(noise, orig_sr=sr_actual, target_sr=sr)
        mixture = librosa.resample(mixture, orig_sr=sr_actual, target_sr=sr)
        sr_actual = sr
    return vocals, noise, mixture, sr_actual


def plot_waveforms(vocals, noise, mixture, sr, title="", axs=None):
    t = np.arange(len(mixture)) / sr
    if axs is None:
        fig, axs = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
    for ax, sig, label in zip(axs, [vocals, noise, mixture], ["Vocals (target)", "Noise", "Mixture (input)"]):
        sig_1d = np.squeeze(sig)
        if sig_1d.ndim > 1:
            sig_1d = sig_1d.mean(axis=1)
        ax.plot(t, sig_1d, color="#2e86ab" if "Vocals" in label else "#a23b72" if "Noise" in label else "#f18f01")
        ax.set_ylabel(label)
        ax.set_xlim(0, t[-1])
    axs[-1].set_xlabel("Time (s)")
    if title:
        axs[0].set_title(title)
    plt.tight_layout()
    return axs


# List sample folders (train or valid)
split_dir = DATA_PATH if DATA_PATH.exists() else VALID_PATH
sample_dirs = sorted([d for d in split_dir.iterdir() if d.is_dir() and not d.name.startswith(".")])
if not sample_dirs:
    raise FileNotFoundError(f"No sample folders in {split_dir}. Create DN-LM with create_dataset.py first.")

# How many raw samples to show
num_raw_samples = 3
indices = np.linspace(0, len(sample_dirs) - 1, num_raw_samples, dtype=int)

for idx in indices:
    sample_dir = sample_dirs[idx]
    vocals, noise, mixture, sr = load_raw_sample(sample_dir, sr=sample_rate)
    fig, axs = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
    plot_waveforms(vocals, noise, mixture, sr, title=f"Raw DN-LM: {sample_dir.name}", axs=axs)
    plt.show()
    print(f"\n{sample_dir.name} — listen:")
    display(Audio(mixture, rate=sr))
    display(Audio(np.squeeze(vocals), rate=sr))
    display(Audio(np.squeeze(noise), rate=sr))

FileNotFoundError: [Errno 2] No such file or directory: '/home/flyingleafe/Research/PhD/projects/Edge-BS-RoFormer-DroneNoise-LibriMix/datasets/DN-LM/valid'

## 2. Metadata (SNR, sources)

If DN-LM was created with `create_dataset.py`, each split has a `metadata.json` with SNR and source file names.

In [None]:
metadata_path = split_dir / "metadata.json"
if metadata_path.exists():
    with open(metadata_path) as f:
        meta = json.load(f)
    # Key is 'train' or 'valid'
    key = "train" if "train" in meta else "valid"
    entries = meta[key]
    print(f"Metadata entries: {len(entries)}")
    for idx in indices[:3]:
        e = entries[idx]
        print(f"  {e.get('id', idx)}: SNR={e.get('input_snr', 'N/A'):.2f} dB, speech_source={e.get('speech_source', '')}")
else:
    print(f"No metadata at {metadata_path}")

## 3. After the dataset pipeline (what the model sees)

The training dataloader uses `MSSDataset`: it loads **chunks** of audio (random offset), optionally applies loudness per track, then sums sources into the mixture. With the default Edge-BS-RoFormer config, **augmentations are disabled**, so you see chunking + loudness only. Output is **(target, mix)** — target is the vocals track, mix is the model input.

Shapes: `(2, chunk_size)` (stereo). For playback we use mono (mean over channels).

In [None]:
from dataset import MSSDataset

data_path = str(DATA_PATH) if DATA_PATH.exists() else str(VALID_PATH)
metadata_path = str(EXPLORE_CACHE / f"metadata_1.pkl")

ds = MSSDataset(
    config,
    data_path,
    metadata_path=metadata_path,
    dataset_type=1,
    batch_size=config.training.batch_size,
    verbose=True,
)

# Get a few samples (target, mix) — each ds[i] is an independent random draw
num_samples_show = 2

for b in range(num_samples_show):
    target, mix = ds[b]  # single sample: (2, chunk_size), (2, chunk_size)
    target_np = target.numpy()  # (2, T)
    mix_np = mix.numpy()
    t = np.arange(chunk_size) / sample_rate
    fig, axs = plt.subplots(2, 1, figsize=(10, 4), sharex=True)
    for ax, sig, label in zip(axs, [target_np, mix_np], ["Target (vocals)", "Mixture (input)"]):
        mono = np.squeeze(sig).mean(axis=0) if np.squeeze(sig).ndim > 1 else np.squeeze(sig)
        ax.plot(t, mono, color="#2e86ab" if "Target" in label else "#f18f01")
        ax.set_ylabel(label)
    axs[-1].set_xlabel("Time (s)")
    axs[0].set_title(f"Dataset output (after chunking/loudness) — sample index {b}")
    plt.tight_layout()
    plt.show()
    print(f"Listen — target then mixture:")
    display(Audio(target_np.mean(axis=0), rate=sample_rate))
    display(Audio(mix_np.mean(axis=0), rate=sample_rate))

## 4. Side-by-side: one raw sample vs one dataset chunk

Compare a **full raw mixture** from disk with a **chunk** from the dataset (different random draw). Duration of the chunk is `chunk_size / sample_rate` seconds.

In [None]:
vocals, noise, mixture, sr = load_raw_sample(sample_dirs[0], sr=sample_rate)
target_chunk, mix_chunk = ds[0]
target_chunk = target_chunk.numpy().mean(axis=0)
mix_chunk = mix_chunk.numpy().mean(axis=0)

fig, axs = plt.subplots(2, 1, figsize=(10, 5), sharex=False)
t_raw = np.arange(len(mixture)) / sr
t_chunk = np.arange(len(mix_chunk)) / sample_rate
axs[0].plot(t_raw, np.squeeze(mixture).mean(axis=0) if np.squeeze(mixture).ndim > 1 else np.squeeze(mixture), color="#f18f01", label="Raw mixture (full length)")
axs[0].set_ylabel("Amplitude")
axs[0].set_title("Raw mixture from disk")
axs[0].legend()
axs[1].plot(t_chunk, mix_chunk, color="#f18f01", label="Dataset mixture (chunk)")
axs[1].set_xlabel("Time (s)")
axs[1].set_ylabel("Amplitude")
axs[1].set_title(f"Dataset mixture chunk ({len(mix_chunk)} samples = {len(mix_chunk)/sample_rate:.2f} s)")
axs[1].legend()
plt.tight_layout()
plt.show()

print("Raw mixture (full):")
display(Audio(np.squeeze(mixture), rate=sr))
print("Dataset mixture (chunk):")
display(Audio(mix_chunk, rate=sample_rate))