
# Phonetic Confusability Evaluation

This notebook evaluates how a trained pitch extraction model handles phonetic voicing contrasts. It focuses on stops and fricatives that differ only in voicing (e.g., /b/ vs. /p/) and measures how often the model incorrectly predicts voicing or unvoicing for each phoneme. By sweeping across phonemes and plotting false-voicing and false-unvoicing rates, we can identify where the model begins to fail.



## Environment Setup

Uncomment and run the following cell if the dependencies used by this notebook are not installed in your environment.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld pandas tqdm seaborn


## Imports and Global Configuration

The repository root is appended to `sys.path` so the notebook can reuse utilities from the training codebase. Update the configuration dictionary with paths that are valid on your machine before running the analysis.


In [None]:

import os
import sys
from dataclasses import dataclass
from pathlib import Path
import math
import re
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf
import pyworld as pw
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet

plt.style.use("seaborn-v0_8")
plt.rcParams.update({
    "figure.figsize": (12, 4),
    "axes.grid": True,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


In [None]:

CONFIG: Dict[str, Any] = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "phoneme_metadata_csv": REPO_ROOT / "data" / "phoneme_segments.csv",
    "audio_root": REPO_ROOT / "data",
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts",
    "phoneme_column": "phoneme",
    "start_column": "start_time",
    "end_column": "end_time",
    "audio_column": "audio_path",
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG



## Helper Utilities

The following helpers load the training configuration, instantiate the mel-spectrogram transform, resolve checkpoint paths, and handle model inference over arbitrary-length audio. The utilities mirror the approach used by the training pipeline so that evaluation is consistent with how the model was optimized.


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()


def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate


def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None

    def _sort_key(p: Path) -> Tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime

    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None


def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}

    import yaml

    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}

    dataset_params = data.get("dataset_params", {})
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    MEL_PARAMS.setdefault("sample_rate", DEFAULT_MEL_PARAMS["sample_rate"])
    return data


TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = int(MEL_PARAMS.get("sample_rate", DEFAULT_MEL_PARAMS["sample_rate"]))
print(f"Target sample rate: {TARGET_SAMPLE_RATE} Hz")

mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)
mel_transform.eval()

In [None]:

