<a href="https://colab.research.google.com/github/equiphysics/education/blob/main/analysis_templates/Standing_Music_HorseName_Year_Month_Day_Time.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## EquiPhysics Data Analysis Notebook - Horses Standing and Listening to Music

This notebook provides a comprehensive workflow for synchronizing and analyzing physiological and movement data from Polar devices with video recordings of horses. It integrates heart rate (HR), accelerometer (ACC) data, and audio/video features to offer insights into horse performance, welfare, and training.

**Workflow Overview:**
1.  **Cell 1: Data Loading & Synchronization:** Mounts Google Drive, allows file upload, loads Polar HR/ACC data and video, automatically aligns video to data based on movement patterns, and extracts audio.
2.  **Cell 2: Feature Extraction & Export:** Computes various metrics (Instant HR, RMSSD, ACC movement, audio loudness, spectral flux, spectral centroid) and saves all processed data and metadata to Google Drive.
3.  **Cell 3: Side-by-Side Video Renderer:** Generates a video with synchronized graphs of key metrics displayed alongside the original video.
4.  **Cell 4: Portrait Reel Renderer:** Creates a vertical video suitable for social media, featuring the video (cropped to fill) and stacked graphs of metrics with a moving cursor.
5.  **Cell 5: Correlation & Causality Analysis:**  Performs statistical analysis to explore relationships (correlations, time lags, and predictive causality) between music features, movement, and physiological states.

## Session Notes

Use this section to record important details about the current analysis session. This information will be saved along with your data exports in a `meta.json` file, providing context for your results.

---

**Date of Session:** [e.g., 2023-10-27]

**Horse's Name:** [e.g., Spirit, Blackjack]

**Brief Session Description:** [e.g., Dressage training, trail ride, lunging session]

**Environmental Conditions:** [e.g., weather (temperature, humidity, wind), and ground surface (arena sand, grass, trail conditions)]

**Horse's Subjective State:** [e.g., horse's energy levels, behavior, or apparent comfort before, during, and after the session]

**Music Used (if any):** [e.g., Classical, Pop, None]

**Notable Events/Interruptions/Anomalies:** [e.g., Dog barked at 1:30 in video, sensor slipped between 3:00-3:15]

**Key Observations During Session:** [e.g., horse looked out the arena as a person walked by, etc.]

**Additional Notes:**

---

## Cell 1: Data Loading, Synchronization, and Setup

**Purpose:** This cell is the starting point of the analysis. It provides an interactive interface to load your raw data files (Polar HR.txt, ACC.txt, and a video file), synchronize them, and define the overall analysis window.

**What it does:**
*   **File Input:** Allows you to either mount your Google Drive and select files or upload them directly from your computer.
*   **Data Processing:** Reads Polar HR and ACC data, calculates an accelerometer-based movement proxy, and determines video duration.
*   **Automatic Synchronization:** Based on user-defined video and data 'sync windows', it automatically searches for the optimal time offset to align the video with your physiological data by comparing movement patterns.
*   **Audio Extraction:** Extracts the full audio track from the synchronized video.
*   **Analysis Window Setup:** Sets a default analysis time window in 'data time' for subsequent processing cells.

**How to use it:**
1.  **Choose Source:** Select whether your files are on Google Drive or if you want to upload them.
2.  **Mount/Upload:** If using Google Drive, click 'Mount Google Drive' and then 'Scan Drive Folder' after specifying the folder path. If uploading, click 'Upload HR/ACC/Video' and select your files.
3.  **Load Files:** Once HR, ACC, and Video files are selected (or auto-detected), click 'Load selected files'. This will display a plot of the ACC movement proxy.
4.  **Define Sync Windows:** Using the ACC movement proxy plot, identify a segment of clear horse movement (e.g., a canter transition) and enter the corresponding start and duration times for both the 'Video' and 'Data' streams. The 'Slack (s)' parameter determines the search range for the automatic offset, and 'Tweak (s)' allows for fine manual adjustment.
5.  **Compute Offset:** Click 'Compute offset from windows' to run the synchronization algorithm. A diagnostic plot will be displayed to help you verify the alignment.
6.  **Save Analysis Window:** Adjust the 'Analysis start (data s)' and 'Analysis stop (data s)' values to define the final time segment for your analysis, then click 'Save analysis window'. This prepares the `EQUIPHY` dictionary for the next cell.

In [None]:
# @title
# ==========================================================
# Cell 1 (COLAB WIDGETS): Drive or Upload + USER-SEGMENT AUTO-SYNC (NO input())
#
# This cell provides a user interface to:
# - Mount Google Drive (optional)
# - Choose to load files from Google Drive OR upload them from your computer.
# - Reads Polar HR.txt and ACC.txt data, and loads a video file.
# - Computes and plots an ACC movement proxy against data time.
# - Allows the user to define a video sync window [v0, v1] and a matching data sync window [d0, d1].
# - Automatically searches for the optimal offset to align data_time = video_time + offset,
#   constraining the search within the implied range of the provided sync windows.
# - Displays a diagnostic overlay plot after alignment for verification.
# - Extracts full audio from the synchronized video, saving it as `extracted_audio.wav`.
# - Sets a default analysis window in DATA time for subsequent cells.
# ==========================================================

import os, sys, io, json, subprocess, shutil
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --- Colab widgets setup ---
# Check if running in Google Colab environment
IN_COLAB = "google.colab" in sys.modules
if not IN_COLAB:
    raise RuntimeError("This cell is written for Google Colab (widgets + Drive mount).")

from IPython.display import display, clear_output
from google.colab import output
output.enable_custom_widget_manager()

import ipywidgets as widgets

# Global dictionary to store state and processed data across cells
EQUIPHY = {}

# Supported video file extensions
VIDEO_EXTS = (".mp4", ".mov", ".m4v", ".avi", ".webm", ".mkv")

# ---------------- Helper Functions ----------------

def _run(cmd, quiet=True):
    """Helper to run shell commands and capture output."""
    if quiet:
        return subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8", errors="replace")
    return subprocess.check_output(cmd).decode("utf-8", errors="replace")

def auto_detect_names(names):
    """Automatically detects HR, ACC, and video files from a list of filenames."""
    names = list(names)
    # Prioritize files with 'hr', 'acc' in their name for .txt files
    hr  = next((f for f in names if f.lower().endswith(".txt") and "hr"  in f.lower()), None)
    acc = next((f for f in names if f.lower().endswith(".txt") and "acc" in f.lower()), None)
    # Detect video files based on common extensions
    vid = next((f for f in names if f.lower().endswith(VIDEO_EXTS)), None)
    return hr, acc, vid

