# Event Detection: Predictions vs Ground Truth

This notebook overlays predicted events against ground truth annotations. It is designed for DCASE-style CSVs but also works with raw training CSVs (per-class columns). Use the config cells to point to your prediction CSV and the correct GT root, then plot a smaller, readable subset.

**Tip**
- If you only see ground truth, your `GT_ROOT` likely does not match the prediction set. The summary cell below will show the intersection size so you can correct the paths.


In [None]:
from pathlib import Path
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

# Make sure repo root is on the path when running from notebooks/
repo_root = Path('..').resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from utils.evaluation import POS_VALUE, N_SHOTS


## Configuration

In [None]:
# ------------------------------------------------------------------
# Paths
# ------------------------------------------------------------------
# Prediction CSV with columns: Audiofilename, Starttime, Endtime
PRED_CSV = Path('../outputs/mlflow_experiments/v1-pcen-f/val_eval/epoch_003/0.75/Eval_raw.csv')

# Root directory containing GT CSV files (recursively searched)
# Example for evaluation: Path('/data/msc-proj/Validation_Set')
GT_ROOT = Path('/data/msc-proj/Training_Set')

# ------------------------------------------------------------------
# GT parsing options
# ------------------------------------------------------------------
# Drop the 5 support shots before plotting (matches evaluation protocol)
DROP_SUPPORT_SHOTS = True

# Optional: focus on one class label if GT CSVs have per-class columns
# Set to None to accept any POS label
TARGET_CLASS = None  # e.g., 'AMRE'

# ------------------------------------------------------------------
# Plot selection options
# ------------------------------------------------------------------
# Which audio files to show: 'intersection', 'union', 'gt_only', or 'pred_only'
AUDIO_MODE = 'intersection'

# Filter to specific audio filenames (use .wav names) or None
AUDIO_FILTER = None  # e.g., ['ME1.wav', 'ME2.wav']

# Page through many files
MAX_FILES_PER_FIG = 25
PAGE = 0  # 0-based page index
SHUFFLE = False
RANDOM_SEED = 0


## Helpers

In [None]:
BASE_COLS = {"Audiofilename", "Starttime", "Endtime"}


def normalize_audio_name(name):
    name = Path(str(name)).name
    if not name.lower().endswith('.wav'):
        name = name + '.wav'
    return name


def _label_columns(df):
    return [c for c in df.columns if c not in BASE_COLS]


def _find_qe_column(df):
    for col in df.columns:
        if 'Q' in col or col.startswith('E_'):
            return col
    return None


def pos_mask(df, target_class=None):
    """Return a boolean mask of POS events, tolerant to GT schema differences."""
    if target_class and target_class in df.columns:
        return df[target_class].eq(POS_VALUE)

    qe_col = _find_qe_column(df)
    if qe_col:
        return df[qe_col].eq(POS_VALUE)

    label_cols = _label_columns(df)
    if label_cols:
        return df[label_cols].eq(POS_VALUE).any(axis=1)

    return pd.Series([False] * len(df), index=df.index)


def drop_support_shots(df, drop_support=True, target_class=None):
    if not drop_support:
        return df
    mask = pos_mask(df, target_class)
    pos_idx = df.index[mask].tolist()
    if len(pos_idx) < N_SHOTS:
        return df
    cutoff = df.loc[pos_idx[N_SHOTS - 1], 'Endtime']
    return df.loc[df['Endtime'] > cutoff]


def load_ground_truth(gt_root: Path, drop_support=True, target_class=None):
    """Return {audio_filename: [(start, end), ...]} from all GT CSVs under gt_root."""
    gt_events = {}
    for csv_path in gt_root.rglob('*.csv'):
        df = pd.read_csv(csv_path, dtype={'Starttime': float, 'Endtime': float})
        df = drop_support_shots(df, drop_support=drop_support, target_class=target_class)
        mask = pos_mask(df, target_class=target_class)
        events = list(zip(df.loc[mask, 'Starttime'], df.loc[mask, 'Endtime']))
        if events:
            if 'Audiofilename' in df.columns and df['Audiofilename'].notna().any():
                audio_name = normalize_audio_name(df['Audiofilename'].iloc[0])
            else:
                audio_name = csv_path.stem + '.wav'
            gt_events[audio_name] = events
    return gt_events


