# Codec and Bandwidth Torture Evaluation


This notebook evaluates how the latest pitch extraction model handles aggressive resampling and lossy compression.We sweep over a range of sample rates and codec/bitrate combinations, compare the resulting pitch metrics against thereference (clean) audio, and visualize where performance begins to degrade.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld torchcrepe praat-parselmouth pandas tqdm

In [None]:

import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import math
import re
import shutil
import subprocess
import tempfile

import numpy as np
import pandas as pd
import torch
import torchaudio
import soundfile as sf

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

from Utils.f0_notebook_utils import (
    build_notebook_f0_extractor,
    compute_f0_for_notebook,
)
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


In [None]:

CONFIG: Dict[str, Any] = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "eval_list_path": None,
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts",
    "resample_rates_hz": [8000, 16000, 22050, 24000, 44100],
    "codecs": {
        "opus": {"ffmpeg_codec": "libopus", "extension": ".opus", "bitrates_kbps": [16, 32, 64, 128]},
        "mp3": {"ffmpeg_codec": "libmp3lame", "extension": ".mp3", "bitrates_kbps": [16, 32, 64, 128]},
        "aac": {"ffmpeg_codec": "aac", "extension": ".m4a", "bitrates_kbps": [16, 32, 64, 128]},
    },
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()

def _normalize_mel_params(params: Dict[str, Any]) -> Dict[str, Any]:
    if "win_len" in params:
        win_len = params.pop("win_len")
        params.setdefault("win_length", win_len)
    return params


def _deep_merge_dict(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    merged = base.copy()
    for key, value in overrides.items():
        if isinstance(value, dict):
            existing = merged.get(key)
            if isinstance(existing, dict):
                merged[key] = _deep_merge_dict(existing, value)
            else:
                merged[key] = value.copy()
        else:
            merged[key] = value
    return merged

MEL_PARAMS = _normalize_mel_params(MEL_PARAMS)

FALLBACK_SAMPLE_RATE = int(DEFAULT_MEL_PARAMS.get("sample_rate", 24000))
FALLBACK_HOP_LENGTH = int(DEFAULT_MEL_PARAMS.get("hop_length", DEFAULT_MEL_PARAMS.get("win_len", 300)))
F0_PARAMS: Dict[str, Any] = {}
F0_ZERO_FILL: float = 0.0
REFERENCE_F0_BACKENDS_USED: set[str] = set()
NOTEBOOK_F0_EXTRACTOR = build_notebook_f0_extractor(
    MEL_PARAMS,
    F0_PARAMS,
    fallback_sr=FALLBACK_SAMPLE_RATE,
    fallback_hop=FALLBACK_HOP_LENGTH,
    verbose=False,
)

AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}


def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate


def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None

    def _sort_key(p: Path) -> tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime

    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None


def _load_training_config() -> Dict[str, Any]:
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}

    import yaml

    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}

    dataset_params = data.get("dataset_params", {})
    global NOTEBOOK_F0_EXTRACTOR, F0_ZERO_FILL

    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)
    F0_PARAMS.clear()
    F0_PARAMS.update(dataset_params.get("f0_params", {}))
    NOTEBOOK_F0_EXTRACTOR = build_notebook_f0_extractor(
        MEL_PARAMS,
        F0_PARAMS,
        fallback_sr=FALLBACK_SAMPLE_RATE,
        fallback_hop=FALLBACK_HOP_LENGTH,
        verbose=True,
    )
    F0_ZERO_FILL = float(F0_PARAMS.get("zero_fill_value", 0.0))

    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr

    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")

    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)

    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")

    return data


TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = int(MEL_PARAMS["sample_rate"])
HOP_LENGTH = int(MEL_PARAMS["hop_length"])
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)



In [None]:


