# Dynamic Pitch Behavior Evaluation


This notebook evaluates how the latest pitch extraction model responds to dynamic pitch behaviors such as vibrato and portamento/glide transitions. Synthetic stimuli are generated with controlled parameters so that we can quantify accuracy, latency, and overshoot across a range of condition levels.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld torchcrepe praat-parselmouth pandas tqdm


In [None]:
import sys

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet

from Utils.dynamic_pitch_tools import (
    generate_vibrato_waveform,
    generate_glide_waveform,
    sample_reference_f0,
    hz_to_cents,
    circular_cents_distance,
    rms_cents_error,
    estimate_tracking_delay_ms,
    compute_overshoot_cents,
)


plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
CONFIG = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts",
    "vibrato": {
        "base_frequency_hz": 220.0,
        "duration_seconds": 3.0,
        "rates_hz": [4.0, 6.0, 8.0],
        "depth_cents": [20, 60, 120, 200],
    },
    "glide": {
        "start_hz": 60.0,
        "end_hz": 500.0,
        "durations_seconds": [0.4, 0.8, 1.6, 3.2],
    },
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()

def _normalize_mel_params(params: Dict[str, Any]) -> Dict[str, Any]:
    if "win_len" in params:
        win_len = params.pop("win_len")
        params.setdefault("win_length", win_len)
    return params


def _deep_merge_dict(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    merged = base.copy()
    for key, value in overrides.items():
        if isinstance(value, dict):
            existing = merged.get(key)
            if isinstance(existing, dict):
                merged[key] = _deep_merge_dict(existing, value)
            else:
                merged[key] = value.copy()
        else:
            merged[key] = value
    return merged

MEL_PARAMS = _normalize_mel_params(MEL_PARAMS)



AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}


def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate


def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None

    def _sort_key(p: Path) -> tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime

    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None


def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}

    import yaml

    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}

    dataset_params = data.get("dataset_params", {})
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)

    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr

    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")

    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)

    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")

    return data


import re

TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = MEL_PARAMS["sample_rate"]
HOP_LENGTH = MEL_PARAMS["hop_length"]
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)



def _collect_model_configuration(checkpoint_state: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    model_params: Dict[str, Any] = {}
    training_params = TRAINING_CONFIG.get("model_params")
    if isinstance(training_params, dict):
        model_params = _deep_merge_dict(model_params, training_params)
    if isinstance(checkpoint_state, dict):
        candidate = checkpoint_state.get("model_params")
        if isinstance(candidate, dict):
            model_params = _deep_merge_dict(model_params, candidate)
        config_section = checkpoint_state.get("config")
        if isinstance(config_section, dict):
            candidate = config_section.get("model_params")
            if isinstance(candidate, dict):
                model_params = _deep_merge_dict(model_params, candidate)
    sequence_config = model_params.get("sequence_model")
    if isinstance(sequence_config, dict):
        sequence_config = sequence_config.copy()
    else:
        sequence_config = {}
    top_level = {k: v for k, v in model_params.items() if k != "sequence_model"}
    top_level.pop("num_class", None)
    return top_level, sequence_config

def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)


def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int] | None:
    try:
        audio, sr = sf.read(str(path), dtype="float32")
    except KeyboardInterrupt:
        raise
    except Exception as exc:
        print(f"Warning: skipping unreadable file '{path}': {exc}")
        return None
    audio = _ensure_mono(audio)
    if audio.size == 0:
        print(f"Warning: skipping empty file '{path}'.")
        return None
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)


def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = CONFIG["chunk_size"]
    overlap = CONFIG["chunk_overlap"]
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)

def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]
    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]
    total_frames = length
    voiced_frames = int(np.count_nonzero(ref_voiced))
    vuv_accuracy = float(np.count_nonzero(ref_voiced == pred_voiced) / max(total_frames, 1))
    if voiced_frames == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "OctaveError": float("nan"),
        }
    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))
    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0
    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0
    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0)
    return {
        "RPA": float(np.count_nonzero(rpa_hits) / voiced_frames),
        "RCA": float(np.count_nonzero(rca_hits) / voiced_frames),
        "VUV": vuv_accuracy,
        "OctaveError": float(np.count_nonzero(octave_errors) / voiced_frames),
    }



def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if not isinstance(checkpoint, dict):
        raise RuntimeError("Unexpected checkpoint format")
    model_state = checkpoint.get("model")
    if model_state is None:
        model_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(model_state, dict):
        raise RuntimeError("Checkpoint is missing a valid model state")
    model_state = dict(model_state)
    classifier_weight = model_state.get("classifier.weight")
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(checkpoint.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    model_params, sequence_model_config = _collect_model_configuration(checkpoint)
    model_kwargs: Dict[str, Any] = {}
    if sequence_model_config:
        model_kwargs["sequence_model_config"] = sequence_model_config
    leaky_relu_slope = model_params.get("leaky_relu_slope")
    if isinstance(leaky_relu_slope, (int, float)):
        model_kwargs["leaky_relu_slope"] = float(leaky_relu_slope)
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes, **model_kwargs)
    incompatible = model.load_state_dict(model_state, strict=False)
    if incompatible.unexpected_keys:
        skipped = ", ".join(sorted(incompatible.unexpected_keys))
        print(f"Warning: skipped unexpected parameters: {skipped}")
    if incompatible.missing_keys:
        missing = ", ".join(sorted(incompatible.missing_keys))
        print(f"Warning: missing parameters not found in checkpoint: {missing}")
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model

