# Pitch Range & Timbre Coverage Evaluation


This notebook sweeps through defined vocal register ranges and synthetic timbre profiles to evaluate the latest pitch extraction model. It reports per-range metrics, highlights octave confusions at the extreme ends of each range, and visualizes how accuracy varies with pitch.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pandas tqdm


In [None]:
import sys
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet
from Utils.dynamic_pitch_tools import sample_reference_f0, hz_to_cents, circular_cents_distance


In [None]:
plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


In [None]:
CONFIG: Dict[str, Any] = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts",
    "stimulus": {
        "duration_seconds": 2.5,
        "frequencies_per_range": 15,
        "edge_band_fraction": 0.15,
        "random_seed": 1337,
    },
    "ranges": [
        {"name": "Bass", "min_hz": 70.0, "max_hz": 120.0},
        {"name": "Baritone/Tenor", "min_hz": 120.0, "max_hz": 220.0},
        {"name": "Alto", "min_hz": 220.0, "max_hz": 350.0},
        {"name": "Child/Falsetto", "min_hz": 350.0, "max_hz": 1000.0},
    ],
    "timbre_profiles": {
        "Pure Sine": {"partials": {1: 1.0}},
        "Warm Vocal": {"partials": {1: 1.0, 2: 0.45, 3: 0.2}},
        "Bright Belt": {"partials": {1: 1.0, 2: 0.9, 3: 0.75, 4: 0.5, 5: 0.35}},
        "Breathy Head": {"partials": {1: 1.0, 2: 0.5, 3: 0.35}, "snr_db": 25.0},
    },
}

Path(CONFIG["output_dir"]).mkdir(parents=True, exist_ok=True)


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()

def _normalize_mel_params(params: Dict[str, Any]) -> Dict[str, Any]:
    if "win_len" in params:
        win_len = params.pop("win_len")
        params.setdefault("win_length", win_len)
    return params


def _deep_merge_dict(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    merged = base.copy()
    for key, value in overrides.items():
        if isinstance(value, dict):
            existing = merged.get(key)
            if isinstance(existing, dict):
                merged[key] = _deep_merge_dict(existing, value)
            else:
                merged[key] = value.copy()
        else:
            merged[key] = value
    return merged

MEL_PARAMS = _normalize_mel_params(MEL_PARAMS)

def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate


import re


def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None

    def _sort_key(p: Path) -> Tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime

    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None


TRAINING_CONFIG: Dict[str, Any] = {}


def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}

    import yaml

    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}

    dataset_params = data.get("dataset_params", {})
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)

    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr

    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")

    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)

    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")

    return data


TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = MEL_PARAMS["sample_rate"]
HOP_LENGTH = MEL_PARAMS["hop_length"]
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)


In [None]:
rng = np.random.default_rng(CONFIG["stimulus"]["random_seed"])


def _apply_fade(audio: np.ndarray, sr: int, fade_time: float = 0.02) -> np.ndarray:
    fade_samples = int(max(fade_time * sr, 0))
    if fade_samples <= 0:
        return audio.astype(np.float32, copy=False)
    window = np.ones_like(audio, dtype=np.float64)
    ramp = 0.5 - 0.5 * np.cos(np.linspace(0.0, math.pi, fade_samples, dtype=np.float64))
    window[:fade_samples] = ramp
    window[-fade_samples:] = ramp[::-1]
    return (audio * window).astype(np.float32)


def _normalize(audio: np.ndarray) -> np.ndarray:
    peak = float(np.max(np.abs(audio)))
    if peak > 0.99:
        audio = audio / (peak + 1e-6)
    return audio.astype(np.float32)


