In [None]:
# assumes existence of all_saccade_collection from the magnitude_velocity_degrees_clean.ipynb pipe

In [67]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go

def launch_main_sequence_lasso_plotly(
    df: pd.DataFrame,
    x_col: str = "magnitude_raw_angular",
    y_col: str = "peak_velocity",
    *,
    convert_deg_per_frame_to_ms: bool = False,
    frame_rate_fps: float = 60.0,
    title: str = "Main sequence (lasso to collect outliers)",
    point_size: float = 4.0,
    point_opacity: float = 0.45,
    show_regression: bool = True,
):
    """
    Returns:
        fig:      plotly.graph_objects.FigureWidget
        view_df:  cleaned DataFrame actually plotted; has columns ['A','V','orig_idx']
        stats:    dict with slope/intercept/r/R2 if show_regression else None
    """
    data = df.copy()
    if y_col not in data.columns:
        raise ValueError(f"'{y_col}' not in df. Provide a column or pre-compute it.")

    # Optional unit conversion (deg/frame -> deg/ms)
    if convert_deg_per_frame_to_ms:
        frame_ms = 1000.0 / float(frame_rate_fps)
        data[y_col] = data[y_col] / frame_ms

    # Clean & keep positives
    mask = (
        np.isfinite(data[x_col]) & (data[x_col] > 0) &
        np.isfinite(data[y_col]) & (data[y_col] > 0)
    )
    view_df = data.loc[mask, [x_col, y_col]].copy()
    if view_df.empty:
        raise ValueError("No valid points after filtering.")
    view_df["orig_idx"] = view_df.index
    view_df = view_df.rename(columns={x_col: "A", y_col: "V"}).copy()

    # Equal numeric span on both axes
    lo = float(min(view_df["A"].min(), view_df["V"].min()))
    hi = float(max(view_df["A"].max(), view_df["V"].max()))
    span = hi - lo
    if span <= 0:
        lo, hi = 0.0, max(1.0, hi)
    pad = 0.05 * (hi - lo if hi > lo else 1.0)
    x_range = [lo - pad, hi + pad]
    y_range = [lo - pad, hi + pad]

    # Base scatter (use Scattergl for speed with many points)
    sc = go.Scattergl(
        x=view_df["A"], y=view_df["V"],
        mode="markers",
        marker=dict(size=point_size, opacity=point_opacity),
        selected=dict(marker=dict(size=point_size+2)),
        unselected=dict(marker=dict(opacity=point_opacity)),
        name="events"
    )

    fig = go.FigureWidget(data=[sc])
    fig.update_layout(
        title=title,
        dragmode="lasso",        # start with lasso (you can switch to box in the toolbar)
        xaxis_title="Amplitude [deg]",
        yaxis_title=("Peak Velocity [deg/ms]" if convert_deg_per_frame_to_ms else "Peak Velocity [deg/frame]"),
        xaxis=dict(range=x_range, scaleratio=1, constrain="range"),
        yaxis=dict(range=y_range, scaleanchor="x"),  # same numeric span + square aspect
        height=420, width=640,
        margin=dict(l=60, r=10, t=40, b=50)
    )

    # Optional regression overlay
    stats = None
    if show_regression and len(view_df) >= 2:
        A = view_df["A"].to_numpy()
        V = view_df["V"].to_numpy()
        slope, intercept = np.polyfit(A, V, 1)
        r = np.corrcoef(A, V)[0, 1]
        r2 = float(r**2)
        stats = dict(slope=float(slope), intercept=float(intercept), r=float(r), R2=r2)

        x_line = np.linspace(x_range[0], x_range[1], 200)
        y_line = slope * x_line + intercept
        fig.add_scatter(x=x_line, y=y_line, mode="lines", name="OLS", line=dict(dash="dash"))

        fig.update_layout(title=f"{title} (r={r:.3f}, R²={r2:.3f}, slope={slope:.3f}, int.={intercept:.3f})")

    display(fig)  # ensure the widget renders in-place
    return fig, view_df, stats


