# Amplitude Pathology Sweep

This notebook evaluates how a trained pitch extraction model handles amplitude-related pathologies. We synthesise clean reference tones, introduce clipping and AGC pumping across severity levels, and measure performance degradation via voiced/unvoiced (V/UV) flips and raw chroma accuracy (RCA).


## Environment Setup

Uncomment and run the cell below if you need to install the required dependencies inside your runtime environment.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld torchcrepe praat-parselmouth pandas tqdm


## Imports and Global Configuration

The repository root is added to `sys.path` so that we can reuse the training utilities when running the evaluation.


In [None]:
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet
from Utils.dynamic_pitch_tools import (
    synthesize_from_f0_curve,
    sample_reference_f0,
    hz_to_cents,
    circular_cents_distance,
)

plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


## User Configuration

Update the configuration below to point at your training config, checkpoint directory, and evaluation preferences.


In [None]:
CONFIG = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts" / "amplitude_pathologies",
    "stimulus": {
        "frequencies_hz": [110.0, 220.0, 440.0],
        "duration_seconds": 3.0,
        "amplitude": 0.85,
    },
    "clipping": {
        "levels_percent": [0, 2, 4, 6, 8, 10],
    },
    "agc": {
        "levels_db": [0, 2, 4, 6, 8, 10],
        "target_rms": 0.15,
    },
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG


## Helper Functions

Utility functions for loading checkpoints, running inference, and computing melody extraction metrics.


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()


def _normalize_mel_params(params: Dict[str, Any]) -> Dict[str, Any]:
    if "win_len" in params:
        win_len = params.pop("win_len")
        params.setdefault("win_length", win_len)
    return params

def _deep_merge_dict(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    merged = base.copy()
    for key, value in overrides.items():
        if isinstance(value, dict):
            existing = merged.get(key)
            if isinstance(existing, dict):
                merged[key] = _deep_merge_dict(existing, value)
            else:
                merged[key] = value.copy()
        else:
            merged[key] = value
    return merged

MEL_PARAMS = _normalize_mel_params(MEL_PARAMS)


AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}


def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate


def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None

    def _sort_key(p: Path) -> tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime

    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None


def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}

    import yaml

    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}

    dataset_params = data.get("dataset_params", {})
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)

    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr

    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")

    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)

    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")

    return data


import re

TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = MEL_PARAMS["sample_rate"]
HOP_LENGTH = MEL_PARAMS["hop_length"]
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)


def _collect_model_configuration(checkpoint_state: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    model_params: Dict[str, Any] = {}
    training_params = TRAINING_CONFIG.get("model_params")
    if isinstance(training_params, dict):
        model_params = _deep_merge_dict(model_params, training_params)
    if isinstance(checkpoint_state, dict):
        candidate = checkpoint_state.get("model_params")
        if isinstance(candidate, dict):
            model_params = _deep_merge_dict(model_params, candidate)
        config_section = checkpoint_state.get("config")
        if isinstance(config_section, dict):
            candidate = config_section.get("model_params")
            if isinstance(candidate, dict):
                model_params = _deep_merge_dict(model_params, candidate)
    sequence_config = model_params.get("sequence_model")
    if isinstance(sequence_config, dict):
        sequence_config = sequence_config.copy()
    else:
        sequence_config = {}
    top_level = {k: v for k, v in model_params.items() if k != "sequence_model"}
    top_level.pop("num_class", None)
    return top_level, sequence_config


def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)


def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int] | None:
    try:
        audio, sr = sf.read(str(path), dtype="float32")
    except KeyboardInterrupt:
        raise
    except Exception as exc:
        print(f"Warning: skipping unreadable file '{path}': {exc}")
        return None
    audio = _ensure_mono(audio)
    if audio.size == 0:
        print(f"Warning: skipping empty file '{path}'.")
        return None
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)


def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = CONFIG["chunk_size"]
    overlap = CONFIG["chunk_overlap"]
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)


def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]
    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]
    total_frames = length
    voiced_frames = int(np.count_nonzero(ref_voiced))
    vuv_accuracy = float(np.count_nonzero(ref_voiced == pred_voiced) / max(total_frames, 1))
    if voiced_frames == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "OctaveError": float("nan"),
        }
    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))
    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0
    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0
    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0)
    return {
        "RPA": float(np.count_nonzero(rpa_hits) / voiced_frames),
        "RCA": float(np.count_nonzero(rca_hits) / voiced_frames),
        "VUV": vuv_accuracy,
        "OctaveError": float(np.count_nonzero(octave_errors) / voiced_frames),
    }