def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    resolved_path: Optional[Path] = None

    if checkpoint_path is not None:
        resolved_path = Path(checkpoint_path)
    else:
        config_checkpoint = CONFIG.get("checkpoint_path")
        if config_checkpoint:
            resolved_path = Path(config_checkpoint)

    if resolved_path is None:
        checkpoint_dir = Path(CONFIG.get("checkpoint_dir", REPO_ROOT / "Checkpoint"))
        resolved_path = _latest_checkpoint(checkpoint_dir)

    if resolved_path is None or not resolved_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {resolved_path}")

    state = torch.load(resolved_path, map_location=DEVICE)
    model_state = state.get("model", state)
    classifier_weight = model_state.get("classifier.weight") if isinstance(model_state, dict) else None
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(state.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722

    model = JDCNet(num_class=inferred_classes)
    if isinstance(model_state, dict):
        model.load_state_dict(model_state)
    else:
        model.load_state_dict(state)
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint: {resolved_path}")
    return model


model = load_model()



## Phoneme Metadata Loading

Evaluation relies on a metadata table that lists the audio files and phoneme-aligned segments to analyze. The CSV file is expected to include at least the following columns:

* `audio_path`: path to the waveform relative to `audio_root` (or absolute paths).
* `phoneme`: the ARPAbet, IPA, or custom symbol for the segment.
* `start_time` and `end_time`: segment bounds in seconds (optional; leave blank to use the full file).

Additional columns are ignored but preserved in the returned DataFrame. Segments are filtered to the phoneme set defined in `TARGET_PHONEMES` below.


In [None]:

TARGET_PHONEMES: Dict[str, List[str]] = {
    "voiced": ["b", "d", "g", "v", "z", "ʒ"],
    "unvoiced": ["p", "t", "k", "f", "s", "ʃ"],
}

ALL_TARGET_PHONEMES = sorted({p for group in TARGET_PHONEMES.values() for p in group})


def load_phoneme_metadata(path: Path) -> pd.DataFrame:
    path = _resolve_relative_path(REPO_ROOT, path)
    if not path.is_file():
        raise FileNotFoundError(f"Phoneme metadata CSV not found: {path}")
    df = pd.read_csv(path)
    phoneme_col = CONFIG.get("phoneme_column", "phoneme")
    df = df[df[phoneme_col].isin(ALL_TARGET_PHONEMES)].copy()
    df.reset_index(drop=True, inplace=True)
    if df.empty:
        raise ValueError("No rows match the target phoneme set. Check the metadata file and configuration.")
    return df


phoneme_metadata = load_phoneme_metadata(CONFIG["phoneme_metadata_csv"])
phoneme_metadata.head()



## Audio and F0 Extraction Helpers

To ensure consistent frame alignment between the ground-truth F0 (computed with WORLD) and the model predictions, the helpers below cache waveforms and F0 tracks, slice segments based on the metadata, and convert waveforms to mel spectrograms before running the network. Reference and predicted F0 sequences are trimmed to the same length prior to metric computation.


In [None]:
_AUDIO_CACHE: Dict[Path, Tuple[np.ndarray, int]] = {}
_F0_CACHE: Dict[Path, np.ndarray] = {}


def load_waveform(path: str | Path) -> Tuple[np.ndarray, int]:
    resolved = _resolve_relative_path(CONFIG.get("audio_root", REPO_ROOT), path)
    if resolved in _AUDIO_CACHE:
        return _AUDIO_CACHE[resolved]
    if not resolved.is_file():
        raise FileNotFoundError(f"Audio file not found: {resolved}")
    audio, sr = sf.read(resolved)
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    audio = audio.astype(np.float32)
    if sr != TARGET_SAMPLE_RATE:
        audio = torchaudio.functional.resample(torch.from_numpy(audio), sr, TARGET_SAMPLE_RATE).numpy()
        sr = TARGET_SAMPLE_RATE
    _AUDIO_CACHE[resolved] = (audio, sr)
    return audio, sr


def load_audio_and_f0(path: str | Path) -> Tuple[np.ndarray, int, np.ndarray]:
    resolved = _resolve_relative_path(CONFIG.get("audio_root", REPO_ROOT), path)
    audio, sr = load_waveform(resolved)
    if resolved in _F0_CACHE:
        return audio, sr, _F0_CACHE[resolved]
    f0 = compute_world_f0(audio, sr)
    _F0_CACHE[resolved] = f0
    return audio, sr, f0


def compute_world_f0(audio: np.ndarray, sr: int) -> np.ndarray:
    frame_period = MEL_PARAMS["hop_length"] * 1000.0 / sr
    f0, t = pw.harvest(audio.astype(np.float64), sr, frame_period=frame_period)
    if np.count_nonzero(f0) < 5:
        f0, t = pw.dio(audio.astype(np.float64), sr, frame_period=frame_period)
    refined = pw.stonemask(audio.astype(np.float64), f0, t, sr)
    return refined.astype(np.float32)


def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().to(DEVICE)
    tensor = tensor.unsqueeze(0)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)


def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = int(CONFIG.get("chunk_size", 192))
    overlap = int(CONFIG.get("chunk_overlap", 48))
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)


