
# Room and Microphone Stress Evaluation Notebook

This notebook probes room acoustics and microphone coloration failure boundaries for a pitch extraction model trained with the JDC-PitchExtractor repository.
It loads the latest checkpoint, prepares a clean evaluation cache, and then sweeps:

* **Room impulse responses** for small-room, office, and hall environments, selecting examples across a configurable T60 decay-time grid.
* **Microphone coloration curves** approximating smartphone, headset, and studio large-diaphragm condenser responses via cascaded biquad filters.

For every condition the notebook reports standard melody metrics and visualises metric deltas relative to the clean baseline to highlight robustness boundaries.



## Environment Setup

Uncomment and execute the next cell if any of the required dependencies are missing in your environment.


In [None]:
# !pip install soundfile torchaudio torch pyyaml matplotlib librosa pyworld torchcrepe praat-parselmouth pandas tqdm


## Imports and Global Configuration

The repository root is added to `sys.path` so we can reuse the dataset and model helpers that ship with the project.


In [None]:
import math
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

REPO_ROOT = Path.cwd().resolve().parents[0]
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import soundfile as sf
import torch
import torchaudio
from torchaudio.functional import equalizer_biquad
from Utils.f0_notebook_utils import (
    build_notebook_f0_extractor,
    compute_f0_for_notebook,
)
from tqdm.auto import tqdm

from meldataset import DEFAULT_MEL_PARAMS
from model import JDCNet

plt.style.use("seaborn-v0_8")
plt.rcParams.update({"figure.figsize": (12, 4), "axes.grid": True})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")



## User Configuration

Update the configuration dictionary with paths that exist on your machine.
The notebook will automatically resolve relative paths against the repository root and the training config directory,
and will fall back to the most recent checkpoint in `checkpoint_dir` if `checkpoint_path` is left as `None`.


In [None]:
CONFIG: Dict[str, Any] = {
    "config_path": REPO_ROOT / "Configs" / "config.yml",
    "checkpoint_dir": REPO_ROOT / "Checkpoint",
    "checkpoint_path": None,
    "eval_list_path": None,
    "output_dir": REPO_ROOT / "notebooks" / "artifacts",
    "chunk_size": 192,
    "chunk_overlap": 48,
    "mel_mean": -4.0,
    "mel_std": 4.0,
    "voicing_threshold_hz": 10.0,
    "rir_library": {
        "small_room": [
            {"path": REPO_ROOT / "ImpulseResponses" / "small_room_t60_0.30.wav", "t60": 0.30},
            {"path": REPO_ROOT / "ImpulseResponses" / "small_room_t60_0.60.wav", "t60": 0.60},
            {"path": REPO_ROOT / "ImpulseResponses" / "small_room_t60_0.90.wav", "t60": 0.90},
        ],
        "office": [
            {"path": REPO_ROOT / "ImpulseResponses" / "office_t60_0.45.wav", "t60": 0.45},
            {"path": REPO_ROOT / "ImpulseResponses" / "office_t60_0.80.wav", "t60": 0.80},
            {"path": REPO_ROOT / "ImpulseResponses" / "office_t60_1.20.wav", "t60": 1.20},
        ],
        "hall": [
            {"path": REPO_ROOT / "ImpulseResponses" / "hall_t60_0.80.wav", "t60": 0.80},
            {"path": REPO_ROOT / "ImpulseResponses" / "hall_t60_1.10.wav", "t60": 1.10},
            {"path": REPO_ROOT / "ImpulseResponses" / "hall_t60_1.50.wav", "t60": 1.50},
        ],
    },
    "t60_sweep": [round(x, 2) for x in np.linspace(0.2, 1.5, 14)],
    "microphone_eq": {
        "smartphone": [
            {"freq": 180.0, "gain_db": -6.0, "Q": 0.8},
            {"freq": 3500.0, "gain_db": 5.0, "Q": 1.2},
            {"freq": 9000.0, "gain_db": 3.0, "Q": 1.0},
        ],
        "headset": [
            {"freq": 120.0, "gain_db": -2.0, "Q": 0.7},
            {"freq": 2400.0, "gain_db": 3.0, "Q": 1.4},
            {"freq": 6000.0, "gain_db": 2.5, "Q": 1.1},
        ],
        "studio_ldc": [
            {"freq": 80.0, "gain_db": 2.0, "Q": 0.9},
            {"freq": 4500.0, "gain_db": -1.5, "Q": 1.3},
            {"freq": 12000.0, "gain_db": 1.5, "Q": 0.9},
        ],
    },
}

CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
CONFIG



## Helper Utilities

The next cell defines shared helpers for path resolution, checkpoint discovery, audio preprocessing, model inference,
and evaluation. These utilities closely mirror the training-time preprocessing performed in `meldataset.MelDataset`.


In [None]:
MEL_PARAMS = DEFAULT_MEL_PARAMS.copy()

def _normalize_mel_params(params: Dict[str, Any]) -> Dict[str, Any]:
    if "win_len" in params:
        win_len = params.pop("win_len")
        params.setdefault("win_length", win_len)
    return params


def _deep_merge_dict(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
    merged = base.copy()
    for key, value in overrides.items():
        if isinstance(value, dict):
            existing = merged.get(key)
            if isinstance(existing, dict):
                merged[key] = _deep_merge_dict(existing, value)
            else:
                merged[key] = value.copy()
        else:
            merged[key] = value
    return merged

MEL_PARAMS = _normalize_mel_params(MEL_PARAMS)

FALLBACK_SAMPLE_RATE = int(DEFAULT_MEL_PARAMS.get("sample_rate", 24000))
FALLBACK_HOP_LENGTH = int(DEFAULT_MEL_PARAMS.get("hop_length", DEFAULT_MEL_PARAMS.get("win_len", 300)))
F0_PARAMS: Dict[str, Any] = {}
F0_ZERO_FILL: float = 0.0
REFERENCE_F0_BACKENDS_USED: set[str] = set()
NOTEBOOK_F0_EXTRACTOR = build_notebook_f0_extractor(
    MEL_PARAMS,
    F0_PARAMS,
    fallback_sr=FALLBACK_SAMPLE_RATE,
    fallback_hop=FALLBACK_HOP_LENGTH,
    verbose=False,
)

AUDIO_EXTENSIONS = {".wav", ".flac", ".ogg", ".mp3", ".m4a"}

def _resolve_relative_path(base: Path, candidate: str | Path) -> Path:
    base_dir = base if base.is_dir() else base.parent
    candidate_path = Path(candidate)
    if candidate_path.is_absolute():
        return candidate_path
    repo_candidate = (REPO_ROOT / candidate_path).resolve()
    config_candidate = (base_dir / candidate_path).resolve()
    if repo_candidate.exists():
        return repo_candidate
    if config_candidate.exists():
        return config_candidate
    return repo_candidate

def _latest_checkpoint(path: Path) -> Optional[Path]:
    if not path.is_dir():
        return None
    def _sort_key(p: Path) -> tuple[int, float]:
        numbers = [int(match) for match in re.findall(r"\d+", p.stem)]
        last = numbers[-1] if numbers else -1
        return last, p.stat().st_mtime
    candidates = sorted(path.glob("*.pth"), key=_sort_key)
    return candidates[-1] if candidates else None

def _load_training_config():
    config_path = CONFIG.get("config_path")
    if not config_path or not Path(config_path).is_file():
        print("Warning: config file not found; using DEFAULT_MEL_PARAMS.")
        return {}
    import yaml
    with open(config_path, "r", encoding="utf-8") as handle:
        data = yaml.safe_load(handle) or {}
    dataset_params = data.get("dataset_params", {})
    global NOTEBOOK_F0_EXTRACTOR, F0_ZERO_FILL
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)
    F0_PARAMS.clear()
    F0_PARAMS.update(dataset_params.get("f0_params", {}))
    NOTEBOOK_F0_EXTRACTOR = build_notebook_f0_extractor(
    MEL_PARAMS,
    F0_PARAMS,
    fallback_sr=FALLBACK_SAMPLE_RATE,
    fallback_hop=FALLBACK_HOP_LENGTH,
    verbose=True,
)
    F0_ZERO_FILL = float(F0_PARAMS.get("zero_fill_value", 0.0))
    MEL_PARAMS.update(dataset_params.get("mel_params", {}))
    _normalize_mel_params(MEL_PARAMS)
    dataset_sr = dataset_params.get("sr")
    if dataset_sr is not None:
        MEL_PARAMS["sample_rate"] = dataset_sr
    val_list = data.get("val_data")
    if val_list and CONFIG.get("eval_list_path") is None:
        CONFIG["eval_list_path"] = _resolve_relative_path(config_path, val_list)
        if not CONFIG["eval_list_path"].is_file():
            print(f"Warning: evaluation list not found at {CONFIG['eval_list_path']}")
    log_dir = data.get("log_dir")
    if log_dir and (not CONFIG.get("checkpoint_dir") or not Path(CONFIG["checkpoint_dir"]).is_dir()):
        CONFIG["checkpoint_dir"] = _resolve_relative_path(config_path, log_dir)
    if CONFIG.get("checkpoint_path") is None:
        latest = _latest_checkpoint(Path(CONFIG["checkpoint_dir"]))
        if latest is not None:
            CONFIG["checkpoint_path"] = latest
        else:
            print("Warning: no checkpoints found; set CONFIG['checkpoint_path'] manually.")
    return data

