# 02 -  Teacher Calibration and Sanity Checks

The purpose of this notebook is to calibrate all SMAD teachers (AST, Whisper-AT, Laion CLAP Large, M2D-CLAP) against BLOCS-SMAD-GOLD, then export thresholds for the pipeline.

## How to use this notebook

This notebook assumes you have:

1. Synced audio and segments with `python scripts/sync_b2_data.py`.
2. Built the acoustic-stats manifest: `python -m data_processing.build_acoustic_stats`.
3. Run all four teacher scripts: e.g., `python -m data_processing.teachers.apply_ast`.
4. Obtain BLOCS-SMAD-GOLD JSONL labels from B2 sync in `data/metadata/`.

If you want to tweak thresholds, make a personal copy in `notebooks/local` (e.g., `02_teacher_calibration_<yourname>.ipynb`).

## 1. Load configuration, gold labels, and teacher outputs

+ Configure paths and load all relevant datasets.

In [None]:
from pathlib import Path
import sys

import pandas as pd
from datasets import load_from_disk

from utils.config_utils import load_env, add_project_root_to_path
load_env()

from config import get_settings
from utils.dams_types import (
    BLOCS_SMAD_V1,
    BLOCS_SMAD_V2_AST,
    BLOCS_SMAD_V2_CLAP,
    BLOCS_SMAD_V2_M2D,
    BLOCS_SMAD_V2_WHISPER,
    CSV_BLOCS_SMAD_GOLD_ANNOTATIONS
)
# Ensure project root is in path.
add_project_root_to_path()

settings = get_settings()
metadata_dir: Path = settings.metadata_path
print(f"Metadata dir: {metadata_dir}")

# Load BLOCS SMAD v1 with acoustic stats and QC flags.
acoustic_ds = load_from_disk(metadata_dir / BLOCS_SMAD_V1)
acoustic_df = acoustic_ds.to_pandas()

# Load BLOCS SMAD v2 for each teacher with logits and probabilities.
ast_smad_v2 = load_from_disk(metadata_dir / BLOCS_SMAD_V2_AST).to_pandas()
clap_smad_v2 = load_from_disk(metadata_dir / BLOCS_SMAD_V2_CLAP).to_pandas()
m2d_smad_v2 = load_from_disk(metadata_dir / BLOCS_SMAD_V2_M2D).to_pandas()
whisper_smad_v2 = load_from_disk(metadata_dir / BLOCS_SMAD_V2_WHISPER).to_pandas()
gold_blocs = pd.read_csv(metadata_dir / CSV_BLOCS_SMAD_GOLD_ANNOTATIONS)

print("Loaded datasets:")
print(f"  acoustic_df:    {acoustic_df.shape}")
print(f"  ast_smad_v2:    {ast_smad_v2.shape}")
print(f"  clap_smad_v2:   {clap_smad_v2.shape}")
print(f"  m2d_smad_v2:    {m2d_smad_v2.shape}")
print(f"  whisper_smad_v2:{whisper_smad_v2.shape}")
print(f"  gold_smad_v2:   {gold_blocs.shape}")

# Sanity check that all have segment_path for joining
for name, df in [
    ("acoustic_df", acoustic_df),
    ("ast_smad_v2", ast_smad_v2),
    ("clap_smad_v2", clap_smad_v2),
    ("m2d_smad_v2", m2d_smad_v2),
    ("whisper_smad_v2", whisper_smad_v2),
    ("gold_smad_v2", gold_blocs),
]:
    assert "segment_path" in df.columns, f"{name} is missing 'segment_path'"

In [None]:
gold_blocs.head()

In [None]:
label_map = {
    'speech_gold': 'speech',
    'music_gold': 'music',
    'noise_gold': 'noise',
}

for gold_col, base_col in label_map.items():
    if gold_col in gold_blocs.columns and base_col not in gold_blocs.columns:
        gold_blocs[base_col] = gold_blocs[gold_col].astype(int)