def synthesize_timbre_waveform(frequency: float, sr: int, duration: float, profile: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
    total_samples = int(duration * sr)
    t = np.linspace(0.0, duration, total_samples, endpoint=False, dtype=np.float64)
    partials = profile.get("partials", {1: 1.0})
    waveform = np.zeros_like(t, dtype=np.float64)
    for harmonic, amplitude in partials.items():
        if amplitude == 0:
            continue
        waveform += amplitude * np.sin(2.0 * math.pi * frequency * float(harmonic) * t)

    envelope = profile.get("envelope")
    if callable(envelope):
        waveform *= envelope(t)

    waveform = _apply_fade(waveform.astype(np.float32), sr)
    signal_rms = float(np.sqrt(np.mean(waveform**2))) if waveform.size else 0.0

    snr_db = profile.get("snr_db")
    if snr_db is not None and signal_rms > 0:
        noise = rng.standard_normal(waveform.shape).astype(np.float32)
        noise_rms = float(np.sqrt(np.mean(noise**2)))
        if noise_rms > 0:
            target_noise_rms = signal_rms / (10.0 ** (snr_db / 20.0))
            noise *= target_noise_rms / noise_rms
            waveform = waveform + noise

    return _normalize(waveform), t.astype(np.float32)


def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)


def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int] | None:
    try:
        audio, sr = sf.read(str(path), dtype="float32")
    except KeyboardInterrupt:
        raise
    except Exception as exc:
        print(f"Warning: skipping unreadable file '{path}': {exc}")
        return None
    audio = _ensure_mono(audio)
    if audio.size == 0:
        print(f"Warning: skipping empty file '{path}'.")
        return None
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)


def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = CONFIG["chunk_size"]
    overlap = CONFIG["chunk_overlap"]
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)


def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]
    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]
    total_frames = length
    voiced_frames = int(np.count_nonzero(ref_voiced))
    vuv_accuracy = float(np.count_nonzero(ref_voiced == pred_voiced) / max(total_frames, 1))
    if voiced_frames == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "OctaveError": float("nan"),
        }
    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))
    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0
    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0
    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0)
    return {
        "RPA": float(np.count_nonzero(rpa_hits) / voiced_frames),
        "RCA": float(np.count_nonzero(rca_hits) / voiced_frames),
        "VUV": vuv_accuracy,
        "OctaveError": float(np.count_nonzero(octave_errors) / voiced_frames),
    }

In [None]:

