# Interactive Cough Detection Model Tester

This notebook provides an interactive Gradio interface for testing XGBoost cough detection models trained on multimodal biosignals (audio + IMU).

## Features

- **Load pre-trained models**: IMU-only, Audio-only, and Multimodal classifiers
- **Test on dataset recordings**: Select from public_dataset with ground truth comparison
- **Upload custom files**: Test on your own audio (WAV, MP3, OGG, M4A, WEBM) and/or CSV IMU data
- **Automatic audio conversion**: Handles various formats and sample rates (auto-converts to 16 kHz)
- **Flexible file uploads**: Upload only the required file(s) based on selected model (audio-only, IMU-only, or both)
- **Audio playback**: Listen to recordings while viewing predictions
- **Interactive Plotly visualizations**: 
  - Waveforms with color-coded TP/FP/FN detections
  - Raw window predictions (all sliding windows with probabilities)
  - Probability timeline showing continuous model confidence
  - Probability distribution histogram
  - Zoom, pan, and hover tooltips for detailed inspection
- **Window-level analysis**: View every individual sliding window prediction before merging
- **Threshold adjustment**: Fine-tune classification threshold in real-time
- **Event-based metrics**: TP/FP/FN counts with sensitivity, precision, F1 scores
- **Comprehensive statistics**: Window counts, merge ratios, probability distributions

## Prerequisites

**IMPORTANT**: You must first run `Model_Training_XGBoost.ipynb` to completion to generate the saved models in `models/`.

The training notebook should create:
- `models/xgb_imu.pkl`
- `models/xgb_audio.pkl`
- `models/xgb_multimodal.pkl`

## Section 1: Setup & Configuration

In [None]:
# Check for required dependencies
import sys

try:
    import gradio as gr
    import xgboost
    import joblib
    import plotly
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    print("✓ All required dependencies installed")
    print(f"  - gradio version: {gr.__version__}")
    print(f"  - xgboost version: {xgboost.__version__}")
    print(f"  - plotly version: {plotly.__version__}")
except ImportError as e:
    print(f"✗ Missing dependency: {e}")
    print("\nInstall with: uv add gradio xgboost joblib plotly")
    sys.exit(1)

In [None]:
# Import dependencies
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.io import wavfile
from scipy import signal
import librosa
from sklearn.preprocessing import StandardScaler
import pickle
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')

# Add src directory to path
if os.path.exists("/kaggle/usr/lib/"):
    # Load from Kaggle as utility scripts
    from edge_ai_cough_count_helpers import * # pyright: ignore[reportMissingImports]
    from edge_ai_cough_count_dataset_gen import * # pyright: ignore[reportMissingImports]
    from edge_ai_cough_count_features import * # pyright: ignore[reportMissingImports]
else:
    # Add src directory to path
    sys.path.append(os.path.abspath('../src'))
    from helpers import *
    from dataset_gen import *
    from features import *

print("✓ All imports successful")

In [None]:
# Set constants
FS_AUDIO_CONST = 16000  # Audio sampling frequency
FS_IMU_CONST = 100      # IMU sampling frequency
WINDOW_LEN = 0.4        # Window length in seconds
HOP_SIZE = 0.05         # Default hop size for sliding window (50ms)

# Locate dataset folder
kaggle_dataset_dir = '/kaggle/input/edge-ai-cough-count'
base_dir = kaggle_dataset_dir if os.path.exists(kaggle_dataset_dir) else ".."
data_folder = base_dir + '/public_dataset/'

if not os.path.exists(data_folder):
    raise FileNotFoundError(
        "Cannot find public_dataset/. Please download from: "
        "https://zenodo.org/record/7562332"
    )

# Locate models directory
kaggle_model_dir = '/kaggle/input/model-training-xgboost'
model_base_dir = kaggle_model_dir if os.path.exists(kaggle_model_dir) else "."
MODEL_DIR = Path(model_base_dir + "/models")

print(f"Configuration:")
print(f"  Audio FS: {FS_AUDIO_CONST} Hz")
print(f"  IMU FS: {FS_IMU_CONST} Hz")
print(f"  Window length: {WINDOW_LEN}s")
print(f"  Dataset folder: {data_folder if data_folder else 'Not found'}")
print(f"  Models directory: {MODEL_DIR}")

## Section 2: Model Loading System

In [None]:
def load_trained_models():
    """
    Load all three trained models from disk.
    
    Returns:
        dict: Dictionary with keys 'imu', 'audio', 'multimodal'
              Each value is {'model': XGBClassifier, 'scaler': StandardScaler, 'threshold': float}
    """
    models = {}
    
    for modality in ['imu', 'audio', 'multimodal']:
        model_path = MODEL_DIR / f'xgb_{modality}.pkl'
        
        if not model_path.exists():
            raise FileNotFoundError(
                f"\n{'='*70}\n"
                f"ERROR: Model file not found: {model_path}\n\n"
                f"Please run Model_Training_XGBoost.ipynb first to train and save models.\n"
                f"The training notebook should create the following files:\n"
                f"  - models/xgb_imu.pkl\n"
                f"  - models/xgb_audio.pkl\n"
                f"  - models/xgb_multimodal.pkl\n"
                f"{'='*70}"
            )
        
        with open(model_path, 'rb') as f:
            models[modality] = pickle.load(f)
        
        print(f"✓ Loaded {modality} model from {model_path}")
        print(f"  Threshold: {models[modality]['threshold']:.3f}")
    
    return models

# Load models
try:
    MODELS = load_trained_models()
    print(f"\n✓ All models loaded successfully")
except FileNotFoundError as e:
    print(e)
    MODELS = None

## Section 3: Feature Extraction Utilities

In [None]:
def extract_features_for_window(audio_window, imu_window, modality='multimodal'):
    """
    Extract features from a single window of audio and IMU data.
    
    Args:
        audio_window: (N_audio,) audio samples
        imu_window: (N_imu, 6) IMU samples
        modality: 'imu', 'audio', or 'multimodal'
    
    Returns:
        np.array: Feature vector
    """
    features = []
    
    if modality in ['audio', 'multimodal']:
        audio_feat = extract_audio_features(audio_window, fs=FS_AUDIO_CONST)
        # Handle NaN/Inf
        audio_feat = np.nan_to_num(audio_feat, nan=0.0, posinf=0.0, neginf=0.0)
        features.append(audio_feat)
    
    if modality in ['imu', 'multimodal']:
        imu_feat = extract_imu_features(imu_window)
        # Handle NaN/Inf
        imu_feat = np.nan_to_num(imu_feat, nan=0.0, posinf=0.0, neginf=0.0)
        features.append(imu_feat)
    
    return np.concatenate(features)

