# Polyphonic Leakage Stress Test

This notebook evaluates how the trained Joint Detection and Classification (JDC) model handles polyphonic mixtures of vocals and accompaniment. We sweep vocal-to-accompaniment gain ratios and measure vocal activity detection quality together with raw pitch accuracy (RPA) on vocal-active frames. The experiment probes whether the model maintains melody tracking when the accompaniment becomes dominant, revealing failure boundaries for karaoke and singing use cases.

**Data expectations**

* The evaluation list referenced by `CONFIG['eval_list_path']` should contain one item per line.
* Each line must provide at least the vocal stem path. Optionally, supply an accompaniment stem as the second `|`-separated field (e.g., `vocal.wav|accompaniment.wav`).
* Paths can be absolute or relative to the evaluation list file.

If no accompaniment is provided for an entry, it will be skipped with a warning because the leakage experiment requires both stems.

In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld pandas tqdm


In [None]:
import sys
import math
import warnings
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
import pyworld as pw
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet

plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


## Configuration

Update the configuration below to point at your training config, checkpoint directory, and evaluation list. The `vocal_to_accompaniment_db` sweep defines the relative gain applied to the vocal stem with respect to the accompaniment (positive values make vocals louder, negative values make accompaniment louder).

In [None]:
CONFIG: Dict[str, Any] = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "eval_list_path": None,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts" / "polyphonic_leakage",
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "vocal_to_accompaniment_db": list(np.arange(12, -12.1, -3.0)),
    "normalize_mix_peak": True,
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()
AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}

def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if config_candidate.exists():
        return config_candidate
    return repo_candidate

def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None
    def _sort_key(p: Path) -> tuple[int, float]:
        numbers = [int(match) for match in __import__("re").findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime
    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None

def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}
    import yaml
    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}
    dataset_params = data.get("dataset_params", {})
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr
    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")
    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)
    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")
    return data

TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = int(MEL_PARAMS["sample_rate"])
HOP_LENGTH = int(MEL_PARAMS["hop_length"])
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)

def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)

def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int]:
    audio, sr = sf.read(str(path), dtype="float32")
    audio = _ensure_mono(audio)
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def compute_reference_f0(audio: np.ndarray, sr: int) -> np.ndarray:
    if audio.size == 0:
        return np.zeros((0,), dtype=np.float32)
    audio64 = audio.astype("double")
    f0, t = pw.harvest(audio64, sr, frame_period=FRAME_PERIOD_MS)
    if np.count_nonzero(f0) < 5:
        f0, t = pw.dio(audio64, sr, frame_period=FRAME_PERIOD_MS)
    refined = pw.stonemask(audio64, f0, t, sr)
    return refined.astype(np.float32)

def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)

def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = int(CONFIG["chunk_size"])
    overlap = int(CONFIG["chunk_overlap"])
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)

def hz_to_cents(f0: np.ndarray) -> np.ndarray:
    cents = np.zeros_like(f0)
    positive = f0 > 0
    cents[positive] = 1200.0 * np.log2(f0[positive] / 55.0)
    return cents.astype(np.float32)

def circular_cents_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    diff = a - b
    diff = np.mod(diff + 600.0, 1200.0) - 600.0
    return diff

def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]
    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]
    total_frames = max(length, 1)
    voiced_frames = max(int(np.count_nonzero(ref_voiced)), 1)
    tp = np.count_nonzero(ref_voiced & pred_voiced)
    fp = np.count_nonzero(~ref_voiced & pred_voiced)
    fn = np.count_nonzero(ref_voiced & ~pred_voiced)
    precision = float(tp / (tp + fp)) if (tp + fp) > 0 else float("nan")
    recall = float(tp / (tp + fn)) if (tp + fn) > 0 else float("nan")
    if precision + recall > 0:
        f1 = 2.0 * precision * recall / (precision + recall)
    else:
        f1 = float("nan")
    vuv_accuracy = float(np.count_nonzero(ref_voiced == pred_voiced) / total_frames)
    if np.count_nonzero(ref_voiced) == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "VAD_Precision": precision,
            "VAD_Recall": recall,
            "VAD_F1": f1,
            "OctaveError": float("nan"),
            "VoicedFrames": 0.0,
        }
    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))
    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0
    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0
    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (
        np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0
    )
    return {
        "RPA": float(np.count_nonzero(rpa_hits) / np.count_nonzero(ref_voiced)),
        "RCA": float(np.count_nonzero(rca_hits) / np.count_nonzero(ref_voiced)),
        "VUV": vuv_accuracy,
        "VAD_Precision": precision,
        "VAD_Recall": recall,
        "VAD_F1": f1,
        "OctaveError": float(np.count_nonzero(octave_errors) / np.count_nonzero(ref_voiced)),
        "VoicedFrames": float(np.count_nonzero(ref_voiced)),
    }