def _collect_model_configuration(checkpoint_state: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    model_params: Dict[str, Any] = {}
    training_params = TRAINING_CONFIG.get("model_params")
    if isinstance(training_params, dict):
        model_params = _deep_merge_dict(model_params, training_params)
    if isinstance(checkpoint_state, dict):
        candidate = checkpoint_state.get("model_params")
        if isinstance(candidate, dict):
            model_params = _deep_merge_dict(model_params, candidate)
        config_section = checkpoint_state.get("config")
        if isinstance(config_section, dict):
            candidate = config_section.get("model_params")
            if isinstance(candidate, dict):
                model_params = _deep_merge_dict(model_params, candidate)
    sequence_config = model_params.get("sequence_model")
    if isinstance(sequence_config, dict):
        sequence_config = sequence_config.copy()
    else:
        sequence_config = {}
    top_level = {k: v for k, v in model_params.items() if k != "sequence_model"}
    top_level.pop("num_class", None)
    return top_level, sequence_config


def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if not isinstance(checkpoint, dict):
        raise RuntimeError("Unexpected checkpoint format")
    model_state = checkpoint.get("model")
    if model_state is None:
        model_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(model_state, dict):
        raise RuntimeError("Checkpoint is missing a valid model state")
    model_state = dict(model_state)
    classifier_weight = model_state.get("classifier.weight")
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(checkpoint.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    model_params, sequence_model_config = _collect_model_configuration(checkpoint)
    model_kwargs: Dict[str, Any] = {}
    if sequence_model_config:
        model_kwargs["sequence_model_config"] = sequence_model_config
    leaky_relu_slope = model_params.get("leaky_relu_slope")
    if isinstance(leaky_relu_slope, (int, float)):
        model_kwargs["leaky_relu_slope"] = float(leaky_relu_slope)
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes, **model_kwargs)
    incompatible = model.load_state_dict(model_state, strict=False)
    if incompatible.unexpected_keys:
        skipped = ", ".join(sorted(incompatible.unexpected_keys))
        print(f"Warning: skipped unexpected parameters: {skipped}")
    if incompatible.missing_keys:
        missing = ", ".join(sorted(incompatible.missing_keys))
        print(f"Warning: missing parameters not found in checkpoint: {missing}")
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model


In [None]:
model = load_model()


In [None]:
def _frequency_grid(min_hz: float, max_hz: float, count: int) -> np.ndarray:
    if count <= 1:
        return np.array([(min_hz + max_hz) / 2.0], dtype=np.float64)
    return np.linspace(min_hz, max_hz, count, dtype=np.float64)


RESULTS: List[Dict[str, Any]] = []

for range_info in tqdm(CONFIG["ranges"], desc="Ranges"):
    name = range_info["name"]
    min_hz = float(range_info["min_hz"])
    max_hz = float(range_info["max_hz"])
    frequencies = _frequency_grid(min_hz, max_hz, CONFIG["stimulus"]["frequencies_per_range"])
    low_cut = min_hz + (max_hz - min_hz) * CONFIG["stimulus"]["edge_band_fraction"]
    high_cut = max_hz - (max_hz - min_hz) * CONFIG["stimulus"]["edge_band_fraction"]

    for frequency in tqdm(frequencies, desc=f"Frequencies: {name}", leave=False):
        for timbre_name, profile in CONFIG["timbre_profiles"].items():
            audio, time_axis = synthesize_timbre_waveform(float(frequency), TARGET_SAMPLE_RATE, CONFIG["stimulus"]["duration_seconds"], profile)
            prediction = predict_f0(model, audio)
            reference_curve = np.full(time_axis.shape[0], float(frequency), dtype=np.float32)
            reference = sample_reference_f0(time_axis, reference_curve, prediction.shape[0])
            metrics = compute_metrics(reference, prediction)
            if frequency <= low_cut:
                edge = "low"
            elif frequency >= high_cut:
                edge = "high"
            else:
                edge = "mid"
            RESULTS.append({
                "range": name,
                "frequency_hz": float(frequency),
                "timbre": timbre_name,
                "edge_region": edge,
                **metrics,
            })

results_df = pd.DataFrame(RESULTS)
results_df.sort_values(["range", "frequency_hz", "timbre"], inplace=True)
results_df.head()


In [None]:
fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

for (range_name, timbre_name), group in results_df.groupby(["range", "timbre"]):
    axes[0].plot(group["frequency_hz"], group["RPA"], marker="o", label=f"{range_name} - {timbre_name}")
    axes[1].plot(group["frequency_hz"], group["OctaveError"], marker="o", label=f"{range_name} - {timbre_name}")

axes[0].set_ylabel("RPA")
axes[0].set_title("Raw Pitch Accuracy vs. Frequency")
axes[0].set_ylim(0.0, 1.05)

axes[1].set_ylabel("Octave Error Rate")
axes[1].set_xlabel("Fundamental Frequency (Hz)")
axes[1].set_title("Octave Confusions vs. Frequency")
axes[1].set_ylim(0.0, 1.0)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.02))
plt.tight_layout(rect=(0, 0, 1, 0.98))
plt.show()


In [None]:
per_range_metrics = (
    results_df.groupby(["range"])
    .agg({"RPA": "mean", "RCA": "mean", "VUV": "mean", "OctaveError": "mean"})
    .rename(columns={"RPA": "RPA_mean", "RCA": "RCA_mean", "VUV": "VUV_mean", "OctaveError": "OctaveError_mean"})
)
per_range_metrics


In [None]:
extreme_summary = (
    results_df[results_df["edge_region"].isin(["low", "high"])]
    .groupby(["range", "edge_region"])
    .agg({"OctaveError": ["mean", "count"], "RPA": "mean"})
)
extreme_summary


In [None]:
output_path = Path(CONFIG["output_dir"]) / "pitch_range_timbre_metrics.csv"
results_df.to_csv(output_path, index=False)
print(f"Saved detailed metrics to {output_path.resolve()}")