print("✓ Feature extraction utilities ready")

## Section 4: Sliding Window Prediction Engine

In [None]:
def sliding_window_predict(audio, imu, model_data, modality='multimodal', 
                          window_len=0.4, hop_size=0.05, threshold=None):
    """
    Apply model to continuous recording using sliding windows.
    
    Args:
        audio: (N_audio,) audio samples
        imu: (N_imu, 6) IMU samples
        model_data: Dict with 'model', 'scaler', 'threshold'
        modality: 'imu', 'audio', or 'multimodal'
        window_len: Window length in seconds
        hop_size: Hop size in seconds
        threshold: Classification threshold (None = use optimal from model)
    
    Returns:
        predictions: List of (start_time, end_time, probability) tuples (only above threshold)
        all_probs: Array of probabilities for each window
        window_times: Array of window center times
        all_windows: List of (start, end, center, prob) for ALL windows
    """
    model = model_data['model']
    scaler = model_data['scaler']
    if threshold is None:
        threshold = model_data['threshold']
    
    # Calculate window and hop in samples
    audio_win_samples = int(window_len * FS_AUDIO_CONST)
    audio_hop_samples = int(hop_size * FS_AUDIO_CONST)
    imu_win_samples = int(window_len * FS_IMU_CONST)
    imu_hop_samples = int(hop_size * FS_IMU_CONST)
    
    # Extract windows
    n_windows = (len(audio) - audio_win_samples) // audio_hop_samples + 1
    features_list = []
    window_times = []
    
    for i in range(n_windows):
        audio_start = i * audio_hop_samples
        audio_end = audio_start + audio_win_samples
        imu_start = i * imu_hop_samples
        imu_end = imu_start + imu_win_samples
        
        if audio_end > len(audio) or imu_end > len(imu):
            break
        
        audio_window = audio[audio_start:audio_end]
        imu_window = imu[imu_start:imu_end, :]
        
        features = extract_features_for_window(audio_window, imu_window, modality)
        features_list.append(features)
        
        # Window center time
        center_time = (audio_start + audio_win_samples / 2) / FS_AUDIO_CONST
        window_times.append(center_time)
    
    # Batch predict
    X = np.array(features_list)
    X_scaled = scaler.transform(X)
    probs = model.predict_proba(X_scaled)[:, 1]
    
    # Create list of ALL windows with their probabilities
    all_windows = []
    # Convert to event-based predictions (only above threshold)
    predictions = []
    for i, (prob, center) in enumerate(zip(probs, window_times)):
        start = center - window_len / 2
        end = center + window_len / 2
        all_windows.append((start, end, center, prob))
        if prob >= threshold:
            predictions.append((start, end, prob))
    
    return predictions, probs, np.array(window_times), all_windows

print("✓ Sliding window prediction engine ready")

In [None]:
def merge_detections(predictions, gap_threshold=0.3):
    """
    Merge consecutive detections that are close together.
    
    Args:
        predictions: List of (start, end, prob) tuples
        gap_threshold: Maximum gap between events to merge (seconds)
    
    Returns:
        merged: List of (start, end, max_prob) tuples
    """
    if not predictions:
        return []
    
    # Sort by start time
    sorted_preds = sorted(predictions, key=lambda x: x[0])
    
    merged = []
    current_start, current_end, current_prob = sorted_preds[0]
    
    for start, end, prob in sorted_preds[1:]:
        # If gap is small, merge
        if start - current_end <= gap_threshold:
            current_end = max(current_end, end)
            current_prob = max(current_prob, prob)
        else:
            # Save current event and start new one
            merged.append((current_start, current_end, current_prob))
            current_start, current_end, current_prob = start, end, prob
    
    # Add last event
    merged.append((current_start, current_end, current_prob))
    
    return merged

def classify_predictions(predictions, ground_truth, tolerance_start=0.25, 
                        tolerance_end=0.25, min_overlap=0.1):
    """
    Classify predictions as TP/FP and identify FN.
    
    Args:
        predictions: List of (start, end, prob) tuples
        ground_truth: List of (start, end) tuples
        tolerance_start: Start tolerance in seconds
        tolerance_end: End tolerance in seconds
        min_overlap: Minimum overlap ratio to count as TP
    
    Returns:
        tp_list: List of TP predictions (start, end, prob)
        fp_list: List of FP predictions (start, end, prob)
        fn_list: List of FN ground truth events (start, end)
    """
    if not ground_truth:
        # No ground truth - all predictions are unknown
        return [], [], []
    
    if not predictions:
        # No predictions - all ground truth are FN
        return [], [], ground_truth
    
    # Convert to arrays
    pred_starts = np.array([p[0] for p in predictions])
    pred_ends = np.array([p[1] for p in predictions])
    gt_starts = np.array([g[0] for g in ground_truth])
    gt_ends = np.array([g[1] for g in ground_truth])
    
    # Track matches
    gt_matched = np.zeros(len(ground_truth), dtype=bool)
    tp_list = []
    fp_list = []
    
    # Classify each prediction
    for pred_start, pred_end, prob in predictions:
        matched = False
        
        for i, (gt_start, gt_end) in enumerate(zip(gt_starts, gt_ends)):
            if gt_matched[i]:
                continue
            
            # Check overlap
            overlap_start = max(pred_start, gt_start - tolerance_start)
            overlap_end = min(pred_end, gt_end + tolerance_end)
            
            if overlap_end > overlap_start:
                overlap_duration = overlap_end - overlap_start
                gt_duration = gt_end - gt_start
                
                if overlap_duration / gt_duration >= min_overlap:
                    # True Positive
                    tp_list.append((pred_start, pred_end, prob))
                    gt_matched[i] = True
                    matched = True
                    break
        
        if not matched:
            # False Positive
            fp_list.append((pred_start, pred_end, prob))
    
    # Unmatched GT events are false negatives
    fn_list = [ground_truth[i] for i in range(len(ground_truth)) if not gt_matched[i]]
    
    return tp_list, fp_list, fn_list