def load_predictions(pred_csv: Path):
    """Return {audio_filename: [(start, end), ...]} from a prediction CSV."""
    pred_df = pd.read_csv(pred_csv, dtype={'Starttime': float, 'Endtime': float})
    pred_df['Audiofilename'] = pred_df['Audiofilename'].map(normalize_audio_name)
    pred_events = {}
    for audio_name, group in pred_df.groupby('Audiofilename'):
        events = list(zip(group['Starttime'].astype(float), group['Endtime'].astype(float)))
        pred_events[audio_name] = events
    return pred_events


def select_audio_names(gt_dict, pred_dict, mode='intersection', audio_filter=None,
                       max_files=None, page=0, shuffle=False, seed=0):
    if mode == 'intersection':
        names = sorted(set(gt_dict) & set(pred_dict))
    elif mode == 'gt_only':
        names = sorted(set(gt_dict))
    elif mode == 'pred_only':
        names = sorted(set(pred_dict))
    elif mode == 'union':
        names = sorted(set(gt_dict) | set(pred_dict))
    else:
        raise ValueError(f"Unknown AUDIO_MODE: {mode}")

    if audio_filter:
        names = [n for n in names if n in set(audio_filter)]

    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(names)

    if max_files is None:
        return names

    start = page * max_files
    end = start + max_files
    return names[start:end]


def plot_events(gt_dict, pred_dict, names, title='Event Detection: Predictions vs Ground Truth'):
    if not names:
        raise ValueError('No audio files to plot. Check filters and paths.')

    fig, ax = plt.subplots(figsize=(18, max(5, len(names) * 0.6)))

    for idx, name in enumerate(names):
        y_gt = idx + 0.18
        y_pred = idx - 0.18

        for s, e in gt_dict.get(name, []):
            ax.hlines(y_gt, s, e, color='seagreen', linewidth=4, alpha=0.85)

        for s, e in pred_dict.get(name, []):
            ax.hlines(y_pred, s, e, color='royalblue', linewidth=3, alpha=0.85)

    ax.set_yticks(range(len(names)))
    ax.set_yticklabels(names)
    ax.set_xlabel('Time (seconds)')
    ax.set_ylabel('Audio Files')
    ax.set_title(title)
    ax.grid(True, axis='x', linestyle='--', alpha=0.3)

    handles = [
        Line2D([0], [0], color='seagreen', lw=4, label='Ground Truth'),
        Line2D([0], [0], color='royalblue', lw=3, label='Predictions'),
    ]
    ax.legend(handles=handles, loc='upper right', frameon=True)
    plt.tight_layout()
    return fig


## Load data and sanity checks

In [None]:
pred_events = load_predictions(PRED_CSV)
gt_events = load_ground_truth(GT_ROOT, drop_support=DROP_SUPPORT_SHOTS, target_class=TARGET_CLASS)

intersection = set(gt_events) & set(pred_events)
print(f"GT files: {len(gt_events)}")
print(f"Pred files: {len(pred_events)}")
print(f"Intersection: {len(intersection)}")

print("GT sample:", list(gt_events.keys())[:5])
print("Pred sample:", list(pred_events.keys())[:5])

if not intersection and AUDIO_MODE == 'intersection':
    print("WARNING: No overlap between GT and predictions. Check GT_ROOT or PRED_CSV.")


## Choose which files to plot

In [None]:
names = select_audio_names(
    gt_events,
    pred_events,
    mode=AUDIO_MODE,
    audio_filter=AUDIO_FILTER,
    max_files=MAX_FILES_PER_FIG,
    page=PAGE,
    shuffle=SHUFFLE,
    seed=RANDOM_SEED,
)

print(f"Plotting {len(names)} files (page {PAGE})")
print(names[:10])


## Plot

In [None]:
fig = plot_events(gt_events, pred_events, names)
plt.show()