## 2. Inspect teacher output schema

+ For each teacher, we check that the expected columns are present in the SMAD v2 outputs.
+ Verify score fields (probabilities / logits / binary flags) and any existing thresholds.

In [None]:
# Helper: group all teacher frames in one place
teacher_dfs = {
    'ast': ast_smad_v2,
    'clap': clap_smad_v2,
    'm2d': m2d_smad_v2,
    'whisper': whisper_smad_v2,
}

# What we expect from each teacher v2 dataset.
# Adjust these lists if your column names differ.
teacher_expected_cols = {
    'ast': [
        'segment_path',
        'speech_score', 'music_score', 'noise_score',
        'speech_label', 'music_label', 'noise_label',
    ],
    'clap': [
        'segment_path',
        'speech_score', 'music_score', 'noise_score',
        'speech_label', 'music_label', 'noise_label',
    ],
    'm2d': [
        'segment_path',
        'speech_score', 'music_score', 'noise_score',
        'speech_label', 'music_label', 'noise_label',
    ],
    'whisper': [
        'segment_path',
        'speech_score', 'music_score', 'noise_score',
        'speech_label', 'music_label', 'noise_label',
    ],
}

# Expected schema for GOLD v2 (HF-style manifest from compute_irr)
gold_expected_cols = [
    'segment_path',
    'raw_file', 'start_time', 'end_time',
    'speech', 'music', 'noise',
    'n_annotators', 'is_irr_segment',
]

print('--- GOLD schema check ---')
missing_gold = [c for c in gold_expected_cols if c not in gold_blocs.columns]
if missing_gold:
    print('Missing in gold_smad_v2:', missing_gold)
else:
    display(gold_blocs[gold_expected_cols].head())
    print(gold_blocs[gold_expected_cols].dtypes)

print('\n--- Teacher schema checks ---')
for name, df in teacher_dfs.items():
    print(f'\n[{name}] shape={df.shape}')

    expected = teacher_expected_cols[name]
    missing = [c for c in expected if c not in df.columns]
    if missing:
        print('  Missing expected columns:', missing)
    else:
        # Small preview of key columns
        display(df[expected].head())
        print(df[expected].dtypes)

## 3. Join teachers with gold IRR subset

+ Here we restrict labels to `is_irr_segment = True` for calibration on our double-labeled gold subset.
+ Quick sanity tables: counts per class in gold vs each teacher’s positives at current thresholds.


In [None]:
# 3.1. Build a unified teacher score table with prefixed columns.

def make_prefixed_teacher_df(df: pd.DataFrame, prefix: str) -> pd.DataFrame:
    """Keep segment_path and SMN scores, prefix scores with teacher shortname."""
    cols = ['segment_path', 'speech_score', 'music_score', 'noise_score']
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"{prefix}: missing columns {missing}")
    df_sub = df[cols].copy()
    rename_map = {
        'speech_score': f'{prefix}_speech_score',
        'music_score': f'{prefix}_music_score',
        'noise_score': f'{prefix}_noise_score',
    }
    return df_sub.rename(columns=rename_map)

ast_scores = make_prefixed_teacher_df(ast_smad_v2, 'ast')
clap_scores = make_prefixed_teacher_df(clap_smad_v2, 'clap')
m2d_scores = make_prefixed_teacher_df(m2d_smad_v2, 'm2d_clap')
whisper_scores = make_prefixed_teacher_df(whisper_smad_v2, 'whisper')

# Start from AST as a base and inner join others on segment_path
teacher_blocs = (
    ast_scores
    .merge(clap_scores, on='segment_path', how='inner')
    .merge(m2d_scores, on='segment_path', how='inner')
    .merge(whisper_scores, on='segment_path', how='inner')
)

print('teacher_blocs shape:', teacher_blocs.shape)
teacher_blocs.head()

In [None]:
# Restrict rows to the double-labeled IRR subset.