print("✓ Detection merging and classification ready")

In [None]:
def refine_cough_events(audio, candidate_segments, fs_audio=16000,
                        t_dedup=0.23, t_bout=0.55):
    """
    Post-process merged candidate segments using Cough-E-style peak refinement.

    Fixes two failure modes of gap-based merging:
      1. Overlapping windows double-count the same cough.
      2. Multiple coughs in a bout collapse into one long merged region.

    Strategy
    --------
    Rather than computing power thresholds *within* each candidate chunk
    (which fails when one loud cough dominates the threshold and masks quieter
    ones in the same segment), we call segment_cough() on the **full audio**.
    Global thresholds are calibrated to the whole recording — mostly silence —
    so the RMS is low and every real cough clears the bar, regardless of whether
    it is the loudest or not.

    We then keep only physical coughs whose peak falls inside a confirmed
    candidate region (from the ML classifier), applying a ±0.3 s grace window
    to account for window-level timing uncertainty.

    After filtering, we apply the paper's two physiology-based rules:
      - Deduplicate peaks < t_dedup=0.23 s apart  (same cough, keep union)
      - Bout-split: if next peak < t_bout=0.55 s away, set end[n] = start[n+1]

    Args:
        audio:              (N,) audio samples at fs_audio Hz
        candidate_segments: List of (start, end, prob) from merge_detections()
        fs_audio:           Audio sampling rate in Hz (default 16000)
        t_dedup:            Min inter-peak gap to count as a new cough (0.23 s)
        t_bout:             Max inter-peak gap within a cough bout (0.55 s)

    Returns:
        refined: List of (start, end, prob) — one tuple per refined cough event
    """
    if not candidate_segments:
        return []

    # ── Step 1: Run segment_cough() on the FULL audio ─────────────────────────
    # Thresholds (th_h = 2*rms, th_l = 0.1*rms) are computed globally, so a
    # single loud spike does NOT raise the bar for quieter coughs elsewhere.
    _, _, starts_idx, ends_idx, _, peak_locs_idx = segment_cough(
        audio, fs_audio,
        cough_padding=0.1,    # 100 ms padding (smaller — we keep our own margins)
        min_cough_len=0.1,    # allow events as short as 100 ms
        th_l_multiplier=0.1,
        th_h_multiplier=2.0,
    )

    if len(starts_idx) == 0:
        # segment_cough found nothing — fall back to coarse candidate list
        return candidate_segments

    # Convert sample indices → seconds
    seg_starts = starts_idx / fs_audio
    seg_ends   = ends_idx   / fs_audio
    seg_peaks  = np.array(peak_locs_idx) / fs_audio

    # ── Step 2: Filter to confirmed candidate regions ─────────────────────────
    # Only keep physical coughs whose peak lies within (or within ±0.3 s of)
    # a ML-confirmed candidate region.  This suppresses false alarms from
    # segment_cough in silent / non-cough sections of the recording.
    cand_tol = 0.3
    filtered = []
    for seg_st, seg_et, seg_pt in zip(seg_starts, seg_ends, seg_peaks):
        for cand_st, cand_et, prob in candidate_segments:
            if cand_st - cand_tol <= seg_pt <= cand_et + cand_tol:
                filtered.append([seg_st, seg_et, seg_pt, prob])
                break

    if not filtered:
        # No overlap found — fall back so we don't silently drop everything
        return candidate_segments

    # Sort by peak time before applying the next two rules
    filtered.sort(key=lambda x: x[2])

    # ── Step 3: Deduplicate peaks closer than t_dedup (same physical cough) ──
    deduped = [filtered[0][:]]
    for r in filtered[1:]:
        prev = deduped[-1]
        if r[2] - prev[2] < t_dedup:
            # Same cough: take union of regions, keep max probability
            prev[1] = max(prev[1], r[1])
            prev[3] = max(prev[3], r[3])
        else:
            deduped.append(r[:])

    # ── Step 4: Bout splitting ────────────────────────────────────────────────
    # If the next cough's peak is within t_bout of the current peak, the two
    # coughs are in a bout.  Hard-split by setting end[n] = start[n+1].
    final = []
    for i, r in enumerate(deduped):
        start_t, end_t, peak_t, prob = r
        if i + 1 < len(deduped):
            next_start = deduped[i + 1][0]
            next_peak  = deduped[i + 1][2]
            if next_peak - peak_t < t_bout:
                end_t = next_start   # hard split at next region start
        final.append((start_t, end_t, prob))

    return final

print("✓ Cough-E post-processing (refine_cough_events) ready")

## Section 5: Visualization Functions