def _match_length(audio: np.ndarray, target_length: int) -> np.ndarray:
    if audio.shape[0] == target_length:
        return audio
    if audio.shape[0] > target_length:
        return audio[:target_length]
    return np.pad(audio, (0, target_length - audio.shape[0]))

def mix_vocals_with_accompaniment(
    vocal: np.ndarray,
    accompaniment: Optional[np.ndarray],
    ratio_db: float,
    normalize_peak: bool = True,
) -> np.ndarray:
    if accompaniment is None or accompaniment.size == 0:
        warnings.warn("Missing accompaniment; returning scaled vocals only.")
        accompaniment = np.zeros_like(vocal)
    target_length = max(vocal.shape[0], accompaniment.shape[0])
    vocal_aligned = _match_length(vocal, target_length)
    accomp_aligned = _match_length(accompaniment, target_length)
    vocal_gain = 10.0 ** (ratio_db / 20.0)
    mix = vocal_gain * vocal_aligned + accomp_aligned
    if normalize_peak and mix.size > 0:
        peak = np.max(np.abs(mix))
        if peak > 0.99:
            mix = mix / (peak + 1e-6)
    return mix.astype(np.float32)

DATASET_CACHE: Optional[List[Dict[str, Any]]] = None

def prepare_dataset_cache(force: bool = False) -> List[Dict[str, Any]]:
    global DATASET_CACHE
    if DATASET_CACHE is not None and not force:
        return DATASET_CACHE
    eval_list = CONFIG.get("eval_list_path")
    if eval_list is None or not Path(eval_list).is_file():
        raise FileNotFoundError(f"Evaluation list not found: {eval_list}")
    entries: List[Dict[str, Optional[Path]]] = []
    with open(eval_list, "r", encoding="utf-8") as handle:
        for raw_line in handle:
            line = raw_line.strip()
            if not line:
                continue
            parts = [part.strip() for part in line.split("|")]
            vocal_candidate = Path(parts[0])
            if not vocal_candidate.is_absolute():
                vocal_candidate = (Path(eval_list).parent / vocal_candidate).resolve()
            accompaniment_candidate = None
            if len(parts) > 1 and parts[1]:
                accompaniment_candidate = Path(parts[1])
                if not accompaniment_candidate.is_absolute():
                    accompaniment_candidate = (Path(eval_list).parent / accompaniment_candidate).resolve()
            entries.append({
                "vocal": vocal_candidate,
                "accompaniment": accompaniment_candidate,
            })
    if not entries:
        raise RuntimeError(f"No evaluation files located in {eval_list}")
    cache: List[Dict[str, Any]] = []
    skipped = 0
    for entry in tqdm(entries, desc="Preparing evaluation cache"):
        vocal_path = entry["vocal"]
        accompaniment_path = entry.get("accompaniment")
        if accompaniment_path is None or not accompaniment_path.is_file():
            warnings.warn(f"Skipping {vocal_path} because accompaniment is missing.")
            skipped += 1
            continue
        vocal_audio, sr = load_waveform(vocal_path)
        accompaniment_audio, accomp_sr = load_waveform(accompaniment_path)
        if accomp_sr != sr:
            accompaniment_audio, _ = load_waveform(accompaniment_path, sr)
        reference_f0 = compute_reference_f0(vocal_audio, sr)
        cache.append({
            "path": vocal_path,
            "vocal_audio": vocal_audio,
            "accompaniment_audio": accompaniment_audio,
            "sample_rate": sr,
            "reference_f0": reference_f0,
        })
    if not cache:
        raise RuntimeError("No evaluation items with both vocal and accompaniment were found.")
    if skipped:
        print(f"Warning: skipped {skipped} entries without accompaniment stems.")
    DATASET_CACHE = cache
    print(f"Cached {len(DATASET_CACHE)} evaluation utterances.")
    return DATASET_CACHE