def slice_segment(audio: np.ndarray, sr: int, start: Optional[float], end: Optional[float]) -> np.ndarray:
    if start is None and end is None:
        return audio
    start_idx = 0 if start is None or math.isnan(start) else int(max(start * sr, 0))
    end_idx = len(audio) if end is None or math.isnan(end) else int(min(end * sr, len(audio)))
    if end_idx <= start_idx:
        end_idx = min(start_idx + sr // MEL_PARAMS.get("hop_length", 300), len(audio))
    return audio[start_idx:end_idx]


def extract_reference_segment(f0: np.ndarray, sr: int, start: Optional[float], end: Optional[float]) -> np.ndarray:
    hop = MEL_PARAMS.get("hop_length", 300)
    frame_period = hop / sr
    start_frame = 0 if start is None or math.isnan(start) else int(max(math.floor(start / frame_period), 0))
    end_frame = len(f0) if end is None or math.isnan(end) else int(min(math.ceil(end / frame_period), len(f0)))
    if end_frame <= start_frame:
        end_frame = min(start_frame + 1, len(f0))
    return f0[start_frame:end_frame]


def align_reference_to_prediction(reference: np.ndarray, prediction: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    length = min(len(reference), len(prediction))
    if length <= 0:
        return reference, prediction
    return reference[:length], prediction[:length]


## Metric Computation

False-voicing and false-unvoicing rates are computed on a per-phoneme basis. A "false-voicing" event occurs when the ground-truth segment is unvoiced (F0 == 0 Hz) but the model predicts a voiced frame. Conversely, "false-unvoicing" counts voiced ground-truth frames where the model predicts silence or an unvoiced frame. Rates are normalized by the number of reference frames belonging to each class.


In [None]:

@dataclass
class SegmentResult:
    phoneme: str
    category: str
    total_frames: int
    voiced_frames: int
    unvoiced_frames: int
    false_voicing: int
    false_unvoicing: int
    false_voicing_rate: float
    false_unvoicing_rate: float


def compute_segment_metrics(phoneme: str, reference: np.ndarray, prediction: np.ndarray) -> SegmentResult:
    reference, prediction = align_reference_to_prediction(reference, prediction)
    total_frames = int(min(len(reference), len(prediction)))
    if total_frames == 0:
        return SegmentResult(
            phoneme=phoneme,
            category=_categorize_phoneme(phoneme),
            total_frames=0,
            voiced_frames=0,
            unvoiced_frames=0,
            false_voicing=0,
            false_unvoicing=0,
            false_voicing_rate=float("nan"),
            false_unvoicing_rate=float("nan"),
        )

    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG.get("voicing_threshold_hz", 10.0)

    voiced_frames = int(np.count_nonzero(ref_voiced))
    unvoiced_frames = total_frames - voiced_frames

    false_voicing = int(np.count_nonzero(~ref_voiced & pred_voiced))
    false_unvoicing = int(np.count_nonzero(ref_voiced & ~pred_voiced))

    false_voicing_rate = float(false_voicing / unvoiced_frames) if unvoiced_frames > 0 else float("nan")
    false_unvoicing_rate = float(false_unvoicing / voiced_frames) if voiced_frames > 0 else float("nan")

    return SegmentResult(
        phoneme=phoneme,
        category=_categorize_phoneme(phoneme),
        total_frames=total_frames,
        voiced_frames=voiced_frames,
        unvoiced_frames=unvoiced_frames,
        false_voicing=false_voicing,
        false_unvoicing=false_unvoicing,
        false_voicing_rate=false_voicing_rate,
        false_unvoicing_rate=false_unvoicing_rate,
    )


def _categorize_phoneme(phoneme: str) -> str:
    for category, phones in TARGET_PHONEMES.items():
        if phoneme in phones:
            return category
    return "other"



## Run Phoneme Sweep

The loop below iterates over every segment listed in the metadata file, runs the model, and aggregates the per-segment metrics into a DataFrame. The resulting table contains one row per segment, which can be grouped or pivoted to analyze error patterns.


In [None]:
results: List[SegmentResult] = []
phoneme_col = CONFIG.get("phoneme_column", "phoneme")
start_col = CONFIG.get("start_column", "start_time")
end_col = CONFIG.get("end_column", "end_time")
audio_col = CONFIG.get("audio_column", "audio_path")

for _, row in tqdm(phoneme_metadata.iterrows(), total=len(phoneme_metadata), desc="Evaluating segments"):
    phoneme = str(row[phoneme_col])
    start_time = row.get(start_col)
    end_time = row.get(end_col)
    start_time = None if pd.isna(start_time) else float(start_time)
    end_time = None if pd.isna(end_time) else float(end_time)

    audio, sr, f0_track = load_audio_and_f0(row[audio_col])

    segment_audio = slice_segment(audio, sr, start_time, end_time)
    reference_segment = extract_reference_segment(f0_track, sr, start_time, end_time)

    prediction = predict_f0(model, segment_audio)

    segment_result = compute_segment_metrics(phoneme, reference_segment, prediction)
    results.append(segment_result)

results_df = pd.DataFrame([r.__dict__ for r in results])
results_df.head()


## Aggregate Metrics by Phoneme

Aggregating the per-segment results reveals systematic voicing errors for each phoneme. The summary statistics below average the false-voicing and false-unvoicing rates across all segments belonging to the same phoneme.


In [None]:

phoneme_summary = (
    results_df.groupby(["phoneme", "category"])
    .agg({
        "total_frames": "sum",
        "voiced_frames": "sum",
        "unvoiced_frames": "sum",
        "false_voicing": "sum",
        "false_unvoicing": "sum",
        "false_voicing_rate": "mean",
        "false_unvoicing_rate": "mean",
    })
    .reset_index()
)
phoneme_summary.sort_values("phoneme", inplace=True)
phoneme_summary



## Visualization

False-voicing and false-unvoicing rates are plotted against the phoneme condition levels. This view highlights which phonemes are most prone to voicing confusions.


In [None]:
ordered_phonemes = [p for group in TARGET_PHONEMES.values() for p in group]
plot_df = phoneme_summary.set_index("phoneme").reindex(ordered_phonemes).reset_index()
plot_df.rename(columns={"index": "phoneme"}, inplace=True)

fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharey=False)

valid_false_voicing = plot_df["false_voicing_rate"].dropna()
max_false_voicing = float(valid_false_voicing.max()) if not valid_false_voicing.empty else 0.0

sns.barplot(
    data=plot_df,
    x="phoneme",
    y="false_voicing_rate",
    hue="category",
    palette="Blues",
    ax=axes[0],
)
axes[0].set_title("False-Voicing Rate by Phoneme")
axes[0].set_ylabel("False-Voicing Rate")
axes[0].set_xlabel("Phoneme")
axes[0].legend(title="Category")
axes[0].set_ylim(0, min(1.0, max(0.01, max_false_voicing * 1.1 if max_false_voicing > 0 else 0.1)))

valid_false_unvoicing = plot_df["false_unvoicing_rate"].dropna()
max_false_unvoicing = float(valid_false_unvoicing.max()) if not valid_false_unvoicing.empty else 0.0

sns.barplot(
    data=plot_df,
    x="phoneme",
    y="false_unvoicing_rate",
    hue="category",
    palette="Reds",
    ax=axes[1],
)
axes[1].set_title("False-Unvoicing Rate by Phoneme")
axes[1].set_ylabel("False-Unvoicing Rate")
axes[1].set_xlabel("Phoneme")
axes[1].legend(title="Category")
axes[1].set_ylim(0, min(1.0, max(0.01, max_false_unvoicing * 1.1 if max_false_unvoicing > 0 else 0.1)))

plt.tight_layout()
plt.show()


## Category-Level Summary

The table below aggregates metrics by voiced vs. unvoiced classes to expose broader trends across the phoneme groups.


In [None]:

category_summary = (
    results_df.groupby("category")
    .agg({
        "total_frames": "sum",
        "voiced_frames": "sum",
        "unvoiced_frames": "sum",
        "false_voicing": "sum",
        "false_unvoicing": "sum",
    })
    .reset_index()
)
category_summary["false_voicing_rate"] = category_summary["false_voicing"] / category_summary["unvoiced_frames"].replace({0: np.nan})
category_summary["false_unvoicing_rate"] = category_summary["false_unvoicing"] / category_summary["voiced_frames"].replace({0: np.nan})
category_summary



## Save Results

All intermediate tables are saved to the configured `output_dir` so they can be compared across checkpoints or combined with other evaluations.


In [None]:

output_dir = CONFIG["output_dir"]
phoneme_summary_path = output_dir / "phoneme_confusability_summary.csv"
segment_results_path = output_dir / "phoneme_confusability_segments.csv"
category_summary_path = output_dir / "phoneme_confusability_category.csv"

phoneme_summary.to_csv(phoneme_summary_path, index=False)
results_df.to_csv(segment_results_path, index=False)
category_summary.to_csv(category_summary_path, index=False)

phoneme_summary_path, segment_results_path, category_summary_path



## Next Steps

* Compare multiple checkpoints by pointing `checkpoint_path` to different models and re-running the notebook.
* Expand the phoneme list to include additional minimal pairs or language-specific contrasts.
* Join the per-segment metrics with acoustic metadata (e.g., SNR, speaker identity) to understand when voicing errors are most likely.