irr_gold = gold_blocs[gold_blocs['is_irr_segment'] == True]
print(f"IRR subset shape: {irr_gold.shape}")

In [None]:
# 3c. Join IRR gold with all teacher scores.

irr_join = irr_gold.merge(teacher_blocs, on='segment_path', how='inner')
print(f"irr_join shape: {irr_join.shape}")

if len(irr_join) != len(irr_gold):
    missing = set(irr_gold['segment_path']) - set(teacher_blocs['segment_path'])
    print(f"Warning: {len(missing)} IRR segments missing teacher scores.")

In [None]:
irr_join.head()

In [None]:
# Set thresholds for teacher comparison based on prior knowledge or defaults.

labels = ['speech', 'music', 'noise']

per_teacher_thresholds = {
    # AudioSet teachers (AST, Whisper-AT)
    'ast': {
        'speech': 0.6,
        'music': 0.40,
        'noise': 0.40,
    },
    'whisper': {
        'speech': 0.60,
        'music': 0.40,
        'noise': 0.40,
    },
    # Laion-CLAP zero-shot SMN
    'clap': {
        'speech': 0.45,
        'music': 0.35,
        'noise': 0.35,
    },
    # M2D-CLAP cosine thresholds
    'm2d_clap': {
        'speech': 0.24,
        'music': 0.26,
        'noise': 0.24,
    },
}

# Auto discover teacher score columns of the form <teacher>_<label>_score.
teacher_score_cols = []
for col in irr_join.columns:
    # skip gold columns like speech_gold etc.
    if col.endswith('_gold'):
        continue
    for lab in labels:
        if col.endswith(f'_{lab}_score'):
            teacher_score_cols.append(col)
            break

# print('Teacher score columns found:')
# for c in teacher_score_cols:
#     print('  ', c)

rows, n_segments = [], len(irr_join)

for teacher, thresh_map in per_teacher_thresholds.items():
    for label in labels:
        score_col = f'{teacher}_{label}_score'
        if score_col not in irr_join.columns:
            print(f"Warning: {score_col} not found in irr_join.")
            continue

        thr = thresh_map[label]
        threshold = thr  # set threshold for teacher and label.
        gold_col = f'{label}_gold'
        if gold_col not in irr_join.columns:
            raise ValueError(f'Missing gold column: {gold_col}')

        n_teacher_positive = int((irr_join[score_col] >= threshold).sum())
        n_gold_positive = int((irr_join[gold_col] == 1).sum())

        rows.append({
            'teacher': teacher,
            'label': label,
            'threshold': threshold,
            'n_segments': n_segments,
            'n_teacher_positive': n_teacher_positive,
            'n_gold_positive': n_gold_positive,
            'teacher_pos_rate': n_teacher_positive / n_segments if n_segments > 0 else float('nan'),
            'gold_pos_rate': n_gold_positive / n_segments if n_segments > 0 else float('nan'),
        })

gold_vs_teacher_counts = (
    pd.DataFrame(rows)
    .sort_values(by=['teacher', 'label'])
    .reset_index(drop=True)
)

In [None]:
# Show summary table of counts and positive rates.
gold_vs_teacher_counts

### Summary of gold vs teacher results

On the IRR subset, we have **N = ____** segments. In BLOCS-SMAD-GOLD, the class prevalences are:

- Speech: ____ segments (____%)
- Music:  ____ segments (____%)
- Noise:  ____ segments (____%)

At the current working thresholds, the teachers behave as follows:

- **AST**
  Speech: teacher pos rate ____ vs gold ____
  Music:  teacher pos rate ____ vs gold ____
  Noise:  teacher pos rate ____ vs gold ____

- **Whisper-AT**
  Speech: teacher pos rate ____ vs gold ____
  Music:  teacher pos rate ____ vs gold ____
  Noise:  teacher pos rate ____ vs gold ____