def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    state = torch.load(checkpoint_path, map_location=DEVICE)
    model_state = state.get("model", {})
    classifier_weight = model_state.get("classifier.weight") if isinstance(model_state, dict) else None
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(state.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes)
    model.load_state_dict(model_state)
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model


## Load model and evaluation cache

In [None]:
model = load_model()
DATASET_CACHE = prepare_dataset_cache(force=True)
len(DATASET_CACHE)


## Baseline: Vocal-only performance

We first evaluate the model on clean vocal stems to establish a reference for RPA and vocal activity detection quality.

In [None]:
baseline_records: List[Dict[str, Any]] = []
for entry in tqdm(DATASET_CACHE, desc="Evaluating vocal-only baseline"):
    prediction = predict_f0(model, entry["vocal_audio"])
    metrics = compute_metrics(entry["reference_f0"], prediction)
    baseline_records.append({"path": str(entry["path"]), **metrics})
baseline_df = pd.DataFrame(baseline_records)
baseline_summary = baseline_df[["RPA", "VAD_F1", "VAD_Recall", "VAD_Precision", "VUV"]].mean()
display(baseline_summary.to_frame(name="Baseline"))


## Polyphonic leakage sweep

Each condition mixes the vocal and accompaniment stems according to the specified gain ratio. We then evaluate RPA and vocal activity detection metrics on the resulting mixtures.

In [None]:
sweep_results: List[Dict[str, Any]] = []
for ratio_db in CONFIG["vocal_to_accompaniment_db"]:
    for entry in tqdm(DATASET_CACHE, desc=f"Mixing {ratio_db:+.1f} dB", leave=False):
        mixture = mix_vocals_with_accompaniment(
            entry["vocal_audio"],
            entry["accompaniment_audio"],
            ratio_db,
            normalize_peak=CONFIG.get("normalize_mix_peak", True),
        )
        prediction = predict_f0(model, mixture)
        metrics = compute_metrics(entry["reference_f0"], prediction)
        sweep_results.append({
            "ratio_db": float(ratio_db),
            "path": str(entry["path"]),
            **metrics,
        })
polyphonic_df = pd.DataFrame(sweep_results)
polyphonic_summary = (
    polyphonic_df
    .groupby("ratio_db")[["RPA", "VAD_F1", "VAD_Recall", "VAD_Precision", "VUV"]]
    .mean()
    .reset_index()
    .sort_values("ratio_db", ascending=False)
)
display(polyphonic_summary)


## Visualise failure boundaries

In [None]:
if not polyphonic_summary.empty:
    fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharex=True)
    axes[0].plot(polyphonic_summary["ratio_db"], polyphonic_summary["RPA"], marker="o")
    axes[0].axhline(baseline_summary["RPA"], color="gray", linestyle="--", label="Baseline")
    axes[0].set_xlabel("Vocal : Accompaniment (dB)")
    axes[0].set_ylabel("RPA")
    axes[0].set_title("Raw Pitch Accuracy vs Mix Ratio")
    axes[0].invert_xaxis()
    axes[0].legend()

    axes[1].plot(polyphonic_summary["ratio_db"], polyphonic_summary["VAD_F1"], marker="o", label="F1")
    axes[1].plot(polyphonic_summary["ratio_db"], polyphonic_summary["VAD_Recall"], marker="s", label="Recall")
    axes[1].plot(polyphonic_summary["ratio_db"], polyphonic_summary["VAD_Precision"], marker="^", label="Precision")
    axes[1].axhline(baseline_summary["VAD_F1"], color="gray", linestyle="--", label="Baseline F1")
    axes[1].set_xlabel("Vocal : Accompaniment (dB)")
    axes[1].set_ylabel("Score")
    axes[1].set_title("Vocal Activity Detection Metrics")
    axes[1].invert_xaxis()
    axes[1].legend()

    plt.tight_layout()
    plt.show()
else:
    print("No sweep results available for plotting.")


## Export artefacts

In [None]:
OUTPUT_DIR = CONFIG["output_dir"]
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
baseline_df.to_csv(OUTPUT_DIR / "baseline_vocal_metrics.csv", index=False)
polyphonic_df.to_csv(OUTPUT_DIR / "polyphonic_leakage_metrics.csv", index=False)
polyphonic_summary.to_csv(OUTPUT_DIR / "polyphonic_leakage_summary.csv", index=False)
print(f"Artifacts written to {OUTPUT_DIR}")