def _collect_model_configuration(checkpoint_state: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    model_params: Dict[str, Any] = {}
    training_params = TRAINING_CONFIG.get("model_params")
    if isinstance(training_params, dict):
        model_params = _deep_merge_dict(model_params, training_params)
    if isinstance(checkpoint_state, dict):
        candidate = checkpoint_state.get("model_params")
        if isinstance(candidate, dict):
            model_params = _deep_merge_dict(model_params, candidate)
        config_section = checkpoint_state.get("config")
        if isinstance(config_section, dict):
            candidate = config_section.get("model_params")
            if isinstance(candidate, dict):
                model_params = _deep_merge_dict(model_params, candidate)
    sequence_config = model_params.get("sequence_model")
    if isinstance(sequence_config, dict):
        sequence_config = sequence_config.copy()
    else:
        sequence_config = {}
    top_level = {k: v for k, v in model_params.items() if k != "sequence_model"}
    top_level.pop("num_class", None)
    return top_level, sequence_config

def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)


def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int] | None:
    try:
        audio, sr = sf.read(str(path), dtype="float32")
    except KeyboardInterrupt:
        raise
    except Exception as exc:
        print(f"Warning: skipping unreadable file '{path}': {exc}")
        return None
    audio = _ensure_mono(audio)
    if audio.size == 0:
        print(f"Warning: skipping empty file '{path}'.")
        return None
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def compute_reference_f0(audio: np.ndarray, sr: int) -> np.ndarray:
    result = compute_f0_for_notebook(
        audio,
        sr,
        NOTEBOOK_F0_EXTRACTOR,
        zero_fill_value=F0_ZERO_FILL,
    )
    if result.backend_name and result.backend_name not in REFERENCE_F0_BACKENDS_USED:
        REFERENCE_F0_BACKENDS_USED.add(result.backend_name)
        print(f"Using F0 backend '{result.backend_name}' for reference computation.")
    return result.f0
def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)


def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = CONFIG["chunk_size"]
    overlap = CONFIG["chunk_overlap"]
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)



def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if not isinstance(checkpoint, dict):
        raise RuntimeError("Unexpected checkpoint format")
    model_state = checkpoint.get("model")
    if model_state is None:
        model_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(model_state, dict):
        raise RuntimeError("Checkpoint is missing a valid model state")
    model_state = dict(model_state)
    classifier_weight = model_state.get("classifier.weight")
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(checkpoint.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    model_params, sequence_model_config = _collect_model_configuration(checkpoint)
    model_kwargs: Dict[str, Any] = {}
    if sequence_model_config:
        model_kwargs["sequence_model_config"] = sequence_model_config
    leaky_relu_slope = model_params.get("leaky_relu_slope")
    if isinstance(leaky_relu_slope, (int, float)):
        model_kwargs["leaky_relu_slope"] = float(leaky_relu_slope)
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes, **model_kwargs)
    incompatible = model.load_state_dict(model_state, strict=False)
    if incompatible.unexpected_keys:
        skipped = ", ".join(sorted(incompatible.unexpected_keys))
        print(f"Warning: skipped unexpected parameters: {skipped}")
    if incompatible.missing_keys:
        missing = ", ".join(sorted(incompatible.missing_keys))
        print(f"Warning: missing parameters not found in checkpoint: {missing}")
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model


In [None]:

def hz_to_cents(f0: np.ndarray) -> np.ndarray:
    cents = np.zeros_like(f0)
    positive = f0 > 0
    cents[positive] = 1200.0 * np.log2(f0[positive] / 55.0)
    return cents


def circular_cents_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    diff = a - b
    diff = np.mod(diff + 600.0, 1200.0) - 600.0
    return diff


def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]

    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]

    total_frames = length
    voiced_frames = np.count_nonzero(ref_voiced)

    vuv_accuracy = np.count_nonzero(ref_voiced == pred_voiced) / max(total_frames, 1)

    if voiced_frames == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "OctaveError": float("nan"),
        }

    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))

    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0

    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0

    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (
        np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0
    )

    return {
        "RPA": np.count_nonzero(rpa_hits) / voiced_frames,
        "RCA": np.count_nonzero(rca_hits) / voiced_frames,
        "VUV": vuv_accuracy,
        "OctaveError": np.count_nonzero(octave_errors) / voiced_frames,
    }


In [None]:

DATASET_CACHE: List[Dict[str, Any]] | None = None