- **Laion-CLAP**
  Speech: teacher pos rate ____ vs gold ____
  Music:  teacher pos rate ____ vs gold ____
  Noise:  teacher pos rate ____ vs gold ____

- **M2D-CLAP**
  Speech: teacher pos rate ____ vs gold ____
  Music:  teacher pos rate ____ vs gold ____
  Noise:  teacher pos rate ____ vs gold ____

Key takeaways in words:

- Speech: ___________________________________________
- Music:  ___________________________________________
- Noise:  ___________________________________________

## 4. Per-class score distributions (IRR subset)

+ For each teacher and each class of interest (speech, music, noise), we plot score distributions (probabilities / logits) on positive and negative gold labels.
+ This helps us visualize how well each teacher separates positives from negatives, and identify good candidate thresholds.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MaxNLocator

def set_plot_style() -> None:
    """Set a clean, readable style for EDA plots without grid lines."""
    sns.set_style("white", {"axes.grid": False})
    sns.set_context("talk")

def style_generic_plot(ax, title: str, xlabel: str, ylabel: str) -> None:
    """Apply consistent styling to a single axis."""
    sns.despine(ax=ax)
    ax.grid(axis='y', linestyle='--', alpha=0.4)
    # add horizontal grid lines.
    ax.set_title(title, fontsize=14)
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=6))
    ax.tick_params(axis="both", which="major", labelsize=10)
    ax.tick_params(axis="x", which="both", labelbottom=True)  # <- force x tick labels on
    ax.set_facecolor("white")

set_plot_style()

teachers = list(per_teacher_thresholds.keys())  # ['ast', 'whisper', 'clap', 'm2d_clap']
labels = ['speech', 'music', 'noise']

def plot_teacher_grid(irr_df: pd.DataFrame, teacher: str) -> None:
    """
    For a single teacher, plot score distributions for speech, music, noise
    in a 2×2 grid. Each panel overlays gold=0 and gold=1.
    """
    n_rows, n_cols = 2, 2
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(10, 7), sharex=True, sharey=False)
    axes_flat = axes.flatten()

    for i, label in enumerate(labels):
        ax = axes_flat[i]
        score_col = f'{teacher}_{label}_score'
        gold_col = f'{label}_gold'

        if score_col not in irr_df.columns:
            ax.axis('off')
            continue

        # Make a tidy frame for seaborn
        tmp = irr_df[[score_col, gold_col]].copy()
        tmp = tmp.rename(columns={score_col: 'score', gold_col: 'gold'})
        tmp['gold'] = tmp['gold'].map({0: 'gold=0', 1: 'gold=1'})

        sns.histplot(
            data=tmp,
            x='score',
            hue='gold',
            bins=20,
            stat='density',
            element='step',   # outlines + fill
            alpha=0.7,
            ax=ax,
        )

        thr = per_teacher_thresholds[teacher][label]
        ax.axvline(thr, linestyle='--', linewidth=1.5, label=f'thr={thr:.2f}')

        style_generic_plot(
            ax,
            title=f'{teacher} – {label} scores on IRR',
            xlabel='score',
            ylabel='density',
        )

        # Avoid duplicate legends in every panel: keep only top left.
        if i == 0:
            ax.legend()
        else:
            ax.legend().remove()

    # Turn off any unused subplot (bottom right, since we only have 3 labels).
    if len(labels) < len(axes_flat):
        for j in range(len(labels), len(axes_flat)):
            axes_flat[j].axis('off')

    fig.suptitle(f'{teacher} score distributions on IRR', fontsize=16)
    plt.tight_layout()
    plt.show()


# Example: loop teachers, one 2×2 grid per teacher
for t in teachers:
    plot_teacher_grid(irr_join, t)

# Summary of score distribution observations

- **AST**
    + Speech: ___________________________________________
    + Music:  ___________________________________________
    + Noise:  ___________________________________________