In [None]:
def plot_predictions_plotly(audio, imu, predictions, ground_truth, all_windows, 
                           threshold, window_times, all_probs):
    """
    Create interactive Plotly visualization with multiple views.
    
    Args:
        audio: (N,) audio samples
        imu: (M, 6) IMU samples
        predictions: List of merged (start, end, prob) tuples
        ground_truth: Optional list of (start, end) tuples
        all_windows: List of ALL (start, end, center, prob) tuples
        threshold: Classification threshold used
        window_times: Array of window center times
        all_probs: Array of all window probabilities
    
    Returns:
        plotly.graph_objects.Figure
    """
    # Classify predictions if ground truth available
    if ground_truth:
        tp_list, fp_list, fn_list = classify_predictions(predictions, ground_truth)
    else:
        tp_list, fp_list, fn_list = [], [], []

    # Map (start, end) → global event index so numbers stay consistent across
    # TP/FP/FN classification and match the rows in the events table.
    pred_index_map = {(s, e): i for i, (s, e, _) in enumerate(predictions)}

    # Color pairs [even-index, odd-index] — alternate so adjacent events are
    # always visually distinct even when their boundaries touch.
    if ground_truth:
        tp_colors = ['rgba(40, 167, 69, 0.4)',  'rgba(23, 162, 184, 0.4)']  # green / teal
        fp_colors = ['rgba(220, 53, 69, 0.4)',  'rgba(253, 126, 20, 0.4)']  # red / orange
        fn_color  = 'rgba(255, 193, 7, 0.25)'                                # amber (no alternation)
    else:
        pred_colors = ['rgba(220, 53, 69, 0.4)', 'rgba(52, 120, 219, 0.4)'] # red / steel-blue

    # Create subplots: Audio, Raw Windows, Probability Timeline, IMU, Histogram
    fig = make_subplots(
        rows=5, cols=1,
        row_heights=[0.25, 0.15, 0.15, 0.25, 0.20],
        subplot_titles=(
            f'Audio Waveform + Merged Events ({len(predictions)} detections)',
            f'Raw Window Predictions ({len(all_windows)} windows, {sum(all_probs >= threshold)} above threshold)',
            f'Probability Timeline (threshold = {threshold:.3f})',
            'IMU Accelerometer Z (negated) + Merged Events',
            f'Probability Distribution ({len(all_probs)} windows)'
        ),
        vertical_spacing=0.08,
        specs=[[{"secondary_y": False}],
               [{"secondary_y": False}],
               [{"secondary_y": False}],
               [{"secondary_y": False}],
               [{"secondary_y": False}]]
    )
    
    # Time axes
    audio_time = np.arange(len(audio)) / FS_AUDIO_CONST
    imu_time = np.arange(len(imu)) / FS_IMU_CONST
    
    # ========== Subplot 1: Audio Waveform ==========
    fig.add_trace(
        go.Scatter(x=audio_time, y=audio, mode='lines', 
                  line=dict(color='black', width=0.5),
                  name='Audio', showlegend=True,
                  hovertemplate='Time: %{x:.3f}s<br>Amplitude: %{y:.3f}<extra></extra>'),
        row=1, col=1
    )
    
    # Add merged events — alternating colors + thin border + event number label.
    # The border ensures adjacent events are visually separated even when they
    # share an exact boundary (e.g. event #1 ends at 2.00s, #2 starts at 2.00s).
    if ground_truth:
        for start, end, prob in tp_list:
            idx = pred_index_map.get((start, end), 0)
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=tp_colors[idx % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(40,167,69,0.8)',
                row=1, col=1,
                annotation_text=f"TP #{idx + 1}",
                annotation_position="top left",
            )
        for start, end, prob in fp_list:
            idx = pred_index_map.get((start, end), 0)
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=fp_colors[idx % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(180,40,40,0.8)',
                row=1, col=1,
                annotation_text=f"FP #{idx + 1}",
                annotation_position="top left",
            )
        for start, end in fn_list:
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=fn_color, opacity=1.0,
                layer="below", line_width=1, line_color='rgba(180,140,0,0.6)',
                row=1, col=1,
                annotation_text="FN",
                annotation_position="top left",
            )
    else:
        for i, (start, end, prob) in enumerate(predictions):
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=pred_colors[i % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(80,80,80,0.5)',
                row=1, col=1,
                annotation_text=f"#{i + 1}",
                annotation_position="top left",
            )
    
    # Add ground truth spans
    if ground_truth:
        gt_legend_added = False
        for start, end in ground_truth:
            fig.add_trace(
                go.Scatter(x=[start, end], y=[0, 0], mode='lines',
                          line=dict(color='green', width=3),
                          name='Ground Truth', showlegend=(not gt_legend_added),
                          legendgroup='ground_truth',
                          hovertemplate=f'GT: {start:.3f}s - {end:.3f}s<extra></extra>'),
                row=1, col=1
            )
            gt_legend_added = True
    
    # ========== Subplot 2: Raw Window Predictions ==========
    # Color code windows by probability
    colors = []
    hover_texts = []
    for start, end, center, prob in all_windows:
        if prob >= threshold:
            colors.append(f'rgba(255, 0, 0, {min(prob, 1.0)})')  # Red with alpha = probability
        else:
            colors.append(f'rgba(100, 100, 100, {max(0.1, prob)})')  # Gray for below threshold
        
        hover_texts.append(
            f'Window #{all_windows.index((start, end, center, prob))}<br>' +
            f'Center: {center:.3f}s<br>' +
            f'Range: {start:.3f}s - {end:.3f}s<br>' +
            f'Duration: {end-start:.3f}s<br>' +
            f'Probability: {prob:.4f}<br>' +
            f'Above threshold: {"Yes" if prob >= threshold else "No"}'
        )
    
    fig.add_trace(
        go.Scatter(
            x=[w[2] for w in all_windows],  # center times
            y=[w[3] for w in all_windows],  # probabilities
            mode='markers',
            marker=dict(
                size=8,
                color=[w[3] for w in all_windows],
                colorscale='Reds',
                showscale=True,
                colorbar=dict(title="Probability", x=1.12, len=0.2, y=0.75,yanchor='middle'),
                line=dict(width=1, color='black')
            ),
            name='Raw Windows',
            text=hover_texts,
            hovertemplate='%{text}<extra></extra>'
        ),
        row=2, col=1
    )
    
    # Add threshold line
    fig.add_hline(y=threshold, line_dash="dash", line_color="red",
                 annotation_text=f"Threshold",
                 annotation_position="right", row=2, col=1
                 )
    
    # ========== Subplot 3: Probability Timeline ==========
    # Continuous probability curve
    fig.add_trace(
        go.Scatter(
            x=window_times,
            y=all_probs,
            mode='lines',
            line=dict(color='blue', width=2),
            fill='tonexty',
            name='Probability',
            hovertemplate='Time: %{x:.3f}s<br>Probability: %{y:.4f}<extra></extra>'
        ),
        row=3, col=1
    )
    
    # Highlight above-threshold regions
    above_thresh = all_probs >= threshold
    for i in range(len(window_times) - 1):
        if above_thresh[i]:
            fig.add_vrect(
                x0=window_times[i], x1=window_times[i+1],
                fillcolor="red", opacity=0.2,
                layer="below", line_width=0, row=3, col=1
            )
    
    # Add threshold line
    fig.add_hline(y=threshold, line_dash="dash", line_color="red",
                 row=3, col=1)
    
    # ========== Subplot 4: IMU Signal ==========
    fig.add_trace(
        go.Scatter(x=imu_time, y=-imu[:, 2], mode='lines',
                  line=dict(color='blue', width=1),
                  name='IMU Z-axis', showlegend=True,
                  hovertemplate='Time: %{x:.3f}s<br>Accel: %{y:.3f}<extra></extra>'),
        row=4, col=1
    )
    
    # IMU events — same alternating colors and borders as audio, no labels
    # (labels would be redundant and clutter the smaller IMU subplot).
    if ground_truth:
        for start, end, prob in tp_list:
            idx = pred_index_map.get((start, end), 0)
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=tp_colors[idx % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(40,167,69,0.8)',
                row=4, col=1,
            )
        for start, end, prob in fp_list:
            idx = pred_index_map.get((start, end), 0)
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=fp_colors[idx % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(180,40,40,0.8)',
                row=4, col=1,
            )
        for start, end in fn_list:
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=fn_color, opacity=1.0,
                layer="below", line_width=1, line_color='rgba(180,140,0,0.6)',
                row=4, col=1,
            )
    else:
        for i, (start, end, prob) in enumerate(predictions):
            fig.add_vrect(
                x0=start, x1=end,
                fillcolor=pred_colors[i % 2], opacity=1.0,
                layer="below", line_width=1, line_color='rgba(80,80,80,0.5)',
                row=4, col=1,
            )
    
    # ========== Subplot 5: Probability Distribution ==========
    fig.add_trace(
        go.Histogram(
            x=all_probs,
            nbinsx=50,
            name='All Windows',
            marker_color='lightblue',
            hovertemplate='Probability: %{x:.3f}<br>Count: %{y}<extra></extra>'
        ),
        row=5, col=1
    )
    
    # Add threshold line
    fig.add_vline(x=threshold, line_dash="dash", line_color="red",
                 annotation_text=f"Threshold",
                 annotation_position="top", row=5, col=1)
    
    # Statistics annotation
    stats_text = (
        f"Total Windows: {len(all_probs)}<br>" +
        f"Above Threshold: {sum(all_probs >= threshold)} ({100*sum(all_probs >= threshold)/len(all_probs):.1f}%)<br>" +
        f"Mean Prob: {np.mean(all_probs):.3f}<br>" +
        f"Merged Events: {len(predictions)}<br>" +
        f"Merge Ratio: {len(all_probs)}/{len(predictions)} = {len(all_probs)/max(len(predictions),1):.1f}x"
    )
    
    fig.add_annotation(
        text=stats_text,
        xref="paper", yref="paper",
        x=0.02, y=0.98, showarrow=False,
        bgcolor="white", bordercolor="black", borderwidth=1,
        align="left", font=dict(size=10)
    )
    
    # Update layout
    fig.update_xaxes(title_text="Time (s)", row=1, col=1)
    fig.update_xaxes(title_text="Time (s)", row=2, col=1)
    fig.update_xaxes(title_text="Time (s)", row=3, col=1)
    fig.update_xaxes(title_text="Time (s)", row=4, col=1)
    fig.update_xaxes(title_text="Probability", row=5, col=1)
    
    fig.update_yaxes(title_text="Amplitude", row=1, col=1)
    fig.update_yaxes(title_text="Probability", row=2, col=1)
    fig.update_yaxes(title_text="Probability", row=3, col=1)
    fig.update_yaxes(title_text="Acceleration", row=4, col=1)
    fig.update_yaxes(title_text="Count", row=5, col=1)
    
    # Sync x-axes for time-based plots
    fig.update_xaxes(matches='x', row=1, col=1)
    fig.update_xaxes(matches='x', row=2, col=1)
    fig.update_xaxes(matches='x', row=3, col=1)
    fig.update_xaxes(matches='x', row=4, col=1)
    
    fig.update_layout(
        height=1400,
        showlegend=True,
        hovermode='x unified',
        title_text=f"Interactive Cough Detection Analysis",
        title_font_size=16
    )
    
    return fig