def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if not isinstance(checkpoint, dict):
        raise RuntimeError("Unexpected checkpoint format")
    model_state = checkpoint.get("model")
    if model_state is None:
        model_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(model_state, dict):
        raise RuntimeError("Checkpoint is missing a valid model state")
    model_state = dict(model_state)
    classifier_weight = model_state.get("classifier.weight")
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(checkpoint.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    model_params, sequence_model_config = _collect_model_configuration(checkpoint)
    model_kwargs: Dict[str, Any] = {}
    if sequence_model_config:
        model_kwargs["sequence_model_config"] = sequence_model_config
    leaky_relu_slope = model_params.get("leaky_relu_slope")
    if isinstance(leaky_relu_slope, (int, float)):
        model_kwargs["leaky_relu_slope"] = float(leaky_relu_slope)
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes, **model_kwargs)
    incompatible = model.load_state_dict(model_state, strict=False)
    if incompatible.unexpected_keys:
        skipped = ", ".join(sorted(incompatible.unexpected_keys))
        print(f"Warning: skipped unexpected parameters: {skipped}")
    if incompatible.missing_keys:
        missing = ", ".join(sorted(incompatible.missing_keys))
        print(f"Warning: missing parameters not found in checkpoint: {missing}")
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model

## Stimulus Generation and Perturbations

Functions for synthesising reference tones and applying amplitude pathologies.


In [None]:
def generate_reference_stimuli() -> List[Dict[str, Any]]:
    sr = TARGET_SAMPLE_RATE
    duration = float(CONFIG["stimulus"]["duration_seconds"])
    amplitude = float(CONFIG["stimulus"]["amplitude"])
    num_samples = int(duration * sr)
    time_axis = np.arange(num_samples, dtype=np.float32) / float(sr)
    stimuli: List[Dict[str, Any]] = []
    for freq in CONFIG["stimulus"]["frequencies_hz"]:
        f0_curve = np.full((num_samples,), float(freq), dtype=np.float32)
        audio = synthesize_from_f0_curve(f0_curve, sr, amplitude=amplitude)
        stimuli.append(
            {
                "id": f"tone_{int(freq)}Hz",
                "audio": audio.astype(np.float32, copy=False),
                "sr": sr,
                "time_axis": time_axis.copy(),
                "f0_curve": f0_curve,
            }
        )
    return stimuli


def apply_sample_clipping(audio: np.ndarray, percent: float, sr: int, **_: Any) -> np.ndarray:
    percent = float(percent)
    if percent <= 0:
        return audio.astype(np.float32, copy=True)
    threshold = np.quantile(np.abs(audio), max(0.0, 1.0 - percent / 100.0))
    if threshold <= 0:
        return audio.astype(np.float32, copy=True)
    clipped = np.clip(audio, -threshold, threshold)
    return clipped.astype(np.float32, copy=False)


def apply_agc_pumping(audio: np.ndarray, level_db: float, sr: int, target_rms: float, **_: Any) -> np.ndarray:
    level_db = float(level_db)
    if level_db <= 0:
        return audio.astype(np.float32, copy=True)
    attack = 0.01
    release = np.interp(level_db, [0.0, 10.0], [0.05, 0.4])
    depth_db = np.interp(level_db, [0.0, 10.0], [3.0, 18.0])
    attack_coeff = np.exp(-1.0 / (attack * sr))
    release_coeff = np.exp(-1.0 / (release * sr))
    env = 0.0
    gains = np.zeros_like(audio, dtype=np.float32)
    for i, sample in enumerate(audio):
        rectified = abs(float(sample))
        if rectified > env:
            env = attack_coeff * env + (1.0 - attack_coeff) * rectified
        else:
            env = release_coeff * env + (1.0 - release_coeff) * rectified
        desired = target_rms / (env + 1e-6)
        max_gain = 10 ** (depth_db / 20.0)
        gains[i] = np.clip(desired, 1.0 / max_gain, max_gain)
    smoothing = int(sr * np.interp(level_db, [0.0, 10.0], [0.01, 0.12]))
    if smoothing > 1:
        kernel = np.ones(smoothing, dtype=np.float32) / smoothing
        gains = np.convolve(gains, kernel, mode="same")
    pumped = audio * gains
    pumped = np.clip(pumped, -1.0, 1.0)
    return pumped.astype(np.float32, copy=False)