- **Whisper-AT**
    + Speech: ___________________________________________
    + Music:  ___________________________________________
    + Noise:  ___________________________________________
- **Laion-CLAP**
    + Speech: ___________________________________________
    + Music:  ___________________________________________
    + Noise:  ___________________________________________
- **M2D-CLAP**
    + Speech: ___________________________________________
    + Music:  ___________________________________________
    + Noise:  ___________________________________________

## 5. Threshold sweeps and calibration curves

+ For each teacher and class, we sweep over possible decision thresholds and plot precision-recall and ROC curves on the IRR subset.
+ This helps us pick operating points that balance false positives and false negatives according to our needs.
+ We record candidate thresholds that match desired precision or recall levels. (e.g., 90% precision on speech, 80% recall on music, etc.)

In [None]:
# 5.1. Imports and setup

from sklearn.metrics import precision_recall_curve, roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def compute_curves(df, teacher: str, label: str) -> tuple:
    """Compute precision-recall and ROC curves for a given teacher and label."""
    score_col = f'{teacher}_{label}_score'
    gold_col = f'{label}_gold'

    y_true = df[gold_col].astype(int).values
    y_score = df[score_col].astype(float).values

    # Precision Recall
    prec, rec, pr_thr = precision_recall_curve(y_true, y_score)

    # ROC
    fpr, tpr, roc_thr = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)

    return prec, rec, pr_thr, fpr, tpr, roc_thr, roc_auc


def plot_teacher_curves(df, teacher: str):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    for label in labels:
        score_col = f"{teacher}_{label}_score"
        if score_col not in df.columns:
            continue

        prec, rec, pr_thr, fpr, tpr, roc_thr, roc_auc = compute_curves(df, teacher, label)

        # --- PR curve ---
        ax = axes[0]
        ax.plot(rec, prec, label=f'{label}')
        ax.set_title(f'{teacher}: Precision–Recall')
        ax.set_xlabel('Recall')
        ax.set_ylabel('Precision')
        ax.grid(False)

        # --- ROC curve ---
        ax = axes[1]
        ax.plot(fpr, tpr, label=f'{label} (AUC={roc_auc:.2f})')
        ax.plot([0, 1], [0, 1], 'k--', linewidth=1)
        ax.set_title(f'{teacher}: ROC Curve')
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.grid(False)

    axes[0].legend()
    axes[1].legend()
    fig.suptitle(f'{teacher} — threshold calibration curves', fontsize=15)
    plt.tight_layout()
    plt.show()


In [None]:
# 5.2. Plot curves for each teacher.

for t in teachers:
    plot_teacher_curves(irr_join, t)

In [None]:
#$ 5.3. Record candidate thresholds based on desired precision/recall.

def find_threshold_for_precision(df, teacher, label, target_precision=0.90):
    prec, rec, thr = precision_recall_curve(
        df[f'{label}_gold'].astype(int),
        df[f'{teacher}_{label}_score']
    )

    # Remove the last "None" threshold sklearn appends
    thr = thr[:-1]

    candidates = np.where(prec[:-1] >= target_precision)[0]
    if len(candidates) == 0:
        return None

    idx = candidates[0]
    return float(thr[idx]), float(prec[idx]), float(rec[idx])

In [None]:
# 5.4. Test for candidate thresholds at target precision for each teacher and label.

find_threshold_for_precision(irr_join, 'ast', 'speech', target_precision=0.90)

## 5.5. Threshold Calibration Curves

This section sweeps thresholds from 0 to 1 and examines precision, recall, and ROC behavior on the IRR subset.

**Speech**:
+ Teacher achieving highest AUC: ____
+ Threshold achieving ≥90 percent precision: ____
+ Tradeoff comment: ____*

**Music**:
+ Teacher with most stable PR curve: ____
+ Threshold achieving ≥80 percent recall: ____*

**Noise**:
+ Most reliable teacher: ____
+ Threshold achieving balanced precision and recall (or best F1): ____*