print("✓ Interactive Plotly visualization ready")

## Section 6: Metrics Computation

In [None]:
def compute_event_metrics(predictions, ground_truth, tolerance_start=0.25, 
                         tolerance_end=0.25, min_overlap=0.1):
    """
    Compute event-based metrics (TP, FP, FN).
    
    Args:
        predictions: List of (start, end, prob) tuples
        ground_truth: List of (start, end) tuples
        tolerance_start: Start tolerance in seconds
        tolerance_end: End tolerance in seconds
        min_overlap: Minimum overlap ratio to count as TP
    
    Returns:
        dict: Metrics including TP, FP, FN, Sensitivity, Precision, F1
    """
    if not ground_truth:
        # No ground truth - just return detection count
        return {
            'TP': None,
            'FP': None,
            'FN': None,
            'Sensitivity': None,
            'Precision': None,
            'F1': None,
            'Total_Detections': len(predictions),
            'Ground_Truth_Count': 0
        }
    
    # Convert to start/end arrays
    pred_starts = np.array([p[0] for p in predictions])
    pred_ends = np.array([p[1] for p in predictions])
    gt_starts = np.array([g[0] for g in ground_truth])
    gt_ends = np.array([g[1] for g in ground_truth])
    
    # Track which GT events have been matched
    gt_matched = np.zeros(len(ground_truth), dtype=bool)
    tp_count = 0
    fp_count = 0
    
    # For each prediction, check if it matches a GT event
    for pred_start, pred_end in zip(pred_starts, pred_ends):
        matched = False
        
        for i, (gt_start, gt_end) in enumerate(zip(gt_starts, gt_ends)):
            if gt_matched[i]:
                continue
            
            # Check overlap
            overlap_start = max(pred_start, gt_start - tolerance_start)
            overlap_end = min(pred_end, gt_end + tolerance_end)
            
            if overlap_end > overlap_start:
                overlap_duration = overlap_end - overlap_start
                gt_duration = gt_end - gt_start
                
                if overlap_duration / gt_duration >= min_overlap:
                    # Match!
                    tp_count += 1
                    gt_matched[i] = True
                    matched = True
                    break
        
        if not matched:
            fp_count += 1
    
    # Unmatched GT events are false negatives
    fn_count = np.sum(~gt_matched)
    
    # Compute metrics
    sensitivity = tp_count / (tp_count + fn_count) if (tp_count + fn_count) > 0 else 0
    precision = tp_count / (tp_count + fp_count) if (tp_count + fp_count) > 0 else 0
    f1 = 2 * precision * sensitivity / (precision + sensitivity) if (precision + sensitivity) > 0 else 0
    
    return {
        'TP': int(tp_count),
        'FP': int(fp_count),
        'FN': int(fn_count),
        'Sensitivity': float(sensitivity),
        'Precision': float(precision),
        'F1': float(f1),
        'Total_Detections': len(predictions),
        'Ground_Truth_Count': len(ground_truth)
    }

