# Imports & Directories

In [77]:
import os
import subprocess
import numpy as np
import torch
import cv2
from ultralytics import YOLO
from tqdm import tqdm
from skimage import measure
import shlex
from pathlib import Path

# CUDA setup
torch.backends.cudnn.benchmark = True  # Optimize for GPUs with a fixed input size
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. This script requires a GPU.")
else:
    print("CUDA is available")
device = torch.device("cuda")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
PYTORCH_CUDA_ALLOC_CONF=expandable_segments=True
torch.cuda.empty_cache()  # Clear the CUDA memory before starting

# Display current memory stats
print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
print(f"Max allocated memory: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
print(f"Max reserved memory: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")

#From the end of YOLO training
model_path = "/home/ramanlab/Documents/Arshiya/Yolo-Model/YOLOLocustPalps/runs/obb/train72/weights/best.pt"

# Main directory containing individual fly folders
main_directory = "/home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX" 

overall_mean_latency_s = 0
OVERALL_MEAN_LATENCY_S = overall_mean_latency_s

# Get list of all individual fly folders
individual_fly_folders = [f for f in os.listdir(main_directory) if os.path.isdir(os.path.join(main_directory, f))]

print(f'{individual_fly_folders}')

CUDA is available
Allocated memory: 277.25 MB
Reserved memory: 280.00 MB
Max allocated memory: 8711.36 MB
Max reserved memory: 8816.00 MB
['L5', 'L2', 'L8', 'L6', 'L1', 'L3', 'L4', 'L7']


In [68]:
%run extra_scripts/timestampFunction.py


# Run Yolo Model

In [78]:
import os
import time
import logging
import cv2
import numpy as np
import pandas as pd
from ultralytics import YOLO 
import torch
from collections import deque
from typing import Optional, Dict, Tuple, List
import subprocess

def create_video_writer(output_path: str, width: int, height: int, fps: float):
    """
    Try to create an OpenCV VideoWriter for MP4.
    If fails, fallback to AVI + XVID.
    Returns: writer object, final output path
    """
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    if writer.isOpened():
        return writer, output_path

    # fallback to AVI
    avi_path = output_path.replace(".mp4", ".avi")
    fourcc = cv2.VideoWriter_fourcc(*"XVID")
    writer = cv2.VideoWriter(avi_path, fourcc, fps, (width, height))
    
    if not writer.isOpened():
        raise RuntimeError(f"Cannot open video writer for {output_path} or fallback {avi_path}")
    
    print(f"[INFO] OpenCV MP4 writer failed, using AVI fallback: {avi_path}")
    return writer, avi_path

def convert_avi_to_mp4(avi_path: str, mp4_path: str):
    """Convert AVI to MP4 using ffmpeg"""
    cmd = [
        "ffmpeg",
        "-y",  # overwrite if exists
        "-i", avi_path,
        "-c:v", "libx264",
        "-crf", "18",
        "-preset", "fast",
        mp4_path
    ]
    print(f"[INFO] Converting {avi_path} -> {mp4_path} via ffmpeg")
    subprocess.run(cmd, check=True)

logging.getLogger("ultralytics").setLevel(logging.WARNING)

# ──────────────────────────────────────────────────────────────────────────────
# USER CONFIG
# Expect 'model_path' and 'main_directory' to be defined in your environment.
model = YOLO(model_path)

# Target: track TWO instances of class 1 (palps) from an OBB model
TARGET_CLASS_ID = 1
NUM_TARGETS     = 2

# Force output video to 30 FPS (regardless of input/CSV)
OUTPUT_FPS      = 30.0

if torch.cuda.is_available():
    model.to("cuda")
    print("Model moved to CUDA.")
else:
    print("CUDA not available. Running on CPU.")

# ──────────────────────────────────────────────────────────────────────────────
# TEMPORAL / SPATIAL ENHANCEMENT CONFIG
CONF_THRES      = 0.40   # detector confidence threshold
IOU_MATCH_THRES = 0.70   # IoU threshold for associating detections to existing track
MAX_AGE         = 15     # frames allowed without fresh detection before a track is considered stale
EMA_ALPHA       = 0.20   # extra exponential smoothing after Kalman update (0=off; 0.15–0.25 recommended)

# Optical flow (optional, only used when a track has no fresh detection this frame)
FLOW_ENABLE     = True
FLOW_SKIP_EDGE  = 10     # ignore flow near borders (px)
FLOW_PARAMS     = dict(pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0)

# ──────────────────────────────────────────────────────────────────────────────
# Geometry helpers
def order_corners(corners):
    """
    corners: np.array shape (4,2) in arbitrary order
    Returns: ordered list of corners in clockwise order.
    """
    pts = np.array(corners, dtype=np.float32)
    cx, cy = np.mean(pts, axis=0)
    angles = np.arctan2(pts[:, 1] - cy, pts[:, 0] - cx)
    pts_sorted = pts[np.argsort(angles)]
    return pts_sorted.tolist()

def xyxy_to_cxcywh(b):
    x1, y1, x2, y2 = b
    w = max(0.0, x2 - x1)
    h = max(0.0, y2 - y1)
    cx = x1 + w / 2.0
    cy = y1 + h / 2.0
    return np.array([cx, cy, w, h], dtype=np.float32)

def cxcywh_to_xyxy(s):
    cx, cy, w, h = s
    x1 = cx - w / 2.0
    y1 = cy - h / 2.0
    x2 = cx + w / 2.0
    y2 = cy + h / 2.0
    return np.array([x1, y1, x2, y2], dtype=np.float32)

def iou(a, b):
    # a,b: [N,4] xyxy
    N = a.shape[0]
    M = b.shape[0]
    if N == 0 or M == 0:
        return np.zeros((N, M), dtype=np.float32)

    ax1, ay1, ax2, ay2 = a[:, 0][:, None], a[:, 1][:, None], a[:, 2][:, None], a[:, 3][:, None]
    bx1, by1, bx2, by2 = b[:, 0][None, :], b[:, 1][None, :], b[:, 2][None, :], b[:, 3][None, :]

    inter_w = np.maximum(0, np.minimum(ax2, bx2) - np.maximum(ax1, bx1))
    inter_h = np.maximum(0, np.minimum(ay2, by2) - np.maximum(ay1, by1))
    inter   = inter_w * inter_h

    area_a = (ax2 - ax1) * (ay2 - ay1)
    area_b = (bx2 - bx1) * (by2 - by1)
    union  = area_a + area_b - inter + 1e-6
    return (inter / union).astype(np.float32)

# ──────────────────────────────────────────────────────────────────────────────
# Kalman for (cx, cy, w, h) + multi-object tracker (single class)
class KalmanBBox:
    def __init__(self):
        self.x = np.zeros((8, 1), dtype=np.float32)
        self.P = np.eye(8, dtype=np.float32) * 10.0

        self.F = np.eye(8, dtype=np.float32)
        for i in range(4):
            self.F[i, i + 4] = 1.0  # constant velocity

        self.H = np.zeros((4, 8), dtype=np.float32)
        self.H[0, 0] = self.H[1, 1] = self.H[2, 2] = self.H[3, 3] = 1.0

        self.Q = np.eye(8, dtype=np.float32) * 0.02
        self.R = np.eye(4, dtype=np.float32) * 1.0

    def init(self, cxcywh):
        self.x[:4, 0] = cxcywh
        self.x[4:, 0] = 0.0
        self.P = np.eye(8, dtype=np.float32) * 10.0

    def predict(self):
        self.x = self.F @ self.x
        self.P = self.F @ self.P @ self.F.T + self.Q
        return self.x[:4, 0].copy()

    def update(self, z):
        z = z.reshape(4, 1)
        y = z - (self.H @ self.x)
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)
        self.x = self.x + K @ y
        I = np.eye(8, dtype=np.float32)
        self.P = (I - K @ self.H) @ self.P

class Track:
    _next_id = 1

    def __init__(self, cxcywh, score, corners: Optional[List[List[float]]] = None):
        self.id = Track._next_id
        Track._next_id += 1

        self.kf = KalmanBBox()
        self.kf.init(cxcywh)

        self.score = float(score)
        self.box_xyxy = cxcywh_to_xyxy(cxcywh)

        self.time_since_update = 0
        self.hits = 1
        self.history = deque(maxlen=30)

        # Last known OBB corners (ordered 4x2 list), if available
        self.corners = corners

    def predict(self):
        pred = self.kf.predict()
        box = cxcywh_to_xyxy(pred)
        self.box_xyxy = box
        self.history.append(box.copy())
        self.time_since_update += 1
        return box

    def correct(self, cxcywh, score, corners: Optional[List[List[float]]] = None):
        self.kf.update(cxcywh)
        box = cxcywh_to_xyxy(self.kf.x[:4, 0])

        # EMA smoothing to reduce jitter
        self.box_xyxy = (1 - EMA_ALPHA) * box + EMA_ALPHA * self.box_xyxy

        self.score = float(score)
        self.hits += 1
        self.time_since_update = 0
        self.history.append(self.box_xyxy.copy())

        if corners is not None:
            self.corners = corners

class MultiObjectSingleClassTracker:
    """Maintains 0–N tracks for ONE class; returns all active tracks sorted by quality."""
    def __init__(self, iou_thres=0.25, max_age=15):
        self.iou_thres = iou_thres
        self.max_age   = max_age
        self.tracks: List[Track] = []

    def step(self, det_xyxy: np.ndarray, det_scores: np.ndarray, det_corners: Optional[List[Optional[List[List[float]]]]] = None) -> List[Track]:
        if det_corners is None:
            det_corners = [None] * len(det_xyxy)

        # 1) Predict all
        preds = [t.predict() for t in self.tracks]

        # 2) Greedy IoU match (track <-> detection)
        assigned_tr, assigned_det = set(), set()
        if len(self.tracks) and len(det_xyxy):
            M = iou(np.stack(preds), det_xyxy)
            while True:
                i, j = np.unravel_index(np.argmax(M), M.shape)
                if M[i, j] < self.iou_thres:
                    break
                if i in assigned_tr or j in assigned_det:
                    M[i, j] = -1
                    continue

                self.tracks[i].correct(
                    xyxy_to_cxcywh(det_xyxy[j]),
                    det_scores[j],
                    corners=det_corners[j]
                )
                assigned_tr.add(i)
                assigned_det.add(j)
                M[i, :] = -1
                M[:, j] = -1

        # 3) Create tracks for unmatched detections
        for j in range(len(det_xyxy)):
            if j in assigned_det:
                continue
            self.tracks.append(
                Track(xyxy_to_cxcywh(det_xyxy[j]), det_scores[j], corners=det_corners[j])
            )

        # 4) Prune stale
        self.tracks = [t for t in self.tracks if t.time_since_update <= self.max_age]

        # 5) Sort: freshest first, then most hits, then score
        self.tracks.sort(key=lambda t: (t.time_since_update, -t.hits, -t.score, t.id))
        return self.tracks

def flow_nudge(prev_gray, gray, box_xyxy):
    if prev_gray is None:
        return box_xyxy

    x1, y1, x2, y2 = box_xyxy.astype(int)
    x1 = max(FLOW_SKIP_EDGE, x1)
    y1 = max(FLOW_SKIP_EDGE, y1)
    x2 = min(gray.shape[1] - FLOW_SKIP_EDGE, x2)
    y2 = min(gray.shape[0] - FLOW_SKIP_EDGE, y2)
    if x2 <= x1 or y2 <= y1:
        return box_xyxy

    flow = cv2.calcOpticalFlowFarneback(
        prev_gray[y1:y2, x1:x2],
        gray[y1:y2, x1:x2],
        None,
        **FLOW_PARAMS
    )
    dx = np.median(flow[..., 0])
    dy = np.median(flow[..., 1])

    nudged = box_xyxy.copy().astype(np.float32)
    nudged[0::2] += dx
    nudged[1::2] += dy
    return nudged

# ──────────────────────────────────────────────────────────────────────────────
# CSV timestamp helpers (unchanged)
TIMESTAMP_CANDIDATES = ["UTC_ISO", "Timestamp", "Number", "MonoNs"]
FRAME_CANDIDATES = ["Frame Number", "FrameNumber"]

def _pick_timestamp_column(df: pd.DataFrame) -> Optional[str]:
    for c in TIMESTAMP_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _pick_frame_column(df: pd.DataFrame) -> Optional[str]:
    for c in FRAME_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _to_seconds_series(df: pd.DataFrame, ts_col: str) -> pd.Series:
    s = df[ts_col]
    if ts_col in ("UTC_ISO", "Timestamp"):
        dt = pd.to_datetime(s, errors="coerce", utc=(ts_col == "UTC_ISO"))
        secs = dt.astype("int64") / 1e9
        t0 = np.nanmin(secs.values)
        return (secs - t0).astype(float)
    if ts_col == "Number":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        t0 = np.nanmin(vals.values)
        return vals - t0
    if ts_col == "MonoNs":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        secs = vals / 1e9
        t0 = np.nanmin(secs.values)
        return secs - t0
    raise ValueError(f"Unsupported timestamp column: {ts_col}")

def _estimate_fps_from_seconds(seconds_series: pd.Series) -> Optional[float]:
    mask = seconds_series.notna()
    if mask.sum() < 2:
        return None
    duration = seconds_series[mask].iloc[-1] - seconds_series[mask].iloc[0]
    if duration <= 0:
        return None
    return mask.sum() / duration

# ──────────────────────────────────────────────────────────────────────────────
# Per-frame processing: detect ALL class-1 palps, track, draw boxes+centers, line between two
def process_frame(
    frame,
    frame_number,
    current_timestamp,
    tracker: MultiObjectSingleClassTracker,
    prev_gray
):
    # 1) YOLO inference
    r = model.predict(source=frame, conf=CONF_THRES, verbose=False)[0]

    # 2) Collect ALL detections for TARGET_CLASS_ID (xyxy + score + OBB corners if present)
    det_boxes = []
    det_scores = []
    det_corners = []  # list of ordered corners (4x2) or None per det

    if hasattr(r, "obb") and r.obb is not None:
        xyxyxyxy = r.obb.xyxyxyxy.cpu().numpy()  # (N,8)
        cls_arr  = r.obb.cls.cpu().numpy().astype(int)
        conf_arr = (
            r.obb.conf.cpu().numpy().astype(np.float32)
            if hasattr(r.obb, "conf") and r.obb.conf is not None
            else np.ones_like(cls_arr, dtype=np.float32)
        )

        for i, (c, s) in enumerate(zip(cls_arr, conf_arr)):
            if c != TARGET_CLASS_ID:
                continue
            corners = xyxyxyxy[i].reshape(4, 2)
            x1, y1 = float(corners[:, 0].min()), float(corners[:, 1].min())
            x2, y2 = float(corners[:, 0].max()), float(corners[:, 1].max())
            det_boxes.append([x1, y1, x2, y2])
            det_scores.append(float(s))
            det_corners.append(order_corners(corners))
    else:
        # Fallback if model outputs axis-aligned boxes
        if r.boxes is not None and len(r.boxes) > 0:
            xyxy = r.boxes.xyxy.cpu().numpy()
            cls_arr  = r.boxes.cls.cpu().numpy().astype(int)
            conf_arr = r.boxes.conf.cpu().numpy().astype(np.float32)
            for b, c, s in zip(xyxy, cls_arr, conf_arr):
                if c != TARGET_CLASS_ID:
                    continue
                det_boxes.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
                det_scores.append(float(s))
                det_corners.append(None)

    det_xyxy = np.array(det_boxes, dtype=np.float32) if len(det_boxes) else np.zeros((0, 4), dtype=np.float32)
    det_scores = np.array(det_scores, dtype=np.float32) if len(det_scores) else np.zeros((0,), dtype=np.float32)

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # 3) Update tracker (multi-object). Select best two tracks.
    tracks = tracker.step(det_xyxy, det_scores, det_corners)
    selected = tracks[:NUM_TARGETS]

    # Optional: if tracks are stale this frame, nudge them by optical flow (display-only)
    if FLOW_ENABLE and prev_gray is not None:
        for t in selected:
            if t.time_since_update > 0:
                nudged = flow_nudge(prev_gray, gray, t.box_xyxy)
                t.box_xyxy = nudged

    # 4) Build outputs for CSV + draw annotations
    palps = []
    for idx in range(NUM_TARGETS):
        if idx < len(selected):
            t = selected[idx]
            box = t.box_xyxy.astype(np.float32)
            cx, cy, bw, bh = xyxy_to_cxcywh(box)
            palps.append({
                "track_id": t.id,
                "x1": float(box[0]), "y1": float(box[1]), "x2": float(box[2]), "y2": float(box[3]),
                "cx": float(cx), "cy": float(cy),
                "corners": (t.corners if t.corners is not None else np.nan),
                "age": int(t.time_since_update),
                "score": float(t.score),
            })

            # Draw OBB polygon if available, else draw axis-aligned rectangle
            if t.corners is not None and isinstance(t.corners, list) and len(t.corners) == 4:
                pts = np.array(t.corners, dtype=np.int32).reshape((-1, 1, 2))
                cv2.polylines(frame, [pts], isClosed=True, color=(0, 255, 255), thickness=2)
            else:
                cv2.rectangle(
                    frame,
                    (int(box[0]), int(box[1])),
                    (int(box[2]), int(box[3])),
                    (0, 255, 255),
                    2
                )

            # Draw center
            cv2.circle(frame, (int(cx), int(cy)), 1, (0, 255, 255), -1)

            # Label
            cv2.putText(
                frame,
                f"palp#{idx+1} id={t.id} age={t.time_since_update}",
                (max(0, int(box[0])), max(20, int(box[1]) - 8)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.3,
                (0, 255, 255),
                1
            )
        else:
            palps.append({
                "track_id": np.nan,
                "x1": np.nan, "y1": np.nan, "x2": np.nan, "y2": np.nan,
                "cx": np.nan, "cy": np.nan,
                "corners": np.nan,
                "age": np.nan,
                "score": np.nan,
            })

    # 5) Draw line connecting the two centers (if both exist)
    distance = np.nan
    if not (np.isnan(palps[0]["cx"]) or np.isnan(palps[1]["cx"])):
        p0 = (int(palps[0]["cx"]), int(palps[0]["cy"]))
        p1 = (int(palps[1]["cx"]), int(palps[1]["cy"]))
        cv2.line(frame, p0, p1, (0, 255, 0), 2)
        distance = float(np.hypot(palps[0]["cx"] - palps[1]["cx"], palps[0]["cy"] - palps[1]["cy"]))
        cv2.putText(
            frame,
            f"dist={distance:.2f}px",
            (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.9,
            (0, 255, 0),
            2
        )

    # 6) CSV row
    row = {
        "frame": frame_number,
        "timestamp": current_timestamp,

        "track_id_palp1": palps[0]["track_id"],
        "x1_palp1": palps[0]["x1"], "y1_palp1": palps[0]["y1"], "x2_palp1": palps[0]["x2"], "y2_palp1": palps[0]["y2"],
        "cx_palp1": palps[0]["cx"], "cy_palp1": palps[0]["cy"],
        "corners_palp1": str(palps[0]["corners"]),
        "age_palp1": palps[0]["age"],
        "score_palp1": palps[0]["score"],

        "track_id_palp2": palps[1]["track_id"],
        "x1_palp2": palps[1]["x1"], "y1_palp2": palps[1]["y1"], "x2_palp2": palps[1]["x2"], "y2_palp2": palps[1]["y2"],
        "cx_palp2": palps[1]["cx"], "cy_palp2": palps[1]["cy"],
        "corners_palp2": str(palps[1]["corners"]),
        "age_palp2": palps[1]["age"],
        "score_palp2": palps[1]["score"],

        "distance_palp1_palp2_px": distance,
    }

    return frame, row, gray

# ──────────────────────────────────────────────────────────────────────────────
# MAIN LOOP
prev_gray = None

individual_fly_folders = [
    f for f in os.listdir(main_directory)
    if os.path.isdir(os.path.join(main_directory, f))
]

for fly_folder in individual_fly_folders:
    fly_folder_path = os.path.join(main_directory, fly_folder)

    video_files = [
        f for f in os.listdir(fly_folder_path)
        if f.lower().endswith((".mp4", ".avi")) and f.split(".")[0]
    ]

    for video_file in video_files:
        video_path = os.path.join(fly_folder_path, video_file)
        video_name = os.path.basename(video_path)
        video_base_name = video_name.split(".")[0]

        csv_file_name = video_base_name.replace("_preprocessed", "") + ".csv"
        csv_file_path = os.path.join(fly_folder_path, csv_file_name)

        folder_name = "_".join(video_base_name.split("_")[1:7])
        output_folder = os.path.join(fly_folder_path, folder_name)
        output_video_path = os.path.join(output_folder, f"{folder_name}_palps_annotated_30fps.mp4")

        if os.path.isdir(output_folder):
            print(f"Skipping {video_path} — found existing folder: {output_folder}")
            continue

        os.makedirs(output_folder, exist_ok=True)

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Could not open video file {video_path}")
            continue

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        max_frame = total_frames - 1

        timestamps = {}
        inferred_fps_for_timestamp = cap.get(cv2.CAP_PROP_FPS) or OUTPUT_FPS

        if os.path.exists(csv_file_path):
            df_timestamps = pd.read_csv(csv_file_path)
            row_count = len(df_timestamps)
            print(f"Found CSV file {csv_file_path} with {row_count} rows.")

            frame_col = _pick_frame_column(df_timestamps)
            if frame_col is not None:
                ts_col = _pick_timestamp_column(df_timestamps)
                if ts_col is not None:
                    secs = _to_seconds_series(df_timestamps, ts_col)
                    tmp = pd.DataFrame({
                        "_frame": pd.to_numeric(df_timestamps[frame_col], errors="coerce"),
                        "seconds": secs
                    }).dropna(subset=["_frame", "seconds"])
                    tmp["_frame"] = tmp["_frame"].astype(int)
                    timestamps = tmp.set_index("_frame")["seconds"].to_dict()

                    if not tmp["_frame"].empty:
                        max_frame = int(tmp["_frame"].max())

                    fps_from_csv = _estimate_fps_from_seconds(secs)
                    if fps_from_csv and np.isfinite(fps_from_csv) and fps_from_csv > 0:
                        inferred_fps_for_timestamp = float(fps_from_csv)
                        print("Calculated FPS from CSV timestamps:", inferred_fps_for_timestamp)
                    else:
                        inferred_fps_for_timestamp = float(cap.get(cv2.CAP_PROP_FPS) or OUTPUT_FPS)
                        print("FPS from CSV unavailable; falling back to video FPS for timestamps:", inferred_fps_for_timestamp)
                else:
                    print("CSV has a frame column but no recognized timestamp column; timestamps will be synthetic.")
            else:
                print("CSV missing frame column; timestamps will be synthetic.")
        else:
            print(f"CSV file {csv_file_path} not found. Timestamps will be synthetic.")

        # Use OUTPUT_FPS for the *writer* (strict 30 fps output)
        writer = None

        writer_fps = OUTPUT_FPS
        
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        
        # inside the while loop, right before the first writer.write(...)
        
        if writer is None:
        
            out_h, out_w = frame.shape[:2]              # frame is already resized (or not) at this point
        
            writer = cv2.VideoWriter(output_video_path, fourcc, writer_fps, (out_w, out_h))
        
            if not writer.isOpened():
        
                raise RuntimeError(
        
                    f"VideoWriter failed to open. "
        
                    f"Path={output_video_path}, fourcc=mp4v, fps={writer_fps}, size={(out_w, out_h)}"
        
                )
        
        writer.write(frame)


        all_rows = []

        # One tracker for class-1 palps (multi-object)
        tracker = MultiObjectSingleClassTracker(iou_thres=IOU_MATCH_THRES, max_age=MAX_AGE)

        start_time = time.time()
        frame_count = 0
        prev_gray = None  # reset per video

        while cap.isOpened() and frame_count <= max_frame:
            ret, frame = cap.read()
            if not ret:
                break


            # timestamp priority:
            # - if CSV provides seconds, use it
            # - else use synthetic timestamp consistent with OUTPUT_FPS (since you requested 30 fps)
            current_timestamp = timestamps.get(frame_count, frame_count / OUTPUT_FPS)

            frame, row, prev_gray = process_frame(
                frame=frame,
                frame_number=frame_count,
                current_timestamp=current_timestamp,
                tracker=tracker,
                prev_gray=prev_gray
            )

            writer.write(frame)
            all_rows.append(row)
            frame_count += 1

        cap.release()
        writer.release()

        out_csv_path = os.path.join(output_folder, f"{folder_name}_palps_tracks.csv")
        pd.DataFrame(all_rows).to_csv(out_csv_path, index=False)

        elapsed_time = time.time() - start_time
        print(f"Processed video {video_path} in {elapsed_time:.2f} seconds. Output saved to {output_folder}")

Model moved to CUDA.
CSV file /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L5/Trial_1_Recording.csv not found. Timestamps will be synthetic.
Processed video /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L5/Trial_1_Recording.mp4 in 4.54 seconds. Output saved to /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L5/1_Recording
CSV file /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L2/Trial_1_Recording.csv not found. Timestamps will be synthetic.
Processed video /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L2/Trial_1_Recording.mp4 in 4.45 seconds. Output saved to /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L2/1_Recording
CSV file /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L8/Trial_1_Recording.csv not found. Timestamps will be synthetic.
Processed video /home/ramanlab/Documents/Arshiya/all_vids/08.14.2025/Testing/HEX/L8/Trial_1_Recording.mp4 in 4.47 seconds

# Fruit Fly Analysis Code Is Below must modify to do what u want distance between class 1 (palp centers)

# Plots & Data for Distance Between The Eye and Proboscis

In [62]:
"""
Compute global min/max distance values for each fly folder
while **ignoring** any distances outside user‑defined hard limits.

Workflow
========
1.  Scan every CSV matching ``*class_2.csv`` in each fly folder.
2.  For each file, keep only distances in **[HARD_MIN, HARD_MAX]**.
3.  Find that file’s min/max; aggregate to obtain the folder‑level
    global min/max.
4.  Write results to ``global_distance_stats_class_2.json`` inside the
   corresponding fly folder.

If *all* values in a file are out of range, the file is skipped.
If no files contain in‑range data, the folder is reported and skipped.

Edit ``HARD_MIN``, and ``HARD_MAX`` below.
"""

from __future__ import annotations
import glob
import json
import os
from pathlib import Path
import pandas as pd

# Only values in this inclusive range are considered.
HARD_MIN_eye: float = 70    # TODO: 105 for manual 70 for OCTo
HARD_MAX_eye: float = 250  # TODO: 
# --------------------------------------------------------

main_directory = Path(main_directory).expanduser().resolve()

if not main_directory.is_dir():
    raise NotADirectoryError(f"{main_directory} is not a valid directory")

# Identify each immediate sub‑folder (one per fly)
fly_folders = [p for p in main_directory.iterdir() if p.is_dir()]

for fly_folder in fly_folders:
    pattern = fly_folder / "**" / "*merged.csv"
    csv_files = glob.glob(str(pattern), recursive=True)

    global_min = float("inf")
    global_max = float("-inf")

    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        if "distance_2_6" not in df.columns:
            print(f"Skipping {csv_file} — 'distance' column missing.")
            continue

        # Keep only values within hard limits
        in_range = df["distance_2_6"].between(HARD_MIN_eye, HARD_MAX_eye, inclusive="both")
        distances = df.loc[in_range, "distance_2_6"]

        if distances.empty:
            print(f"All 'distance' values out of range in {csv_file}; skipping.")
            continue

        file_min = distances.min()
        file_max = distances.max()
        global_min = min(global_min, file_min)
        global_max = max(global_max, file_max)

    # Were any valid values found across all files?
    if global_min == float("inf") or global_max == float("-inf"):
        print(f"No in‑range 'distance' data found in {fly_folder}.")
        continue

    stats = {"global_min": global_min, "global_max": global_max}
    stats_path = fly_folder / "global_distance_stats_class_2.json"
    with open(stats_path, "w") as f:
        json.dump(stats, f)

    print(
        f"{fly_folder.name}: min = {global_min}, max = {global_max} "
        f"(range {HARD_MIN_eye}–{HARD_MAX_eye}). → {stats_path}"
    )

No in‑range 'distance' data found in /home/ramanlab/Documents/Arshiya/all_vids/04.15.2025/Testing/HEX/L4.
No in‑range 'distance' data found in /home/ramanlab/Documents/Arshiya/all_vids/04.15.2025/Testing/HEX/L3.
No in‑range 'distance' data found in /home/ramanlab/Documents/Arshiya/all_vids/04.15.2025/Testing/HEX/L2.
No in‑range 'distance' data found in /home/ramanlab/Documents/Arshiya/all_vids/04.15.2025/Testing/HEX/L1.


In [63]:
"""
Add distance‑percentage normalisation to every *class_2.csv*, using the
folder‑level global min/max produced by *Set Hard Min Max Distance*.

Rules
-----
* If a distance > global_max  → distance_percentage = **101**
* If a distance < global_min  → distance_percentage = **-1**
* Otherwise                     distance_percentage = 100 · (d − global_min)/(global_max − global_min)

The script also writes the reference min/max to each row (columns
``min_distance`` and ``max_distance``) for reproducibility.
"""

from __future__ import annotations

import glob
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd

main_directory = Path(main_directory).expanduser().resolve()

if not main_directory.is_dir():
    raise NotADirectoryError(f"{main_directory} is not a valid directory")

fly_folders = [p for p in main_directory.iterdir() if p.is_dir()]

for fly_folder in fly_folders:
    stats_path = fly_folder / "global_distance_stats_class_2.json"
    if not stats_path.exists():
        print(f"No global stats found for {fly_folder.name}; skipping.")
        continue

    with open(stats_path, "r") as f:
        stats = json.load(f)
    global_min = stats["global_min"]
    global_max = stats["global_max"]

    if global_max == global_min:
        print(
            f"Warning: global_max == global_min in {fly_folder.name}; "
            "distance‑percentage set to 0 for in‑range values."
        )

    pattern = fly_folder / "**" / "*merged.csv"
    csv_files = glob.glob(str(pattern), recursive=True)

    for csv_file in csv_files:
        df = pd.read_csv(csv_file)

        if "distance_2_6" not in df.columns:
            print(f"'distance' column missing in {csv_file}; skipping.")
            continue

        # Add reference columns
        df["min_distance_2_6"] = global_min
        df["max_distance_2_6"] = global_max

        # Vectorised classification
        d = df["distance_2_6"].to_numpy()
        perc = np.empty_like(d, dtype=float)

        over = d > global_max
        under = d < global_min
        in_range = ~(over | under)

        perc[over] = 101.0
        perc[under] = -1.0

        if global_max != global_min:
            perc[in_range] = 100.0 * (d[in_range] - global_min) / (
                global_max - global_min
            )
        else:
            perc[in_range] = 0.0  # arbitrary when range is zero

        df["distance_percentage_2_6"] = perc

        df.to_csv(csv_file, index=False)
        print(f"Updated {csv_file}")

No global stats found for L4; skipping.
No global stats found for L3; skipping.
No global stats found for L2; skipping.
No global stats found for L1; skipping.


In [70]:
"""
Create time-series plots of distance_percentage for each *class_2.csv*,
adding "_time" to the generated PNG filenames.

Usage
-----
1.  Set ``main_directory`` to the path that contains the individual fly
    folders.
2.  Ensure ``timestamp_to_seconds`` is defined in the same namespace or
    import it from your utilities module.
3.  Run the script (e.g. ``python plot_distance_percentage_time.py``).

Each plot is saved in the same directory as its CSV with the pattern:
``<original-csv-name>_time.png``.
"""

from __future__ import annotations

import glob
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

# --------------------------------------------------------

main_directory = Path(main_directory).expanduser().resolve()

fly_folders = [p for p in main_directory.iterdir() if p.is_dir()]

for fly_folder in fly_folders:
    pattern = fly_folder / "**" / "*merged.csv"
    csv_files = glob.glob(str(pattern), recursive=True)

    for csv_file in csv_files:
        df = pd.read_csv(csv_file)

        if {"timestamp", "distance_percentage_2_6"}.issubset(df.columns):
            # Convert timestamps → seconds
            df["time_seconds"] = df["timestamp"].apply(timestamp_to_seconds)
            df = df.dropna(subset=["time_seconds"])

            if df.empty:
                print(f"Skipping {csv_file}: no valid timestamps.")
                continue

            # Normalise start-time to 0
            df["time_seconds"] -= df["time_seconds"].iloc[0]

            # Plot
            plt.figure(figsize=(10, 6))
            plt.plot(
                df["time_seconds"],
                df["distance_percentage_2_6"],
                label="Normalised Distance %",
                marker="o",
                linestyle="-",
                markersize=3,
            )
            plt.xlabel("Time (seconds)")
            plt.ylabel("Normalised Distance %")
            plt.title(f"Normalised Distance %\n{Path(csv_file).name}")
            plt.legend()
            plt.grid(True)

            # Build output filename with "_time" suffix
            csv_path = Path(csv_file)
            plot_file = csv_path.with_suffix("").as_posix() + "_time.png"

            plt.savefig(plot_file)
            plt.close()

            print(f"Plot saved → {plot_file}")
        else:
            print(
                f"Skipping {csv_file}: missing 'timestamp' or 'distance_percentage' columns."
            )

In [71]:
import os
import glob
import pandas as pd

# Identify each fly folder in the main directory
fly_folders = [os.path.join(main_directory, f) for f in os.listdir(main_directory)
               if os.path.isdir(os.path.join(main_directory, f))]

for fly_folder in fly_folders:
    # Glob for CSV files that contain "class_2.csv" anywhere in the filename.
    pattern = os.path.join(fly_folder, "**", "*merged.csv")
    csv_files = glob.glob(pattern, recursive=True)
    
    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        
        # Strip spaces from column names
        df.columns = df.columns.str.strip()
        
        # Check if 'frame' column exists
        if 'frame' not in df.columns:
            print(f"Warning: 'frame' column not found in {csv_file}. Skipping file.")
            continue
        
        # Convert 'distance' column to numeric and check for NaNs or blanks
        if 'distance_2_6' in df.columns:
            df['distance_2_6'] = pd.to_numeric(df['distance_2_6'], errors='coerce')
            dropped_frames = df[df['distance_2_6'].isna()]['frame'].tolist()
        else:
            dropped_frames = []
        
        # Get the unique frame numbers present in the CSV
        present_frames = sorted(df['frame'].unique())
        if not present_frames:
            continue  # Skip if no frame numbers are found
        
        # Determine the full range of expected frame numbers
        min_frame = present_frames[0]
        max_frame = present_frames[-1]
        expected_frames = set(range(min_frame, max_frame + 1))
        
        # Identify missing frames (i.e., frames not present at all)
        missing_frames = sorted(expected_frames - set(present_frames))
        
        # Combine missing frames and NaN distance frames
        all_dropped_frames = sorted(set(missing_frames) | set(dropped_frames))
        total_dropped = len(all_dropped_frames)
        
        # Create a text file with the same base name as the CSV (append _dropped_frames.txt)
        txt_file = csv_file.replace(".csv", "_dropped_frames.txt")
        
        with open(txt_file, "w") as f:
            if total_dropped == 0:
                f.write("No dropped frames found.\n")
            else:
                f.write("Dropped frames (missing or NaN distance):\n")
                for frame in all_dropped_frames:
                    f.write(f"{frame}\n")
                f.write(f"\nTotal dropped frames: {total_dropped}\n")
        
        print(f"Dropped frames details saved to {txt_file}")

# Heatmaps of Eye / Proboscis Distance

In [None]:
import os
import glob
import pandas as pd

# Identify each fly folder in the main directory.
fly_folders = [os.path.join(main_directory, f) 
               for f in os.listdir(main_directory) 
               if os.path.isdir(os.path.join(main_directory, f))]

for fly_folder in fly_folders:
    # Create the destination folder inside each fly folder.
    destination_folder = os.path.join(fly_folder, "Eye_Prob_Dist")
    os.makedirs(destination_folder, exist_ok=True)
    
    # Glob for CSV files that contain "class_2.csv" anywhere in the filename.
    pattern = os.path.join(fly_folder, "**", "*merged.csv")
    csv_files = glob.glob(pattern, recursive=True)
    
    for csv_file in csv_files:
        # Skip files that are already in the destination folder to avoid duplicates.
        if destination_folder in os.path.dirname(csv_file):
            continue
        
        try:
            df = pd.read_csv(csv_file)
            # Keep only the columns: frame, distance_percentage, and timestamp.
            df_filtered = df[["frame", "timestamp", "x_class2", "y_class2", "x_class6", "y_class6", "distance_percentage_2_6"]]
            # Construct the destination file path with a prefix "updated_".
            dest_file = os.path.join(destination_folder, "updated_" + os.path.basename(csv_file))
            # Write the filtered data to a new CSV file.
            df_filtered.to_csv(dest_file, index=False)
            print(f"Copied filtered data from {csv_file} to {dest_file}")
        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")


In [72]:
import os
import pandas as pd

# Define the main directory.
main_dir = main_directory

# Iterate over each fly folder in the main directory.
for fly_folder in os.listdir(main_dir):
    fly_path = os.path.join(main_dir, fly_folder)
    if not os.path.isdir(fly_path):
        continue

    # Define the cvs_class_2 folder within the fly folder.
    cvs_class_dir = os.path.join(fly_path, 'Eye_Prob_Dist')
    if not (os.path.exists(cvs_class_dir) and os.path.isdir(cvs_class_dir)):
        continue

    print(f"Processing fly folder: {fly_path}")

    # Process only CSV files in cvs_class_2 that start with "updated_"
    for file in os.listdir(cvs_class_dir):
        if file.endswith('.csv') and file.startswith('updated_'):
            updated_csv_path = os.path.join(cvs_class_dir, file)
            print(f"  Found updated CSV: {file}")

            # Extract the identifier from the file name.
            base_name = file[len('updated_'):]
            tokens = base_name.split('_')
            if len(tokens) < 6:
                print(f"    Skipping {file}: not enough tokens to extract identifier.")
                continue
            identifier = '_'.join(tokens[:6])
            print(f"    Identifier extracted: {identifier}")

            # Search for the corresponding output CSV in the fly folder.
            corresponding_output = None
            for f in os.listdir(fly_path):
                if f.startswith("output_" + identifier) and f.endswith('.csv'):
                    corresponding_output = os.path.join(fly_path, f)
                    break

            if corresponding_output is None:
                print(f"    No corresponding output file found for identifier {identifier}.")
                continue

            print(f"    Found corresponding output file: {os.path.basename(corresponding_output)}")

            # Load the output CSV to determine odor on/off frames.
            df_output = pd.read_csv(corresponding_output)
            if "ActiveOFM" not in df_output.columns:
                print(f"    File {corresponding_output} does not contain 'Active OFM Pin' column.")
                continue

            # Determine frames where odor is on: any value not equal to "off"
            odor_on_indices = df_output.index[df_output["ActiveOFM"].astype(str) != "off"].tolist()
            if not odor_on_indices:
                print(f"    No odor on detected in {corresponding_output}.")
                continue

            odor_on_first = min(odor_on_indices)
            odor_on_last = max(odor_on_indices)
            print(f"    Odor on from frame {odor_on_first} to {odor_on_last}")

            # Load the updated CSV file to update.
            df_updated = pd.read_csv(updated_csv_path)
            # Use the "Frame" column if available; otherwise, use the DataFrame index.
            if "Frame" in df_updated.columns:
                df_updated["OFM_State"] = df_updated["Frame"].apply(
                    lambda x: "before" if x < odor_on_first else ("during" if odor_on_first <= x <= odor_on_last else "after")
                )
            else:
                df_updated["OFM_State"] = df_updated.index.map(
                    lambda x: "before" if x < odor_on_first else ("during" if odor_on_first <= x <= odor_on_last else "after")
                )
            
            # Overwrite the updated CSV file with the new columns.
            df_updated.to_csv(updated_csv_path, index=False)
            print(f"    Overwritten {file} with updated odor state information.\n")

In [None]:
import os
import shutil

# Define valid month prefixes
MONTHS = [
    "january", "february", "march", "april", "may", "june",
    "july", "august", "september", "october", "november", "december"
]

# List all fly folders inside the main directory.
fly_folders = [
    os.path.join(main_directory, folder)
    for folder in os.listdir(main_directory)
    if os.path.isdir(os.path.join(main_directory, folder))
    and folder.lower().startswith(tuple(MONTHS))  # Only include folders starting with a month
]

for fly_folder in fly_folders:
    fly_name = os.path.basename(fly_folder)
    print(f"Processing fly: {fly_name}")
    
    # Define the subfolder containing CSV files.
    cvs_folder = os.path.join(fly_folder, "Eye_Prob_Dist")
    if not os.path.exists(cvs_folder):
        print(f"  No 'Eye_Prob_Dist' folder found in {fly_name}. Skipping...")
        continue

    # Create training and testing directories inside the Eye_Prob_Dist folder if they don't exist.
    training_dir = os.path.join(cvs_folder, "training")
    testing_dir = os.path.join(cvs_folder, "testing")
    os.makedirs(training_dir, exist_ok=True)
    os.makedirs(testing_dir, exist_ok=True)

    # List all CSV files in the Eye_Prob_Dist folder.
    csv_files = [os.path.join(cvs_folder, f) for f in os.listdir(cvs_folder) if f.endswith('.csv')]

    # Sort CSV files into training and testing groups.
    training_files = sorted([f for f in csv_files if "training" in os.path.basename(f).lower()])
    testing_files = sorted([f for f in csv_files if "testing" in os.path.basename(f).lower()])

    # Move the training files to the training directory.
    print("  Moving Training CSV Files:")
    for file in training_files:
        destination = os.path.join(training_dir, os.path.basename(file))
        print(f"    Moving {file} to {destination}")
        shutil.move(file, destination)

    # Move the testing files to the testing directory.
    print("  Moving Testing CSV Files:")
    for file in testing_files:
        destination = os.path.join(testing_dir, os.path.basename(file))
        print(f"    Moving {file} to {destination}")
        shutil.move(file, destination)

    print("-" * 40)

In [73]:
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize
from matplotlib.ticker import FuncFormatter
import os                       # ← already used later; kept for completeness

# --------------------------------------------------------------------------
# 1.  Matplotlib defaults (unchanged)
# --------------------------------------------------------------------------
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'lines.linewidth': 2,
    'axes.linewidth': 1.5,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'legend.fontsize': 12
})

def log_tick_formatter(x, pos):
    original = np.expm1(x)
    return f"{original:.0f}"

# --------------------------------------------------------------------------
# 2.  Locate fly folders (unchanged)
# --------------------------------------------------------------------------
fly_folders = [os.path.join(main_directory, f) for f in os.listdir(main_directory)
               if os.path.isdir(os.path.join(main_directory, f))]

# --------------------------------------------------------------------------
# 3.  Main loop (only the sections marked ➊–➌ differ from your original code)
# --------------------------------------------------------------------------
for fly_folder in fly_folders:
    fly_name = os.path.basename(fly_folder)
    cvs_folder = os.path.join(fly_folder, "Eye_Prob_Dist")
    if not os.path.exists(cvs_folder):
        continue

    heat_maps_folder = os.path.join(cvs_folder, "heat_maps")
    os.makedirs(heat_maps_folder, exist_ok=True)

    import re

    def trial_index(name: str, category: str) -> int:
        """
        Return the numeric trial id that follows '{category}_' in the filename.
        Falls back to the last number in the name if no match is found.
        """
        m = re.search(fr'_{category}_(\d+)\b', name)
        if m:
            return int(m.group(1))
        # Fallback: use the last number in the string
        nums = re.findall(r'\d+', name)
        return int(nums[-1]) if nums else float('inf')
    
    for category in ["training", "testing"]:
        category_folder = os.path.join(cvs_folder, category)
        csv_files = [
            os.path.join(category_folder, f) for f in os.listdir(category_folder)
            if f.startswith("updated") and f.endswith(".csv")
        ]
        # ✅ Sort by the trial number, so 1…9 come before 10
        csv_files = sorted(csv_files, key=lambda p: trial_index(os.path.basename(p), category))

        if not csv_files:
            continue

        trials = []
        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            df.columns = df.columns.str.strip()
            if 'timestamp' not in df or 'distance_percentage_2_6' not in df or "OFM_State" not in df or "frame" not in df:
                continue

            # ------------------------------------------------------------------
            # Timestamp processing (unchanged)
            # ------------------------------------------------------------------
            df['time_seconds'] = df['timestamp'].apply(timestamp_to_seconds)
            df.dropna(subset=['time_seconds'], inplace=True)
            if df.empty or len(df) < 2:
                continue
            df['relative_time'] = df['time_seconds'] - df['time_seconds'].iloc[0]

            odor_df = df[df["OFM_State"].str.lower() == "during"]
            if odor_df.empty:
                continue
            odor_onset = odor_df['relative_time'].iloc[0]
            odor_offset = odor_df['relative_time'].iloc[-1]

            trial_label = os.path.splitext(os.path.basename(csv_file))[0]
            if category == "testing" or (category == "training" and not any(x in trial_label 
                                                                              for x in ["training_5", "training_6", "training_7", "training_8"])):
                odor_duration = 30.0
            else:
                odor_duration = odor_offset - odor_onset

            final_total_duration = 30 + odor_duration + 90

            total_frames = np.arange(df['frame'].min(), df['frame'].max() + 1)
            full_data = np.full_like(total_frames, np.nan, dtype=float)
            frame_indices = np.searchsorted(total_frames, df['frame'].values)
            full_data[frame_indices] = df['distance_percentage_2_6'].values

            # ------------------------------------------------------------------
            # ➊ Flag special values so they render in red
            # ------------------------------------------------------------------
            data_for_plot = full_data.copy()
            data_for_plot[data_for_plot == -1] = -0.5   # below vmin → ‘under’ colour
            data_for_plot[data_for_plot == 101] = 101    # above vmax → ‘over’ colour

            new_time = np.linspace(0, final_total_duration, len(total_frames))

            trials.append({
                'label': trial_label,
                'time': new_time,
                'data': data_for_plot,
                'odor_start': 30,
                'odor_end': 30 + odor_duration
            })

        if not trials:
            continue

        # ------------------------------------------------------------------
        # ➋ Create a copy of the viridis colormap and paint extremes red
        # ------------------------------------------------------------------
        # --- colormap setup -------------------------------------------------------
        cmap = mpl.colormaps['viridis'].copy()
        cmap.set_under('pink')
        cmap.set_over('pink')
        cmap.set_bad('dimgray')          # ← NEW: NaNs (blank frames) show as dark-gray

        # Normalise on log1p scale (unchanged)
        norm = Normalize(vmin=np.log1p(0), vmax=np.log1p(100))

        # ------------------------------------------------------------------
        # Plotting (minor edits at pcm and colour-bar lines only)
        # ------------------------------------------------------------------
        n_trials = len(trials)
        fig, axs = plt.subplots(n_trials, 1, figsize=(18, 2 * n_trials), sharex=True)
        if n_trials == 1:
            axs = [axs]

        legend_handles = [Line2D([0], [0], color='red', linewidth=2.5, linestyle='-', label='Odor Period')]

        for i, trial in enumerate(trials):
            ax = axs[i]
            time_edges = np.linspace(trial['time'][0], trial['time'][-1], len(trial['time']) + 1)
            X, Y = np.meshgrid(time_edges, [0, 1])

            # ------------------------------------------------------------------
            # ➌ Use the updated data array and custom cmap
            # ------------------------------------------------------------------
            data_row = np.log1p(trial['data']).reshape(1, -1)
            pcm = ax.pcolormesh(X, Y, data_row, cmap=cmap, shading='auto', norm=norm)

            ax.set_yticks([])
            ax.tick_params(axis='x', direction='out')
            ax.set_xlim(trial['time'][0], trial['time'][-1])
            ax.set_title(trial['label'], loc='left')
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.axvline(trial['odor_start'], color='red', linewidth=2.5)
            ax.axvline(trial['odor_end'], color='red', linewidth=2.5)

        axs[-1].set_xlabel("Time (seconds)")
        fig.suptitle(f"{fly_name} - {category.capitalize()} Trials\nLog-Transformed (log1p) Heatmaps", fontsize=20)
        fig.tight_layout(rect=[0, 0, 1, 0.96])

        # ------------------------------------------------------------------
        # ➍ Enable arrows on the colour-bar so users can see the red extremes
        # ------------------------------------------------------------------
        cbar = fig.colorbar(pcm, ax=axs, orientation='vertical',
                            fraction=0.02, pad=0.04, extend='both')
        cbar.set_label("Distance Percentage", fontsize=14)
        cbar.ax.yaxis.set_major_formatter(FuncFormatter(log_tick_formatter))
        fig.legend(handles=legend_handles, loc='upper right', frameon=True)

        out_path = os.path.join(heat_maps_folder, f"{fly_name}_{category}_heatmap_log.png")
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Heatmap saved for fly {fly_name} in category {category}: {out_path}")

In [None]:
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize
import os
# --------------------------------------------------------------------------
# Matplotlib defaults
# --------------------------------------------------------------------------
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'lines.linewidth': 2,
    'axes.linewidth': 1.5,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'legend.fontsize': 12
})

# --------------------------------------------------------------------------
# Locate fly folders
# --------------------------------------------------------------------------
fly_folders = [os.path.join(main_directory, f) for f in os.listdir(main_directory)
               if os.path.isdir(os.path.join(main_directory, f))]

# --------------------------------------------------------------------------
# Main loop
# --------------------------------------------------------------------------
for fly_folder in fly_folders:
    fly_name = os.path.basename(fly_folder)
    cvs_folder = os.path.join(fly_folder, "Eye_Prob_Dist")
    if not os.path.exists(cvs_folder):
        continue

    heat_maps_folder = os.path.join(cvs_folder, "heat_maps")
    os.makedirs(heat_maps_folder, exist_ok=True)

    import re

    def trial_index(name: str, category: str) -> int:
        """
        Return the numeric trial id that follows '{category}_' in the filename.
        Falls back to the last number in the name if no match is found.
        """
        m = re.search(fr'_{category}_(\d+)\b', name)
        if m:
            return int(m.group(1))
        # Fallback: use the last number in the string
        nums = re.findall(r'\d+', name)
        return int(nums[-1]) if nums else float('inf')
    
    for category in ["training", "testing"]:
        category_folder = os.path.join(cvs_folder, category)
        csv_files = [
            os.path.join(category_folder, f) for f in os.listdir(category_folder)
            if f.startswith("updated") and f.endswith(".csv")
        ]
        # ✅ Sort by the trial number, so 1…9 come before 10
        csv_files = sorted(csv_files, key=lambda p: trial_index(os.path.basename(p), category))

        if not csv_files:
            continue

        trials = []
        for csv_file in csv_files:
            df = pd.read_csv(csv_file)
            df.columns = df.columns.str.strip()
            if {'timestamp', 'distance_percentage_2_6', 'OFM_State', 'frame'} - set(df.columns):
                continue

            # --- timestamp handling ------------------------------------------------
            df['time_seconds'] = df['timestamp'].apply(timestamp_to_seconds)
            df.dropna(subset=['time_seconds'], inplace=True)
            if df.empty or len(df) < 2:
                continue
            df['relative_time'] = df['time_seconds'] - df['time_seconds'].iloc[0]

            odor_df = df[df["OFM_State"].str.lower() == "during"]
            if odor_df.empty:
                continue
            odor_onset = odor_df['relative_time'].iloc[0]
            odor_offset = odor_df['relative_time'].iloc[-1]

            trial_label = os.path.splitext(os.path.basename(csv_file))[0]
            if category == "testing" or (
                category == "training"
                and not any(x in trial_label for x in ["training_5", "training_6", "training_7", "training_8"])
            ):
                odor_duration = 30.0
            else:
                odor_duration = odor_offset - odor_onset

            final_total_duration = 30 + odor_duration + 90

            total_frames = np.arange(df['frame'].min(), df['frame'].max() + 1)
            full_data = np.full_like(total_frames, np.nan, dtype=float)
            frame_indices = np.searchsorted(total_frames, df['frame'].values)
            full_data[frame_indices] = df['distance_percentage_2_6'].values

            # --- flag sentinel values for red colouring ---------------------------
            data_for_plot = full_data.copy()
            data_for_plot[data_for_plot == -1] = -0.5   # below vmin
            data_for_plot[data_for_plot == 101] = 101    # above vmax

            new_time = np.linspace(0, final_total_duration, len(total_frames))

            trials.append({
                'label': trial_label,
                'time': new_time,
                'data': data_for_plot,
                'odor_start': 30,
                'odor_end': 30 + odor_duration
            })

        if not trials:
            continue

        # --- colormap setup -------------------------------------------------------
        cmap = mpl.colormaps['viridis'].copy()
        cmap.set_under('pink')
        cmap.set_over('pink')
        cmap.set_bad('dimgray')          # ← NEW: NaNs (blank frames) show as dark-gray

        # linear normalisation 0–100 %
        norm = Normalize(vmin=0, vmax=100)

        # --- plotting ------------------------------------------------------------
        n_trials = len(trials)
        fig, axs = plt.subplots(n_trials, 1, figsize=(18, 2 * n_trials), sharex=True)
        if n_trials == 1:
            axs = [axs]

        legend_handles = [Line2D([0], [0], color='red', linewidth=2.5, label='Odor Period')]

        for i, trial in enumerate(trials):
            ax = axs[i]
            time_edges = np.linspace(trial['time'][0], trial['time'][-1], len(trial['time']) + 1)
            X, Y = np.meshgrid(time_edges, [0, 1])

            data_row = trial['data'].reshape(1, -1)
            pcm = ax.pcolormesh(X, Y, data_row, cmap=cmap, shading='auto', norm=norm)

            ax.set_yticks([])
            ax.tick_params(axis='x', direction='out')
            ax.set_xlim(trial['time'][0], trial['time'][-1])
            ax.set_title(trial['label'], loc='left')
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.axvline(trial['odor_start'], color='red', linewidth=2.5)
            ax.axvline(trial['odor_end'], color='red', linewidth=2.5)

        axs[-1].set_xlabel("Time (seconds)")
        fig.suptitle(f"{fly_name} – {category.capitalize()} Trials", fontsize=20)
        fig.tight_layout(rect=[0, 0, 1, 0.96])

        cbar = fig.colorbar(pcm, ax=axs, orientation='vertical',
                            fraction=0.02, pad=0.04, extend='both')
        cbar.set_label("Distance Percentage (%)", fontsize=14)

        fig.legend(handles=legend_handles, loc='upper right', frameon=True)

        out_path = os.path.join(heat_maps_folder, f"{fly_name}_{category}_heatmap.png")
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Heatmap saved for fly {fly_name} in category {category}: {out_path}")

In [74]:
import os
import shutil

# REQUIRED: main_directory must already be defined.
global_heatmap_folder = os.path.join(main_directory, "heat_maps")

# Create the global heat_maps folder if missing
os.makedirs(global_heatmap_folder, exist_ok=True)

# Create Eye_Prob_Dist if missing
eye_prob_dist_dir = os.path.join(global_heatmap_folder, "Eye_Prob_Dist")
os.makedirs(eye_prob_dist_dir, exist_ok=True)

abs_dest_root = os.path.abspath(global_heatmap_folder)

for root, _, files in os.walk(main_directory, topdown=True):
    # 1) Skip the destination tree to avoid copying into itself
    if os.path.abspath(root).startswith(abs_dest_root):
        continue

    # 2) Only gather from source heat_maps folders
    if "heat_maps" not in root:
        continue

    for file in files:
        if not (file.endswith(".png") and "heatmap" in file):
            continue

        src_path = os.path.join(root, file)
        dest_path = os.path.join(eye_prob_dist_dir, file)

        # Skip if source already equals destination (extra safety)
        if os.path.abspath(src_path) == os.path.abspath(dest_path):
            continue

        # Overwrite existing files with the same name
        shutil.copy2(src_path, dest_path)
        print(f"Copied (overwritten if existed): {src_path} → {dest_path}")


# Envelope RMS

In [None]:
import os
import glob
import pandas as pd

# Identify each fly folder in the main directory.
fly_folders = [os.path.join(main_directory, f) 
               for f in os.listdir(main_directory) 
               if os.path.isdir(os.path.join(main_directory, f))]

for fly_folder in fly_folders:
    # Create the destination folder inside each fly folder.
    destination_folder = os.path.join(fly_folder, "RMS_calculations")
    os.makedirs(destination_folder, exist_ok=True)
    
    # Glob for CSV files that contain "class_2.csv" anywhere in the filename.
    pattern = os.path.join(fly_folder, "**", "*merged.csv")
    csv_files = glob.glob(pattern, recursive=True)
    
    for csv_file in csv_files:
        # Skip files that are already in the destination folder to avoid duplicates.
        if destination_folder in os.path.dirname(csv_file):
            continue
        
        try:
            df = pd.read_csv(csv_file)
            # Keep only the columns: frame, distance_percentage, and timestamp.
            df_filtered = df[["frame", "timestamp", "x_class2", "y_class2", "x_class6", "y_class6", "distance_percentage_2_6"]]
            # Construct the destination file path with a prefix "updated_".
            dest_file = os.path.join(destination_folder, "updated_" + os.path.basename(csv_file))
            # Write the filtered data to a new CSV file.
            df_filtered.to_csv(dest_file, index=False)
            print(f"Copied filtered data from {csv_file} to {dest_file}")
        except Exception as e:
            print(f"Failed to process {csv_file}: {e}")


In [None]:
import os
import pandas as pd

# Define the main directory.
main_dir = main_directory

# Iterate over each fly folder in the main directory.
for fly_folder in os.listdir(main_dir):
    fly_path = os.path.join(main_dir, fly_folder)
    if not os.path.isdir(fly_path):
        continue

    # Define the cvs_class_2 folder within the fly folder.
    cvs_class_dir = os.path.join(fly_path, 'RMS_calculations')
    if not (os.path.exists(cvs_class_dir) and os.path.isdir(cvs_class_dir)):
        continue

    print(f"Processing fly folder: {fly_path}")

    # Process only CSV files in cvs_class_2 that start with "updated_"
    for file in os.listdir(cvs_class_dir):
        if file.endswith('.csv') and file.startswith('updated_'):
            updated_csv_path = os.path.join(cvs_class_dir, file)
            print(f"  Found updated CSV: {file}")

            # Extract the identifier from the file name.
            base_name = file[len('updated_'):]
            tokens = base_name.split('_')
            if len(tokens) < 6:
                print(f"    Skipping {file}: not enough tokens to extract identifier.")
                continue
            identifier = '_'.join(tokens[:6])
            print(f"    Identifier extracted: {identifier}")

            # Search for the corresponding output CSV in the fly folder.
            corresponding_output = None
            for f in os.listdir(fly_path):
                if f.startswith("output_" + identifier) and f.endswith('.csv'):
                    corresponding_output = os.path.join(fly_path, f)
                    break

            if corresponding_output is None:
                print(f"    No corresponding output file found for identifier {identifier}.")
                continue

            print(f"    Found corresponding output file: {os.path.basename(corresponding_output)}")

            # Load the output CSV to determine odor on/off frames.
            df_output = pd.read_csv(corresponding_output)
            if "ActiveOFM" not in df_output.columns:
                print(f"    File {corresponding_output} does not contain 'Active OFM Pin' column.")
                continue

            # Determine frames where odor is on: any value not equal to "off"
            odor_on_indices = df_output.index[df_output["ActiveOFM"].astype(str) != "off"].tolist()
            if not odor_on_indices:
                print(f"    No odor on detected in {corresponding_output}.")
                continue

            odor_on_first = min(odor_on_indices)
            odor_on_last = max(odor_on_indices)
            print(f"    Odor on from frame {odor_on_first} to {odor_on_last}")

            # Load the updated CSV file to update.
            df_updated = pd.read_csv(updated_csv_path)
            # Use the "Frame" column if available; otherwise, use the DataFrame index.
            if "Frame" in df_updated.columns:
                df_updated["OFM_State"] = df_updated["Frame"].apply(
                    lambda x: "before" if x < odor_on_first else ("during" if odor_on_first <= x <= odor_on_last else "after")
                )
            else:
                df_updated["OFM_State"] = df_updated.index.map(
                    lambda x: "before" if x < odor_on_first else ("during" if odor_on_first <= x <= odor_on_last else "after")
                )
            
            # Overwrite the updated CSV file with the new columns.
            df_updated.to_csv(updated_csv_path, index=False)
            print(f"    Overwritten {file} with updated odor state information.\n")

In [None]:
#!/usr/bin/env python3
# fly_distance_histograms.py — v12b (Notebook-ready): Segment-wise RMS ratio
# Accepts `time_seconds` or `timestamp` (numeric/string, HH:MM:SS(:MS), or ISO datetime)

from __future__ import annotations

import re
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
DEFAULT_MAIN_DIRECTORY = main_directory

PHASE_ALIASES: Dict[str, str] = {
    "before": "before", "pre": "before", "baseline": "before",
    "during": "during", "on": "during", "odor_on": "during",
    "after": "after", "post": "after", "odor_off": "after",
}

MEASURE_COLS    = ["distance_percentage_2_6", "distance_percentage"]
BAR_EDGE_COLOR  = "black"
FIGSIZE         = (9, 5)
THRESHOLD       = 1.1
FPS_DEFAULT     = 40
WINDOW_SEC      = 0.5
OUT_FIG_DIR     = "RMS_calculations/histograms"
TESTING_REGEX   = re.compile(r"testing_\d+")

# -----------------------------------------------------------------------------
# HELPERS
# -----------------------------------------------------------------------------

def _resolve_ofm_column(df: pd.DataFrame) -> str:
    for cand in ["OFM State", "OFM_State"]:
        if cand in df.columns:
            return cand
    raise KeyError(f"OFM State column not found: {list(df.columns)}")

def _resolve_measure_column(df: pd.DataFrame) -> str:
    for cand in MEASURE_COLS:
        if cand in df.columns:
            return cand
    raise KeyError(f"Measure column not found. Expected {MEASURE_COLS}, got {list(df.columns)}")

def _normalize_state(val) -> str:
    if not isinstance(val, str):
        raise TypeError(f"State value is not a string: {val}")
    key = val.strip().lower()
    if key not in PHASE_ALIASES:
        raise ValueError(f"Unknown OFM state '{val}' encountered")
    return PHASE_ALIASES[key]

def _extract_trial_label(stem: str) -> str:
    m = TESTING_REGEX.search(stem)
    return m.group(0) if m else stem

def _resolve_time_seconds(df: pd.DataFrame) -> pd.Series:
    """
    Returns seconds-from-start as float Series.
    Accepts:
      - time_seconds (numeric or numeric-looking strings)
      - timestamp: numeric or numeric-looking strings (sec/ms), ISO datetimes,
                   'HH:MM:SS', or 'HH:MM:SS:MS'.
    """
    # Helper: numeric coercion if column is numeric-like strings
    def _coerce_numeric(series: pd.Series) -> pd.Series | None:
        num = pd.to_numeric(series, errors="coerce")
        # Accept if majority are numeric
        if num.notna().sum() >= max(3, int(0.8 * len(num))):
            return num.astype(float)
        return None

    if "time_seconds" in df.columns:
        ts = _coerce_numeric(df["time_seconds"])
        if ts is None:
            raise ValueError("time_seconds present but not numeric/numeric-like.")
        ts = ts - ts.iloc[0]
        return ts

    if "timestamp" in df.columns:
        col = df["timestamp"]

        # 1) Try numeric-like first (handles "0.0", "25", etc.)
        num = _coerce_numeric(col)
        if num is not None:
            ts = num - num.iloc[0]
            # Infer ms vs s using median delta
            dt = ts.diff().median()
            if pd.isna(dt) or dt <= 0:
                # Bad or zero diffs → assume ms to be safe
                ts = ts / 1000.0
            elif dt > 5:  # Typical ms steps (e.g., 25 for 40 FPS)
                ts = ts / 1000.0
            return ts

        # 2) Try pandas datetime parsing (errors=coerce to avoid warnings)
        dt_series = pd.to_datetime(col, errors="coerce", utc=False)
        if dt_series.notna().any():
            base = dt_series[dt_series.notna()].iloc[0]
            ts = (dt_series - base).dt.total_seconds()
            # Fill any NaT rows by forward/backward fill
            ts = ts.fillna(method="ffill").fillna(method="bfill")
            return ts

        # 3) Custom clock-format parsing
        def _parse_clock(s) -> float | np.nan:
            s = str(s).strip()
            if s == "" or s.lower() == "nan":
                return np.nan
            if s.count(":") == 3:  # HH:MM:SS:MS
                try:
                    h, m, sec, ms = s.split(":")
                    return int(h)*3600 + int(m)*60 + float(sec) + float(ms)/1000.0
                except Exception:
                    return np.nan
            if s.count(":") == 2:  # HH:MM:SS
                try:
                    h, m, sec = s.split(":")
                    return int(h)*3600 + int(m)*60 + float(sec)
                except Exception:
                    return np.nan
            # Fallback: plain float seconds if looks like "0.0"
            try:
                return float(s)
            except Exception:
                return np.nan

        ts = col.map(_parse_clock)
        if ts.notna().sum() == 0:
            raise KeyError("Unable to parse 'timestamp' into seconds.")
        ts = ts - ts.dropna().iloc[0]
        return ts

    raise KeyError("Neither 'time_seconds' nor 'timestamp' column found.")

def compute_burst_metrics(series: pd.Series, fps: float, win_s: float = WINDOW_SEC) -> float:
    w = max(int(round(win_s * fps)), 1)
    rms = series.pow(2).rolling(w, center=True).mean().pipe(np.sqrt)
    return rms.max(skipna=True)

def compute_segment_rms(series: pd.Series) -> float:
    if series.empty:
        raise ValueError("Empty series for segment RMS")
    return float(np.sqrt(np.mean(series**2)))

# -----------------------------------------------------------------------------
# PLOTTING
# -----------------------------------------------------------------------------

def plot_grouped_bar_chart(
    fly_name: str,
    data: List[Tuple[str, Dict[str, float]]],
    out_path: Path
):
    labels      = [t for t, _ in data]
    during_vals = [d["during"] for _, d in data]
    after_vals  = [d["after"]  for _, d in data]
    x           = np.arange(len(labels))

    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.bar(x - 0.35, during_vals, width=0.35,
           edgecolor=BAR_EDGE_COLOR,
           color=["red" if v < THRESHOLD else "green" for v in during_vals])
    ax.bar(x + 0.00, after_vals,  width=0.35,
           edgecolor=BAR_EDGE_COLOR,
           color=["red" if v < THRESHOLD else "green" for v in after_vals])
    for bar in ax.patches:
        ax.text(bar.get_x()+bar.get_width()/2,
                bar.get_height()+0.02,
                f"{bar.get_height():.2f}", ha="center", va="bottom")
    ax.axhline(THRESHOLD, color="gray", linestyle="--", linewidth=1)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("Mean ratio (segment ∕ baseline)")
    ax.set_title(f"Fly: {fly_name} — During & After Ratios")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300)
    plt.close(fig)

def plot_rms_ratio_chart(
    fly_name: str,
    data: List[Tuple[str, float]],
    out_path: Path
):
    labels, values = zip(*data) if data else ([], [])
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.bar(labels, values, edgecolor=BAR_EDGE_COLOR)
    for bar in ax.patches:
        ax.text(bar.get_x()+bar.get_width()/2,
                bar.get_height()+0.02,
                f"{bar.get_height():.2f}", ha="center", va="bottom")
    ax.set_ylabel("RMS ratio ((d+a)/b)")
    ax.set_title(f"Fly: {fly_name} — RMS Ratio per Trial")
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=300)
    plt.close(fig)

# -----------------------------------------------------------------------------
# MAIN PIPELINE
# -----------------------------------------------------------------------------

def process_fly(fly_folder: Path, *, verbose: bool = False):
    # Only process flies that actually have RMS_calculations
    testing_dir = fly_folder / "RMS_calculations"
    csv_paths = sorted(testing_dir.glob("*testing*.csv"))
    if not csv_paths:
        return  # silent skip (avoids noisy 'invalid month prefix' messages)

    # Determine FPS from time_seconds or timestamp
    df0 = pd.read_csv(csv_paths[0])
    ts0 = _resolve_time_seconds(df0)
    dt = ts0.diff().median()
    fps_est = (1.0 / dt) if (pd.notna(dt) and dt > 0) else FPS_DEFAULT

    ratio_data: List[Tuple[str, Dict[str, float]]] = []
    rms_ratio_data: List[Tuple[str, float]] = []
    for path in csv_paths:
        df = pd.read_csv(path)
        trial = _extract_trial_label(path.stem)
        meas = _resolve_measure_column(df)
        state = _resolve_ofm_column(df)

        # ✅ Ensure numeric and drop out-of-range values (>100% or <0)
        df[meas] = pd.to_numeric(df[meas], errors="coerce")
        clean = df.loc[df[meas].between(0, 100, inclusive="both")].copy()
        if clean.empty:
            if verbose:
                print(f"{trial}: no valid {meas} values in [0, 100]; skipped.")
            continue

        clean["phase"] = clean[state].apply(_normalize_state)

        baseline = clean.loc[clean.phase == "before", meas]
        if baseline.empty:
            continue
        mean_base = baseline.mean()
        clean["ratio"] = clean[meas] / mean_base

        d_val = clean.loc[clean.phase == "during", "ratio"].mean()
        a_val = clean.loc[clean.phase == "after",  "ratio"].mean()
        ratio_data.append((trial, {"during": d_val, "after": a_val}))

        rms_b = compute_segment_rms(clean.loc[clean.phase == "before", "ratio"])
        rms_d = compute_segment_rms(clean.loc[clean.phase == "during", "ratio"])
        rms_a = compute_segment_rms(clean.loc[clean.phase == "after",  "ratio"])
        ratio_rms = (rms_d + rms_a) / rms_b
        rms_ratio_data.append((trial, ratio_rms))

        if verbose:
            print(f"{trial}: fps≈{fps_est:.2f}, mean_d={d_val:.3f}, mean_a={a_val:.3f}, rms_ratio={ratio_rms:.3f}")

    out_dir = fly_folder / OUT_FIG_DIR
    plot_grouped_bar_chart(fly_folder.name, ratio_data, out_dir / "odor_response_bar.png")
    plot_rms_ratio_chart(fly_folder.name, rms_ratio_data, out_dir / "rms_ratio.png")

# -----------------------------------------------------------------------------
# AUTO-RUN FOR JUPYTER
# -----------------------------------------------------------------------------
root = Path(DEFAULT_MAIN_DIRECTORY)
for fly in root.iterdir():
    if fly.is_dir():
        process_fly(fly, verbose=True)

In [None]:
#!/usr/bin/env python3
# fly_envelope_over_time.py — analytic envelope via Hilbert transform over time per trial

import re
from pathlib import Path
from typing import Optional, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import hilbert

# ───────────────────────────────── CONFIG ───────────────────────────────
DEFAULT_MAIN_DIRECTORY = main_directory
OUT_FIG_DIR           = "RMS_calculations/envelope_over_time_plots"
FPS_DEFAULT           = 40
WINDOW_SEC            = 0.25
WINDOW_FRAMES         = max(int(WINDOW_SEC * FPS_DEFAULT), 1)
MEASURE_COLS    = ["distance_percentage_2_6", "distance_percentage"]
TESTING_REGEX         = re.compile(r"testing_(\d+)")

# Odor timing (seconds)
ODOR_ON_S  = 30.0
ODOR_OFF_S = 60.0

# ─────────────────────────────── HELPERS ────────────────────────────────
def _resolve_measure_column(df: pd.DataFrame) -> Optional[str]:
    return next((c for c in MEASURE_COLS if c in df.columns), None)

def _extract_trials(testing_dir: Path):
    """Yield trials sorted numerically by their testing index."""
    files = list(testing_dir.glob("*testing*.csv"))

    def extract_number(p: Path):
        m = TESTING_REGEX.search(p.stem)
        return int(m.group(1)) if m else float('inf')  # if no match, send to end

    for csv_path in sorted(files, key=extract_number):
        m = TESTING_REGEX.search(csv_path.stem)
        label = m.group(0) if m else csv_path.stem
        yield label, csv_path

def _compute_envelope(series: pd.Series, win_frames: int) -> pd.Series:
    """
    Compute analytic envelope on valid samples only.
    - Values outside [0, 100] and non-numeric are treated as missing (NaN) and do not count.
    - We temporarily interpolate to run Hilbert, then re-mask invalid positions back to NaN.
    - Final rolling mean ignores NaNs.
    """
    # Coerce numeric
    s = pd.to_numeric(series, errors="coerce")

    # Valid = within [0, 100]
    valid = s.between(0, 100, inclusive="both")

    if not valid.any():
        # No usable data
        return pd.Series(np.nan, index=s.index)

    # Prepare a series for Hilbert (no NaNs), but keep track of invalids
    s_interp = s.where(valid)
    s_interp = s_interp.interpolate(limit_direction="both")

    # If still all-NaN (pathological), bail out
    if s_interp.isna().all():
        return pd.Series(np.nan, index=s.index)

    analytic = hilbert(s_interp.to_numpy())
    env = np.abs(analytic)

    env_series = pd.Series(env, index=s.index)

    # Re-apply mask so invalid samples do not count downstream
    env_series = env_series.mask(~valid)

    # Smooth; NaNs are ignored within the window
    return (
        env_series
        .rolling(window=win_frames, center=True, min_periods=1)
        .mean()
    )

def _rolling_rms_valid(series: pd.Series, win_frames: int) -> pd.Series:
    """
    Rolling RMS over a centered window of size `win_frames`.
    - Uses only valid samples within [0, 100]; invalids are treated as NaN.
    - The rolling mean ignores NaNs (min_periods=1), so edges still produce values.
    """
    s = pd.to_numeric(series, errors="coerce")
    valid = s.between(0, 100, inclusive="both")
    s2 = s.where(valid).pow(2)
    rms = s2.rolling(window=win_frames, center=True, min_periods=1).mean()
    return rms.pow(0.5)

# -----------------------------------------------------------------------------
# PLOTTING
# -----------------------------------------------------------------------------

def plot_envelope_subplots(
    fly_name: str,
    trials_data: dict[str, tuple[np.ndarray, np.ndarray, float]],
    out_path: Path,
    y_max: float
):
    """
    Plot analytic envelope over time for multiple trials as stacked subplots,
    marking the global peak and drawing θ = μ_before + 4σ_before as a horizontal line.
    """
    n = len(trials_data)
    if n == 0:
        print(f"[WARN] {fly_name}: no trials to plot.")
        return

    padded_max = y_max * 1.02 if y_max > 0 else 1.0

    plt.rcParams.update({"figure.dpi": 300, "savefig.dpi": 300})
    fig, axes = plt.subplots(n, 1, figsize=(10, 2.5 * n), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, (label, (time_s, env_vals, thr)) in zip(axes, trials_data.items()):
        ax.plot(time_s, env_vals, linewidth=1, clip_on=False)

        # Peak marker only if we have finite values
        if np.any(np.isfinite(env_vals)):
            idx = np.nanargmax(env_vals)
            ax.plot(time_s[idx], env_vals[idx], marker='o', markersize=10, color='red', zorder=5)

        ax.axhline(thr, linestyle='-', linewidth=1, label='θ = μ_before + 4σ')

        # Odor on/off markers + shaded interval
        ax.axvline(ODOR_ON_S,  linestyle='--', linewidth=1)
        ax.axvline(ODOR_OFF_S, linestyle='--', linewidth=1)
        ax.axvspan(ODOR_ON_S, ODOR_OFF_S, alpha=0.15)

        # Axes cosmetics
        ax.set_ylim(0, padded_max, auto=False)
        ax.autoscale(enable=False, axis="y")
        ax.margins(x=0, y=0)
        ax.set_ylabel("Envelope")
        ax.set_title(label)
        ax.grid(True)

        # Legend (compact)
        peak_handle = plt.Line2D([0], [0], marker='o', color='red', linestyle='None', markersize=6, label='Peak')
        on_handle   = plt.Line2D([0], [0], linestyle='--', label='Odor on/off')
        span_handle = plt.Rectangle((0,0), 1, 1, alpha=0.15, label='Odor on window')
        thr_handle  = plt.Line2D([0], [0], linestyle='-', label='θ = μ_before + 4σ')
        ax.legend(handles=[peak_handle, thr_handle, on_handle, span_handle], loc='upper right', frameon=True, fontsize=8)

    axes[-1].set_xlabel("Time (s)")
    fig.suptitle(
        f"{fly_name}: Analytic Envelope Over Time (window={WINDOW_SEC}s; odor {ODOR_ON_S:.0f}–{ODOR_OFF_S:.0f}s)",
        y=0.98
    )
    fig.tight_layout(rect=[0, 0, 1, 0.95])

    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    plt.close(fig)
    print(f"[OK] {out_path}")

# -----------------------------------------------------------------------------
# WORKFLOW
# -----------------------------------------------------------------------------

def process_fly_envelope(fly_folder: Path):
    """
    Compute analytic envelope for each trial then plot as subplots
    with a common y-axis from 0 to the fly’s global maximum.
    """
    fly_name   = fly_folder.name
    testing_dir = fly_folder / "RMS_calculations"
    if not testing_dir.is_dir():
        print(f"[WARN] {fly_name}: no testing directory.")
        return

    trials_data: dict[str, tuple[np.ndarray, np.ndarray, float]] = {}
    all_max = 0.0

    for label, csv_path in _extract_trials(testing_dir):
        df = pd.read_csv(csv_path)

        # time axis
        if "time_seconds" in df.columns:
            time_s = df["time_seconds"].to_numpy(dtype=float)
        else:
            time_s = np.arange(len(df)) / FPS_DEFAULT

        # measurement series
        meas_col = _resolve_measure_column(df)
        if meas_col is None:
            print(f"[ERROR] {fly_name} {label}: no measure column.")
            continue

        # Coerce numeric and keep raw series; envelope fn handles validity
        series = pd.to_numeric(df[meas_col], errors="coerce")

        # ✅ Envelope (excludes out-of-range samples internally)
        env_vals = _compute_envelope(series, WINDOW_FRAMES).to_numpy()

        # ✅ Rolling RMS column (excludes >100% and <0)
        rms_col = f"{meas_col}_rms_win{WINDOW_FRAMES}"
        df[rms_col] = _rolling_rms_valid(df[meas_col], WINDOW_FRAMES)

        # ✅ Persist back to the same CSV (overwrite in place)
        df.to_csv(csv_path, index=False)

        # Compute threshold from pre-odor window (time < ODOR_ON_S)
        pre_mask = time_s < ODOR_ON_S
        if np.any(pre_mask):
            pre_vals = env_vals[pre_mask]
        else:
            # Fallback: use the first ODOR_ON_S seconds worth of frames by FPS
            n_pre = max(int(ODOR_ON_S * FPS_DEFAULT), 1)
            pre_vals = env_vals[:n_pre]

        mu_before = float(np.nanmean(pre_vals)) if pre_vals.size else 0.0
        sd_before = float(np.nanstd(pre_vals, ddof=0)) if pre_vals.size else 0.0
        thr = mu_before + 4.0 * sd_before

        trial_max = float(np.nanmax(env_vals)) if env_vals.size else 0.0
        
        try:
            trial_max = float(np.nanmax(env_vals))
        except ValueError:
            trial_max = float('nan')

        print(f"[DEBUG] {fly_name} {label} peak={0.0 if not np.isfinite(trial_max) else trial_max:.3f}  μ_before={mu_before:.3f}  σ_before={sd_before:.3f}  θ={thr:.3f}")

        trials_data[label] = (time_s, env_vals, thr)
        if np.isfinite(trial_max) and (trial_max > all_max):
            all_max = trial_max

    print(f"[DEBUG] {fly_name} global peak envelope = {all_max:.3f}")

    out_dir  = fly_folder / OUT_FIG_DIR
    out_path = out_dir / f"{fly_name}_envelope_over_time_subplots.png"
    plot_envelope_subplots(fly_name, trials_data, out_path, y_max=all_max)

def run_envelope_over_time(main_directory: Optional[Union[Path, str]] = None):
    root = Path(main_directory) if main_directory else Path(DEFAULT_MAIN_DIRECTORY)
    root = root.expanduser().resolve()
    for fly in root.iterdir():
        if fly.is_dir():
            process_fly_envelope(fly)

if __name__ == "__main__":
    run_envelope_over_time()

In [None]:
import shutil

def collect_all_plots(main_directory: Union[str, Path], dest_folder: str = "all_envelope_plots"):
    """
    Collect all envelope_over_time_subplots.png files from fly folders
    into a single folder inside main_directory.
    """
    root = Path(main_directory).expanduser().resolve()
    dest = root / dest_folder
    dest.mkdir(parents=True, exist_ok=True)

    count = 0
    for fly in root.iterdir():
        if not fly.is_dir():
            continue
        # expected location of plot
        plot_path = fly / OUT_FIG_DIR / f"{fly.name}_envelope_over_time_subplots.png"
        if plot_path.is_file():
            new_name = f"{fly.name}_envelope_over_time_subplots.png"
            shutil.copy2(plot_path, dest / new_name)
            count += 1

    print(f"[OK] Collected {count} plots into {dest}")

if __name__ == "__main__":
    run_envelope_over_time()  # generate plots
    collect_all_plots(DEFAULT_MAIN_DIRECTORY)  # gather them


# Proboscis Angle

In [None]:
# JUPYTER CELL — Angle at point2 with per-fly centering using distance_percentage==0
# - Uncentered plots: fly_dir/angle_plots/
# - Centered plots:   fly_dir/angle_plots_centered/
# Input match: *_distance_class_2.csv
# Output CSV:  *_distance_class_2_angle_ARB.csv  (adds/updates angle_ARB_deg, angle_centered_deg)

import os, re, glob
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

# ───────── prerequisites ─────────
main_directory = Path(main_directory)
assert main_directory.is_dir(), f"Not a directory: {main_directory}"

# ───────── config ─────────
# Point 1 (arbitrary anchor). If you prefer strict in-bounds indexing use 1079.0.
ANCHOR_X, ANCHOR_Y = 1080.0, 540.0
IN_SUFFIX  = "merged.csv"
OUT_SUFFIX = "merged.csv"
FPS = 40.0  # Frames per second (adjust if needed)

# Month names for folder matching (case-insensitive)
MONTHS = ("january","february","march","april","may","june",
          "july","august","september","october","november","december")

def is_month_folder(p: Path) -> bool:
    return p.is_dir() and p.name.lower().startswith(MONTHS)

def compute_angle_deg_at_point2(df: pd.DataFrame) -> pd.Series:
    """
    Angle at point 2 (class2) between segments:
      (point1 → point2) and (point2 → point3)
    Using vectors:
      u = point1 - point2
      v = point3 - point2
    angle = atan2(|u×v|, u·v) in degrees ∈ [0, 180].
    Returns a pandas Series aligned to df.index.
    """
    required = ["x_class2", "y_class2", "x_class6", "y_class6"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")

    # Point 2 (vertex)
    p2x = df["x_class2"].astype("float64")
    p2y = df["y_class2"].astype("float64")

    # Point 1 (anchor) and Point 3 (class6)
    p1x, p1y = ANCHOR_X, ANCHOR_Y
    p3x = df["x_class6"].astype("float64")
    p3y = df["y_class6"].astype("float64")

    # u = p1 - p2 ; v = p3 - p2
    ux = p1x - p2x
    uy = p1y - p2y
    vx = p3x - p2x
    vy = p3y - p2y

    dot   = ux*vx + uy*vy
    cross = ux*vy - uy*vx

    n1 = np.hypot(ux, uy)
    n2 = np.hypot(vx, vy)
    valid = (n1 > 0) & (n2 > 0) & np.isfinite(dot) & np.isfinite(cross)

    angle_rad = np.full(len(df), np.nan, dtype="float64")
    angle_rad[valid.to_numpy()] = np.arctan2(np.abs(cross[valid]), dot[valid])  # [0, π]
    angle_deg = np.degrees(angle_rad)  # ndarray

    return pd.Series(angle_deg, index=df.index, name="angle_ARB_deg")

def get_time_axis(df: pd.DataFrame):
    """Return time array (seconds) for plotting."""
    if "time_s" in df.columns:
        return df["time_s"].astype(float).to_numpy()
    elif "frame" in df.columns:
        return (df["frame"].astype(float) / FPS).to_numpy()
    else:
        return (np.arange(len(df)) / FPS)

def require_distance_col(df: pd.DataFrame) -> str:
    """Return the name of the distance percentage column, raising if not found."""
    for cand in ["distance_percentage", "distance_percentage_2_6", "distance_pct"]:
        if cand in df.columns:
            return cand
    raise ValueError("Could not find a distance percentage column (tried: "
                     "'distance_percentage', 'distance_percentage_2_6', 'distance_pct').")

def find_fly_reference_angle(csv_paths):
    """
    Across all CSVs for a fly:
    - Prefer the first row where distance_percentage == 0
    - Otherwise choose the row with minimal |distance_percentage_2_6|
    Returns (ref_angle, info_string).
    """
    best = None  # tuple (priority, abs_val, angle_deg, file, idx, time_s)

    for p in csv_paths:
        df = pd.read_csv(p)
        angle_ser = compute_angle_deg_at_point2(df)        # pandas Series
        dist_col  = require_distance_col(df)
        dist      = df[dist_col].astype(float)

        # exact zeros in this file?
        exact_idx = np.flatnonzero(dist.to_numpy() == 0)
        if exact_idx.size > 0:
            idx = int(exact_idx[0])
            angle_here = float(angle_ser.iloc[idx]) if np.isfinite(angle_ser.iloc[idx]) else np.nan
            tvals = get_time_axis(df)
            candidate = (0, 0.0, angle_here, p, idx, float(tvals[idx]))
            if best is None or candidate < best:
                best = candidate
            continue

        # fallback: min |dist|
        with np.errstate(invalid="ignore"):
            absdist = np.abs(dist.to_numpy(dtype=float))
        if absdist.size == 0 or not np.isfinite(absdist).any():
            continue
        idx = int(np.nanargmin(absdist))
        candidate_abs = float(absdist[idx])
        angle_here = float(angle_ser.iloc[idx]) if np.isfinite(angle_ser.iloc[idx]) else np.nan
        tvals = get_time_axis(df)
        candidate = (1, candidate_abs, angle_here, p, idx, float(tvals[idx]))
        if best is None or candidate < best:
            best = candidate

    if best is None:
        raise RuntimeError("No suitable reference frame found for this fly.")
    priority, absval, ref_angle, ref_file, ref_idx, ref_time = best
    how = "distance_percentage == 0" if priority == 0 else f"min |distance_percentage| = {absval:.6g}"
    msg = f"Ref: {how} at {Path(ref_file).name}[row {ref_idx}] t={ref_time:.3f}s → ref_angle={ref_angle:.3f}°"
    return ref_angle, msg

# ───────── traversal ─────────
processed, plotted_unc, plotted_ctr, failed = 0, 0, 0, 0
errors = []

for fly_dir in sorted([p for p in main_directory.iterdir() if p.is_dir()]):
    # Gather all candidate CSVs for this fly (under subfolders that start with a month)
    month_folders = [sub for sub in fly_dir.rglob("*") if is_month_folder(sub)]
    csv_paths = []
    for month_folder in month_folders:
        csv_paths.extend(Path(x) for x in glob.iglob(str(month_folder / "**" / f"*{IN_SUFFIX}"), recursive=True))
    csv_paths = [p for p in csv_paths if p.name.endswith(IN_SUFFIX)]

    if not csv_paths:
        continue

    # Compute per-fly reference angle
    try:
        ref_angle, ref_info = find_fly_reference_angle(csv_paths)
        print(f"[{fly_dir.name}] {ref_info}")
    except Exception as e:
        failed += 1
        errors.append(f"{fly_dir}: {e}")
        continue

    # Ensure per-fly plot folders
    fly_plot_dir_unc = fly_dir / "angle_plots"
    fly_plot_dir_ctr = fly_dir / "angle_plots" / "angle_plots_centered"
    fly_plot_dir_unc.mkdir(exist_ok=True)
    fly_plot_dir_ctr.mkdir(exist_ok=True)

    # Process each CSV with centering
    for csv_path in csv_paths:
        out_path = csv_path.with_name(csv_path.name.replace(IN_SUFFIX, OUT_SUFFIX, 1))
        try:
            df = pd.read_csv(csv_path)
            df["angle_deg_c2_26_vs_anchor"] = compute_angle_deg_at_point2(df)
            df["angle_centered_deg"] = df["angle_deg_c2_26_vs_anchor"] - ref_angle
            df.to_csv(out_path, index=False)
            processed += 1

            # Time axis
            t = get_time_axis(df)

            # File-safe plot name (flatten relative path)
            relative_parts = csv_path.relative_to(fly_dir).with_suffix("").parts
            safe_name = "_".join(relative_parts)

            # —— Uncentered plot ——
            png_unc = fly_plot_dir_unc / f"{safe_name}_angle_ARB.png"
            plt.figure(figsize=(10, 4))
            plt.plot(t, df["angle_deg_c2_26_vs_anchor"].to_numpy(), linewidth=1.5)
            plt.xlabel("Time (s)")
            plt.ylabel("Angle (degrees)")
            plt.title(f"{csv_path.stem} — angle@class2")
            plt.grid(True, alpha=0.25)
            plt.tight_layout()
            plt.savefig(png_unc, dpi=150)
            plt.close()
            plotted_unc += 1

            # —— Centered plot ——
            png_ctr = fly_plot_dir_ctr / f"{safe_name}_angle_ARB_centered.png"
            plt.figure(figsize=(10, 4))
            plt.plot(t, df["angle_centered_deg"].to_numpy(), linewidth=1.5)
            plt.axhline(0, linestyle="--", linewidth=1.0, alpha=0.6)
            plt.xlabel("Time (s)")
            plt.ylabel("Centered angle (deg)")
            plt.title(f"{csv_path.stem} — centered by fly ref (angle@class2 − {ref_angle:.2f}°)")
            plt.grid(True, alpha=0.25)
            plt.tight_layout()
            plt.savefig(png_ctr, dpi=150)
            plt.close()
            plotted_ctr += 1

        except Exception as e:
            failed += 1
            errors.append(f"{csv_path}: {e}")

print(f"Done. CSVs updated: {processed} | Uncentered plots: {plotted_unc} | Centered plots: {plotted_ctr} | Failed: {failed}")
if errors:
    print("\nErrors:")
    for msg in errors[:12]:
        print(" •", msg)
    if len(errors) > 12:
        print(f" … and {len(errors)-12} more")

In [None]:
# JUPYTER CELL — Combined heatmaps of centered angle percentages per fly
# Produces, for each fly:
#   fly_dir/angle_plots_centered_percentage_heatmaps/<fly>_training_angle_centered_pct_heatmap.png
#   fly_dir/angle_plots_centered_percentage_heatmaps/<fly>_testing_angle_centered_pct_heatmap.png
#
# Inputs (preferred): *_distance_class_2_angle_ARB.csv (with angle_centered_pct)

import os, glob
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize

# ───────── prerequisites ─────────
assert 'main_directory' in globals(), "Define main_directory = '/path/to/root' before running."
main_directory = Path(main_directory)
assert main_directory.is_dir(), f"Not a directory: {main_directory}"

# ───────── display defaults (similar to your snippet) ─────────
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'lines.linewidth': 2,
    'axes.linewidth': 1.5,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'legend.fontsize': 12
})

# ───────── config ─────────
FPS = 40.0
IN_SUFFIX_RAW = "*merged.csv"
IN_SUFFIX_ANG = "*merged.csv"

# Make this RELATIVE (no leading slash) and a Path
OUT_HEAT_DIR = Path("angle_plots/angle_plots_centered_percentage_heatmaps")
ANCHOR_X, ANCHOR_Y = 1080.0, 540.0   # anchor used if we must recompute angles

MONTHS = ("january","february","march","april","may","june",
          "july","august","september","october","november","december")

def is_month_folder(p: Path) -> bool:
    return p.is_dir() and p.name.lower().startswith(MONTHS)

def infer_category_from_path(p: Path) -> str | None:
    parts = [s.lower() for s in p.parts]
    if "testing" in parts:
        return "testing"
    if "training" in parts:
        return "training"
    # If not explicit, try filename
    name = p.name.lower()
    if "testing" in name:
        return "testing"
    if "training" in name:
        return "training"
    return None

def get_time_axis(df: pd.DataFrame) -> np.ndarray:
    if "time_s" in df.columns:
        return df["time_s"].astype(float).to_numpy()
    if "frame" in df.columns:
        return (df["frame"].astype(float) / FPS).to_numpy()
    return np.arange(len(df)) / FPS

def require_distance_col(df: pd.DataFrame) -> str:
    for cand in ["distance_percentage", "distance_percentage_2_6", "distance_pct"]:
        if cand in df.columns:
            return cand
    raise ValueError("No distance percentage column found (expected one of: "
                     "'distance_percentage','distance_percentage_2_6','distance_pct').")

def compute_angle_deg_at_point2(df: pd.DataFrame) -> pd.Series:
    """Angle at class2 between (point1→class2) and (class2→class6), degrees [0,180]."""
    required = ["x_class2", "y_class2", "x_class6", "y_class6"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns for angle computation: {missing}")
    p2x = df["x_class2"].astype(float); p2y = df["y_class2"].astype(float)
    p1x, p1y = ANCHOR_X, ANCHOR_Y
    p3x = df["x_class6"].astype(float); p3y = df["y_class6"].astype(float)
    ux, uy = (p1x - p2x), (p1y - p2y)    # p1 - p2
    vx, vy = (p3x - p2x), (p3y - p2y)    # p3 - p2
    dot   = ux*vx + uy*vy
    cross = ux*vy - uy*vx
    n1 = np.hypot(ux, uy); n2 = np.hypot(vx, vy)
    valid = (n1 > 0) & (n2 > 0) & np.isfinite(dot) & np.isfinite(cross)
    angle_rad = np.full(len(df), np.nan)
    angle_rad[valid.to_numpy()] = np.arctan2(np.abs(cross[valid]), dot[valid])
    return pd.Series(np.degrees(angle_rad), index=df.index, name="angle_ARB_deg")

def find_fly_reference_angle(csv_paths_raw: list[Path]) -> float:
    """
    Per-fly reference: first distance_percentage == 0 across files,
    else row with minimal |distance_percentage_2_6|.
    Returns the angle (deg) at that reference frame.
    """
    best = None  # (priority, abs_val, angle_deg)
    for p in csv_paths_raw:
        df = pd.read_csv(p)
        angle_ser = compute_angle_deg_at_point2(df)
        dist_col  = require_distance_col(df)
        dist      = df[dist_col].astype(float).to_numpy()
        exact_idx = np.flatnonzero(dist == 0)
        if exact_idx.size > 0:
            idx = int(exact_idx[0])
            angle_here = float(angle_ser.iloc[idx]) if np.isfinite(angle_ser.iloc[idx]) else np.nan
            cand = (0, 0.0, angle_here)
            if best is None or cand < best:
                best = cand
            continue
        with np.errstate(invalid="ignore"):
            absdist = np.abs(dist)
        if absdist.size == 0 or not np.isfinite(absdist).any():
            continue
        idx = int(np.nanargmin(absdist))
        angle_here = float(angle_ser.iloc[idx]) if np.isfinite(angle_ser.iloc[idx]) else np.nan
        cand = (1, float(absdist[idx]), angle_here)
        if best is None or cand < best:
            best = cand
    if best is None:
        return np.nan
    return best[2]

def compute_fly_max_abs_centered(csv_paths_raw: list[Path], ref_angle: float) -> float:
    """Max |angle_centered_deg| across all raw CSVs for a fly."""
    fly_max = 0.0
    for p in csv_paths_raw:
        try:
            df = pd.read_csv(p)
            ang = compute_angle_deg_at_point2(df)
            centered = (ang - ref_angle).to_numpy(dtype=float)
            with np.errstate(invalid="ignore"):
                local = np.nanmax(np.abs(centered))
            if np.isfinite(local):
                fly_max = max(fly_max, float(local))
        except Exception:
            pass
    return fly_max if np.isfinite(fly_max) and fly_max > 0 else np.nan

def get_odor_window_if_available(df: pd.DataFrame) -> tuple[float, float] | None:
    """
    If 'OFM State' exists, return (onset_time_s, offset_time_s) relative to the file start,
    based on the first/last 'during' segment. Otherwise None.
    """
    cols = {c.lower(): c for c in df.columns}
    if "ofm state" in cols:
        col = cols["ofm state"]
        try:
            # Build time_s
            t = get_time_axis(df)
            mask = df[col].astype(str).str.lower().eq("during").to_numpy()
            if mask.any():
                idx = np.flatnonzero(mask)
                return float(t[idx[0]] - t[0]), float(t[idx[-1]] - t[0])
        except Exception:
            return None
    return None

def _persist_angle_columns(path: Path, df_raw: pd.DataFrame, ang: pd.Series, ref_angle: float, fly_max_abs: float) -> Path:
    """Write angle metrics (deg, centered deg, centered %) back to disk for this file.
    Returns the output CSV path."""
    df_out = df_raw.copy()
    df_out["angle_ARB_deg"] = ang.to_numpy()
    centered_deg = (ang - ref_angle).to_numpy() if np.isfinite(ref_angle) else np.full(len(ang), np.nan)
    df_out["angle_centered_deg"] = centered_deg
    if np.isfinite(fly_max_abs) and fly_max_abs > 0:
        df_out["angle_centered_pct"] = (centered_deg / float(fly_max_abs)) * 100.0
    else:
        df_out["angle_centered_pct"] = np.zeros(len(ang), dtype=float)

    out_csv = path.with_name(path.stem + ".csv")
    df_out.to_csv(out_csv, index=False)
    return out_csv

# ───────── collect flies ─────────
fly_dirs = [p for p in main_directory.iterdir() if p.is_dir()]

for fly_dir in sorted(fly_dirs):
    fly_name = fly_dir.name

    out_dir = fly_dir / OUT_HEAT_DIR
    out_dir.mkdir(parents=True, exist_ok=True)  # ← create full tree if needed
    # optional sanity check:
    assert out_dir.is_dir(), f"Failed to create output directory: {out_dir}"


    # Find candidate CSVs beneath month-named subfolders
    month_folders = [sub for sub in fly_dir.rglob("*") if is_month_folder(sub)]
    raw_csvs, ang_csvs = [], []
    for month_folder in month_folders:
        raw_csvs += [Path(x) for x in glob.iglob(str(month_folder / "**" / f"*{IN_SUFFIX_RAW}"), recursive=True)]
        ang_csvs += [Path(x) for x in glob.iglob(str(month_folder / "**" / f"*{IN_SUFFIX_ANG}"), recursive=True)]

    # Group by category (training/testing)
    grouped = {"training": [], "testing": []}

    # Prefer angle CSVs if they already exist and contain angle_centered_pct
    for p in sorted(ang_csvs):
        cat = infer_category_from_path(p) or "testing"  # default to testing if ambiguous
        grouped[cat].append(("angle", p))

    # Include raw CSVs as well (to compute if angle CSV missing or incomplete)
    for p in sorted(raw_csvs):
        cat = infer_category_from_path(p) or "testing"
        grouped[cat].append(("raw", p))

    # Compute per-fly reference & max if we need to compute from raw
    need_compute = any(kind == "raw" for pairs in grouped.values() for kind, _ in pairs)
    ref_angle = np.nan
    fly_max_abs = np.nan
    if need_compute:
        ref_angle = find_fly_reference_angle(raw_csvs) if raw_csvs else np.nan
        fly_max_abs = compute_fly_max_abs_centered(raw_csvs, ref_angle) if np.isfinite(ref_angle) else np.nan

    for category in ("training", "testing"):
        # Build trial rows (time, centered %, odor window, label)
        trials = []
        seen_labels = set()

        for kind, path in grouped[category]:
            try:
                if kind == "angle":
                    df = pd.read_csv(path)
                    if "angle_centered_pct" not in df.columns:
                        # Fall back to compute if missing
                        raise RuntimeError("angle_centered_pct missing in angle CSV; recomputing from raw.")
                    time = get_time_axis(df)
                    pct = df["angle_centered_pct"].astype(float).to_numpy()
                    # Try to recover odor window from sibling raw CSV if available
                    odor = None
                    # Heuristic: same name but without _angle_ARB
                    raw_guess = path.with_name(path.name.replace(IN_SUFFIX_ANG, IN_SUFFIX_RAW))
                    if raw_guess.exists():
                        df_raw = pd.read_csv(raw_guess)
                        odor = get_odor_window_if_available(df_raw)
                    label = path.stem
                else:
                    # kind == "raw" → compute centered %
                    df_raw = pd.read_csv(path)
                    ang = compute_angle_deg_at_point2(df_raw)
                    centered = ang - ref_angle if np.isfinite(ref_angle) else ang * 0.0
                    if np.isfinite(fly_max_abs) and fly_max_abs > 0:
                        pct = (centered.to_numpy() / fly_max_abs) * 100.0
                    else:
                        pct = np.zeros(len(centered), dtype=float)
                    time = get_time_axis(df_raw)
                    odor = get_odor_window_if_available(df_raw)

                    # >>> ADD to persist the new columns to disk
                    try:
                        out_csv_path = _persist_angle_columns(path, df_raw, ang, ref_angle, fly_max_abs)
                        # Optional: if you want downstream code to load the augmented file next time,
                        # you could log it here:
                        print(f"Augmented CSV written → {out_csv_path}")
                    except Exception as _e:
                        # Non-fatal: continue plotting even if write fails
                        pass
                    # <<< END ADD

                    label = path.stem

                # Avoid duplicates if both angle+raw contributed the same trial
                if label in seen_labels:
                    continue
                seen_labels.add(label)

                # Pack trial row
                trials.append({
                    "label": label,
                    "time": time,
                    "pct": pct,
                    "odor": odor
                })
            except Exception:
                # Skip problematic files silently to keep the grid building
                continue

        if not trials:
            continue

        # Normalize each trial onto a common time grid per trial for pcolormesh
        # Use each trial's own uniform edges (no need to resample); pcolormesh accepts per-row edges.
        n_trials = len(trials)
        fig_h = max(2, int(2 * n_trials))  # scale height similar to your script
        fig, axs = plt.subplots(n_trials, 1, figsize=(18, fig_h), sharex=False)
        if n_trials == 1:
            axs = [axs]

        # Colormap: diverging, centered at 0, with fixed range -100..+100
        cmap = mpl.colormaps['coolwarm'].copy()
        cmap.set_bad('dimgray')
        norm = Normalize(vmin=-100, vmax=100)

        legend_handles = []
        used_odor_line = False

        for i, tr in enumerate(trials):
            ax = axs[i]
            t = tr["time"]
            data_row = tr["pct"]

            # Guard against empty/malformed sequences
            if len(t) < 2 or len(data_row) != len(t):
                # Skip this row visually
                ax.text(0.5, 0.5, f"Skipped: {tr['label']}", transform=ax.transAxes,
                        ha='center', va='center')
                ax.set_yticks([])
                continue

            # Build edges for pcolormesh
            time_edges = np.linspace(t[0], t[-1], len(t) + 1)
            X, Y = np.meshgrid(time_edges, [0, 1])
            row = np.asarray(data_row, dtype=float).reshape(1, -1)

            pcm = ax.pcolormesh(X, Y, row, cmap=cmap, shading='auto', norm=norm)

            # Odor window (if available)
            if tr["odor"] is not None:
                odor_start, odor_end = tr["odor"]
                ax.axvline(odor_start, color='red', linewidth=2.5)
                ax.axvline(odor_end,   color='red', linewidth=2.5)
                used_odor_line = True

            ax.set_yticks([])
            ax.set_xlim(t[0], t[-1])
            ax.set_title(tr["label"], loc='left')
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.tick_params(axis='x', direction='out')

        axs[-1].set_xlabel("Time (seconds)")
        fig.suptitle(f"{fly_name} – {category.capitalize()} Trials (Centered Angle %)", fontsize=20)
        fig.tight_layout(rect=[0, 0, 1, 0.96])

        # Colorbar
        cbar = fig.colorbar(pcm, ax=axs, orientation='vertical', fraction=0.02, pad=0.04)
        cbar.set_label("Centered Angle (%)", fontsize=14)

        # Legend (odor period)
        if used_odor_line:
            legend_handles = [Line2D([0], [0], color='red', linewidth=2.5, label='Odor Period')]
            fig.legend(handles=legend_handles, loc='upper right', frameon=True)

        out_png = out_dir / f"{fly_name}_{category}_angle_centered_pct_heatmap.png"
        plt.savefig(out_png, bbox_inches='tight')
        plt.close()

        print(f"[{fly_name}] {category.capitalize()} heatmap saved → {out_png}")

# Antenna Eye Distance

In [None]:
# --------------------------------------------------------------
# Combine per-pair CSVs into one with only requested columns.
# Looks for sfx2 in: <fly>/Eye_Prob_Dist/testing/
# Saves to: <fly>/RMS_calculations/<...>_class_combined.csv
# --------------------------------------------------------------
from pathlib import Path
import pandas as pd
import re

# REQUIRED: define main_directory as a Path or string before running
main_directory = Path(main_directory)

sfx1    = "merged.csv"
sfx2    = "merged.csv"
out_sfx = "merged.csv"

EPD_SUBPATH = ("Eye_Prob_Dist", "testing")  # where to find sfx2

cols1 = ["frame", "x_class1", "y_class1"]
cols2 = ["frame", "x_class2", "y_class2", "x_class6", "y_class6", "distance_percentage", "timestamp", "OFM_State"]

def _find_matching_sfx2(p1_name: str, p2_dir: Path) -> Path | None:
    # 1) Exact same basename with swapped suffix
    candidate = p2_dir / p1_name.replace(sfx1, sfx2)
    if candidate.exists():
        return candidate
    # 2) Match by testing/training index if present
    m = re.search(r"(testing|training)_(\d+)", p1_name)
    if m:
        n = m.group(2)
        matches = list(p2_dir.glob(f"*testing_{n}*{sfx2}"))
        if len(matches) == 1:
            return matches[0]
    # 3) Single-file fallback
    all_sfx2 = list(p2_dir.glob(f"*{sfx2}"))
    if len(all_sfx2) == 1:
        return all_sfx2[0]
    return None

def _read_df2_robust(p2: Path) -> pd.DataFrame:
    """Read df2, tolerating missing optional columns by filling with NA."""
    try:
        return pd.read_csv(p2, usecols=cols2)
    except Exception:
        base = ["frame", "x_class2", "y_class2"]
        df2 = pd.read_csv(p2, usecols=base)
        # Fill any optional columns not present
        optional = ["x_class6", "y_class6", "distance_percentage", "timestamp", "OFM_State"]
        for c in optional:
            if c not in df2.columns:
                df2[c] = pd.NA
        return df2

def process_pair(p1: Path, fly_root: Path):
    """Create merged CSV under <fly>/RMS_calculations/."""
    p2_dir = fly_root.joinpath(*EPD_SUBPATH)
    if not p2_dir.is_dir():
        print(f"⚠️  {fly_root.name}: missing {p2_dir.relative_to(fly_root)}")
        return

    p2 = _find_matching_sfx2(p1.name, p2_dir)
    if p2 is None:
        print(f"⚠️  {p1.relative_to(main_directory)}: no matching {sfx2} in {p2_dir.relative_to(fly_root)}")
        return

    try:
        df1 = pd.read_csv(p1, usecols=cols1)
    except Exception as e:
        print(f"⚠️  Failed reading {p1}: {e}")
        return

    df2 = _read_df2_robust(p2)
    merged = pd.merge(df1, df2, on="frame", how="inner").sort_values("frame")

    out_dir = fly_root / "Eye_Antenna_Dist"
    out_dir.mkdir(exist_ok=True)
    out_csv = out_dir / p1.name.replace(sfx1, out_sfx)

    merged.to_csv(out_csv, index=False)
    print(f"✓ {out_csv.relative_to(fly_root)}  ⇐ merged with {p2.relative_to(fly_root)}")

# Walk each fly folder; find every *class_1_2.csv, merge with sfx2 from Eye_Prob_Dist/testing
for fly_folder in [p for p in main_directory.iterdir() if p.is_dir()]:
    for p1 in fly_folder.rglob(f"*{sfx1}"):
        process_pair(p1, fly_folder)

In [None]:
# JUPYTER CELL — Add distance_class1_class2 in place (no testing/training subfolders assumed)
from pathlib import Path
import numpy as np
import pandas as pd

def compute_distance_column(df: pd.DataFrame) -> pd.DataFrame:
    required = ["x_class1", "y_class1", "x_class2", "y_class2"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")
    for c in required:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    dx = df["x_class1"] - df["x_class2"]
    dy = df["y_class1"] - df["y_class2"]
    df["distance_class1_class2"] = np.sqrt(dx*dx + dy*dy)
    return df

def process_main_directory(main_directory: str | Path) -> list[Path]:
    """
    Find CSVs under <fly>/Eye_Antenna_Dist/*combined.csv (no testing/training folders),
    add 'distance_class1_class2' if missing, and overwrite in place.
    """
    root = Path(main_directory)
    assert root.is_dir(), f"Not a directory: {root}"

    # Relative recursive search (avoids absolute glob error)
    pattern = "Eye_Antenna_Dist/*combined.csv"
    csv_paths = sorted(root.rglob(pattern))

    outputs: list[Path] = []
    for csv_path in csv_paths:
        try:
            df = pd.read_csv(csv_path)
            if "distance_class1_class2" not in df.columns:
                df = compute_distance_column(df)
                df.to_csv(csv_path, index=False)  # overwrite in place
                outputs.append(csv_path)
            else:
                # Already present; skip to avoid unnecessary writes
                pass
        except Exception as e:
            print(f"[WARN] Skipped {csv_path}: {e}")
    return outputs

# Example:
outputs = process_main_directory(main_directory)
print(f"Updated {len(outputs)} files.")

In [None]:
# JUPYTER CELL — Normalize distance_class1_class2 to percentage per fly
# Bottom 5% of values (global) are trimmed; min = smallest remaining; max = global (untrimmed).

from pathlib import Path
import pandas as pd
import numpy as np

# --- Config ---
DIST_COL  = "distance_class1_class2"
PCT_COL   = "distance_class1_class2_pct"
TRIM_FRAC = 0.05  # bottom 5%

def _find_fly_dirs(main_directory: Path) -> list[Path]:
    """Return immediate subdirectories that look like fly folders."""
    return [p for p in main_directory.iterdir() if p.is_dir()]

def _gather_csvs_for_fly(fly_dir: Path) -> list[Path]:
    """
    Use the updated CSVs written in place by the previous script:
      <fly>/Eye_Antenna_Dist/*combined.csv
    """
    epd = fly_dir / "Eye_Antenna_Dist"
    if not epd.is_dir():
        return []
    return sorted(epd.glob("*combined.csv"))

def _compute_trimmed_min_and_global_max(csv_paths: list[Path]) -> tuple[float, float, float, int, int] | None:
    """
    Collect all DIST_COL values across the fly, drop NaNs.
    Compute p5 = 5th percentile, drop values < p5, then:
      - trimmed_min = min(remaining)
      - global_max  = max(all)
    Returns (trimmed_min, global_max, p5, total_count, kept_count)
    """
    vals = []
    for p in csv_paths:
        try:
            df = pd.read_csv(p, usecols=[DIST_COL])
            vals.append(pd.to_numeric(df[DIST_COL], errors="coerce").to_numpy())
        except ValueError:
            # Column missing; skip file
            continue
        except Exception as e:
            print(f"[WARN] Could not read {p.name}: {e}")
    if not vals:
        return None

    all_vals = np.concatenate(vals)
    all_vals = all_vals[np.isfinite(all_vals)]
    total = all_vals.size
    if total == 0:
        return None

    global_max = float(np.max(all_vals))
    p5 = float(np.percentile(all_vals, 100 * TRIM_FRAC, method="linear"))

    kept = all_vals[all_vals >= p5]
    kept_count = kept.size
    if kept_count == 0:
        trimmed_min = float(np.min(all_vals))  # fallback if everything trimmed
        kept_count = 0
    else:
        trimmed_min = float(np.min(kept))

    return trimmed_min, global_max, p5, total, kept_count

def _write_with_percentage(csv_path: Path, global_min: float, global_max: float) -> Path | None:
    """Add percentage column based on existing DIST_COL and write alongside original."""
    try:
        df = pd.read_csv(csv_path)
        if DIST_COL not in df.columns:
            print(f"[SKIP] {csv_path.name}: missing '{DIST_COL}'")
            return None

        dist = pd.to_numeric(df[DIST_COL], errors="coerce")
        if np.isfinite(global_min) and np.isfinite(global_max) and (global_max > global_min):
            pct = (dist - global_min) / (global_max - global_min) * 100.0
        else:
            pct = pd.Series(np.where(dist.notna(), 0.0, np.nan), index=df.index)

        df[PCT_COL] = pct
        out_path = csv_path.with_name(csv_path.stem + ".csv")
        df.to_csv(out_path, index=False)
        return out_path
    except Exception as e:
        print(f"[WARN] Failed {csv_path}: {e}")
        return None

def normalize_distances_to_percentage(main_directory: str | Path) -> dict[str, dict]:
    """
    For each fly:
      - Use <fly>/Eye_Antenna_Dist/*combined.csv (updated in place by prior script).
      - Compute 5th percentile cutoff across those CSVs.
      - Trim values < p5; use min(remaining) as global_min; global (untrimmed) max as global_max.
      - Write new CSVs with 'distance_class1_class2_pct' appended.
    Returns a summary dict.
    """
    main_directory = Path(main_directory)
    assert main_directory.is_dir(), f"Not a directory: {main_directory}"

    summary = {}
    for fly_dir in _find_fly_dirs(main_directory):
        csvs = _gather_csvs_for_fly(fly_dir)
        if not csvs:
            continue

        mm = _compute_trimmed_min_and_global_max(csvs)
        if mm is None:
            summary[fly_dir.name] = {"files": len(csvs), "written": 0, "note": "No valid distances found"}
            print(f"[INFO] {fly_dir.name}: no valid '{DIST_COL}' values across files.")
            continue

        trimmed_min, global_max, p5, total_cnt, kept_cnt = mm
        written = 0
        for p in csvs:
            out = _write_with_percentage(p, trimmed_min, global_max)
            written += int(out is not None)

        summary[fly_dir.name] = {
            "files": len(csvs),
            "written": written,
            "p5_cutoff": p5,
            "trimmed_min": trimmed_min,
            "global_max": global_max,
            "total_points": total_cnt,
            "kept_points": kept_cnt,
        }
        print(f"[DONE] {fly_dir.name}: p5={p5:.6g}, trimmed_min={trimmed_min:.6g}, "
              f"global_max={global_max:.6g}, kept={kept_cnt}/{total_cnt}, wrote {written}/{len(csvs)}")
    return summary

# --- Example usage ---
report = normalize_distances_to_percentage(main_directory)
report

In [None]:
# JUPYTER CELL — Heat maps using in-place files from Eye_Antenna_Dist/*combined.csv
# If distance_class1_class2_pct is missing, compute it IN PLACE (no new CSVs).
# Per-fly normalization: trim bottom 5% globally; min = smallest remaining; max = global (untrimmed).

import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import Normalize
from datetime import datetime

# ───────── SETTINGS ─────────
DEBUG = True
FPS_DEFAULT = 40.0
SRC_DIRNAME = "Eye_Antenna_Dist"
OUT_DIRNAME = "heat_maps_distance_pct"     # under <fly>/<SRC_DIRNAME>/
PCT_COL = "distance_class1_class2_pct"
DIST_COL = "distance_class1_class2"
TRIM_FRAC = 0.05  # bottom 5% trim per fly when computing PCT if missing

plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'lines.linewidth': 2,
    'axes.linewidth': 1.5,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'legend.fontsize': 12
})

def dprint(*args):
    if DEBUG:
        print(*args)

# Robust timestamp parser (supports "HH:MM:SS:MS", "HH:MM:SS", "MM:SS", seconds)
def timestamp_to_seconds(ts):
    if pd.isna(ts):
        return np.nan
    try:
        return float(ts)
    except Exception:
        pass
    s = str(ts).strip()
    parts = s.split(":")
    try:
        if len(parts) == 4:
            hh, mm, ss, ms = parts
            return int(hh)*3600 + int(mm)*60 + int(ss) + int(ms)/1000.0
        if len(parts) == 3:
            hh, mm, ss = parts
            return int(hh)*3600 + int(mm)*60 + float(ss)
        if len(parts) == 2:
            mm, ss = parts
            return int(mm)*60 + float(ss)
        if len(parts) == 1:
            return float(parts[0])
    except Exception:
        return np.nan
    return np.nan

def collect_fly_folders(main_directory: Path) -> list[Path]:
    return [p for p in Path(main_directory).iterdir() if p.is_dir()]

def find_col(df, candidates):
    cols = set(df.columns)
    for c in candidates:
        if c in cols:
            return c
    return None

def _derive_fps(df) -> float:
    fps_col = find_col(df, ["fps","FPS","frame_rate","frameRate"])
    if fps_col is not None:
        try:
            fps_val = pd.to_numeric(df[fps_col], errors="coerce").median()
            if np.isfinite(fps_val) and fps_val > 0:
                return float(fps_val)
        except Exception:
            pass
    return FPS_DEFAULT

def _ensure_time_series(df, frame_col, ts_col, debug_lines) -> tuple[pd.Series, dict]:
    meta = {'used': None, 'valid_ts': 0, 'total': len(df), 'fps': None}

    if ts_col is not None:
        if ts_col in ["time_seconds", "relative_time"]:
            ts = pd.to_numeric(df[ts_col], errors="coerce")
        else:
            ts = df[ts_col].apply(timestamp_to_seconds)
        valid = ts.notna().sum()
        meta['valid_ts'] = int(valid)
        if valid >= 2 and np.nanmax(np.diff(ts.dropna().values)) > 0:
            meta['used'] = 'timestamp'
            return ts, meta
        debug_lines.append(
            f"Timestamp parse insufficient (valid={valid}/{len(df)} or non-increasing). Falling back."
        )

    if frame_col is not None:
        frames = pd.to_numeric(df[frame_col], errors="coerce")
        if frames.notna().sum() >= 2:
            fps = _derive_fps(df)
            meta['used'] = 'frame_fallback'
            meta['fps'] = fps
            f0 = int(np.nanmin(frames.values))
            ts_fb = (frames - f0) / fps
            return ts_fb, meta

    fps = _derive_fps(df)
    meta['used'] = 'index_fallback'
    meta['fps'] = fps
    ts_fb = pd.Series(np.arange(len(df), dtype=float) / fps, index=df.index)
    return ts_fb, meta

# ───────── Per-fly PCT computation (in-place if missing) ─────────

def _gather_inplace_csvs(fly_dir: Path) -> list[Path]:
    src = fly_dir / SRC_DIRNAME
    if not src.is_dir():
        return []
    return sorted(src.glob("*combined.csv"))

def _compute_fly_trim_min_and_max(csv_paths: list[Path]) -> tuple[float, float] | None:
    vals = []
    for p in csv_paths:
        try:
            df = pd.read_csv(p, usecols=[DIST_COL])
            v = pd.to_numeric(df[DIST_COL], errors="coerce").to_numpy()
            vals.append(v)
        except Exception:
            continue
    if not vals:
        return None
    all_vals = np.concatenate(vals)
    all_vals = all_vals[np.isfinite(all_vals)]
    if all_vals.size == 0:
        return None
    global_max = float(np.max(all_vals))
    p5 = float(np.percentile(all_vals, 100*TRIM_FRAC, method="linear"))
    kept = all_vals[all_vals >= p5]
    trimmed_min = float(np.min(kept)) if kept.size else float(np.min(all_vals))
    return trimmed_min, global_max

def _ensure_pct_inplace(csv_paths: list[Path], fly_min: float, fly_max: float) -> list[Path]:
    """
    Add PCT_COL in place (no new files) where missing.
    """
    updated = []
    for p in csv_paths:
        try:
            df = pd.read_csv(p)
            if PCT_COL not in df.columns:
                dist = pd.to_numeric(df.get(DIST_COL, np.nan), errors="coerce")
                if np.isfinite(fly_min) and np.isfinite(fly_max) and (fly_max > fly_min):
                    pct = (dist - fly_min) / (fly_max - fly_min) * 100.0
                else:
                    pct = pd.Series(np.where(dist.notna(), 0.0, np.nan), index=df.index)
                df[PCT_COL] = pct
                df.to_csv(p, index=False)  # overwrite in place
                updated.append(p)
        except Exception as e:
            dprint(f"[WARN] Could not add {PCT_COL} to {p.name}: {e}")
    return updated

# ───────── Heatmap building ─────────

def build_trial_from_csv(csv_path: Path, debug_lines):
    """
    Returns (trial_dict, reason_if_skipped)
    trial_dict keys: label, time, data, odor_start, odor_end
    """
    try:
        df = pd.read_csv(csv_path)
        df.columns = df.columns.str.strip()

        frame_col = find_col(df, ["frame", "Frame", "frame_num", "frame_index"])
        ts_col    = find_col(df, ["timestamp", "Timestamp", "time", "Time", "time_seconds", "relative_time"])
        ofm_col   = find_col(df, ["OFM State", "OFM_State", "ofm_state", "ofm"])

        if PCT_COL not in df.columns:
            return None, f"Missing '{PCT_COL}'"

        # Build time_seconds with robust fallback
        time_seconds, meta = _ensure_time_series(df, frame_col, ts_col, debug_lines)
        df["time_seconds"] = pd.to_numeric(time_seconds, errors="coerce")
        valid_after = df["time_seconds"].notna().sum()

        debug_lines.append(
            f"{csv_path.name}: time_source={meta['used']}, valid_ts={meta['valid_ts']}/{meta['total']}, fps={meta['fps']}"
        )

        if valid_after < 2:
            return None, "Too few usable time points after fallback"

        df["relative_time"] = df["time_seconds"] - df["time_seconds"].min()

        # Odor window detection (if present)
        pre_sec, post_sec = 30.0, 90.0
        odor_duration = 30.0
        if ofm_col is not None:
            ofm_series = df[ofm_col].astype(str).str.lower()
            odor_mask = ofm_series == "during"
            if odor_mask.any():
                rt = df.loc[odor_mask, "relative_time"]
                if len(rt) >= 2:
                    odor_duration = max(float(rt.iloc[-1] - rt.iloc[0]), 0.0)

        trial_label = os.path.splitext(os.path.basename(csv_path))[0]
        final_total_duration = pre_sec + odor_duration + post_sec

        if frame_col is not None:
            frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
        else:
            frames = pd.Series(np.arange(len(df), dtype=int), index=df.index)

        if frames.empty:
            return None, "No valid frame information"

        total_frames = np.arange(frames.min(), frames.max() + 1, dtype=int)
        full_data = np.full(total_frames.shape, np.nan, dtype=float)

        idx_map = {f: i for i, f in enumerate(total_frames)}
        present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
        if len(present_idx) == 0:
            return None, "Frame alignment failed (no overlap)."

        vals = pd.to_numeric(df.loc[frames.index, PCT_COL], errors="coerce").to_numpy()
        full_data[np.array(present_idx, dtype=int)] = vals

        data_for_plot = full_data.copy()
        data_for_plot[data_for_plot == -1]  = -0.5
        data_for_plot[data_for_plot == 101] = 101

        new_time = np.linspace(0, final_total_duration, len(total_frames))

        return {
            "label": trial_label,
            "time": new_time,
            "data": data_for_plot,
            "odor_start": pre_sec,
            "odor_end": pre_sec + odor_duration
        }, None

    except Exception as e:
        return None, f"Exception: {e}"

def write_debug_log(out_dir: Path, lines: list[str]):
    out_dir.mkdir(parents=True, exist_ok=True)
    log_path = out_dir / "_debug_heatmap_log.txt"
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(f"\n===== {datetime.now().isoformat(timespec='seconds')} =====\n")
        for ln in lines:
            f.write(ln.rstrip() + "\n")
    dprint(f"[LOG] {log_path}")

def plot_fly(main_directory: Path, fly_dir: Path):
    """
    Build TWO heatmaps per fly from in-place files <fly>/Eye_Antenna_Dist/*combined.csv:
      - <fly>_testing_heatmap_distance_pct.png  (files with 'testing' in name)
      - <fly>_training_heatmap_distance_pct.png (files with 'training' in name)
    If PCT_COL is missing, compute it in-place first (per-fly trimmed-min / global-max).
    """
    src_dir = fly_dir / SRC_DIRNAME
    out_dir = src_dir / OUT_DIRNAME
    out_dir.mkdir(parents=True, exist_ok=True)

    debug_lines = [
        f"Fly: {fly_dir.name}",
        f"Input dir: {src_dir}",
        f"Output dir: {out_dir}"
    ]

    # 1) Gather in-place CSVs
    csv_files = _gather_inplace_csvs(fly_dir)
    if not csv_files:
        write_debug_log(out_dir, debug_lines + ["No CSVs to process."])
        return

    # 2) Ensure PCT exists in-place for ALL files (compute per-fly min/max once)
    stats = _compute_fly_trim_min_and_max(csv_files)
    if stats is None:
        write_debug_log(out_dir, debug_lines + [f"No valid '{DIST_COL}' values across files."])
        return
    fly_min, fly_max = stats
    updated = _ensure_pct_inplace(csv_files, fly_min, fly_max)
    if updated:
        debug_lines.append(f"Added {PCT_COL} to: {', '.join(p.name for p in updated)}")

    # 3) Split by category using filename
    testing_files  = [p for p in csv_files if "testing"  in p.name.lower()]
    training_files = [p for p in csv_files if "training" in p.name.lower()]

    # Helper to plot one category
    def _plot_category(cat_name: str, files: list[Path]):
        if not files:
            debug_lines.append(f"No {cat_name} files; skipping {cat_name} figure.")
            return
        trials, skipped_files = [], 0
        for csv_path in sorted(files):
            tr, reason = build_trial_from_csv(csv_path, debug_lines)
            if tr is None:
                skipped_files += 1
                msg = f"{csv_path.name}: {reason}"
                dprint(f"  [SKIP FILE] {msg}")
                debug_lines.append(f"SKIP FILE: {msg}")
            else:
                trials.append(tr)
        if not trials:
            debug_lines.append(f"All {cat_name} files skipped ({skipped_files} skipped). No figure produced.")
            return

        # Colormap and norm (linear 0–100 %; NaN→dimgray, under/over→pink)
        cmap = mpl.colormaps['viridis'].copy()
        cmap.set_under('pink')
        cmap.set_over('pink')
        cmap.set_bad('dimgray')
        norm = Normalize(vmin=0, vmax=100)

        n = len(trials)
        fig_h = max(2, 2 * n)
        fig, axs = plt.subplots(n, 1, figsize=(18, fig_h), sharex=True)
        if n == 1:
            axs = [axs]

        legend_handles = [Line2D([0], [0], color='red', linewidth=2.5, label='Odor Period')]
        pcm_last = None

        for i, tr in enumerate(trials):
            ax = axs[i]
            time_edges = np.linspace(tr["time"][0], tr["time"][-1], len(tr["time"]) + 1)
            X, Y = np.meshgrid(time_edges, [0, 1])
            row = tr["data"].reshape(1, -1)
            pcm_last = ax.pcolormesh(X, Y, row, cmap=cmap, shading='auto', norm=norm)
            ax.set_yticks([])
            ax.tick_params(axis='x', direction='out')
            ax.set_xlim(tr["time"][0], tr["time"][-1])
            ax.set_title(tr["label"], loc='left')
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.axvline(tr["odor_start"], color='red', linewidth=2.5)
            ax.axvline(tr["odor_end"],   color='red', linewidth=2.5)

        axs[-1].set_xlabel("Time (seconds)")
        fly_name = fly_dir.name
        fig.suptitle(f"{fly_name} – {cat_name.capitalize()} Trials (distance_class1_class2_pct)", fontsize=20)
        fig.tight_layout(rect=[0, 0, 1, 0.96])

        cbar = fig.colorbar(pcm_last, ax=axs, orientation='vertical', fraction=0.02, pad=0.04, extend='both')
        cbar.set_label("Distance Percentage (%)", fontsize=14)
        fig.legend(handles=legend_handles, loc='upper right', frameon=True)

        out_path = out_dir / f"{fly_name}_{cat_name}_heatmap_distance_pct.png"
        fig.savefig(out_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

        dprint(f"[SAVED] {out_path}")
        debug_lines.append(f"SAVED: {out_path}")

    # 4) Emit two PNGs (if files exist)
    _plot_category("testing", testing_files)
    _plot_category("training", training_files)

    # 5) Write combined debug log
    write_debug_log(out_dir, debug_lines)

def plot_all_flies(main_directory: str | Path):
    """
    Scan only fly folders whose names START with a calendar month (full or 3-letter abbrev),
    case-insensitive. Examples matched: 'August_19_fly_1', 'aug_12_fly_3', 'Dec-05_fly_2'.
    """
    main_directory = Path(main_directory)
    assert main_directory.is_dir(), f"Not a directory: {main_directory}"
    dprint(f"Scanning main_directory: {main_directory}")

    # Month prefixes to match (lowercase)
    month_prefixes = (
        "january","february","march","april","may","june",
        "july","august","september","october","november","december",
        "jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec"
    )

    # Only include immediate subfolders starting with a month prefix
    all_dirs = collect_fly_folders(main_directory)
    fly_dirs = [p for p in all_dirs if p.name.lower().startswith(month_prefixes)]

    if not fly_dirs:
        print(f"[WARN] No month-prefixed fly folders found in {main_directory}")
        return

    for fly_dir in fly_dirs:
        dprint(f"\n== Fly: {fly_dir.name} ==")
        plot_fly(main_directory, fly_dir)


# --- Example usage ---
# main_directory = "/path/to/root/with/fly_folders"
plot_all_flies(main_directory)


In [None]:
# JUPYTER CELL — Collect all heatmap PNGs into <main_directory>/heat_maps/Eye_Antenna_Dist/
# Organizes into subfolders: testing/, training/, other/
from pathlib import Path
import shutil

# Adjust if you want to move instead of copy
MOVE_FILES = False  # False = copy2 (preserve originals), True = move

# Source pattern produced by your plotting code
SRC_SUBPATH = ("Eye_Antenna_Dist", "heat_maps_distance_pct")

# Month-prefixed folder filter (full + common abbreviations)
_MONTH_PREFIXES = (
    "january","february","march","april","may","june",
    "july","august","september","october","november","december",
    "jan","feb","mar","apr","may","jun","jul","aug","sep","sept","oct","nov","dec"
)

def _is_month_prefixed(name: str) -> bool:
    return name.lower().startswith(_MONTH_PREFIXES)

def collect_heatmaps(main_directory: str | Path, *, move: bool = MOVE_FILES) -> int:
    """
    Gather all PNGs from <fly>/Eye_Antenna_Dist/heat_maps_distance_pct/*.png
    where <fly> starts with a month, and copy/move them to:
        <main_directory>/heat_maps/Eye_Antenna_Dist/{testing|training|other}/
    Returns the number of images copied/moved.
    """
    root = Path(main_directory)
    assert root.is_dir(), f"Not a directory: {root}"

    dest_root = root / "heat_maps" / "Eye_Antenna_Dist"
    (dest_root / "testing").mkdir(parents=True, exist_ok=True)
    (dest_root / "training").mkdir(parents=True, exist_ok=True)
    (dest_root / "other").mkdir(parents=True, exist_ok=True)

    n = 0
    for fly_dir in (p for p in root.iterdir() if p.is_dir() and _is_month_prefixed(p.name)):
        src_dir = fly_dir.joinpath(*SRC_SUBPATH)
        if not src_dir.is_dir():
            continue
        for img in sorted(src_dir.glob("*.png")):
            name_l = img.name.lower()
            category = "testing" if "testing" in name_l else ("training" if "training" in name_l else "other")
            dest_dir = dest_root / category
            dest_path = dest_dir / img.name

            # Avoid collisions by suffixing _1, _2, ...
            if dest_path.exists():
                stem, ext = dest_path.stem, dest_path.suffix
                k = 1
                while (dest_dir / f"{stem}_{k}{ext}").exists():
                    k += 1
                dest_path = dest_dir / f"{stem}_{k}{ext}"

            if move:
                shutil.move(str(img), dest_path)
                action = "Moved"
            else:
                shutil.copy2(str(img), dest_path)
                action = "Copied"

            print(f"{action}: {img} -> {dest_path}")
            n += 1

    print(f"Total {'moved' if move else 'copied'}: {n}")
    return n

# --- Example usage ---
collect_heatmaps(main_directory, move=False)

# Mega Heat Maps

In [None]:
# JUPYTER CELL — Mega heatmaps per fly (training & testing), adapted to new layout
# Row 1: Eye_Prob_Dist
# Row 2: Centered angle % (month folders, *_distance_class_2_angle_ARB.csv preferred)
# Row 3: Eye_Antenna_Dist (trim-5% min / global max fallback if pct missing)

import os, re, glob
from pathlib import Path
from typing import Optional, Tuple, Dict, List
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.lines import Line2D
from matplotlib.transforms import Bbox

plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'lines.linewidth': 2,
    'axes.linewidth': 1.25,
    'figure.dpi': 160,
    'savefig.dpi': 300,
    'legend.fontsize': 11
})

# ───────── Config ─────────
FPS_DEFAULT   = 40.0
VIRIDIS_NORM  = Normalize(vmin=0,   vmax=100)
COOLWARM_NORM = Normalize(vmin=-100, vmax=100)
ANCHOR_X, ANCHOR_Y = 1080.0, 540.0

PRE_SEC, POST_SEC = 30.0, 90.0
TRIM_FRAC = 0.05  # bottom 5% trim (Eye_Antenna_Dist fallback pct normalization)

MONTHS = ("january","february","march","april","may","june",
          "july","august","september","october","november","december",
          "jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec")

# ───────── Utilities ─────────
def timestamp_to_seconds(ts) -> float:
    if pd.isna(ts): return np.nan
    try: return float(ts)
    except Exception: pass
    s = str(ts).strip(); parts = s.split(":")
    try:
        if len(parts) == 4: hh,mm,ss,ms = parts; return int(hh)*3600+int(mm)*60+int(ss)+int(ms)/1000.0
        if len(parts) == 3: hh,mm,ss   = parts; return int(hh)*3600+int(mm)*60+float(ss)
        if len(parts) == 2: mm,ss      = parts; return int(mm)*60+float(ss)
        if len(parts) == 1: return float(parts[0])
    except Exception:
        return np.nan
    return np.nan

def find_col(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    cols = set(df.columns)
    for c in cands:
        if c in cols: return c
    return None

def derive_fps(df: pd.DataFrame) -> float:
    fps_col = find_col(df, ["fps","FPS","frame_rate","frameRate"])
    if fps_col is not None:
        val = pd.to_numeric(df[fps_col], errors="coerce").median()
        if np.isfinite(val) and val > 0: return float(val)
    return FPS_DEFAULT

def ensure_time_series(df: pd.DataFrame, frame_col: Optional[str], ts_col: Optional[str]):
    # robust: timestamp → frame → index
    if ts_col is not None:
        if ts_col in ["time_seconds","relative_time","time_s"]:
            ts = pd.to_numeric(df[ts_col], errors="coerce")
        else:
            ts = df[ts_col].apply(timestamp_to_seconds)
        if ts.notna().sum() >= 2 and np.nanmax(np.diff(ts.dropna().values)) > 0:
            return ts, {'used':'timestamp'}
    if frame_col is not None:
        frames = pd.to_numeric(df[frame_col], errors="coerce")
        if frames.notna().sum() >= 2:
            fps = derive_fps(df); f0 = int(np.nanmin(frames.values))
            return (frames - f0)/fps, {'used':'frame_fallback','fps':fps}
    fps = derive_fps(df)
    return pd.Series(np.arange(len(df), dtype=float)/fps, index=df.index), {'used':'index_fallback','fps':fps}

def _extract_trial_fallback(stem: str, category: str | None) -> Optional[int]:
    """
    Robust trial index extraction:
      1) '{category}_(\d+)' or '(training|testing)_(\d+)'.
      2) last numeric group in the stem.
      3) None → caller may auto-assign.
    """
    s = stem.lower()
    if category:
        m = re.search(rf"{category}_(\d+)", s)
        if m:
            try: return int(m.group(1))
            except: pass
    m2 = re.search(r"(training|testing)_(\d+)", s)
    if m2:
        try: return int(m2.group(2))
        except: pass
    m3 = re.findall(r"(\d+)", s)
    if m3:
        try: return int(m3[-1])
        except: pass
    return None
    
def odor_window_from_ofm(df: pd.DataFrame, time_col: str = "relative_time"):
    ofm_col = find_col(df, ["OFM State","OFM_State","ofm_state","ofm"])
    if ofm_col is None or time_col not in df.columns: return None
    try:
        mask = df[ofm_col].astype(str).str.lower().eq("during")
        if mask.any():
            t = df[time_col]; idx = np.flatnonzero(mask.to_numpy())
            if idx.size >= 2: return float(t.iloc[idx[0]]), float(t.iloc[idx[-1]])
    except Exception:
        return None
    return None

def build_row(time_seconds: np.ndarray, values: np.ndarray, odor_start: float, odor_end: float,
              label: str, apply_sentinels=True) -> dict:
    if len(time_seconds) < 2: return {}
    time_edges = np.linspace(time_seconds[0], time_seconds[-1], len(time_seconds) + 1)
    row_vals = np.array(values, dtype=float).copy()
    if apply_sentinels:
        row_vals[row_vals == -1]  = -0.5
        row_vals[row_vals == 101] = 101
    return {"label": label, "time_edges": time_edges, "data_row": row_vals.reshape(1,-1),
            "odor_start": float(odor_start), "odor_end": float(odor_end)}

# ───────── Angle helpers (per your angle heatmap) ─────────
def compute_angle_deg_at_point2(df: pd.DataFrame) -> pd.Series:
    req = ["x_class2","y_class2","x_class6","y_class6"]
    if any(c not in df.columns for c in req): raise ValueError(f"Missing angle columns: {req}")
    p2x = df["x_class2"].astype(float).to_numpy(); p2y = df["y_class2"].astype(float).to_numpy()
    p3x = df["x_class6"].astype(float).to_numpy(); p3y = df["y_class6"].astype(float).to_numpy()
    ux, uy = (ANCHOR_X - p2x), (ANCHOR_Y - p2y)
    vx, vy = (p3x - p2x), (p3y - p2y)
    dot = ux*vx + uy*vy; cross = ux*vy - uy*vx
    n1 = np.hypot(ux, uy); n2 = np.hypot(vx, vy)
    valid = (n1 > 0) & (n2 > 0) & np.isfinite(dot) & np.isfinite(cross)
    ang = np.full(len(p2x), np.nan); ang[valid] = np.arctan2(np.abs(cross[valid]), dot[valid])
    return pd.Series(np.degrees(ang), index=df.index, name="angle_ARB_deg")

def find_fly_reference_angle(csvs_raw: List[Path]) -> float:
    best = None
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            dist_col = find_col(df, ["distance_percentage","distance_percent","distance_pct",
                                     "distance_class1_class2_pct"])
            if dist_col is None: continue
            dist = pd.to_numeric(df[dist_col], errors="coerce").to_numpy()
            exact = np.flatnonzero(dist == 0)
            if exact.size > 0:
                idx = int(exact[0]); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (0, 0.0, angle_here)
            else:
                with np.errstate(invalid="ignore"): absd = np.abs(dist)
                if not np.isfinite(absd).any(): continue
                idx = int(np.nanargmin(absd)); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (1, float(absd[idx]), angle_here)
            if best is None or cand < best: best = cand
        except Exception:
            pass
    return best[2] if best is not None else np.nan

def compute_fly_max_abs_centered(csvs_raw: List[Path], ref_angle: float) -> float:
    fly_max = 0.0
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            centered = ang - ref_angle if np.isfinite(ref_angle) else ang*0.0
            local = np.nanmax(np.abs(centered.to_numpy(dtype=float)))
            if np.isfinite(local): fly_max = max(fly_max, float(local))
        except Exception:
            pass
    return fly_max if np.isfinite(fly_max) and fly_max > 0 else np.nan

# ───────── Row collectors (new locations) ─────────
def collect_rows_eye_prob_dist(fly_dir: Path, category: str) -> Dict[int, dict]:
    """
    Eye_Prob_Dist/<category>/*.csv  (or fallback: Eye_Prob_Dist/*category*.csv)
    Expects a *percentage* column (tries common names).
    """
    out = {}
    base = fly_dir / "Eye_Prob_Dist" / category
    candidates = []
    if base.is_dir():
        candidates = sorted([p for p in base.glob("*.csv") if p.is_file()])
    else:
        root = fly_dir / "Eye_Prob_Dist"
        if root.is_dir():
            candidates = sorted([p for p in root.glob("*.csv") if category in p.name.lower()])

    for p in candidates:
        try:
            df = pd.read_csv(p); df.columns = df.columns.str.strip()
            frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
            ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
            pct_col   = find_col(df, [
                "distance_percentage","distance_percent","distance_pct",
                "distance_proboscis_eye_pct","proboscis_eye_distance_pct",
                "eye_prob_distance_pct","distance_class1_class2_pct"  # keep broad
            ])
            if pct_col is None: 
                continue  # nothing sensible to plot here
            ts, _ = ensure_time_series(df, frame_col, ts_col)
            if ts.notna().sum() < 2: 
                continue
            # relative timeline with default odor window if OFM missing
            time_s = pd.to_numeric(ts, errors="coerce"); t0 = float(np.nanmin(time_s))
            rel = time_s - t0
            df["relative_time"] = rel
            odor = odor_window_from_ofm(df, "relative_time")
            if odor is None:
                odor = (PRE_SEC, PRE_SEC + 30.0)
            odor_start, odor_end = odor
            total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

            # build dense frame-aligned row if frames exist; else use raw length
            if frame_col is not None and frame_col in df.columns:
                frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
                total_frames = np.arange(frames.min(), frames.max()+1, dtype=int)
                idx_map = {f:i for i,f in enumerate(total_frames)}
                present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
                vals = pd.to_numeric(df.loc[frames.index, pct_col], errors="coerce").to_numpy()
                full = np.full_like(total_frames, np.nan, dtype=float)
                if present_idx: full[np.array(present_idx, dtype=int)] = vals
                time_axis = np.linspace(0, total_duration, len(total_frames))
            else:
                vals = pd.to_numeric(df[pct_col], errors="coerce").to_numpy()
                time_axis = np.linspace(0, total_duration, len(vals)); full = vals

            tri = _extract_trial_fallback(p.stem, category)
            if tri is None: 
                continue
            row = build_row(time_axis, full, PRE_SEC, PRE_SEC+(odor_end-odor_start), p.stem, True)
            if row: out[tri] = row
        except Exception:
            continue
    return out

# ───────── helpers (add if not present) ─────────
# ───────── helpers (add if not present) ─────────
MONTHS = (
    "january","february","march","april","may","june",
    "july","august","september","october","november","december",
    "jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec"
)

def _is_month_folder(p: Path) -> bool:
    nm = p.name.lower()
    return any(nm.startswith(m) for m in MONTHS)

# (Optional) Stronger category inference: recognizes train/test shorthands.
def _infer_category_from_path(p: Path) -> str | None:
    tokens = " ".join([*p.parts, p.stem]).lower()
    if any(t in tokens for t in ("training","train","trn")): return "training"
    if any(t in tokens for t in ("testing","test","tst")):  return "testing"
    return None

# REPLACE this function
def collect_rows_angle_centered_pct(fly_dir: Path, category: str) -> Dict[int, dict]:
    """
    Centered angle % for the mega heatmap:
      - Prefer *_distance_class_2_angle_ARB.csv (use angle_centered_pct if present).
      - Fallback to *_distance_class_2.csv and compute centered% per-fly.
      - Searches the entire fly_dir (not limited to month folders).
      - Robust trial inference and odor-window recovery from raw CSVs.
    """
    out: Dict[int, dict] = {}

    # 1) Discover candidates anywhere under the fly
    angle_csvs = list(fly_dir.rglob("*_distance_class_2_angle_ARB.csv"))
    raw_csvs   = list(fly_dir.rglob("*_distance_class_2_angle_ARB.csv"))

    # Keep files that match the target category or are ambiguous
    def _cat_match(p: Path, target: str) -> bool:
        inf = _infer_category_from_path(p)
        return (inf == target) or (inf is None)

    angle_cat = [p for p in angle_csvs if _cat_match(p, category)]
    raw_cat   = [p for p in raw_csvs   if _cat_match(p, category)]
    if not angle_cat and not raw_cat:
        return out

    # 2) Per-fly reference and max from ALL raw (most stable)
    raw_for_ref = raw_csvs if raw_csvs else raw_cat
    ref_angle   = find_fly_reference_angle(raw_for_ref) if raw_for_ref else np.nan
    fly_max_abs = compute_fly_max_abs_centered(raw_for_ref, ref_angle) if np.isfinite(ref_angle) else np.nan

    # 3) Prefer angle files over raw when both exist for the same logical trial
    def _base_key(p: Path) -> str:
        s = p.stem
        return s.replace("_distance_class_2_angle_ARB","").replace("_distance_class_2","")

    chosen: Dict[str, Path] = {}
    for p in sorted(raw_cat):
        chosen.setdefault(_base_key(p), p)
    for p in sorted(angle_cat):
        chosen[_base_key(p)] = p  # overwrite with angle file

    # 4) Build rows
    for base in sorted(chosen.keys()):
        p = chosen[base]
        try:
            df = pd.read_csv(p); df.columns = df.columns.str.strip()

            # Trial index: explicit → last digits fallback
            tri = _extract_trial_fallback(p.stem, category)
            if tri is None:
                nums = [int(x) for x in re.findall(r"\d+", p.stem)]
                if not nums: 
                    continue
                tri = nums[-1]

            # Centered %: use column if present; else compute from coords/raw
            if "angle_centered_pct" in df.columns:
                pct = pd.to_numeric(df["angle_centered_pct"], errors="coerce").to_numpy()
            else:
                if {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns):
                    ang = compute_angle_deg_at_point2(df).to_numpy(dtype=float)
                else:
                    partner = p.with_name(p.name.replace("_distance_class_2_angle_ARB.csv","_distance_class_2.csv"))
                    if not partner.exists(): 
                        continue
                    dfr = pd.read_csv(partner)
                    ang = compute_angle_deg_at_point2(dfr).to_numpy(dtype=float)

                if np.isfinite(fly_max_abs) and fly_max_abs > 0 and np.isfinite(ref_angle):
                    pct = ((ang - ref_angle) / fly_max_abs) * 100.0
                else:
                    pct = np.zeros_like(ang, dtype=float)  # still renders; flags missing ref/max

            # Odor window: try sibling/nearby raw; else default
            odor = None
            sib_raw = p.with_name(p.name.replace("_distance_class_2_angle_ARB.csv","_distance_class_2.csv"))
            candidates = [sib_raw] if sib_raw.exists() else []
            candidates += [q for q in p.parent.glob("*_distance_class_2.csv") if _cat_match(q, category)]
            if not candidates:
                candidates = raw_cat  # last resort

            for q in candidates:
                try:
                    dfr = pd.read_csv(q)
                    fr_col = find_col(dfr, ["frame","Frame","frame_num","frame_index"])
                    ts_col = find_col(dfr, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
                    ts, _  = ensure_time_series(dfr, fr_col, ts_col)
                    dfr["time_seconds"] = pd.to_numeric(ts, errors="coerce")
                    dfr["relative_time"] = dfr["time_seconds"] - dfr["time_seconds"].min()
                    odor = odor_window_from_ofm(dfr, "relative_time")
                    if odor is not None:
                        break
                except Exception:
                    continue
            if odor is None:
                odor = (PRE_SEC, PRE_SEC + 30.0)

            odor_start, odor_end = odor
            total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC
            time_axis = np.linspace(0, total_duration, len(pct))

            row = build_row(time_axis, pct, PRE_SEC, PRE_SEC + (odor_end - odor_start),
                            p.stem, apply_sentinels=False)
            if row:
                out[tri] = row
        except Exception:
            continue

    return out

def _ead_compute_trim_min_max(csvs: List[Path], dist_col: str) -> Tuple[float,float] | None:
    vals = []
    for p in csvs:
        try:
            v = pd.to_numeric(pd.read_csv(p, usecols=[dist_col])[dist_col], errors="coerce").to_numpy()
            vals.append(v)
        except Exception:
            continue
    if not vals: return None
    allv = np.concatenate(vals); allv = allv[np.isfinite(allv)]
    if allv.size == 0: return None
    gmax = float(np.max(allv))
    p5   = float(np.percentile(allv, 100*TRIM_FRAC, method="linear"))
    kept = allv[allv >= p5]
    trimmed_min = float(np.min(kept)) if kept.size else float(np.min(allv))
    return trimmed_min, gmax

def collect_rows_eye_antenna_pct(fly_dir: Path, category: str) -> Dict[int, dict]:
    """
    Eye_Antenna_Dist/*combined.csv (filtered by filename containing 'training'/'testing').
    If distance_class1_class2_pct missing, compute in-memory using per-fly trimmed-min & global max.
    """
    out = {}
    base = fly_dir / "Eye_Antenna_Dist"
    if not base.is_dir(): 
        return out
    csvs = sorted([p for p in base.glob("*combined.csv") if category in p.name.lower()])
    if not csvs:
        return out

    # Column names
    PCT_COL  = "distance_class1_class2_pct"
    DIST_COL = "distance_class1_class2"  # source for fallback pct

    # Precompute per-fly stats if needed
    need_stats = False
    for p in csvs:
        try:
            hdr = pd.read_csv(p, nrows=0)
            if PCT_COL not in hdr.columns: 
                need_stats = True; break
        except Exception:
            continue
    stats = None
    if need_stats:
        stats = _ead_compute_trim_min_max(sorted(base.glob("*combined.csv")), DIST_COL)

    for p in csvs:
        try:
            df = pd.read_csv(p); df.columns = df.columns.str.strip()
            frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
            ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
            ts, _     = ensure_time_series(df, frame_col, ts_col)
            if ts.notna().sum() < 2: 
                continue

            # ensure pct in-memory
            if PCT_COL not in df.columns:
                if stats is None:
                    continue
                fly_min, fly_max = stats
                dist = pd.to_numeric(df.get(DIST_COL, np.nan), errors="coerce")
                if np.isfinite(fly_min) and np.isfinite(fly_max) and (fly_max > fly_min):
                    pct = (dist - fly_min) / (fly_max - fly_min) * 100.0
                else:
                    pct = pd.Series(np.where(dist.notna(), 0.0, np.nan), index=df.index)
                df[PCT_COL] = pct

            # time + odor
            time_s = pd.to_numeric(ts, errors="coerce"); t0 = float(np.nanmin(time_s))
            rel = time_s - t0
            df["relative_time"] = rel
            odor = odor_window_from_ofm(df, "relative_time")
            if odor is None:
                odor = (PRE_SEC, PRE_SEC + 30.0)
            odor_start, odor_end = odor
            total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

            # dense alignment if frames exist
            if frame_col is not None and frame_col in df.columns:
                frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
                total_frames = np.arange(frames.min(), frames.max()+1, dtype=int)
                idx_map = {f:i for i,f in enumerate(total_frames)}
                present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
                vals = pd.to_numeric(df.loc[frames.index, PCT_COL], errors="coerce").to_numpy()
                full = np.full_like(total_frames, np.nan, dtype=float)
                if present_idx: full[np.array(present_idx, dtype=int)] = vals
                time_axis = np.linspace(0, total_duration, len(total_frames))
            else:
                vals = pd.to_numeric(df[PCT_COL], errors="coerce").to_numpy()
                time_axis = np.linspace(0, total_duration, len(vals)); full = vals

            tri = _extract_trial_fallback(p.stem, category)
            if tri is None:
                continue
            row = build_row(time_axis, full, PRE_SEC, PRE_SEC+(odor_end-odor_start), p.stem, True)
            if row: out[tri] = row
        except Exception:
            continue
    return out

# ───────── Plotting ─────────
def union_bbox(ax_list):
    boxes = [ax.get_position() for ax in ax_list]
    if not boxes: return None
    bb = boxes[0]
    for b in boxes[1:]:
        bb = Bbox.union([bb, b])
    return bb

def plot_category_mega(fly_dir: Path, category: str):
    fly_name = fly_dir.name
    out_dir = fly_dir / "mega_heatmaps"
    out_dir.mkdir(exist_ok=True)

    # Row collectors (new layout)
    rows_top   = collect_rows_eye_prob_dist(fly_dir, category)     # Eye_Prob_Dist
    rows_angle = collect_rows_angle_centered_pct(fly_dir, category) # Centered angle %
    rows_pct   = collect_rows_eye_antenna_pct(fly_dir, category)    # Eye_Antenna_Dist robust %

    all_trials = sorted(set(rows_top.keys()) | set(rows_pct.keys()) | set(rows_angle.keys()))
    if not all_trials:
        print(f"[WARN] No trials found for {fly_name} / {category}. Skipping."); return

    n_cols, n_rows = len(all_trials), 3
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(4.6*n_cols, 2.0*n_rows + 1.0), squeeze=False)

    cmap_v = mpl.colormaps['viridis'].copy(); cmap_v.set_under('pink'); cmap_v.set_over('pink'); cmap_v.set_bad('dimgray')
    cmap_a = mpl.colormaps['coolwarm'].copy(); cmap_a.set_bad('dimgray')

    for j, tri in enumerate(all_trials):
        axs[0, j].set_title(f"{category}_{tri}", fontsize=12, pad=6)

    row_labels = ["Top distance % (Eye_Prob_Dist)", "Centered angle %", "Class1–Class2 % (Eye_Antenna_Dist)"]
    for i in range(n_rows):
        axs[i,0].text(-0.06, 0.5, row_labels[i], transform=axs[i,0].transAxes,
                      rotation=90, va='center', ha='right', fontsize=11)

    pcm_top = pcm_ang = pcm_pct = None
    for j, tri in enumerate(all_trials):
        # Row 0: Eye_Prob_Dist
        ax = axs[0, j]; r = rows_top.get(tri)
        if r:
            pcm_top = ax.pcolormesh(r["time_edges"], [0,1], r["data_row"], cmap=cmap_v, norm=VIRIDIS_NORM, shading='auto')
            ax.axvline(r["odor_start"], color='red', linewidth=2.0); ax.axvline(r["odor_end"], color='red', linewidth=2.0)
        else:
            ax.text(0.5, 0.5, "Missing", ha='center', va='center', transform=ax.transAxes)
        ax.set_yticks([]); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)

        # Row 1: Centered angle %
        ax = axs[1, j]; r = rows_angle.get(tri)
        if r:
            pcm_ang = ax.pcolormesh(r["time_edges"], [0,1], r["data_row"], cmap=cmap_a, norm=COOLWARM_NORM, shading='auto')
            ax.axvline(r["odor_start"], color='red', linewidth=2.0); ax.axvline(r["odor_end"], color='red', linewidth=2.0)
        else:
            ax.text(0.5, 0.5, "Missing", ha='center', va='center', transform=ax.transAxes)
        ax.set_yticks([]); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)

        # Row 2: Eye_Antenna_Dist robust %
        ax = axs[2, j]; r = rows_pct.get(tri)
        if r:
            pcm_pct = ax.pcolormesh(r["time_edges"], [0,1], r["data_row"], cmap=cmap_v, norm=VIRIDIS_NORM, shading='auto')
            ax.axvline(r["odor_start"], color='red', linewidth=2.0); ax.axvline(r["odor_end"], color='red', linewidth=2.0)
        else:
            ax.text(0.5, 0.5, "Missing", ha='center', va='center', transform=ax.transAxes)
        ax.set_yticks([]); ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)

    for j in range(n_cols):
        axs[2, j].set_xlabel("Time (s)")

    # Make room on the right for two colorbars
    fig.tight_layout(rect=[0.08, 0.02, 0.86, 0.94])
    fig.canvas.draw_idle()

    grid_bb = union_bbox([ax for row in axs for ax in row])
    y0, h = (grid_bb.y0, grid_bb.height) if grid_bb is not None else (0.12, 0.76)

    cbar_w = 0.015
    gap    = 0.02
    x_vir  = 0.89
    x_ang  = x_vir + cbar_w + gap

    if (pcm_top or pcm_pct):
        cax_vir = fig.add_axes([x_vir, y0, cbar_w, h])
        fig.colorbar(pcm_pct if pcm_pct else pcm_top, cax=cax_vir, orientation='vertical', extend='both')
        cax_vir.set_ylabel("Distance Percentage (%)")

    if pcm_ang:
        cax_ang = fig.add_axes([x_ang, y0, cbar_w, h])
        fig.colorbar(pcm_ang, cax=cax_ang, orientation='vertical')
        cax_ang.set_ylabel("Centered Angle (%)")

    odor_handle = [Line2D([0],[0], color='red', linewidth=2.0, label='Odor Period')]
    fig.legend(handles=odor_handle, loc='upper left', frameon=True)

    fig.suptitle(f"{fly_name} — {category.capitalize()} (All trials grouped by index)", fontsize=14)

    out_path = (fly_dir / "mega_heatmaps") / f"{fly_name}_{category}_MEGA_heatmaps.png"
    fig.savefig(out_path, bbox_inches='tight')
    plt.close(fig)
    print(f"[SAVED] {out_path}")

def plot_all_flies_mega(main_directory: Path | str):
    md = Path(main_directory)
    assert md.is_dir(), f"Not a directory: {md}"

    # Only month-prefixed fly folders (full or 3-letter), case-insensitive
    fly_dirs = [p for p in md.iterdir() if p.is_dir() and any(p.name.lower().startswith(m) for m in MONTHS)]
    if not fly_dirs:
        print(f"[WARN] No month-prefixed fly folders found in {md}"); return

    for fly_dir in sorted(fly_dirs):
        for category in ("training","testing"):
            plot_category_mega(fly_dir, category)

# Run (expects `main_directory` defined)
plot_all_flies_mega(main_directory)

In [None]:
# JUPYTER CELL — Collect all MEGA heatmaps into main_directory/MEGA_Heatmap

from pathlib import Path
import shutil

# --- Config ---
assert 'main_directory' in globals(), "Define main_directory = '/path/to/root' first."
MAIN_DIR   = Path(main_directory).resolve()
DEST_DIR   = MAIN_DIR / "MEGA_Heatmap"
MOVE_FILES = False  # set True to move instead of copy

assert MAIN_DIR.is_dir(), f"Not a directory: {MAIN_DIR}"
DEST_DIR.mkdir(exist_ok=True)

# Find all mega heatmaps created by the previous code
mega_pngs = sorted(p for p in MAIN_DIR.rglob("*_MEGA_heatmaps.png") if p.is_file())

if not mega_pngs:
    print(f"[INFO] No MEGA heatmaps found under {MAIN_DIR}")
else:
    copied, skipped = 0, 0
    for src in mega_pngs:
        dst = DEST_DIR / src.name
        if dst.exists():
            # Avoid collision by appending a numeric suffix
            stem, suf = dst.stem, dst.suffix
            k = 1
            while True:
                alt = DEST_DIR / f"{stem} ({k}){suf}"
                if not alt.exists():
                    dst = alt
                    break
                k += 1
        try:
            if MOVE_FILES:
                shutil.move(str(src), str(dst))
            else:
                shutil.copy2(str(src), str(dst))
            copied += 1
        except Exception as e:
            skipped += 1
            print(f"[WARN] Could not {'move' if MOVE_FILES else 'copy'} {src} → {dst}: {e}")

    action = "moved" if MOVE_FILES else "copied"
    print(f"[DONE] {copied} file(s) {action} to {DEST_DIR} | {skipped} skipped.")
    print(f"[DEST] {DEST_DIR}")

# Videos With 3 Line Plots

In [None]:
#!/usr/bin/env python3
# Move trial videos into per-fly three_line_videos/{testing,training}
# Assumes trial folders are named like: /.../august_06_fly_3/august_06_fly_3_testing_1/

from pathlib import Path
import shutil
import re

# --- Configure this to your root that contains the fly folders ---
ROOT = Path(main_directory).expanduser().resolve()

# Destination folder name under each fly
DEST_FOLDER = "videos_with_rms"

# Video extensions to collect (lowercase, with dot)
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".mpg", ".mpeg", ".m4v"}

# Set to True for a dry run (no files moved)
DRY_RUN = False

TRIAL_DIR_RE = re.compile(r"^(?P<fly>.+)_(?P<phase>testing|training)_(?P<idx>\d+)$", re.IGNORECASE)

def is_video(p: Path) -> bool:
    return p.is_file() and p.suffix.lower() in VIDEO_EXTS

def ensure_unique_path(dst: Path) -> Path:
    """If dst exists, append _dupN before the extension."""
    if not dst.exists():
        return dst
    stem, suffix = dst.stem, dst.suffix
    n = 1
    while True:
        candidate = dst.with_name(f"{stem}_dup{n}{suffix}")
        if not candidate.exists():
            return candidate
        n += 1

def move_videos_from_trial(trial_dir: Path, fly_dir: Path, phase: str):
    # destination under each fly
    dest_dir = fly_dir / DEST_FOLDER / phase.lower()
    if not DRY_RUN:
        dest_dir.mkdir(parents=True, exist_ok=True)

    moved = 0
    for vid in sorted(trial_dir.iterdir()):
        if not is_video(vid):
            continue
        dst = ensure_unique_path(dest_dir / vid.name)
        print(f"{'DRY-RUN: ' if DRY_RUN else ''}Moving {vid} -> {dst}")
        if not DRY_RUN:
            shutil.move(str(vid), str(dst))
        moved += 1
    return moved

def main():
    if not ROOT.exists():
        raise SystemExit(f"ROOT does not exist: {ROOT}")

    # Iterate fly folders directly under ROOT
    for fly_dir in sorted([p for p in ROOT.iterdir() if p.is_dir()]):
        fly_name = fly_dir.name

        # Skip non-fly folders (heuristic: must contain '_fly_')
        if "_fly_" not in fly_name:
            continue

        # Skip destination-looking roots (defensive)
        if fly_name.lower().endswith(("testing", "training")):
            continue

        print(f"\n=== Fly: {fly_name} ===")

        # Iterate immediate subfolders of the fly directory
        for trial_dir in sorted([p for p in fly_dir.iterdir() if p.is_dir()]):
            trial_name = trial_dir.name

            # Skip the destination folder itself and other generated outputs
            if trial_name.lower().startswith(DEST_FOLDER.lower()):
                continue

            m = TRIAL_DIR_RE.match(trial_name)
            if not m:
                # Not a trial folder in the expected pattern
                continue

            phase = m.group("phase").lower()

            # Extra guard: ensure the prefix (fly name) matches this fly
            if m.group("fly") != fly_name:
                continue

            moved = move_videos_from_trial(trial_dir, fly_dir, phase)
            if moved == 0:
                print(f"  (No videos found in {trial_dir})")

    print("\nDone.")

if __name__ == "__main__":
    main()

In [None]:
# JUPYTER CELL — Line-plot panel under the video (Top %, Centered angle %, Class1–Class2 %)
# Sources & normalization mirror the FIXED MEGA heatmap code.
# Videos are read from: {fly}/three_line_videos/{training,testing}/
# Outputs go to:        {fly}/three_line_videos/with_line_plots/{fly}_{category}_{trial}_LINES_three_series.mp4

import os, re, glob, io
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image
from moviepy.editor import VideoFileClip, VideoClip, CompositeVideoClip

# ─────────────────────────────────────────────────────────────
# REQUIRED: set main_directory (Path or str)
assert 'main_directory' in globals(), "Define main_directory = '/path/to/root' before running."
ROOT = Path(main_directory).expanduser().resolve()
assert ROOT.is_dir(), f"Not a directory: {ROOT}"

# ─────────────────────────────────────────────────────────────
# Visual defaults
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'lines.linewidth': 2,
    'axes.linewidth': 1.25,
    'figure.dpi': 160,
    'savefig.dpi': 300,
    'legend.fontsize': 11
})

# ─────────────────────────────────────────────────────────────
# Constants (aligned with MEGA)
FPS_DEFAULT   = 40.0
ANCHOR_X, ANCHOR_Y = 1080.0, 540.0
PCT_COL_ROBUST = "distance_class1_class2_pct"
DIST_COL_ROBUST = "distance_class1_class2"
PRE_SEC, POST_SEC = 30.0, 90.0
TRIM_FRAC = 0.05  # bottom 5% trimmed-min for Eye_Antenna_Dist fallback
MONTHS = (
    "january","february","march","april","may","june",
    "july","august","september","october","november","december",
    "jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec"
)

# ── Panel rendering options ─────────────────────────────────────
PANEL_HEIGHT_FRACTION = 0.24
YLIM = (-100, 100)  # angle is [-100,100], distance lines are [0,100] → keep unified
VIDEO_INPUT_DIR = "three_line_videos"              # where copied trial videos live
VIDEO_OUTPUT_SUBDIR = "with_line_plots"            # under three_line_videos/

# ─────────────────────────────────────────────────────────────
# Helpers (same behavior as MEGA)
# Delete the source videos after rendering new ones
DELETE_SOURCE_AFTER_RENDER = True
# Optionally remove the input category folder if it becomes empty
DELETE_EMPTY_INPUT_DIRS = True

VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".mpg", ".mpeg", ".m4v"}  # keep as in your script

def _safe_unlink(p: Path):
    try:
        if p.exists():
            p.unlink()
            print(f"    Deleted source video: {p.name}")
    except Exception as e:
        print(f"    [warn] Could not delete {p}: {e}")

def _maybe_rmdir_empty(dir_path: Path):
    try:
        # remove folder if it contains no files/dirs
        if dir_path.exists() and not any(dir_path.iterdir()):
            dir_path.rmdir()
            print(f"    Removed empty folder: {dir_path}")
    except Exception as e:
        print(f"    [warn] Could not remove folder {dir_path}: {e}")

def timestamp_to_seconds(ts) -> float:
    if pd.isna(ts): return np.nan
    try: return float(ts)
    except Exception:
        s = str(ts).strip(); parts = s.split(":")
        try:
            if len(parts)==4: hh,mm,ss,ms=parts; return int(hh)*3600+int(mm)*60+int(ss)+int(ms)/1000.0
            if len(parts)==3: hh,mm,ss=parts;   return int(hh)*3600+int(mm)*60+float(ss)
            if len(parts)==2: mm,ss=parts;      return int(mm)*60+float(ss)
            if len(parts)==1: return float(parts[0])
        except Exception: return np.nan
    return np.nan

def find_col(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    cols = set(df.columns)
    for c in cands:
        if c in cols: return c
    return None

def derive_fps(df: pd.DataFrame) -> float:
    fps_col = find_col(df, ["fps","FPS","frame_rate","frameRate"])
    if fps_col is not None:
        fps_val = pd.to_numeric(df[fps_col], errors="coerce").median()
        if np.isfinite(fps_val) and fps_val > 0: return float(fps_val)
    return FPS_DEFAULT

def ensure_time_series(df: pd.DataFrame, frame_col: Optional[str], ts_col: Optional[str]):
    if ts_col is not None:
        ts = pd.to_numeric(df[ts_col], errors="coerce") if ts_col in ["time_seconds","relative_time","time_s"] \
             else df[ts_col].apply(timestamp_to_seconds)
        if ts.notna().sum() >= 2 and np.nanmax(np.diff(ts.dropna().values)) > 0:
            return ts, {'used':'timestamp'}
    if frame_col is not None:
        frames = pd.to_numeric(df[frame_col], errors="coerce")
        if frames.notna().sum() >= 2:
            fps = derive_fps(df); f0 = int(np.nanmin(frames.values))
            return (frames - f0)/fps, {'used':'frame_fallback','fps':fps}
    fps = derive_fps(df)
    return pd.Series(np.arange(len(df), dtype=float)/fps, index=df.index), {'used':'index_fallback','fps':fps}

def extract_trial_index(name: str, category: str) -> Optional[int]:
    m = re.search(rf"{category}_(\d+)", name.lower())
    if m:
        try: return int(m.group(1))
        except: return None
    m2 = re.search(r"(training|testing)_(\d+)", name.lower())
    if m2:
        try: return int(m2.group(2))
        except: return None
    # fallback: last number in stem
    nums = re.findall(r"\d+", name)
    if nums:
        try: return int(nums[-1])
        except: return None
    return None

def odor_window_from_ofm(df: pd.DataFrame, time_col: str = "relative_time"):
    ofm_col = find_col(df, ["OFM State","OFM_State","ofm_state","ofm"])
    if ofm_col is None or time_col not in df.columns: return None
    try:
        mask = df[ofm_col].astype(str).str.lower().eq("during")
        if mask.any():
            t = df[time_col]; idx = np.flatnonzero(mask.to_numpy())
            if idx.size >= 2: return float(t.iloc[idx[0]]), float(t.iloc[idx[-1]])
    except Exception:
        return None
    return None

def compute_angle_deg_at_point2(df: pd.DataFrame) -> pd.Series:
    req = ["x_class2","y_class2","x_class6","y_class6"]
    if any(c not in df.columns for c in req): raise ValueError("Missing angle columns")
    p2x = df["x_class2"].astype(float).to_numpy(); p2y = df["y_class2"].astype(float).to_numpy()
    p3x = df["x_class6"].astype(float).to_numpy(); p3y = df["y_class6"].astype(float).to_numpy()
    ux, uy = (ANCHOR_X - p2x), (ANCHOR_Y - p2y)
    vx, vy = (p3x - p2x), (p3y - p2y)
    dot = ux*vx + uy*vy; cross = ux*vy - uy*vx
    n1 = np.hypot(ux, uy); n2 = np.hypot(vx, vy)
    valid = (n1 > 0) & (n2 > 0) & np.isfinite(dot) & np.isfinite(cross)
    ang = np.full(len(p2x), np.nan); ang[valid] = np.degrees(np.arctan2(np.abs(cross[valid]), dot[valid]))
    return pd.Series(ang, index=df.index, name="angle_ARB_deg")

def find_fly_reference_angle(csvs_raw: List[Path]) -> float:
    best = None
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            dist_col = find_col(df, ["distance_percentage","distance_percent","distance_pct","distance_class1_class2_pct"])
            if dist_col is None: continue
            dist = pd.to_numeric(df[dist_col], errors="coerce").to_numpy()
            exact = np.flatnonzero(dist == 0)
            if exact.size > 0:
                idx = int(exact[0]); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (0, 0.0, angle_here)
            else:
                with np.errstate(invalid="ignore"): absd = np.abs(dist)
                if not np.isfinite(absd).any(): continue
                idx = int(np.nanargmin(absd)); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (1, float(absd[idx]), angle_here)
            if best is None or cand < best: best = cand
        except Exception:
            pass
    return best[2] if best is not None else np.nan

def compute_fly_max_abs_centered(csvs_raw: List[Path], ref_angle: float) -> float:
    fly_max = 0.0
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            centered = ang - ref_angle if np.isfinite(ref_angle) else ang*0.0
            local = np.nanmax(np.abs(centered.to_numpy(dtype=float)))
            if np.isfinite(local): fly_max = max(fly_max, float(local))
        except Exception:
            pass
    return fly_max if np.isfinite(fly_max) and fly_max > 0 else np.nan

def _infer_category_from_path(p: Path) -> Optional[str]:
    tokens = " ".join([*p.parts, p.stem]).lower()
    if any(t in tokens for t in ("training","train","trn")): return "training"
    if any(t in tokens for t in ("testing","test","tst")):  return "testing"
    return None

# ─────────────────────────────────────────────────────────────
# Series collectors (mirror MEGA). Return (t, y, odor_on, odor_off).

def _series_top_distance(fly_dir: Path, category: str, trial_index: int):
    base = fly_dir / "Eye_Prob_Dist" / category
    candidates = []
    if base.is_dir():
        candidates = sorted(base.glob("*.csv"))
    else:
        root = fly_dir / "Eye_Prob_Dist"
        if root.is_dir():
            candidates = sorted([p for p in root.glob("*.csv") if category in p.name.lower()])
    chosen = None
    for p in candidates:
        if extract_trial_index(p.stem, category) == trial_index:
            chosen = p; break
    if chosen is None: return None

    df = pd.read_csv(chosen); df.columns = df.columns.str.strip()
    frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
    ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
    pct_col   = find_col(df, [
        "distance_percentage","distance_percent","distance_pct",
        "distance_proboscis_eye_pct","proboscis_eye_distance_pct",
        "eye_prob_distance_pct","distance_class1_class2_pct"
    ])
    if pct_col is None: return None

    ts, _ = ensure_time_series(df, frame_col, ts_col)
    if ts.notna().sum() < 2: return None
    time_s = pd.to_numeric(ts, errors="coerce"); t0 = float(np.nanmin(time_s))
    df["relative_time"] = time_s - t0
    odor = odor_window_from_ofm(df, "relative_time") or (PRE_SEC, PRE_SEC + 30.0)
    odor_start, odor_end = odor
    total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

    if frame_col is not None and frame_col in df.columns:
        frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
        total_frames = np.arange(frames.min(), frames.max()+1, dtype=int)
        idx_map = {f:i for i,f in enumerate(total_frames)}
        present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
        vals = pd.to_numeric(df.loc[frames.index, pct_col], errors="coerce").to_numpy()
        full = np.full_like(total_frames, np.nan, dtype=float)
        if present_idx: full[np.array(present_idx, dtype=int)] = vals
        t = np.linspace(0, total_duration, len(total_frames)); y = full
    else:
        vals = pd.to_numeric(df[pct_col], errors="coerce").to_numpy()
        t = np.linspace(0, total_duration, len(vals)); y = vals
    return t, y.astype(float), float(PRE_SEC), float(PRE_SEC + (odor_end - odor_start))

# Cache per-fly trimmed-min/global-max for Eye_Antenna_Dist
_ead_stats_cache: Dict[Path, Tuple[float,float]] = {}

def _ead_compute_trim_min_max(fly_dir: Path) -> Optional[Tuple[float,float]]:
    if fly_dir in _ead_stats_cache:
        return _ead_stats_cache[fly_dir]
    base = fly_dir / "Eye_Antenna_Dist"
    if not base.is_dir(): return None
    vals = []
    for p in sorted(base.glob("*combined.csv")):
        try:
            v = pd.to_numeric(pd.read_csv(p, usecols=[DIST_COL_ROBUST])[DIST_COL_ROBUST], errors="coerce").to_numpy()
            vals.append(v)
        except Exception:
            continue
    if not vals: return None
    allv = np.concatenate(vals); allv = allv[np.isfinite(allv)]
    if allv.size == 0: return None
    gmax = float(np.max(allv))
    p5   = float(np.percentile(allv, 100*TRIM_FRAC, method="linear"))
    kept = allv[allv >= p5]
    trimmed_min = float(np.min(kept)) if kept.size else float(np.min(allv))
    _ead_stats_cache[fly_dir] = (trimmed_min, gmax)
    return _ead_stats_cache[fly_dir]

def _series_robust_pct(fly_dir: Path, category: str, trial_index: int):
    base = fly_dir / "Eye_Antenna_Dist"
    if not base.is_dir(): return None
    csvs = sorted([p for p in base.glob("*combined.csv") if category in p.name.lower()])
    chosen = None
    for p in csvs:
        if extract_trial_index(p.stem, category) == trial_index:
            chosen = p; break
    if chosen is None: return None

    df = pd.read_csv(chosen); df.columns = df.columns.str.strip()
    frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
    ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
    ts, _     = ensure_time_series(df, frame_col, ts_col)
    if ts.notna().sum() < 2: return None

    # Ensure percentage in-memory if missing
    if PCT_COL_ROBUST not in df.columns:
        stats = _ead_compute_trim_min_max(fly_dir)
        if stats is None: return None
        fly_min, fly_max = stats
        dist = pd.to_numeric(df.get(DIST_COL_ROBUST, np.nan), errors="coerce")
        if np.isfinite(fly_min) and np.isfinite(fly_max) and (fly_max > fly_min):
            df[PCT_COL_ROBUST] = (dist - fly_min) / (fly_max - fly_min) * 100.0
        else:
            df[PCT_COL_ROBUST] = np.where(dist.notna(), 0.0, np.nan)

    df["time_seconds"] = pd.to_numeric(ts, errors="coerce")
    df["relative_time"] = df["time_seconds"] - df["time_seconds"].min()
    odor = odor_window_from_ofm(df, "relative_time") or (PRE_SEC, PRE_SEC + 30.0)
    odor_start, odor_end = odor
    total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

    if frame_col is not None and frame_col in df.columns:
        frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
        total_frames = np.arange(frames.min(), frames.max()+1, dtype=int)
        idx_map = {f:i for i,f in enumerate(total_frames)}
        present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
        vals = pd.to_numeric(df.loc[frames.index, PCT_COL_ROBUST], errors="coerce").to_numpy()
        full = np.full_like(total_frames, np.nan, dtype=float)
        if present_idx: full[np.array(present_idx, dtype=int)] = vals
        t = np.linspace(0, total_duration, len(total_frames)); y = full
    else:
        vals = pd.to_numeric(df[PCT_COL_ROBUST], errors="coerce").to_numpy()
        t = np.linspace(0, total_duration, len(vals)); y = vals
    return t, y.astype(float), float(PRE_SEC), float(PRE_SEC + (odor_end - odor_start))

def _series_angle_centered_pct(fly_dir: Path, category: str, trial_index: int):
    # Discover angle/raw files across the entire fly folder (mirrors robust MEGA logic)
    angle_csvs = list(fly_dir.rglob("*_distance_class_2_angle_ARB.csv"))
    raw_csvs   = list(fly_dir.rglob("*distance_class_2_angle_ARB.csv"))

    def _cat_match(p: Path) -> bool:
        inf = _infer_category_from_path(p)
        return (inf == category) or (inf is None)

    angle_cat = [p for p in angle_csvs if _cat_match(p)]
    raw_cat   = [p for p in raw_csvs   if _cat_match(p)]

    # Choose the file that corresponds to this trial index (prefer angle over raw)
    chosen = None
    for p in sorted(angle_cat):
        if extract_trial_index(p.stem, category) == trial_index:
            chosen = p; break
    if chosen is None:
        for p in sorted(raw_cat):
            if extract_trial_index(p.stem, category) == trial_index:
                chosen = p; break
    if chosen is None:
        return None

    # Build per-fly reference/max from ALL raw CSVs for stable centering
    raw_for_ref = raw_csvs if raw_csvs else raw_cat
    ref_angle   = find_fly_reference_angle(raw_for_ref) if raw_for_ref else np.nan
    fly_max_abs = compute_fly_max_abs_centered(raw_for_ref, ref_angle) if np.isfinite(ref_angle) else np.nan

    df = pd.read_csv(chosen); df.columns = df.columns.str.strip()
    frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
    ts_col    = find_col(df, ["time_s","timestamp","Timestamp","time","Time","time_seconds","relative_time"])
    if ts_col == "time_s":
        ts = pd.to_numeric(df["time_s"], errors="coerce")
    else:
        ts, _ = ensure_time_series(df, frame_col, ts_col)
    if ts.notna().sum() < 2: return None

    # Odor from sibling/raw if possible
    odor = None
    raw_guess = None
    if "_distance_class_2_angle_ARB.csv" in chosen.name:
        raw_guess = chosen.with_name(chosen.name.replace("_distance_class_2_angle_ARB.csv","_distance_class_2.csv"))
    if raw_guess and raw_guess.exists():
        dfr = pd.read_csv(raw_guess)
        fr_col = find_col(dfr, ["frame","Frame","frame_num","frame_index"])
        ts_r, _ = ensure_time_series(dfr, fr_col, find_col(dfr, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"]))
        dfr["time_seconds"] = pd.to_numeric(ts_r, errors="coerce")
        dfr["relative_time"] = dfr["time_seconds"] - dfr["time_seconds"].min()
        odor = odor_window_from_ofm(dfr, "relative_time")
    if odor is None:
        tmp = df.copy()
        ts_series = pd.to_numeric(ts, errors="coerce")
        tmp["relative_time"] = ts_series - ts_series.min()
        odor = odor_window_from_ofm(tmp, "relative_time")
    if odor is None: odor = (PRE_SEC, PRE_SEC + 30.0)
    odor_start, odor_end = odor
    total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

    # Percent line
    if "angle_centered_pct" in df.columns:
        vals = pd.to_numeric(df["angle_centered_pct"], errors="coerce").to_numpy(dtype=float)
    else:
        df_pos = df
        if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df_pos.columns):
            if raw_guess and raw_guess.exists(): df_pos = pd.read_csv(raw_guess)
        if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df_pos.columns):
            return None
        ang = compute_angle_deg_at_point2(df_pos).to_numpy(dtype=float)
        if np.isfinite(fly_max_abs) and fly_max_abs > 0 and np.isfinite(ref_angle):
            centered = ang - ref_angle; vals = (centered / fly_max_abs) * 100.0
        else:
            vals = np.zeros_like(ang, dtype=float)

    t = np.linspace(0, total_duration, len(vals))
    return t, vals.astype(float), float(PRE_SEC), float(PRE_SEC + (odor_end - odor_start))

# ─────────────────────────────────────────────────────────────
# Video discovery & matching in three_line_videos/<category>

VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".mpg", ".mpeg", ".m4v"}

def _find_video_for_trial(fly_dir: Path, category: str, tri: int) -> Optional[Path]:
    vid_dir = fly_dir / VIDEO_INPUT_DIR / category
    if not vid_dir.is_dir(): return None
    vids = [p for p in vid_dir.iterdir() if p.is_file() and p.suffix.lower() in VIDEO_EXTS]
    # Heuristics: exact token "<category>_<tri>", then "_<tri>.", then single-file fallback
    token = f"{category}_{tri}"
    cand = [v for v in vids if token in v.stem.lower()]
    if cand: return cand[0]
    cand = [v for v in vids if re.search(rf"[_\-]{tri}(\D|$)", v.stem)]
    if cand: return cand[0]
    if len(vids) == 1: return vids[0]
    return None

# ─────────────────────────────────────────────────────────────
# Line-panel rendering & composition

def _render_line_panel_png(series_list: List[dict],
                           width_px: int,
                           height_px: int,
                           xlim: Tuple[float,float],
                           ylim: Tuple[float,float],
                           odor_on: float | None,
                           odor_off: float | None) -> np.ndarray:
    """Render a single axes with up to 3 lines, legend, and odor window; axes hidden."""
    fig, ax = plt.subplots(figsize=(width_px/100, height_px/100), dpi=100)

    for s in series_list:
        ax.plot(s["t"], s["y"], label=s["label"], linewidth=1)

    ax.set_xlim(*xlim); ax.set_ylim(*ylim)

    if odor_on is not None and odor_off is not None and np.isfinite(odor_on) and np.isfinite(odor_off):
        ax.axvline(odor_on, color='red', linewidth=1.0)
        ax.axvline(odor_off, color='red', linewidth=1.0)

    if series_list:
        leg = ax.legend(loc='upper right', frameon=True, framealpha=0.9)
        leg.get_frame().set_facecolor('white')
        leg.get_frame().set_edgecolor('black')
        leg.get_frame().set_linewidth(0.8)

    ax.axis("off")
    plt.tight_layout(pad=0)
    buf = io.BytesIO(); plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0); plt.close(fig)
    buf.seek(0)
    arr = np.array(Image.open(buf).convert("RGB"))
    return np.array(Image.fromarray(arr).resize((width_px, height_px), resample=Image.BILINEAR))

def _compose_lineplot_video(video_path: Path,
                            series_list: List[dict],
                            xlim: Tuple[float,float],
                            odor_on: float | None,
                            odor_off: float | None,
                            out_mp4: Path,
                            panel_height_fraction: float = PANEL_HEIGHT_FRACTION,
                            ylim: Tuple[float,float] = YLIM) -> bool:
    """Compose a single panel with up to 3 lines under the video; add a moving cursor.
       Returns True on successful write."""
    clip = VideoFileClip(str(video_path))
    vw, vh = clip.size
    ph = max(1, int(vh * panel_height_fraction))

    bg = _render_line_panel_png(series_list, vw, ph, xlim, ylim, odor_on, odor_off)

    def _add_cursor(img: np.ndarray, t_cur: float, xlim: Tuple[float,float]) -> np.ndarray:
        img = img.copy()
        t0, t1 = float(xlim[0]), float(xlim[1])
        if not (np.isfinite(t0) and np.isfinite(t1)) or t1 <= t0: return img
        frac = float(np.clip((t_cur - t0) / (t1 - t0), 0.0, 0.9999))
        x = int(frac * (img.shape[1] - 1))
        img[:, x:x+2, 0] = 255; img[:, x:x+2, 1:] = 0
        return img

    def panel_frame(t_cur: float) -> np.ndarray:
        return _add_cursor(bg, t_cur, xlim)

    panel_clip = VideoClip(panel_frame, duration=clip.duration)
    comp = CompositeVideoClip([clip.set_position(("center", 0)),
                               panel_clip.set_position(("center", vh))],
                              size=(vw, vh + ph))

    out_mp4.parent.mkdir(parents=True, exist_ok=True)
    comp.write_videofile(str(out_mp4), fps=clip.fps, codec="libx264", audio=False, preset="ultrafast")

    # cleanup
    for c in (clip, panel_clip, comp):
        try: c.close()
        except: pass

    try:
        return out_mp4.exists() and out_mp4.stat().st_size > 0
    except Exception:
        return False

# ─────────────────────────────────────────────────────────────
# Trial discovery (union across sources, per category), then compose videos

def _discover_trials(fly_dir: Path, category: str) -> List[int]:
    trials = set()

    # Eye_Prob_Dist
    base = fly_dir / "Eye_Prob_Dist" / category
    if base.is_dir():
        for p in base.glob("*.csv"):
            ti = extract_trial_index(p.stem, category); 
            if ti is not None: trials.add(ti)
    else:
        root = fly_dir / "Eye_Prob_Dist"
        if root.is_dir():
            for p in root.glob("*.csv"):
                if category in p.name.lower():
                    ti = extract_trial_index(p.stem, category); 
                    if ti is not None: trials.add(ti)

    # Eye_Antenna_Dist
    ead = fly_dir / "Eye_Antenna_Dist"
    if ead.is_dir():
        for p in ead.glob("*combined.csv"):
            if category in p.name.lower():
                ti = extract_trial_index(p.stem, category); 
                if ti is not None: trials.add(ti)

    # Angle files anywhere
    for p in fly_dir.rglob("*_distance_class_2_angle_ARB.csv"):
        if (_infer_category_from_path(p) in (category, None)):
            ti = extract_trial_index(p.stem, category)
            if ti is not None: trials.add(ti)

    # Raw as last resort
    for p in fly_dir.rglob("*_distance_class_2.csv"):
        if (_infer_category_from_path(p) in (category, None)):
            ti = extract_trial_index(p.stem, category)
            if ti is not None: trials.add(ti)

    return sorted(trials)

def _is_month_fly(p: Path) -> bool:
    return p.is_dir() and any(p.name.lower().startswith(m) for m in MONTHS)

# ─────────────────────────────────────────────────────────────
# Main

for fly_dir in sorted([p for p in ROOT.iterdir() if _is_month_fly(p)]):
    fly_name = fly_dir.name
    print(f"\n=== Fly: {fly_name} ===")

    for category in ("training", "testing"):
        trials = _discover_trials(fly_dir, category)
        if not trials:
            print(f"  [{category}] No trials discovered.")
            continue

        # Prepare output directory under three_line_videos
        out_root = fly_dir / VIDEO_INPUT_DIR / VIDEO_OUTPUT_SUBDIR
        out_root.mkdir(parents=True, exist_ok=True)

        for tri in trials:
            # Find a video for this trial
            video_path = _find_video_for_trial(fly_dir, category, tri)
            if not video_path:
                print(f"  [{category} {tri}] ⤫ No matching video in {VIDEO_INPUT_DIR}/{category}/")
                continue

            # Build up to three series
            s_top   = _series_top_distance(fly_dir, category, tri)
            s_angle = _series_angle_centered_pct(fly_dir, category, tri)
            s_pct   = _series_robust_pct(fly_dir, category, tri)

            series_list = []
            xmins, xmaxs = [], []
            odor_candidates = []

            if s_top is not None:
                t, y, on, off = s_top
                series_list.append({"t": t, "y": y, "label": "Proboscis Distance %"})
                xmins.append(np.nanmin(t)); xmaxs.append(np.nanmax(t))
                odor_candidates.append((on, off))
            if s_angle is not None:
                t, y, on, off = s_angle
                series_list.append({"t": t, "y": y, "label": "Centered Proboscis Angle %"})
                xmins.append(np.nanmin(t)); xmaxs.append(np.nanmax(t))
                odor_candidates.append((on, off))
            if s_pct is not None:
                t, y, on, off = s_pct
                series_list.append({"t": t, "y": y, "label": "Antenna Distance %"})
                xmins.append(np.nanmin(t)); xmaxs.append(np.nanmax(t))
                odor_candidates.append((on, off))

            if not series_list:
                print(f"  [{category} {tri}] ⤫ No data series found; skipping.")
                continue

            xlim = (float(np.nanmin(xmins)), float(np.nanmax(xmaxs)))
            # Prefer robust’s odor window if present; else first available
            odor_on = odor_candidates[-1][0] if s_pct is not None else odor_candidates[0][0]
            odor_off = odor_candidates[-1][1] if s_pct is not None else odor_candidates[0][1]

            out_mp4 = out_root / f"{fly_name}_{category}_{tri}_LINES_three_series.mp4"
            if out_mp4.exists():
                print(f"  [{category} {tri}] ⤫ Exists, skipping: {out_mp4.name}")
                continue

            print(f"  [{category} {tri}] ✓ Video: {video_path.name} → {out_mp4.name}")
            ok = _compose_lineplot_video(video_path, series_list, xlim, odor_on, odor_off, out_mp4,
                             panel_height_fraction=PANEL_HEIGHT_FRACTION,
                             ylim=YLIM)

            if ok:
                print(f"  [{category} {tri}] [SAVED] {out_mp4.name}")
                if DELETE_SOURCE_AFTER_RENDER:
                    _safe_unlink(video_path)
            else:
                print(f"  [{category} {tri}] ⤫ Render failed; source retained.")
                
    if DELETE_SOURCE_AFTER_RENDER and DELETE_EMPTY_INPUT_DIRS:
        cat_input_dir = fly_dir / VIDEO_INPUT_DIR / category
        # If all videos for this category were deleted and the folder is empty, remove it
        if cat_input_dir.exists():
            # Only consider known video types to decide emptiness
            remaining = [p for p in cat_input_dir.iterdir()
                         if p.is_file() and p.suffix.lower() in VIDEO_EXTS]
            if not remaining:
                _maybe_rmdir_empty(cat_input_dir)

print("\nDone.")

## Just RMS

In [None]:
# JUPYTER CELL — Line-plot panel under the video (Top %, Centered angle %, Class1–Class2 %)
# Sources & normalization mirror the FIXED MEGA heatmap code.
# Videos are read from: {fly}/three_line_videos/{training,testing}/
# Outputs go to:        {fly}/three_line_videos/with_line_plots/{fly}_{category}_{trial}_LINES_three_series.mp4

import os, re, glob, io
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image
from moviepy.editor import VideoFileClip, VideoClip, CompositeVideoClip

# ─────────────────────────────────────────────────────────────
# REQUIRED: set main_directory (Path or str)
assert 'main_directory' in globals(), "Define main_directory = '/path/to/root' before running."
ROOT = Path(main_directory).expanduser().resolve()
assert ROOT.is_dir(), f"Not a directory: {ROOT}"

# ─────────────────────────────────────────────────────────────
# Visual defaults
plt.rcParams.update({
    'font.family': 'serif',
    'font.size': 12,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'lines.linewidth': 2,
    'axes.linewidth': 1.25,
    'figure.dpi': 160,
    'savefig.dpi': 300,
    'legend.fontsize': 11
})

# ─────────────────────────────────────────────────────────────
# Constants (aligned with MEGA)
FPS_DEFAULT   = 40.0
ANCHOR_X, ANCHOR_Y = 1080.0, 540.0
PCT_COL_ROBUST = "distance_class1_class2_pct"
DIST_COL_ROBUST = "distance_class1_class2"
PRE_SEC, POST_SEC = 30.0, 90.0
TRIM_FRAC = 0.05  # bottom 5% trimmed-min for Eye_Antenna_Dist fallback
MONTHS = (
    "january","february","march","april","may","june",
    "july","august","september","october","november","december",
    "jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec"
)
RMS_WINDOW_S = 1.0   # seconds for rolling RMS window
THRESH_K     = 4.0   # threshold = mean_pre + K * std_pre

# ── Panel rendering options ─────────────────────────────────────
PANEL_HEIGHT_FRACTION = 0.24
YLIM = (-100, 100)  # angle is [-100,100], distance lines are [0,100] → keep unified
VIDEO_INPUT_DIR = "videos_with_rms"              # where copied trial videos live
VIDEO_OUTPUT_SUBDIR = "videos_with_rms"            # under three_line_videos/

# ─────────────────────────────────────────────────────────────
# Helpers (same behavior as MEGA)
# Delete the source videos after rendering new ones
DELETE_SOURCE_AFTER_RENDER = True
# Optionally remove the input category folder if it becomes empty
DELETE_EMPTY_INPUT_DIRS = True

VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".mpg", ".mpeg", ".m4v"}  # keep as in your script

def _safe_unlink(p: Path):
    try:
        if p.exists():
            p.unlink()
            print(f"    Deleted source video: {p.name}")
    except Exception as e:
        print(f"    [warn] Could not delete {p}: {e}")

def _maybe_rmdir_empty(dir_path: Path):
    try:
        # remove folder if it contains no files/dirs
        if dir_path.exists() and not any(dir_path.iterdir()):
            dir_path.rmdir()
            print(f"    Removed empty folder: {dir_path}")
    except Exception as e:
        print(f"    [warn] Could not remove folder {dir_path}: {e}")

def timestamp_to_seconds(ts) -> float:
    if pd.isna(ts): return np.nan
    try: return float(ts)
    except Exception:
        s = str(ts).strip(); parts = s.split(":")
        try:
            if len(parts)==4: hh,mm,ss,ms=parts; return int(hh)*3600+int(mm)*60+int(ss)+int(ms)/1000.0
            if len(parts)==3: hh,mm,ss=parts;   return int(hh)*3600+int(mm)*60+float(ss)
            if len(parts)==2: mm,ss=parts;      return int(mm)*60+float(ss)
            if len(parts)==1: return float(parts[0])
        except Exception: return np.nan
    return np.nan

def find_col(df: pd.DataFrame, cands: List[str]) -> Optional[str]:
    cols = set(df.columns)
    for c in cands:
        if c in cols: return c
    return None

def derive_fps(df: pd.DataFrame) -> float:
    fps_col = find_col(df, ["fps","FPS","frame_rate","frameRate"])
    if fps_col is not None:
        fps_val = pd.to_numeric(df[fps_col], errors="coerce").median()
        if np.isfinite(fps_val) and fps_val > 0: return float(fps_val)
    return FPS_DEFAULT

def ensure_time_series(df: pd.DataFrame, frame_col: Optional[str], ts_col: Optional[str]):
    if ts_col is not None:
        ts = pd.to_numeric(df[ts_col], errors="coerce") if ts_col in ["time_seconds","relative_time","time_s"] \
             else df[ts_col].apply(timestamp_to_seconds)
        if ts.notna().sum() >= 2 and np.nanmax(np.diff(ts.dropna().values)) > 0:
            return ts, {'used':'timestamp'}
    if frame_col is not None:
        frames = pd.to_numeric(df[frame_col], errors="coerce")
        if frames.notna().sum() >= 2:
            fps = derive_fps(df); f0 = int(np.nanmin(frames.values))
            return (frames - f0)/fps, {'used':'frame_fallback','fps':fps}
    fps = derive_fps(df)
    return pd.Series(np.arange(len(df), dtype=float)/fps, index=df.index), {'used':'index_fallback','fps':fps}

def extract_trial_index(name: str, category: str) -> Optional[int]:
    m = re.search(rf"{category}_(\d+)", name.lower())
    if m:
        try: return int(m.group(1))
        except: return None
    m2 = re.search(r"(training|testing)_(\d+)", name.lower())
    if m2:
        try: return int(m2.group(2))
        except: return None
    # fallback: last number in stem
    nums = re.findall(r"\d+", name)
    if nums:
        try: return int(nums[-1])
        except: return None
    return None

def odor_window_from_ofm(df: pd.DataFrame, time_col: str = "relative_time"):
    ofm_col = find_col(df, ["OFM State","OFM_State","ofm_state","ofm"])
    if ofm_col is None or time_col not in df.columns: return None
    try:
        mask = df[ofm_col].astype(str).str.lower().eq("during")
        if mask.any():
            t = df[time_col]; idx = np.flatnonzero(mask.to_numpy())
            if idx.size >= 2: return float(t.iloc[idx[0]]), float(t.iloc[idx[-1]])
    except Exception:
        return None
    return None

def compute_angle_deg_at_point2(df: pd.DataFrame) -> pd.Series:
    req = ["x_class2","y_class2","x_class6","y_class6"]
    if any(c not in df.columns for c in req): raise ValueError("Missing angle columns")
    p2x = df["x_class2"].astype(float).to_numpy(); p2y = df["y_class2"].astype(float).to_numpy()
    p3x = df["x_class6"].astype(float).to_numpy(); p3y = df["y_class6"].astype(float).to_numpy()
    ux, uy = (ANCHOR_X - p2x), (ANCHOR_Y - p2y)
    vx, vy = (p3x - p2x), (p3y - p2y)
    dot = ux*vx + uy*vy; cross = ux*vy - uy*vx
    n1 = np.hypot(ux, uy); n2 = np.hypot(vx, vy)
    valid = (n1 > 0) & (n2 > 0) & np.isfinite(dot) & np.isfinite(cross)
    ang = np.full(len(p2x), np.nan); ang[valid] = np.degrees(np.arctan2(np.abs(cross[valid]), dot[valid]))
    return pd.Series(ang, index=df.index, name="angle_ARB_deg")

def find_fly_reference_angle(csvs_raw: List[Path]) -> float:
    best = None
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            dist_col = find_col(df, ["distance_percentage","distance_percent","distance_pct","distance_class1_class2_pct"])
            if dist_col is None: continue
            dist = pd.to_numeric(df[dist_col], errors="coerce").to_numpy()
            exact = np.flatnonzero(dist == 0)
            if exact.size > 0:
                idx = int(exact[0]); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (0, 0.0, angle_here)
            else:
                with np.errstate(invalid="ignore"): absd = np.abs(dist)
                if not np.isfinite(absd).any(): continue
                idx = int(np.nanargmin(absd)); angle_here = float(ang.iloc[idx]) if np.isfinite(ang.iloc[idx]) else np.nan
                cand = (1, float(absd[idx]), angle_here)
            if best is None or cand < best: best = cand
        except Exception:
            pass
    return best[2] if best is not None else np.nan

def compute_fly_max_abs_centered(csvs_raw: List[Path], ref_angle: float) -> float:
    fly_max = 0.0
    for p in csvs_raw:
        try:
            df = pd.read_csv(p)
            if not {"x_class2","y_class2","x_class6","y_class6"}.issubset(df.columns): continue
            ang = compute_angle_deg_at_point2(df)
            centered = ang - ref_angle if np.isfinite(ref_angle) else ang*0.0
            local = np.nanmax(np.abs(centered.to_numpy(dtype=float)))
            if np.isfinite(local): fly_max = max(fly_max, float(local))
        except Exception:
            pass
    return fly_max if np.isfinite(fly_max) and fly_max > 0 else np.nan

def _infer_category_from_path(p: Path) -> Optional[str]:
    tokens = " ".join([*p.parts, p.stem]).lower()
    if any(t in tokens for t in ("training","train","trn")): return "training"
    if any(t in tokens for t in ("testing","test","tst")):  return "testing"
    return None

# ─────────────────────────────────────────────────────────────
# Series collectors (mirror MEGA). Return (t, y, odor_on, odor_off).

def _series_rms_from_rmscalc(fly_dir: Path, category: str, trial_index: int):
    """
    Load the combined CSV for this trial from <fly>/RMS_calculations/,
    build a timebase, compute rolling RMS of a distance-% column, and return:
      (t, rms, odor_on, odor_off, threshold)
    """
    # Candidate filenames like: september_08_fly_1_testing_1_*_class_combined.csv
    rdir = fly_dir / "RMS_calculations"
    if not rdir.is_dir():
        return None

    cands = sorted([p for p in rdir.glob("*merged.csv")
                    if category in p.name.lower() and extract_trial_index(p.stem, category) == trial_index])
    if not cands:
        return None
    csv_path = cands[0]

    df = pd.read_csv(csv_path)
    df.columns = df.columns.str.strip()

    # Time
    frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
    ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
    ts, _meta = ensure_time_series(df, frame_col, ts_col)
    if ts.notna().sum() < 2:
        return None
    time_s = pd.to_numeric(ts, errors="coerce")
    t0 = float(np.nanmin(time_s))
    df["relative_time"] = time_s - t0

    # Odor window from OFM_State if available; else default [PRE_SEC, PRE_SEC+30]
    odor = odor_window_from_ofm(df, "relative_time") or (PRE_SEC, PRE_SEC + 30.0)
    odor_on, odor_off = map(float, odor)
    total_duration = PRE_SEC + (odor_off - odor_on) + POST_SEC

    # Pick a distance-% column (common names across your pipeline)
    pct_col = find_col(df, [
        "distance_class1_class2_pct",
        "distance_percentage","distance_percent","distance_pct",
        "distance_proboscis_eye_pct","proboscis_eye_distance_pct","eye_prob_distance_pct",
        "distance_percentage_2_6"  # if you prefer this, move it to the top of the list
    ])
    if pct_col is None:
        return None

    vals = pd.to_numeric(df[pct_col], errors="coerce").to_numpy(dtype=float)

    # Rolling RMS (centered, 1 s window by FPS)
    fps = derive_fps(df)
    win = max(1, int(round(fps * RMS_WINDOW_S)))
    s = pd.Series(vals)
    rms = s.rolling(win, min_periods=max(1, win // 2), center=True).apply(
        lambda x: float(np.sqrt(np.nanmean(np.square(x)))), raw=False
    ).to_numpy()

    # Time axis sized to the series
    t = np.linspace(0.0, total_duration, len(rms))

    # Threshold from pre-odor region
    pre_mask = (df["relative_time"].to_numpy(dtype=float) < PRE_SEC)
    # align mask length to rms length if needed
    if pre_mask.size != rms.size:
        # fallback: approximate using fps
        pre_vals = rms[:max(1, int(PRE_SEC * fps))]
    else:
        pre_vals = rms[pre_mask]
    mu = float(np.nanmean(pre_vals)) if np.isfinite(pre_vals).any() else np.nan
    sd = float(np.nanstd(pre_vals))  if np.isfinite(pre_vals).any() else np.nan
    threshold = mu + THRESH_K * sd if np.isfinite(mu) and np.isfinite(sd) else np.nan

    return t, rms.astype(float), odor_on, odor_off, threshold

# Cache per-fly trimmed-min/global-max for Eye_Antenna_Dist
_ead_stats_cache: Dict[Path, Tuple[float,float]] = {}

def _ead_compute_trim_min_max(fly_dir: Path) -> Optional[Tuple[float,float]]:
    if fly_dir in _ead_stats_cache:
        return _ead_stats_cache[fly_dir]
    base = fly_dir / "RMS_calculations"
    if not base.is_dir(): return None
    vals = []
    for p in sorted(base.glob("*merged.csv")):
        try:
            v = pd.to_numeric(pd.read_csv(p, usecols=[DIST_COL_ROBUST])[DIST_COL_ROBUST], errors="coerce").to_numpy()
            vals.append(v)
        except Exception:
            continue
    if not vals: return None
    allv = np.concatenate(vals); allv = allv[np.isfinite(allv)]
    if allv.size == 0: return None
    gmax = float(np.max(allv))
    p5   = float(np.percentile(allv, 100*TRIM_FRAC, method="linear"))
    kept = allv[allv >= p5]
    trimmed_min = float(np.min(kept)) if kept.size else float(np.min(allv))
    _ead_stats_cache[fly_dir] = (trimmed_min, gmax)
    return _ead_stats_cache[fly_dir]

def _series_robust_pct(fly_dir: Path, category: str, trial_index: int):
    base = fly_dir / "RMS_calculations"
    if not base.is_dir(): return None
    csvs = sorted([p for p in base.glob("*merged.csv") if category in p.name.lower()])
    chosen = None
    for p in csvs:
        if extract_trial_index(p.stem, category) == trial_index:
            chosen = p; break
    if chosen is None: return None

    df = pd.read_csv(chosen); df.columns = df.columns.str.strip()
    frame_col = find_col(df, ["frame","Frame","frame_num","frame_index"])
    ts_col    = find_col(df, ["timestamp","Timestamp","time","Time","time_seconds","relative_time","time_s"])
    ts, _     = ensure_time_series(df, frame_col, ts_col)
    if ts.notna().sum() < 2: return None

    # Ensure percentage in-memory if missing
    if PCT_COL_ROBUST not in df.columns:
        stats = _ead_compute_trim_min_max(fly_dir)
        if stats is None: return None
        fly_min, fly_max = stats
        dist = pd.to_numeric(df.get(DIST_COL_ROBUST, np.nan), errors="coerce")
        if np.isfinite(fly_min) and np.isfinite(fly_max) and (fly_max > fly_min):
            df[PCT_COL_ROBUST] = (dist - fly_min) / (fly_max - fly_min) * 100.0
        else:
            df[PCT_COL_ROBUST] = np.where(dist.notna(), 0.0, np.nan)

    df["time_seconds"] = pd.to_numeric(ts, errors="coerce")
    df["relative_time"] = df["time_seconds"] - df["time_seconds"].min()
    odor = odor_window_from_ofm(df, "relative_time") or (PRE_SEC, PRE_SEC + 30.0)
    odor_start, odor_end = odor
    total_duration = PRE_SEC + (odor_end - odor_start) + POST_SEC

    if frame_col is not None and frame_col in df.columns:
        frames = pd.to_numeric(df[frame_col], errors="coerce").dropna().astype(int)
        total_frames = np.arange(frames.min(), frames.max()+1, dtype=int)
        idx_map = {f:i for i,f in enumerate(total_frames)}
        present_idx = [idx_map.get(int(f)) for f in frames if int(f) in idx_map]
        vals = pd.to_numeric(df.loc[frames.index, PCT_COL_ROBUST], errors="coerce").to_numpy()
        full = np.full_like(total_frames, np.nan, dtype=float)
        if present_idx: full[np.array(present_idx, dtype=int)] = vals
        t = np.linspace(0, total_duration, len(total_frames)); y = full
    else:
        vals = pd.to_numeric(df[PCT_COL_ROBUST], errors="coerce").to_numpy()
        t = np.linspace(0, total_duration, len(vals)); y = vals
    return t, y.astype(float), float(PRE_SEC), float(PRE_SEC + (odor_end - odor_start))

# ─────────────────────────────────────────────────────────────
# Video discovery & matching in three_line_videos/<category>

VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".mpg", ".mpeg", ".m4v"}

def _find_video_for_trial(fly_dir: Path, category: str, tri: int) -> Optional[Path]:
    vid_dir = fly_dir / VIDEO_INPUT_DIR / category
    if not vid_dir.is_dir(): return None
    vids = [p for p in vid_dir.iterdir() if p.is_file() and p.suffix.lower() in VIDEO_EXTS]
    # Heuristics: exact token "<category>_<tri>", then "_<tri>.", then single-file fallback
    token = f"{category}_{tri}"
    cand = [v for v in vids if token in v.stem.lower()]
    if cand: return cand[0]
    cand = [v for v in vids if re.search(rf"[_\-]{tri}(\D|$)", v.stem)]
    if cand: return cand[0]
    if len(vids) == 1: return vids[0]
    return None

# ─────────────────────────────────────────────────────────────
# Line-panel rendering & composition

def _render_line_panel_png(series_list: List[dict],
                           width_px: int,
                           height_px: int,
                           xlim: Tuple[float,float],
                           ylim: Tuple[float,float],
                           odor_on: float | None,
                           odor_off: float | None,
                           threshold: float | None) -> np.ndarray:
    fig, ax = plt.subplots(figsize=(width_px/100, height_px/100), dpi=100)

    # Expect one series: RMS
    if series_list:
        s = series_list[0]
        ax.plot(s["t"], s["y"], label="RMS", linewidth=1.2, color="blue")

    # Horizontal threshold in red
    if threshold is not None and np.isfinite(threshold):
        ax.axhline(threshold, color="red", linewidth=1.2, label="Threshold")

    # Odor-on/off markers (vertical red)
    if odor_on is not None and odor_off is not None and np.isfinite(odor_on) and np.isfinite(odor_off):
        ax.axvline(odor_on, color='red', linewidth=1.0)
        ax.axvline(odor_off, color='red', linewidth=1.0)

    ax.set_xlim(*xlim); ax.set_ylim(*ylim)

    if series_list or (threshold is not None and np.isfinite(threshold)):
        leg = ax.legend(loc='upper right', frameon=True, framealpha=0.9)
        leg.get_frame().set_facecolor('white')
        leg.get_frame().set_edgecolor('black')
        leg.get_frame().set_linewidth(0.8)

    ax.axis("off")
    plt.tight_layout(pad=0)
    buf = io.BytesIO(); plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0); plt.close(fig)
    buf.seek(0)
    arr = np.array(Image.open(buf).convert("RGB"))
    return np.array(Image.fromarray(arr).resize((width_px, height_px), resample=Image.BILINEAR))

def _compose_lineplot_video(video_path: Path,
                            series_list: List[dict],
                            xlim: Tuple[float,float],
                            odor_on: float | None,
                            odor_off: float | None,
                            out_mp4: Path,
                            panel_height_fraction: float = PANEL_HEIGHT_FRACTION,
                            ylim: Tuple[float,float] = YLIM,
                            threshold: float | None = None) -> bool:
    clip = VideoFileClip(str(video_path))
    vw, vh = clip.size
    ph = max(1, int(vh * panel_height_fraction))

    bg = _render_line_panel_png(series_list, vw, ph, xlim, ylim, odor_on, odor_off, threshold)
    
    def _add_cursor(img: np.ndarray, t_cur: float, xlim: Tuple[float,float]) -> np.ndarray:
        img = img.copy()
        t0, t1 = float(xlim[0]), float(xlim[1])
        if not (np.isfinite(t0) and np.isfinite(t1)) or t1 <= t0: return img
        frac = float(np.clip((t_cur - t0) / (t1 - t0), 0.0, 0.9999))
        x = int(frac * (img.shape[1] - 1))
        img[:, x:x+2, 0] = 255; img[:, x:x+2, 1:] = 0
        return img

    def panel_frame(t_cur: float) -> np.ndarray:
        return _add_cursor(bg, t_cur, xlim)

    panel_clip = VideoClip(panel_frame, duration=clip.duration)
    comp = CompositeVideoClip([clip.set_position(("center", 0)),
                               panel_clip.set_position(("center", vh))],
                              size=(vw, vh + ph))

    out_mp4.parent.mkdir(parents=True, exist_ok=True)
    comp.write_videofile(str(out_mp4), fps=clip.fps, codec="libx264", audio=False, preset="ultrafast")

    # cleanup
    for c in (clip, panel_clip, comp):
        try: c.close()
        except: pass

    try:
        return out_mp4.exists() and out_mp4.stat().st_size > 0
    except Exception:
        return False

# ─────────────────────────────────────────────────────────────
# Trial discovery (union across sources, per category), then compose videos

def _discover_trials(fly_dir: Path, category: str) -> List[int]:
    trials = set()

    # RMS_calculations
    rdir = fly_dir / "RMS_calculations"
    if rdir.is_dir():
        for p in rdir.glob("*merged.csv"):
            if category in p.name.lower():
                ti = extract_trial_index(p.stem, category)
                if ti is not None:
                    trials.add(ti)

    # keep any other sources you want (optional)...

    return sorted(trials)

def _is_month_fly(p: Path) -> bool:
    return p.is_dir() and any(p.name.lower().startswith(m) for m in MONTHS)

# ─────────────────────────────────────────────────────────────
# Main

for fly_dir in sorted([p for p in ROOT.iterdir() if _is_month_fly(p)]):
    fly_name = fly_dir.name
    print(f"\n=== Fly: {fly_name} ===")

    for category in ("training", "testing"):
        trials = _discover_trials(fly_dir, category)
        if not trials:
            print(f"  [{category}] No trials discovered.")
            continue

        # Prepare output directory under three_line_videos
        out_root = fly_dir / VIDEO_INPUT_DIR / VIDEO_OUTPUT_SUBDIR
        out_root.mkdir(parents=True, exist_ok=True)

        for tri in trials:
            # Find a video for this trial
            video_path = _find_video_for_trial(fly_dir, category, tri)
            if not video_path:
                print(f"  [{category} {tri}] ⤫ No matching video in {VIDEO_INPUT_DIR}/{category}/")
                continue

            # Only RMS series
            s_rms = _series_rms_from_rmscalc(fly_dir, category, tri)
            if s_rms is None:
                print(f"  [{category} {tri}] ⤫ No RMS series (RMS_calculations); skipping.")
                continue

            t, y, on, off, thr = s_rms
            series_list = [{"t": t, "y": y, "label": "RMS"}]
            xlim = (float(np.nanmin(t)), float(np.nanmax(t)))
            odor_on, odor_off = on, off

            out_mp4 = out_root / f"{fly_name}_{category}_{tri}_LINES_rms.mp4"
            if out_mp4.exists():
                print(f"  [{category} {tri}] ⤫ Exists, skipping: {out_mp4.name}")
                continue

            print(f"  [{category} {tri}] ✓ Video: {video_path.name} → {out_mp4.name}")
            ok = _compose_lineplot_video(
                video_path, series_list, xlim, odor_on, odor_off, out_mp4,
                panel_height_fraction=PANEL_HEIGHT_FRACTION, ylim=YLIM, threshold=thr
            )

            if ok:
                print(f"  [{category} {tri}] [SAVED] {out_mp4.name}")
                if DELETE_SOURCE_AFTER_RENDER:
                    _safe_unlink(video_path)
            else:
                print(f"  [{category} {tri}] ⤫ Render failed; source retained.")
                
    if DELETE_SOURCE_AFTER_RENDER and DELETE_EMPTY_INPUT_DIRS:
        cat_input_dir = fly_dir / VIDEO_INPUT_DIR / category
        # If all videos for this category were deleted and the folder is empty, remove it
        if cat_input_dir.exists():
            # Only consider known video types to decide emptiness
            remaining = [p for p in cat_input_dir.iterdir()
                         if p.is_file() and p.suffix.lower() in VIDEO_EXTS]
            if not remaining:
                _maybe_rmdir_empty(cat_input_dir)

print("\nDone.")

# Single CSV File --> Matrix

In [None]:
# JUPYTER CELL — Wide CSV of per-frame ENVELOPE across MANY main_directories
from pathlib import Path
import re
import numpy as np
import pandas as pd
from scipy.signal import hilbert

# ───────── INPUT: add as many roots as you need ─────────
ROOTS = [
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_benz/"),
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_EB/"),
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_benz_1/"),
]

# ───────── CONFIG (mirrors your envelope logic) ─────────
MEASURE_COLS  = ["distance_percentage", "distance_percentage_2_6"]
FPS_DEFAULT   = 40
WINDOW_SEC    = 0.25
WINDOW_FRAMES = max(int(WINDOW_SEC * FPS_DEFAULT), 1)

# Output file (combined for all datasets)
OUT_WIDE_CSV = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/all_envelope_rows_wide.csv")

# AFTER
TRIAL_REGEX = re.compile(r"(testing|training)_(\d+)", re.IGNORECASE)

from typing import Optional

# Timestamp + frame columns we’ll look for
TIMESTAMP_CANDIDATES = ["UTC_ISO", "Timestamp", "Number", "MonoNs"]
FRAME_CANDIDATES     = ["Frame", "FrameNumber", "Frame Number"]

# Fallback when FPS can’t be inferred from timestamps (no video here, so use constant 50 like prior script)
FALLBACK_FPS = 40

def _pick_timestamp_column(df: pd.DataFrame) -> Optional[str]:
    for c in TIMESTAMP_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _pick_frame_column(df: pd.DataFrame) -> Optional[str]:
    for c in FRAME_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _to_seconds_series(df: pd.DataFrame, ts_col: str) -> pd.Series:
    """
    Return float seconds aligned to rows (t=0 at first valid).
    - UTC_ISO / Timestamp: ISO-8601 strings → seconds
    - Number: numeric seconds
    - MonoNs: numeric nanoseconds → seconds
    """
    s = df[ts_col]
    if ts_col in ("UTC_ISO", "Timestamp"):
        dt = pd.to_datetime(s, errors="coerce", utc=(ts_col == "UTC_ISO"))
        secs = dt.astype("int64") / 1e9  # NaT -> NaN
        t0 = np.nanmin(secs.values)
        return (secs - t0).astype(float)

    if ts_col == "Number":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        t0 = np.nanmin(vals.values)
        return vals - t0

    if ts_col == "MonoNs":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        secs = vals / 1e9
        t0 = np.nanmin(secs.values)
        return secs - t0

    raise ValueError(f"Unsupported timestamp column: {ts_col}")

def _estimate_fps_from_seconds(seconds_series: pd.Series) -> Optional[float]:
    mask = seconds_series.notna()
    if mask.sum() < 2:
        return None
    duration = seconds_series[mask].iloc[-1] - seconds_series[mask].iloc[0]
    if duration <= 0:
        return None
    return mask.sum() / duration

def _resolve_measure_column(df: pd.DataFrame) -> str | None:
    return next((c for c in MEASURE_COLS if c in df.columns), None)

def _compute_envelope(series: pd.Series, win_frames: int) -> np.ndarray:
    """Clip to [0,100] → Hilbert analytic envelope → centered rolling mean; length preserved."""
    series = pd.to_numeric(series, errors="coerce").fillna(0.0).clip(lower=0, upper=100)
    analytic = hilbert(series.to_numpy())
    env = np.abs(analytic)
    return (
        pd.Series(env, index=series.index)
          .rolling(window=win_frames, center=True, min_periods=1)
          .mean()
          .to_numpy()
    )

def _infer_trial_type(p: Path) -> str:
    s = (p.stem + "/" + "/".join(q.name for q in p.parents)).lower()
    if "testing" in s:  return "testing"
    if "training" in s: return "training"
    return "unknown"

def _trial_label(p: Path) -> str:
    """
    Return 'testing_<n>' or 'training_<n>' if found anywhere in the filename;
    otherwise fall back to a clean, short label.
    """
    m = TRIAL_REGEX.search(p.stem)
    if not m:
        # also look up the directory chain in case the file stem lacks the token
        chain = (p.stem + "/" + "/".join(q.name for q in p.parents)).lower()
        m = TRIAL_REGEX.search(chain)
    if m:
        kind, num = m.group(1).lower(), m.group(2)
        return f"{kind}_{num}"

    # Fallback: compress a long stem into something readable
    stem = p.stem
    # try to extract a trailing integer like *_3
    m2 = re.search(r"(\d+)$", stem)
    if m2:
        return f"{_infer_trial_type(p)}_{m2.group(1)}"
    return stem  # last resort

def _find_trial_csvs(fly_dir: Path):
    """Prefer RMS_calculations subtree; fallback to fly root; include testing/training CSVs."""
    search_root = fly_dir / "RMS_calculations"
    if not search_root.is_dir():
        search_root = fly_dir
    patterns = ["**/*testing*.csv", "**/*training*.csv"]
    seen = set()
    for pat in patterns:
        for csv in search_root.glob(pat):
            if csv.is_file():
                rp = csv.resolve()
                if rp not in seen:
                    seen.add(rp)
                    yield rp

# ───────── PASS 1: discover items and determine max row length ─────────
items = []   # [{dataset, fly, csv_path, trial_type, trial_label, measure_col, n_frames}]
max_len = 0

for root in ROOTS:
    root = root.expanduser().resolve()
    assert root.is_dir(), f"Not a directory: {root}"
    dataset = root.name  # e.g., "ACV"

    for fly_dir in sorted(p for p in root.iterdir() if p.is_dir()):
        fly = fly_dir.name
        for csv_path in _find_trial_csvs(fly_dir):
            # Resolve measure column from header
            try:
                header_df = pd.read_csv(csv_path, nrows=0)
            except Exception as e:
                print(f"[WARN] Skip {csv_path.name}: header read error: {e}")
                continue
            col = _resolve_measure_column(header_df)
            if col is None:
                print(f"[SKIP] {csv_path.name}: none of {MEASURE_COLS} present.")
                continue
            # Count frames cheaply
            try:
                n_frames = pd.read_csv(csv_path, usecols=[col]).shape[0]
            except Exception as e:
                print(f"[WARN] Skip {csv_path.name}: count error: {e}")
                continue

            items.append({
                "dataset": dataset,
                "fly": fly,
                "csv_path": csv_path,
                "trial_type": _infer_trial_type(csv_path),
                "trial_label": _trial_label(csv_path),
                "measure_col": col,
                "n_frames": n_frames
            })
            max_len = max(max_len, n_frames)

if not items:
    raise RuntimeError("No eligible testing/training CSVs found in provided roots.")

print(f"[INFO] Datasets: {[r.name for r in ROOTS]}")
print(f"[INFO] Discovered {len(items)} videos. Max frames = {max_len}")

# ───────── PASS 2: compute envelope and write combined wide CSV ─────────
cols = ["dataset", "fly", "trial_type", "trial_label", "fps"] + [f"env_{i}" for i in range(max_len)]
pd.DataFrame(columns=cols).to_csv(OUT_WIDE_CSV, index=False)

for it in items:
    dataset     = it["dataset"]
    fly         = it["fly"]
    csv_path    = it["csv_path"]
    trial_type  = it["trial_type"]
    label       = it["trial_label"]
    measure_col = it["measure_col"]

    # --- Determine FPS from timestamps, if possible ---  ← INSERT HERE
    try:
        hdr2 = pd.read_csv(csv_path, nrows=0)
    except Exception:
        hdr2 = pd.DataFrame()

    frame_col = _pick_frame_column(hdr2) if not hdr2.empty else None
    ts_col    = _pick_timestamp_column(hdr2) if not hdr2.empty else None

    fps = np.nan
    if frame_col is not None and ts_col is not None:
        try:
            # Read just the needed columns
            df_ts = pd.read_csv(csv_path, usecols=[frame_col, ts_col])
            secs  = _to_seconds_series(df_ts, ts_col)
            fps_from_csv = _estimate_fps_from_seconds(secs)
            if fps_from_csv and np.isfinite(fps_from_csv) and fps_from_csv > 0:
                fps = float(fps_from_csv)
            else:
                fps = float(FALLBACK_FPS)
        except Exception as e:
            print(f"[WARN] FPS inference failed for {csv_path.name}: {e}")
            fps = float(FALLBACK_FPS)
    else:
        # Couldn’t find both a frame and a timestamp column; mimic prior script’s constant fallback
        fps = float(FALLBACK_FPS)
    # --- END FPS BLOCK ---

    # Now read the measure column and compute envelope
    try:
        df = pd.read_csv(csv_path, usecols=[measure_col])
    except Exception as e:
        print(f"[WARN] Read failed {csv_path}: {e}")
        continue

    env = _compute_envelope(df[measure_col], WINDOW_FRAMES).astype(float)

    # Include fps in the output row
    row = [dataset, fly, trial_type, label, fps] + list(env)

    # pad/truncate to max_len
    if len(env) < max_len:
        row += [np.nan] * (max_len - len(env))
    elif len(env) > max_len:
        row = row[:5 + max_len]  # account for 5 metadata cols now

    pd.DataFrame([row], columns=cols).to_csv(OUT_WIDE_CSV, mode="a", header=False, index=False)

print(f"[OK] Wrote combined envelope table: {OUT_WIDE_CSV}")

In [None]:
# JUPYTER CELL — Convert wide envelope CSV → 16-bit numeric matrix + code key
from pathlib import Path
import numpy as np
import pandas as pd
import json

# ===== INPUT / OUTPUT =====
INPUT_CSV = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/all_envelope_rows_wide.csv")  # change if your file lives elsewhere
OUT_DIR   = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/")                           # change if desired
OUT_DIR.mkdir(parents=True, exist_ok=True)

MATRIX_NPY = OUT_DIR / "envelope_matrix_float16.npy"   # 16-bit floating matrix
CODE_KEY   = OUT_DIR / "code_key.txt"                  # human-readable mapping & schema
CODES_JSON = OUT_DIR / "code_maps.json"                # machine-readable mappings (optional)

# ===== LOAD =====
df = pd.read_csv(INPUT_CSV)

# Identify metadata columns (present subset)
meta_cols_all = ["dataset", "fly", "trial_type", "trial_label", "fps"]
meta_cols = [c for c in meta_cols_all if c in df.columns]
assert meta_cols, "No metadata columns found. Expected at least one of: dataset, fly, trial_type, trial_label."

# Envelope columns (everything else)
env_cols = [c for c in df.columns if c not in meta_cols]
assert len(env_cols) > 0, "No envelope columns found."

# ===== BUILD INTEGER CODES FOR METADATA =====
# Codes start at 1; 0 is reserved for 'unknown'
code_maps = {}
for col in meta_cols:
    uniques = pd.Series(df[col].astype(str).fillna("UNKNOWN")).unique().tolist()
    mapping = {"UNKNOWN": 0}
    next_code = 1
    for u in uniques:
        if u not in mapping:
            mapping[u] = next_code
            next_code += 1
    code_maps[col] = mapping

# Apply codes to a copy
df_num = df.copy()
for col, mapping in code_maps.items():
    df_num[col] = df_num[col].astype(str).map(mapping).fillna(0).astype(np.int32)

# Ensure envelope columns are numeric and NaN-free
df_num[env_cols] = df_num[env_cols].apply(pd.to_numeric, errors="coerce")
df_num[env_cols] = df_num[env_cols].fillna(0.0)

# ===== BUILD THE MATRIX (float16) =====
# Order: [meta_cols...] + [env_0...env_N]
ordered_cols = meta_cols + env_cols
matrix_f16 = df_num[ordered_cols].to_numpy(dtype=np.float16)

# ===== SAVE ARTIFACTS =====
np.save(MATRIX_NPY, matrix_f16)

# Human-readable key file
with CODE_KEY.open("w", encoding="utf-8") as f:
    f.write("# Envelope matrix schema (float16), row-wise\n")
    f.write("# Columns (in order):\n")
    for i, col in enumerate(ordered_cols):
        f.write(f"{i:>5}: {col}\n")
    f.write("\n# Metadata code maps (string → integer code)\n")
    for col in meta_cols:
        f.write(f"\n[{col}]\n")
        # Sort by numeric code
        inv = sorted(((code, name) for name, code in code_maps[col].items()), key=lambda x: x[0])
        for code, name in inv:
            f.write(f"{code:>5} : {name}\n")
    f.write("\nNotes:\n")
    f.write("- Matrix dtype is float16 (16-bit). Metadata codes are stored as float16 numbers in the matrix.\n")
    f.write("- Envelope NaNs (shorter videos) were replaced with 0.0.\n")
    f.write("- Code '0' means UNKNOWN for the metadata fields.\n")

# Optional: machine-readable mappings
with CODES_JSON.open("w", encoding="utf-8") as jf:
    json.dump({"column_order": ordered_cols, "code_maps": code_maps}, jf, indent=2)

print(f"[OK] Saved 16-bit matrix: {MATRIX_NPY}  (shape={matrix_f16.shape}, dtype={matrix_f16.dtype})")
print(f"[OK] Saved key:           {CODE_KEY}")
print(f"[OK] Saved JSON maps:     {CODES_JSON}")


## Matrix's

In [None]:
# JUPYTER CELL — Reaction matrices per odor + fly-category counts (During & After)
from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec
from matplotlib.patches import Patch

# ───────── USER KNOB: spacing between rows (increase this to add space)
ROW_GAP = 0.6
HEIGHT_PER_GAP_IN = 3.0
BOTTOM_SHIFT_IN = 0.50

# ───────── PARAMETERS ─────────
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json")
FPS_DEFAULT       = 40
BEFORE_SEC        = 30.0
DURING_SEC        = 30.0
AFTER_WINDOW_SEC  = 30.0
THRESH_STD_MULT   = 4
MIN_SAMPLES_OVER  = 20

ODOR_TRANSIT_LAT_S = overall_mean_latency_s

OUT_DIR           = Path("/home/ramanlab/Documents/cole/Results/Opto/Matrixs_DIST")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Canon keys for grouping
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics Ethyl Butyrate": "opto_EB",
    "Optogenetics benzaldehyde": "opto_benz_1",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
    "opto_benz_1": "Benzaldehyde",
}
ODOR_ORDER = ["ACV", "3-octonol", "Benz", "EB", "10s_Odor_Benz", "opto_benz", "opto_EB", "opto_benz_1"]

# ───────── LOAD + DECODE ─────────
matrix = np.load(MATRIX_NPY)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)
ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset", "fly", "trial_type", "trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])
df = pd.DataFrame(matrix, columns=ordered_cols)

for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

if "fps" in df.columns:
    if "fps" in rev_maps:
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

df = df[df["trial_type"].str.lower() == "testing"].copy()

FPS_FALLBACK = FPS_DEFAULT
df["fps"] = df["fps"].fillna(FPS_FALLBACK).replace([np.inf, -np.inf], FPS_FALLBACK)

env_cols = [c for c in ordered_cols if c not in meta_cols]

def _canon_odor(s: str) -> str:
    if not isinstance(s, str): return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())
df["dataset_canon"] = df["dataset"].apply(_canon_odor)

def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # hexanol controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

def score_trial_from_env(env_row: pd.Series, fps: float) -> tuple[int, int]:
    env = env_row.to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    if env.size == 0:
        return (0, 0)
    total = env.size
    b_end   = int(round(BEFORE_SEC * fps))
    shift   = int(round(ODOR_TRANSIT_LAT_S * fps))
    d_start = b_end + shift
    d_end   = b_end + int(round(DURING_SEC * fps)) + shift
    a_end   = d_end + int(round(AFTER_WINDOW_SEC * fps))

    b_end   = max(0, min(b_end, total))
    d_start = max(b_end, min(d_start, total))
    d_end   = max(d_start, min(d_end, total))
    a_end   = max(d_end, min(a_end, total))

    before = env[:b_end]
    during = env[d_start:d_end]
    after  = env[d_end:a_end]

    if before.size == 0:
        return (0, 0)

    theta = float(np.nanmean(before)) + THRESH_STD_MULT * float(np.nanstd(before))
    during_hit = int(np.sum(during > theta) >= MIN_SAMPLES_OVER) if during.size else 0
    after_hit  = int(np.sum(after  > theta) >= MIN_SAMPLES_OVER) if after.size  else 0
    return during_hit, after_hit

# ───────── Score all rows ─────────
scores = []
for _, row in df.iterrows():
    row_fps = float(row.get("fps", FPS_FALLBACK))
    d_hit, a_hit = score_trial_from_env(row[env_cols], row_fps)
    scores.append({
        "dataset": row["dataset_canon"],
        "fly": row["fly"],
        "trial": row["trial_label"],
        "trial_num": _trial_num(row["trial_label"]),
        "during_hit": d_hit,
        "after_hit": a_hit
    })
scores_df = pd.DataFrame(scores)

# ───────── Colormaps and helpers ─────────
cmap = ListedColormap(["0.7", "1.0", "0.0"])  # gray, white, black
norm = BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap.N)

def style_trained_xticks_vertical(ax, labels, trained_disp: str, fontsize: int):
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize=fontsize)
    txts = []
    for tick in ax.get_xticklabels():
        txt = tick.get_text()
        if txt.strip().lower() == trained_disp.lower():
            tick.set_text(trained_disp.upper())
            tick.set_color("tab:blue")
        txts.append(tick.get_text())
    ax.set_xticklabels(txts, rotation=90, ha="center", va="top", fontsize=fontsize)
    ax.tick_params(axis="x", pad=2)

def compute_fly_category_counts(mat: np.ndarray, labels: list[str], trained_disp: str, include_hexanol: bool = False):
    if mat.size == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    trained_idx = [j for j, lab in enumerate(labels)
                   if lab.strip().lower() == trained_disp.lower()]
    other_idx = [j for j, lab in enumerate(labels)
                 if lab.strip().lower() != trained_disp.lower()
                 and (include_hexanol or lab.strip().lower() != "hexanol")]
    if len(trained_idx) == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    counts = {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    for i in range(mat.shape[0]):
        row = mat[i, :]
        row = np.where(row < 0, 0, row)
        t_hit = np.any(row[trained_idx] == 1)
        o_hit = np.any(row[other_idx]   == 1) if len(other_idx) else False
        if t_hit and not o_hit:
            counts["Trained only"] += 1
        elif t_hit and o_hit:
            counts["Trained + Others"] += 1
        elif (not t_hit) and o_hit:
            counts["Others only"] += 1
    return counts

def plot_category_counts(ax, counts: dict, n_flies: int, title: str):
    cats = ["Trained only", "Trained + Others", "Others only"]
    raw = np.array([counts.get(c, 0) for c in cats], dtype=float)
    vals_pct = 100.0 * raw / float(n_flies) if n_flies > 0 else np.zeros_like(raw)
    x = np.arange(len(cats))
    bars = ax.bar(x, vals_pct, width=0.75, edgecolor="black", linewidth=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(cats, rotation=15, ha="right")
    ax.set_ylim(0, 100)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_ylabel("% of flies")
    ax.set_title(title, fontsize=12, weight="bold")
    ax.margins(x=0.05)
    for b, pct in zip(bars, vals_pct):
        ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.5, f"{pct:.0f}%", ha="center", va="bottom", fontsize=9)

def shade_latency_on_timeseries(ax, before_sec: float = BEFORE_SEC, latency_s: float = ODOR_TRANSIT_LAT_S):
    x0 = before_sec
    x1 = before_sec + latency_s
    ax.axvspan(x0, x1, color="red", alpha=0.30, lw=0)

# ──────── helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ───────── Build & save per-odor figures (per-odor subfolders) ─────────
present = scores_df["dataset"].unique().tolist()
ordered_present = [o for o in ODOR_ORDER if o in present]
extras = sorted([o for o in present if o not in ODOR_ORDER])
for odor in ordered_present + extras:
    sub = scores_df[scores_df["dataset"] == odor].copy()
    if sub.empty:
        print(f"[WARN] No testing trials for {odor}")
        continue

    # per-odor output directory
    odir = OUT_DIR / _safe_dirname(odor)
    odir.mkdir(parents=True, exist_ok=True)

    flies  = sorted(sub["fly"].unique())
    trials = sorted(sub["trial"].unique(), key=_trial_num)
    pretty_cols = [display_odor_for_trial(odor, t) for t in trials]

    D = -np.ones((len(flies), len(trials)), dtype=int)
    A = -np.ones((len(flies), len(trials)), dtype=int)
    for i, fly in enumerate(flies):
        fly_rows = sub[sub["fly"] == fly]
        for j, t in enumerate(trials):
            s = fly_rows[fly_rows["trial"] == t]
            if s.empty: continue
            D[i, j] = int(s["during_hit"].iloc[0])
            A[i, j] = int(s["after_hit"].iloc[0])

    odor_label   = DISPLAY_LABEL.get(odor, odor)
    trained_disp = DISPLAY_LABEL.get(odor, odor)
    n_flies = len(flies)
    n_trials = len(trials)

    base_fig_w = max(10.0, 0.70 * n_trials + 6.0)
    base_fig_h = max(5.0, n_flies * 0.26 + 3.8)
    fig_w = base_fig_w
    fig_h = base_fig_h + ROW_GAP * HEIGHT_PER_GAP_IN
    fig_h += BOTTOM_SHIFT_IN
    xtick_fs = 9 if n_trials <= 10 else (8 if n_trials <= 16 else 7)

    during_counts = compute_fly_category_counts(D, pretty_cols, trained_disp, include_hexanol=True)
    after_counts  = compute_fly_category_counts(A, pretty_cols, trained_disp, include_hexanol=True)

    fig = plt.figure(figsize=(fig_w, fig_h), constrained_layout=False)
    gs  = gridspec.GridSpec(2, 2, height_ratios=[3.0, 1.25], width_ratios=[1, 1], hspace=ROW_GAP, wspace=0.10)

    axD  = fig.add_subplot(gs[0, 0])
    axA  = fig.add_subplot(gs[0, 1])
    axDc = fig.add_subplot(gs[1, 0])
    axAc = fig.add_subplot(gs[1, 1])

    imD = axD.imshow(D, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axD.set_title(f"{odor_label} — During\n(DURING shifted by +{ODOR_TRANSIT_LAT_S:.2f} s)", fontsize=14, weight="bold", linespacing=1.1)
    style_trained_xticks_vertical(axD, pretty_cols, trained_disp, fontsize=xtick_fs)
    axD.set_yticks([]); axD.set_ylabel(f"{n_flies} Flies", fontsize=11)

    imA = axA.imshow(A, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axA.set_title(f"{odor_label} — After (first {int(AFTER_WINDOW_SEC)} s)", fontsize=14, weight="bold")
    style_trained_xticks_vertical(axA, pretty_cols, trained_disp, fontsize=xtick_fs)
    axA.set_yticks([]); axA.set_ylabel(f"{n_flies} Flies", fontsize=11)

    plot_category_counts(axDc, during_counts, n_flies, title="During — Fly Reaction Categories")
    plot_category_counts(axAc, after_counts,  n_flies, title=f"After (first {int(AFTER_WINDOW_SEC)} s) — Fly Reaction Categories")

    red_patch = Patch(facecolor="red", edgecolor="red", alpha=0.30, label=f"Odor transit {ODOR_TRANSIT_LAT_S:.2f} s (pre-DURING)")
    axD.legend(handles=[red_patch], loc="upper left", frameon=True, fontsize=9)

    shift_frac = BOTTOM_SHIFT_IN / fig_h
    for ax in (axDc, axAc):
        pos = ax.get_position()
        new_y0 = max(0.05, pos.y0 - shift_frac)
        ax.set_position([pos.x0, new_y0, pos.width, pos.height])

    # Save into per-odor folder
    out_png = odir / f"reaction_matrix_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}_latency_{ODOR_TRANSIT_LAT_S:.3f}s.png"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"[OK] saved {out_png}")

    key_path = odir / f"row_key_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}.txt"
    with key_path.open("w") as fh:
        for i, fly in enumerate(flies):
            fh.write(f"Row {i}: {fly}\n")
    print(f"[OK] saved {key_path}")

print("[DONE] Per-odor exports saved into subfolders under OUT_DIR.)")


## Matrix's not in testing trial order

In [None]:
# JUPYTER CELL — Reaction matrices per odor + fly-category counts (During & After)
from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec

# ───────── USER KNOB: spacing between rows (increase this to add space)
ROW_GAP = 0.6            # try 0.10 … 0.60 (higher = more space between rows)
HEIGHT_PER_GAP_IN = 3.0  # how many inches of figure height to add per 1.0 ROW_GAP
BOTTOM_SHIFT_IN = 0.50   # inches to lower the bottom row; increase to move further down

# ───────── PARAMETERS ─────────
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json")
FPS_DEFAULT       = 40
BEFORE_SEC        = 30.0
DURING_SEC        = 30.0
AFTER_WINDOW_SEC  = 30.0
THRESH_STD_MULT   = 4
MIN_SAMPLES_OVER  = 20

# Shift DURING window by latency at both start and end:
# e.g., DURING [30,60] → [30+lat, 60+lat]
ODOR_TRANSIT_LAT_S = overall_mean_latency_s

OUT_DIR           = Path("/home/ramanlab/Documents/cole/Results/Opto/Matrixs_DIST")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Canon keys for grouping
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics benzaldehyde": "opto_benz_1",
    "Optogenetics Ethyl Butyrate": "opto_EB",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
    "opto_benz_1": "Benzaldehyde",
}
ODOR_ORDER = ["ACV", "3-octonol", "Benz", "EB", "10s_Odor_Benz", "opto_benz", "opto_EB", "opto_benz_1"]

# ───────── LOAD + DECODE ─────────
matrix = np.load(MATRIX_NPY)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)
ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset", "fly", "trial_type", "trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])
df = pd.DataFrame(matrix, columns=ordered_cols)

# decode label-coded columns
for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

# ensure fps numeric
if "fps" in df.columns:
    if "fps" in rev_maps:
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

# testing only
df = df[df["trial_type"].str.lower() == "testing"].copy()

# fill missing fps
FPS_FALLBACK = FPS_DEFAULT
df["fps"] = df["fps"].fillna(FPS_FALLBACK).replace([np.inf, -np.inf], FPS_FALLBACK)

# envelope columns exclude meta
env_cols = [c for c in ordered_cols if c not in meta_cols]

def _canon_odor(s: str) -> str:
    if not isinstance(s, str): return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())
df["dataset_canon"] = df["dataset"].apply(_canon_odor)

def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

# ───────── Custom trial→display-odor mapping per dataset ─────────
def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # hexanol controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

# ───────── Scoring on envelope row ─────────
def score_trial_from_env(env_row: pd.Series, fps: float) -> tuple[int, int]:
    """
    Compute During/After hits using a baseline from BEFORE.
    DURING is fully shifted by ODOR_TRANSIT_LAT_S at start and end:
      DURING: [BEFORE_SEC + ODOR_TRANSIT_LAT_S, BEFORE_SEC + DURING_SEC + ODOR_TRANSIT_LAT_S]
      AFTER:  the next AFTER_WINDOW_SEC immediately following DURING.
    """
    env = env_row.to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    if env.size == 0:
        return (0, 0)

    total = env.size

    # Indices (in samples)
    b_end   = int(round(BEFORE_SEC * fps))  # end of BEFORE
    shift   = int(round(ODOR_TRANSIT_LAT_S * fps))
    d_start = b_end + shift
    d_end   = b_end + int(round(DURING_SEC * fps)) + shift
    a_end   = d_end + int(round(AFTER_WINDOW_SEC * fps))

    # Clip/guard
    b_end   = max(0, min(b_end, total))
    d_start = max(b_end, min(d_start, total))
    d_end   = max(d_start, min(d_end, total))
    a_end   = max(d_end, min(a_end, total))

    # Windows
    before = env[:b_end]
    during = env[d_start:d_end]   # fully shifted DURING window
    after  = env[d_end:a_end]

    if before.size == 0:
        return (0, 0)

    # Threshold from BEFORE baseline
    theta = float(np.nanmean(before)) + THRESH_STD_MULT * float(np.nanstd(before))

    # Hits (require at least MIN_SAMPLES_OVER above theta)
    during_hit = int(np.sum(during > theta) >= MIN_SAMPLES_OVER) if during.size else 0
    after_hit  = int(np.sum(after  > theta) >= MIN_SAMPLES_OVER) if after.size  else 0
    return during_hit, after_hit

# ───────── Score all rows ─────────
scores = []
for _, row in df.iterrows():
    row_fps = float(row.get("fps", FPS_FALLBACK))
    d_hit, a_hit = score_trial_from_env(row[env_cols], row_fps)
    scores.append({
        "dataset": row["dataset_canon"],
        "fly": row["fly"],
        "trial": row["trial_label"],
        "trial_num": _trial_num(row["trial_label"]),
        "during_hit": d_hit,
        "after_hit": a_hit
    })
scores_df = pd.DataFrame(scores)

# ───────── Colormaps and helpers ─────────
cmap = ListedColormap(["0.7", "1.0", "0.0"])  # gray, white, black
norm = BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap.N)

def style_trained_xticks_vertical(ax, labels, trained_disp: str, fontsize: int):
    """Vertical x labels; trained odor BLUE + UPPERCASE."""
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize=fontsize)
    txts = []
    for tick in ax.get_xticklabels():
        txt = tick.get_text()
        if txt.strip().lower() == trained_disp.lower():
            tick.set_text(trained_disp.upper())
            tick.set_color("tab:blue")
        txts.append(tick.get_text())
    ax.set_xticklabels(txts, rotation=90, ha="center", va="top", fontsize=fontsize)
    ax.tick_params(axis="x", pad=2)

def compute_fly_category_counts(mat: np.ndarray, labels: list[str], trained_disp: str, include_hexanol: bool = False):
    if mat.size == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}

    trained_idx = [j for j, lab in enumerate(labels)
                   if lab.strip().lower() == trained_disp.lower()]

    other_idx = [j for j, lab in enumerate(labels)
                 if lab.strip().lower() != trained_disp.lower()
                 and (include_hexanol or lab.strip().lower() != "hexanol")]
    if len(trained_idx) == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}

    counts = {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    for i in range(mat.shape[0]):
        row = mat[i, :]
        row = np.where(row < 0, 0, row)  # treat missing (-1) as 0 for categorization
        t_hit = np.any(row[trained_idx] == 1)
        o_hit = np.any(row[other_idx]   == 1) if len(other_idx) else False

        if t_hit and not o_hit:
            counts["Trained only"] += 1
        elif t_hit and o_hit:
            counts["Trained + Others"] += 1
        elif (not t_hit) and o_hit:
            counts["Others only"] += 1
    return counts

def plot_category_counts(ax, counts: dict, n_flies: int, title: str):
    cats = ["Trained only", "Trained + Others", "Others only"]
    raw = np.array([counts.get(c, 0) for c in cats], dtype=float)
    vals_pct = 100.0 * raw / float(n_flies) if n_flies > 0 else np.zeros_like(raw)

    x = np.arange(len(cats))
    bars = ax.bar(x, vals_pct, width=0.75, edgecolor="black", linewidth=0.8)

    ax.set_xticks(x)
    ax.set_xticklabels(cats, rotation=15, ha="right")
    ax.set_ylim(0, 100)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_ylabel("% of flies")
    ax.set_title(title, fontsize=12, weight="bold")
    ax.margins(x=0.05)

    for b, pct in zip(bars, vals_pct):
        ax.text(b.get_x() + b.get_width()/2,
                b.get_height() + 1.5,
                f"{pct:.0f}%",
                ha="center", va="bottom", fontsize=9)

# ─────── helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ───────── Build & save per-odor figures (into per-odor subfolders) ─────────
# Only iterate over odors present; preserve preferred order, then extras
present = scores_df["dataset"].unique().tolist()
ordered_present = [o for o in ODOR_ORDER if o in present]
extras = sorted([o for o in present if o not in ODOR_ORDER])

for odor in ordered_present + extras:
    sub = scores_df[scores_df["dataset"] == odor].copy()
    if sub.empty:
        print(f"[WARN] No testing trials for {odor}")
        continue

    # per-odor output directory
    odir = OUT_DIR / _safe_dirname(odor)
    odir.mkdir(parents=True, exist_ok=True)

    flies  = sorted(sub["fly"].unique())
    # Build trial order: trained odor first (2,4,5), then 1,3,6,7,8,9
    existing_trials = list(sub["trial"].unique())

    def _tnum(lbl):
        m = re.search(r"(\d+)", str(lbl))
        return int(m.group(1)) if m else -1

    desired_order = [2, 4, 5, 1, 3, 6, 7, 8, 9]

    by_num = {}
    for t in existing_trials:
        n = _tnum(t)
        if n not in by_num:
            by_num[n] = t

    ordered_trials = [by_num[n] for n in desired_order if n in by_num]
    leftovers = sorted([n for n in by_num.keys() if n not in set(desired_order) and n >= 0])
    ordered_trials += [by_num[n] for n in leftovers]

    trials = ordered_trials
    pretty_cols = [display_odor_for_trial(odor, t) for t in trials]

    # Matrices with sentinel -1 for missing, else 0/1
    D = -np.ones((len(flies), len(trials)), dtype=int)
    A = -np.ones((len(flies), len(trials)), dtype=int)
    for i, fly in enumerate(flies):
        fly_rows = sub[sub["fly"] == fly]
        for j, t in enumerate(trials):
            s = fly_rows[fly_rows["trial"] == t]
            if s.empty: continue
            D[i, j] = int(s["during_hit"].iloc[0])
            A[i, j] = int(s["after_hit"].iloc[0])

    odor_label   = DISPLAY_LABEL.get(odor, odor)
    trained_disp = DISPLAY_LABEL.get(odor, odor)
    n_flies = len(flies)
    n_trials = len(trials)

    # Figure size (height grows with flies and with ROW_GAP)
    base_fig_w = max(10.0, 0.70 * n_trials + 6.0)
    base_fig_h = max(5.0, n_flies * 0.26 + 3.8)
    fig_w = base_fig_w
    fig_h = base_fig_h + ROW_GAP * HEIGHT_PER_GAP_IN
    fig_h += BOTTOM_SHIFT_IN   # keep layout comfortable while lowering bottom row

    xtick_fs = 9 if n_trials <= 10 else (8 if n_trials <= 16 else 7)

    # NEW: Compute fly-category counts for During & After
    during_counts = compute_fly_category_counts(D, pretty_cols, trained_disp, include_hexanol=True)
    after_counts  = compute_fly_category_counts(A, pretty_cols, trained_disp, include_hexanol=True)

    # Create figure (manual layout)
    fig = plt.figure(figsize=(fig_w, fig_h), constrained_layout=False)
    gs  = gridspec.GridSpec(
        2, 2,
        height_ratios=[3.0, 1.25],
        width_ratios=[1, 1],
        hspace=ROW_GAP,
        wspace=0.10
    )

    axD  = fig.add_subplot(gs[0, 0])   # top-left  (During matrix)
    axA  = fig.add_subplot(gs[0, 1])   # top-right (After matrix)
    axDc = fig.add_subplot(gs[1, 0])   # bottom-left  (During categories)
    axAc = fig.add_subplot(gs[1, 1])   # bottom-right (After categories)

    # Top: matrices — vertical x labels, trained odor in blue
    imD = axD.imshow(D, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axD.set_title(
        f"{odor_label} — During (shifted +{ODOR_TRANSIT_LAT_S:.2f}s)",
        fontsize=14, weight="bold"
    )
    style_trained_xticks_vertical(axD, pretty_cols, trained_disp, fontsize=xtick_fs)
    axD.set_yticks([]); axD.set_ylabel(f"{n_flies} Flies", fontsize=11)

    imA = axA.imshow(A, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axA.set_title(f"{odor_label} — After (first {int(AFTER_WINDOW_SEC)} s)", fontsize=14, weight="bold")
    style_trained_xticks_vertical(axA, pretty_cols, trained_disp, fontsize=xtick_fs)
    axA.set_yticks([]); axA.set_ylabel(f"{n_flies} Flies", fontsize=11)

    # Bottom: category count bars
    plot_category_counts(axDc, during_counts, n_flies, title="During — Fly Reaction Categories")
    plot_category_counts(axAc, after_counts,  n_flies, title=f"After (first {int(AFTER_WINDOW_SEC)} s) — Fly Reaction Categories")

    # Lower the bottom row by BOTTOM_SHIFT_IN (inches)
    shift_frac = BOTTOM_SHIFT_IN / fig_h
    for ax in (axDc, axAc):
        pos = ax.get_position()
        new_y0 = max(0.05, pos.y0 - shift_frac)
        ax.set_position([pos.x0, new_y0, pos.width, pos.height])

    # Save — into per-odor folder
    out_png = odir / f"reaction_matrix_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}_latency_{ODOR_TRANSIT_LAT_S:.3f}s_unordered.png"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"[OK] saved {out_png}")

    # Row index → fly key — into per-odor folder
    key_path = odir / f"row_key_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}.txt"
    with key_path.open("w") as fh:
        for i, fly in enumerate(flies):
            fh.write(f"Row {i}: {fly}\n")
    print(f"[OK] saved {key_path}")

    # ───────── CSV per odor with actual odor names — into per-odor folder ─────────
    sub_for_csv = sub.copy()
    sub_for_csv["odor_sent"] = sub_for_csv["trial"].apply(lambda t: display_odor_for_trial(odor, t))
    order_map = {t: i for i, t in enumerate(trials)}
    sub_for_csv["trial_ord"] = sub_for_csv["trial"].map(order_map).fillna(10**9).astype(int)
    sub_for_csv = sub_for_csv.sort_values(["fly", "trial_ord", "trial_num", "trial"])

    export_cols = ["dataset", "fly", "trial_num", "odor_sent", "during_hit", "after_hit"]
    out_csv = odir / f"binary_reactions_{odor.replace(' ', '_')}.csv"
    sub_for_csv[export_cols].to_csv(out_csv, index=False)
    print(f"[OK] saved {out_csv}")

print("[DONE] Per-odor exports saved into subfolders under OUT_DIR.)")

## Envople

In [None]:
# JUPYTER CELL — Per-fly envelope plots from MATRIX with trained-odor styling
# (after-period limited to 30 s) + per-trial threshold line

from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt

# ========= PARAMETERS =========
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json")

FPS_DEFAULT       = 40.0       # fallback if fps missing/invalid
ODOR_ON_S         = 30.0
ODOR_OFF_S        = 60.0
AFTER_SHOW_S      = 30.0       # show only first 30 s after odor OFF

# Threshold params — matches your scoring code design (baseline is [0, ODOR_ON_S))
THRESH_STD_MULT   = 4.0        # θ = μ_before + k·σ

# Odor transit latency (mean time for plume to reach the fly)
ODOR_TRANSIT_LAT_S = overall_mean_latency_s

# IMPORTANT: extend visible window so AFTER is measured from shifted OFF
X_MAX_LIMIT       = ODOR_OFF_S + ODOR_TRANSIT_LAT_S + AFTER_SHOW_S

OUT_DIR = Path("/home/ramanlab/Documents/cole/Results/Opto/Envlope_DIST")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ========= LOAD MATRIX + METADATA =========
matrix = np.load(MATRIX_NPY, allow_pickle=False)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)

ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset","fly","trial_type","trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])

df = pd.DataFrame(matrix, columns=ordered_cols)

# Decode labels -> strings
for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

# Ensure fps exists and is numeric
if "fps" in df.columns:
    if "fps" in rev_maps:  # if coded (rare)
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

# Keep only testing trials
df = df[df["trial_type"].str.lower()=="testing"].copy()

# Fill missing/invalid fps with fallback
df["fps"] = df["fps"].replace([np.inf, -np.inf], np.nan).fillna(FPS_DEFAULT)

# env_* columns exclude meta (including fps)
env_cols  = [c for c in ordered_cols if c not in meta_cols]

# ========= Canon keys & display names =========
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics benzaldehyde": "opto_benz_1",
    "Optogenetics Ethyl Butyrate": "opto_EB",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_benz_1": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
}

def _canon_dataset(s: str) -> str:
    if not isinstance(s, str):
        return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())

df["dataset_canon"] = df["dataset"].apply(_canon_dataset)

# helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ========= Helpers =========
def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

def _extract_env(row: pd.Series) -> np.ndarray:
    env = row[env_cols].to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    return env

def _compute_theta(env_full: np.ndarray, fps: float) -> float:
    """θ = mean(before) + k*std(before), where before = [0, ODOR_ON_S)."""
    if env_full.size == 0 or fps <= 0:
        return np.nan
    b_end = int(ODOR_ON_S * fps)
    b_end = min(b_end, env_full.size)
    before = env_full[:b_end]
    if before.size == 0:
        return np.nan
    mu = float(np.nanmean(before))
    sd = float(np.nanstd(before))
    return mu + THRESH_STD_MULT * sd

def _is_trained_odor(dataset_canon: str, odor_name: str) -> bool:
    trained = DISPLAY_LABEL.get(dataset_canon, dataset_canon)
    return str(odor_name).strip().lower() == str(trained).strip().lower()

def style_trained_title(ax, odor_label: str):
    ax.set_title(
        odor_label.upper(),
        loc="left",
        fontsize=11,
        weight="bold",
        pad=2,
        color="tab:blue",
    )

# ========= MAKE FIGURES PER FLY =========
for fly, g in df.groupby("fly"):
    g = g.sort_values("trial_label", key=lambda s: s.map(_trial_num))
    dataset_canon = _canon_dataset(g["dataset"].iloc[0])

    # (odor_name, t_visible, env_visible, theta, is_trained)
    trial_curves = []
    y_max = 0.0

    for _, row in g.iterrows():
        env_full = _extract_env(row)
        if env_full.size == 0:
            continue

        row_fps = float(row.get("fps", FPS_DEFAULT)) if np.isfinite(row.get("fps", np.nan)) else FPS_DEFAULT
        t_full = np.arange(env_full.size, dtype=float) / max(row_fps, 1e-9)

        theta = _compute_theta(env_full, row_fps)

        # Clip to [0, X_MAX_LIMIT] for visualization
        mask = (t_full <= X_MAX_LIMIT + 1e-9)
        t = t_full[mask]
        env = env_full[mask]
        if t.size == 0:
            continue

        odor_name = display_odor_for_trial(dataset_canon, row["trial_label"])
        trial_curves.append((odor_name, t, env, theta, _is_trained_odor(dataset_canon, odor_name)))

        local_max = np.nanmax(env) if np.isfinite(env).any() else 0.0
        if np.isfinite(theta):
            local_max = max(local_max, theta)
        y_max = max(y_max, float(local_max))

    if not trial_curves:
        print(f"[WARN] {fly}: no usable testing trials; skipping.")
        continue

    plt.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "axes.spines.top": False, "axes.spines.right": False,
        "axes.linewidth": 0.8, "xtick.direction": "out", "ytick.direction": "out",
        "font.size": 10,
    })

    n = len(trial_curves)
    fig_h = max(3.0, n * 1.6 + 1.5)
    fig, axes = plt.subplots(n, 1, figsize=(10, fig_h), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, (odor_name, t, env, theta, is_trained) in zip(axes, trial_curves):
        ax.plot(t, env, linewidth=1.2, color='black')

        # Nominal valve timing markers (hardware command times)
        ax.axvline(ODOR_ON_S,  linestyle='--', linewidth=1.0, color='black')
        ax.axvline(ODOR_OFF_S, linestyle='--', linewidth=1.0, color='black')

        # Effective plume windows using latency:
        on_lat_end   = min(ODOR_ON_S  + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        off_lat_end  = min(ODOR_OFF_S + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        eff_on_start = min(on_lat_end, X_MAX_LIMIT)
        eff_on_end   = min(off_lat_end, X_MAX_LIMIT)

        # Shade start latency (red) and effective ON (gray)
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_ON_S, on_lat_end, alpha=0.25, color='red')
        if eff_on_end > eff_on_start:
            ax.axvspan(eff_on_start, eff_on_end, alpha=0.15, color='gray')

        # Shade end latency (red)
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_OFF_S, off_lat_end, alpha=0.25, color='red')

        # Threshold line
        if np.isfinite(theta):
            ax.axhline(theta, linestyle='-', linewidth=1.0, color='tab:red', alpha=0.9)

        ax.set_ylim(0, y_max * 1.02 if y_max > 0 else 1.0)
        ax.set_xlim(0, X_MAX_LIMIT)
        ax.margins(x=0, y=0.02)
        ax.set_ylabel("RMS (a.u.)", fontsize=10)

        if is_trained:
            style_trained_title(ax, odor_name)
        else:
            ax.set_title(odor_name, loc="left", fontsize=11, weight="bold", pad=2, color="black")

    axes[-1].set_xlabel("Time (s)", fontsize=11)

    # SINGLE legend on the figure
    on_handle      = plt.Line2D([0], [0], linestyle='--', linewidth=1.0, color='black', label='Valve on/off (command)')
    transit_handle = plt.Rectangle((0,0), 1, 1, alpha=0.25, color='red',  label=f'Odor transit (~{ODOR_TRANSIT_LAT_S:.2f}s)')
    span_handle    = plt.Rectangle((0,0), 1, 1, alpha=0.15, color='gray', label='Effective odor-on at fly')
    theta_handle   = plt.Line2D([0], [0], linestyle='-',  linewidth=1.0, color='tab:red', label=r'$\theta = \mu_{\mathrm{before}} + k\,\sigma_{\mathrm{before}}$')

    fig.legend(
        handles=[on_handle, transit_handle, span_handle, theta_handle],
        labels=[
            'Valve on/off (command)',
            f'Odor transit (~{ODOR_TRANSIT_LAT_S:.2f}s) — start & end',
            'Effective odor-on at fly',
            r'$\theta = \mu_\mathrm{before} + k\,\sigma_\mathrm{before}$'
        ],
        title=f'Threshold: k = {int(THRESH_STD_MULT) if THRESH_STD_MULT.is_integer() else THRESH_STD_MULT}',
        loc='upper right',
        bbox_to_anchor=(0.98, 0.97),
        frameon=True,
        fontsize=9,
        title_fontsize=9,
    )

    fig.suptitle(f"{fly} RMS of Proboscis - Eye Distance Percentage", y=0.995, fontsize=14, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    # === SAVE: write this fly's figure into each odor-specific folder observed for this fly ===
    odors_present = sorted({name for (name, _, _, _, _) in trial_curves})
    for odor_name in odors_present:
        odir = OUT_DIR / _safe_dirname(odor_name)
        odir.mkdir(parents=True, exist_ok=True)
        out_png = odir / f"{fly}_envelope_trials_by_odor_{AFTER_SHOW_S}_shifted.png"
        fig.savefig(out_png)
        print(f"[OK] Saved {out_png}")

    plt.close(fig)

# Training code

In [None]:
#!/usr/bin/env python3
# fly_envelope_over_time.py — analytic envelope via Hilbert transform over time per trial

import re
from pathlib import Path
from typing import Optional, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import hilbert

# ───────────────────────────────── CONFIG ───────────────────────────────
DEFAULT_MAIN_DIRECTORY = main_directory  # keep your existing variable/environment
OUT_FIG_DIR           = "RMS_calculations/envelope_over_time_plots"
FPS_DEFAULT           = 40
WINDOW_SEC            = 0.25
WINDOW_FRAMES         = max(int(WINDOW_SEC * FPS_DEFAULT), 1)
MEASURE_COLS    = ["distance_percentage_2_6", "distance_percentage"]
TRAINING_REGEX        = re.compile(r"training_(\d+)", re.I)

# NEW — odor timing (seconds)
ODOR_ON_S  = 30.0
ODOR_OFF_S = 60.0

# ─────────────────────────────── HELPERS ────────────────────────────────
def _resolve_measure_column(df: pd.DataFrame) -> Optional[str]:
    return next((c for c in MEASURE_COLS if c in df.columns), None)

def _extract_trials(data_dir: Path):
    """
    Yield (label, csv_path) for files whose names contain 'training'.
    The label is 'training_<num>' if matched, else the stem.
    """
    for csv_path in sorted(data_dir.glob("*training*.csv")):
        m = TRAINING_REGEX.search(csv_path.stem)
        label = f"training_{m.group(1)}" if m else csv_path.stem
        yield label, csv_path

def _compute_envelope(series: pd.Series, win_frames: int) -> pd.Series:
    """
    Clip raw values to [0, 100], then compute the analytic envelope
    and smooth with a centred rolling mean.
    """
    series = series.clip(lower=0, upper=100)           # ← clipping
    analytic = hilbert(series.to_numpy())
    env = np.abs(analytic)
    return (
        pd.Series(env, index=series.index)
          .rolling(window=win_frames, center=True, min_periods=1)
          .mean()
    )

# -----------------------------------------------------------------------------
# PLOTTING
# -----------------------------------------------------------------------------

def plot_envelope_subplots(
    fly_name: str,
    trials_data: dict[str, tuple[np.ndarray, np.ndarray]],
    out_path: Path,
    y_max: float
):
    """
    Plot analytic envelope over time for multiple trials as stacked subplots,
    marking the global peak with a red dot and sharing y-limits [0, y_max].
    Adds vertical bars at odor on/off and a shaded region for odor-on interval.
    """
    n = len(trials_data)
    if n == 0:
        print(f"[WARN] {fly_name}: no training trials to plot.")
        return

    padded_max = y_max * 1.02

    plt.rcParams.update({"figure.dpi": 300, "savefig.dpi": 300})
    fig, axes = plt.subplots(n, 1, figsize=(10, 2.5 * n), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, (label, (time_s, env_vals)) in zip(axes, trials_data.items()):
        # Envelope trace
        ax.plot(time_s, env_vals, linewidth=1, clip_on=False)

        # Mark global peak
        idx = np.nanargmax(env_vals)
        ax.plot(time_s[idx], env_vals[idx], marker='o', markersize=10, color='red', zorder=5)

        # Odor on/off markers + shaded interval
        ax.axvline(ODOR_ON_S,  linestyle='--', linewidth=1)
        ax.axvline(ODOR_OFF_S, linestyle='--', linewidth=1)
        ax.axvspan(ODOR_ON_S, ODOR_OFF_S, alpha=0.15)

        # Axes cosmetics
        ax.set_ylim(0, padded_max, auto=False)
        ax.autoscale(enable=False, axis="y")
        ax.margins(x=0, y=0)
        ax.set_ylabel("Envelope")
        ax.set_title(label)
        ax.grid(True)

        # Legend (simple, non-intrusive)
        peak_handle = plt.Line2D([0], [0], marker='o', color='red', linestyle='None', markersize=6, label='Peak')
        on_handle   = plt.Line2D([0], [0], linestyle='--', label='Odor on/off')
        span_handle = plt.Rectangle((0,0), 1, 1, alpha=0.15, label='Odor on window')
        ax.legend(handles=[peak_handle, on_handle, span_handle], loc='upper right', frameon=True, fontsize=8)

    axes[-1].set_xlabel("Time (s)")
    fig.suptitle(
        f"{fly_name}: Analytic Envelope Over Time — TRAINING (window={WINDOW_SEC}s; odor {ODOR_ON_S:.0f}–{ODOR_OFF_S:.0f}s)",
        y=0.98
    )
    fig.tight_layout(rect=[0, 0, 1, 0.95])

    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path)
    plt.close(fig)
    print(f"[OK] {out_path}")

# -----------------------------------------------------------------------------
# WORKFLOW
# -----------------------------------------------------------------------------

def process_fly_envelope(fly_folder: Path):
    """
    Compute analytic envelope for each TRAINING trial then plot as subplots
    with a common y-axis from 0 to the fly’s global maximum.
    """
    fly_name    = fly_folder.name
    data_dir    = fly_folder / "RMS_calculations"
    if not data_dir.is_dir():
        print(f"[WARN] {fly_name}: no RMS_calculations directory.")
        return

    trials_data: dict[str, tuple[np.ndarray, np.ndarray]] = {}
    all_max = 0.0

    for label, csv_path in _extract_trials(data_dir):
        df = pd.read_csv(csv_path)

        # time axis
        if "time_seconds" in df.columns:
            time_s = df["time_seconds"].to_numpy(dtype=float)
        else:
            time_s = np.arange(len(df)) / FPS_DEFAULT

        # measurement series
        meas_col = _resolve_measure_column(df)
        if meas_col is None:
            print(f"[ERROR] {fly_name} {label}: no measure column.")
            continue

        series = pd.to_numeric(df[meas_col], errors="coerce").fillna(0.0)
        env_vals = _compute_envelope(series, WINDOW_FRAMES).to_numpy()

        trial_max = np.nanmax(env_vals)
        print(f"[DEBUG] {fly_name} {label} peak envelope = {trial_max:.3f}")

        trials_data[label] = (time_s, env_vals)
        if trial_max > all_max:
            all_max = trial_max

    print(f"[DEBUG] {fly_name} global peak envelope = {all_max:.3f}")

    out_dir  = fly_folder / OUT_FIG_DIR
    out_path = out_dir / f"{fly_name}_TRAINING_envelope_over_time_subplots.png"
    plot_envelope_subplots(fly_name, trials_data, out_path, y_max=all_max)

def run_envelope_over_time(main_directory: Optional[Union[Path, str]] = None):
    root = Path(main_directory) if main_directory else Path(DEFAULT_MAIN_DIRECTORY)
    root = root.expanduser().resolve()
    for fly in root.iterdir():
        if fly.is_dir():
            process_fly_envelope(fly)

if __name__ == "__main__":
    run_envelope_over_time()


In [None]:
import shutil

def collect_all_plots(main_directory: Union[str, Path], dest_folder: str = "all_envelope_plots"):
    """
    Collect all envelope_over_time_subplots.png files from fly folders
    into a single folder inside main_directory.
    """
    root = Path(main_directory).expanduser().resolve()
    dest = root / dest_folder
    dest.mkdir(parents=True, exist_ok=True)

    count = 0
    for fly in root.iterdir():
        if not fly.is_dir():
            continue
        # expected location of plot
        plot_path = fly / OUT_FIG_DIR / f"{fly.name}_TRAINING_envelope_over_time_subplots.png"
        if plot_path.is_file():
            new_name = f"{fly.name}_envelope_over_time_subplots_training.png"
            shutil.copy2(plot_path, dest / new_name)
            count += 1

    print(f"[OK] Collected {count} plots into {dest}")

if __name__ == "__main__":
    run_envelope_over_time()  # generate plots
    collect_all_plots(DEFAULT_MAIN_DIRECTORY)  # gather them


In [None]:
# JUPYTER CELL — Latency to threshold crossing in training_5/6/7/8 (per fly + per trained-odor means + grand mean by odor)
# Figure style: research-presentation / professional

from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms

# ───────── PARAMETERS ─────────
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json")

FPS_DEFAULT       = 40.0        # fallback if fps missing/invalid
BEFORE_SEC        = 30.0
DURING_SEC        = 35.0
THRESH_STD_MULT   = 4           # θ = μ_before + k·σ_before
LATENCY_CEILING_S = 9.5         # > this window → NR
TRIALS_OF_INTEREST = [4, 5, 6]  # training_5/6/7/8

OUT_DIR = Path("/home/ramanlab/Documents/cole/Results/Opto/Training_RESP_Time_DIST")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Titles
TITLE_FLY        = "{} — Time to PER"
TITLE_ODOR       = "{} — Mean Time to PER"
TITLE_ODOR_GRAND = "Grand Mean Time to Reaction by Trained Odor"

# ───────── LOAD MATRIX + MAPS ─────────
matrix = np.load(MATRIX_NPY)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)

ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

# Keep fps numeric; only decode label-coded columns
decode_cols = [c for c in ["dataset","fly","trial_type","trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])

df = pd.DataFrame(matrix, columns=ordered_cols)

for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

# Ensure fps exists & is numeric
if "fps" in df.columns:
    if "fps" in rev_maps:  # in case it was code-mapped
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

# Keep only training rows
df = df[df["trial_type"].str.lower() == "training"].copy()
df["fps"] = df["fps"].replace([np.inf, -np.inf], np.nan).fillna(FPS_DEFAULT)

# Envelope columns exclude meta (including fps)
env_cols  = [c for c in ordered_cols if c not in meta_cols]

# ───────── Canonicalize trained odor names (for foldering) ─────────
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "ethyl butyrate": "EB",
    "optogenetics benzaldehyde": "opto_benz",
    "optogenetics benzaldehyde ": "opto_benz",  # guard
    "optogenetics benzaldehyde 1": "opto_benz_1",
    "optogenetics ethyl butyrate": "opto_EB",
    "10s_odor_benz": "10s_Odor_Benz",
}
def _canon_dataset(s: str) -> str:
    if not isinstance(s, str): return "UNKNOWN"
    key = s.strip().lower()
    return ODOR_CANON.get(key, s.strip())

def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

df["dataset_canon"] = df["dataset"].apply(_canon_dataset)

# Helpers
def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def _extract_env(row: pd.Series) -> np.ndarray:
    """Return non-padded envelope as 1D array (keep >0 finite)."""
    env = row[env_cols].to_numpy(dtype=float)
    return env[np.isfinite(env) & (env > 0)]

def latency_to_cross(env: np.ndarray, fps: float) -> float | None:
    """Latency (s) from DURING start to first sample > θ; None if no crossing."""
    if env.size == 0 or not np.isfinite(fps) or fps <= 0:
        return None
    b_end = int(BEFORE_SEC * fps)
    d_end = b_end + int(DURING_SEC * fps)
    total = env.size
    b_end = min(b_end, total); d_end = min(d_end, total)
    before = env[:b_end]; during = env[b_end:d_end]
    if before.size == 0 or during.size == 0:
        return None
    theta = float(np.nanmean(before)) + THRESH_STD_MULT * float(np.nanstd(before))
    idx = np.where(during > theta)[0]
    return (float(idx[0]) / fps) if idx.size else None

# ───────── Compute latencies table once (for per-fly + group means + grand means) ─────────
lat_records = []
for _, row in df.iterrows():
    tr = _trial_num(row["trial_label"])
    if tr not in TRIALS_OF_INTEREST:
        continue
    env = _extract_env(row)
    fps = float(row.get("fps", FPS_DEFAULT))
    fps = fps if (np.isfinite(fps) and fps > 0) else FPS_DEFAULT
    lat = latency_to_cross(env, fps)  # None if NR
    lat_for_mean = np.nan if (lat is None or lat > LATENCY_CEILING_S) else lat
    lat_records.append({
        "dataset": row["dataset"],
        "dataset_canon": row["dataset_canon"],
        "fly": row["fly"],
        "trial_num": tr,
        "latency": lat,                # per-fly display
        "lat_for_mean": lat_for_mean   # for means/STD/aggregates
    })

lat_df = pd.DataFrame(lat_records)

# ───────── Style ─────────
plt.rcParams.update({
    "figure.dpi": 300, "savefig.dpi": 300,
    "axes.spines.top": False, "axes.spines.right": False,
    "axes.linewidth": 1.0,
    "xtick.direction": "out", "ytick.direction": "out",
    "font.size": 11,
})

# ───────── PER-FLY LATENCIES — save into that fly's trained-odor folder ─────────
for fly in sorted(df["fly"].unique()):
    df_fly = df[df["fly"] == fly]
    if df_fly.empty:
        continue
    # assume single trained dataset per fly
    odor_folder = _safe_dirname(_canon_dataset(df_fly["dataset"].iloc[0]))
    odir = OUT_DIR / odor_folder
    odir.mkdir(parents=True, exist_ok=True)

    sub = lat_df[lat_df["fly"] == fly]
    labels = [f"Training {n}" for n in TRIALS_OF_INTEREST]

    # Collect per trial
    latencies = []
    for n in TRIALS_OF_INTEREST:
        s = sub[sub["trial_num"] == n]["latency"]
        latencies.append(s.iloc[0] if len(s) else None)

    # If this fly NEVER responds → NR panel
    any_resp = any((lat is not None) and (lat <= LATENCY_CEILING_S) for lat in latencies)
    if not any_resp:
        fig, ax = plt.subplots(figsize=(6.5, 3.2))
        ax.set_title(TITLE_FLY.format(fly), pad=10, fontsize=14, weight="bold")
        ax.set_xticks(np.arange(len(labels))); ax.set_xticklabels(labels)
        ax.set_ylim(0, LATENCY_CEILING_S + 2.0)
        ax.text(0.5, 0.55, "NR", transform=ax.transAxes, ha="center", va="center",
                fontsize=18, color="#666666", weight="bold")
        ax.set_ylabel("Time After Odor Sent(s)")
        ax.axhline(LATENCY_CEILING_S, linestyle="--", linewidth=1.1, color="#444444")
        trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
        ax.text(0.995, LATENCY_CEILING_S + 0.12, f"NR if > {LATENCY_CEILING_S:.1f} s",
                transform=trans, ha="right", va="bottom", fontsize=10, color="#444444", clip_on=False)
        fig.tight_layout()
        out_png = odir / f"{fly}_training_{'_'.join(map(str,TRIALS_OF_INTEREST))}_latency.png"
        fig.savefig(out_png); plt.close(fig)
        print(f"[OK] saved {out_png} (NR panel)")
        continue

    # Otherwise draw bars per trial
    bar_vals, annots, colors = [], [], []
    for lat in latencies:
        if lat is None or lat > LATENCY_CEILING_S:
            bar_vals.append(LATENCY_CEILING_S); annots.append("NR"); colors.append("#BDBDBD")
        else:
            bar_vals.append(lat); annots.append(f"{lat:.2f}s"); colors.append("#1A1A1A")

    fig, ax = plt.subplots(figsize=(6.5, 3.6))
    x = np.arange(len(labels))
    bars = ax.bar(x, bar_vals, width=0.6, color=colors, edgecolor="black", linewidth=1.0)

    for b, txt in zip(bars, annots):
        ytxt = float(max(b.get_height() * 0.5, 0.35))
        ax.text(b.get_x() + b.get_width()/2, ytxt, txt,
                ha="center", va="center", fontsize=10,
                color=("white" if txt != "NR" else "#444444"))

    ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel("Time After Odor Sent (s)")
    ax.set_ylim(0, LATENCY_CEILING_S + 2.5)

    ax.axhline(LATENCY_CEILING_S, linestyle="--", linewidth=1.1, color="#444444")
    trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
    ax.text(0.995, LATENCY_CEILING_S + 0.12, f"NR if > {LATENCY_CEILING_S:.1f} s",
            transform=trans, ha="right", va="bottom", fontsize=10, color="#444444", clip_on=False)

    ax.set_title(TITLE_FLY.format(fly), pad=10, fontsize=14, weight="bold")

    fig.tight_layout()
    out_png = odir / f"{fly}_training_{'_'.join(map(str,TRIALS_OF_INTEREST))}_latency.png"
    fig.savefig(out_png); plt.close(fig)
    print(f"[OK] saved {out_png}")

# ───────── PER-TRAINED-ODOR (dataset) MEAN LATENCIES + UPWARD SEM — into each odor folder ─────────
for odor in sorted(lat_df["dataset_canon"].unique() if not lat_df.empty else []):
    sub = lat_df[lat_df["dataset_canon"] == odor]
    labels = [f"Training {n}" for n in TRIALS_OF_INTEREST]
    odir = OUT_DIR / _safe_dirname(odor)
    odir.mkdir(parents=True, exist_ok=True)

    if sub.empty:
        fig, ax = plt.subplots(figsize=(6.8, 3.2))
        ax.set_title(TITLE_ODOR.format(odor), pad=10, fontsize=14, weight="bold")
        ax.set_xticks(np.arange(len(labels))); ax.set_xticklabels(labels)
        ax.set_ylim(0, LATENCY_CEILING_S + 2.0)
        ax.text(0.5, 0.55, "NR", transform=ax.transAxes, ha="center", va="center",
                fontsize=18, color="#666666", weight="bold")
        ax.set_ylabel("Time After Odor Sent(s)")
        out_png = odir / f"{odor}_training_{'_'.join(map(str,TRIALS_OF_INTEREST))}_mean_latency.png"
        fig.tight_layout(); fig.savefig(out_png); plt.close(fig)
        print(f"[OK] saved {out_png} (NR panel)")
        continue

    means, sems, ns = [], [], []
    for n in TRIALS_OF_INTEREST:
        s = sub[sub["trial_num"] == n]["lat_for_mean"].to_numpy(dtype=float)
        finite = s[np.isfinite(s)]
        n_resp = int(finite.size)
        mu = float(finite.mean()) if n_resp > 0 else np.nan
        sd = float(finite.std(ddof=1)) if n_resp > 1 else (0.0 if n_resp == 1 else np.nan)
        sem = (sd / np.sqrt(n_resp)) if n_resp > 1 else (0.0 if n_resp == 1 else np.nan)
        means.append(mu); sems.append(sem); ns.append(n_resp)

    if sum(ns) == 0:
        fig, ax = plt.subplots(figsize=(6.8, 3.2))
        ax.set_title(TITLE_ODOR.format(odor), pad=10, fontsize=14, weight="bold")
        ax.set_xticks(np.arange(len(labels))); ax.set_xticklabels(labels)
        ax.set_ylim(0, LATENCY_CEILING_S + 2.0)
        ax.text(0.5, 0.55, "NR", transform=ax.transAxes, ha="center", va="center",
                fontsize=18, color="#666666", weight="bold")
        ax.set_ylabel("Time After Odor Sent(s)")
        out_png = odir / f"{odor}_training_{'_'.join(map(str,TRIALS_OF_INTEREST))}_mean_latency.png"
        fig.tight_layout(); fig.savefig(out_png); plt.close(fig)
        print(f"[OK] saved {out_png} (NR panel)")
        continue

    # Upward-only SEM bars
    y        = np.nan_to_num(np.array(means, dtype=float), nan=0.0)
    yerr_up  = np.nan_to_num(np.array(sems,  dtype=float), nan=0.0)
    yerr     = np.vstack([np.zeros_like(yerr_up), yerr_up])  # lower=0, upper=SEM

    fig, ax = plt.subplots(figsize=(6.8, 3.8))
    x = np.arange(len(labels))
    bars = ax.bar(x, y, width=0.6, color="#1A1A1A", edgecolor="black", linewidth=1.0)
    ax.errorbar(x, y, yerr=yerr, fmt="none", ecolor="black", elinewidth=1.2, capsize=4)

    # Annotate each bar: mean (white, inside) + SEM and n above
    for xi, b in enumerate(bars):
        n_resp  = ns[xi]
        mu_val  = y[xi]
        sem_val = yerr_up[xi]
        if n_resp == 0 or not np.isfinite(mu_val):
            label_y = max(0.5, mu_val + 0.08)
            ax.text(b.get_x() + b.get_width()/2, label_y, "NR",
                    ha="center", va="bottom", fontsize=9, color="#444444")
            continue
        inner_y = max(mu_val * 0.50, min(mu_val - 0.10, mu_val * 0.90))
        ax.text(b.get_x() + b.get_width()/2, inner_y, f"{mu_val:.2f}s",
                ha="center", va="top", fontsize=10, color="white")
        topper_y = float(mu_val + sem_val + 0.06) if np.isfinite(sem_val) else float(mu_val + 0.06)
        ax.text(b.get_x() + b.get_width()/2, topper_y, f"SEM={sem_val:.2f}s\nn={n_resp}",
                ha="center", va="bottom", fontsize=9, color="#333333")

    ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel("Time After Odor Sent(s)")
    ymax = max(LATENCY_CEILING_S + 2.0, float((y + yerr_up).max()) + 1.2)
    ax.set_ylim(0, ymax)

    ax.axhline(LATENCY_CEILING_S, linestyle="--", linewidth=1.0, color="#6f6f6f")
    trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
    ax.text(0.995, LATENCY_CEILING_S + 0.08, f"NR if > {LATENCY_CEILING_S:.1f} s",
            transform=trans, ha="right", va="bottom", fontsize=9, color="#6f6f6f")

    ax.set_title(TITLE_ODOR.format(odor), pad=10, fontsize=14, weight="bold")

    fig.tight_layout()
    out_png = odir / f"{odor}_training_{'_'.join(map(str,TRIALS_OF_INTEREST))}_mean_latency.png"
    fig.savefig(out_png); plt.close(fig)
    print(f"[OK] saved {out_png}")

# ───────── GRAND MEAN ACROSS TRIALS BY ODOR (one bar per odor; SEM shown) ─────────
odors_all = sorted(df["dataset_canon"].unique())
summary_rows = []
grand_means, grand_sems, grand_ns = [], [], []

for odor in odors_all:
    s = lat_df[lat_df["dataset_canon"] == odor]["lat_for_mean"].to_numpy(dtype=float)
    finite = s[np.isfinite(s)]
    n_resp = int(finite.size)
    mu = float(finite.mean()) if n_resp > 0 else np.nan
    sd = float(finite.std(ddof=1)) if n_resp > 1 else (0.0 if n_resp == 1 else np.nan)
    sem = (sd / np.sqrt(n_resp)) if n_resp > 1 else (0.0 if n_resp == 1 else np.nan)

    grand_means.append(mu)
    grand_sems.append(sem)
    grand_ns.append(n_resp)

    summary_rows.append({
        "odor": odor, "n_resp": n_resp,
        "mean_s": mu if np.isfinite(mu) else np.nan,
        "sem_s": sem if np.isfinite(sem) else np.nan
    })

# Save CSV summary (root)
pd.DataFrame(summary_rows).to_csv(OUT_DIR / "grand_mean_by_odor_latency.csv", index=False)

if sum(grand_ns) == 0:
    fig, ax = plt.subplots(figsize=(7.2, 3.2))
    ax.set_title(TITLE_ODOR_GRAND, pad=10, fontsize=14, weight="bold")
    ax.set_xticks(np.arange(len(odors_all))); ax.set_xticklabels(odors_all, rotation=0)
    ax.set_ylim(0, LATENCY_CEILING_S + 2.0)
    ax.text(0.5, 0.55, "NR", transform=ax.transAxes, ha="center", va="center",
            fontsize=18, color="#666666", weight="bold")
    ax.set_ylabel("Time After Odor Sent(s)")
    ax.axhline(LATENCY_CEILING_S, linestyle="--", linewidth=1.0, color="#6f6f6f")
    out_png = OUT_DIR / "grand_mean_by_odor_latency.png"
    fig.tight_layout(); fig.savefig(out_png); plt.close(fig)
    print(f"[OK] saved {out_png} (NR panel)")
else:
    y       = np.nan_to_num(np.array(grand_means, dtype=float), nan=0.0)
    yerr_up = np.nan_to_num(np.array(grand_sems,  dtype=float),  nan=0.0)
    yerr    = np.vstack([np.zeros_like(yerr_up), yerr_up])  # lower=0, upper=SEM

    fig, ax = plt.subplots(figsize=(max(7.2, 1.8*len(odors_all)), 3.8))
    x = np.arange(len(odors_all))
    bars = ax.bar(x, y, width=0.6, color="#1A1A1A", edgecolor="black", linewidth=1.0)
    ax.errorbar(x, y, yerr=yerr, fmt="none", ecolor="black", elinewidth=1.2, capsize=4)

    # Annotate SEM (or NR) per odor
    for xi, b in enumerate(bars):
        n_resp = grand_ns[xi]
        sem_val = yerr_up[xi]
        if n_resp == 0:
            label_y = max(0.5, y[xi] + 0.08)
            ax.text(b.get_x() + b.get_width()/2, label_y, "NR",
                    ha="center", va="bottom", fontsize=9, color="#444444")
        else:
            label_y = float(y[xi] + sem_val + 0.06) if np.isfinite(sem_val) else float(y[xi] + 0.06)
            ax.text(b.get_x() + b.get_width()/2, label_y, f"SM={sem_val:.2f} s\nn={n_resp}",
                    ha="center", va="bottom", fontsize=9, color="#333333")

    ax.set_xticks(x); ax.set_xticklabels(odors_all, rotation=0)
    ax.set_ylabel("Time After Odor Sent(s)")
    ymax = max(LATENCY_CEILING_S + 2.0, float((y + yerr_up).max()) + 1.2)
    ax.set_ylim(0, ymax)

    ax.axhline(LATENCY_CEILING_S, linestyle="--", linewidth=1.0, color="#6f6f6f")
    ax.set_title(TITLE_ODOR_GRAND, pad=10, fontsize=14, weight="bold")

    out_png = OUT_DIR / "grand_mean_by_odor_latency.png"
    fig.tight_layout(); fig.savefig(out_png); plt.close(fig)
    print(f"[OK] saved {out_png}")

print("[DONE] Per-fly, per-trained-odor, and grand-mean-by-odor latency plots exported.)")

# Combined Angle / RMS

In [None]:
# JUPYTER CELL — Combine centered angle % with distance %, then RMS + Hilbert envelope
from pathlib import Path
import glob, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import hilbert

# ───────── prerequisites ─────────
assert 'main_directory' in globals(), "Define main_directory = '/path/to/root' before running."
ROOT = Path(main_directory).expanduser().resolve()
assert ROOT.is_dir(), f"Not a directory: {ROOT}"

# ───────── config ─────────
FPS_DEFAULT     = 40.0
WINDOW_SEC      = 0.25
WINDOW_FRAMES   = max(int(WINDOW_SEC * FPS_DEFAULT), 1)

# Raw per-trial CSV search patterns (supports one or two patterns)
IN_SUFFIX_ANG  = ("*merged.csv", "*class_2_6.csv")   # or just ("*merged.csv",)
IN_SUFFIX_DIST = ("*merged.csv", "*class_2_6.csv")   # or just ("*merged.csv",)

MONTHS = ("january","february","march","april","may","june",
          "july","august","september","october","november","december")

# Odor timing (seconds)
ODOR_ON_S  = 30.0
ODOR_OFF_S = 60.0

# Column candidates
TIME_CANDS   = ["time_s", "time_seconds", "t_s", "time"]
ANGLE_COLS   = ["angle_centered_pct", "angle_centered_percentage", "angle_pct"]
DIST_COLS    = ["distance_percentage_2_6", "distance_percentage", "distance_pct", "measure", "value"]

# ───────── utils ─────────
TESTING_REGEX = re.compile(r"testing_(\d+)", re.IGNORECASE)

def is_month_folder(p: Path) -> bool:
    return p.is_dir() and p.name.lower().startswith(MONTHS)

def infer_category_from_path(p: Path) -> str | None:
    parts = [s.lower() for s in p.parts]
    if "testing" in parts: return "testing"
    if "training" in parts: return "training"
    name = p.name.lower()
    if "testing" in name: return "testing"
    if "training" in name: return "training"
    return None

def _pick_col(df: pd.DataFrame, cands: list[str]) -> str | None:
    for c in cands:
        if c in df.columns:
            return c
    return None

def _time_axis(df: pd.DataFrame) -> np.ndarray:
    c = _pick_col(df, TIME_CANDS)
    if c:
        return pd.to_numeric(df[c], errors='coerce').to_numpy()
    if "frame" in df.columns:
        return pd.to_numeric(df["frame"], errors='coerce').to_numpy() / FPS_DEFAULT
    return np.arange(len(df)) / FPS_DEFAULT

def _rolling_rms(x: np.ndarray, win_frames: int) -> np.ndarray:
    s = pd.Series(pd.to_numeric(x, errors='coerce')).fillna(0.0)
    return (s.pow(2).rolling(window=win_frames, center=True, min_periods=1).mean().pow(0.5)).to_numpy()

def _hilbert_envelope(x: np.ndarray, win_frames: int) -> np.ndarray:
    env = np.abs(hilbert(np.nan_to_num(x, nan=0.0)))
    # light smoothing to match your previous style
    return pd.Series(env).rolling(window=win_frames, center=True, min_periods=1).mean().to_numpy()

def angle_multiplier(angle_pct: np.ndarray) -> np.ndarray:
    ap = np.asarray(angle_pct, dtype=float)
    ap = np.clip(ap, -100.0, 100.0)
    conds = [
        (ap < -40),
        (ap >= -40) & (ap < -25),
        (ap >= -25) & (ap < -10),
        (ap >= -10) & (ap <= 10),
        (ap > 10)  & (ap <= 25),
        (ap > 25)  & (ap <= 40),
        (ap > 40)  & (ap <= 60),
        (ap > 60)  & (ap <= 100),
    ]
    vals = [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00]
    return np.select(conds, vals, default=np.nan)

# ───────── discovery for angle & distance trials (raw per-trial CSVs) ─────────
from collections.abc import Iterable

def _locate_trials_with_cols(fly_dir: Path, suffix_globs: Iterable[str] | str, required_cols: list[str]):
    """
    Return [(label, path, category)] for CSVs containing any of required_cols, within month subfolders.
    Accepts one or many suffix patterns (e.g., "*merged.csv", "*combined.csv").
    """
    # Normalize to iterable
    if isinstance(suffix_globs, str):
        suffixes = [suffix_globs]
    else:
        suffixes = list(suffix_globs)

    month_folders = [sub for sub in fly_dir.rglob("*") if is_month_folder(sub)]
    csvs = []
    for month_folder in month_folders:
        for suff in suffixes:
            csvs.extend(Path(p) for p in glob.iglob(str(month_folder / "**" / suff), recursive=True))

    out = []
    for p in sorted(set(csvs)):
        try:
            df = pd.read_csv(p, nrows=5)
            if _pick_col(df, required_cols):
                cat = infer_category_from_path(p) or "testing"
                out.append((p.stem, p, cat))
        except Exception:
            pass
    return out

def _index_testing_by_id(entries):
    """Map testing_<#> -> path; plus fallback by lowercase stem."""
    idx = {}
    fallback = {}
    for label, path, cat in entries:
        if cat != "testing":
            continue
        m = TESTING_REGEX.search(label)
        if m:
            idx[m.group(0).lower()] = path
        else:
            fallback[label.lower()] = path
    return idx, fallback

# ───────── core compute+plot ─────────
def compute_and_plot_for_fly(fly_dir: Path):
    fly_name = fly_dir.name

    # Find raw testing files that have angle and distance signals
    angle_entries   = _locate_trials_with_cols(fly_dir, IN_SUFFIX_ANG, ANGLE_COLS)
    distance_entries= _locate_trials_with_cols(fly_dir, IN_SUFFIX_DIST, DIST_COLS)

    if not distance_entries:
        print(f"[{fly_name}] No testing distance trials found — skipping.")
        return

    ang_idx,  ang_fallback  = _index_testing_by_id(angle_entries)
    dist_idx, dist_fallback = _index_testing_by_id(distance_entries)

    out_dir_csv  = fly_dir / "angle_distance_rms_envelope"
    out_dir_figs = out_dir_csv / "plots"
    out_dir_csv.mkdir(parents=True, exist_ok=True)
    out_dir_figs.mkdir(parents=True, exist_ok=True)

    n_done, n_skip = 0, 0

    # Anchor iteration on distance trials (time base)
    for test_id, dist_path in sorted(dist_idx.items()):
        # Match angle file to the same testing_<#>
        angle_path = ang_idx.get(test_id)
        if angle_path is None:
            # try fallback: identical stem
            angle_path = ang_fallback.get(Path(dist_path).stem.lower())
        if angle_path is None:
            print(f"[WARN] {fly_name} {test_id}: no matching angle file — skipped.")
            n_skip += 1
            continue

        try:
            # Load distance
            df_d = pd.read_csv(dist_path)
            t_d = _time_axis(df_d)
            dist_col = _pick_col(df_d, DIST_COLS)
            if not dist_col:
                raise ValueError("No distance_% column in distance trial file.")
            dist_pct = pd.to_numeric(df_d[dist_col], errors='coerce').fillna(0.0).clip(lower=0, upper=100).to_numpy()

            # Load angle and align to distance time base
            df_a = pd.read_csv(angle_path)
            angle_col = _pick_col(df_a, ANGLE_COLS)
            if not angle_col:
                raise ValueError("No angle_centered_% column in angle trial file.")
            t_a = _time_axis(df_a)
            ang_pct = pd.to_numeric(df_a[angle_col], errors='coerce').to_numpy()

            # Interpolate angle onto distance time base
            order = np.argsort(t_a)
            t_a_sorted = t_a[order]
            ang_sorted = ang_pct[order]
            good = np.isfinite(t_a_sorted) & np.isfinite(ang_sorted)
            if not np.any(good):
                raise ValueError("Angle series has no finite values for interpolation.")
            t_a_g = t_a_sorted[good]
            ang_g = ang_sorted[good]
            ang_on_dist_t = np.interp(t_d, t_a_g, ang_g, left=ang_g[0], right=ang_g[-1])

            # Combine angle % and distance % (replace prior RMS×mult):
            #   combined_base = distance_pct * angle_multiplier(angle_pct_interp)
            mult = angle_multiplier(ang_on_dist_t)
            combined_base = dist_pct * mult

            # Rolling RMS of the combined signal
            combined_rms = _rolling_rms(combined_base, WINDOW_FRAMES)

            # Hilbert envelope of the RMS (smoothed)
            envelope_rms = _hilbert_envelope(combined_rms, WINDOW_FRAMES)

            # Persist
            out_df = pd.DataFrame({
                "time_s": t_d,
                "angle_centered_pct_interp": ang_on_dist_t,
                "distance_percentage": dist_pct,
                "multiplier": mult,
                "combined_base": combined_base,
                "rolling_rms": combined_rms,
                "envelope_of_rms": envelope_rms
            })
            out_csv = out_dir_csv / f"{test_id}_angle_distance_rms_envelope.csv"
            out_df.to_csv(out_csv, index=False)

            # Plot envelope (primary output) with odor markers
            plt.figure(figsize=(12, 4))
            plt.plot(t_d, envelope_rms, linewidth=1.5)
            plt.axvline(ODOR_ON_S,  color='red', linewidth=2)
            plt.axvline(ODOR_OFF_S, color='red', linewidth=2)
            plt.title(f"{fly_name} — {test_id}: Envelope( RMS( distance × angle-mult ) )")
            plt.xlabel("Time (s)")
            plt.ylabel("Envelope of RMS (arb.)")
            plt.margins(x=0)
            plt.grid(True, alpha=0.3)
            out_png = out_dir_figs / f"{fly_name}_{test_id}_env_rms_angle_distance.png"
            plt.savefig(out_png, bbox_inches='tight', dpi=300)
            plt.close()

            print(f"[OK] {fly_name} {test_id} → CSV: {out_csv.name} | FIG: {out_png.name}")
            n_done += 1

        except Exception as e:
            print(f"[WARN] {fly_name} {test_id} → {e}")
            n_skip += 1

    print(f"[{fly_name}] completed: {n_done}, skipped: {n_skip}")

def run_all(root: Path = ROOT):
    for fly_dir in sorted([d for d in root.iterdir() if d.is_dir()]):
        compute_and_plot_for_fly(fly_dir)

# ───────── execute ─────────
run_all()

# Clean Up

In [None]:
import shutil
from pathlib import Path

# ───────── INPUT: source directories ─────────
ROOTS = [
    Path("/home/ramanlab/Documents/cole/Data/flys/10s_Odor_Benz/"),
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_benz/"),
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_EB/"),
    Path("/home/ramanlab/Documents/cole/Data/flys/opto_benz_1/"),

]

# ───────── Destination directory ─────────
DEST_ROOT = Path("/securedstorage/DATAsec/cole/Data-secured/")

# Valid month prefixes
MONTHS = [
    "january", "february", "march", "april", "may", "june",
    "july", "august", "september", "october", "november", "december"
]

# ───────── Step 1: Copy all data to secured storage ─────────
for root in ROOTS:
    if not root.exists():
        print(f"Source folder missing: {root}")
        continue

    dest_path = DEST_ROOT / root.name
    print(f"\nCopying from {root} → {dest_path}")
    dest_path.mkdir(parents=True, exist_ok=True)

    for item in root.rglob("*"):
        relative_path = item.relative_to(root)
        target = dest_path / relative_path

        if item.is_dir():
            target.mkdir(parents=True, exist_ok=True)
        else:
            if target.exists():
                print(f"Skipping (already exists): {target}")
                continue
            shutil.copy2(item, target)
            print(f"Copied: {target}")

print("\nCopy phase completed successfully.")

# ───────── Step 2: Clean up source folders ─────────
for root in ROOTS:
    print(f"\nCleaning up {root}...")

    for fly_folder in root.iterdir():
        if not fly_folder.is_dir():
            continue  # Skip files directly under ACV/Benz/etc.

        # ---- Rule 4: Delete folder if it doesn't start with a valid month ----
        folder_name_lower = fly_folder.name.lower()
        if not any(folder_name_lower.startswith(month) for month in MONTHS):
            shutil.rmtree(fly_folder)
            print(f"Deleted non-month folder: {fly_folder}")
            continue  # Skip further cleanup for this folder since it's gone

        # ---- Rule 2 & 3: Inside a valid month fly folder ----
        for item in fly_folder.iterdir():
            # Preserve RMS_calculations folder
            if item.name == "RMS_calculations":
                print(f"Preserving folder: {item}")
                continue

            # Preserve any CSV files directly inside the fly folder
            if item.is_file() and item.suffix.lower() == ".csv":
                print(f"Preserving CSV file: {item}")
                continue

            # Delete other files
            if item.is_file():
                item.unlink()
                print(f"Deleted file: {item}")

            # Delete directories other than RMS_calculations
            elif item.is_dir():
                shutil.rmtree(item)
                print(f"Deleted folder: {item}")

print("\nCleanup completed successfully.")

# Single CSV / Matrix:

In [None]:
# JUPYTER CELL — Wide CSV of per-frame DIRECTION VALUE across MANY main_directories
from pathlib import Path
import re
import numpy as np
import pandas as pd
from typing import Optional

# ───────── INPUT: add as many roots as you need ─────────
ROOTS = [
    Path("/securedstorage/DATAsec/cole/Data-secured/opto_EB/"),
    Path("/securedstorage/DATAsec/cole/Data-secured/opto_benz/"),
    Path("/securedstorage/DATAsec/cole/Data-secured/opto_benz_1/"),
]

# ───────── CONFIG (direction_value aggregation) ─────────
MEASURE_COLS  = ["envelope_of_rms"]
FPS_DEFAULT   = 40

# Output file (combined for all datasets)
OUT_WIDE_CSV = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/all_direction_values_rows_wide.csv")

TRIAL_REGEX = re.compile(r"(testing|training)_(\d+)", re.IGNORECASE)

# Timestamp + frame columns we’ll look for (to estimate FPS for metadata)
TIMESTAMP_CANDIDATES = ["UTC_ISO", "Timestamp", "Number", "MonoNs"]
FRAME_CANDIDATES     = ["Frame", "FrameNumber", "Frame Number"]
FALLBACK_FPS = 40

def _pick_timestamp_column(df: pd.DataFrame) -> Optional[str]:
    for c in TIMESTAMP_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _pick_frame_column(df: pd.DataFrame) -> Optional[str]:
    for c in FRAME_CANDIDATES:
        if c in df.columns:
            return c
    return None

def _to_seconds_series(df: pd.DataFrame, ts_col: str) -> pd.Series:
    s = df[ts_col]
    if ts_col in ("UTC_ISO", "Timestamp"):
        dt = pd.to_datetime(s, errors="coerce", utc=(ts_col == "UTC_ISO"))
        secs = dt.astype("int64") / 1e9  # NaT -> NaN
        t0 = np.nanmin(secs.values)
        return (secs - t0).astype(float)
    if ts_col == "Number":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        t0 = np.nanmin(vals.values)
        return vals - t0
    if ts_col == "MonoNs":
        vals = pd.to_numeric(s, errors="coerce").astype(float)
        secs = vals / 1e9
        t0 = np.nanmin(secs.values)
        return secs - t0
    raise ValueError(f"Unsupported timestamp column: {ts_col}")

def _estimate_fps_from_seconds(seconds_series: pd.Series) -> Optional[float]:
    mask = seconds_series.notna()
    if mask.sum() < 2:
        return None
    duration = seconds_series[mask].iloc[-1] - seconds_series[mask].iloc[0]
    if duration <= 0:
        return None
    return mask.sum() / duration

def _resolve_measure_column(df: pd.DataFrame) -> str | None:
    return next((c for c in MEASURE_COLS if c in df.columns), None)

def _infer_trial_type(p: Path) -> str:
    s = (p.stem + "/" + "/".join(q.name for q in p.parents)).lower()
    if "testing" in s:  return "testing"
    if "training" in s: return "training"
    return "unknown"

def _trial_label(p: Path) -> str:
    m = TRIAL_REGEX.search(p.stem)
    if not m:
        chain = (p.stem + "/" + "/".join(q.name for q in p.parents)).lower()
        m = TRIAL_REGEX.search(chain)
    if m:
        kind, num = m.group(1).lower(), m.group(2)
        return f"{kind}_{num}"
    stem = p.stem
    m2 = re.search(r"(\d+)$", stem)
    if m2:
        return f"{_infer_trial_type(p)}_{m2.group(1)}"
    return stem

def _find_trial_csvs(fly_dir: Path):
    search_root = fly_dir / "envelope_of_rms"
    if not search_root.is_dir():
        search_root = fly_dir
    patterns = ["**/*testing*.csv", "**/*training*.csv"]
    seen = set()
    for pat in patterns:
        for csv in search_root.glob(pat):
            if csv.is_file():
                rp = csv.resolve()
                if rp not in seen:
                    seen.add(rp)
                    yield rp

# ───────── PASS 1: discover items and determine max row length ─────────
items = []   # [{dataset, fly, csv_path, trial_type, trial_label, measure_col, n_frames}]
max_len = 0

for root in ROOTS:
    root = root.expanduser().resolve()
    assert root.is_dir(), f"Not a directory: {root}"
    dataset = root.name

    for fly_dir in sorted(p for p in root.iterdir() if p.is_dir()):
        fly = fly_dir.name
        for csv_path in _find_trial_csvs(fly_dir):
            try:
                header_df = pd.read_csv(csv_path, nrows=0)
            except Exception as e:
                print(f"[WARN] Skip {csv_path.name}: header read error: {e}")
                continue
            col = _resolve_measure_column(header_df)
            if col is None:
                print(f"[SKIP] {csv_path.name}: none of {MEASURE_COLS} present.")
                continue
            try:
                n_frames = pd.read_csv(csv_path, usecols=[col]).shape[0]
            except Exception as e:
                print(f"[WARN] Skip {csv_path.name}: count error: {e}")
                continue

            items.append({
                "dataset": dataset,
                "fly": fly,
                "csv_path": csv_path,
                "trial_type": _infer_trial_type(csv_path),
                "trial_label": _trial_label(csv_path),
                "measure_col": col,
                "n_frames": n_frames
            })
            max_len = max(max_len, n_frames)

if not items:
    raise RuntimeError("No eligible testing/training CSVs with 'direction_value' found in provided roots.")

print(f"[INFO] Datasets: {[r.name for r in ROOTS]}")
print(f"[INFO] Discovered {len(items)} videos. Max frames = {max_len}")

# ───────── PASS 2: read direction_value and write combined wide CSV ─────────
cols = ["dataset", "fly", "trial_type", "trial_label", "fps"] + [f"dir_val_{i}" for i in range(max_len)]
pd.DataFrame(columns=cols).to_csv(OUT_WIDE_CSV, index=False)

for it in items:
    dataset     = it["dataset"]
    fly         = it["fly"]
    csv_path    = it["csv_path"]
    trial_type  = it["trial_type"]
    label       = it["trial_label"]
    measure_col = it["measure_col"]

    # --- Determine FPS from timestamps, if possible ---
    try:
        hdr2 = pd.read_csv(csv_path, nrows=0)
    except Exception:
        hdr2 = pd.DataFrame()

    frame_col = _pick_frame_column(hdr2) if not hdr2.empty else None
    ts_col    = _pick_timestamp_column(hdr2) if not hdr2.empty else None

    fps = np.nan
    if frame_col is not None and ts_col is not None:
        try:
            df_ts = pd.read_csv(csv_path, usecols=[frame_col, ts_col])
            secs  = _to_seconds_series(df_ts, ts_col)
            fps_from_csv = _estimate_fps_from_seconds(secs)
            fps = float(fps_from_csv) if (fps_from_csv and np.isfinite(fps_from_csv) and fps_from_csv > 0) else float(FALLBACK_FPS)
        except Exception as e:
            print(f"[WARN] FPS inference failed for {csv_path.name}: {e}")
            fps = float(FALLBACK_FPS)
    else:
        fps = float(FALLBACK_FPS)
    # --- END FPS BLOCK ---

    # Read raw direction values
    try:
        df = pd.read_csv(csv_path, usecols=[measure_col])
        vals = pd.to_numeric(df[measure_col], errors="coerce").astype(float).to_numpy()
    except Exception as e:
        print(f"[WARN] Read failed {csv_path}: {e}")
        continue

    # Build output row
    row = [dataset, fly, trial_type, label, fps] + list(vals)

    # pad/truncate to max_len
    if len(vals) < max_len:
        row += [np.nan] * (max_len - len(vals))
    elif len(vals) > max_len:
        row = row[:5 + max_len]  # 5 metadata columns

    pd.DataFrame([row], columns=cols).to_csv(OUT_WIDE_CSV, mode="a", header=False, index=False)

print(f"[OK] Wrote combined direction-value table: {OUT_WIDE_CSV}")

In [None]:
# JUPYTER CELL — Convert wide envelope CSV → 16-bit numeric matrix + code key
from pathlib import Path
import numpy as np
import pandas as pd
import json

# ===== INPUT / OUTPUT =====
INPUT_CSV = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/all_direction_values_rows_wide.csv")  # change if your file lives elsewhere
OUT_DIR   = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/")                           # change if desired
OUT_DIR.mkdir(parents=True, exist_ok=True)

MATRIX_NPY = OUT_DIR / "envelope_matrix_float16.npy"   # 16-bit floating matrix
CODE_KEY   = OUT_DIR / "code_key.txt"                  # human-readable mapping & schema
CODES_JSON = OUT_DIR / "code_maps.json"                # machine-readable mappings (optional)

# ===== LOAD =====
df = pd.read_csv(INPUT_CSV)

# Identify metadata columns (present subset)
meta_cols_all = ["dataset", "fly", "trial_type", "trial_label", "fps"]
meta_cols = [c for c in meta_cols_all if c in df.columns]
assert meta_cols, "No metadata columns found. Expected at least one of: dataset, fly, trial_type, trial_label."

# Envelope columns (everything else)
env_cols = [c for c in df.columns if c not in meta_cols]
assert len(env_cols) > 0, "No envelope columns found."

# ===== BUILD INTEGER CODES FOR METADATA =====
# Codes start at 1; 0 is reserved for 'unknown'
code_maps = {}
for col in meta_cols:
    uniques = pd.Series(df[col].astype(str).fillna("UNKNOWN")).unique().tolist()
    mapping = {"UNKNOWN": 0}
    next_code = 1
    for u in uniques:
        if u not in mapping:
            mapping[u] = next_code
            next_code += 1
    code_maps[col] = mapping

# Apply codes to a copy
df_num = df.copy()
for col, mapping in code_maps.items():
    df_num[col] = df_num[col].astype(str).map(mapping).fillna(0).astype(np.int32)

# Ensure envelope columns are numeric and NaN-free
df_num[env_cols] = df_num[env_cols].apply(pd.to_numeric, errors="coerce")
df_num[env_cols] = df_num[env_cols].fillna(0.0)

# ===== BUILD THE MATRIX (float16) =====
# Order: [meta_cols...] + [env_0...env_N]
ordered_cols = meta_cols + env_cols
matrix_f16 = df_num[ordered_cols].to_numpy(dtype=np.float16)

# ===== SAVE ARTIFACTS =====
np.save(MATRIX_NPY, matrix_f16)

# Human-readable key file
with CODE_KEY.open("w", encoding="utf-8") as f:
    f.write("# Envelope matrix schema (float16), row-wise\n")
    f.write("# Columns (in order):\n")
    for i, col in enumerate(ordered_cols):
        f.write(f"{i:>5}: {col}\n")
    f.write("\n# Metadata code maps (string → integer code)\n")
    for col in meta_cols:
        f.write(f"\n[{col}]\n")
        # Sort by numeric code
        inv = sorted(((code, name) for name, code in code_maps[col].items()), key=lambda x: x[0])
        for code, name in inv:
            f.write(f"{code:>5} : {name}\n")
    f.write("\nNotes:\n")
    f.write("- Matrix dtype is float16 (16-bit). Metadata codes are stored as float16 numbers in the matrix.\n")
    f.write("- Envelope NaNs (shorter videos) were replaced with 0.0.\n")
    f.write("- Code '0' means UNKNOWN for the metadata fields.\n")

# Optional: machine-readable mappings
with CODES_JSON.open("w", encoding="utf-8") as jf:
    json.dump({"column_order": ordered_cols, "code_maps": code_maps}, jf, indent=2)

print(f"[OK] Saved 16-bit matrix: {MATRIX_NPY}  (shape={matrix_f16.shape}, dtype={matrix_f16.dtype})")
print(f"[OK] Saved key:           {CODE_KEY}")
print(f"[OK] Saved JSON maps:     {CODES_JSON}")


## Matrix's

In [None]:
# JUPYTER CELL — Reaction matrices per odor + fly-category counts (During & After)
from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec
from matplotlib.patches import Patch

# ───────── USER KNOB: spacing between rows (increase this to add space)
ROW_GAP = 0.6
HEIGHT_PER_GAP_IN = 3.0
BOTTOM_SHIFT_IN = 0.50

# ───────── PARAMETERS ─────────
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/code_maps.json")
FPS_DEFAULT       = 40
BEFORE_SEC        = 30.0
DURING_SEC        = 30.0
AFTER_WINDOW_SEC  = 30.0
THRESH_STD_MULT   = 4
MIN_SAMPLES_OVER  = 20

ODOR_TRANSIT_LAT_S = overall_mean_latency_s

OUT_DIR           = Path("/home/ramanlab/Documents/cole/Results/Opto/Matrixs_DISTxANGLE")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Canon keys for grouping
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics Ethyl Butyrate": "opto_EB",
    "Optogenetics benzaldehyde": "opto_benz_1",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
    "opto_benz_1": "Benzaldehyde",
}
ODOR_ORDER = ["ACV", "3-octonol", "Benz", "EB", "10s_Odor_Benz", "opto_benz", "opto_EB", "opto_benz_1"]

# ───────── LOAD + DECODE ─────────
matrix = np.load(MATRIX_NPY)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)
ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset", "fly", "trial_type", "trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])
df = pd.DataFrame(matrix, columns=ordered_cols)

for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

if "fps" in df.columns:
    if "fps" in rev_maps:
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

df = df[df["trial_type"].str.lower() == "testing"].copy()

FPS_FALLBACK = FPS_DEFAULT
df["fps"] = df["fps"].fillna(FPS_FALLBACK).replace([np.inf, -np.inf], FPS_FALLBACK)

env_cols = [c for c in ordered_cols if c not in meta_cols]

def _canon_odor(s: str) -> str:
    if not isinstance(s, str): return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())
df["dataset_canon"] = df["dataset"].apply(_canon_odor)

def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # hexanol controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

def score_trial_from_env(env_row: pd.Series, fps: float) -> tuple[int, int]:
    env = env_row.to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    if env.size == 0:
        return (0, 0)
    total = env.size
    b_end   = int(round(BEFORE_SEC * fps))
    shift   = int(round(ODOR_TRANSIT_LAT_S * fps))
    d_start = b_end + shift
    d_end   = b_end + int(round(DURING_SEC * fps)) + shift
    a_end   = d_end + int(round(AFTER_WINDOW_SEC * fps))

    b_end   = max(0, min(b_end, total))
    d_start = max(b_end, min(d_start, total))
    d_end   = max(d_start, min(d_end, total))
    a_end   = max(d_end, min(a_end, total))

    before = env[:b_end]
    during = env[d_start:d_end]
    after  = env[d_end:a_end]

    if before.size == 0:
        return (0, 0)

    theta = float(np.nanmean(before)) + THRESH_STD_MULT * float(np.nanstd(before))
    during_hit = int(np.sum(during > theta) >= MIN_SAMPLES_OVER) if during.size else 0
    after_hit  = int(np.sum(after  > theta) >= MIN_SAMPLES_OVER) if after.size  else 0
    return during_hit, after_hit

# ───────── Score all rows ─────────
scores = []
for _, row in df.iterrows():
    row_fps = float(row.get("fps", FPS_FALLBACK))
    d_hit, a_hit = score_trial_from_env(row[env_cols], row_fps)
    scores.append({
        "dataset": row["dataset_canon"],
        "fly": row["fly"],
        "trial": row["trial_label"],
        "trial_num": _trial_num(row["trial_label"]),
        "during_hit": d_hit,
        "after_hit": a_hit
    })
scores_df = pd.DataFrame(scores)

# ───────── Colormaps and helpers ─────────
cmap = ListedColormap(["0.7", "1.0", "0.0"])  # gray, white, black
norm = BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap.N)

def style_trained_xticks_vertical(ax, labels, trained_disp: str, fontsize: int):
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize=fontsize)
    txts = []
    for tick in ax.get_xticklabels():
        txt = tick.get_text()
        if txt.strip().lower() == trained_disp.lower():
            tick.set_text(trained_disp.upper())
            tick.set_color("tab:blue")
        txts.append(tick.get_text())
    ax.set_xticklabels(txts, rotation=90, ha="center", va="top", fontsize=fontsize)
    ax.tick_params(axis="x", pad=2)

def compute_fly_category_counts(mat: np.ndarray, labels: list[str], trained_disp: str, include_hexanol: bool = False):
    if mat.size == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    trained_idx = [j for j, lab in enumerate(labels)
                   if lab.strip().lower() == trained_disp.lower()]
    other_idx = [j for j, lab in enumerate(labels)
                 if lab.strip().lower() != trained_disp.lower()
                 and (include_hexanol or lab.strip().lower() != "hexanol")]
    if len(trained_idx) == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    counts = {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    for i in range(mat.shape[0]):
        row = mat[i, :]
        row = np.where(row < 0, 0, row)
        t_hit = np.any(row[trained_idx] == 1)
        o_hit = np.any(row[other_idx]   == 1) if len(other_idx) else False
        if t_hit and not o_hit:
            counts["Trained only"] += 1
        elif t_hit and o_hit:
            counts["Trained + Others"] += 1
        elif (not t_hit) and o_hit:
            counts["Others only"] += 1
    return counts

def plot_category_counts(ax, counts: dict, n_flies: int, title: str):
    cats = ["Trained only", "Trained + Others", "Others only"]
    raw = np.array([counts.get(c, 0) for c in cats], dtype=float)
    vals_pct = 100.0 * raw / float(n_flies) if n_flies > 0 else np.zeros_like(raw)
    x = np.arange(len(cats))
    bars = ax.bar(x, vals_pct, width=0.75, edgecolor="black", linewidth=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(cats, rotation=15, ha="right")
    ax.set_ylim(0, 100)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_ylabel("% of flies")
    ax.set_title(title, fontsize=12, weight="bold")
    ax.margins(x=0.05)
    for b, pct in zip(bars, vals_pct):
        ax.text(b.get_x() + b.get_width()/2, b.get_height() + 1.5, f"{pct:.0f}%", ha="center", va="bottom", fontsize=9)

def shade_latency_on_timeseries(ax, before_sec: float = BEFORE_SEC, latency_s: float = ODOR_TRANSIT_LAT_S):
    x0 = before_sec
    x1 = before_sec + latency_s
    ax.axvspan(x0, x1, color="red", alpha=0.30, lw=0)

# ──────── helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ───────── Build & save per-odor figures (per-odor subfolders) ─────────
present = scores_df["dataset"].unique().tolist()
ordered_present = [o for o in ODOR_ORDER if o in present]
extras = sorted([o for o in present if o not in ODOR_ORDER])
for odor in ordered_present + extras:
    sub = scores_df[scores_df["dataset"] == odor].copy()
    if sub.empty:
        print(f"[WARN] No testing trials for {odor}")
        continue

    # per-odor output directory
    odir = OUT_DIR / _safe_dirname(odor)
    odir.mkdir(parents=True, exist_ok=True)

    flies  = sorted(sub["fly"].unique())
    trials = sorted(sub["trial"].unique(), key=_trial_num)
    pretty_cols = [display_odor_for_trial(odor, t) for t in trials]

    D = -np.ones((len(flies), len(trials)), dtype=int)
    A = -np.ones((len(flies), len(trials)), dtype=int)
    for i, fly in enumerate(flies):
        fly_rows = sub[sub["fly"] == fly]
        for j, t in enumerate(trials):
            s = fly_rows[fly_rows["trial"] == t]
            if s.empty: continue
            D[i, j] = int(s["during_hit"].iloc[0])
            A[i, j] = int(s["after_hit"].iloc[0])

    odor_label   = DISPLAY_LABEL.get(odor, odor)
    trained_disp = DISPLAY_LABEL.get(odor, odor)
    n_flies = len(flies)
    n_trials = len(trials)

    base_fig_w = max(10.0, 0.70 * n_trials + 6.0)
    base_fig_h = max(5.0, n_flies * 0.26 + 3.8)
    fig_w = base_fig_w
    fig_h = base_fig_h + ROW_GAP * HEIGHT_PER_GAP_IN
    fig_h += BOTTOM_SHIFT_IN
    xtick_fs = 9 if n_trials <= 10 else (8 if n_trials <= 16 else 7)

    during_counts = compute_fly_category_counts(D, pretty_cols, trained_disp, include_hexanol=True)
    after_counts  = compute_fly_category_counts(A, pretty_cols, trained_disp, include_hexanol=True)

    fig = plt.figure(figsize=(fig_w, fig_h), constrained_layout=False)
    gs  = gridspec.GridSpec(2, 2, height_ratios=[3.0, 1.25], width_ratios=[1, 1], hspace=ROW_GAP, wspace=0.10)

    axD  = fig.add_subplot(gs[0, 0])
    axA  = fig.add_subplot(gs[0, 1])
    axDc = fig.add_subplot(gs[1, 0])
    axAc = fig.add_subplot(gs[1, 1])

    imD = axD.imshow(D, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axD.set_title(f"{odor_label} — During\n(DURING shifted by +{ODOR_TRANSIT_LAT_S:.2f} s)", fontsize=14, weight="bold", linespacing=1.1)
    style_trained_xticks_vertical(axD, pretty_cols, trained_disp, fontsize=xtick_fs)
    axD.set_yticks([]); axD.set_ylabel(f"{n_flies} Flies", fontsize=11)

    imA = axA.imshow(A, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axA.set_title(f"{odor_label} — After (first {int(AFTER_WINDOW_SEC)} s)", fontsize=14, weight="bold")
    style_trained_xticks_vertical(axA, pretty_cols, trained_disp, fontsize=xtick_fs)
    axA.set_yticks([]); axA.set_ylabel(f"{n_flies} Flies", fontsize=11)

    plot_category_counts(axDc, during_counts, n_flies, title="During — Fly Reaction Categories")
    plot_category_counts(axAc, after_counts,  n_flies, title=f"After (first {int(AFTER_WINDOW_SEC)} s) — Fly Reaction Categories")

    red_patch = Patch(facecolor="red", edgecolor="red", alpha=0.30, label=f"Odor transit {ODOR_TRANSIT_LAT_S:.2f} s (pre-DURING)")
    axD.legend(handles=[red_patch], loc="upper left", frameon=True, fontsize=9)

    shift_frac = BOTTOM_SHIFT_IN / fig_h
    for ax in (axDc, axAc):
        pos = ax.get_position()
        new_y0 = max(0.05, pos.y0 - shift_frac)
        ax.set_position([pos.x0, new_y0, pos.width, pos.height])

    # Save into per-odor folder
    out_png = odir / f"reaction_matrix_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}_latency_{ODOR_TRANSIT_LAT_S:.3f}s.png"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"[OK] saved {out_png}")

    key_path = odir / f"row_key_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}.txt"
    with key_path.open("w") as fh:
        for i, fly in enumerate(flies):
            fh.write(f"Row {i}: {fly}\n")
    print(f"[OK] saved {key_path}")

print("[DONE] Per-odor exports saved into subfolders under OUT_DIR.)")

In [None]:
# JUPYTER CELL — Reaction matrices per odor + fly-category counts (During & After)
from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib import gridspec

# ───────── USER KNOB: spacing between rows (increase this to add space)
ROW_GAP = 0.6            # try 0.10 … 0.60 (higher = more space between rows)
HEIGHT_PER_GAP_IN = 3.0  # how many inches of figure height to add per 1.0 ROW_GAP
BOTTOM_SHIFT_IN = 0.50   # inches to lower the bottom row; increase to move further down

# ───────── PARAMETERS ─────────
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/code_maps.json")
FPS_DEFAULT       = 40
BEFORE_SEC        = 30.0
DURING_SEC        = 30.0
AFTER_WINDOW_SEC  = 30.0
THRESH_STD_MULT   = 4
MIN_SAMPLES_OVER  = 20

# Shift DURING window by latency at both start and end:
# e.g., DURING [30,60] → [30+lat, 60+lat]
ODOR_TRANSIT_LAT_S = overall_mean_latency_s

OUT_DIR           = Path("/home/ramanlab/Documents/cole/Results/Opto/Matrixs_DISTxANGLE")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Canon keys for grouping
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics benzaldehyde": "opto_benz_1",
    "Optogenetics Ethyl Butyrate": "opto_EB",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
    "opto_benz_1": "Benzaldehyde",
}
ODOR_ORDER = ["ACV", "3-octonol", "Benz", "EB", "10s_Odor_Benz", "opto_benz", "opto_EB", "opto_benz_1"]

# ───────── LOAD + DECODE ─────────
matrix = np.load(MATRIX_NPY)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)
ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset", "fly", "trial_type", "trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])
df = pd.DataFrame(matrix, columns=ordered_cols)

# decode label-coded columns
for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

# ensure fps numeric
if "fps" in df.columns:
    if "fps" in rev_maps:
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

# testing only
df = df[df["trial_type"].str.lower() == "testing"].copy()

# fill missing fps
FPS_FALLBACK = FPS_DEFAULT
df["fps"] = df["fps"].fillna(FPS_FALLBACK).replace([np.inf, -np.inf], FPS_FALLBACK)

# envelope columns exclude meta
env_cols = [c for c in ordered_cols if c not in meta_cols]

def _canon_odor(s: str) -> str:
    if not isinstance(s, str): return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())
df["dataset_canon"] = df["dataset"].apply(_canon_odor)

def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

# ───────── Custom trial→display-odor mapping per dataset ─────────
def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # hexanol controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

# ───────── Scoring on envelope row ─────────
def score_trial_from_env(env_row: pd.Series, fps: float) -> tuple[int, int]:
    """
    Compute During/After hits using a baseline from BEFORE.
    DURING is fully shifted by ODOR_TRANSIT_LAT_S at start and end:
      DURING: [BEFORE_SEC + ODOR_TRANSIT_LAT_S, BEFORE_SEC + DURING_SEC + ODOR_TRANSIT_LAT_S]
      AFTER:  the next AFTER_WINDOW_SEC immediately following DURING.
    """
    env = env_row.to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    if env.size == 0:
        return (0, 0)

    total = env.size

    # Indices (in samples)
    b_end   = int(round(BEFORE_SEC * fps))  # end of BEFORE
    shift   = int(round(ODOR_TRANSIT_LAT_S * fps))
    d_start = b_end + shift
    d_end   = b_end + int(round(DURING_SEC * fps)) + shift
    a_end   = d_end + int(round(AFTER_WINDOW_SEC * fps))

    # Clip/guard
    b_end   = max(0, min(b_end, total))
    d_start = max(b_end, min(d_start, total))
    d_end   = max(d_start, min(d_end, total))
    a_end   = max(d_end, min(a_end, total))

    # Windows
    before = env[:b_end]
    during = env[d_start:d_end]   # fully shifted DURING window
    after  = env[d_end:a_end]

    if before.size == 0:
        return (0, 0)

    # Threshold from BEFORE baseline
    theta = float(np.nanmean(before)) + THRESH_STD_MULT * float(np.nanstd(before))

    # Hits (require at least MIN_SAMPLES_OVER above theta)
    during_hit = int(np.sum(during > theta) >= MIN_SAMPLES_OVER) if during.size else 0
    after_hit  = int(np.sum(after  > theta) >= MIN_SAMPLES_OVER) if after.size  else 0
    return during_hit, after_hit

# ───────── Score all rows ─────────
scores = []
for _, row in df.iterrows():
    row_fps = float(row.get("fps", FPS_FALLBACK))
    d_hit, a_hit = score_trial_from_env(row[env_cols], row_fps)
    scores.append({
        "dataset": row["dataset_canon"],
        "fly": row["fly"],
        "trial": row["trial_label"],
        "trial_num": _trial_num(row["trial_label"]),
        "during_hit": d_hit,
        "after_hit": a_hit
    })
scores_df = pd.DataFrame(scores)

# ───────── Colormaps and helpers ─────────
cmap = ListedColormap(["0.7", "1.0", "0.0"])  # gray, white, black
norm = BoundaryNorm([-1.5, -0.5, 0.5, 1.5], cmap.N)

def style_trained_xticks_vertical(ax, labels, trained_disp: str, fontsize: int):
    """Vertical x labels; trained odor BLUE + UPPERCASE."""
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=90, ha="center", va="top", fontsize=fontsize)
    txts = []
    for tick in ax.get_xticklabels():
        txt = tick.get_text()
        if txt.strip().lower() == trained_disp.lower():
            tick.set_text(trained_disp.upper())
            tick.set_color("tab:blue")
        txts.append(tick.get_text())
    ax.set_xticklabels(txts, rotation=90, ha="center", va="top", fontsize=fontsize)
    ax.tick_params(axis="x", pad=2)

def compute_fly_category_counts(mat: np.ndarray, labels: list[str], trained_disp: str, include_hexanol: bool = False):
    if mat.size == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}

    trained_idx = [j for j, lab in enumerate(labels)
                   if lab.strip().lower() == trained_disp.lower()]

    other_idx = [j for j, lab in enumerate(labels)
                 if lab.strip().lower() != trained_disp.lower()
                 and (include_hexanol or lab.strip().lower() != "hexanol")]
    if len(trained_idx) == 0:
        return {"Trained only": 0, "Trained + Others": 0, "Others only": 0}

    counts = {"Trained only": 0, "Trained + Others": 0, "Others only": 0}
    for i in range(mat.shape[0]):
        row = mat[i, :]
        row = np.where(row < 0, 0, row)  # treat missing (-1) as 0 for categorization
        t_hit = np.any(row[trained_idx] == 1)
        o_hit = np.any(row[other_idx]   == 1) if len(other_idx) else False

        if t_hit and not o_hit:
            counts["Trained only"] += 1
        elif t_hit and o_hit:
            counts["Trained + Others"] += 1
        elif (not t_hit) and o_hit:
            counts["Others only"] += 1
    return counts

def plot_category_counts(ax, counts: dict, n_flies: int, title: str):
    cats = ["Trained only", "Trained + Others", "Others only"]
    raw = np.array([counts.get(c, 0) for c in cats], dtype=float)
    vals_pct = 100.0 * raw / float(n_flies) if n_flies > 0 else np.zeros_like(raw)

    x = np.arange(len(cats))
    bars = ax.bar(x, vals_pct, width=0.75, edgecolor="black", linewidth=0.8)

    ax.set_xticks(x)
    ax.set_xticklabels(cats, rotation=15, ha="right")
    ax.set_ylim(0, 100)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_ylabel("% of flies")
    ax.set_title(title, fontsize=12, weight="bold")
    ax.margins(x=0.05)

    for b, pct in zip(bars, vals_pct):
        ax.text(b.get_x() + b.get_width()/2,
                b.get_height() + 1.5,
                f"{pct:.0f}%",
                ha="center", va="bottom", fontsize=9)

# ─────── helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ───────── Build & save per-odor figures (into per-odor subfolders) ─────────
# Only iterate over odors present; preserve preferred order, then extras
present = scores_df["dataset"].unique().tolist()
ordered_present = [o for o in ODOR_ORDER if o in present]
extras = sorted([o for o in present if o not in ODOR_ORDER])

for odor in ordered_present + extras:
    sub = scores_df[scores_df["dataset"] == odor].copy()
    if sub.empty:
        print(f"[WARN] No testing trials for {odor}")
        continue

    # per-odor output directory
    odir = OUT_DIR / _safe_dirname(odor)
    odir.mkdir(parents=True, exist_ok=True)

    flies  = sorted(sub["fly"].unique())
    # Build trial order: trained odor first (2,4,5), then 1,3,6,7,8,9
    existing_trials = list(sub["trial"].unique())

    def _tnum(lbl):
        m = re.search(r"(\d+)", str(lbl))
        return int(m.group(1)) if m else -1

    desired_order = [2, 4, 5, 1, 3, 6, 7, 8, 9]

    by_num = {}
    for t in existing_trials:
        n = _tnum(t)
        if n not in by_num:
            by_num[n] = t

    ordered_trials = [by_num[n] for n in desired_order if n in by_num]
    leftovers = sorted([n for n in by_num.keys() if n not in set(desired_order) and n >= 0])
    ordered_trials += [by_num[n] for n in leftovers]

    trials = ordered_trials
    pretty_cols = [display_odor_for_trial(odor, t) for t in trials]

    # Matrices with sentinel -1 for missing, else 0/1
    D = -np.ones((len(flies), len(trials)), dtype=int)
    A = -np.ones((len(flies), len(trials)), dtype=int)
    for i, fly in enumerate(flies):
        fly_rows = sub[sub["fly"] == fly]
        for j, t in enumerate(trials):
            s = fly_rows[fly_rows["trial"] == t]
            if s.empty: continue
            D[i, j] = int(s["during_hit"].iloc[0])
            A[i, j] = int(s["after_hit"].iloc[0])

    odor_label   = DISPLAY_LABEL.get(odor, odor)
    trained_disp = DISPLAY_LABEL.get(odor, odor)
    n_flies = len(flies)
    n_trials = len(trials)

    # Figure size (height grows with flies and with ROW_GAP)
    base_fig_w = max(10.0, 0.70 * n_trials + 6.0)
    base_fig_h = max(5.0, n_flies * 0.26 + 3.8)
    fig_w = base_fig_w
    fig_h = base_fig_h + ROW_GAP * HEIGHT_PER_GAP_IN
    fig_h += BOTTOM_SHIFT_IN   # keep layout comfortable while lowering bottom row

    xtick_fs = 9 if n_trials <= 10 else (8 if n_trials <= 16 else 7)

    # NEW: Compute fly-category counts for During & After
    during_counts = compute_fly_category_counts(D, pretty_cols, trained_disp, include_hexanol=True)
    after_counts  = compute_fly_category_counts(A, pretty_cols, trained_disp, include_hexanol=True)

    # Create figure (manual layout)
    fig = plt.figure(figsize=(fig_w, fig_h), constrained_layout=False)
    gs  = gridspec.GridSpec(
        2, 2,
        height_ratios=[3.0, 1.25],
        width_ratios=[1, 1],
        hspace=ROW_GAP,
        wspace=0.10
    )

    axD  = fig.add_subplot(gs[0, 0])   # top-left  (During matrix)
    axA  = fig.add_subplot(gs[0, 1])   # top-right (After matrix)
    axDc = fig.add_subplot(gs[1, 0])   # bottom-left  (During categories)
    axAc = fig.add_subplot(gs[1, 1])   # bottom-right (After categories)

    # Top: matrices — vertical x labels, trained odor in blue
    imD = axD.imshow(D, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axD.set_title(
        f"{odor_label} — During (shifted +{ODOR_TRANSIT_LAT_S:.2f}s)",
        fontsize=14, weight="bold"
    )
    style_trained_xticks_vertical(axD, pretty_cols, trained_disp, fontsize=xtick_fs)
    axD.set_yticks([]); axD.set_ylabel(f"{n_flies} Flies", fontsize=11)

    imA = axA.imshow(A, cmap=cmap, norm=norm, aspect="auto", interpolation="nearest")
    axA.set_title(f"{odor_label} — After (first {int(AFTER_WINDOW_SEC)} s)", fontsize=14, weight="bold")
    style_trained_xticks_vertical(axA, pretty_cols, trained_disp, fontsize=xtick_fs)
    axA.set_yticks([]); axA.set_ylabel(f"{n_flies} Flies", fontsize=11)

    # Bottom: category count bars
    plot_category_counts(axDc, during_counts, n_flies, title="During — Fly Reaction Categories")
    plot_category_counts(axAc, after_counts,  n_flies, title=f"After (first {int(AFTER_WINDOW_SEC)} s) — Fly Reaction Categories")

    # Lower the bottom row by BOTTOM_SHIFT_IN (inches)
    shift_frac = BOTTOM_SHIFT_IN / fig_h
    for ax in (axDc, axAc):
        pos = ax.get_position()
        new_y0 = max(0.05, pos.y0 - shift_frac)
        ax.set_position([pos.x0, new_y0, pos.width, pos.height])

    # Save — into per-odor folder
    out_png = odir / f"reaction_matrix_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}_latency_{ODOR_TRANSIT_LAT_S:.3f}s_unordered.png"
    fig.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close(fig)
    print(f"[OK] saved {out_png}")

    # Row index → fly key — into per-odor folder
    key_path = odir / f"row_key_{odor.replace(' ', '_')}_{AFTER_WINDOW_SEC}.txt"
    with key_path.open("w") as fh:
        for i, fly in enumerate(flies):
            fh.write(f"Row {i}: {fly}\n")
    print(f"[OK] saved {key_path}")

    # ───────── CSV per odor with actual odor names — into per-odor folder ─────────
    sub_for_csv = sub.copy()
    sub_for_csv["odor_sent"] = sub_for_csv["trial"].apply(lambda t: display_odor_for_trial(odor, t))
    order_map = {t: i for i, t in enumerate(trials)}
    sub_for_csv["trial_ord"] = sub_for_csv["trial"].map(order_map).fillna(10**9).astype(int)
    sub_for_csv = sub_for_csv.sort_values(["fly", "trial_ord", "trial_num", "trial"])

    export_cols = ["dataset", "fly", "trial_num", "odor_sent", "during_hit", "after_hit"]
    out_csv = odir / f"binary_reactions_{odor.replace(' ', '_')}.csv"
    sub_for_csv[export_cols].to_csv(out_csv, index=False)
    print(f"[OK] saved {out_csv}")

print("[DONE] Per-odor exports saved into subfolders under OUT_DIR.)")

## Envolope

In [None]:
# JUPYTER CELL — Per-fly envelope plots from MATRIX with trained-odor styling
# (after-period limited to 30 s) + per-trial threshold line

from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt

# ========= PARAMETERS =========
MATRIX_NPY        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/envelope_matrix_float16.npy")
CODES_JSON        = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/code_maps.json")

FPS_DEFAULT       = 40.0       # fallback if fps missing/invalid
ODOR_ON_S         = 30.0
ODOR_OFF_S        = 60.0
AFTER_SHOW_S      = 30.0       # show only first 30 s after odor OFF

# Threshold params — matches your scoring code design (baseline is [0, ODOR_ON_S))
THRESH_STD_MULT   = 4.0        # θ = μ_before + k·σ

# Odor transit latency (mean time for plume to reach the fly)
ODOR_TRANSIT_LAT_S = overall_mean_latency_s

# IMPORTANT: extend visible window so AFTER is measured from shifted OFF
X_MAX_LIMIT       = ODOR_OFF_S + ODOR_TRANSIT_LAT_S + AFTER_SHOW_S

OUT_DIR = Path("/home/ramanlab/Documents/cole/Results/Opto/Envlope_DISTxANGLE")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ========= LOAD MATRIX + METADATA =========
matrix = np.load(MATRIX_NPY, allow_pickle=False)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)

ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset","fly","trial_type","trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])

df = pd.DataFrame(matrix, columns=ordered_cols)

# Decode labels -> strings
for c in decode_cols:
    df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

# Ensure fps exists and is numeric
if "fps" in df.columns:
    if "fps" in rev_maps:  # if coded (rare)
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
else:
    df["fps"] = np.nan

# Keep only testing trials
df = df[df["trial_type"].str.lower()=="testing"].copy()

# Fill missing/invalid fps with fallback
df["fps"] = df["fps"].replace([np.inf, -np.inf], np.nan).fillna(FPS_DEFAULT)

# env_* columns exclude meta (including fps)
env_cols  = [c for c in ordered_cols if c not in meta_cols]

# ========= Canon keys & display names =========
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "Ethyl Butyrate": "EB",
    "Optogenetics benzaldehyde": "opto_benz",
    "Optogenetics benzaldehyde": "opto_benz_1",
    "Optogenetics Ethyl Butyrate": "opto_EB",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_benz_1": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
}

def _canon_dataset(s: str) -> str:
    if not isinstance(s, str):
        return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())

df["dataset_canon"] = df["dataset"].apply(_canon_dataset)

# helper: safe dir name
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

# ========= Helpers =========
def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

def _extract_env(row: pd.Series) -> np.ndarray:
    env = row[env_cols].to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    return env

def _compute_theta(env_full: np.ndarray, fps: float) -> float:
    """θ = mean(before) + k*std(before), where before = [0, ODOR_ON_S)."""
    if env_full.size == 0 or fps <= 0:
        return np.nan
    b_end = int(ODOR_ON_S * fps)
    b_end = min(b_end, env_full.size)
    before = env_full[:b_end]
    if before.size == 0:
        return np.nan
    mu = float(np.nanmean(before))
    sd = float(np.nanstd(before))
    return mu + THRESH_STD_MULT * sd

def _is_trained_odor(dataset_canon: str, odor_name: str) -> bool:
    trained = DISPLAY_LABEL.get(dataset_canon, dataset_canon)
    return str(odor_name).strip().lower() == str(trained).strip().lower()

def style_trained_title(ax, odor_label: str):
    ax.set_title(
        odor_label.upper(),
        loc="left",
        fontsize=11,
        weight="bold",
        pad=2,
        color="tab:blue",
    )

# ========= MAKE FIGURES PER FLY =========
for fly, g in df.groupby("fly"):
    g = g.sort_values("trial_label", key=lambda s: s.map(_trial_num))
    dataset_canon = _canon_dataset(g["dataset"].iloc[0])

    # (odor_name, t_visible, env_visible, theta, is_trained)
    trial_curves = []
    y_max = 0.0

    for _, row in g.iterrows():
        env_full = _extract_env(row)
        if env_full.size == 0:
            continue

        row_fps = float(row.get("fps", FPS_DEFAULT)) if np.isfinite(row.get("fps", np.nan)) else FPS_DEFAULT
        t_full = np.arange(env_full.size, dtype=float) / max(row_fps, 1e-9)

        theta = _compute_theta(env_full, row_fps)

        # Clip to [0, X_MAX_LIMIT] for visualization
        mask = (t_full <= X_MAX_LIMIT + 1e-9)
        t = t_full[mask]
        env = env_full[mask]
        if t.size == 0:
            continue

        odor_name = display_odor_for_trial(dataset_canon, row["trial_label"])
        trial_curves.append((odor_name, t, env, theta, _is_trained_odor(dataset_canon, odor_name)))

        local_max = np.nanmax(env) if np.isfinite(env).any() else 0.0
        if np.isfinite(theta):
            local_max = max(local_max, theta)
        y_max = max(y_max, float(local_max))

    if not trial_curves:
        print(f"[WARN] {fly}: no usable testing trials; skipping.")
        continue

    plt.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "axes.spines.top": False, "axes.spines.right": False,
        "axes.linewidth": 0.8, "xtick.direction": "out", "ytick.direction": "out",
        "font.size": 10,
    })

    n = len(trial_curves)
    fig_h = max(3.0, n * 1.6 + 1.5)
    fig, axes = plt.subplots(n, 1, figsize=(10, fig_h), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, (odor_name, t, env, theta, is_trained) in zip(axes, trial_curves):
        ax.plot(t, env, linewidth=1.2, color='black')

        # Nominal valve timing markers (hardware command times)
        ax.axvline(ODOR_ON_S,  linestyle='--', linewidth=1.0, color='black')
        ax.axvline(ODOR_OFF_S, linestyle='--', linewidth=1.0, color='black')

        # Effective plume windows using latency:
        on_lat_end   = min(ODOR_ON_S  + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        off_lat_end  = min(ODOR_OFF_S + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        eff_on_start = min(on_lat_end, X_MAX_LIMIT)
        eff_on_end   = min(off_lat_end, X_MAX_LIMIT)

        # Shade start latency (red) and effective ON (gray)
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_ON_S, on_lat_end, alpha=0.25, color='red')
        if eff_on_end > eff_on_start:
            ax.axvspan(eff_on_start, eff_on_end, alpha=0.15, color='gray')

        # Shade end latency (red)
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_OFF_S, off_lat_end, alpha=0.25, color='red')

        # Threshold line
        if np.isfinite(theta):
            ax.axhline(theta, linestyle='-', linewidth=1.0, color='tab:red', alpha=0.9)

        ax.set_ylim(0, y_max * 1.02 if y_max > 0 else 1.0)
        ax.set_xlim(0, X_MAX_LIMIT)
        ax.margins(x=0, y=0.02)
        ax.set_ylabel("DIST x ANGLE RMS", fontsize=10)

        if is_trained:
            style_trained_title(ax, odor_name)
        else:
            ax.set_title(odor_name, loc="left", fontsize=11, weight="bold", pad=2, color="black")

    axes[-1].set_xlabel("Time (s)", fontsize=11)

    # SINGLE legend on the figure
    on_handle      = plt.Line2D([0], [0], linestyle='--', linewidth=1.0, color='black', label='Valve on/off (command)')
    transit_handle = plt.Rectangle((0,0), 1, 1, alpha=0.25, color='red',  label=f'Odor transit (~{ODOR_TRANSIT_LAT_S:.2f}s)')
    span_handle    = plt.Rectangle((0,0), 1, 1, alpha=0.15, color='gray', label='Effective odor-on at fly')
    theta_handle   = plt.Line2D([0], [0], linestyle='-',  linewidth=1.0, color='tab:red', label=r'$\theta = \mu_{\mathrm{before}} + k\,\sigma_{\mathrm{before}}$')

    fig.legend(
        handles=[on_handle, transit_handle, span_handle, theta_handle],
        labels=[
            'Valve on/off (command)',
            f'Odor transit (~{ODOR_TRANSIT_LAT_S:.2f}s) — start & end',
            'Effective odor-on at fly',
            r'$\theta = \mu_\mathrm{before} + k\,\sigma_\mathrm{before}$'
        ],
        title=f'Threshold: k = {int(THRESH_STD_MULT) if THRESH_STD_MULT.is_integer() else THRESH_STD_MULT}',
        loc='upper right',
        bbox_to_anchor=(0.98, 0.97),
        frameon=True,
        fontsize=9,
        title_fontsize=9,
    )

    fig.suptitle(f"{fly} RMS of Proboscis - Eye Distance Percentage", y=0.995, fontsize=14, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    # === SAVE: write this fly's figure into each odor-specific folder observed for this fly ===
    odors_present = sorted({name for (name, _, _, _, _) in trial_curves})
    for odor_name in odors_present:
        odir = OUT_DIR / _safe_dirname(odor_name)
        odir.mkdir(parents=True, exist_ok=True)
        out_png = odir / f"{fly}_envelope_trials_by_odor_{AFTER_SHOW_S}_shifted.png"
        fig.savefig(out_png)
        print(f"[OK] Saved {out_png}")

    plt.close(fig)

## Combined + RMS Alone Envolpe

In [None]:
# JUPYTER CELL — Overlay envelopes per fly & per testing-trial across TWO matrices
# Threshold: θ = μ_global(fly, source, all pre-odor) + k * σ_trial(pre-odor)

from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt

# ========= INPUTS (two sources to overlay) =========
SOURCES = {
    "RMS x Angle": {
        "MATRIX_NPY": Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/envelope_matrix_float16.npy"),
        "CODES_JSON": Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto_combined/code_maps.json"),
    },
    "RMS": {
        "MATRIX_NPY": Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy"),
        "CODES_JSON": Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json"),
    },
}

# ========= PARAMETERS =========
FPS_DEFAULT         = 40.0
ODOR_ON_S           = 30.0
ODOR_OFF_S          = 60.0
AFTER_SHOW_S        = 30.0
THRESH_STD_MULT     = 4.0  # k

# If overall_mean_latency_s isn't in scope, default to 0.0
try:
    ODOR_TRANSIT_LAT_S = float(overall_mean_latency_s)
    if not np.isfinite(ODOR_TRANSIT_LAT_S): ODOR_TRANSIT_LAT_S = 0.0
except Exception:
    ODOR_TRANSIT_LAT_S = 0.0

X_MAX_LIMIT         = ODOR_OFF_S + ODOR_TRANSIT_LAT_S + AFTER_SHOW_S

OUT_DIR = Path("/home/ramanlab/Documents/cole/Results/Manual/Compare_Envlopes")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ========= CANONICALIZATION =========
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "ethyl butyrate": "EB",
    "optogenetics benzaldehyde": "opto_benz",
    "optogenetics benzaldehyde": "opto_benz_1",
    "optogenetics ethyl butyrate": "opto_EB",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "10s_Odor_Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "ret_EB": "Ethyl Butyrate",
    "opto_benz": "Benzaldehyde",
    "opto_benz_1": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
}

def _canon_dataset(s: str) -> str:
    if not isinstance(s, str):
        return "UNKNOWN"
    return ODOR_CANON.get(s.strip().lower(), s.strip())

# NEW: safe folder names
def _safe_dirname(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]+', '_', str(s)).strip('_')

def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "ret_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "opto_benz_1":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Ethyl Butyrate"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    return trial_label

# ========= LOADING HELPERS =========
def load_source_df(tag: str, paths: dict):
    matrix = np.load(paths["MATRIX_NPY"], allow_pickle=False)
    with open(paths["CODES_JSON"], "r") as f:
        meta = json.load(f)

    ordered_cols = meta["column_order"]
    code_maps    = meta["code_maps"]
    rev_maps     = {c: {v:k for k, v in m.items()} for c, m in code_maps.items()}

    decode_cols = [c for c in ["dataset","fly","trial_type","trial_label"] if c in ordered_cols]
    meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])

    df = pd.DataFrame(matrix, columns=ordered_cols)

    # decode label columns
    for c in decode_cols:
        df[c] = df[c].astype(int).map(rev_maps[c]).fillna("UNKNOWN")

    if "fps" in df.columns:
        if "fps" in rev_maps:  # rarely coded
            df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
        df["fps"] = pd.to_numeric(df["fps"], errors="coerce")
    else:
        df["fps"] = np.nan

    # testing only
    df = df[df["trial_type"].str.lower()=="testing"].copy()
    df["fps"] = df["fps"].replace([np.inf, -np.inf], np.nan).fillna(FPS_DEFAULT)

    df["dataset_canon"] = df["dataset"].apply(_canon_dataset)
    env_cols = [c for c in ordered_cols if c not in meta_cols]

    df["_env_cols"] = [env_cols]*len(df)  # attach per-row for extraction
    df["_source"]   = tag
    return df

def _extract_env(row: pd.Series) -> np.ndarray:
    env_cols = row.get("_env_cols", [])
    env = row[env_cols].to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    return env

def _trial_baseline(env_full: np.ndarray, fps: float) -> np.ndarray:
    """Return pre-odor samples for this trial [0, ODOR_ON_S)."""
    if env_full.size == 0 or not np.isfinite(fps) or fps <= 0:
        return np.array([], dtype=float)
    b_end = int(ODOR_ON_S * fps)
    b_end = min(b_end, env_full.size)
    return env_full[:b_end]

# ========= LOAD BOTH SOURCES =========
dfs = []
for tag, p in SOURCES.items():
    try:
        dfs.append(load_source_df(tag, p))
    except Exception as e:
        print(f"[WARN] Failed to load {tag}: {e}")

if not dfs:
    raise RuntimeError("No sources loaded.")

all_df = pd.concat(dfs, ignore_index=True)
all_flies = sorted(all_df["fly"].unique(), key=lambda x: str(x))

# ========= PLOTTING STYLES PER SOURCE =========
SOURCE_STYLES = {
    "RMS x Angle": dict(color="tab:blue",   label="RMS x Angle"),
    "RMS":          dict(color="tab:orange", label="RMS"),
}
default_style = dict(color=None, label=None)

plt.rcParams.update({
    "figure.dpi": 300, "savefig.dpi": 300,
    "axes.spines.top": False, "axes.spines.right": False,
    "axes.linewidth": 0.8, "xtick.direction": "out", "ytick.direction": "out",
    "font.size": 10,
})

# ========= BUILD & PLOT =========
for fly in all_flies:
    df_fly = all_df[all_df["fly"] == fly].copy()
    if df_fly.empty:
        continue

    # ---- NEW: compute μ_global per (fly, source) from ALL trials' pre-odor samples
    global_mu_by_source = {}
    for src, gsrc in df_fly.groupby("_source"):
        pooled = []
        for _, row in gsrc.iterrows():
            env_full = _extract_env(row)
            fps = float(row.get("fps", FPS_DEFAULT))
            pre = _trial_baseline(env_full, fps)
            if pre.size:
                pooled.append(pre[np.isfinite(pre)])
        if pooled:
            pooled_all = np.concatenate(pooled)
            pooled_all = pooled_all[np.isfinite(pooled_all)]
            if pooled_all.size:
                global_mu_by_source[src] = float(np.mean(pooled_all))
                continue
        global_mu_by_source[src] = np.nan  # fallback later

    # keys: trial_label -> list of (source, t, env, theta, odor_name, dataset_canon, is_trained)
    trials = {}
    y_max  = 0.0

    for _, row in df_fly.iterrows():
        env_full = _extract_env(row)
        if env_full.size == 0:
            continue

        fps = float(row.get("fps", FPS_DEFAULT))
        t_full = np.arange(env_full.size, dtype=float) / max(fps, 1e-9)

        # clip visual window
        mask = (t_full <= X_MAX_LIMIT + 1e-9)
        t    = t_full[mask]
        env  = env_full[mask]
        if t.size == 0:
            continue

        # Trial-specific σ from this trial's pre-odor; μ is global per source
        pre_this = _trial_baseline(env_full, fps)
        if pre_this.size:
            sigma_trial = float(np.nanstd(pre_this))
            mu_trial    = float(np.nanmean(pre_this))  # for fallback only
        else:
            sigma_trial = np.nan
            mu_trial    = np.nan

        src = row["_source"]
        mu_global = global_mu_by_source.get(src, np.nan)
        if not np.isfinite(mu_global):
            mu_global = mu_trial

        theta = (mu_global + THRESH_STD_MULT * sigma_trial) if (np.isfinite(mu_global) and np.isfinite(sigma_trial)) else np.nan

        dsc        = str(row.get("dataset_canon", "UNKNOWN"))
        trial_lab  = str(row.get("trial_label", "UNKNOWN"))
        odor_name  = display_odor_for_trial(dsc, trial_lab)
        is_trained = str(odor_name).strip().lower() == str(DISPLAY_LABEL.get(dsc, dsc)).strip().lower()

        trials.setdefault(trial_lab, []).append(
            (src, t, env, theta, odor_name, dsc, is_trained)
        )

        local_max = np.nanmax(env) if np.isfinite(env).any() else 0.0
        if np.isfinite(theta):
            local_max = max(local_max, theta)
        y_max = max(y_max, float(local_max))

    if not trials:
        print(f"[WARN] {fly}: no usable testing trials across sources; skipping.")
        continue

    # stable trial order
    trial_labels_sorted = sorted(trials.keys(), key=_trial_num)
    n = len(trial_labels_sorted)
    fig_h = max(3.0, n * 1.6 + 1.5)
    fig, axes = plt.subplots(n, 1, figsize=(10, fig_h), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, trial_lab in zip(axes, trial_labels_sorted):
        curves = trials[trial_lab]

        # Title odor from first curve
        odor_name  = curves[0][4]
        dsc        = curves[0][5]
        is_trained = curves[0][6]

        # Valve command lines
        ax.axvline(ODOR_ON_S,  linestyle='--', linewidth=1.0, color='black')
        ax.axvline(ODOR_OFF_S, linestyle='--', linewidth=1.0, color='black')

        # Latency shading
        on_lat_end   = min(ODOR_ON_S  + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        off_lat_end  = min(ODOR_OFF_S + ODOR_TRANSIT_LAT_S, X_MAX_LIMIT)
        eff_on_start = min(on_lat_end, X_MAX_LIMIT)
        eff_on_end   = min(off_lat_end, X_MAX_LIMIT)
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_ON_S, on_lat_end, alpha=0.25, color='red')
        if eff_on_end > eff_on_start:
            ax.axvspan(eff_on_start, eff_on_end, alpha=0.15, color='gray')
        if ODOR_TRANSIT_LAT_S > 0:
            ax.axvspan(ODOR_OFF_S, off_lat_end, alpha=0.25, color='red')

        # Overlay curves from each source on this trial
        for (src, t, env, theta, _odor, _dsc, _is_trained) in curves:
            st = SOURCE_STYLES.get(src, default_style)
            line = ax.plot(t, env, linewidth=1.3, **{k:v for k,v in st.items() if v is not None})
            # Per-trace threshold (global μ + trial σ)
            if np.isfinite(theta):
                ax.axhline(theta, linestyle=':', linewidth=1.0, color=line[0].get_color(), alpha=0.9)

        ax.set_ylim(0, y_max * 1.02 if y_max > 0 else 1.0)
        ax.set_xlim(0, X_MAX_LIMIT)
        ax.margins(x=0, y=0.02)
        ax.set_ylabel("DIST or DISTxANGLE", fontsize=10)

        if is_trained:
            ax.set_title(f"{odor_name} — {trial_lab}", loc="left", fontsize=11, weight="bold", pad=2, color="tab:blue")
        else:
            ax.set_title(f"{odor_name} — {trial_lab}", loc="left", fontsize=11, weight="bold", pad=2, color="black")

    axes[-1].set_xlabel("Time (s)", fontsize=11)

    # Legend
    src_handles = [plt.Line2D([0],[0], linewidth=1.3, color=SOURCE_STYLES[s]['color'], label=SOURCE_STYLES[s]['label'])
                   for s in SOURCES.keys() if s in SOURCE_STYLES]
    on_handle      = plt.Line2D([0], [0], linestyle='--', linewidth=1.0, color='black', label='Valve on/off (command)')
    transit_handle = plt.Rectangle((0,0), 1, 1, alpha=0.25, color='red',  label=f'Odor transit (~{ODOR_TRANSIT_LAT_S:.2f}s)')
    span_handle    = plt.Rectangle((0,0), 1, 1, alpha=0.15, color='gray', label='Effective odor-on at fly')
    theta_handle   = plt.Line2D([0], [0], linestyle=':',  linewidth=1.0, color='black', label=r'$\theta=\mu_{\mathrm{global}}+k\sigma_{\mathrm{trial}}$')

    fig = plt.gcf()
    fig.legend(
        handles=src_handles + [on_handle, transit_handle, span_handle, theta_handle],
        loc='upper right',
        bbox_to_anchor=(0.98, 0.97),
        frameon=True,
        fontsize=9,
        title=f'k = {int(THRESH_STD_MULT) if float(THRESH_STD_MULT).is_integer() else THRESH_STD_MULT}',
        title_fontsize=9,
    )

    fig.suptitle(f"{fly} — Envelope overlay by testing trial (global μ per source, σ per trial)", y=0.995, fontsize=14, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.97])

    # === SAVE: write this fly's overlay into each odor-specific folder observed for this fly ===
    odors_present = sorted({curves[0][4] for curves in trials.values()})
    for odor_name in odors_present:
        odir = OUT_DIR / _safe_dirname(odor_name)
        odir.mkdir(parents=True, exist_ok=True)
        out_png = odir / f"{fly}_overlay_envelope_by_trial_{AFTER_SHOW_S}s_shifted.png"
        fig.savefig(out_png)
        print(f"[OK] Saved {out_png}")

    plt.close(fig)


## Training

# Special Plots

In [None]:
# JUPYTER CELL — RMS trace + per-trial heatmap (testing 2/4/5/8 only; hard 30–60s ON; no latency)
from pathlib import Path
import numpy as np
import pandas as pd
import json, re
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype']  = 42
mpl.rcParams['savefig.transparent'] = False  # ensure opaque background

FILL_GRAY = "#e6e6e6"  # light gray, opaque (no alpha)
RASTERIZE_HEATMAPS = True  # keep heatmaps raster in vector outputs (clean EPS)

# ========= PARAMETERS =========
MATRIX_NPY  = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/envelope_matrix_float16.npy")
CODES_JSON  = Path("/home/ramanlab/Documents/cole/Data/single_matrix_opto/code_maps.json")
OUT_DIR     = Path("/home/ramanlab/Documents/cole/Results/rms_trace_plus_heatmaps_opto")
OUT_DIR.mkdir(parents=True, exist_ok=True)

FPS_DEFAULT = 40.0
ODOR_ON_S   = 30.0
ODOR_OFF_S  = 60.0
AFTER_S     = 30.0                 # show only first 30 s after OFF
X_MAX_LIMIT = ODOR_OFF_S + AFTER_S # = 90 s

THRESH_STD_MULT = 4.0              # θ = μ_before + k·σ
KEEP_TEST_NUMS  = {2, 4, 5, 6, 8}     # testing trials to include

# ---- Visual knobs ----
HM_ONOFF_LS    = "--"   # heatmap odor on/off line style
HM_ONOFF_LW    = 1.5    # heatmap odor on/off line width
TRACE_ONOFF_LW = 1.5    # trace odor on/off line width

# ---- Figure / colorbar layout knobs ----
FIG_W            = 16.0  # wider figure
CBAR_COL_RATIO   = 0.04  # dedicated colorbar column width (relative to plot column)
HSPACE           = 0.60  # vertical spacing between rows
WSPACE           = 0.125  # spacing between plot column and colorbar
TITLE_Y          = 0.95  # extra gap above first axes (used with constrained_layout)

# ========= LOAD MATRIX + METADATA =========
matrix = np.load(MATRIX_NPY, allow_pickle=False)
with open(CODES_JSON, "r") as f:
    meta = json.load(f)

ordered_cols = meta["column_order"]
code_maps    = meta["code_maps"]
rev_maps     = {c: {v: k for k, v in m.items()} for c, m in code_maps.items()}

decode_cols = [c for c in ["dataset","fly","trial_type","trial_label"] if c in ordered_cols]
meta_cols   = decode_cols + (["fps"] if "fps" in ordered_cols else [])

df = pd.DataFrame(matrix, columns=ordered_cols)

# Decode categorical codes (if present)
for c in decode_cols:
    if df[c].dtype != object:
        df[c] = df[c].astype(int).map(rev_maps.get(c, {})).fillna(df[c].astype(str))

# FPS handling
if "fps" in df.columns:
    if "fps" in rev_maps:  # rare coded fps
        df["fps"] = df["fps"].astype(int).map(rev_maps["fps"])
    df["fps"] = pd.to_numeric(df["fps"], errors="coerce").fillna(FPS_DEFAULT)
else:
    df["fps"] = FPS_DEFAULT

# Keep only testing trials
df = df[df["trial_type"].str.lower() == "testing"].copy()

# Envelope cols (all non-meta)
env_cols = [c for c in ordered_cols if c not in meta_cols]

# ========= ODOR CANON + LABELS =========
ODOR_CANON = {
    "acv": "ACV",
    "apple cider vinegar": "ACV",
    "apple-cider-vinegar": "ACV",
    "3-octonol": "3-octonol",
    "3 octonol": "3-octonol",
    "3-octanol": "3-octonol",
    "3 octanol": "3-octonol",
    "benz": "Benz",
    "benzaldehyde": "Benz",
    "benz-ald": "Benz",
    "benzadhyde": "Benz",
    "ethyl butyrate": "EB",
    "ret_eb": "ret_EB",
    "10s_odor_benz": "10s_Odor_Benz",
    "opto eb": "opto_EB",
    "opto_eb": "opto_EB",
    "opto benz": "opto_benz",
    "opto_benz": "opto_benz",
}
DISPLAY_LABEL = {
    "ACV": "ACV",
    "3-octonol": "3-Octonol",
    "Benz": "Benzaldehyde",
    "EB": "Ethyl Butyrate",
    "ret_EB": "Ethyl Butyrate (ret.)",
    "10s_Odor_Benz": "Benzaldehyde",
    "opto_benz": "Benzaldehyde",
    "opto_EB": "Ethyl Butyrate",
}

def _canon_dataset(s: str) -> str:
    if not isinstance(s, str):
        return "UNKNOWN"
    key = s.strip()
    low = key.lower()
    return ODOR_CANON.get(low, key)

# ========= HELPERS =========
def _trial_num(label: str) -> int:
    m = re.search(r"(\d+)", str(label))
    return int(m.group(1)) if m else -1

def _extract_env(row: pd.Series) -> np.ndarray:
    env = row[env_cols].to_numpy(dtype=float)
    env = env[np.isfinite(env) & (env > 0)]
    return env

def _compute_theta(env_full: np.ndarray, fps: float) -> float:
    if env_full.size == 0 or fps <= 0:
        return np.nan
    b_end = int(ODOR_ON_S * fps)
    b_end = min(b_end, env_full.size)
    before = env_full[:b_end]
    if before.size == 0:
        return np.nan
    mu = float(np.nanmean(before))
    sd = float(np.nanstd(before))
    return mu + THRESH_STD_MULT * sd

# --- Odor naming per dataset + trial (ADDED BACK) ---
def display_odor_for_trial(dataset_canon: str, trial_label: str) -> str:
    n = _trial_num(trial_label)
    if n in (1, 3):  # controls
        return "Hexanol"
    if n in (2, 4, 5):  # trained odor
        return DISPLAY_LABEL.get(dataset_canon, dataset_canon)

    if dataset_canon == "ACV":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    elif dataset_canon == "3-octonol":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Citral"
        if n == 8: return "Linalool"
    elif dataset_canon == "Benz":
        if n == 6: return "Citral"
        if n == 7: return "Linalool"
    elif dataset_canon == "EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "ret_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "10s_Odor_Benz":
        if n == 6: return "Benzaldehyde"
        if n == 7: return "Benzaldehyde"
    elif dataset_canon == "opto_EB":
        if n == 6: return "Apple Cider Vinegar"
        if n == 7: return "3-Octonol"
        if n == 8: return "Benzaldehyde"
        if n == 9: return "Citral"
        if n == 10: return "Linalool"
    elif dataset_canon == "opto_benz":
        if n == 6: return "3-Octonol"
        if n == 7: return "Benzaldehyde"
        if n == 8: return "Citral"
        if n == 9: return "Linalool"
    return trial_label

# ========= MAKE FIGURES PER FLY =========
for fly, g in df.groupby("fly"):
    # Only keep testing 2/4/5/8
    g = g[g["trial_label"].map(_trial_num).isin(KEEP_TEST_NUMS)].copy()
    if g.empty:
        print(f"[SKIP] {fly}: no testing trials in {sorted(KEEP_TEST_NUMS)}.")
        continue

    g = g.sort_values("trial_label", key=lambda s: s.map(_trial_num))
    dataset_canon = _canon_dataset(g["dataset"].iloc[0])

    # Collect per-trial data and global limits
    trials = []
    y_max = 0.0
    heat_vmin, heat_vmax = np.inf, -np.inf

    for _, row in g.iterrows():
        env_full = _extract_env(row)
        if env_full.size == 0:
            continue

        fps = float(row["fps"]) if np.isfinite(row["fps"]) else FPS_DEFAULT
        t_full = np.arange(env_full.size, dtype=float) / max(fps, 1e-9)

        # clip to [0, X_MAX_LIMIT]
        mask = (t_full <= X_MAX_LIMIT + 1e-9)
        t = t_full[mask]
        env = env_full[mask]
        if t.size == 0:
            continue

        theta = _compute_theta(env_full, fps)
        tn = _trial_num(row["trial_label"])
        odor_name = display_odor_for_trial(dataset_canon, row["trial_label"])

        local_max = np.nanmax(env) if np.isfinite(env).any() else 0.0
        if np.isfinite(theta):
            local_max = max(local_max, theta)
        y_max = max(y_max, float(local_max))

        if np.isfinite(env).any():
            heat_vmin = min(heat_vmin, float(np.nanmin(env)))
            heat_vmax = max(heat_vmax, float(np.nanmax(env)))

        trials.append((tn, odor_name, t, env, theta))

    if not trials:
        print(f"[SKIP] {fly}: no usable trials after filtering.")
        continue

    # Safe heat range
    if (not np.isfinite(heat_vmin)) or (not np.isfinite(heat_vmax)) or (heat_vmin == heat_vmax):
        heat_vmin, heat_vmax = 0.0, max(1.0, y_max)

    # Layout: 2 rows per trial (trace, heatmap) and 2 columns ([plots | colorbar])
    n = len(trials)
    height_ratios = []
    for _ in range(n):
        height_ratios.extend([2.2, 1.2])  # trace taller than heatmap

    # >>> Use constrained_layout so colorbar ticks/labels are NEVER clipped
    fig_h = max(4.0, n * 2.2 + 1.5)
    fig = plt.figure(figsize=(FIG_W, fig_h), constrained_layout=True)
    gs = GridSpec(
        nrows=2*n, ncols=2, figure=fig,
        height_ratios=height_ratios,
        width_ratios=[1.0, CBAR_COL_RATIO],
        hspace=HSPACE, wspace=WSPACE
    )

    axes_heat = []
    for idx, (tn, odor_name, t, env, theta) in enumerate(trials):
        ax_trace = fig.add_subplot(gs[2*idx, 0])
        ax_heat  = fig.add_subplot(gs[2*idx + 1, 0])
        axes_heat.append(ax_heat)

        # ---- top: RMS trace ----
        ax_trace.plot(t, env, linewidth=1.2, color="black")
        ax_trace.axvline(ODOR_ON_S,  linestyle="--", linewidth=TRACE_ONOFF_LW, color="black")
        ax_trace.axvline(ODOR_OFF_S, linestyle="--", linewidth=TRACE_ONOFF_LW, color="black")
        ax_trace.axvspan(ODOR_ON_S, ODOR_OFF_S, color=FILL_GRAY)

        if np.isfinite(theta):
            ax_trace.axhline(theta, linestyle="-", linewidth=1.0, color="tab:red",
                             label=r'$\theta = \mu_\mathrm{before} + 4\sigma$')

        ax_trace.set_xlim(0, X_MAX_LIMIT)
        ax_trace.set_ylim(0, y_max * 1.03 if y_max > 0 else 1.0)
        ax_trace.set_ylabel("RMS", fontsize=10)
        ax_trace.set_title(f"{odor_name}", loc="left", fontsize=11, weight="bold", pad=2)

        if idx == 0 and np.isfinite(theta):
            ax_trace.legend(loc="upper right", fontsize=9, frameon=True, framealpha=1.0)

        # ---- bottom: 1×T heatmap ----
        heat = env[np.newaxis, :]
        ax_heat.imshow(
            heat,
            aspect="auto",
            origin="lower",
            extent=(t[0] if t.size else 0, t[-1] if t.size else X_MAX_LIMIT, 0, 1),
            vmin=heat_vmin, vmax=heat_vmax,
            cmap="viridis", interpolation="nearest",
            rasterized=RASTERIZE_HEATMAPS
        )
        
        # Gray ON window + thick dashed black ON/OFF lines
        ax_heat.axvline(ODOR_ON_S,  linestyle=HM_ONOFF_LS, linewidth=HM_ONOFF_LW, color="black")
        ax_heat.axvline(ODOR_OFF_S, linestyle=HM_ONOFF_LS, linewidth=HM_ONOFF_LW, color="black")

        ax_heat.set_xlim(0, X_MAX_LIMIT)
        ax_heat.set_yticks([])   # remove ticks
        ax_heat.set_ylabel("")   # remove label

        if idx == (n - 1):
            ax_heat.set_xlabel("Time (s)", fontsize=11)
        else:
            ax_heat.set_xticklabels([])

    # ---- dedicated colorbar column (prevents overlap) ----
    cax = fig.add_subplot(gs[:, 1])  # entire right column
    sm  = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=heat_vmin, vmax=heat_vmax), cmap="viridis")
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label("RMS", fontsize=10)
    cbar.ax.tick_params(labelsize=9)

    # Suptitle with extra gap from first axes; constrained_layout manages spacing
    fig.suptitle(f"{fly} — RMS trace (top) + Heatmap (bottom)",
                 y=TITLE_Y, fontsize=14, weight="bold")

    out_base = OUT_DIR / f"{fly}_rms_trace_plus_heatmaps_testing_2_4_5_8"
    fig.savefig(out_base.with_suffix(".png"), dpi=300, bbox_inches="tight")
    fig.savefig(out_base.with_suffix(".eps"), format="eps", bbox_inches="tight")  # no alpha anywhere

    plt.close(fig)
    print(f"[OK] Saved {out_base}")