In [None]:
model = load_model()


In [None]:
vibrato_cfg = CONFIG["vibrato"]
base_freq = float(vibrato_cfg["base_frequency_hz"])
duration = float(vibrato_cfg["duration_seconds"])
rates = [float(r) for r in vibrato_cfg["rates_hz"]]
depths = [float(d) for d in vibrato_cfg["depth_cents"]]

VIBRATO_RESULTS: List[Dict[str, float]] = []

for rate in rates:
    for depth in depths:
        audio, t, f0_curve = generate_vibrato_waveform(rate, depth, base_freq, duration, TARGET_SAMPLE_RATE)
        prediction = predict_f0(model, audio)
        reference = sample_reference_f0(t, f0_curve, prediction.shape[0])
        metrics = compute_metrics(reference, prediction)
        rmse = rms_cents_error(reference, prediction)
        VIBRATO_RESULTS.append({
            "rate_hz": rate,
            "depth_cents": depth,
            "RPA": metrics["RPA"],
            "RCA": metrics["RCA"],
            "VUV": metrics["VUV"],
            "OctaveError": metrics["OctaveError"],
            "RMSE_cents": rmse,
        })

vibrato_df = pd.DataFrame(VIBRATO_RESULTS)
vibrato_df.sort_values(["rate_hz", "depth_cents"], inplace=True)
vibrato_df


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharex=True)

for rate in rates:
    subset = vibrato_df[vibrato_df["rate_hz"] == rate]
    axes[0].plot(subset["depth_cents"], subset["RPA"], marker="o", label=f"{rate:.1f} Hz")
    axes[1].plot(subset["depth_cents"], subset["RMSE_cents"], marker="o", label=f"{rate:.1f} Hz")

axes[0].set_title("RPA vs. Vibrato Depth")
axes[0].set_ylabel("RPA")
axes[1].set_title("RMSE (cents) vs. Vibrato Depth")
axes[1].set_ylabel("RMSE (cents)")
for ax in axes:
    ax.set_xlabel("Vibrato depth (cents)")
    ax.set_xticks(depths)
    ax.grid(True)
axes[-1].legend(title="Rate")
plt.tight_layout()
plt.show()


In [None]:
glide_cfg = CONFIG["glide"]
start_hz = float(glide_cfg["start_hz"])
end_hz = float(glide_cfg["end_hz"])
durations = [float(d) for d in glide_cfg["durations_seconds"]]

GLIDE_RESULTS: List[Dict[str, float]] = []

for duration in durations:
    audio, t, f0_curve = generate_glide_waveform(duration, start_hz, end_hz, TARGET_SAMPLE_RATE)
    prediction = predict_f0(model, audio)
    reference = sample_reference_f0(t, f0_curve, prediction.shape[0])
    metrics = compute_metrics(reference, prediction)
    rmse = rms_cents_error(reference, prediction)
    lag_ms = estimate_tracking_delay_ms(reference, prediction, FRAME_PERIOD_MS)
    overshoot = compute_overshoot_cents(reference, prediction)
    final_error = float(1200.0 * np.log2(max(prediction[-1], 1e-5) / max(reference[-1], 1e-5))) if prediction.size and reference[-1] > 0 else float("nan")
    GLIDE_RESULTS.append({
        "duration_s": duration,
        "RPA": metrics["RPA"],
        "RCA": metrics["RCA"],
        "VUV": metrics["VUV"],
        "OctaveError": metrics["OctaveError"],
        "RMSE_cents": rmse,
        "Lag_ms": lag_ms,
        "Overshoot_cents": overshoot,
        "Final_error_cents": final_error,
    })

glide_df = pd.DataFrame(GLIDE_RESULTS)
glide_df.sort_values("duration_s", inplace=True)
glide_df


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(glide_df["duration_s"], glide_df["Lag_ms"], marker="o")
axes[0].set_title("Tracking Delay vs. Glide Duration")
axes[0].set_xlabel("Duration (s)")
axes[0].set_ylabel("Lag (ms)")

axes[1].plot(glide_df["duration_s"], glide_df["Overshoot_cents"], marker="o")
axes[1].set_title("Overshoot vs. Glide Duration")
axes[1].set_xlabel("Duration (s)")
axes[1].set_ylabel("Overshoot (cents)")

axes[2].plot(glide_df["duration_s"], glide_df["RMSE_cents"], marker="o")
axes[2].set_title("RMSE vs. Glide Duration")
axes[2].set_xlabel("Duration (s)")
axes[2].set_ylabel("RMSE (cents)")

for ax in axes:
    ax.grid(True)

plt.tight_layout()
plt.show()


In [None]:
vibrato_path = CONFIG["output_dir"] / "dynamic_pitch_vibrato_metrics.csv"
vibrato_df.to_csv(vibrato_path, index=False)
print(f"Saved vibrato metrics to {vibrato_path.resolve()}")


glide_path = CONFIG["output_dir"] / "dynamic_pitch_glide_metrics.csv"
glide_df.to_csv(glide_path, index=False)
print(f"Saved glide metrics to {glide_path.resolve()}")