print("✓ Metrics computation ready")

## Section 7: Data Loading Utilities

In [None]:
def load_dataset_recording(subject_id, trial, movement, noise, sound):
    """
    Load a recording from the dataset.
    
    Returns:
        audio: (N,) audio samples (outer mic)
        imu_data: (M, 6) IMU samples
        ground_truth: List of (start, end) or None
    """
    # Load audio (outer mic only) - uses peak normalization by default
    audio_air, _ = load_audio(data_folder, subject_id, trial, movement, noise, sound)
    
    # Load IMU
    imu_obj = load_imu(data_folder, subject_id, trial, movement, noise, sound)
    imu_data = imu_obj.make_segment_df().values
    
    # Load ground truth if cough recording
    ground_truth = None
    if sound == Sound.COUGH:
        try:
            start_times, end_times = load_annotation(data_folder, subject_id, trial, movement, noise, sound)
            ground_truth = list(zip(start_times, end_times))
        except:
            pass
    
    return audio_air, imu_data, ground_truth

def load_uploaded_audio(file_obj):
    """
    Load audio from various formats (WAV, MP3, OGG, M4A, WEBM) and convert to 16 kHz.
    
    Returns:
        audio: (N,) normalized audio samples in [-1, +1] range at 16 kHz
    """
    # librosa.load automatically:
    # - Handles multiple formats (WAV, MP3, OGG, M4A, WEBM, FLAC, etc.)
    # - Resamples to target sr (16000 Hz)
    # - Returns float32 normalized audio
    audio, fs = librosa.load(file_obj.name, sr=FS_AUDIO_CONST, mono=True)
    
    # Apply peak normalization (matching training preprocessing)
    audio = audio - np.mean(audio)
    audio = audio / (np.max(np.abs(audio)) + 1e-17)
    
    return audio

def load_uploaded_imu(file_obj):
    """
    Load IMU CSV from Gradio file upload.
    
    Returns:
        imu: (N, 6) IMU samples
    """
    df = pd.read_csv(file_obj.name)
    required_cols = ['Accel x', 'Accel y', 'Accel z', 'Gyro Y', 'Gyro P', 'Gyro R']
    
    if not all(col in df.columns for col in required_cols):
        raise ValueError(f"IMU CSV must contain: {required_cols}")
    
    return df[required_cols].values

def create_dummy_audio(duration_seconds):
    """
    Create dummy audio data for when only IMU is provided.
    
    Args:
        duration_seconds: Duration in seconds
    
    Returns:
        audio: (N,) zero-filled audio samples
    """
    n_samples = int(duration_seconds * FS_AUDIO_CONST)
    return np.zeros(n_samples, dtype=np.float32)

def create_dummy_imu(duration_seconds):
    """
    Create dummy IMU data for when only audio is provided.
    
    Args:
        duration_seconds: Duration in seconds
    
    Returns:
        imu: (N, 6) zero-filled IMU samples
    """
    n_samples = int(duration_seconds * FS_IMU_CONST)
    return np.zeros((n_samples, 6), dtype=np.float32)

print("✓ Data loading utilities ready")

## Section 8: Main Gradio Interface

In [None]:
def run_prediction(data_source, subject_id, trial, movement, noise, sound,
                  audio_file, imu_file, modality, threshold_override, use_refinement):
    """
    Main prediction function called by Gradio interface.
    """
    if MODELS is None:
        return None, None, {"Error": "Models not loaded"}, pd.DataFrame()
    
    try:
        # Map modality to model key
        modality_map = {
            "IMU-only": "imu",
            "Audio-only": "audio",
            "Multimodal": "multimodal"
        }
        model_key = modality_map[modality]
        
        # Load data based on source
        if data_source == "Dataset Selector":
            if not data_folder:
                return None, None, {"Error": "Dataset not found"}, pd.DataFrame()
            
            # Convert dropdown values to Enum
            trial_enum = Trial(trial)
            mov_enum = Movement(movement.lower())
            noise_enum = Noise(noise.lower().replace(' ', '_'))
            sound_enum = Sound(sound.lower().replace(' ', '_'))
            
            audio, imu, ground_truth = load_dataset_recording(
                subject_id, trial_enum, mov_enum, noise_enum, sound_enum
            )
        else:  # Upload Files
            # Check which files are required based on modality
            audio = None
            imu = None
            
            if model_key in ['audio', 'multimodal']:
                if audio_file is None:
                    return None, None, {"Error": f"Please upload audio file for {modality} model"}, pd.DataFrame()
                audio = load_uploaded_audio(audio_file)
            
            if model_key in ['imu', 'multimodal']:
                if imu_file is None:
                    return None, None, {"Error": f"Please upload IMU file for {modality} model"}, pd.DataFrame()
                imu = load_uploaded_imu(imu_file)
            
            # Create dummy data for missing modality if not required
            if audio is None and imu is not None:
                # IMU-only model: create dummy audio matching IMU duration
                duration = len(imu) / FS_IMU_CONST
                audio = create_dummy_audio(duration)
            elif imu is None and audio is not None:
                # Audio-only model: create dummy IMU matching audio duration
                duration = len(audio) / FS_AUDIO_CONST
                imu = create_dummy_imu(duration)
            
            ground_truth = None
        
        model_data = MODELS[model_key]
        
        # Override threshold if specified (0.0 means use optimal)
        threshold = model_data['threshold'] if threshold_override == 0.0 else threshold_override
        
        # Run prediction - now returns all_windows as well
        raw_predictions, all_probs, window_times, all_windows = sliding_window_predict(
            audio, imu, model_data, modality=model_key, threshold=threshold
        )
        
        if use_refinement:
            # Merge into coarse candidate segments with a wider gap so the refiner
            # can find and split bouts that would otherwise be pre-split here.
            candidate_segments = merge_detections(raw_predictions, gap_threshold=0.5)
            # Refine: hysteresis peak extraction → dedup at 0.23s → bout split at 0.55s
            predictions = refine_cough_events(audio, candidate_segments)
        else:
            # Classic gap-based merging only (original behaviour)
            predictions = merge_detections(raw_predictions, gap_threshold=0.3)
        
        # Compute metrics
        metrics = compute_event_metrics(predictions, ground_truth)
        
        # Add threshold info to metrics
        metrics['Threshold_Used'] = float(threshold)
        metrics['Is_Optimal_Threshold'] = (threshold_override == 0.0)
        metrics['Cough_E_Refinement'] = bool(use_refinement)
        
        # Create interactive Plotly visualization
        fig = plot_predictions_plotly(
            audio, imu, predictions, ground_truth, 
            all_windows, threshold, window_times, all_probs
        )
        
        # Create events table
        if predictions:
            events_df = pd.DataFrame([
                {'Start (s)': f'{s:.2f}', 'End (s)': f'{e:.2f}', 'Confidence': f'{p:.3f}'}
                for s, e, p in predictions
            ])
        else:
            events_df = pd.DataFrame({'Message': ['No coughs detected']})
        
        # Prepare audio for playback (only if real audio data exists)
        audio_playback = None
        if data_source == "Dataset Selector" or (data_source == "Upload Files" and audio_file is not None):
            audio_playback = (FS_AUDIO_CONST, audio)
        
        return fig, audio_playback, metrics, events_df
    
    except Exception as e:
        import traceback
        error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
        return None, None, {"Error": error_msg}, pd.DataFrame()