def read_polar_txt_bytes(file_bytes: bytes) -> pd.DataFrame:
    """Reads Polar HR/ACC data from byte content of a .txt file."""
    raw = file_bytes.replace(b"\x00", b"") # Remove null bytes if present
    lines = raw.decode("utf-8", errors="replace").splitlines()

    # Find the start of data rows (after header/comments)
    data_start = None
    for i, line in enumerate(lines):
        if line.startswith("#") or line.strip() == "":
            continue
        data_start = i
        break
    if data_start is None:
        raise ValueError("Could not find data rows (file looks empty or only comments).")

    # Read data into DataFrame, convert relevant columns to numeric
    df = pd.read_csv(io.StringIO("\n".join(lines[data_start:])), skipinitialspace=True)
    df.columns = [c.strip() for c in df.columns] # Clean column names
    for col in ["MS","RR","SC","ACCX","ACCY","ACCZ","HR"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")
    df = df.dropna(subset=["MS"]).reset_index(drop=True) # Drop rows with missing MS values
    return df

def ffprobe_duration(path: str):
    """Uses ffprobe to get video duration in seconds."""
    cmd = ["ffprobe","-v","error","-show_entries","format=duration","-of","json", path]
    js = json.loads(_run(cmd))
    try:
        return float(js.get("format", {}).get("duration", None))
    except Exception:
        return None

def ffprobe_wh(path: str):
    """Uses ffprobe to get video width and height."""
    cmd = ["ffprobe","-v","error","-select_streams","v:0","-show_entries","stream=width,height","-of","json", path]
    js = json.loads(_run(cmd))
    st = (js.get("streams") or [{}])[0]
    w = int(st.get("width", 0) or 0)
    h = int(st.get("height", 0) or 0)
    return w, h

def extract_audio_wav(video_path: str, out_wav: str = "extracted_audio.wav", sr=44100):
    """Extracts audio from a video file using ffmpeg and saves it as a WAV."""
    cmd = ["ffmpeg","-y","-i",video_path,"-vn","-ac","1","-ar",str(sr),"-c:a","pcm_s16le",out_wav]
    try:
        subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8", errors="replace")
        return out_wav
    except subprocess.CalledProcessError as e:
        print("\n⚠️ ffmpeg audio extraction failed.")
        print(e.output.decode("utf-8", errors="replace")[:1200])
        return None

# ---------- ACC Movement Index Calculation ----------
from scipy.signal import butter, filtfilt
def acc_movement_index(df_acc: pd.DataFrame, ms0: float,
                       hp_hz=0.3, win_sec=0.5, hop_sec=0.1):
    """Calculates a movement index from accelerometer data.
    It high-passes the acceleration, computes RMS magnitude in windows.
    """
    req = {"ACCX","ACCY","ACCZ","MS"}
    missing = req - set(df_acc.columns)
    if missing:
        raise ValueError(f"ACC missing columns: {sorted(missing)}")

    # Convert MS to seconds relative to a common 'ms0' start time
    t = (df_acc["MS"].to_numpy(float) - ms0) / 1000.0
    ax = df_acc["ACCX"].to_numpy(float) / 1000.0 # Convert to g
    ay = df_acc["ACCY"].to_numpy(float) / 1000.0
    az = df_acc["ACCZ"].to_numpy(float) / 1000.0

    # Estimate sampling frequency from MS timestamps
    dms = np.diff(df_acc["MS"].to_numpy(float))
    dms = dms[np.isfinite(dms) & (dms > 0)]
    fs = 100.0 if len(dms) == 0 else 1000.0 / np.median(dms)

    # Apply a high-pass filter to remove gravity and slow movements
    b, a = butter(2, hp_hz/(fs/2), btype="highpass")
    axf = filtfilt(b, a, ax)
    ayf = filtfilt(b, a, ay)
    azf = filtfilt(b, a, az)
    mag = np.sqrt(axf*axf + ayf*ayf + azf*azf) # Magnitude of filtered acceleration

    # Compute RMS of magnitude in sliding windows
    win = max(1, int(fs*win_sec))
    hop = max(1, int(fs*hop_sec))
    n = len(mag)
    starts = np.arange(0, n-win+1, hop, dtype=int)
    if len(starts) < 3:
        return np.array([]), np.array([]), fs

    idx = np.array([np.sqrt(np.mean(mag[s:s+win]**2)) for s in starts], dtype=float) # RMS
    tt  = t[starts + win//2] # Time for the center of each window
    return tt, idx, fs

# ---------- Video Motion Energy via FFmpeg ----------
def video_motion_energy_ffmpeg(video_path: str, t_start=0.0, t_end=None, fps=5, resize_w=160):
    """Extracts video motion energy using ffmpeg.
    It extracts frames, converts to grayscale, and computes mean absolute frame difference.
    """
    t_start = float(max(0.0, t_start))
    dur_total = ffprobe_duration(video_path)
    if dur_total is not None:
        if t_end is None:
            t_end = dur_total
        t_end = float(min(t_end, dur_total))
    else:
        if t_end is None:
            raise RuntimeError("Could not determine video duration; please provide t_end.")
        t_end = float(t_end)

    if t_end <= t_start + (2.0/fps):
        return np.array([]), np.array([])

    # Compute scaled height preserving aspect ratio for consistent processing
    W0, H0 = ffprobe_wh(video_path)
    if W0 <= 0 or H0 <= 0:
        # Fallback if video dimensions cannot be determined
        Hs = 90
    else:
        Hs = int(round(H0 * (resize_w / W0)))
        Hs = max(2, 2*(Hs//2))  # Ensure even height for ffmpeg compatibility

    duration = t_end - t_start

    # FFmpeg command to extract grayscale frames at a specific FPS
    cmd = [
        "ffmpeg","-v","error",
        "-ss", f"{t_start:.6f}",
        "-t",  f"{duration:.6f}",
        "-i",  video_path,
        "-vf", f"fps={fps},scale={resize_w}:{Hs},format=gray",
        "-f","rawvideo","-pix_fmt","gray","pipe:1" # Output raw grayscale to stdout
    ]
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    raw = p.stdout.read()
    p.stdout.close()
    err = p.stderr.read().decode("utf-8", errors="replace")
    p.stderr.close()
    rc = p.wait()

    if rc != 0 or len(raw) == 0:
        raise RuntimeError("ffmpeg frame extraction failed.\n" + err[:1200])

    frame_size = resize_w * Hs  # Size of one grayscale frame (bytes)
    n_frames = len(raw) // frame_size
    if n_frames < 3:
        return np.array([]), np.array([])

    buf = np.frombuffer(raw[:n_frames*frame_size], dtype=np.uint8)
    frames = buf.reshape((n_frames, Hs, resize_w)).astype(np.float32)

    # Compute mean absolute frame difference as motion energy
    diffs = np.mean(np.abs(frames[1:] - frames[:-1]), axis=(1,2))
    times = t_start + (np.arange(1, n_frames) / float(fps))
    return times.astype(float), diffs.astype(float)

def zscore(x):
    """Standardizes data to have a mean of 0 and std dev of 1."""
    x = np.asarray(x, float)
    m = np.isfinite(x)
    if m.sum() < 5:
        return x
    mu = np.nanmean(x[m])
    sd = np.nanstd(x[m]) + 1e-9 # Add small epsilon to prevent division by zero
    return (x - mu) / sd

def corr_at_offset(t_acc, x_acc, t_vid, y_vid, offset):
    """Calculates correlation between ACC movement and video motion at a given offset.
    data_time = video_time + offset
    """
    # Map video times to data times using the offset
    td = t_vid + offset
    # Select data points where video time overlaps with ACC data range
    m = (td >= np.nanmin(t_acc)) & (td <= np.nanmax(t_acc)) & np.isfinite(td)
    if m.sum() < 10:
        return np.nan
    # Interpolate ACC movement onto the shifted video times
    xa = np.interp(td[m], t_acc, x_acc)
    ya = y_vid[m]
    # Z-score normalize for correlation calculation
    xa = zscore(xa); ya = zscore(ya)
    mm = np.isfinite(xa) & np.isfinite(ya)
    if mm.sum() < 10:
        return np.nan
    return float(np.corrcoef(xa[mm], ya[mm])[0,1])

def search_best_offset(t_acc, x_acc, t_vid, y_vid,
                       offset_min, offset_max,
                       coarse_step=0.5, fine_step=0.05, fine_window=3.0):
    """Searches for the best offset (highest correlation) between two time series.
    Performs a coarse search, then a fine search around the best coarse result.
    """
    # Coarse search
    offsets = np.arange(offset_min, offset_max + 1e-9, coarse_step, dtype=float)
    best = (None, -np.inf)
    for off in offsets:
        r = corr_at_offset(t_acc, x_acc, t_vid, y_vid, off)
        if np.isfinite(r) and r > best[1]:
            best = (off, r)

    if best[0] is None:
        return None, np.nan

    # Fine search around the best coarse offset
    c0 = best[0]
    offsets2 = np.arange(c0 - fine_window, c0 + fine_window + 1e-9, fine_step, dtype=float)
    best2 = (c0, best[1])
    for off in offsets2:
        if off < offset_min or off > offset_max:
            continue
        r = corr_at_offset(t_acc, x_acc, t_vid, y_vid, off)
        if np.isfinite(r) and r > best2[1]:
            best2 = (off, r)

    return float(best2[0]), float(best2[1])

# ---------------- UI Widget Definitions ----------------
ui_out = widgets.Output() # Widget for displaying output messages and plots

# Radio buttons to choose source (Google Drive or Upload)
mode = widgets.RadioButtons(
    options=[("Google Drive", "drive"), ("Upload from computer", "upload")],
    value="drive",
    description="Source:"
)

# Google Drive related buttons and input
mount_btn = widgets.Button(description="Mount Google Drive", button_style="info")
scan_btn  = widgets.Button(description="Scan Drive Folder", button_style="primary")

drive_folder = widgets.Text(
    value="/content/drive/MyDrive/",
    description="Drive folder:",
    layout=widgets.Layout(width="700px")
)
# Option to copy video locally for faster processing
copy_local = widgets.Checkbox(value=True, description="Copy video to /content for speed")

# Dropdowns for selecting HR, ACC, and Video files (shared for both modes)
hr_dd  = widgets.Dropdown(options=[], description="HR file:", layout=widgets.Layout(width="700px"))
acc_dd = widgets.Dropdown(options=[], description="ACC file:", layout=widgets.Layout(width="700px"))
vid_dd = widgets.Dropdown(options=[], description="Video file:", layout=widgets.Layout(width="700px"))

# Upload button for computer upload mode
upload_btn = widgets.Button(description="Upload HR/ACC/Video", button_style="primary")

# Button to load selected files
load_btn  = widgets.Button(description="Load selected files", button_style="success")

# Sync parameter widgets
sync_box = widgets.VBox([]) # Container for sync-related input widgets
sync_btn = widgets.Button(description="Compute offset from windows", button_style="warning")
win_box = widgets.VBox([]) # Container for analysis window input widgets
savewin_btn = widgets.Button(description="Save analysis window", button_style="success")

# Numeric input fields for sync windows
v0_w = widgets.FloatText(value=0.0, description="Video start (s):")
vd_w = widgets.FloatText(value=20.0, description="Video dur (s):")
d0_w = widgets.FloatText(value=0.0, description="Data start (s):")
dd_w = widgets.FloatText(value=20.0, description="Data dur (s):")
slack_w = widgets.FloatText(value=10.0, description="Slack (s):") # Search range for offset
tweak_w = widgets.FloatText(value=0.0, description="Tweak (s):") # Manual adjustment to offset

# Analysis window input fields
t0_w = widgets.FloatText(value=0.0, description="Analysis start (data s):")
t1_w = widgets.FloatText(value=300.0, description="Analysis stop (data s):")

# ---------------- UI Action Callbacks ----------------

def mount_drive(_):
    """Callback to mount Google Drive."""
    with ui_out:
        print("Mounting Google Drive...")
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)
    with ui_out:
        print("Drive mounted at /content/drive")

def scan_drive(_):
    """Callback to scan the selected Google Drive folder for HR, ACC, and video files."""
    folder = drive_folder.value.strip()
    p = Path(folder)
    with ui_out:
        clear_output()
        if not p.exists():
            print("Folder not found:", folder)
            return

        # Scan for likely candidate files based on naming conventions and extensions
        names = [str(x) for x in sorted(p.glob("*")) if x.is_file()]
        hr_cands  = [n for n in names if n.lower().endswith(".txt") and "hr" in Path(n).name.lower()]
        acc_cands = [n for n in names if n.lower().endswith(".txt") and "acc" in Path(n).name.lower()]
        vid_cands = [n for n in names if Path(n).suffix.lower() in VIDEO_EXTS]

        # If strict matching fails, offer broader txt list for HR/ACC
        if not hr_cands:
            hr_cands = [n for n in names if n.lower().endswith(".txt")]
        if not acc_cands:
            acc_cands = [n for n in names if n.lower().endswith(".txt")]

        # Update dropdown options
        hr_dd.options  = hr_cands
        acc_dd.options = acc_cands
        vid_dd.options = vid_cands

        print(f"Scanned: {folder}")
        print(f"  HR candidates:  {len(hr_cands)}")
        print(f"  ACC candidates: {len(acc_cands)}")
        print(f"  Video candidates:{len(vid_cands)}")
        if len(hr_cands)==0 or len(acc_cands)==0 or len(vid_cands)==0:
            print("If lists are empty, point 'Drive folder' to the directory containing the files.")

def do_upload(_):
    """Callback for uploading files from the local computer."""
    with ui_out:
        clear_output()
        print("Upload dialog should appear below. Select HR*.txt, ACC*.txt, and the video file.")
    from google.colab import files
    up = files.upload()
    if not up:
        with ui_out:
            print("No files uploaded.")
        return

    names = list(up.keys())
    hr_name, acc_name, vid_name = auto_detect_names(names) # Try to auto-detect files

    # Fallback if auto-detection is not perfect, populate dropdowns for manual selection
    if hr_name is None or acc_name is None or vid_name is None:
        txts = [n for n in names if n.lower().endswith(".txt")]
        vids = [n for n in names if n.lower().endswith(VIDEO_EXTS)]
        hr_dd.options  = txts
        acc_dd.options = txts
        vid_dd.options = vids
        with ui_out:
            print("Auto-detect failed. Pick files in the dropdowns, then click 'Load selected files'.")
        EQUIPHY["_uploaded_dict"] = up # Store uploaded bytes temporarily
        return

    # If auto-detection succeeded, pre-select in dropdowns
    EQUIPHY["_uploaded_dict"] = up
    hr_dd.options  = [hr_name]
    acc_dd.options = [acc_name]
    vid_dd.options = [vid_name]
    hr_dd.value = hr_name
    acc_dd.value = acc_name
    vid_dd.value = vid_name

    with ui_out:
        print("Uploaded and auto-detected:")
        print("  HR :", hr_name)
        print("  ACC:", acc_name)
        print("  VID:", vid_name)
        print("Now click 'Load selected files'.")

def load_files(_):
    """Callback to load the selected HR, ACC, and video files, process them, and plot ACC movement."""
    with ui_out:
        clear_output()
        print("Loading selected files...")

    src = mode.value

    if src == "drive":
        hr_path = Path(hr_dd.value)
        acc_path = Path(acc_dd.value)
        vid_path = Path(vid_dd.value)

        if not hr_path.exists() or not acc_path.exists() or not vid_path.exists():
            with ui_out:
                print("One or more selected paths do not exist.")
            return

        EQUIPHY["hr_name"] = hr_path.name
        EQUIPHY["acc_name"] = acc_path.name
        EQUIPHY["video_name"] = vid_path.name

        EQUIPHY["hr_bytes"]  = hr_path.read_bytes()
        EQUIPHY["acc_bytes"] = acc_path.read_bytes()

        # Copy video to /content for potentially faster access in Colab
        if copy_local.value:
            local_vid = Path("/content") / vid_path.name
            if not local_vid.exists():
                shutil.copy2(str(vid_path), str(local_vid))
            EQUIPHY["video_path"] = str(local_vid)
            EQUIPHY["video_path_original"] = str(vid_path)
        else:
            EQUIPHY["video_path"] = str(vid_path)
            EQUIPHY["video_path_original"] = str(vid_path)

    else:  # Upload mode
        up = EQUIPHY.get("_uploaded_dict", None)
        if up is None:
            with ui_out:
                print("No uploaded files found. Click 'Upload HR/ACC/Video' first.")
            return

        EQUIPHY["hr_name"] = Path(hr_dd.value).name
        EQUIPHY["acc_name"] = Path(acc_dd.value).name
        EQUIPHY["video_name"] = Path(vid_dd.value).name

        EQUIPHY["hr_bytes"]  = up[hr_dd.value]
        EQUIPHY["acc_bytes"] = up[acc_dd.value]
        vid_bytes = up[vid_dd.value]

        # Write uploaded video bytes to a local file for ffmpeg/ffprobe
        local_vid = Path("/content") / Path(vid_dd.value).name
        with open(local_vid, "wb") as f:
            f.write(vid_bytes)
        EQUIPHY["video_path"] = str(local_vid)
        EQUIPHY["video_path_original"] = None # Original path not relevant for uploaded files

    # Parse HR and ACC dataframes
    df_hr  = read_polar_txt_bytes(EQUIPHY["hr_bytes"])
    df_acc = read_polar_txt_bytes(EQUIPHY["acc_bytes"])

    # Determine common starting MS timestamp (ms0) and total data duration
    ms0 = float(np.nanmin([df_hr["MS"].min(), df_acc["MS"].min()]))
    data_end = float(np.nanmax([df_hr["MS"].max(), df_acc["MS"].max()]) - ms0) / 1000.0

    EQUIPHY["ms0"] = ms0
    EQUIPHY["data_duration_s"] = data_end

    # Get video duration using ffprobe
    vid_dur = ffprobe_duration(EQUIPHY["video_path"])
    EQUIPHY["video_duration_s"] = vid_dur

    # Compute ACC movement proxy
    t_acc, mov_acc, fs_acc = acc_movement_index(df_acc, ms0, hp_hz=0.3, win_sec=0.5, hop_sec=0.1)
    if len(t_acc) < 20:
        raise RuntimeError("ACC movement series too short to sync. Check ACC columns/data.")

    # Store important metrics in EQUIPHY
    EQUIPHY["_df_hr_rows"] = len(df_hr)
    EQUIPHY["_df_acc_rows"] = len(df_acc)
    EQUIPHY["_t_acc"] = t_acc
    EQUIPHY["_mov_acc"] = mov_acc
    EQUIPHY["_fs_acc"] = fs_acc

    # Plot ACC movement proxy for user to estimate sync window
    with ui_out:
        clear_output()
        print("Loaded ✅")
        print("  HR   :", EQUIPHY["hr_name"], f"(rows={len(df_hr)})")
        print("  ACC  :", EQUIPHY["acc_name"], f"(rows={len(df_acc)}, fs≈{fs_acc:.1f} Hz)")
        print("  VIDEO:", EQUIPHY["video_name"])
        print("  Data duration (s):", f"{data_end:.2f}")
        print("  Video duration (s):", "unknown" if vid_dur is None else f"{vid_dur:.2f}")
        print("\nMovement proxy plot (use this to estimate a matching DATA window):")

        plt.figure(figsize=(12,4))
        plt.plot(t_acc, mov_acc, linewidth=1.0)
        plt.xlabel("DATA time (s since Polar start)")
        plt.ylabel("Movement proxy (g RMS, high-pass)")
        plt.title("ACC movement proxy vs time")
        plt.grid(True, alpha=0.3)
        plt.show()

    # Set default values for sync window inputs
    v0_w.value = 0.0
    vd_w.value = 20.0
    d0_w.value = 0.0
    dd_w.value = 20.0

    # Display sync window input widgets
    sync_box.children = [
        widgets.HTML("<h3>Step 2: Enter sync windows</h3>"
                     "<p><b>VIDEO</b>: pick a segment with clear horse motion + minimal camera motion.<br>"
                     "<b>DATA</b>: use the movement proxy plot to guess the matching time window.</p>"
                     "<p><b>Slack (s)</b>: This defines how much wider than your initial guess the automatic search for the best offset will be. A larger slack value allows for a broader search if your initial sync window guess is slightly off.</p>"
                     "<p><b>Tweak (s)</b>: A manual adjustment to the automatically computed offset. Use this if you visually inspect the result and feel a small correction is needed.</p>"),
        widgets.HBox([v0_w, vd_w]),
        widgets.HBox([d0_w, dd_w]),
        widgets.HBox([slack_w, tweak_w]),
        sync_btn
    ]

def compute_offset(_):
    """Callback to compute the optimal time offset between video and data based on user-defined windows."""
    with ui_out:
        print("\nComputing offset...")

    video_path = EQUIPHY["video_path"]
    t_acc = EQUIPHY["_t_acc"]
    mov_acc = EQUIPHY["_mov_acc"]
    data_end = float(EQUIPHY["data_duration_s"])
    vid_dur = EQUIPHY.get("video_duration_s", None)

    # Get user-defined sync window values
    v_sync_start = float(v0_w.value)
    v_sync_end   = float(v0_w.value + vd_w.value)

    # Clamp video sync window to video duration if known
    if vid_dur is not None:
        v_sync_start = max(0.0, min(v_sync_start, vid_dur-0.1))
        v_sync_end   = max(v_sync_start+0.1, min(v_sync_end, vid_dur))

    d_sync_start = float(d0_w.value)
    d_sync_end   = float(d0_w.value + dd_w.value)
    # Clamp data sync window to data duration
    d_sync_start = max(0.0, min(d_sync_start, data_end-0.1))
    d_sync_end   = max(d_sync_start+0.1, min(d_sync_end, data_end))

    # Extract video motion energy for the specified video sync window
    t_vid, mot = video_motion_energy_ffmpeg(
        video_path, t_start=v_sync_start, t_end=v_sync_end, fps=5, resize_w=160
    )
    if len(t_vid) < 20:
        raise RuntimeError("Video motion series too short. Increase duration or pick a different segment.")
    mot_z = zscore(mot) # Z-score normalize video motion

    # Define the search range for the offset based on the sync windows and slack
    center_off = 0.5*((d_sync_start - v_sync_start) + (d_sync_end - v_sync_end))
    span_off   = max(abs((d_sync_start - v_sync_start) - center_off),
                     abs((d_sync_end   - v_sync_end) - center_off))
    slack = float(slack_w.value)
    guess = center_off
    rng   = max(5.0, span_off + slack)

    # Feasibility bounds for the offset (ensuring data overlap)
    offset_min_feas = 0.0 - float(np.max(t_vid))
    offset_max_feas = float(data_end) - float(np.min(t_vid))
    o_min = max(offset_min_feas, guess - rng)
    o_max = min(offset_max_feas, guess + rng)

    # Search for the best offset
    best_off, best_r = search_best_offset(
        t_acc, mov_acc, t_vid, mot_z,
        offset_min=o_min, offset_max=o_max,
        coarse_step=0.5, fine_step=0.05, fine_window=3.0
    )
    if best_off is None or not np.isfinite(best_r):
        raise RuntimeError("Could not find a stable offset. Try a different sync segment or widen slack.")

    # Apply manual tweak if provided
    best_off = float(best_off + float(tweak_w.value))

    # Store computed offset and sync details in EQUIPHY
    EQUIPHY["video_to_data_offset_s"] = best_off
    EQUIPHY["audio_offset_s"] = best_off # Audio offset is same as video
    EQUIPHY["sync_video_segment"] = (float(v_sync_start), float(v_sync_end))
    EQUIPHY["sync_data_segment_guess"] = (float(d_sync_start), float(d_sync_end))
    EQUIPHY["sync_corr"] = float(best_r)

    # Diagnostic overlay plot to visualize alignment
    td = t_vid + best_off
    m = (td >= np.nanmin(t_acc)) & (td <= np.nanmax(t_acc))
    acc_on_vid = np.interp(td[m], t_acc, mov_acc)

    with ui_out:
        print(f"\n✅ Best offset: {best_off:+.3f} s  (corr={best_r:.3f})")
        print("Offset convention: data_time = video_time + offset")
        print("\nDiagnostic overlay (z-scored):")

        plt.figure(figsize=(12,4))
        plt.plot(td[m], zscore(acc_on_vid), label="ACC movement (z)", linewidth=1.2)
        plt.plot(td[m], zscore(mot_z[m]),   label="Video motion (z)", linewidth=1.2)
        plt.axvspan(d_sync_start, d_sync_end, alpha=0.15, label="your DATA-window guess")
        plt.xlabel("DATA time (s)")
        plt.title(f"Overlay after alignment (corr={best_r:.3f})")
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.show()

    # Extract FULL audio from the original video using the full duration
    with ui_out:
        print("\nExtracting full audio -> extracted_audio.wav ...")
    wav_path = extract_audio_wav(video_path, out_wav="extracted_audio.wav", sr=44100)
    if wav_path is None or (not os.path.exists(wav_path)):
        with ui_out:
            print("⚠️ No extracted audio WAV produced.")
        EQUIPHY["aud_name"] = None
        EQUIPHY["aud_path"] = None
        EQUIPHY["aud_bytes"] = None
        EQUIPHY["aud_sr"] = None
    else:
        EQUIPHY["aud_name"] = Path(wav_path).name
        EQUIPHY["aud_path"] = str(Path(wav_path).resolve())
        EQUIPHY["aud_bytes"] = Path(wav_path).read_bytes()
        EQUIPHY["aud_sr"] = 44100
        with ui_out:
            print("✅ Extracted audio:", EQUIPHY["aud_name"], f"(sr={EQUIPHY['aud_sr']})")

    # Set default analysis window based on computed offset and durations
    if vid_dur is None:
        t_start_default = max(0.0, best_off)
        t_stop_default  = min(data_end, t_start_default + 300.0)
    else:
        t_start_default = float(np.clip(best_off, 0.0, data_end))
        t_stop_default  = min(data_end, t_start_default + 300.0, best_off + vid_dur)

    t0_w.value = float(t_start_default)
    t1_w.value = float(t_stop_default)

    # Display analysis window input widgets
    win_box.children = [
        widgets.HTML("<h3>Step 3: Choose analysis window (DATA time)</h3>"),
        widgets.HBox([t0_w, t1_w]),
        savewin_btn
    ]

def save_window(_):
    """Callback to save the final analysis window and mark EQUIPHY as ready."""
    t0 = float(t0_w.value)
    t1 = float(t1_w.value)
    if t1 <= t0:
        raise ValueError("Stop time must be > start time.")
    EQUIPHY["t_start_s"] = t0
    EQUIPHY["t_stop_s"] = t1
    EQUIPHY["ready"] = True # Indicate that EQUIPHY is prepared for downstream cells
    with ui_out:
        print("\nSaved ✅ EQUIPHY is ready for Cell 2")
        print(f"  video_to_data_offset_s: {EQUIPHY.get('video_to_data_offset_s', np.nan):+.3f}")
        print(f"  Analysis window (data s): [{t0:.2f}, {t1:.2f}]")
        print("\nNext: Run Cell 2 (timestamp-free / ms0-based) and compute series.")
        print("Reminder: in Cell 2, use t_s = (MS - EQUIPHY['ms0'])/1000.0")

# --- Wire Buttons to Callbacks ---
mount_btn.on_click(mount_drive)
scan_btn.on_click(scan_drive)
upload_btn.on_click(do_upload)
load_btn.on_click(load_files)
sync_btn.on_click(compute_offset)
savewin_btn.on_click(save_window)

# --- UI elements specific to Google Drive mode ---
drive_controls_ui = widgets.VBox([
    widgets.HBox([mount_btn]),
    drive_folder,
    widgets.HBox([scan_btn, copy_local]),
])

# --- UI elements specific to Upload from computer mode ---
upload_controls_ui = widgets.VBox([
    upload_btn,
])

source_box = widgets.VBox([mode])

def _toggle_ui(*args):
    """Toggles the visibility of Drive UI vs. Upload UI based on selected mode."""
    if mode.value == "drive":
        drive_controls_ui.layout.display = "flex" # Show Drive UI
        upload_controls_ui.layout.display = "none" # Hide Upload UI
    else: # mode.value == "upload"
        drive_controls_ui.layout.display = "none" # Hide Drive UI
        upload_controls_ui.layout.display = "flex" # Show Upload UI

# Set initial visibility and observe changes to the mode radio buttons
_toggle_ui() # Apply initial state
mode.observe(_toggle_ui, names="value") # Update on radio button change

# Define the header widget
header = widgets.HTML("<h2>EquiPhysics Data Sync</h2>")

# Main UI composition
ui = widgets.VBox([
    header,
    source_box,
    widgets.HTML("<hr>"),
    widgets.HTML("<h3>Choose files</h3>"),
    # Mode-specific controls (visibility toggled)
    drive_controls_ui,
    upload_controls_ui,
    # Shared file selection dropdowns
    hr_dd, acc_dd, vid_dd,
    widgets.HBox([load_btn]),
    widgets.HTML("<hr>"),
    sync_box,
    widgets.HTML("<hr>"),
    win_box,
    widgets.HTML("<hr>"),
    ui_out
])

display(ui)


## Alignment Notes

Use this section to record your observations and assessments of the video and data alignment. This context will be crucial for interpreting subsequent analysis results.

---

**Overall Assessment of Alignment:** [e.g., Good, acceptable, fair, poor]

**Visual Inspection Feedback:** [e.g., Did the movement in the video visually match the peaks in the ACC movement proxy? Was the synchronization diagnostic plot clear?]

**Any Concerns or Anomalies:** [e.g., Were there any periods where the alignment seemed off? Did the 'slack' parameter need to be adjusted significantly?]

**Suggestions for Re-alignment (if any):** [e.g., Should a different sync window be chosen for the video/data?]

---

## Cell 2: Feature Extraction & Export

**Purpose:** This cell processes the loaded HR, ACC, and audio data to compute various physiological and movement metrics. It then prepares these metrics for export, along with metadata and a diagnostic plot.

**What it does:**
*   **Reads Data:** Accesses the raw HR and ACC data loaded in Cell 1, as well as any extracted audio.
*   **Computes Metrics:** Calculates:
    *   Instantaneous Heart Rate (HR)
    *   Rolling Root Mean Square of Successive Differences (RMSSD) as a Heart Rate Variability (HRV) metric.
    *   Accelerometer (ACC) Movement Index.
    *   Audio Loudness, Spectral Flux (a beat proxy), and Spectral Centroid (a frequency proxy) if audio is available.
*   **Applies Analysis Window:** Filters all computed metrics to the `Analysis start` and `Analysis stop` times defined in Cell 1.
*   **Generates Diagnostic Plot:** Creates a multi-panel plot visualizing all computed metrics for quick review.
*   **Exports Data:** Saves all processed time series data as a compressed NumPy `.npz` file, individual `.csv` files for easy inspection, a `meta.json` file containing run details and settings, and a `.png` image of the diagnostic plot to a structured folder in your Google Drive.

**How to use it:**
1.  **Ensure Cell 1 is complete:** Make sure you have successfully loaded your files, computed the offset, and saved the analysis window in Cell 1. The `EQUIPHY['ready']` flag should be `True`.
2.  **Review `USER SETTINGS`:** Adjust parameters like `SAVE_TO_DRIVE`, `DRIVE_BASE_DIR`, `RMSSD_WINDOW_BEATS`, etc., at the beginning of the cell to match your preferences.
3.  **Run the cell:** Execute this cell. It will process the data, display the diagnostic plot, and print messages indicating where your exported files are saved in Google Drive.

In [None]:
# @title
# ==========================================================
# Cell 2 (EXPORT + SAVE TO GOOGLE DRIVE)
# Instant HR + Rolling RMSSD + Movement + Audio/music proxies
#
# This cell processes the loaded HR, ACC, and (optionally) audio data.
# It computes various physiological and movement metrics over the defined
# analysis window and prepares them for export. It also generates a
# diagnostic plot and saves all processed data and metadata to Google Drive.
#
# Requires:
#   - Cell 1 already ran and set EQUIPHY["ready"]=True and:
#   - EQUIPHY["hr_bytes"], ["acc_bytes"], ["ms0"] (raw data and common start time)
#   - EQUIPHY["t_start_s"], ["t_stop_s"] (the analysis window in data time)
#   - EQUIPHY["audio_offset_s"] (the video/audio to data alignment offset)
#   - (optional) EQUIPHY["aud_bytes"] or extracted_audio.wav present for audio analysis
#
# Outputs:
#   - EQUIPHY["series"] populated (a dictionary containing all computed time series)
#   - Saves a folder to Google Drive containing:
#       meta.json  (small metadata + settings of the run)
#       series_arrays.npz (all computed time series as NumPy arrays)
#       beats.csv / movement.csv / audio_*.csv (optional, for quick checks)
#       preview.png (the 6-panel diagnostic plot)
# ==========================================================

import os, io, json, time, platform
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, stft, resample_poly

# ---------------- USER SETTINGS ----------------
# Configure saving options and output directories
SAVE_TO_DRIVE = True # Set to False to prevent saving anything to Drive
DRIVE_BASE_DIR = "/content/drive/MyDrive/EquiPhysics/exports"  # Base directory for exports on Drive
EXPORT_CSVS = True       # Set to True to export data into individual CSV files
EXPORT_PLOT_PNG = True   # Set to True to save the diagnostic plot as a PNG image
RUN_TAG = None           # Custom tag for the export folder. If None, an auto-generated name is used.

# --------- (optional) force inline plotting ----------
# Ensures matplotlib plots are displayed within the notebook output.
try:
    from IPython import get_ipython
    get_ipython().run_line_magic("matplotlib", "inline")
except Exception:
    pass

# --------- processing knobs ----------
# These parameters control how the various metrics are calculated.
RMSSD_WINDOW_BEATS = 30        # Window size in beats for calculating rolling RMSSD (e.g., 20–50 beats)
DETREND_RR = True              # If True, removes slow trends from RR intervals before RMSSD calculation
RR_TREND_ALPHA = 0.02          # Alpha parameter for exponential moving average detrending of RR intervals

AUDIO_DOWNSAMPLE_TO = 11025    # Target sample rate for audio processing (e.g., for beat/frequency proxies)
LOUD_WIN_SEC, LOUD_HOP_SEC = 0.5, 0.1 # Window and hop size for audio loudness calculation (in seconds)

ACC_HP_HZ = 0.3                # High-pass filter frequency for accelerometer data (removes gravity/slow motion)
ACC_WIN_SEC, ACC_HOP_SEC = 1.0, 0.2 # Window and hop size for ACC movement index calculation (in seconds)

# --------- helper functions ----------
# These functions are internal to Cell 2 and assist with data processing.

def _read_polar_txt_bytes(file_bytes: bytes) -> pd.DataFrame:
    """Reads Polar HR/ACC data from byte content, handling common Polar file formats."""
    raw = file_bytes.replace(b"\x00", b"")
    lines = raw.decode("utf-8", errors="replace").splitlines()

    data_start = None
    for i, line in enumerate(lines):
        if line.startswith("#") or line.strip() == "":
            continue
        data_start = i
        break
    if data_start is None:
        raise ValueError("Could not find data rows (file looks empty or only comments).")

    df = pd.read_csv(io.StringIO("\n".join(lines[data_start:])), skipinitialspace=True)
    df.columns = [c.strip() for c in df.columns]

    for col in ["MS","HR","RR","SC","ACCX","ACCY","ACCZ"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")

    df = df.dropna(subset=["MS"]).reset_index(drop=True)
    return df

def _read_audio_bytes(file_bytes: bytes, filename: str):
    """Reads audio data from byte content, attempting with soundfile, falling back to scipy."""
    # soundfile is nice; fallback to scipy wavfile
    try:
        import soundfile as sf
        y, sr = sf.read(io.BytesIO(file_bytes))
        if y.ndim == 2:
            y = y.mean(axis=1)
        return y.astype(np.float32), int(sr)
    except Exception:
        from scipy.io import wavfile
        sr, y = wavfile.read(io.BytesIO(file_bytes))
        y = y.astype(np.float32)
        if y.ndim == 2:
            y = y.mean(axis=1)
        m = np.max(np.abs(y)) if len(y) else 1.0
        if m > 1.5:
            y = y / m
        return y, int(sr)

def _moving_average(x, k=7):
    """Calculates a simple moving average of an array."""
    k = int(max(1, k))
    if k == 1:
        return x
    w = np.ones(k) / k
    return np.convolve(x, w, mode="same")

def ew_mean(x, alpha=0.02):
    """Calculates an exponential weighted moving average (EWM) for detrending."""
    x = np.asarray(x, float)
    out = np.full_like(x, np.nan, dtype=float)
    m = np.isfinite(x)
    if not m.any():
        return out
    y = x[m]
    ew = np.empty_like(y, dtype=float)
    ew[0] = y[0]
    for i in range(1, len(y)):
        ew[i] = alpha*y[i] + (1-alpha)*ew[i-1]
    out[m] = ew
    return out

def rolling_rmssd(rr_ms, window_beats=30):
    """Computes the Root Mean Square of Successive Differences (RMSSD) of RR intervals."""
    rr = pd.Series(rr_ms, dtype="float64")
    drr2 = rr.diff().pow(2) # Squared differences between successive RR intervals
    w = max(3, int(window_beats))
    rmssd = np.sqrt(drr2.rolling(window=w-1, min_periods=w-1).mean())
    return rmssd.to_numpy()

def audio_loudness_db(y, sr, win_sec=0.5, hop_sec=0.1, eps=1e-8):
    """Calculates audio loudness in dB over sliding windows."""
    win = max(1, int(sr * win_sec))
    hop = max(1, int(sr * hop_sec))
    n = len(y)
    n_frames = 1 + max(0, (n - win) // hop)

    t = np.zeros(n_frames, dtype=float)
    db = np.zeros(n_frames, dtype=float)
    for i in range(n_frames):
        s = i * hop
        frame = y[s:s+win]
        rms = np.sqrt(np.mean(frame * frame) + eps)
        db[i] = 20 * np.log10(rms + eps)
        t[i] = (s + win/2) / sr
    return t, db

def audio_beat_and_freq_proxies(y, sr, downsample_to=11025, nperseg=2048, hop=512, eps=1e-12):
    """Calculates audio beat proxy (spectral flux) and frequency proxy (spectral centroid)."""
    # Beat proxy: spectral flux. Frequency proxy: spectral centroid (Hz).
    y = np.asarray(y, np.float32)
    if y.ndim == 2:
        y = y.mean(axis=1)

    # Downsample audio to reduce computational load for STFT
    if sr > downsample_to:
        if sr % downsample_to == 0:
            down = sr // downsample_to
            y_ds = resample_poly(y, up=1, down=down)
            sr_ds = sr // down
        else:
            g = np.gcd(sr, downsample_to)
            up = downsample_to // g
            down = sr // g
            y_ds = resample_poly(y, up=up, down=down)
            sr_ds = downsample_to
    else:
        y_ds, sr_ds = y, sr

    noverlap = max(0, nperseg - hop)
    f, t, Z = stft(y_ds, fs=sr_ds, nperseg=nperseg, noverlap=noverlap, boundary=None)
    mag = np.abs(Z).astype(np.float32)

    # Spectral Centroid: weighted average of frequencies present in a frame
    denom = mag.sum(axis=0) + eps
    centroid_hz = (f[:, None] * mag).sum(axis=0) / denom

    # Spectral Flux: a measure of how quickly the power spectrum of a signal is changing
    dmag = np.diff(mag, axis=1)
    flux = np.maximum(dmag, 0).sum(axis=0) # Only positive changes contribute to flux
    return t[1:], flux, t, centroid_hz

def movement_index_from_arrays(ax, ay, az, t_s, fs, hp_hz=0.3, win_sec=1.0, hop_sec=0.2):
    """Calculates a movement index from high-pass filtered accelerometer data arrays."""
    # Apply a high-pass filter to remove gravity and slow movements
    b, a = butter(2, hp_hz / (fs / 2), btype="highpass")
    axf = filtfilt(b, a, ax)
    ayf = filtfilt(b, a, ay)
    azf = filtfilt(b, a, az)
    mag = np.sqrt(axf*axf + ayf*ayf + azf*azf) # Magnitude of filtered acceleration

    # Compute RMS of magnitude in sliding windows to get movement index
    win = max(1, int(fs * win_sec))
    hop = max(1, int(fs * hop_sec))
    n = len(mag)
    n_frames = 1 + max(0, (n - win) // hop)

    tt = np.zeros(n_frames, dtype=float)
    idx = np.zeros(n_frames, dtype=float)
    for i in range(n_frames):
        s = i * hop
        frame = mag[s:s+win]
        idx[i] = np.sqrt(np.mean(frame * frame)) # RMS value
        tt[i] = t_s[min(s + win//2, n-1)] # Time at the center of the window
    return tt, idx

def _safe_np(a):
    """Converts input to a float NumPy array, handling None by returning empty array."""
    if a is None:
        return np.array([], dtype=float)
    return np.asarray(a, dtype=float)

def _now_tag():
    """Generates a timestamp string for naming output folders."""
    return time.strftime("%Y%m%d_%H%M%S")

# --------- sanity checks ----------
# Ensures Cell 1 has been run and EQUIPHY is ready for processing.
if not globals().get("EQUIPHY", {}).get("ready", False):
    raise RuntimeError("Run Cell 1 first (EQUIPHY must be ready).")

# --------- mount drive if saving ----------
# Automatically mounts Google Drive if SAVE_TO_DRIVE is True.
if SAVE_TO_DRIVE:
    from google.colab import drive
    if not Path("/content/drive").exists():
        drive.mount("/content/drive", force_remount=False)

# --------- pull necessary data from EQUIPHY ----------
# Retrieves raw data bytes and synchronization parameters from EQUIPHY.
df_hr  = _read_polar_txt_bytes(EQUIPHY["hr_bytes"])
df_acc = _read_polar_txt_bytes(EQUIPHY["acc_bytes"])

ms0 = float(EQUIPHY.get("ms0", np.nanmin([df_hr["MS"].min(), df_acc["MS"].min()]))) # Common millisecond start offset
t0 = float(EQUIPHY.get("t_start_s", 0.0)) # Start of analysis window (data time)
t1 = float(EQUIPHY.get("t_stop_s", np.inf)) # End of analysis window (data time)
audio_offset = float(EQUIPHY.get("audio_offset_s", 0.0)) # Offset from video/audio time to data time

hr_name  = EQUIPHY.get("hr_name", "HR")
acc_name = EQUIPHY.get("acc_name", "ACC")
vid_name = EQUIPHY.get("video_name", "VIDEO")
aud_name = EQUIPHY.get("aud_name", None)

# Convert MS timestamps to seconds relative to 'ms0' for consistent time axes.
df_hr["t_s"]  = (df_hr["MS"].to_numpy(float)  - ms0) / 1000.0
df_acc["t_s"] = (df_acc["MS"].to_numpy(float) - ms0) / 1000.0

# =========================
# 1) Beats: Instant HR + flags
# Processes HR data to extract instant heart rate and flag invalid beats.
# =========================
if "RR" not in df_hr.columns:
    raise ValueError("HR file must contain RR (ms) to compute instantaneous HR.")

t_hr_all  = df_hr["t_s"].to_numpy(float)
rr_all    = df_hr["RR"].to_numpy(float)
# SC (Signal Quality) column, defaulting to 1 (good) if not present
sc_all    = df_hr["SC"].to_numpy(int) if "SC" in df_hr.columns else np.ones(len(df_hr), dtype=int)

# Filter for data within the analysis window and valid RR intervals
in_win = np.isfinite(t_hr_all) & (t_hr_all >= t0) & (t_hr_all <= t1)
valid_rr = np.isfinite(rr_all) & (rr_all > 0)

good = in_win & valid_rr & (sc_all == 1) # Good quality beats
flag = in_win & valid_rr & (sc_all != 1) # Flagged beats (e.g., due to poor signal quality)

t_beats = t_hr_all[good]
rr_ms   = rr_all[good]
hr_inst = 60000.0 / rr_ms # Calculate instant HR from RR intervals (60,000 ms/min / RR_ms)

t_flag = t_hr_all[flag]
hr_flag = (60000.0 / rr_all[flag]) if flag.any() else np.array([], dtype=float)

# =========================
# 2) Rolling RMSSD (beats)
# Computes RMSSD, a common HRV metric, over a rolling window.
# =========================
rr_use = rr_ms.copy()
if DETREND_RR and len(rr_use) > 0:
    # Detrend RR intervals to remove slow variations if configured
    rr_use = rr_use - ew_mean(rr_use, alpha=RR_TREND_ALPHA)
rmssd_roll = rolling_rmssd(rr_use, window_beats=RMSSD_WINDOW_BEATS)

# =========================
# 3) Movement proxy (ACC)
# Derives a movement index from accelerometer data.
# =========================
ms_acc = df_acc["MS"].to_numpy(float)
# Estimate accelerometer sampling frequency
dms = np.diff(ms_acc)
dms = dms[np.isfinite(dms) & (dms > 0)]
fs_acc = 100.0 if len(dms) == 0 else 1000.0 / np.median(dms)

if not all(c in df_acc.columns for c in ["ACCX","ACCY","ACCZ"]):
    raise ValueError("ACC file missing ACCX/ACCY/ACCZ columns.")

# Convert raw accelerometer values to 'g' (gravity units)
ax = df_acc["ACCX"].to_numpy(float) / 1000.0
ay = df_acc["ACCY"].to_numpy(float) / 1000.0
az = df_acc["ACCZ"].to_numpy(float) / 1000.0
t_acc = df_acc["t_s"].to_numpy(float)

# Calculate the movement index over the full data range first
t_mov_full, mov_idx_full = movement_index_from_arrays(
    ax, ay, az, t_acc, fs_acc,
    hp_hz=ACC_HP_HZ, win_sec=ACC_WIN_SEC, hop_sec=ACC_HOP_SEC
)
# Then filter for the desired analysis window
m_mov = np.isfinite(t_mov_full) & (t_mov_full >= t0) & (t_mov_full <= t1)
t_mov = t_mov_full[m_mov]
mov_idx = mov_idx_full[m_mov]

# =========================
# 4) Audio/music metrics (optional)
# Extracts loudness, beat, and frequency proxies from audio data.
# audio_time 'a' -> data_time = 'a' + audio_offset
# =========================
have_audio = False
y = sr = None

# Attempt to load audio bytes from EQUIPHY (if extracted by Cell 1)
if (EQUIPHY.get("aud_bytes", None) is not None) and (EQUIPHY.get("aud_name", None) is not None):
    y, sr = _read_audio_bytes(EQUIPHY["aud_bytes"], EQUIPHY["aud_name"])
    have_audio = True
else:
    # Fallback: check if 'extracted_audio.wav' exists in the current working directory
    if Path("extracted_audio.wav").exists():
        try:
            import soundfile as sf
            y, sr = sf.read("extracted_audio.wav")
            if y.ndim == 2:
                y = y.mean(axis=1)
            y = y.astype(np.float32)
            have_audio = True
            aud_name = "extracted_audio.wav"
        except Exception:
            have_audio = False

if have_audio:
    # Compute audio features if audio data is available
    t_loud_a, loud_db = audio_loudness_db(y, sr, win_sec=LOUD_WIN_SEC, hop_sec=LOUD_HOP_SEC)
    t_flux_a, flux, t_cent_a, centroid_hz = audio_beat_and_freq_proxies(y, sr, downsample_to=AUDIO_DOWNSAMPLE_TO)

    # Shift audio times to data time using the determined offset
    t_loud = t_loud_a + audio_offset
    t_flux = t_flux_a + audio_offset
    t_cent = t_cent_a + audio_offset

    # Filter audio features for the analysis window
    m_loud = np.isfinite(t_loud) & (t_loud >= t0) & (t_loud <= t1)
    m_flux = np.isfinite(t_flux) & (t_flux >= t0) & (t_flux <= t1)
    m_cent = np.isfinite(t_cent) & (t_cent >= t0) & (t_cent <= t1)

    t_loud, loud_db = t_loud[m_loud], loud_db[m_loud]
    t_flux, flux = t_flux[m_flux], flux[m_flux]
    t_cent, centroid_hz = t_cent[m_cent], centroid_hz[m_cent]

    # Post-process flux for better visualization (z-score and smoothing)
    if len(flux) > 0:
        flux_z = (flux - np.mean(flux)) / (np.std(flux) + 1e-9)
        flux_z = _moving_average(flux_z, k=9)
    else:
        flux_z = flux

    centroid_sm = _moving_average(centroid_hz, k=7) if len(centroid_hz) else centroid_hz
else:
    # Set audio related variables to None if no audio is present
    t_loud = loud_db = t_flux = flux_z = t_cent = centroid_sm = None

# =========================
# Export series (arrays)
# Stores all computed time series into the EQUIPHY["series"] dictionary.
# =========================
EQUIPHY["series"] = {
    "t0": float(t0), "t1": float(t1),
    "ms0": float(ms0),
    "audio_offset_s": float(audio_offset),
    "fs_acc": float(fs_acc),

    "t_beats": _safe_np(t_beats),
    "rr_ms": _safe_np(rr_ms),
    "hr_inst": _safe_np(hr_inst),
    "rmssd_roll": _safe_np(rmssd_roll),

    "t_flag": _safe_np(t_flag) if len(t_flag) else np.array([], dtype=float), # Flagged HR points
    "hr_flag": _safe_np(hr_flag) if len(t_flag) else np.array([], dtype=float), # HR values for flagged points

    "t_mov": _safe_np(t_mov),
    "mov_idx": _safe_np(mov_idx),

    "have_audio": bool(have_audio),
    "t_loud": _safe_np(t_loud) if have_audio and t_loud is not None else None,
    "loud_db": _safe_np(loud_db) if have_audio and t_loud is not None else None,
    "t_flux": _safe_np(t_flux) if have_audio and t_flux is not None else None,
    "flux_z": _safe_np(flux_z) if have_audio and t_flux is not None else None,
    "t_cent": _safe_np(t_cent) if have_audio and t_cent is not None else None,
    "centroid_sm": _safe_np(centroid_sm) if have_audio and t_cent is not None else None,
}
print("✅ Exported EQUIPHY['series']")

# =========================
# Plot (preview)
# Generates a multi-panel plot to visualize the computed metrics.
# This plot helps in quick verification of the data and alignment.
# =========================
print("\nLoaded:")
print("  HR   :", hr_name, f"(rows={len(df_hr)})")
print("  ACC  :", acc_name, f"(rows={len(df_acc)}, fs≈{fs_acc:.1f} Hz)")
print("  VIDEO:", vid_name)
print("  ms0  :", f"{ms0:.0f} ms  (shared origin)")
print(f"Analysis window (DATA time): [{t0:.2f}, {t1:.2f}] s")
print(f"RMSSD window: {RMSSD_WINDOW_BEATS} beats | detrend RR: {DETREND_RR} (alpha={RR_TREND_ALPHA})")
print(f"Audio: {aud_name} | audio_offset_s (audio→data): {audio_offset:+.3f} | have_audio={have_audio}")

fig = plt.figure(figsize=(12, 12)) # Create a figure for the 6-panel plot

# Plot 1: Instantaneous Heart Rate
ax1 = plt.subplot(6, 1, 1)
ax1.plot(t_beats, hr_inst, linewidth=1.5)
if len(t_flag):
    ax1.scatter(t_flag, hr_flag, s=18, marker="x", label="flagged (SC!=1)")
ax1.set_ylabel("Instant HR\n(bpm)")
ax1.grid(True, alpha=0.3)
if len(t_flag):
    ax1.legend(loc="upper right", fontsize=8)

# Plot 2: Rolling RMSSD
ax2 = plt.subplot(6, 1, 2, sharex=ax1) # Share X-axis with the HR plot
ax2.plot(t_beats, rmssd_roll, linewidth=1.5)
ax2.set_ylabel("RMSSD (rolling)\n(ms)")
ax2.grid(True, alpha=0.3)

# Plot 3: ACC Movement Index
ax3 = plt.subplot(6, 1, 3, sharex=ax1)
ax3.plot(t_mov, mov_idx, linewidth=1.5)
ax3.set_ylabel("Movement\n(g RMS)")
ax3.grid(True, alpha=0.3)

# Plots 4, 5, 6: Audio features (if available)
ax4 = plt.subplot(6, 1, 4, sharex=ax1)
ax5 = plt.subplot(6, 1, 5, sharex=ax1)
ax6 = plt.subplot(6, 1, 6, sharex=ax1)

if have_audio and t_loud is not None:
    ax4.plot(t_loud, loud_db, linewidth=1.5)
    ax4.set_ylabel("Loudness\n(RMS dB)")
    ax4.grid(True, alpha=0.3)

    ax5.plot(t_flux, flux_z, linewidth=1.5)
    ax5.set_ylabel("Beat proxy\n(flux z)")
    ax5.grid(True, alpha=0.3)

    ax6.plot(t_cent, centroid_sm, linewidth=1.5)
    ax6.set_ylabel("Freq proxy\n(centroid Hz)")
    ax6.set_xlabel("Time (data s)")
    ax6.grid(True, alpha=0.3)
else:
    # Display a message if audio data is not available
    ax4.text(0.02, 0.5, "No extracted audio available\n(skipping music metrics)", transform=ax4.transAxes)
    ax4.set_axis_off() # Hide axes for empty plots
    ax5.set_axis_off()
    ax6.set_axis_off()

# Set X-axis limits for all shared plots
xmax = t1 if np.isfinite(t1) else (t_beats[-1] if len(t_beats) else t0 + 10)
ax1.set_xlim(t0, xmax)

plt.tight_layout() # Adjust layout to prevent overlapping elements
plt.show()

# =========================
# SAVE TO DRIVE (metadata + arrays)
# Saves processed data and metadata to a structured folder on Google Drive.
# =========================
if SAVE_TO_DRIVE:
    run_tag = RUN_TAG
    if run_tag is None:
        # Generate a unique run tag based on video name and current timestamp
        stem = Path(vid_name).stem[:40].replace(" ", "_")
        run_tag = f"{stem}__{_now_tag()}"
    out_dir = Path(DRIVE_BASE_DIR) / run_tag # Create a dedicated output folder
    out_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists

    # --- metadata (small json) ---
    # Stores key information about the run, settings, and file details.
    meta = {
        "run_tag": run_tag,
        "created_localtime": time.strftime("%Y-%m-%d %H:%M:%S"),
        "platform": {
            "python": platform.python_version(),
            "system": platform.platform(),
        },
        "files": {
            "hr_name": hr_name,
            "acc_name": acc_name,
            "video_name": vid_name,
            "audio_name": aud_name,
        },
        "alignment": {
            "video_to_data_offset_s": float(EQUIPHY.get("video_to_data_offset_s", np.nan)),
            "audio_offset_s": float(audio_offset),
            "sync_corr": float(EQUIPHY.get("sync_corr", np.nan)),
            "sync_video_segment": EQUIPHY.get("sync_video_segment", None),
            "sync_data_segment_guess": EQUIPHY.get("sync_data_segment_guess", None),
        },
        "window": {
            "t0": float(t0),
            "t1": float(t1),
            "ms0": float(ms0),
        },
        "settings": {
            "RMSSD_WINDOW_BEATS": int(RMSSD_WINDOW_BEATS),
            "DETREND_RR": bool(DETREND_RR),
            "RR_TREND_ALPHA": float(RR_TREND_ALPHA),
            "ACC_HP_HZ": float(ACC_HP_HZ),
            "ACC_WIN_SEC": float(ACC_WIN_SEC),
            "ACC_HOP_SEC": float(ACC_HOP_SEC),
            "AUDIO_DOWNSAMPLE_TO": int(AUDIO_DOWNSAMPLE_TO),
            "LOUD_WIN_SEC": float(LOUD_WIN_SEC),
            "LOUD_HOP_SEC": float(LOUD_HOP_SEC),
        }
    }
    meta_path = out_dir / "meta.json"
    meta_path.write_text(json.dumps(meta, indent=2)) # Save metadata as a pretty-printed JSON file

    # --- arrays (npz) ---
    # Saves all computed time series arrays into a single compressed NumPy .npz file.
    arrays = {}
    for k, v in EQUIPHY["series"].items():
        if isinstance(v, np.ndarray):
            arrays[k] = v
        elif v is None:
            continue
        elif isinstance(v, (float, int, bool)):
            arrays[k] = np.array([v]) # Store single values as arrays for consistency
        # Strings/dicts from EQUIPHY["series"] live in meta.json, not here.
    npz_path = out_dir / "series_arrays.npz"
    np.savez_compressed(npz_path, **arrays) # Save compressed NumPy arrays

    # --- optional CSV exports ---
    # Exports individual data streams to CSV files for easy viewing/sharing.
    if EXPORT_CSVS:
        beats_df = pd.DataFrame({
            "t_beats_s": t_beats,
            "rr_ms": rr_ms,
            "hr_bpm": hr_inst,
            "rmssd_roll_ms": rmssd_roll
        })
        beats_df.to_csv(out_dir / "beats.csv", index=False)

        mov_df = pd.DataFrame({"t_mov_s": t_mov, "mov_idx": mov_idx})
        mov_df.to_csv(out_dir / "movement.csv", index=False)

        if have_audio and (t_loud is not None):
            pd.DataFrame({"t_loud_s": t_loud, "loud_db": loud_db}).to_csv(out_dir / "audio_loudness.csv", index=False)
        if have_audio and (t_flux is not None):
            pd.DataFrame({"t_flux_s": t_flux, "flux_z": flux_z}).to_csv(out_dir / "audio_flux.csv", index=False)
        if have_audio and (t_cent is not None):
            pd.DataFrame({"t_cent_s": t_cent, "centroid_hz_sm": centroid_sm}).to_csv(out_dir / "audio_centroid.csv", index=False)

    # --- optional plot png ---
    # Saves the diagnostic plot as a PNG image.
    if EXPORT_PLOT_PNG:
        fig_path = out_dir / "preview.png"
        fig.savefig(fig_path, dpi=150)

    # Store paths to exported files in EQUIPHY for potential downstream use
    EQUIPHY["export_dir"] = str(out_dir)
    EQUIPHY["export_meta_json"] = str(meta_path)
    EQUIPHY["export_series_npz"] = str(npz_path)

    print("\n✅ Saved exports to Drive:")
    print("  folder:", out_dir)
    print("  meta  :", meta_path.name)
    print("  npz   :", npz_path.name)
    if EXPORT_CSVS:
        print("  csv   : beats.csv, movement.csv (+ audio_*.csv if audio present)")
    if EXPORT_PLOT_PNG:
        print("  plot  : preview.png")


## Graph Observations (Cell 2)

Use this section to record your observations and initial interpretations of the physiological and movement graphs generated in Cell 2. This helps to connect the numerical data with potential real-world events or horse states.

---

**Overall Impression of Graphs:** [e.g., Clear and easy to read, some noise in HR, audio data looks consistent]

**Heart Rate (HR) Trends:** [e.g., Steady HR around 30-35 bpm, a few spikes, slight increase towards the end]

**Heart Rate Variability (RMSSD) Trends:** [e.g., RMSSD generally stable, decreased during HR spikes, increased during calm periods]

**Movement (ACC) Patterns:** [e.g., Low overall movement, a few brief periods of increased movement (shifting weight, head movement)]

**Audio Feature Observations:** [e.g., Music started at ~50s mark, loudness consistent, spectral centroid showed changes with different music parts]

**Noteworthy Correlations/Relationships:** [e.g., HR slightly increased when movement occurred, no clear link between music and HR/movement]

**Any Anomalies or Unexpected Patterns:** [e.g., A sudden drop in HR that doesn't seem to align with movement, unexpected high-frequency audio noise]

**Additional Notes on Interpretation:**

---

## Cell 3: Side-by-Side Video Renderer

**Purpose:** This cell generates a video with synchronized graphs of key metrics displayed alongside the original video. It allows for a direct visual correlation between the horse's behavior and the physiological/movement data.

**What it does:**
*   **Video Processing:** Cuts the original video to the defined analysis window.
*   **Graph Synchronization:** Overlays synchronized plots of Heart Rate, RMSSD, Movement, and Audio Frequency metrics.
*   **Rendering:** Combines the video and the dynamic dashboard into a single MP4 file.

**How to use it:**
1.  **Run the Cell:** Execute the code to start the rendering process.
2.  **Monitor Progress:** A progress bar will be displayed to show the rendering status.
3.  **View Output:** The final video will be saved to the `renders` folder in your defined Google Drive directory.

**Warning:** Video rendering is computationally intensive. Depending on the video length and resolution, this process can take several minutes. Please keep the browser tab open and active while the progress bar updates.



In [None]:
# @title
# ==========================================================
# Cell 3: Movie renderer (SAVE TO GOOGLE DRIVE):
# Video (left) + 4 stacked graphs (right) + moving cursor
# Graphs: HR, rolling RMSSD, Movement, Frequency proxy (spectral centroid)
# Audio: extracted from SAME video subclip via ffmpeg
#
# Requires:
#   - Cell 1: EQUIPHY["video_path"], ["video_to_data_offset_s"], ["t_start_s"], ["t_stop_s"], EQUIPHY["ready"]=True
#   - Cell 2: EQUIPHY["series"] populated (and ideally saved export_dir)
#
# Output:
#   - Writes MP4 directly to Google Drive
# ==========================================================

import os, subprocess, shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

# Ensure tqdm is available for progress bars
try:
    from tqdm import tqdm
except ImportError:
    !pip install -q tqdm
    from tqdm import tqdm

from moviepy.editor import VideoFileClip, VideoClip, clips_array, AudioFileClip

# ---------------- knobs ----------------
OUT_FPS = 15                 # 15 looks nicer but is slower; 10–12 is a good compromise
H = 720                      # output video height
PANEL_WIDTH_PX = 720
PANEL_HEIGHT_IN = 7.5
SCROLL_WINDOW_S = 60         # None = full window, else last N seconds
FINE_TUNE_OFFSET_S = 0.0     # optional micro-adjust if you see lag
AUDIO_SR = 22050

# save location
SAVE_TO_DRIVE = True
DRIVE_FALLBACK_DIR = "/content/drive/MyDrive/EquiPhysics/renders"
ALSO_COPY_TO_CONTENT = False  # if True, also copy mp4 to /content

# -------------- checks ----------------
if not globals().get("EQUIPHY", {}).get("ready", False):
    raise RuntimeError("Run Cell 1 first (EQUIPHY must be ready).")
if "video_path" not in EQUIPHY:
    raise RuntimeError("EQUIPHY['video_path'] missing. Run Cell 1 to load the video.")
if "series" not in EQUIPHY:
    raise RuntimeError("EQUIPHY['series'] missing. Run Cell 2 first.")

S = EQUIPHY["series"]
video_path = EQUIPHY["video_path"]

t0_req = float(EQUIPHY.get("t_start_s", S.get("t0", 0.0)))
t1_req = float(EQUIPHY.get("t_stop_s",  S.get("t1", 0.0)))
if t1_req <= t0_req:
    raise ValueError("Bad t0/t1: t_stop_s must be > t_start_s.")

vid_to_data = float(EQUIPHY.get("video_to_data_offset_s", 0.0)) + float(FINE_TUNE_OFFSET_S)

# Pull aligned series (already in DATA time)
t_beats = np.asarray(S.get("t_beats", []), float)
hr_inst = np.asarray(S.get("hr_inst", []), float)
rmssd   = np.asarray(S.get("rmssd_roll", []), float)
t_mov   = np.asarray(S.get("t_mov", []), float)
mov_idx = np.asarray(S.get("mov_idx", []), float)

t_cent   = S.get("t_cent", None)
centroid = S.get("centroid_sm", None)
have_centroid = (t_cent is not None) and (centroid is not None)
if have_centroid:
    t_cent = np.asarray(t_cent, float)
    centroid = np.asarray(centroid, float)

def _clip_xy(t, y, a, b):
    t = np.asarray(t, float); y = np.asarray(y, float)
    m = np.isfinite(t) & (t >= a) & (t <= b)
    return t[m], y[m]

# ---------------- mount Drive + choose output dir ----------------
if SAVE_TO_DRIVE:
    from google.colab import drive
    if not Path("/content/drive").exists():
        drive.mount("/content/drive", force_remount=False)

export_dir = EQUIPHY.get("export_dir", None)
if SAVE_TO_DRIVE:
    if export_dir is not None and str(export_dir).startswith("/content/drive"):
        out_dir = Path(export_dir) / "renders"
    else:
        out_dir = Path(DRIVE_FALLBACK_DIR)
    out_dir.mkdir(parents=True, exist_ok=True)
else:
    out_dir = Path(".")

# ---------------- choose video subclip from DATA window ----------------
# video_time = data_time - offset
# Keep base open until AFTER write_videofile finishes
base = VideoFileClip(video_path, audio=False)

v_start = t0_req - vid_to_data
v_end   = t1_req - vid_to_data

v_start = max(0.0, float(v_start))
v_end   = min(float(base.duration), float(v_end))

if v_end <= v_start:
    base.close()
    raise RuntimeError("Invalid video subclip after clamping. Check offset or t0/t1.")

clip = base.subclip(v_start, v_end)  # no audio
duration = float(clip.duration)

# This is exact mapping for this rendered clip
out_data_start = v_start + vid_to_data
out_data_end   = out_data_start + duration

print("Render mapping:")
print(f"  requested DATA window: [{t0_req:.2f}, {t1_req:.2f}] s")
print(f"  video subclip:         [{v_start:.2f}, {v_start+duration:.2f}] s  (dur={duration:.2f}s)")
print(f"  offset (data=video+off): {vid_to_data:+.3f} s")
print(f"  output DATA window:    [{out_data_start:.2f}, {out_data_end:.2f}] s")

# Clip plot series to actual output data range
t_beats_c, hr_c    = _clip_xy(t_beats, hr_inst, out_data_start, out_data_end)
_,         rmssd_c = _clip_xy(t_beats, rmssd,   out_data_start, out_data_end)
t_mov_c,   mov_c   = _clip_xy(t_mov,   mov_idx, out_data_start, out_data_end)

if have_centroid:
    t_cent_c, cent_c = _clip_xy(t_cent, centroid, out_data_start, out_data_end)
else:
    t_cent_c = np.array([out_data_start, out_data_end], float)
    cent_c   = np.array([np.nan, np.nan], float)

if len(t_beats_c) < 2 or len(t_mov_c) < 2:
    clip.close(); base.close()
    raise RuntimeError("Not enough beat/movement data in the output window. Widen t0/t1 or check Cell 2 export.")

# ---------------- extract audio from SAME subclip via ffmpeg ----------------
wav_path = "_subclip_audio.wav"
audio_clip = None

ff_cmd = [
    "ffmpeg", "-y",
    "-i",  video_path,
    "-ss", f"{v_start:.6f}",
    "-t",  f"{duration:.6f}",
    "-vn",
    "-ac", "1",
    "-ar", str(AUDIO_SR),
    "-c:a", "pcm_s16le",
    wav_path
]
try:
    subprocess.run(ff_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
        audio_clip = AudioFileClip(wav_path).set_duration(duration)
except Exception as e:
    print("WARNING: could not extract/attach audio:", e)
    audio_clip = None

# ---------------- panel renderer ----------------
class PanelRenderer:
    def __init__(self):
        self.fig = plt.figure(figsize=(PANEL_WIDTH_PX/100, PANEL_HEIGHT_IN), dpi=100)
        self.canvas = FigureCanvas(self.fig)
        self.axes = [self.fig.add_subplot(4,1,i+1) for i in range(4)]

        self.axes[0].plot(t_beats_c, hr_c, linewidth=1.6)
        self.axes[0].set_ylabel("HR (bpm)")
        self.axes[0].grid(True, alpha=0.25)

        self.axes[1].plot(t_beats_c, rmssd_c, linewidth=1.6)
        self.axes[1].set_ylabel("RMSSD (ms)")
        self.axes[1].grid(True, alpha=0.25)

        self.axes[2].plot(t_mov_c, mov_c, linewidth=1.6)
        self.axes[2].set_ylabel("Move (RMS)")
        self.axes[2].grid(True, alpha=0.25)

        self.axes[3].plot(t_cent_c, cent_c, linewidth=1.6)
        self.axes[3].set_ylabel("Freq (Hz)")
        self.axes[3].set_xlabel("DATA time (s)")
        self.axes[3].grid(True, alpha=0.25)

        self.vlines = [ax.axvline(out_data_start, linewidth=2.5) for ax in self.axes]
        for ax in self.axes:
            ax.set_xlim(out_data_start, out_data_end)

        self.fig.tight_layout()

    def __call__(self, t_sub):
        data_t = out_data_start + float(t_sub)

        if SCROLL_WINDOW_S is None:
            xmin, xmax = out_data_start, out_data_end
        else:
            xmin = max(out_data_start, data_t - SCROLL_WINDOW_S)
            xmax = min(out_data_end, xmin + SCROLL_WINDOW_S)
            if xmax <= xmin:
                xmax = xmin + 1.0

        for ax in self.axes:
            ax.set_xlim(xmin, xmax)
        for vl in self.vlines:
            vl.set_xdata([data_t, data_t])

        self.canvas.draw()
        rgba = np.asarray(self.canvas.buffer_rgba())
        return rgba[..., :3].copy()

panel_clip = VideoClip(PanelRenderer(), duration=duration).set_fps(OUT_FPS)

# ---------------- compose side-by-side ----------------
clip_l = clip.resize(height=H)
panel_r = panel_clip.resize(height=H)

final = clips_array([[clip_l, panel_r]]).set_duration(duration)
if audio_clip is not None:
    final = final.set_audio(audio_clip)

# ---------------- write to drive ----------------
stem = Path(video_path).stem[:40].replace(" ", "_")
out_name = f"{stem}__dash__data_{t0_req:.1f}_{t1_req:.1f}.mp4"
out_path = out_dir / out_name

print("\nWriting MP4 to:", out_path)

final.write_videofile(
    str(out_path),
    fps=OUT_FPS,
    codec="libx264",
    audio_codec="aac",
    threads=4,
    preset="medium",
    verbose=False,
    logger="bar"
)

# optional copy to /content for easy download button
if ALSO_COPY_TO_CONTENT:
    shutil.copy2(out_path, Path("/content") / out_name)

# cleanup + close resources
try:
    if os.path.exists(wav_path):
        os.remove(wav_path)
except Exception:
    pass

try:
    if audio_clip is not None:
        audio_clip.close()
except Exception:
    pass

try:
    final.close()
except Exception:
    pass

try:
    panel_clip.close()
except Exception:
    pass

try:
    clip.close()
except Exception:
    pass

try:
    base.close()
except Exception:
    pass

print("\n✅ Saved movie to Drive:")
print("  ", out_path)

## Cell 4: Portrait Reel Renderer

**Purpose:** This cell is designed to create a vertical video (9:16 aspect ratio) suitable for social media platforms like Instagram Reels, TikTok, or YouTube Shorts. It visualizes your data in a mobile-friendly format.

**What it does:**
*   **Vertical Layout:** Crops the original video to fill the top portion of the frame.
*   **Stacked Dashboard:** Places synchronized graphs of Heart Rate, RMSSD, Movement, Loudness, and Frequency in the bottom portion.
*   **Social Ready:** Exports a high-quality vertical MP4 file that is ready for sharing.

**How to use it:**
1.  **Run the Cell:** Execute the code.
2.  **Define Render Window:** You may be asked to input a specific start and stop time (in data seconds) for the clip you want to render, or it will default to a specific segment.
3.  **Wait for Processing:** Similar to Cell 3, this renders frame-by-frame and may take some time.
4.  **Locate File:** The finished video will be saved in your Drive's `renders` folder with `__reel__` in the filename.


In [None]:
# @title
# ==========================================================
# Cell 4: Reel renderer (portrait 9:16, 1080x1920) — SAVE TO DRIVE
# Top: video (cropped to fill)
# Bottom: 5 graphs (HR, RMSSD, Move, Loudness, Freq centroid) + progress bar + cursor
#
# Requires:
#   - Cell 1: EQUIPHY["video_path"], EQUIPHY["video_to_data_offset_s"], EQUIPHY["ready"]=True
#   - Cell 2: EQUIPHY["series"] populated (recommended)
#
# Output:
#   - Saved to Google Drive (folder configurable below)
# ==========================================================

import os, io, subprocess
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from scipy.signal import stft

# ---------- moviepy import (install if needed) ----------
try:
    from moviepy.editor import (
        VideoFileClip, VideoClip, CompositeVideoClip, AudioFileClip
    )
    import moviepy.video.fx.all as vfx
except Exception as e:
    print("Installing moviepy...")
    !pip -q install moviepy
    from moviepy.editor import (
        VideoFileClip, VideoClip, CompositeVideoClip, AudioFileClip
    )
    import moviepy.video.fx.all as vfx

# ------------------ knobs ------------------
OUT_W, OUT_H = 1080, 1920
VIDEO_H = 1080
PANEL_H = OUT_H - VIDEO_H

FPS = 15  # 20 is smoother but slower; 12–15 is a good compromise

SCROLL_WINDOW_S = None   # None = full window; e.g. 30 = scrolling last 30s

# audio features
AUDIO_SR = 22050
LOUD_WIN_SEC, LOUD_HOP_SEC = 0.5, 0.1
CENTROID_NPERSEG = 2048
CENTROID_HOP = 512

# If you need a tiny sync nudge:
FINE_TUNE_OFFSET_S = 0.0

# Render window defaults (DATA time)
DEFAULT_DUR_S = 20.0

# Drive save
SAVE_TO_DRIVE = True
DRIVE_OUT_DIR = "/content/drive/MyDrive/EquiPhysics/renders"

# ------------------ checks ------------------
if not globals().get("EQUIPHY", {}).get("ready", False):
    raise RuntimeError("Run Cell 1 first (EQUIPHY must be ready).")
if "video_path" not in EQUIPHY:
    raise RuntimeError("EQUIPHY['video_path'] missing.")

video_path = EQUIPHY["video_path"]
vid_to_data = float(EQUIPHY.get("video_to_data_offset_s", 0.0)) + float(FINE_TUNE_OFFSET_S)

if "series" not in EQUIPHY:
    raise RuntimeError("EQUIPHY['series'] missing. Run Cell 2 first (recommended).")

S = EQUIPHY["series"]

# ------------------ mount drive + output dir ------------------
if SAVE_TO_DRIVE:
    from google.colab import drive
    if not Path("/content/drive").exists():
        drive.mount("/content/drive", force_remount=False)
    out_dir = Path(EQUIPHY.get("export_dir", DRIVE_OUT_DIR))
    if not str(out_dir).startswith("/content/drive"):
        out_dir = Path(DRIVE_OUT_DIR)
    out_dir.mkdir(parents=True, exist_ok=True)
else:
    out_dir = Path(".")

# ------------------ helpers ------------------
def clip_by_t(t, y, t0, t1):
    t = np.asarray(t, float); y = np.asarray(y, float)
    m = np.isfinite(t) & (t >= t0) & (t <= t1)
    return t[m], y[m]

def fit_to_box(clip_in, box_w, box_h):
    """Resize to cover the box then center-crop (fill)."""
    c = clip_in.resize(height=box_h)
    if c.w < box_w:
        c = clip_in.resize(width=box_w)
    c = c.fx(vfx.crop, width=box_w, height=box_h, x_center=c.w/2, y_center=c.h/2)
    return c

def loudness_db(y, sr, win_sec=0.5, hop_sec=0.1, eps=1e-8):
    y = np.asarray(y, float)
    if y.ndim == 2:
        y = y.mean(axis=1)
    win = max(1, int(sr*win_sec))
    hop = max(1, int(sr*hop_sec))
    n = len(y)
    starts = np.arange(0, max(0, n - win + 1), hop, dtype=int)
    t = (starts + win/2) / sr
    db = np.empty(len(starts), dtype=float)
    for i, s in enumerate(starts):
        frame = y[s:s+win]
        rms = np.sqrt(np.mean(frame*frame) + eps)
        db[i] = 20*np.log10(rms + eps)
    return t, db

def spectral_centroid(y, sr, nperseg=2048, hop=512, eps=1e-12):
    y = np.asarray(y, float)
    if y.ndim == 2:
        y = y.mean(axis=1)
    noverlap = max(0, nperseg - hop)
    f, t, Z = stft(y, fs=sr, nperseg=nperseg, noverlap=noverlap, boundary=None)
    mag = np.abs(Z).astype(np.float32)
    denom = mag.sum(axis=0) + eps
    cent = (f[:, None] * mag).sum(axis=0) / denom
    return t, cent

# ------------------ choose render window (DATA time) ------------------
t0_default = float(EQUIPHY.get("t_start_s", S.get("t0", 0.0)))
t1_default = float(EQUIPHY.get("t_stop_s",  min(t0_default + DEFAULT_DUR_S, S.get("t1", t0_default + DEFAULT_DUR_S))))

print(f"Suggested render window (DATA time): start={t0_default:.2f}, stop={t1_default:.2f}")
try:
    render_t0 = float(input(f"Render start (data s) [default {t0_default:.2f}]: ") or f"{t0_default:.2f}")
    render_t1 = float(input(f"Render stop  (data s) [default {t1_default:.2f}]: ") or f"{t1_default:.2f}")
except Exception:
    # If inputs don't show (Colab forms sometimes), fall back silently
    render_t0, render_t1 = t0_default, t1_default

if render_t1 <= render_t0:
    raise ValueError("Render stop must be > render start.")

# ------------------ map DATA window to VIDEO subclip ------------------
# data = video + offset  => video = data - offset
base = VideoFileClip(video_path, audio=False)

v0 = max(0.0, render_t0 - vid_to_data)
v1 = min(float(base.duration), render_t1 - vid_to_data)

if v1 <= v0:
    base.close()
    raise RuntimeError("Invalid video window after mapping. Check video_to_data_offset_s and render window.")

clip = base.subclip(v0, v1)     # keep base open until end
duration = float(clip.duration)

out_data_start = v0 + vid_to_data
out_data_end = out_data_start + duration

print("\nRender mapping:")
print(f"  DATA window requested: [{render_t0:.2f}, {render_t1:.2f}]")
print(f"  VIDEO subclip:         [{v0:.2f}, {v0+duration:.2f}] (dur={duration:.2f}s)")
print(f"  offset (data=video+off): {vid_to_data:+.3f}")
print(f"  OUTPUT DATA window:    [{out_data_start:.2f}, {out_data_end:.2f}]")

# ------------------ pull series from EQUIPHY['series'] and clip to output ------------------
t_beats = np.asarray(S.get("t_beats", []), float)
hr_inst = np.asarray(S.get("hr_inst", []), float)
rmssd   = np.asarray(S.get("rmssd_roll", []), float)

t_mov   = np.asarray(S.get("t_mov", []), float)
mov_idx = np.asarray(S.get("mov_idx", []), float)

t_beats_c, hr_c    = clip_by_t(t_beats, hr_inst, out_data_start, out_data_end)
_,         rmssd_c = clip_by_t(t_beats, rmssd,   out_data_start, out_data_end)
t_mov_c,   mov_c   = clip_by_t(t_mov,   mov_idx, out_data_start, out_data_end)

if len(t_beats_c) < 2 or len(t_mov_c) < 2:
    clip.close(); base.close()
    raise RuntimeError("Not enough HR/movement points in this window. Pick a wider window or check Cell 2 export.")

# ------------------ extract subclip audio via ffmpeg (robust) ------------------
wav_path = "_reel_audio.wav"
audio_clip = None
t_loud_data = np.array([out_data_start, out_data_end], float)
loud_db = np.array([np.nan, np.nan], float)
t_cent_data = np.array([out_data_start, out_data_end], float)
cent_hz = np.array([np.nan, np.nan], float)

ff_cmd = [
    "ffmpeg", "-y",
    "-i", video_path,
    "-ss", f"{v0:.6f}",
    "-t",  f"{duration:.6f}",
    "-vn",
    "-ac", "1",
    "-ar", str(AUDIO_SR),
    "-c:a", "pcm_s16le",
    wav_path
]
try:
    subprocess.run(ff_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
        audio_clip = AudioFileClip(wav_path).set_duration(duration)

        import soundfile as sf
        y_aud, sr_aud = sf.read(wav_path, always_2d=False)
        if y_aud.ndim == 2:
            y_aud = y_aud.mean(axis=1)

        # features in subclip time (0..duration), then map -> DATA time
        tL, dB = loudness_db(y_aud.astype(np.float32), sr_aud, win_sec=LOUD_WIN_SEC, hop_sec=LOUD_HOP_SEC)
        t_loud_data = (v0 + tL) + vid_to_data
        t_loud_data, loud_db = clip_by_t(t_loud_data, dB, out_data_start, out_data_end)

        tC, cHz = spectral_centroid(y_aud.astype(np.float32), sr_aud, nperseg=CENTROID_NPERSEG, hop=CENTROID_HOP)
        t_cent_data = (v0 + tC) + vid_to_data
        t_cent_data, cent_hz = clip_by_t(t_cent_data, cHz, out_data_start, out_data_end)

except Exception as e:
    print("WARNING: could not extract/attach audio or compute audio features:", e)
    audio_clip = None

# ------------------ prep top video (crop to fill) ------------------
clip_top = fit_to_box(clip, OUT_W, VIDEO_H)

# ------------------ dashboard renderer ------------------
class DashboardRenderer:
    def __init__(self):
        self.fig = plt.figure(figsize=(OUT_W/100, PANEL_H/100), dpi=100)
        self.canvas = FigureCanvas(self.fig)

        # 6 rows: HR, RMSSD, Move, Loud, Centroid, Progress
        gs = self.fig.add_gridspec(6, 1, height_ratios=[1,1,1,1,1,0.32], hspace=0.28)

        self.ax_hr   = self.fig.add_subplot(gs[0,0])
        self.ax_rms  = self.fig.add_subplot(gs[1,0], sharex=self.ax_hr)
        self.ax_mov  = self.fig.add_subplot(gs[2,0], sharex=self.ax_hr)
        self.ax_loud = self.fig.add_subplot(gs[3,0], sharex=self.ax_hr)
        self.ax_cent = self.fig.add_subplot(gs[4,0], sharex=self.ax_hr)
        self.ax_bar  = self.fig.add_subplot(gs[5,0], sharex=self.ax_hr)

        # static plots
        self.ax_hr.plot(t_beats_c, hr_c, linewidth=2)
        self.ax_hr.set_ylabel("HR")
        self.ax_hr.grid(True, alpha=0.25)

        self.ax_rms.plot(t_beats_c, rmssd_c, linewidth=2)
        self.ax_rms.set_ylabel("RMSSD")
        self.ax_rms.grid(True, alpha=0.25)

        self.ax_mov.plot(t_mov_c, mov_c, linewidth=2)
        self.ax_mov.set_ylabel("Move")
        self.ax_mov.grid(True, alpha=0.25)

        self.ax_loud.plot(t_loud_data, loud_db, linewidth=2)
        self.ax_loud.set_ylabel("Loud dB")
        self.ax_loud.grid(True, alpha=0.25)

        self.ax_cent.plot(t_cent_data, cent_hz, linewidth=2)
        self.ax_cent.set_ylabel("Freq Hz")
        self.ax_cent.grid(True, alpha=0.25)
        self.ax_cent.set_xlabel("DATA time (s)")

        # cursor lines
        self.vlines = [
            self.ax_hr.axvline(out_data_start, linewidth=3),
            self.ax_rms.axvline(out_data_start, linewidth=3),
            self.ax_mov.axvline(out_data_start, linewidth=3),
            self.ax_loud.axvline(out_data_start, linewidth=3),
            self.ax_cent.axvline(out_data_start, linewidth=3),
        ]

        # progress bar
        self.ax_bar.set_ylim(0, 1)
        self.ax_bar.set_yticks([])
        self.ax_bar.grid(False)
        self.bar_line = self.ax_bar.axvline(out_data_start, linewidth=3)
        self.bar_fill = self.ax_bar.axvspan(out_data_start, out_data_start, ymin=0, ymax=1, alpha=0.35)

        # initial xlim
        for ax in [self.ax_hr, self.ax_rms, self.ax_mov, self.ax_loud, self.ax_cent, self.ax_bar]:
            ax.set_xlim(out_data_start, out_data_end)

        self.fig.tight_layout()

    def __call__(self, t_sub):
        data_t = out_data_start + float(t_sub)

        if SCROLL_WINDOW_S is None:
            xmin, xmax = out_data_start, out_data_end
        else:
            xmin = max(out_data_start, data_t - SCROLL_WINDOW_S)
            xmax = min(out_data_end, xmin + SCROLL_WINDOW_S)
            if xmax <= xmin:
                xmax = xmin + 1.0

        for ax in [self.ax_hr, self.ax_rms, self.ax_mov, self.ax_loud, self.ax_cent, self.ax_bar]:
            ax.set_xlim(xmin, xmax)

        for vl in self.vlines:
            vl.set_xdata([data_t, data_t])

        self.bar_line.set_xdata([data_t, data_t])
        self.bar_fill.remove()
        self.bar_fill = self.ax_bar.axvspan(out_data_start, min(data_t, out_data_end), ymin=0, ymax=1, alpha=0.35)

        self.canvas.draw()
        rgba = np.asarray(self.canvas.buffer_rgba())
        return rgba[..., :3].copy()

panel_clip = VideoClip(DashboardRenderer(), duration=duration).set_fps(FPS)
panel_clip = panel_clip.resize((OUT_W, PANEL_H))

# ------------------ compose portrait reel ------------------
final = CompositeVideoClip(
    [
        clip_top.set_position((0, 0)),
        panel_clip.set_position((0, VIDEO_H)),
    ],
    size=(OUT_W, OUT_H)
).set_duration(duration)

if audio_clip is not None:
    final = final.set_audio(audio_clip)

# ------------------ render + save ------------------
stem = Path(video_path).stem[:30].replace(" ", "_")
out_name = f"{stem}__reel__data_{out_data_start:.1f}_{out_data_end:.1f}.mp4"
out_path = out_dir / out_name

print("\nWriting:", out_path)
final.write_videofile(
    str(out_path),
    fps=FPS,
    codec="libx264",
    audio_codec="aac",
    threads=4,
    preset="medium",
    verbose=False,
    logger=None
)

# ------------------ cleanup ------------------
try:
    if audio_clip is not None:
        audio_clip.close()
except: pass
try:
    final.close()
except: pass
try:
    panel_clip.close()
except: pass
try:
    clip.close()
except: pass
try:
    base.close()
except: pass
try:
    if os.path.exists(wav_path):
        os.remove(wav_path)
except: pass

print("\n✅ Saved reel to:", out_path)



## Cell 5: Correlation & Causality Analysis

**Purpose:** This cell performs advanced statistical analysis to explore relationships between the horse's physiological state (HR, HRV), movement, and the music features.

**What it does:**
*   **Aligns Data:** Resamples all time series onto a common time grid (defined by `DT`).
*   **Correlation Analysis:** Calculates Pearson and Spearman correlations to find linear and monotonic relationships. It also checks for time-lagged correlations (e.g., does HR rise 5 seconds *after* the music gets loud?).
*   **Granger Causality:** Tests for predictive relationships. If "Music Granger-causes HR", it means past values of music features help predict future HR values better than HR history alone.

**How to use it:**
1.  **Run the cell:** Just execute the code. It uses the data processed in Cell 2.
2.  **Review the Output:** Look at the printed tables and plots.
    *   **Correlation Table:** Look for high absolute `pearson_r` values (close to 1 or -1) with low `pearson_p` values (< 0.05). `best_lag_s` tells you the time delay.
    *   **Granger Table:** Look for low `best_p` values (< 0.05). This suggests a potential predictive link.
3.  **Adjust Parameters (Optional):**
    *   `DT`: Resampling interval in seconds (default `1.0`). Smaller values (e.g., `0.5`) give higher resolution but might be noisier.
    *   `MAX_LAG_S`: Max time lag to check (default `30` seconds).

**Note:** "Granger causality" is a statistical concept of predictability, not necessarily proof of real-world cause-and-effect.

In [None]:
# @title
# ==========================================================
# Cell 5: Correlations + lag analysis + Granger Causality
#   music metrics -> (HR, HRV, movement)
#   movement -> (HR, HRV)
#
# Assumes Cell 2 ran and created EQUIPHY["series"].
# Notes:
# - Correlation is not causation.
# - Granger causality tests whether X helps predict future Y (linear, predictive).
# ==========================================================

import numpy as np
import pandas as pd
from pathlib import Path
from scipy import stats
import matplotlib.pyplot as plt
from IPython.display import display

# statsmodels is used for Granger tests
from statsmodels.tsa.stattools import grangercausalitytests

S = EQUIPHY.get("series", None)
if S is None:
    raise RuntimeError("EQUIPHY['series'] not found. Run Cell 2 first.")

# ----------------- user-tunable parameters -----------------
DT = 1.0               # seconds, common resample grid (try 0.5 or 0.2 for higher-res)
MAX_LAG_S = 30         # seconds, lag search window for correlation and info-transfer tests
MAX_GRANGER_LAG_S = 15 # seconds, max Granger lag (kept smaller than MAX_LAG_S for stability)
USE_DIFF_FOR_GRANGER = True  # difference series before Granger to reduce trend-driven false positives
MIN_OBS_FOR_CORR = 20  # Minimum number of overlapping observations for a valid correlation
MIN_OBS_FOR_GRANGER = 50 # Minimum number of observations for Granger test (higher for stability)
# -----------------------------------------------------------

def _interp_to_grid(t, y, t_grid):
    """Linear interpolation onto t_grid. Returns NaNs where interpolation is impossible."""
    t = np.asarray(t, dtype=float)
    y = np.asarray(y, dtype=float)
    m = np.isfinite(t) & np.isfinite(y)
    if m.sum() < 2:
        return np.full_like(t_grid, np.nan, dtype=float)
    tt = t[m]
    yy = y[m]

    # Sort based on time
    idx = np.argsort(tt)
    tt = tt[idx]
    yy = yy[idx]

    # np.interp requires increasing tt
    return np.interp(t_grid, tt, yy, left=np.nan, right=np.nan)

def _zscore(x):
    x = np.asarray(x, dtype=float)
    m = np.isfinite(x)
    if m.sum() < 3:
        return x
    mu = np.nanmean(x[m])
    sd = np.nanstd(x[m])
    if sd <= 0:
        return x - mu
    return (x - mu) / sd

def _pairwise_corr(x, y):
    """Pearson + Spearman, with p-values, ignoring NaNs."""
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < MIN_OBS_FOR_CORR:
        return dict(n=int(m.sum()), pearson_r=np.nan, pearson_p=np.nan, spearman_rho=np.nan, spearman_p=np.nan)
    pr, pp = stats.pearsonr(x[m], y[m])
    sr, sp = stats.spearmanr(x[m], y[m])
    return dict(n=int(m.sum()), pearson_r=float(pr), pearson_p=float(pp), spearman_rho=float(sr), spearman_p=float(sp))

def _partial_corr(x, y, covars):
    """Partial Pearson correlation between x and y controlling for covars (linear residualization)."""
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    C = np.asarray(covars, float)
    if C.ndim == 1:
        C = C[:, None]
    m = np.isfinite(x) & np.isfinite(y) & np.all(np.isfinite(C), axis=1)
    if m.sum() < (MIN_OBS_FOR_CORR + 10): # Need more observations for partial corr
        return dict(n=int(m.sum()), r=np.nan, p=np.nan)
    X = np.column_stack([np.ones(m.sum()), C[m]])
    # residualize x and y
    bx, *_ = np.linalg.lstsq(X, x[m], rcond=None)
    by, *_ = np.linalg.lstsq(X, y[m], rcond=None)
    rx = x[m] - X @ bx
    ry = y[m] - X @ by
    r, p = stats.pearsonr(rx, ry)
    return dict(n=float(m.sum()), r=float(r), p=float(p))

def _best_lag_corr(x, y, dt, max_lag_s):
    """
    Find lag (in seconds) maximizing |corr| when x leads y.
    We compute corr(x(t), y(t+lag)) for lag in [-max_lag, +max_lag].
    Positive lag means x leads y by lag seconds.
    """
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    max_k = int(round(max_lag_s / dt))
    lags = np.arange(-max_k, max_k + 1, dtype=int)
    best = dict(lag_s=np.nan, r=np.nan, n=0)

    for k in lags:
        if k == 0:
            xx, yy = x, y
        elif k > 0:
            xx, yy = x[:-k], y[k:]
        else:
            kk = -k
            xx, yy = x[kk:], y[:-kk]
        m = np.isfinite(xx) & np.isfinite(yy)
        if m.sum() < MIN_OBS_FOR_CORR:
            continue
        r = np.corrcoef(xx[m], yy[m])[0, 1]
        if not np.isfinite(r):
            continue
        if (not np.isfinite(best["r"])) or (abs(r) > abs(best["r"])): # Maximize absolute correlation
            best = dict(lag_s=float(k * dt), r=float(r), n=int(m.sum()))
    return best

def _granger_best_p(x, y, maxlag_steps):
    """
    Granger test x -> y using statsmodels.
    Returns best p-value (ssr_ftest) across lags and the lag where it occurs.
    """
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < MIN_OBS_FOR_GRANGER:
        return dict(n=int(m.sum()), best_p=np.nan, best_lag=np.nan)
    xx = x[m]
    yy = y[m]
    if USE_DIFF_FOR_GRANGER:
        xx = np.diff(xx)
        yy = np.diff(yy)
        if len(xx) < (maxlag_steps + 20):
            return dict(n=len(xx), best_p=np.nan, best_lag=np.nan)
    data = np.column_stack([yy, xx])  # [y, x]
    try:
        res = grangercausalitytests(data, maxlag=maxlag_steps, verbose=False)
        pvals = []
        for lag, out in res.items():
            p = out[0]["ssr_ftest"][1]
            pvals.append((lag, p))
        if not pvals:
             return dict(n=int(len(xx)), best_p=np.nan, best_lag=np.nan)
        lag_best, p_best = min(pvals, key=lambda t: t[1])
        return dict(n=int(len(xx)), best_p=float(p_best), best_lag=float(lag_best))
    except Exception:
        return dict(n=int(len(xx)), best_p=np.nan, best_lag=np.nan)

# ----------------- Build a common, aligned dataframe -----------------
t0 = float(S["t0"]); t1 = float(S["t1"])
t_grid = np.arange(t0, t1 + 1e-12, DT)

# Core physiology/movement
HR = _interp_to_grid(S["t_beats"], S["hr_inst"], t_grid)
RMSSD = _interp_to_grid(S["t_beats"], S["rmssd_roll"], t_grid)
MOV = _interp_to_grid(S["t_mov"], S["mov_idx"], t_grid)

# Optional: SDNN rolling (computed on beats, then interpolated)
SDNN = np.full_like(t_grid, np.nan, dtype=float)
try:
    rr = np.asarray(S["rr_ms"], float)
    tb = np.asarray(S["t_beats"], float)
    m_rr = np.isfinite(rr) & np.isfinite(tb)
    rr = rr[m_rr]; tb = tb[m_rr]
    if len(rr) > (RMSSD_WINDOW_BEATS + 10): # Ensure enough beats for SDNN calculation
        win = RMSSD_WINDOW_BEATS  # beats, match RMSSD_WINDOW_BEATS from Cell 2
        sdnn_roll = pd.Series(rr).rolling(win, min_periods=max(10, win//2)).std(ddof=1).to_numpy()
        SDNN = _interp_to_grid(tb, sdnn_roll, t_grid)
except Exception:
    pass

df = pd.DataFrame({
    "t_s": t_grid,
    "hr_bpm": HR,
    "rmssd_ms": RMSSD,
    "sdnn_ms": SDNN,
    "mov_idx": MOV,
})

# Audio/music metrics, if present
music_cols = []
if bool(S.get("have_audio", False)) and (S.get("t_loud", None) is not None):
    df["loud_db"] = _interp_to_grid(S["t_loud"], S["loud_db"], t_grid); music_cols.append("loud_db")
if bool(S.get("have_audio", False)) and (S.get("t_flux", None) is not None):
    df["flux_z"] = _interp_to_grid(S["t_flux"], S["flux_z"], t_grid); music_cols.append("flux_z")
if bool(S.get("have_audio", False)) and (S.get("t_cent", None) is not None):
    df["centroid_hz"] = _interp_to_grid(S["t_cent"], S["centroid_sm"], t_grid); music_cols.append("centroid_hz")

# Report
print(f"Aligned dataframe: {len(df)} samples @ DT={DT:.3f}s over [{t0:.1f}, {t1:.1f}] s")
print("Columns:", list(df.columns))

# ----------------- Correlations -----------------
print("\n--- Zero-lag & Lagged Correlations ---")
targets_phys = ["hr_bpm", "rmssd_ms", "sdnn_ms", "mov_idx"]
targets_hrv = ["rmssd_ms", "sdnn_ms"]

rows = []

# music -> (HR, HRV, movement)
for x in music_cols:
    for y in targets_phys:
        out = _pairwise_corr(df[x], df[y])
        lagbest = _best_lag_corr(df[x], df[y], DT, MAX_LAG_S)
        pc = None
        # partial corr controlling for movement (only makes sense for y=HR/HRV)
        if y in (["hr_bpm"] + targets_hrv):
            pc = _partial_corr(df[x], df[y], covars=df["mov_idx"])
        rows.append({
            "x": x, "y": y,
            **out,
            "best_lag_s (x leads y)": lagbest["lag_s"],
            "best_lag_r": lagbest["r"],
            "best_lag_n": lagbest["n"],
            "partial_r | mov": (pc["r"] if pc else np.nan),
            "partial_p | mov": (pc["p"] if pc else np.nan),
            "partial_n | mov": (pc["n"] if pc else np.nan),
        })

# movement -> (HR, HRV)
for y in ["hr_bpm"] + targets_hrv:
    out = _pairwise_corr(df["mov_idx"], df[y])
    lagbest = _best_lag_corr(df["mov_idx"], df[y], DT, MAX_LAG_S)
    rows.append({
        "x": "mov_idx", "y": y,
        **out,
        "best_lag_s (x leads y)": lagbest["lag_s"],
        "best_lag_r": lagbest["r"],
        "best_lag_n": lagbest["n"],
        "partial_r | mov": np.nan, # Not applicable for mov -> other, controlling for mov
        "partial_p | mov": np.nan,
        "partial_n | mov": np.nan,
    })

corr_tbl = pd.DataFrame(rows)
# Filter out rows with insufficient observations or NaN correlations
corr_tbl = corr_tbl[ (corr_tbl['n'] >= MIN_OBS_FOR_CORR) &
                     (corr_tbl['pearson_r'].notna()) &
                     (corr_tbl['best_lag_n'] >= MIN_OBS_FOR_CORR) &
                     (corr_tbl['best_lag_r'].notna())].copy()

# Sort by strongest absolute zero-lag Pearson correlation
corr_tbl["abs_pearson"] = np.abs(corr_tbl["pearson_r"])
corr_tbl = corr_tbl.sort_values(["abs_pearson"], ascending=False).drop(columns=["abs_pearson"])

# Format for display
corr_tbl_display = corr_tbl.round({
    "n": 0, "pearson_r": 3, "pearson_p": 4,
    "spearman_rho": 3, "spearman_p": 4,
    "best_lag_s (x leads y)": 1, "best_lag_r": 3, "best_lag_n": 0,
    "partial_r | mov": 3, "partial_p | mov": 4, "partial_n | mov": 0
})
# Use Int64 (nullable int) to handle NaNs in integer columns without error
corr_tbl_display["n"] = corr_tbl_display["n"].astype("Int64")
corr_tbl_display["best_lag_n"] = corr_tbl_display["best_lag_n"].astype("Int64")
corr_tbl_display["partial_n | mov"] = corr_tbl_display["partial_n | mov"].astype("Int64")

print("Interpretation of Correlation Table:")
print("  - 'x' vs 'y': The two time series being compared.")
print("  - 'n': Number of overlapping, non-NaN data points for zero-lag correlation.")
print("  - 'pearson_r': Linear correlation coefficient (range -1 to 1).")
print("  - 'pearson_p': P-value for Pearson correlation (lower is more significant).")
print("  - 'spearman_rho': Rank correlation coefficient, useful for non-linear relationships.")
print("  - 'spearman_p': P-value for Spearman correlation.")
print("  - 'best_lag_s (x leads y)': The time lag (in seconds) where the absolute correlation between x and y is maximized. A positive lag means 'x' occurs before 'y'.")
print("  - 'best_lag_r': The correlation coefficient at the best lag.")
print("  - 'partial_r | mov': Partial Pearson correlation controlling for 'mov_idx'. This helps to see if a relationship between 'x' and 'y' persists even when accounting for movement.")
display(corr_tbl_display)

# ----------------- Granger causality (predictive directionality) -----------------
print("\n--- Granger Causality (Predictive Directionality) ---")
# Granger causality checks if past values of X help predict Y better than just past values of Y.
# We prep by z-scoring (standardizing) the data first.
def _prep_for_granger(a, b):
    aa = _zscore(a)
    bb = _zscore(b)
    return aa, bb

maxlag_steps = max(1, int(round(MAX_GRANGER_LAG_S / DT))) # Convert seconds to steps
info_rows = []

def _do_granger_pair(xname, yname):
    x, y = _prep_for_granger(df[xname].to_numpy(), df[yname].to_numpy())
    g = _granger_best_p(x, y, maxlag_steps=maxlag_steps)
    return g

# music -> (HR, HRV, movement)
for x in music_cols:
    for y in targets_phys:
        g = _do_granger_pair(x, y)
        info_rows.append({"method":"Granger", "x":x, "y":y, "best_p":g["best_p"], "best_lag_steps":g["best_lag"], "n":g["n"]})

# movement -> (HR, HRV)
for y in ["hr_bpm"] + targets_hrv:
    g = _do_granger_pair("mov_idx", y)
    info_rows.append({"method":"Granger", "x":"mov_idx", "y":y, "best_p":g["best_p"], "best_lag_steps":g["best_lag"], "n":g["n"]})

g_tbl = pd.DataFrame(info_rows)
# Filter out rows with insufficient observations or NaN p-values
g_tbl = g_tbl[ (g_tbl['n'] >= MIN_OBS_FOR_GRANGER) & (g_tbl['best_p'].notna()) ].copy()

g_tbl_display = g_tbl.sort_values(["best_p"], ascending=True).round({"best_p": 4, "best_lag_steps": 0, "n": 0})
g_tbl_display["n"] = g_tbl_display["n"].astype("Int64")

print("Interpretation of Granger Causality Table:")
print("  - 'x' -> 'y': Tests if 'x' can predict future values of 'y'.")
print("  - 'n': Number of observations used for the test (after differencing if enabled).")
print("  - 'best_p': Smallest p-value across all tested lags. A p-value < 0.05 (or other significance level) suggests 'x' Granger-causes 'y'.")
print("  - 'best_lag_steps': The lag (in DT steps) corresponding to the 'best_p'.")
print("Note: Granger causality indicates predictive relationship, not necessarily direct causation.")
display(g_tbl_display)

# ----------------- Quick plots -----------------
# 1) Zero-lag Pearson r heatmap (music + mov) vs outputs
heat_x = music_cols + ["mov_idx"]
heat_y = ["hr_bpm", "rmssd_ms", "sdnn_ms", "mov_idx"]
H = np.full((len(heat_x), len(heat_y)), np.nan, dtype=float)

for i, x_var in enumerate(heat_x):
    for j, y_var in enumerate(heat_y):
        m = np.isfinite(df[x_var]) & np.isfinite(df[y_var])
        if m.sum() >= MIN_OBS_FOR_CORR:
            H[i, j] = np.corrcoef(df.loc[m, x_var], df.loc[m, y_var])[0, 1]

fig, ax = plt.subplots(figsize=(1.2*len(heat_y)+2, 0.6*len(heat_x)+2))
im = ax.imshow(H, aspect="auto", vmin=-1, vmax=1, cmap="RdBu")
ax.set_xticks(range(len(heat_y))); ax.set_xticklabels(heat_y, rotation=45, ha="right")
ax.set_yticks(range(len(heat_x))); ax.set_yticklabels(heat_x)
ax.set_title("Zero-lag Pearson correlation (r)")
fig.colorbar(im, ax=ax, shrink=0.8)
plt.show()

# 2) Lag-correlation curve for a couple of key pairs (edit list if desired)
def _plot_lag_curve(xname, yname):
    x = df[xname].to_numpy(float)
    y = df[yname].to_numpy(float)
    max_k = int(round(MAX_LAG_S / DT))
    lags = np.arange(-max_k, max_k+1, dtype=int)
    rs = np.full_like(lags, np.nan, dtype=float)
    ns = np.zeros_like(lags, dtype=int)
    for ii, k in enumerate(lags):
        if k == 0:
            xx, yy = x, y
        elif k > 0:
            xx, yy = x[:-k], y[k:]
        else:
            kk = -k
            xx, yy = x[kk:], y[:-kk]
        m = np.isfinite(xx) & np.isfinite(yy)
        ns[ii] = int(m.sum())
        if m.sum() >= MIN_OBS_FOR_CORR:
            rs[ii] = np.corrcoef(xx[m], yy[m])[0, 1]
    fig, ax = plt.subplots(figsize=(7,3))
    ax.plot(lags*DT, rs)
    ax.axvline(0, linestyle="--", color='gray')
    ax.axhline(0, linestyle="--", color='gray')
    ax.set_xlabel("Lag (s)  [positive => x leads y]")
    ax.set_ylabel("Correlation (r)")
    ax.set_title(f"Lag correlation: {xname} → {yname}")
    plt.grid(True, alpha=0.3)
    plt.show()

# Pick a couple representative pairs (customize if you want)
if len(music_cols) > 0:
    print(f"\nLagged correlation plot for {music_cols[0]} -> hr_bpm:")
    _plot_lag_curve(music_cols[0], "hr_bpm")
    print(f"\nLagged correlation plot for mov_idx -> hr_bpm:")
else:
    print("\nNo music columns available for lagged correlation plots.")

_plot_lag_curve("mov_idx", "hr_bpm")

# ----------------- Save outputs next to export_dir if available -----------------
out_dir = Path(EQUIPHY.get("export_dir", "."))
try:
    analysis_dir = out_dir / "analysis"
    analysis_dir.mkdir(parents=True, exist_ok=True)
    corr_tbl_display.to_csv(analysis_dir / "corr_summary.csv", index=False)
    g_tbl_display.to_csv(analysis_dir / "granger_summary.csv", index=False)
    print(f"\n✅ Saved filtered and formatted analysis CSVs to: {analysis_dir}")
except Exception as e:
    print(f"\n⚠️ Could not save analysis outputs: {e}")


## Analysis Interpretation & Notes

Use this section to synthesize the statistical results from Cell 5. Do the numbers support your visual observations?

---

**Correlation Findings:** [e.g., Strong negative correlation between Music Flux and HRV, suggesting intense music lowers HRV.]

**Time Lag Observations:** [e.g., HR seems to react to Loudness with a lag of ~5 seconds.]

**Causality (Granger) Results:** [e.g., Movement significantly predicts HRV changes (p<0.05). Did music features show any predictive power?]

**Overall Conclusion:** [e.g., While there are some correlations, movement appears to be the primary driver of physiological changes, with music playing a secondary role.]

---