In [None]:
"""
Background Sound Event Timeline with YAMNet (local)
---------------------------------------------------
- Loads a single audio file (any common format) and converts it to 16 kHz mono
- Runs YAMNet to get frame-wise event probabilities (~0.96 s per frame)
- (Optional) suppresses frames dominated by 'Speech' so you focus on background
- Groups contiguous frames of the same top event into segments
- Exports a CSV: start_sec, end_sec, label, max_prob
- (Optional) plots a simple timeline of top-K events over time

Requirements:
  pip install tensorflow tensorflow-hub librosa soundfile numpy pandas matplotlib imageio-ffmpeg
"""

from pathlib import Path
import subprocess, tempfile, math, csv
import numpy as np
import pandas as pd
import librosa
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import imageio_ffmpeg as ffmpegio

# ----------------------------
# HARD-CODED PARAMETERS
# ----------------------------
AUDIO_PATH = Path("../../datasets/audio/Voice 250810_182638.m4a")  # any format (.wav, .mp3, .m4a, .flac, .ogg)
OUTPUT_CSV = Path("yamnet_event_timeline.csv")

TARGET_SR = 16000
TOP_K = 3                 # keep top-K classes per frame (for optional plotting)
EVENT_THRESHOLD = 0.35    # minimum probability to accept an event frame
SUPPRESS_SPEECH = True    # drop frames where 'Speech' prob >= this threshold
SPEECH_THRESH = 0.50
PLOT_TIMELINE = True

# ----------------------------
# Utility: convert any format to temp WAV 16k mono (venv FFmpeg)
# ----------------------------
FFMPEG_BIN = ffmpegio.get_ffmpeg_exe()

def to_temp_wav_16k_mono(src: Path) -> Path:
    tmpdir = Path(tempfile.mkdtemp(prefix="yamnet_"))
    dst = tmpdir / (src.stem + "_16k_mono.wav")
    cmd = [FFMPEG_BIN, "-y", "-loglevel", "error", "-i", str(src), "-ar", str(TARGET_SR), "-ac", "1", str(dst)]
    subprocess.run(cmd, check=True)
    return dst

# ----------------------------
# Load YAMNet model and labels
# ----------------------------
# Model from TF Hub: https://tfhub.dev/google/yamnet/1
yamnet = hub.load("https://tfhub.dev/google/yamnet/1")
# Class map CSV (included in the model assets)
labels_path = yamnet.assets["yamnet_class_map.csv"].numpy().decode("utf-8")
# Read labels
class_names = []
with tf.io.gfile.GFile(labels_path) as f:
    reader = csv.reader(f)
    next(reader)  # skip header: display_name, mid
    for row in reader:
        class_names.append(row[0])

# Find the 'Speech' class index (for suppression)
try:
    SPEECH_IDX = class_names.index("Speech")
except ValueError:
    SPEECH_IDX = None

# ----------------------------
# Load audio (16 kHz mono)
# ----------------------------
wav16_path = to_temp_wav_16k_mono(AUDIO_PATH)
wav, sr = librosa.load(wav16_path, sr=TARGET_SR, mono=True)
wav = wav.astype(np.float32)

# ----------------------------
# Inference
# ----------------------------
# YAMNet expects mono float32 16 kHz waveform. It returns (scores, embeddings, spectrogram)
scores, embeddings, spectrogram = yamnet(wav)

# Convert to numpy
scores_np = scores.numpy()             # shape: [num_frames, 521]
num_frames, num_classes = scores_np.shape
frame_hop_seconds = 0.48               # YAMNet hop ~0.48s; each frame covers ~0.96s
frame_win_seconds = 0.96

# Optional: drop frames dominated by 'Speech'
if SUPPRESS_SPEECH and SPEECH_IDX is not None:
    mask = scores_np[:, SPEECH_IDX] < SPEECH_THRESH
    # If you drop frames, you also need their time indices. We'll keep all frames
    # but mark speech-dominated frames as 'no event' by zeroing their scores.
    scores_np = scores_np * mask[:, None].astype(np.float32)

# ----------------------------
# Build per-frame top class and time mapping
# ----------------------------
top_class_idx = scores_np.argmax(axis=1)
top_class_prob = scores_np.max(axis=1)

# Time helpers
frame_starts = np.arange(num_frames) * frame_hop_seconds
frame_ends = frame_starts + frame_win_seconds

# ----------------------------
# Segmenting: group contiguous frames where top event label is the same and >= threshold
# ----------------------------
segments = []
if num_frames > 0:
    current_idx = None
    current_start = None
    current_max_prob = 0.0

    for i in range(num_frames):
        label_idx = int(top_class_idx[i])
        prob = float(top_class_prob[i])

        # If below threshold, treat as "no event"
        if prob < EVENT_THRESHOLD:
            label_idx = -1  # no event

        if current_idx is None:
            # start new segment if label valid
            if label_idx != -1:
                current_idx = label_idx
                current_start = frame_starts[i]
                current_max_prob = prob
            continue

        if label_idx == current_idx:
            # continue segment; update max prob
            current_max_prob = max(current_max_prob, prob)
        else:
            # close previous if it was valid
            if current_idx != -1:
                seg_end = frame_ends[i - 1]
                segments.append({
                    "start_sec": round(float(current_start), 2),
                    "end_sec": round(float(seg_end), 2),
                    "label": class_names[current_idx],
                    "max_prob": round(float(current_max_prob), 3),
                })
            # start new (or idle)
            if label_idx != -1:
                current_idx = label_idx
                current_start = frame_starts[i]
                current_max_prob = prob
            else:
                current_idx = None
                current_start = None
                current_max_prob = 0.0

    # flush tail
    if current_idx is not None and current_idx != -1:
        seg_end = frame_ends[-1]
        segments.append({
            "start_sec": round(float(current_start), 2),
            "end_sec": round(float(seg_end), 2),
            "label": class_names[current_idx],
            "max_prob": round(float(current_max_prob), 3),
        })

# ----------------------------
# Save CSV
# ----------------------------
df = pd.DataFrame(segments, columns=["start_sec", "end_sec", "label", "max_prob"])
df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved timeline with {len(df)} segments → {OUTPUT_CSV}")
df.head(10)