print("✓ Main prediction function ready")

In [None]:
# Get dataset parameters for dropdowns
if data_folder:
    subject_ids = [d for d in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, d))]
    subject_ids = sorted(subject_ids)
else:
    subject_ids = []

# Create Gradio interface
with gr.Blocks(title="Interactive Cough Detection Model Tester", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Interactive Cough Detection Model Tester
        
        Test XGBoost cough detection models on multimodal biosignals (audio + IMU).
        
        **Instructions:**
        1. Choose data source: Dataset recordings or upload your own files
        2. Select model: IMU-only, Audio-only, or Multimodal
        3. Upload only the required file(s) based on your selected model
        4. Adjust threshold if needed (0 = use optimal from training)
        5. Click "Run Prediction" to see results and listen to the audio
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            # Data source selector
            data_source = gr.Radio(
                choices=["Dataset Selector", "Upload Files"],
                value="Dataset Selector" if data_folder else "Upload Files",
                label="Data Source"
            )
            
            # Dataset selector (visible when Dataset Selector is chosen)
            with gr.Group(visible=(data_folder is not None)) as dataset_group:
                gr.Markdown("### Dataset Recording")
                subject_dropdown = gr.Dropdown(
                    choices=subject_ids,
                    value=subject_ids[0] if subject_ids else None,
                    label="Subject ID"
                )
                trial_dropdown = gr.Dropdown(
                    choices=["1", "2", "3"],
                    value="1",
                    label="Trial"
                )
                movement_dropdown = gr.Dropdown(
                    choices=["Sit", "Walk"],
                    value="Sit",
                    label="Movement"
                )
                noise_dropdown = gr.Dropdown(
                    choices=["Nothing", "Music", "Someone else cough", "Traffic"],
                    value="Nothing",
                    label="Background Noise"
                )
                sound_dropdown = gr.Dropdown(
                    choices=["Cough", "Laugh", "Deep breathing", "Throat clearing"],
                    value="Cough",
                    label="Sound Type"
                )
            
            # File upload (visible when Upload Files is chosen)
            with gr.Group(visible=(data_folder is None)) as upload_group:
                gr.Markdown("### Upload Files")
                
                # Dynamic hint about required files
                file_requirements_hint = gr.Markdown(
                    "**Required files:** Audio + IMU (for Multimodal model)"
                )
                
                audio_upload = gr.File(
                    label="Audio",
                )
                imu_upload = gr.File(
                    label="IMU CSV (100 Hz)",
                    file_types=[".csv"]
                )
                gr.Markdown(
                    "*CSV must contain: Accel x, Accel y, Accel z, Gyro Y, Gyro P, Gyro R*"
                )
            
            # Toggle visibility based on data source
            def toggle_data_source(choice):
                if choice == "Dataset Selector":
                    return gr.update(visible=True), gr.update(visible=False)
                else:
                    return gr.update(visible=False), gr.update(visible=True)
            
            data_source.change(
                toggle_data_source,
                inputs=[data_source],
                outputs=[dataset_group, upload_group]
            )
            
            # Model selection
            gr.Markdown("### Model Settings")
            modality_radio = gr.Radio(
                choices=["IMU-only", "Audio-only", "Multimodal"],
                value="Multimodal",
                label="Model"
            )
            
            # Display optimal threshold for selected model
            if MODELS is not None:
                optimal_thresh_display = gr.Markdown(
                    f"**Optimal Threshold:** {MODELS['multimodal']['threshold']:.3f}"
                )
            else:
                optimal_thresh_display = gr.Markdown("**Optimal Threshold:** Models not loaded")
            
            # Update optimal threshold display and file requirements when model changes
            def update_optimal_threshold(modality):
                if MODELS is None:
                    return "**Optimal Threshold:** Models not loaded"
                model_key = modality.lower().replace('-only', '')
                thresh = MODELS[model_key]['threshold']
                return f"**Optimal Threshold:** {thresh:.3f}"
            
            def update_file_requirements(modality):
                if modality == "IMU-only":
                    return "**Required files:** IMU only (audio not needed)"
                elif modality == "Audio-only":
                    return "**Required files:** Audio only (IMU not needed)"
                else:  # Multimodal
                    return "**Required files:** Both Audio + IMU"
            
            modality_radio.change(
                update_optimal_threshold,
                inputs=[modality_radio],
                outputs=[optimal_thresh_display]
            )
            
            modality_radio.change(
                update_file_requirements,
                inputs=[modality_radio],
                outputs=[file_requirements_hint]
            )
            
            threshold_slider = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                value=0.0,
                step=0.05,
                label="Threshold Override",
                info="Set to 0.0 to use optimal threshold above, or override with custom value"
            )
            
            refinement_checkbox = gr.Checkbox(
                value=True,
                label="Use Cough-E Refinement",
                info=(
                    "Post-process merged detections with peak-level analysis: "
                    "hysteresis on signal power, dedup peaks <0.23 s apart, "
                    "split cough bouts at 0.55 s. Improves per-cough counting "
                    "accuracy, especially for rapid or sequential coughs. "
                    "Uncheck to use classic gap-based merging only."
                )
            )
            
            # Run button
            run_btn = gr.Button(
                "Run Prediction",
                variant="primary",
                size="lg"
            )
        
        with gr.Column(scale=2):
            # Outputs
            gr.Markdown("### Results")
            
            # Audio playback widget
            audio_output = gr.Audio(
                label="Audio Playback",
                type="numpy",
                interactive=False
            )
            
            plot_output = gr.Plot(label="Waveform with Detections")
            metrics_output = gr.JSON(label="Metrics")
            events_output = gr.Dataframe(label="Detected Events")
    
    # Connect button to prediction function
    run_btn.click(
        run_prediction,
        inputs=[
            data_source, subject_dropdown, trial_dropdown, movement_dropdown,
            noise_dropdown, sound_dropdown, audio_upload, imu_upload,
            modality_radio, threshold_slider, refinement_checkbox
        ],
        outputs=[plot_output, audio_output, metrics_output, events_output]
    )