def get_selected_subset_plotly(
    fig: go.FigureWidget,
    view_df: pd.DataFrame,
    original_df: pd.DataFrame = None
) -> pd.DataFrame:
    """
    Reads currently selected points from the FigureWidget and returns a DataFrame.
    If original_df is provided, returns those original rows (all columns) using 'orig_idx'.
    """
    # selectedpoints are index positions into the *trace's* data arrays (same order as view_df)
    sel = fig.data[0].selectedpoints
    if not sel:
        return pd.DataFrame(columns=(original_df.columns if original_df is not None else view_df.columns))
    picked = view_df.iloc[list(sel)]
    if original_df is None:
        return picked
    return original_df.loc[picked["orig_idx"].values]


In [89]:
# 1) Build your filtered df, as before:
df_use = all_saccade_collection.query(
    'head_movement==False'
)

# 2) Launch interactive figure
fig, view_df, stats = launch_main_sequence_lasso_plotly(
    df_use,
    x_col="magnitude_raw_angular",
    y_col="peak_velocity",
    convert_deg_per_frame_to_ms=False,
    title="Main sequence (lasso to collect outliers)",
    point_size=4.0, point_opacity=0.45, show_regression=True
)

FigureWidget({
    'data': [{'marker': {'opacity': 0.45, 'size': 4.0},
              'mode': 'markers',
              'name': 'events',
              'selected': {'marker': {'size': 6.0}},
              'type': 'scattergl',
              'uid': '241d0a02-cfa0-4649-9b00-dd0d4710c47b',
              'unselected': {'marker': {'opacity': 0.45}},
              'x': array([4.75654836, 3.77292389, 1.70140414, ..., 8.97522325, 5.40068903,
                          4.51347842]),
              'y': array([1.89750303, 2.89177602, 0.94773718, ..., 1.72715725, 1.38962263,
                          1.54272811])},
             {'line': {'dash': 'dash'},
              'mode': 'lines',
              'name': 'OLS',
              'type': 'scatter',
              'uid': '7cdd4a5b-95ff-4410-ba07-cf53d8113164',
              'x': array([-2.73367195e+00, -2.34297120e+00, -1.95227044e+00, -1.56156968e+00,
                          -1.17086892e+00, -7.80168161e-01, -3.89467403e-01,  1.23335606e-03,
           

In [90]:
# 3) Use the toolbar lasso (or box) to select points.
#    Then run this cell to retrieve the exact rows you selected:
picked = get_selected_subset_plotly(fig, view_df, original_df=df_use)
picked.shape, picked.head()


((86, 30),
         Main  Sub  saccade_start_ind  saccade_end_ind  \
 1202  6993.0    L               8616             8618   
 1461     NaN  NaN              10958            10959   
 7333     NaN  NaN              50546            50550   
 7334  1901.0    L              50550            50558   
 9213  4141.0    L              57812            57813   
 
       saccade_start_timestamp  saccade_end_timestamp  saccade_on_ms  \
 1202                3527854.0              3528510.0      176392.70   
 1461                4116178.0              4116511.0      205808.90   
 7333               17420926.0             17422258.0      871046.30   
 7334               17422258.0             17424922.0      871112.90   
 9213               25760039.0             25760378.0     1288001.95   
 
       saccade_off_ms  length  magnitude_raw_pixel  ...  phi_init_pos  \
 1202       176425.50       2            57.580909  ...     -9.301422   
 1461       205825.55       1            29.595568  ...    

In [91]:
picked

Unnamed: 0,Main,Sub,saccade_start_ind,saccade_end_ind,saccade_start_timestamp,saccade_end_timestamp,saccade_on_ms,saccade_off_ms,length,magnitude_raw_pixel,...,phi_init_pos,phi_end_pos,delta_theta,delta_phi,head_movement,eye,block,animal,time_to_peak_v,peak_velocity
1202,6993.0,L,8616,8618,3527854.0,3528510.0,176392.70,176425.50,2,57.580909,...,-9.301422,-9.259457,2.600528,0.041965,False,L,012,PV_126,0,15.903053
1461,,,10958,10959,4116178.0,4116511.0,205808.90,205825.55,1,29.595568,...,-18.482588,-18.263559,-16.900139,0.219028,False,L,011,PV_106,17,16.901558
7333,,,50546,50550,17420926.0,17422258.0,871046.30,871112.90,4,46.121221,...,-7.631453,9.752787,-13.909389,17.384239,False,L,012,PV_106,68,12.892066
7334,1901.0,L,50550,50558,17422258.0,17424922.0,871112.90,871246.10,8,57.756055,...,9.752787,-10.194513,18.314468,-19.947300,False,L,012,PV_106,0,12.892066
9213,4141.0,L,57812,57813,25760039.0,25760378.0,1288001.95,1288018.90,1,44.410016,...,9.083358,8.344794,0.174200,-0.738563,False,L,038,PV_62,0,18.779927
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23708,,,119768,119771,40020259.0,40021249.0,2001012.95,2001062.45,3,62.611191,...,-29.167343,-26.014725,4.994096,3.152618,False,R,009,PV_126,0,37.599724
23755,,,101755,101762,40687880.0,40690258.0,2034394.00,2034512.90,7,56.644470,...,-27.425144,-46.345291,-22.173606,-18.920147,False,R,038,PV_62,34,7.182394
23760,,,101834,101844,40714713.0,40718109.0,2035735.65,2035905.45,10,38.445766,...,-44.025806,-29.428060,31.318878,14.597746,False,R,038,PV_62,34,10.896240
24363,,,131904,131909,50927050.0,50928748.0,2546352.50,2546437.40,5,31.721567,...,-8.571403,-6.413555,11.999031,2.157849,False,R,038,PV_62,0,15.115057


In [92]:
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Tuple, Dict, List

def review_events_multi(
    block_dict: Dict[str, object],
    events_subset_df: pd.DataFrame,
    export_dir: Optional[Path] = None,   # <-- now optional; None is allowed
    # column names in events_subset_df (auto-detected if not provided)
    animal_col: str = "animal",
    block_col: str = "block",
    eye_col: str = "eye",
    start_ms_col: Optional[str] = None,  # auto-detect among: 'start_ms','saccade_on_ms'
    end_ms_col: Optional[str] = None,    # auto-detect among: 'end_ms','saccade_off_ms'
    window_scale: float = 0.9,
    text_cols: Tuple[str, ...] = ("phi","theta","peak_velocity","magnitude_raw_angular","pupil_diameter"),
    font_scale: float = 0.6,
    thickness: int = 2,
    wait_ms: int = 15,
    flip_mode: str = "vertical",  # "vertical" or "none"
) -> pd.DataFrame:
    """
    Multi-animal/multi-block synchronized L/R eye visualizer for outlier vetting.

    Adds UI buttons:
      - mark_bad  -> sets manual_outlier_detected=True (button shown RED for current event)
      - mark_good -> sets manual_outlier_detected=False (button shown GREEN for current event)
      - export_annotated_df -> writes CSV with columns:
          animal, block, eye, start_ms, end_ms, manual_outlier_detected
        If export_dir is None, writes to CWD with a timestamped filename.

    Always returns the reviewed DataFrame with the 'manual_outlier_detected' column.
    """

    # ---------- resolve timing cols ----------
    def _resolve_time_cols(df: pd.DataFrame, start_c: Optional[str], end_c: Optional[str]) -> Tuple[str, str]:
        cand_start = [start_c, "start_ms", "saccade_on_ms"]
        cand_end   = [end_c,   "end_ms",   "saccade_off_ms"]
        s = next((c for c in cand_start if c and c in df.columns), None)
        e = next((c for c in cand_end   if c and c in df.columns), None)
        if s is None or e is None:
            raise ValueError("Could not resolve start/end ms columns. "
                             "Tried: start in {start_ms,saccade_on_ms}, end in {end_ms,saccade_off_ms}.")
        return s, e

    start_ms_col, end_ms_col = _resolve_time_cols(events_subset_df, start_ms_col, end_ms_col)

    # ---------- working copy ----------
    events = events_subset_df.copy().reset_index(drop=True)
    # keep None for "unset" until user marks
    if "manual_outlier_detected" not in events.columns:
        events["manual_outlier_detected"] = None

    # ---------- guards ----------
    for c in (animal_col, block_col, eye_col, start_ms_col, end_ms_col):
        if c not in events.columns:
            raise ValueError(f"events_subset_df is missing required column '{c}'")

    # ---------- helpers ----------
    def _frame_col(df: pd.DataFrame) -> Optional[str]:
        for c in ("eye_frame", "frame", "frame_idx", "video_frame"):
            if c in df.columns:
                return c
        return None

    def _lookup_block(animal: str, block_num: int):
        # Scan values for exact match
        for obj in block_dict.values():
            if getattr(obj, "animal_call", None) == animal and getattr(obj, "block_num", None) == block_num:
                return obj
        # Key patterns
        key1 = f"{animal}_block_{block_num}"
        key2 = f"{animal}_block_{int(block_num):03d}"
        if key1 in block_dict: return block_dict[key1]
        if key2 in block_dict: return block_dict[key2]
        raise KeyError(f"BlockSync not found for animal='{animal}', block={block_num}")

    def _nearest_row(df: pd.DataFrame, ms: float) -> Optional[pd.Series]:
        arr = df["ms_axis"].values
        if len(arr) == 0:
            return None
        idx = int(np.argmin(np.abs(arr - ms)))
        return df.iloc[idx]

    def _apply_flip(img: np.ndarray) -> np.ndarray:
        if flip_mode == "vertical":
            return cv2.flip(img, 0)  # y-axis reversal
        return img

    # ---------- lazy video state ----------
    capL, capR = None, None
    cur_animal, cur_block = None, None
    cur_blocksync = None
    left_df = None
    right_df = None
    left_frame_col = None
    right_frame_col = None
    fpsL = 60.0
    fpsR = 60.0
    Wl = Hl = Wr = Hr = 0

    def _release_caps():
        nonlocal capL, capR
        if capL is not None:
            capL.release(); capL = None
        if capR is not None:
            capR.release(); capR = None

    def _open_for(animal: str, block_num: int):
        nonlocal cur_animal, cur_block, cur_blocksync
        nonlocal left_df, right_df, left_frame_col, right_frame_col
        nonlocal capL, capR, fpsL, fpsR, Wl, Hl, Wr, Hr

        if animal == cur_animal and block_num == cur_block:
            return

        _release_caps()
        bs = _lookup_block(animal, block_num)
        cur_blocksync = bs
        cur_animal, cur_block = animal, block_num

        left_df = getattr(bs, "left_eye_data", None)
        right_df = getattr(bs, "right_eye_data", None)
        if left_df is None or right_df is None:
            raise RuntimeError(f"Missing left/right eye data for {animal} block {block_num}.")
        if "ms_axis" not in left_df.columns or "ms_axis" not in right_df.columns:
            raise RuntimeError(f"'ms_axis' missing in eye data for {animal} block {block_num}.")

        left_frame_col = _frame_col(left_df)
        right_frame_col = _frame_col(right_df)

        lv = Path(bs.le_videos[0])
        rv = Path(bs.re_videos[0])
        capL_local = cv2.VideoCapture(str(lv))
        capR_local = cv2.VideoCapture(str(rv))
        if not capL_local.isOpened():
            raise RuntimeError(f"Cannot open left video: {lv}")
        if not capR_local.isOpened():
            raise RuntimeError(f"Cannot open right video: {rv}")
        capL, capR = capL_local, capR_local

        Wl, Hl = int(capL.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capL.get(cv2.CAP_PROP_FRAME_HEIGHT))
        Wr, Hr = int(capR.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capR.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fpsL = capL.get(cv2.CAP_PROP_FPS) or 60.0
        fpsR = capR.get(cv2.CAP_PROP_FPS) or 60.0

    def _seek(cap, idx: int):
        cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(idx)))

    def _read(cap):
        ret, f = cap.read()
        if not ret: return None
        return f

    def _overlay_text(img, lines: List[str], origin=(10, 24), vstep=22, color=(255,255,255)):
        x,y = origin
        for ln in lines:
            cv2.putText(img, ln, (x,y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness, cv2.LINE_AA)
            y += vstep

    def _overlay_ellipse(img, df_eye: Optional[pd.DataFrame], frame_idx: Optional[int]):
        if df_eye is None or frame_idx is None: return
        col = _frame_col(df_eye)
        if col is None: return
        hit = df_eye[df_eye[col] == frame_idx]
        if hit.empty: return
        row = hit.iloc[0]
        cx, cy = row.get("center_x", np.nan), row.get("center_y", np.nan)
        w, h  = row.get("width", np.nan), row.get("height", np.nan)
        phi   = row.get("phi", np.nan)
        if not (pd.isna(cx) or pd.isna(cy) or pd.isna(w) or pd.isna(h)):
            cv2.ellipse(
                img,
                (int(round(cx)), int(round(cy))),
                (max(1,int(round(w))), max(1,int(round(h)))),
                float(0 if pd.isna(phi) else phi),
                0, 360, (0,255,0), thickness
            )

    # ---------- display geometry ----------
    disp_Wl, disp_Hl = 0, 0
    disp_Wr, disp_Hr = 0, 0

    # ---------- controls ----------
    ctrl_w, ctrl_h = 420, 360
    buttons = {
        "Play":                  ((10,  10),(200, 60)),
        "Pause":                 ((220, 10),(410, 60)),
        "Prev":                  ((10,  80),(200,130)),
        "Next":                  ((220, 80),(410,130)),
        "Step -1":               ((10,  150),(200,200)),
        "Step +1":               ((220, 150),(410,200)),
        "mark_bad":              ((10,  220),(200,270)),   # red when BAD
        "mark_good":             ((220, 220),(410,270)),   # green when GOOD
        "export_annotated_df":   ((10,  290),(410,340)),
    }

    # BGR colors
    COLOR_BG      = (60, 60, 60)
    COLOR_BORDER  = (180,180,180)
    COLOR_TEXT    = (220,220,220)
    COLOR_BAD     = (0,   0, 255)   # bright red
    COLOR_GOOD    = (0, 255,   0)   # bright green
    COLOR_EXPORT  = (0, 165, 255)   # orange

    last_status = ""  # echoed message (e.g., after export)

    def _draw_controls(idx: int) -> np.ndarray:
        img = np.zeros((ctrl_h, ctrl_w, 3), dtype=np.uint8)

        state = events.iloc[idx]["manual_outlier_detected"]
        state_str = "UNSET" if state is None else ("BAD" if bool(state) else "GOOD")
        header = f"Event {idx+1}/{len(events)} | state={state_str}"
        cv2.putText(img, header, (10, ctrl_h-12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1, cv2.LINE_AA)

        for name, ((x1,y1),(x2,y2)) in buttons.items():
            fill = COLOR_BG
            if name == "mark_bad"  and state is True:  fill = COLOR_BAD
            if name == "mark_good" and state is False: fill = COLOR_GOOD
            if name == "export_annotated_df":          fill = COLOR_EXPORT

            cv2.rectangle(img, (x1,y1), (x2,y2), fill, -1)
            cv2.rectangle(img, (x1,y1), (x2,y2), COLOR_BORDER, 2)
            label = name
            text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
            tx = x1 + (x2 - x1 - text_size[0]) // 2
            ty = y1 + (y2 - y1 + text_size[1]) // 2
            cv2.putText(img, label, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

        if last_status:
            cv2.putText(img, last_status, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (180,255,180), 1, cv2.LINE_AA)

        return img

    def _hit_button(x,y):
        for name, ((x1,y1),(x2,y2)) in buttons.items():
            if x1 <= x <= x2 and y1 <= y <= y2:
                return name
        return None

    cv2.namedWindow("Controls", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Left Eye", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Right Eye", cv2.WINDOW_NORMAL)

    playing = False
    quit_flag = False
    cur_idx = 0
    controls_img = _draw_controls(cur_idx)
    cv2.imshow("Controls", controls_img)

    # timing step from ms_axis median spacing
    def _median_step_ms(df: pd.DataFrame) -> float:
        if df is None or df.empty: return 1000.0/60.0
        ms = df["ms_axis"].values
        if ms.size < 2: return 1000.0/60.0
        d = np.diff(ms)
        d = d[np.isfinite(d) & (d>0)]
        if d.size == 0: return 1000.0/60.0
        return float(np.median(d))

    step_ms = 1000.0/60.0
    cur_ms = None

    def _export_now():
        """Write the annotated table. If export_dir is None, write to CWD with timestamp."""
        nonlocal last_status
        # Build dataframe to export
        out = events[[animal_col, block_col, eye_col, start_ms_col, end_ms_col, "manual_outlier_detected"]].copy()
        out = out.rename(columns={
            animal_col: "animal",
            block_col: "block",
            eye_col: "eye",
            start_ms_col: "start_ms",
            end_ms_col: "end_ms",
        })
        # Choose path
        if export_dir is None:
            ts = pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
            out_path = Path.cwd() / f"joint_event_annotations_{ts}.csv"
        else:
            export_dir.mkdir(parents=True, exist_ok=True)
            out_path = export_dir / "joint_event_annotations.csv"
        # Write
        out.to_csv(out_path, index=False)
        last_status = f"Exported: {str(out_path)}"

    def on_mouse_controls(event, x, y, flags, param):
        nonlocal playing, cur_idx, controls_img, cur_ms, last_status
        if event != cv2.EVENT_LBUTTONDOWN:
            return
        name = _hit_button(x, y)
        last_status = ""  # clear unless export sets it
        if name == "Play":
            playing = True
        elif name == "Pause":
            playing = False
        elif name == "Prev":
            playing = False; cur_idx = (cur_idx - 1) % len(events); cur_ms = None
        elif name == "Next":
            playing = False; cur_idx = (cur_idx + 1) % len(events); cur_ms = None
        elif name == "Step -1":
            playing = False; cur_ms = None if cur_ms is None else cur_ms - step_ms
        elif name == "Step +1":
            playing = False; cur_ms = None if cur_ms is None else cur_ms + step_ms
        elif name == "mark_bad":
            events.at[events.index[cur_idx], "manual_outlier_detected"] = True
        elif name == "mark_good":
            events.at[events.index[cur_idx], "manual_outlier_detected"] = False
        elif name == "export_annotated_df":
            _export_now()
        controls_img = _draw_controls(cur_idx)
        cv2.imshow("Controls", controls_img)

    cv2.setMouseCallback("Controls", on_mouse_controls)

    # keyboard shortcuts
    # SPACE play/pause; [ / ] prev/next; , / . step; B mark_bad; G mark_good; E export; Q/ESC quit
    while True:
        k = cv2.waitKey(wait_ms) & 0xFF
        if k in (27, ord('q'), ord('Q')):
            quit_flag = True
        elif k == 32:  # space
            playing = not playing
        elif k == ord('['):
            playing = False; cur_idx = (cur_idx - 1) % len(events); cur_ms = None
            controls_img = _draw_controls(cur_idx); cv2.imshow("Controls", controls_img)
        elif k == ord(']'):
            playing = False; cur_idx = (cur_idx + 1) % len(events); cur_ms = None
            controls_img = _draw_controls(cur_idx); cv2.imshow("Controls", controls_img)
        elif k in (ord(','),):  # step -1
            playing = False; cur_ms = None if cur_ms is None else cur_ms - step_ms
        elif k in (ord('.'),):  # step +1
            playing = False; cur_ms = None if cur_ms is None else cur_ms + step_ms
        elif k in (ord('b'), ord('B')):
            events.at[events.index[cur_idx], "manual_outlier_detected"] = True
            controls_img = _draw_controls(cur_idx); cv2.imshow("Controls", controls_img)
        elif k in (ord('g'), ord('G')):
            events.at[events.index[cur_idx], "manual_outlier_detected"] = False
            controls_img = _draw_controls(cur_idx); cv2.imshow("Controls", controls_img)
        elif k in (ord('e'), ord('E')):
            _export_now()
            controls_img = _draw_controls(cur_idx); cv2.imshow("Controls", controls_img)

        if quit_flag:
            break

        # load current event info
        row = events.iloc[cur_idx]
        animal = str(row[animal_col])
        block_num = int(row[block_col])
        start_ms = float(row[start_ms_col])
        end_ms   = float(row[end_ms_col])

        _open_for(animal, block_num)

        if disp_Wl == 0:
            disp_Wl, disp_Hl = int(Wl * window_scale), int(Hl * window_scale)
            disp_Wr, disp_Hr = int(Wr * window_scale), int(Hr * window_scale)
            cv2.resizeWindow("Left Eye", disp_Wl, disp_Hl)
            cv2.resizeWindow("Right Eye", disp_Wr, disp_Hr)
            cv2.resizeWindow("Controls", ctrl_w, ctrl_h)

        step_ms = np.mean([_median_step_ms(left_df), _median_step_ms(right_df)])

        if cur_ms is None:
            cur_ms = start_ms
        cur_ms = min(max(cur_ms, start_ms), end_ms)

        rowL = _nearest_row(left_df, cur_ms)
        rowR = _nearest_row(right_df, cur_ms)

        # ---------- Left frame ----------
        L_img = np.zeros((Hl, Wl, 3), dtype=np.uint8)
        if rowL is not None:
            if left_frame_col is not None and pd.notna(rowL[left_frame_col]):
                L_frame_idx = int(rowL[left_frame_col])
                _seek(capL, L_frame_idx)
                fL = _read(capL)
                if fL is not None:
                    img = fL.copy()
                    # (1) DATA overlays on raw frame
                    _overlay_ellipse(img, left_df, L_frame_idx)
                    # (2) vertical flip
                    img = _apply_flip(img)
                    # (3) TEXT after flip
                    lines = [f"Left | {animal} B{block_num} | t={cur_ms:.1f}ms | frame={L_frame_idx}"]
                    if left_df is not None and left_frame_col is not None:
                        for c in text_cols:
                            if c in left_df.columns:
                                v = rowL.get(c, np.nan)
                                if pd.notna(v):
                                    try:
                                        lines.append(f"{c}={float(v):.3f}")
                                    except Exception:
                                        pass
                    _overlay_text(img, lines, origin=(10, 24))
                    L_img = img
            else:
                _overlay_text(L_img, [f"Left | {animal} B{block_num}",
                                      f"t={cur_ms:.1f}ms", "no synchronized frame"], origin=(10,24))
                L_img = _apply_flip(L_img)
        else:
            _overlay_text(L_img, [f"Left | {animal} B{block_num}", "no data"], origin=(10,24))
            L_img = _apply_flip(L_img)

        # ---------- Right frame ----------
        R_img = np.zeros((Hr, Wr, 3), dtype=np.uint8)
        if rowR is not None:
            if right_frame_col is not None and pd.notna(rowR[right_frame_col]):
                R_frame_idx = int(rowR[right_frame_col])
                _seek(capR, R_frame_idx)
                fR = _read(capR)
                if fR is not None:
                    img = fR.copy()
                    _overlay_ellipse(img, right_df, R_frame_idx)   # (1) overlays on raw
                    img = _apply_flip(img)                         # (2) flip
                    lines = [f"Right | {animal} B{block_num} | t={cur_ms:.1f}ms | frame={R_frame_idx}"]
                    if right_df is not None and right_frame_col is not None:
                        for c in text_cols:
                            if c in right_df.columns:
                                v = rowR.get(c, np.nan)
                                if pd.notna(v):
                                    try:
                                        lines.append(f"{c}={float(v):.3f}")
                                    except Exception:
                                        pass
                    _overlay_text(img, lines, origin=(10, 24))     # (3) text after flip
                    R_img = img
            else:
                _overlay_text(R_img, [f"Right | {animal} B{block_num}",
                                      f"t={cur_ms:.1f}ms", "no synchronized frame"], origin=(10,24))
                R_img = _apply_flip(R_img)
        else:
            _overlay_text(R_img, [f"Right | {animal} B{block_num}", "no data"], origin=(10,24))
            R_img = _apply_flip(R_img)

        # show
        cv2.imshow("Left Eye",  cv2.resize(L_img, (disp_Wl, disp_Hl)))
        cv2.imshow("Right Eye", cv2.resize(R_img, (disp_Wr, disp_Hr)))
        cv2.imshow("Controls", controls_img)

        # advance time if playing
        if playing:
            cur_ms += step_ms
            if cur_ms > end_ms:
                playing = False
                cur_idx = (cur_idx + 1) % len(events)
                cur_ms = None
                controls_img = _draw_controls(cur_idx)
                cv2.imshow("Controls", controls_img)

    # cleanup windows; always return the annotated dataframe
    _release_caps()
    cv2.destroyAllWindows()
    return events


In [93]:
# Suppose you built block_dict with your multi-animal collection (as in your pipeline),
# and you have a subset of rows from all_saccade_collection you want to review:
#subset = all_saccade_collection.sample(300, random_state=1)  # for example
subset=picked
export_dir = Path(r"Z:\Nimrod\experiments\multi_animal_analysis\outlier_manual_filtering")       # global (non-animal) folder

reviewed = review_events_multi(
    block_dict=block_dict,
    events_subset_df=subset,
    export_dir=export_dir,
    # If your time columns are named 'saccade_on_ms'/'saccade_off_ms', no need to pass anything.
    # If your keys in block_dict are unusual, we fall back to scanning .animal_call/.block_num on each BlockSync.
    window_scale=0.85,
    flip_mode='vertical'
)