Candidate Operating Points

Proposed thresholds derived from PR sweeps:
+ speech: ____
+ music: ____
+ noise: ____

## 6. Compare working vs calibrated thresholds

+ For each teacher and class, we compare the current working thresholds (if any) against the newly calibrated thresholds from the previous section.
+ We analyze how the new thresholds would change positive rates and error patterns on the IRR subset.
+ Then we finalize threshold choices and document them for pipeline integration.

In [None]:
# 6.1 Set new thresholds based on calibration findings.
calibrated_thresholds = {
    'ast': {'speech': 0.57, 'music': 0.43, 'noise': 0.41},
    'whisper': {'speech': 0.61, 'music': 0.39, 'noise': 0.38},
    'clap': {'speech': 0.46, 'music': 0.34, 'noise': 0.33},
    'm2d_clap': {'speech': 0.25, 'music': 0.27, 'noise': 0.25},
}

In [None]:
# 6.2. Compare working vs calibrated thresholds.

def compute_pos_rate(df: pd.DataFrame, teacher: str, label: str, threshold: float):
    """Compute positive count and rate for a given teacher/label/threshold on irr_join."""
    score_col = f'{teacher}_{label}_score'
    if score_col not in df.columns:
        return None, None
    mask = df[score_col] >= threshold
    n_pos = int(mask.sum())
    rate = float(mask.mean()) if len(df) > 0 else float('nan')
    return n_pos, rate


# Sanity check that calibrated_thresholds exists.
try:
    calibrated_thresholds
except NameError as e:
    raise ValueError(
        "calibrated_thresholds is not defined. "
        "Please create a dict like calibrated_thresholds[teacher][label] = new_threshold "
        "based on your PR/ROC sweeps in Section 5."
    ) from e

rows = []
n_segments_irr = len(irr_join)

for teacher, old_map in per_teacher_thresholds.items():
    new_map = calibrated_thresholds.get(teacher, {})

    for label in labels:
        old_thr = old_map.get(label, None)
        new_thr = new_map.get(label, None)

        # Skip if we do not have scores for this teacher/label.
        score_col = f'{teacher}_{label}_score'
        if score_col not in irr_join.columns:
            continue

        # Old threshold stats
        old_n_pos, old_rate = (None, None)
        if old_thr is not None:
            old_n_pos, old_rate = compute_pos_rate(irr_join, teacher, label, old_thr)

        # New threshold stats
        new_n_pos, new_rate = (None, None)
        if new_thr is not None:
            new_n_pos, new_rate = compute_pos_rate(irr_join, teacher, label, new_thr)

        rows.append({
            'teacher': teacher,
            'label': label,
            'old_threshold': old_thr,
            'new_threshold': new_thr,
            'delta_threshold': (
                None if (old_thr is None or new_thr is None)
                else float(new_thr - old_thr)
            ),
            'n_segments_irr': n_segments_irr,
            'old_n_positive': old_n_pos,
            'new_n_positive': new_n_pos,
            'old_pos_rate': old_rate,
            'new_pos_rate': new_rate,
            'delta_pos_rate': (
                None if (old_rate is None or new_rate is None)
                else float(new_rate - old_rate)
            ),
        })

threshold_delta_table = (
    pd.DataFrame(rows)
    .sort_values(['teacher', 'label'])
    .reset_index(drop=True)
)

threshold_delta_table

## 7. Qualitative spot checks of teacher-gold disagreements (optional)

+ Sample sample segments where each teacher strongly disagrees with gold (false positives / false negatives) at the chosen thresholds.
+ Show scores, labels, and play audio; jot down typical failure modes and edge cases for each teacher.

## 8. Export final thresholds and summary (optional)

+ We compile the final chosen thresholds for each teacher and class into a summary table.
+ Export this table to a CSV or JSON file for easy reference during pipeline implementation.
+ Summarize key findings and recommendations for future teacher calibration efforts.