print("✓ Gradio interface created")

## Section 9: Launch Application

In [None]:
IS_IN_KAGGLE = os.environ.get('KAGGLE_URL_BASE') is not None

# Launch Gradio app
if MODELS is not None:
    print("\n" + "="*70)
    print("Launching Interactive Cough Detection Model Tester...")
    print("="*70)
    demo.launch(share=IS_IN_KAGGLE, inline=False, debug=True)
else:
    print("\n" + "="*70)
    print("Cannot launch: Models not loaded")
    print("Please run Model_Training_XGBoost.ipynb first to train models")
    print("="*70)

## Section 10: Usage Examples

### Example 1: Test on Dataset Recording

1. Select "Dataset Selector" as data source
2. Choose Subject: `14287`, Trial: `1`, Movement: `Sit`, Noise: `Nothing`, Sound: `Cough`
3. Select Model: `Multimodal`
4. Keep threshold at `0.0` (auto-optimal)
5. Click "Run Prediction"

**Expected output:**
- **Audio Playback**: Interactive player to listen to the recording
- **Waveform plot**: Red prediction spans and green ground truth spans
- **Metrics**: TP/FP/FN counts, Sensitivity ~0.9+, Precision ~0.8+
- **Events table**: Detected cough times with confidence scores

**Tip**: Use the audio player to correlate what you hear with the visual predictions!

### Example 2: Compare Models

Run the same recording through all three models:
- IMU-only: Lower sensitivity, may miss some coughs
- Audio-only: Good performance on coughs
- Multimodal: Best overall performance

Listen to the audio while comparing predictions to understand how each modality performs.

### Example 3: Threshold Adjustment

1. Run prediction with threshold `0.0` (optimal)
2. Increase threshold to `0.7` - fewer detections, higher precision
3. Decrease threshold to `0.3` - more detections, lower precision

Play the audio to verify which threshold setting matches your perception of coughs.

### Example 4: Test on Non-Cough Sounds

1. Select Sound: `Laugh` or `Throat clearing`
2. Model should show low/no detections (good specificity)
3. No ground truth will be shown (only available for coughs)
4. Listen to understand what non-cough sounds are present

### Example 5: Upload Custom Files (Multimodal)

1. Select "Upload Files" as data source
2. Select Model: `Multimodal`
3. Upload both WAV audio and CSV IMU files
4. Run prediction (no ground truth comparison available)
5. Use audio playback to verify predictions on your own data

### Example 6: Upload Audio Only

1. Select "Upload Files" as data source
2. Select Model: `Audio-only`
3. Upload only the WAV audio file (no IMU needed)
4. Notice the hint says "Required files: Audio only (IMU not needed)"
5. Run prediction - the system will automatically create dummy IMU data internally
6. Listen to the audio to verify predictions

### Example 7: Upload IMU Only

1. Select "Upload Files" as data source
2. Select Model: `IMU-only`
3. Upload only the CSV IMU file (no audio needed)
4. Notice the hint says "Required files: IMU only (audio not needed)"
5. Run prediction - the system will automatically create dummy audio data internally
6. No audio playback will be available (since no real audio was provided)

## Performance Notes

- **Processing time**: ~2-5 seconds for 10-second recording (depends on hardware)
- **Window size**: 0.4 seconds (fixed, from training)
- **Hop size**: 0.05 seconds (50ms overlap between windows)
- **Audio playback**: Full recording available for playback when audio data is provided
- **Multimodal model**: Best performance but requires both audio and IMU
- **IMU-only**: Useful for privacy-preserving scenarios (no audio required)
- **Audio-only**: Strong baseline, works well in quiet environments (no IMU required)
- **File flexibility**: Upload only the required file(s) based on your selected model

## Audio Playback Tips

- **Playback speed**: Use browser controls to slow down audio and identify coughs more easily
- **Loop sections**: Replay specific parts to understand false positives/negatives
- **Volume**: Adjust volume to hear quiet coughs that might be missed
- **Correlation**: Compare what you hear with red (predicted) and green (ground truth) spans on the waveform

## Limitations

1. **Dataset bias**: Models trained on 15 subjects, may not generalize to all populations
2. **Microphone dependency**: Audio features tuned to specific hardware
3. **Fixed window**: 0.4s windows may miss very long/short coughs
4. **Threshold sensitivity**: Performance varies with threshold choice
5. **No real-time processing**: Batch processing only (not streaming)
6. **Dummy data for single-modality uploads**: When using Audio-only or IMU-only models with uploaded files, the system creates zero-filled placeholder data for the missing modality - this works for prediction but means visualizations will show flat lines for the unused modality

## Next Steps

- Test on different subjects to assess generalization
- Use audio playback to understand model errors (listen to false positives/negatives)
- Experiment with threshold values for your use case
- Compare model performance across different noise conditions
- Analyze false positives/negatives to understand model weaknesses
- Try uploading single-modality files to test Audio-only or IMU-only models
- Consider deploying to edge device for real-time monitoring