TRAINING_CONFIG = _load_training_config()
TARGET_SAMPLE_RATE = MEL_PARAMS["sample_rate"]
HOP_LENGTH = MEL_PARAMS["hop_length"]
FRAME_PERIOD_MS = HOP_LENGTH * 1000.0 / TARGET_SAMPLE_RATE
mel_transform = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS).to(DEVICE)


def _collect_model_configuration(checkpoint_state: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    model_params: Dict[str, Any] = {}
    training_params = TRAINING_CONFIG.get("model_params")
    if isinstance(training_params, dict):
        model_params = _deep_merge_dict(model_params, training_params)
    if isinstance(checkpoint_state, dict):
        candidate = checkpoint_state.get("model_params")
        if isinstance(candidate, dict):
            model_params = _deep_merge_dict(model_params, candidate)
        config_section = checkpoint_state.get("config")
        if isinstance(config_section, dict):
            candidate = config_section.get("model_params")
            if isinstance(candidate, dict):
                model_params = _deep_merge_dict(model_params, candidate)
    sequence_config = model_params.get("sequence_model")
    if isinstance(sequence_config, dict):
        sequence_config = sequence_config.copy()
    else:
        sequence_config = {}
    top_level = {k: v for k, v in model_params.items() if k != "sequence_model"}
    top_level.pop("num_class", None)
    return top_level, sequence_config

def _ensure_mono(audio: np.ndarray) -> np.ndarray:
    if audio.ndim == 1:
        return audio
    return audio.mean(axis=1)

def load_waveform(path: Path, target_sr: int = TARGET_SAMPLE_RATE) -> tuple[np.ndarray, int] | None:
    try:
        audio, sr = sf.read(str(path), dtype="float32")
    except KeyboardInterrupt:
        raise
    except Exception as exc:
        print(f"Warning: skipping unreadable file '{path}': {exc}")
        return None
    audio = _ensure_mono(audio)
    if audio.size == 0:
        print(f"Warning: skipping empty file '{path}'.")
        return None
    if sr != target_sr:
        tensor = torch.from_numpy(audio).unsqueeze(0)
        resampled = torchaudio.functional.resample(tensor, sr, target_sr)
        audio = resampled.squeeze(0).cpu().numpy()
        sr = target_sr
    return audio.astype(np.float32), sr

def compute_reference_f0(audio: np.ndarray, sr: int) -> np.ndarray:
    result = compute_f0_for_notebook(
        audio,
        sr,
        NOTEBOOK_F0_EXTRACTOR,
        zero_fill_value=F0_ZERO_FILL,
    )
    if result.backend_name and result.backend_name not in REFERENCE_F0_BACKENDS_USED:
        REFERENCE_F0_BACKENDS_USED.add(result.backend_name)
        print(f"Using F0 backend '{result.backend_name}' for reference computation.")
    return result.f0
def waveform_to_mel(audio: np.ndarray) -> torch.Tensor:
    tensor = torch.from_numpy(audio).float().unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        mel = mel_transform(tensor)
    mel = torch.log(mel + 1e-5)
    mel = (mel - CONFIG["mel_mean"]) / CONFIG["mel_std"]
    return mel.squeeze(0)

def predict_f0(model: JDCNet, audio: np.ndarray) -> np.ndarray:
    mel = waveform_to_mel(audio)
    total_frames = mel.shape[-1]
    chunk_size = CONFIG["chunk_size"]
    overlap = CONFIG["chunk_overlap"]
    step = max(chunk_size - overlap, 1)
    preds: List[np.ndarray] = []
    for start in range(0, total_frames, step):
        end = min(start + chunk_size, total_frames)
        mel_chunk = mel[:, start:end]
        pad = chunk_size - mel_chunk.shape[-1]
        if pad > 0:
            mel_chunk = torch.nn.functional.pad(mel_chunk, (0, pad))
        mel_chunk = mel_chunk.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
        with torch.no_grad():
            f0_chunk, _ = model(mel_chunk)
        f0_chunk = f0_chunk.squeeze().detach().cpu().numpy()
        preds.append(f0_chunk[: end - start])
    if preds:
        return np.concatenate(preds)
    return np.zeros((0,), dtype=np.float32)

def hz_to_cents(f0: np.ndarray) -> np.ndarray:
    cents = np.zeros_like(f0)
    positive = f0 > 0
    cents[positive] = 1200.0 * np.log2(f0[positive] / 55.0)
    return cents

def circular_cents_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    diff = a - b
    diff = np.mod(diff + 600.0, 1200.0) - 600.0
    return diff

def compute_metrics(reference: np.ndarray, prediction: np.ndarray) -> Dict[str, float]:
    length = min(reference.shape[0], prediction.shape[0])
    reference = reference[:length]
    prediction = prediction[:length]
    ref_voiced = reference > 0
    pred_voiced = prediction > CONFIG["voicing_threshold_hz"]
    total_frames = length
    voiced_frames = int(np.count_nonzero(ref_voiced))
    vuv_accuracy = float(np.count_nonzero(ref_voiced == pred_voiced) / max(total_frames, 1))
    if voiced_frames == 0:
        return {
            "RPA": float("nan"),
            "RCA": float("nan"),
            "VUV": vuv_accuracy,
            "OctaveError": float("nan"),
        }
    ref_cents = hz_to_cents(reference[ref_voiced])
    pred_cents = hz_to_cents(np.clip(prediction[ref_voiced], a_min=1e-5, a_max=None))
    cents_diff = pred_cents - ref_cents
    rpa_hits = np.abs(cents_diff) <= 50.0
    chroma_diff = circular_cents_distance(pred_cents, ref_cents)
    rca_hits = np.abs(chroma_diff) <= 50.0
    octave_candidates = np.abs(cents_diff) > 50.0
    octave_numbers = np.round(cents_diff / 1200.0)
    octave_errors = octave_candidates & (octave_numbers != 0) & (np.abs(cents_diff - octave_numbers * 1200.0) <= 50.0)
    return {
        "RPA": float(np.count_nonzero(rpa_hits) / voiced_frames),
        "RCA": float(np.count_nonzero(rca_hits) / voiced_frames),
        "VUV": vuv_accuracy,
        "OctaveError": float(np.count_nonzero(octave_errors) / voiced_frames),
    }

DATASET_CACHE: Optional[List[Dict[str, Any]]] = None

def prepare_dataset_cache(force: bool = False) -> List[Dict[str, Any]]:
    global DATASET_CACHE
    if DATASET_CACHE is not None and not force:
        return DATASET_CACHE
    eval_list = CONFIG.get("eval_list_path")
    if eval_list is None or not Path(eval_list).is_file():
        raise FileNotFoundError(f"Evaluation list not found: {eval_list}")
    entries: List[Path] = []
    with open(eval_list, "r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            candidate = Path(line.split("|")[0])
            if not candidate.is_absolute():
                candidate = (Path(eval_list).parent / candidate).resolve()
            entries.append(candidate)
    if not entries:
        raise RuntimeError(f"No evaluation files located in {eval_list}")
    cache: List[Dict[str, Any]] = []
    for path in tqdm(entries, desc="Preparing evaluation cache"):
        result = load_waveform(path)
        if result is None:
            continue
        audio, sr = result
        reference_f0 = compute_reference_f0(audio, sr)
        cache.append({
            "path": path,
            "audio": audio,
            "sample_rate": sr,
            "reference_f0": reference_f0,
        })
    DATASET_CACHE = cache
    print(f"Cached {len(DATASET_CACHE)} evaluation utterances.")
    return DATASET_CACHE


def load_model(checkpoint_path: Optional[Path] = None) -> JDCNet:
    checkpoint_path = Path(checkpoint_path or CONFIG.get("checkpoint_path"))
    if not checkpoint_path.is_file():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    if not isinstance(checkpoint, dict):
        raise RuntimeError("Unexpected checkpoint format")
    model_state = checkpoint.get("model")
    if model_state is None:
        model_state = checkpoint.get("state_dict", checkpoint)
    if not isinstance(model_state, dict):
        raise RuntimeError("Checkpoint is missing a valid model state")
    model_state = dict(model_state)
    classifier_weight = model_state.get("classifier.weight")
    if classifier_weight is not None:
        inferred_classes = int(classifier_weight.shape[0])
    else:
        inferred_classes = int(checkpoint.get("num_class", CONFIG.get("num_class", 722)))
    if inferred_classes <= 0:
        inferred_classes = 722
    model_params, sequence_model_config = _collect_model_configuration(checkpoint)
    model_kwargs: Dict[str, Any] = {}
    if sequence_model_config:
        model_kwargs["sequence_model_config"] = sequence_model_config
    leaky_relu_slope = model_params.get("leaky_relu_slope")
    if isinstance(leaky_relu_slope, (int, float)):
        model_kwargs["leaky_relu_slope"] = float(leaky_relu_slope)
    print(f"Instantiating JDCNet with {inferred_classes} classes")
    model = JDCNet(num_class=inferred_classes, **model_kwargs)
    incompatible = model.load_state_dict(model_state, strict=False)
    if incompatible.unexpected_keys:
        skipped = ", ".join(sorted(incompatible.unexpected_keys))
        print(f"Warning: skipped unexpected parameters: {skipped}")
    if incompatible.missing_keys:
        missing = ", ".join(sorted(incompatible.missing_keys))
        print(f"Warning: missing parameters not found in checkpoint: {missing}")
    model.to(DEVICE).eval()
    print(f"Loaded checkpoint from {checkpoint_path}")
    return model
def _load_rir_waveform(path: Path) -> np.ndarray | None:
    result = load_waveform(path, TARGET_SAMPLE_RATE)
    if result is None:
        return None
    audio, _ = result
    if audio.size == 0:
        print(f"Warning: impulse response at '{path}' is empty; skipping.")
        return None
    audio = audio / (np.max(np.abs(audio)) + 1e-6)
    return audio.astype(np.float32)

def _infer_t60_from_name(path: Path) -> Optional[float]:
    match = re.search(r"t60[_=]?([0-9]+(?:\.[0-9]+)?)", path.stem, re.IGNORECASE)
    if match:
        return float(match.group(1))
    match = re.search(r"([0-9]+(?:\.[0-9]+)?)s", path.stem, re.IGNORECASE)
    if match:
        return float(match.group(1))
    return None

RIR_CACHE: Dict[str, List[Dict[str, Any]]] = {}

def resolve_rir_library(force: bool = False) -> Dict[str, List[Dict[str, Any]]]:
    global RIR_CACHE
    if RIR_CACHE and not force:
        return RIR_CACHE
    resolved: Dict[str, List[Dict[str, Any]]] = {}
    base_config = Path(CONFIG.get("config_path", REPO_ROOT))
    for room, entries in CONFIG.get("rir_library", {}).items():
        room_entries: List[Dict[str, Any]] = []
        for entry in entries:
            if isinstance(entry, dict):
                candidate = entry.get("path")
                t60 = entry.get("t60")
            else:
                candidate = entry
                t60 = None
            if candidate is None:
                continue
            path = _resolve_relative_path(base_config, candidate)
            if not path.is_file():
                print(f"Warning: RIR file not found ({path})")
                continue
            rir_audio = _load_rir_waveform(path)
            if rir_audio is None:
                continue
            if t60 is None:
                t60 = _infer_t60_from_name(path)
            room_entries.append({
                "path": path,
                "t60": t60,
                "audio": rir_audio,
            })
        if room_entries:
            resolved[room] = room_entries
        else:
            print(f"Warning: no valid RIRs configured for room '{room}'.")
    RIR_CACHE = resolved
    return RIR_CACHE

def _select_rir(room: str, target_t60: float) -> Optional[Dict[str, Any]]:
    candidates = resolve_rir_library().get(room, [])
    valid = [c for c in candidates if c.get("t60") is not None]
    if not valid:
        return None
    return min(valid, key=lambda item: abs(float(item["t60"]) - target_t60))

def apply_rir(audio: np.ndarray, rir: np.ndarray) -> np.ndarray:
    audio_tensor = torch.from_numpy(audio).view(1, 1, -1)
    rir_tensor = torch.from_numpy(rir).view(1, 1, -1)
    padding = rir_tensor.shape[-1] - 1
    with torch.no_grad():
        convolved = torch.nn.functional.conv1d(
            torch.nn.functional.pad(audio_tensor, (padding, 0)),
            rir_tensor.flip(-1),
        )
    result = convolved.squeeze().cpu().numpy()
    if result.shape[0] >= audio.shape[0]:
        result = result[: audio.shape[0]]
    else:
        result = np.pad(result, (0, audio.shape[0] - result.shape[0]))
    max_val = np.max(np.abs(result))
    if max_val > 0.99:
        result = result / (max_val + 1e-6)
    return result.astype(np.float32)

def apply_microphone_eq(audio: np.ndarray, sample_rate: int, curve: Iterable[Dict[str, float]]) -> np.ndarray:
    tensor = torch.from_numpy(audio).unsqueeze(0)
    result = tensor
    for stage in curve:
        freq = float(stage.get("freq", 1000.0))
        gain = float(stage.get("gain_db", 0.0))
        q = float(stage.get("Q", 0.707))
        result = equalizer_biquad(result, sample_rate, freq, gain, q)
    return result.squeeze(0).cpu().numpy().astype(np.float32)

def evaluate_condition(model: JDCNet, transform_fn, label: Dict[str, Any]) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    for entry in DATASET_CACHE or []:
        processed = transform_fn(entry)
        prediction = predict_f0(model, processed)
        metrics = compute_metrics(entry["reference_f0"], prediction)
        record = {
            **label,
            "path": str(entry["path"]),
            **metrics,
        }
        results.append(record)
    return results


## Load Model and Evaluation Set

Run the next cell after updating the configuration to load the trained checkpoint and prepare the clean evaluation cache.


In [None]:
model = load_model()
DATASET_CACHE = prepare_dataset_cache(force=True)
len(DATASET_CACHE)



## Baseline (Clean) Evaluation

Compute clean-mix metrics to use as the reference point for subsequent delta calculations.


In [None]:
baseline_records: List[Dict[str, Any]] = []
for entry in tqdm(DATASET_CACHE, desc="Evaluating clean baseline"):
    prediction = predict_f0(model, entry["audio"])
    metrics = compute_metrics(entry["reference_f0"], prediction)
    baseline_records.append({"path": str(entry["path"]), **metrics})
baseline_df = pd.DataFrame(baseline_records)
baseline_summary = baseline_df[["RPA", "RCA", "VUV", "OctaveError"]].mean()
display(baseline_summary.to_frame(name="Baseline"))



## Room Impulse Response Sweep

For each room category the notebook selects the configured impulse response whose measured T60 is closest to each target in `CONFIG['t60_sweep']`.
Metrics are averaged per condition and plotted against the T60 grid to expose decay-time failure thresholds.


In [None]:
rir_results: List[Dict[str, Any]] = []
t60_targets = CONFIG.get("t60_sweep", [])
for room in CONFIG.get("rir_library", {}):
    for target_t60 in t60_targets:
        rir_entry = _select_rir(room, float(target_t60))
        if rir_entry is None:
            continue
        rir_audio = rir_entry["audio"]
        actual_t60 = rir_entry.get("t60")
        def _transform(sample, rir_waveform=rir_audio):
            return apply_rir(sample["audio"], rir_waveform)
        records = evaluate_condition(model, _transform, {
            "room": room,
            "target_t60": float(target_t60),
            "rir_t60": float(actual_t60) if actual_t60 is not None else np.nan,
            "rir_path": str(rir_entry.get("path")),
        })
        rir_results.extend(records)
rir_df = pd.DataFrame(rir_results)
if not rir_df.empty:
    rir_summary = (
        rir_df
        .groupby(["room", "target_t60", "rir_t60"])
        [["RPA", "RCA", "VUV", "OctaveError"]]
        .mean()
        .reset_index()
    )
    display(rir_summary.head())
else:
    print("No RIR results computed. Check CONFIG['rir_library'].")



### Plot: Metric vs T60

The plot below shows averaged RPA and RCA across the T60 sweep for each room category.


In [None]:
if 'rir_summary' in globals() and not rir_summary.empty:
    fig, axes = plt.subplots(1, 2, figsize=(16, 5), sharey=True)
    metrics_to_plot = ['RPA', 'RCA']
    for ax, metric in zip(axes, metrics_to_plot):
        for room, group in rir_summary.groupby('room'):
            ax.plot(group['target_t60'], group[metric], marker='o', label=room)
        ax.set_xlabel('Target T60 (s)')
        ax.set_ylabel(metric)
        ax.set_title(f'{metric} vs T60')
        ax.set_xlim(min(CONFIG['t60_sweep']), max(CONFIG['t60_sweep']))
        ax.legend()
    plt.tight_layout()
    plt.show()
else:
    print("No RIR summary available to plot.")



## Microphone Coloration Sweep

Each microphone profile is implemented as a sequence of biquad peaking/shelving filters.
We report deltas relative to the clean baseline for RPA and RCA to emphasise degradation due to coloration.


In [None]:
mic_results: List[Dict[str, Any]] = []
for mic_name, curve in CONFIG.get("microphone_eq", {}).items():
    def _transform(sample, eq_curve=curve):
        return apply_microphone_eq(sample["audio"], sample["sample_rate"], eq_curve)
    records = evaluate_condition(model, _transform, {
        "microphone": mic_name,
    })
    mic_results.extend(records)
mic_df = pd.DataFrame(mic_results)
if not mic_df.empty:
    mic_summary = (
        mic_df
        .groupby('microphone')
        [["RPA", "RCA", "VUV", "OctaveError"]]
        .mean()
    )
    rpa_delta = mic_summary['RPA'] - baseline_summary['RPA']
    rca_delta = mic_summary['RCA'] - baseline_summary['RCA']
    delta_df = pd.DataFrame({
        "RPA_delta": rpa_delta,
        "RCA_delta": rca_delta,
    })
    display(delta_df)
else:
    print("No microphone results computed. Check CONFIG['microphone_eq'].")



### Plot: Microphone-Induced Metric Delta

Bar plots showing deviation from the clean baseline for RPA and RCA.


In [None]:
if 'delta_df' in globals() and not delta_df.empty:
    ax = delta_df[['RPA_delta', 'RCA_delta']].plot(kind='bar', figsize=(10, 5))
    ax.set_ylabel('Delta relative to clean')
    ax.set_title('Microphone coloration impact on RPA / RCA')
    plt.axhline(0.0, color='black', linewidth=1.0)
    plt.tight_layout()
    plt.show()
else:
    print("No microphone delta data available for plotting.")



## Saving Results

Optional cell to persist the aggregated DataFrames for later analysis.


In [None]:
OUTPUT_DIR = CONFIG["output_dir"]
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
if 'rir_summary' in globals() and not rir_summary.empty:
    rir_summary.to_csv(OUTPUT_DIR / "rir_sweep_metrics.csv", index=False)
if 'delta_df' in globals() and not delta_df.empty:
    delta_df.to_csv(OUTPUT_DIR / "microphone_delta_metrics.csv")
print(f"Artifacts written to {OUTPUT_DIR}")