def evaluate_pathology(
    model: JDCNet,
    stimuli: List[Dict[str, Any]],
    levels: List[float],
    transform_fn,
    pathology_name: str,
    transform_kwargs: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
    transform_kwargs = transform_kwargs or {}
    levels = list(levels)
    if 0 not in levels:
        levels = [0] + levels
    results: List[Dict[str, Any]] = []
    baseline_cache: Dict[str, Dict[str, Any]] = {}
    for stimulus in stimuli:
        prediction = predict_f0(model, stimulus["audio"])
        reference = sample_reference_f0(stimulus["time_axis"], stimulus["f0_curve"], prediction.shape[0])
        metrics = compute_metrics(reference, prediction)
        voicing = prediction > CONFIG["voicing_threshold_hz"]
        baseline_cache[stimulus["id"]] = {
            "prediction": prediction,
            "voicing": voicing,
            "reference": reference,
        }
        results.append(
            {
                "pathology": pathology_name,
                "level": 0.0,
                "stimulus": stimulus["id"],
                "RCA": metrics["RCA"],
                "VUV_flips": 0.0,
                "VUV_accuracy": metrics["VUV"],
            }
        )
    for level in levels:
        if level == 0:
            continue
        for stimulus in stimuli:
            kwargs = dict(transform_kwargs)
            perturbed = transform_fn(stimulus["audio"], level, stimulus["sr"], **kwargs)
            prediction = predict_f0(model, perturbed)
            reference = sample_reference_f0(stimulus["time_axis"], stimulus["f0_curve"], prediction.shape[0])
            metrics = compute_metrics(reference, prediction)
            baseline = baseline_cache[stimulus["id"]]
            baseline_voicing = baseline["voicing"]
            candidate_voicing = prediction > CONFIG["voicing_threshold_hz"]
            length = min(baseline_voicing.shape[0], candidate_voicing.shape[0])
            if length == 0:
                flip_rate = float("nan")
            else:
                flip_rate = float(np.count_nonzero(baseline_voicing[:length] != candidate_voicing[:length]) / length)
            results.append(
                {
                    "pathology": pathology_name,
                    "level": float(level),
                    "stimulus": stimulus["id"],
                    "RCA": metrics["RCA"],
                    "VUV_flips": flip_rate,
                    "VUV_accuracy": metrics["VUV"],
                }
            )
    df = pd.DataFrame(results)
    return df.sort_values(["pathology", "level", "stimulus"]).reset_index(drop=True)


## Load Model and Generate Stimuli

Instantiate the trained model and synthesise the clean reference tones used throughout the sweeps.


In [None]:
model = load_model()
stimuli = generate_reference_stimuli()
print(f"Prepared {len(stimuli)} stimuli at {TARGET_SAMPLE_RATE} Hz")


## Run Amplitude Pathology Sweeps

Evaluate clipping and AGC pumping severity sweeps, collecting RCA and V/UV flip statistics.


In [None]:
clipping_levels = CONFIG["clipping"]["levels_percent"]
agc_levels = CONFIG["agc"]["levels_db"]

df_clipping = evaluate_pathology(
    model,
    stimuli,
    clipping_levels,
    apply_sample_clipping,
    pathology_name="Clipping",
)

df_agc = evaluate_pathology(
    model,
    stimuli,
    agc_levels,
    apply_agc_pumping,
    pathology_name="AGC pumping",
    transform_kwargs={"target_rms": CONFIG["agc"]["target_rms"]},
)

results_df = pd.concat([df_clipping, df_agc], ignore_index=True)
results_df.head()


## Aggregate Metrics

Summarise average RCA and V/UV flip rates across stimuli for each severity level.


In [None]:
summary = (
    results_df
    .groupby(["pathology", "level"], as_index=False)
    .agg(
        RCA_mean=("RCA", "mean"),
        RCA_std=("RCA", "std"),
        VUV_flip_mean=("VUV_flips", "mean"),
        VUV_flip_std=("VUV_flips", "std"),
        VUV_accuracy_mean=("VUV_accuracy", "mean"),
        VUV_accuracy_std=("VUV_accuracy", "std"),
    )
    .sort_values(["pathology", "level"])
)
summary


## Plot Metric vs Condition Level

Visualise how RCA and V/UV flip rates evolve with increasing amplitude-pathology severity.


In [None]:
for pathology, group in summary.groupby("pathology"):
    fig, axes = plt.subplots(1, 2, figsize=(14, 4), sharex=True)
    levels = group["level"].to_numpy()
    rca_mean = group["RCA_mean"].to_numpy()
    rca_std = group["RCA_std"].fillna(0.0).to_numpy()
    vuv_mean = group["VUV_flip_mean"].to_numpy()
    vuv_std = group["VUV_flip_std"].fillna(0.0).to_numpy()

    axes[0].errorbar(levels, rca_mean, yerr=rca_std, fmt="-o", capsize=4)
    axes[0].set_title(f"{pathology} — RCA")
    axes[0].set_xlabel("Condition level")
    axes[0].set_ylabel("Raw Chroma Accuracy")
    axes[0].set_ylim(0.0, 1.05)

    axes[1].errorbar(levels, vuv_mean, yerr=vuv_std, fmt="-o", color="tab:red", capsize=4)
    axes[1].set_title(f"{pathology} — V/UV flip rate")
    axes[1].set_xlabel("Condition level")
    axes[1].set_ylabel("Fraction of frames with voicing flips")
    axes[1].set_ylim(0.0, 1.0)

    if pathology.lower().startswith("clipping"):
        axes[0].set_xlabel("Clipped samples (%)")
        axes[1].set_xlabel("Clipped samples (%)")
    else:
        axes[0].set_xlabel("AGC severity (dB)")
        axes[1].set_xlabel("AGC severity (dB)")

    fig.suptitle(f"{pathology} severity sweep")
    plt.tight_layout()
    plt.show()