def prepare_dataset_cache(force: bool = False) -> List[Dict[str, Any]]:
    global DATASET_CACHE
    if DATASET_CACHE is not None and not force:
        return DATASET_CACHE

    eval_list = CONFIG.get("eval_list_path")
    if eval_list is None or not Path(eval_list).is_file():
        raise FileNotFoundError(f"Evaluation list not found: {eval_list}")

    entries: List[Path] = []
    with open(eval_list, "r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            relative_path = line.split("|")[0]
            candidate = Path(relative_path)
            if not candidate.is_absolute():
                candidate = (Path(eval_list).parent / candidate).resolve()
            entries.append(candidate)

    if not entries:
        raise RuntimeError(f"No evaluation files listed in {eval_list}")

    cache: List[Dict[str, Any]] = []
    for path in tqdm(entries, desc="Preparing evaluation cache"):
        result = load_waveform(path, TARGET_SAMPLE_RATE)
        if result is None:
            continue
        audio, sr = result
        reference_f0 = compute_reference_f0(audio, sr)
        cache.append(
            {
                "path": path,
                "audio": audio,
                "sample_rate": sr,
                "reference_f0": reference_f0,
            }
        )

    DATASET_CACHE = cache
    print(f"Cached {len(DATASET_CACHE)} evaluation utterances.")
    return DATASET_CACHE


def evaluate_condition(
    model: JDCNet,
    dataset: List[Dict[str, Any]],
    transform_fn,
    label: Dict[str, Any],
) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    for entry in dataset:
        processed = transform_fn(entry)
        prediction = predict_f0(model, processed)
        metrics = compute_metrics(entry["reference_f0"], prediction)
        record = {"path": str(entry["path"]), **label, **metrics}
        results.append(record)
    return results


In [None]:

FFMPEG_PATH = shutil.which("ffmpeg")
if FFMPEG_PATH is None:
    raise EnvironmentError("ffmpeg executable not found. Install ffmpeg to run codec evaluations.")


TORCHAUDIO_RESAMPLER_CACHE: Dict[tuple[int, int], torchaudio.transforms.Resample] = {}


def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
    if src_sr == dst_sr:
        return audio.astype(np.float32, copy=True)
    key = (src_sr, dst_sr)
    if key not in TORCHAUDIO_RESAMPLER_CACHE:
        TORCHAUDIO_RESAMPLER_CACHE[key] = torchaudio.transforms.Resample(src_sr, dst_sr)
    transform = TORCHAUDIO_RESAMPLER_CACHE[key]
    tensor = torch.from_numpy(audio).unsqueeze(0)
    with torch.no_grad():
        resampled = transform(tensor)
    return resampled.squeeze(0).cpu().numpy().astype(np.float32)


def apply_resample_condition(entry: Dict[str, Any], target_rate: int) -> np.ndarray:
    degraded = _resample_audio(entry["audio"], entry["sample_rate"], target_rate)
    if target_rate != TARGET_SAMPLE_RATE:
        degraded = _resample_audio(degraded, target_rate, TARGET_SAMPLE_RATE)
    return degraded.astype(np.float32)


def _ffmpeg_encode_decode(
    audio: np.ndarray,
    sample_rate: int,
    codec_key: str,
    ffmpeg_codec: str,
    extension: str,
    bitrate_kbps: int,
) -> np.ndarray:
    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir_path = Path(tmpdir)
        input_path = tmpdir_path / "input.wav"
        encoded_path = tmpdir_path / f"encoded{extension}"
        output_path = tmpdir_path / "decoded.wav"

        sf.write(str(input_path), audio, sample_rate)

        bitrate_arg = f"{int(bitrate_kbps)}k"
        encode_cmd = [
            FFMPEG_PATH,
            "-y",
            "-loglevel",
            "error",
            "-i",
            str(input_path),
            "-c:a",
            ffmpeg_codec,
            "-b:a",
            bitrate_arg,
            str(encoded_path),
        ]
        result = subprocess.run(encode_cmd, capture_output=True)
        if result.returncode != 0:
            raise RuntimeError(
                f"ffmpeg encoding failed for {codec_key} @ {bitrate_kbps} kbps:\n{result.stderr.decode()}"
            )

        decode_cmd = [
            FFMPEG_PATH,
            "-y",
            "-loglevel",
            "error",
            "-i",
            str(encoded_path),
            "-ar",
            str(TARGET_SAMPLE_RATE),
            str(output_path),
        ]
        result = subprocess.run(decode_cmd, capture_output=True)
        if result.returncode != 0:
            raise RuntimeError(
                f"ffmpeg decoding failed for {codec_key} @ {bitrate_kbps} kbps:\n{result.stderr.decode()}"
            )

        degraded, sr = sf.read(str(output_path), dtype="float32")
        degraded = _ensure_mono(degraded)
        if sr != TARGET_SAMPLE_RATE:
            degraded = _resample_audio(degraded, sr, TARGET_SAMPLE_RATE)
        return degraded.astype(np.float32)


def apply_codec_condition(
    entry: Dict[str, Any],
    codec_key: str,
    codec_config: Dict[str, Any],
    bitrate_kbps: int,
) -> np.ndarray:
    return _ffmpeg_encode_decode(
        entry["audio"],
        entry["sample_rate"],
        codec_key,
        codec_config["ffmpeg_codec"],
        codec_config["extension"],
        bitrate_kbps,
    )


In [None]:

model = load_model()
dataset_cache = prepare_dataset_cache()
print(f"Loaded dataset cache with {len(dataset_cache)} entries at {TARGET_SAMPLE_RATE} Hz")


In [None]:

METRICS = ["RPA", "RCA", "VUV", "OctaveError"]
METRIC_DIRECTIONS = {"RPA": "higher", "RCA": "higher", "VUV": "higher", "OctaveError": "lower"}

baseline_records = evaluate_condition(
    model,
    dataset_cache,
    lambda entry: entry["audio"],
    {"condition_type": "baseline"},
)
baseline_df = pd.DataFrame(baseline_records)
baseline_summary = baseline_df[METRICS].agg(["mean", "std"])
baseline_df


In [None]:
baseline_summary

In [None]:

resample_records: List[Dict[str, Any]] = []
for rate in CONFIG["resample_rates_hz"]:
    resample_records.extend(
        evaluate_condition(
            model,
            dataset_cache,
            lambda entry, rate=rate: apply_resample_condition(entry, rate),
            {"condition_type": "resample", "resample_hz": rate},
        )
    )

resample_df = pd.DataFrame(resample_records)
resample_df.head()


In [None]:

def _summarize_with_drop(df: pd.DataFrame, group_cols: List[str]) -> pd.DataFrame:
    grouped = df.groupby(group_cols)
    summary = grouped[METRICS].agg(["mean", "std"])
    summary.columns = [f"{metric}_{stat}" for metric, stat in summary.columns]
    summary = summary.reset_index()
    for metric in METRICS:
        orientation = METRIC_DIRECTIONS[metric]
        baseline_value = baseline_summary.loc["mean", metric]
        if orientation == "higher":
            drop = baseline_value - summary[f"{metric}_mean"]
        else:
            drop = summary[f"{metric}_mean"] - baseline_value
        summary[f"{metric}_drop"] = drop
    return summary


resample_summary = _summarize_with_drop(resample_df, ["resample_hz"])
resample_summary.sort_values("resample_hz", inplace=True)
resample_summary


In [None]:

from matplotlib.ticker import ScalarFormatter

fig, axes = plt.subplots(1, len(METRICS), figsize=(4 * len(METRICS), 4), sharex=False)
for metric, ax in zip(METRICS, axes):
    ax.plot(
        resample_summary["resample_hz"],
        resample_summary[f"{metric}_mean"],
        marker="o",
        label="Resampled",
    )
    ax.axhline(baseline_summary.loc["mean", metric], color="k", linestyle="--", label="Baseline")
    ax.set_title(metric)
    ax.set_xlabel("Sample rate (Hz)")
    if metric == "OctaveError":
        ax.set_ylabel("Error rate")
    else:
        ax.set_ylabel("Score")
    ax.set_xscale("log")
    ax.set_xticks(CONFIG["resample_rates_hz"])
    ax.get_xaxis().set_major_formatter(ScalarFormatter())
axes[-1].legend()
plt.tight_layout()
plt.show()


In [None]:

fig, axes = plt.subplots(1, len(METRICS), figsize=(4 * len(METRICS), 4))
for metric, ax in zip(METRICS, axes):
    ax.plot(
        resample_summary["resample_hz"],
        resample_summary[f"{metric}_drop"],
        marker="o",
        label="Drop",
    )
    ax.axhline(0.0, color="k", linestyle="--")
    ax.set_title(f"{metric} drop vs. baseline")
    ax.set_xlabel("Sample rate (Hz)")
    ax.set_xscale("log")
    ax.set_xticks(CONFIG["resample_rates_hz"])
    ax.get_xaxis().set_major_formatter(ScalarFormatter())
    ax.set_ylabel("Drop" if metric != "OctaveError" else "Increase")
axes[-1].legend()
plt.tight_layout()
plt.show()


In [None]:

codec_records: List[Dict[str, Any]] = []
for codec_key, codec_config in CONFIG["codecs"].items():
    for bitrate in codec_config["bitrates_kbps"]:
        codec_records.extend(
            evaluate_condition(
                model,
                dataset_cache,
                lambda entry, ck=codec_key, cfg=codec_config, br=bitrate: apply_codec_condition(
                    entry, ck, cfg, br
                ),
                {
                    "condition_type": "codec",
                    "codec": codec_key,
                    "bitrate_kbps": bitrate,
                },
            )
        )

codec_df = pd.DataFrame(codec_records)
codec_df.head()


In [None]:

codec_summary = _summarize_with_drop(codec_df, ["codec", "bitrate_kbps"])
codec_summary.sort_values(["codec", "bitrate_kbps"], inplace=True)
codec_summary


In [None]:

fig, axes = plt.subplots(len(METRICS), 1, figsize=(6, 3 * len(METRICS)), sharex=True)
for metric, ax in zip(METRICS, axes):
    for codec_key in codec_summary["codec"].unique():
        subset = codec_summary[codec_summary["codec"] == codec_key]
        ax.plot(
            subset["bitrate_kbps"],
            subset[f"{metric}_mean"],
            marker="o",
            label=codec_key.upper(),
        )
    ax.axhline(baseline_summary.loc["mean", metric], color="k", linestyle="--", label="Baseline")
    ax.set_title(metric)
    ax.set_ylabel("Score" if metric != "OctaveError" else "Error rate")
    ax.set_xlabel("Bitrate (kbps)")
axes[0].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
plt.tight_layout()
plt.show()


In [None]:

fig, axes = plt.subplots(len(METRICS), 1, figsize=(6, 3 * len(METRICS)), sharex=True)
for metric, ax in zip(METRICS, axes):
    for codec_key in codec_summary["codec"].unique():
        subset = codec_summary[codec_summary["codec"] == codec_key]
        ax.plot(
            subset["bitrate_kbps"],
            subset[f"{metric}_drop"],
            marker="o",
            label=codec_key.upper(),
        )
    ax.axhline(0.0, color="k", linestyle="--")
    ax.set_title(f"{metric} drop vs. baseline")
    ax.set_ylabel("Drop" if metric != "OctaveError" else "Increase")
    ax.set_xlabel("Bitrate (kbps)")
axes[0].legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
plt.tight_layout()
plt.show()


In [None]:

output_base = CONFIG["output_dir"]
output_base.mkdir(parents=True, exist_ok=True)

baseline_path = output_base / "codec_bandwidth_baseline_metrics.csv"
resample_path = output_base / "codec_bandwidth_resample_metrics.csv"
codec_path = output_base / "codec_bandwidth_codec_metrics.csv"
resample_summary_path = output_base / "codec_bandwidth_resample_summary.csv"
codec_summary_path = output_base / "codec_bandwidth_codec_summary.csv"

baseline_df.to_csv(baseline_path, index=False)
resample_df.to_csv(resample_path, index=False)
codec_df.to_csv(codec_path, index=False)
resample_summary.to_csv(resample_summary_path, index=False)
codec_summary.to_csv(codec_summary_path, index=False)

print(f"Saved detailed metrics to {baseline_path}")
print(f"Saved resample sweep metrics to {resample_path}")
print(f"Saved codec sweep metrics to {codec_path}")
print(f"Saved resample summary to {resample_summary_path}")
print(f"Saved codec summary to {codec_summary_path}")



## Next steps

* Inspect the CSV artifacts for deeper analysis or downstream reporting.
* Adjust `CONFIG["resample_rates_hz"]` and `CONFIG["codecs"]` to explore additional stress conditions.
* Consider combining these degradations (e.g., resampling **and** compressing) to probe compounded failure modes.
