In [None]:

import datetime
import numpy as np
import cv2
from itertools import cycle
import pickle
import pathlib
import math
import tqdm
import scipy.io
from matplotlib import pyplot as plt
import scipy.io
import h5py
import re
from lxml import etree as ET
import scipy.signal as sig
import pandas as pd
from scipy.stats import kde
from BlockSync_current import BlockSync
import UtilityFunctions_newOE as uf
from scipy import signal
import bokeh
import seaborn as sns
from matplotlib import rcParams
%matplotlib inline
plt.style.use('default')
rcParams['pdf.fonttype'] = 42  # Ensure fonts are embedded and editable
rcParams['ps.fonttype'] = 42  # Ensure compatibility with vector outputs


def bokeh_plotter(data_list, x_axis_list=None, label_list=None,
                  plot_name='default',
                  x_axis_label='X', y_axis_label='Y',
                  peaks=None, peaks_list=False, export_path=False):
    """Generates an interactive Bokeh plot for the given data vector.
    Args:
        data_list (list or array): The data to be plotted.
        label_list (list of str): The labels of the data vectors
        plot_name (str, optional): The title of the plot. Defaults to 'default'.
        x_axis (str, optional): The label for the x-axis. Defaults to 'X'.
        y_axis (str, optional): The label for the y-axis. Defaults to 'Y'.
        peaks (list or array, optional): Indices of peaks to highlight on the plot. Defaults to None.
        export_path (False or str): when set to str, will output the resulting html fig
    """
    color_cycle = cycle(bokeh.palettes.Category10_10)
    fig = bokeh.plotting.figure(title=f'bokeh explorer: {plot_name}',
                                x_axis_label=x_axis_label,
                                y_axis_label=y_axis_label,
                                plot_width=1500,
                                plot_height=700)

    for i, data_vector in enumerate(data_list):

        color = next(color_cycle)

        if x_axis_list is None:
            x_axis = range(len(data_vector))
        elif len(x_axis_list) == len(data_list):
            print('x_axis manually set')
            x_axis = x_axis_list[i]
        else:
            raise Exception(
                'problem with x_axis_list input - should be either None, or a list with the same length as data_list')
        if label_list is None:
            fig.line(x_axis, data_vector, line_color=color, legend_label=f"Line {i + 1}")
        elif len(label_list) == len(data_list):
            fig.line(range(len(data_vector)), data_vector, line_color=color, legend_label=f"{label_list[i]}")
        if peaks is not None and peaks_list is True:
            fig.circle(peaks[i], data_vector[peaks[i]], size=10, color=color)

    if peaks is not None and peaks_list is False:
        fig.circle(peaks, data_vector[peaks], size=10, color='red')

    if export_path is not False:
        print(f'exporting to {export_path}')
        bokeh.io.output.output_file(filename=str(export_path / f'{plot_name}.html'), title=f'{plot_name}')
    bokeh.plotting.show(fig)


def load_eye_data_2d_w_rotation_matrix(block):
    """
    This function checks if the eye dataframes and rotation dict object exist, then imports them
    :param block: The current blocksync class with verifiec re/le dfs
    :return: None
    """
    try:
        block.left_eye_data = pd.read_csv(block.analysis_path / 'left_eye_data.csv', index_col=0, engine='python')
        block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data.csv', index_col=0, engine='python')
    except FileNotFoundError:
        print('eye_data files not found, run the pipeline!')
        return

    try:
        with open(block.analysis_path / 'rotate_eye_data_params.pkl', 'rb') as file:
            rotation_dict = pickle.load(file)
            block.left_rotation_matrix = rotation_dict['left_rotation_matrix']
            block.right_rotation_matrix = rotation_dict['right_rotation_matrix']
            block.left_rotation_angle = rotation_dict['left_rotation_angle']
            block.right_rotation_angle = rotation_dict['right_rotation_angle']
    except FileNotFoundError:
        print('No rotation matrix file, create it')


def create_saccade_events_df(eye_data_df, speed_threshold, bokeh_verify_threshold=False, magnitude_calib=1,
                             speed_profile=True):
    """
    Detects saccade events in eye tracking data and computes relevant metrics.

    Parameters:
    - eye_data_df (pd.DataFrame): Input DataFrame containing eye tracking data.
    - speed_threshold (float): Threshold for saccade detection based on speed.

    Returns:
    - df (pd.DataFrame): Modified input DataFrame with added columns for speed and saccade detection.
    - saccade_events_df (pd.DataFrame): DataFrame containing information about detected saccade events.

    Steps:
    1. Calculate speed components ('speed_x', 'speed_y') based on differences in 'center_x' and 'center_y'.
    2. Compute the magnitude of the velocity vector ('speed_r').
    3. Create a binary column ('is_saccade') indicating saccade events based on the speed threshold.
    4. Determine saccade onset and offset indices and timestamps.
    5. Create a DataFrame ('saccade_events_df') with columns:
        - 'saccade_start_ind': Indices of saccade onset.
        - 'saccade_start_timestamp': Timestamps corresponding to saccade onset.
        - 'saccade_end_ind': Indices of saccade offset.
        - 'saccade_end_timestamp': Timestamps corresponding to saccade offset.
        - 'length': Duration of each saccade event.
    6. Calculate distance traveled and angles for each saccade event.
    7. Append additional columns to 'saccade_events_df':
        - 'magnitude': Magnitude of the distance traveled during each saccade.
        - 'angle': Angle of the saccade vector in degrees.
        - 'initial_x', 'initial_y': Initial coordinates of the saccade.
        - 'end_x', 'end_y': End coordinates of the saccade.

    Note: The original 'eye_data_df' is not modified; modified data is returned as 'df'.
    """
    df = eye_data_df
    df['speed_x'] = df['center_x'].diff()  # Difference between consecutive 'center_x' values
    df['speed_y'] = df['center_y'].diff()  # Difference between consecutive 'center_y' values

    # Step 2: Calculate magnitude of the velocity vector (R vector speed)
    df['speed_r'] = (df['speed_x'] ** 2 + df['speed_y'] ** 2) ** 0.5

    # Create a column for saccade detection
    df['is_saccade'] = df['speed_r'] > speed_threshold

    # create a saccade_on_off indicator where 1 is rising edge and -1 is falling edge by subtracting a shifted binary mask
    saccade_on_off = df.is_saccade.astype(int) - df.is_saccade.shift(periods=1, fill_value=False).astype(int)
    saccade_on_inds = np.where(saccade_on_off == 1)[
                          0] - 1  # notice the manual shift here, chosen to include the first (sometimes slower) eye frame, just before saccade threshold crossing
    saccade_on_ms = df['ms_axis'].iloc[saccade_on_inds]
    saccade_on_timestamps = df['OE_timestamp'].iloc[saccade_on_inds]
    saccade_off_inds = np.where(saccade_on_off == -1)[0]
    saccade_off_timestamps = df['OE_timestamp'].iloc[saccade_off_inds]
    saccade_off_ms = df['ms_axis'].iloc[saccade_off_inds]

    saccade_dict = {'saccade_start_ind': saccade_on_inds,
                    'saccade_start_timestamp': saccade_on_timestamps.values,
                    'saccade_end_ind': saccade_off_inds,
                    'saccade_end_timestamp': saccade_off_timestamps.values,
                    'saccade_on_ms': saccade_on_ms.values,
                    'saccade_off_ms': saccade_off_ms.values}

    saccade_events_df = pd.DataFrame.from_dict(saccade_dict)
    saccade_events_df['length'] = saccade_events_df['saccade_end_ind'] - saccade_events_df['saccade_start_ind']
    # Drop columns used for intermediate steps
    df = df.drop(['is_saccade'], axis=1)

    distances = []
    angles = []
    speed_list = []
    diameter_list = []
    for index, row in tqdm.tqdm(saccade_events_df.iterrows()):
        saccade_samples = df.loc[(df['OE_timestamp'] >= row['saccade_start_timestamp']) &
                                 (df['OE_timestamp'] <= row['saccade_end_timestamp'])]
        distance_traveled = saccade_samples['speed_r'].sum()
        if speed_profile:
            saccade_speed_profile = saccade_samples['speed_r'].values
            speed_list.append(saccade_speed_profile)
        saccade_diameter_profile = saccade_samples['pupil_diameter'].values
        diameter_list.append(saccade_diameter_profile)
        # Calculate angle from initial position to endpoint
        initial_position = saccade_samples.iloc[0][['center_x', 'center_y']]
        endpoint = saccade_samples.iloc[-1][['center_x', 'center_y']]
        overall_angle = np.arctan2(endpoint['center_y'] - initial_position['center_y'],
                                   endpoint['center_x'] - initial_position['center_x'])

        angles.append(overall_angle)
        distances.append(distance_traveled)

    saccade_events_df['magnitude_raw'] = np.array(distances)
    saccade_events_df['magnitude'] = np.array(distances) * magnitude_calib
    saccade_events_df['angle'] = np.where(np.isnan(angles), angles, np.rad2deg(
        angles) % 360)  # Convert radians to degrees and ensure result is in [0, 360)
    start_ts = saccade_events_df['saccade_start_timestamp'].values
    end_ts = saccade_events_df['saccade_end_timestamp'].values
    saccade_start_df = df[df['OE_timestamp'].isin(start_ts)]
    saccade_end_df = df[df['OE_timestamp'].isin(end_ts)]
    start_x_coord = saccade_start_df['center_x']
    start_y_coord = saccade_start_df['center_y']
    end_x_coord = saccade_end_df['center_x']
    end_y_coord = saccade_end_df['center_y']
    saccade_events_df['initial_x'] = start_x_coord.values
    saccade_events_df['initial_y'] = start_y_coord.values
    saccade_events_df['end_x'] = end_x_coord.values
    saccade_events_df['end_y'] = end_y_coord.values
    saccade_events_df['calib_dx'] = (saccade_events_df['end_x'].values - saccade_events_df[
        'initial_x'].values) * magnitude_calib
    saccade_events_df['calib_dy'] = (saccade_events_df['end_y'].values - saccade_events_df[
        'initial_y'].values) * magnitude_calib
    if speed_profile:
        saccade_events_df['speed_profile'] = speed_list
    saccade_events_df['diameter_profile'] = diameter_list
    if bokeh_verify_threshold:
        bokeh_plotter(data_list=[df.speed_r], label_list=['Pupil Velocity'], peaks=saccade_on_inds)

    return df, saccade_events_df


# create a multi-animal block_collection:

def create_block_collections(animals, block_lists, experiment_path, bad_blocks=None):
    """
    Create block collections and a block dictionary from multiple animals and their respective block lists.

    Parameters:
    - animals: list of str, names of the animals.
    - block_lists: list of lists of int, block numbers corresponding to each animal.
    - experiment_path: pathlib.Path, path to the experiment directory.
    - bad_blocks: list of int, blocks to exclude. Default is an empty list.

    Returns:
    - block_collection: list of BlockSync objects for all specified blocks.
    - block_dict: dictionary where keys are block numbers as strings and values are BlockSync objects.
    """
    import UtilityFunctions_newOE as uf

    if bad_blocks is None:
        bad_blocks = []

    block_collection = []
    block_dict = {}

    for animal, blocks in zip(animals, block_lists):
        # Generate blocks for the current animal
        current_blocks = uf.block_generator(
            block_numbers=blocks,
            experiment_path=experiment_path,
            animal=animal,
            bad_blocks=bad_blocks
        )
        # Add to collection and dictionary
        block_collection.extend(current_blocks)
        for b in current_blocks:
            block_dict[f"{animal}_block_{b.block_num}"] = b

    return block_collection, block_dict


In [None]:
# BLOCK DEFINITION #
# This was the previous run
#animals = ['PV_62', 'PV_126', 'PV_57']
#block_lists = [[24, 26, 38], [7, 8, 9, 10, 11, 12], [7, 8, 9, 12, 13]]
#This with new animals:
animals = ['PV_126']
block_lists = [[7]]
experiment_path = pathlib.Path(r"Z:\Nimrod\experiments")
bad_blocks = [0]  # Example of bad blocks

block_collection, block_dict = create_block_collections(
    animals=animals,
    block_lists=block_lists,
    experiment_path=experiment_path,
    bad_blocks=bad_blocks
)
for block in block_collection:
    block.parse_open_ephys_events()
    block.get_eye_brightness_vectors()
    block.synchronize_block()
    block.create_eye_brightness_df(threshold_value=20)

    # if the code fails here, go to manual synchronization
    block.import_manual_sync_df()
    block.read_dlc_data()
    block.calibrate_pixel_size(10)
    #load_eye_data_2d_w_rotation_matrix(block) #should be integrated again... later

    for block in block_collection:
        # block.left_eye_data = pd.read_csv(block.analysis_path / f'left_eye_data_corr_angles.csv')
        # block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_corr_angles.csv')
        #block.left_eye_data = pd.read_csv(block.analysis_path / f'left_eye_data_degrees_raw_xflipped.csv')
        #block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_degrees_raw_xflipped.csv')
        block.left_eye_data = pd.read_csv(block.analysis_path / f'left_eye_data_degrees_raw_verified.csv')
        block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_degrees_raw_verified.csv')
        # block.left_eye_data = pd.read_csv(block.analysis_path / f'left_eye_data_3d_corr_verified.csv')
        # block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_3d_corr_verified.csv')
        #block.left_eye_data = pd.read_csv(block.analysis_path / f'left_eye_data_degrees_rotated_verified.csv')
        #block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_degrees_rotated_verified.csv')

    # calibrate pupil diameter:
    # if 'pupil_diameter' not in block.left_eye_data.columns:
    #     block.left_eye_data['pupil_diameter_pixels'] = block.left_eye_data.major_ax * 2 * np.pi
    #     block.right_eye_data['pupil_diameter_pixels'] = block.right_eye_data.major_ax * 2 * np.pi
    #     block.left_eye_data['pupil_diameter'] = block.left_eye_data['pupil_diameter_pixels'] * block.L_pix_size
    #     block.right_eye_data['pupil_diameter'] = block.right_eye_data['pupil_diameter_pixels'] * block.R_pix_size

In [None]:
for block in block_collection:
    if 'pupil_diameter' not in block.left_eye_data.columns:
        print(f'calculating pupil diameter for {block} ')
        block.left_eye_data['pupil_diameter_pixels'] = block.left_eye_data.major_ax
        block.right_eye_data['pupil_diameter_pixels'] = block.right_eye_data.major_ax
        block.left_eye_data['pupil_diameter'] = block.left_eye_data['pupil_diameter_pixels'] * block.L_pix_size
        block.right_eye_data['pupil_diameter'] = block.right_eye_data['pupil_diameter_pixels'] * block.R_pix_size

In [None]:
# This cell defines functions for manual annotation of extreme datapoints

import numpy as np
import pandas as pd
import cv2
from pathlib import Path
from typing import Optional, Tuple, Sequence, Union, List, Dict


# ---------------------------- helpers: contiguous runs on integer index ----------------------------
def _runs_from_index(int_index: np.ndarray, min_run_len: int = 1, max_gap: int = 1) -> List[Tuple[int, int]]:
    """
    From a sorted array of integer indices, return [(start_idx, end_idx)] for contiguous runs.
    max_gap=1 means consecutive indices (diff == 1) are one run; larger max_gap stitches small breaks.
    """
    if int_index.size == 0:
        return []
    diffs = np.diff(int_index)
    # A new run begins when the gap exceeds max_gap
    boundaries = np.where(diffs > max_gap)[0]
    starts = np.r_[0, boundaries + 1]
    ends = np.r_[boundaries, len(int_index) - 1]
    runs = [(int(int_index[s]), int(int_index[e])) for s, e in zip(starts, ends)]
    if min_run_len > 1:
        runs = [r for r in runs if (r[1] - r[0] + 1) >= min_run_len]
    return runs


def _pad_ms_bounds(df: pd.DataFrame, start_ms: float, end_ms: float, pre_pad_ms: float, post_pad_ms: float) -> Tuple[
    float, float]:
    if "ms_axis" not in df.columns:
        return start_ms, end_ms
    ms = df["ms_axis"].values
    s = start_ms - float(pre_pad_ms)
    e = end_ms + float(post_pad_ms)
    # clip to recording bounds if ms_axis is monotonic (typical)
    if ms.size:
        s = max(min(ms[0], ms[-1]), min(s, max(ms[0], ms[-1])))
        e = max(min(ms[0], ms[-1]), min(e, max(ms[0], ms[-1])))
    return float(s), float(e)


# ---------------------------- main: query -> events df ----------------------------
def query_to_events_df(
        eye_df: pd.DataFrame,
        query_str: str,
        *,
        animal: str,
        block: Union[int, str],
        eye: str,
        pre_pad_ms: float = 0.0,
        post_pad_ms: float = 0.0,
        min_run_len: int = 1,
        max_gap: int = 1,  # stitch small holes between hits (in index units)
        sort_by_start: bool = True,
) -> pd.DataFrame:
    """
    Turn a DataFrame slice via df.query(...) into lumped contiguous events for manual review.

    Notes
    -----
    - Contiguity is defined on the *DataFrame index* (after filtering).
    - start_ms/end_ms come from the first/last row's `ms_axis`. If ms_axis absent, it falls back to row index.
    - pre/post padding applied in ms (if ms_axis present; otherwise ignored).

    Returns
    -------
    DataFrame with: ['animal','block','eye','start_ms','end_ms']
    """
    if eye not in ("L", "R", "left", "right"):
        raise ValueError("eye should be 'L'/'R' (or 'left'/'right').")

    eye_short = "L" if eye.lower().startswith("l") else "R"

    # Filter via query (robust to @-locals by not injecting variables here)
    try:
        sub = eye_df.query(query_str)
    except Exception as e:
        raise ValueError(f"query failed: {e}")

    if sub.empty:
        return pd.DataFrame(columns=["animal", "block", "eye", "start_ms", "end_ms"])

    # Ensure index is integer-like (use position if not)
    if not np.issubdtype(sub.index.dtype, np.integer):
        # use original positional indices of eye_df for contiguity
        pos_idx = eye_df.index.get_indexer(sub.index)
        valid = pos_idx >= 0
        int_idx = pos_idx[valid]
        # align sub to valid subset
        sub = sub.iloc[np.where(valid)[0]]
    else:
        int_idx = sub.index.values

    int_idx = np.asarray(int_idx, dtype=int)
    runs = _runs_from_index(np.sort(int_idx), min_run_len=min_run_len, max_gap=max_gap)

    rows = []
    for i0, i1 in runs:
        # fetch start/end rows (by index label)
        try:
            r0 = eye_df.loc[i0]
            r1 = eye_df.loc[i1]
        except KeyError:
            # if label-based lookup fails (non-unique or non-monotonic), fallback to iloc around positions
            # This is rare, but keeps the function robust.
            r0 = eye_df.iloc[i0] if (0 <= i0 < len(eye_df)) else sub.iloc[0]
            r1 = eye_df.iloc[i1] if (0 <= i1 < len(eye_df)) else sub.iloc[-1]

        if "ms_axis" in eye_df.columns:
            s_ms = float(r0["ms_axis"])
            e_ms = float(r1["ms_axis"])
        else:
            # fall back to treating index as a proxy time
            s_ms = float(i0)
            e_ms = float(i1)

        s_ms, e_ms = _pad_ms_bounds(eye_df, s_ms, e_ms, pre_pad_ms, post_pad_ms)
        if e_ms <= s_ms:
            continue

        rows.append({
            "animal": str(animal),
            "block": str(block),
            "eye": eye_short,
            "start_ms": s_ms,
            "end_ms": e_ms,
        })

    out = pd.DataFrame(rows)
    if sort_by_start and not out.empty:
        out = out.sort_values(["animal", "block", "eye", "start_ms"]).reset_index(drop=True)
    return out


# ---------------------------- convenience: threshold -> events df ----------------------------
def threshold_to_events_df(
        eye_df: pd.DataFrame,
        column: str,
        *,
        op: str,  # one of: '>', '>=', '<', '<=', '==', '!=', 'abs>'
        value: float,
        animal: str,
        block: Union[int, str],
        eye: str,
        pre_pad_ms: float = 0.0,
        post_pad_ms: float = 0.0,
        min_run_len: int = 1,
        max_gap: int = 1,
        sort_by_start: bool = True,
) -> pd.DataFrame:
    """
    Build events by thresholding a single column without writing a query string.
    Examples:
        threshold_to_events_df(df, 'k_theta', op='<', value=-50, animal='PV_126', block=7, eye='L')
        threshold_to_events_df(df, 'speed_r', op='abs>', value=2.0, ...)
    """
    if column not in eye_df.columns:
        raise ValueError(f"column '{column}' not found in DataFrame.")

    if op == 'abs>':
        q = f"abs({column}) > @value"
    elif op in ('>', '>=', '<', '<=', '==', '!='):
        q = f"{column} {op} @value"
    else:
        raise ValueError("op must be one of {'>','>=','<','<=','==','!=','abs>'}")

    return query_to_events_df(
        eye_df, q,
        animal=animal, block=block, eye=eye,
        pre_pad_ms=pre_pad_ms, post_pad_ms=post_pad_ms,
        min_run_len=min_run_len, max_gap=max_gap, sort_by_start=sort_by_start
    )


# ---------------------------- per-block CSV IO ----------------------------
CSV_NAME = "manual_event_annotations.csv"  # one file per block in its analysis folder
TAG_COL = "manual_outlier_detected"  # boolean (object-dtype), None if untagged
TS_COL = "annotation_timestamp"  # 'YYYY_MM_DD_HH_MM'


def _now_stamp() -> str:
    # YYYY_MM_DD_HH_MM (local system time)
    return pd.Timestamp.now().strftime("%Y_%m_%d_%H_%M")


def _block_for_key(block_dict: Dict[str, object], animal: str, block_num: Union[str, int]):
    # your block_dict has various keys; be permissive
    for obj in block_dict.values():
        if getattr(obj, "animal_call", None) == animal and str(getattr(obj, "block_num", None)) == str(block_num):
            return obj
    # try direct keys e.g. "PV_126_block_7"
    key1 = f"{animal}_block_{int(block_num)}"
    key2 = f"{animal}_block_{int(block_num):03d}"
    if key1 in block_dict: return block_dict[key1]
    if key2 in block_dict: return block_dict[key2]
    raise KeyError(f"BlockSync not found for animal='{animal}', block='{block_num}'")


def _ann_path_for_block(bs) -> Path:
    p = Path(getattr(bs, "analysis_path"))
    p.mkdir(parents=True, exist_ok=True)
    return p / CSV_NAME


def load_block_annotations(bs: object) -> pd.DataFrame:
    """
    Load per-block annotation CSV if exists. Returns standardized df:
    ['animal_call','block','eye','start_ms','end_ms','manual_outlier_detected','annotation_timestamp']
    """
    path = _ann_path_for_block(bs)
    cols = ['animal_call', 'block', 'eye', 'start_ms', 'end_ms', TAG_COL, TS_COL]
    if not path.exists():
        return pd.DataFrame(columns=cols)
    df = pd.read_csv(path)
    for c in cols:
        if c not in df.columns:
            df[c] = np.nan
    df = df[cols].copy()
    # normalize dtypes
    df['animal_call'] = df['animal_call'].astype('string')
    df['block'] = df['block'].astype('string')
    df['eye'] = df['eye'].astype('string')
    df['start_ms'] = pd.to_numeric(df['start_ms'], errors='coerce').astype(float)
    df['end_ms'] = pd.to_numeric(df['end_ms'], errors='coerce').astype(float)

    # boolean-as-object tri-state
    def _coerce_tag(v):
        if v is None or (isinstance(v, float) and np.isnan(v)): return None
        if isinstance(v, bool): return v
        s = str(v).strip().lower()
        if s in ("true", "t", "1", "yes", "y"): return True
        if s in ("false", "f", "0", "no", "n"): return False
        return None

    df[TAG_COL] = df[TAG_COL].map(_coerce_tag).astype('object')
    df[TS_COL] = df[TS_COL].astype('string')
    return df


def _merge_annotations(old_df: pd.DataFrame, new_df: pd.DataFrame, overwrite: bool = True) -> pd.DataFrame:
    """
    Merge on (animal_call, block, eye, rounded start_ms, rounded end_ms).
    If overwrite=True, newer rows (from new_df) win.
    """
    if old_df is None or old_df.empty:
        return new_df.copy()

    out = pd.concat([old_df, new_df], ignore_index=True)
    # create integer keys for robust equality
    out['_s'] = np.rint(out['start_ms']).astype('Int64')
    out['_e'] = np.rint(out['end_ms']).astype('Int64')
    out['__ord__'] = out.index if overwrite else -out.index
    out = (out.sort_values(['animal_call', 'block', 'eye', '_s', '_e', '__ord__'])
           .drop_duplicates(['animal_call', 'block', 'eye', '_s', '_e'], keep='last')
           .drop(columns=['_s', '_e', '__ord__']))
    return out.reset_index(drop=True)


def _write_block_annotations(bs: object, df_block: pd.DataFrame, overwrite: bool = True) -> Path:
    """
    Write/merge df_block to that block's CSV.
    df_block must already be filtered to one (animal_call, block).
    """
    path = _ann_path_for_block(bs)
    existing = load_block_annotations(bs)
    final = _merge_annotations(existing, df_block, overwrite=overwrite)
    final.to_csv(path, index=False)
    return path


# ---------------------------- UPDATED reviewer (per-block IO) ----------------------------
def review_events_multi_with_arena_v2(
        block_dict: Dict[str, object],
        events_subset_df: pd.DataFrame,
        *,
        animal_col: str = "animal",
        block_col: str = "block",
        eye_col: str = "eye",
        start_ms_col: Optional[str] = None,  # auto-detect from {'start_ms','saccade_on_ms'}
        end_ms_col: Optional[str] = None,  # auto-detect from {'end_ms','saccade_off_ms'}
        window_scale: float = 0.85,
        text_cols: Tuple[str, ...] = ("phi", "theta", "peak_velocity", "magnitude_raw_angular", "pupil_diameter"),
        font_scale: float = 0.6,
        thickness: int = 2,
        wait_ms: int = 15,
        flip_mode: str = "vertical",
        ms_tolerance: float = 0.0,  # for preloading within tolerance
        overwrite_existing: bool = True,
        show_arena: bool = True,
) -> pd.DataFrame:
    """
    Like your current reviewer, but:
      • Preloads per-block CSVs (manual_event_annotations.csv) from each block's analysis folder.
      • Displays the preloaded GOOD/BAD state right away.
      • On export (or on-the-fly tagging), writes one CSV per block, merging safely.
      • Adds/updates 'annotation_timestamp' (YYYY_MM_DD_HH_MM) whenever a tag is changed.

    Controls / shortcuts remain the same:
      mark_bad / mark_good, export_annotated_df, Play/Pause, Prev/Next, step, Arena Switch, etc.
    """

    # ---- inner utils largely reused from your original (trimmed to only the deltas we need) ----
    def _resolve_time_cols(df: pd.DataFrame, start_c: Optional[str], end_c: Optional[str]) -> Tuple[str, str]:
        cand_start = [start_c, "start_ms", "saccade_on_ms"]
        cand_end = [end_c, "end_ms", "saccade_off_ms"]
        s = next((c for c in cand_start if c and c in df.columns), None)
        e = next((c for c in cand_end if c and c in df.columns), None)
        if s is None or e is None:
            raise ValueError("Could not resolve start/end ms columns.")
        return s, e

    # normalize input df
    s_col, e_col = _resolve_time_cols(events_subset_df, start_ms_col, end_ms_col)
    events = events_subset_df.copy().reset_index(drop=True)

    # required cols?
    for c in (animal_col, block_col, eye_col, s_col, e_col):
        if c not in events.columns:
            raise ValueError(f"events_subset_df is missing required column '{c}'")

    # Working std columns
    events["_animal_std"] = events[animal_col].astype(str)
    events["_block_std"] = events[block_col].astype(str)
    events["_eye_std"] = events[eye_col].astype(str)
    events["_start_ms_std"] = pd.to_numeric(events[s_col], errors="coerce").astype(float)
    events["_end_ms_std"] = pd.to_numeric(events[e_col], errors="coerce").astype(float)
    events["_start_key"] = pd.Series(np.rint(events["_start_ms_std"]).astype("Int64"))
    events["_end_key"] = pd.Series(np.rint(events["_end_ms_std"]).astype("Int64"))

    # Tri-state annotation col + timestamp
    for c in (TAG_COL, TS_COL):
        if c not in events.columns:
            events[c] = pd.Series([None] * len(events), dtype='object')

    def _coerce_tag(v):
        if v is None or (isinstance(v, float) and np.isnan(v)): return None
        if isinstance(v, bool): return v
        s = str(v).strip().lower()
        if s in ("true", "t", "1", "yes", "y"): return True
        if s in ("false", "f", "0", "no", "n"): return False
        return None

    # ---------------- preload per-block annotations ----------------
    # strategy: for each (animal, block), load that block's CSV then match by (eye, start_key, end_key)
    # exact match if ms_tolerance==0, else tolerant match
    grouped_keys = events.groupby(["_animal_std", "_block_std"]).indices
    filled = 0
    for (animal, block), idxs in grouped_keys.items():
        bs = _block_for_key(block_dict, animal, block)
        old = load_block_annotations(bs)
        if old.empty:
            continue
        old["_start_key"] = pd.Series(np.rint(old["start_ms"]).astype("Int64"))
        old["_end_key"] = pd.Series(np.rint(old["end_ms"]).astype("Int64"))

        if ms_tolerance <= 0:
            lookup = {(str(r.animal_call), str(r.block), str(r.eye), int(r._start_key), int(r._end_key)): (r[TAG_COL],
                                                                                                           r[TS_COL])
                      for _, r in old.dropna(subset=["_start_key", "_end_key"]).iterrows()}
            for i in idxs:
                sk, ek = events.at[i, "_start_key"], events.at[i, "_end_key"]
                k = (animal, block, str(events.at[i, "_eye_std"]), int(sk), int(ek))
                if k in lookup and events.at[i, TAG_COL] is None:
                    tag, ts = lookup[k]
                    events.at[i, TAG_COL] = _coerce_tag(tag)
                    events.at[i, TS_COL] = ts if (ts is None or not (isinstance(ts, float) and np.isnan(ts))) else None
                    filled += 1
        else:
            tol = float(ms_tolerance)
            for i in idxs:
                e_eye = str(events.at[i, "_eye_std"])
                s0, e0 = events.at[i, "_start_ms_std"], events.at[i, "_end_ms_std"]
                # filter old by same eye
                sub_old = old[old["eye"].astype(str) == e_eye]
                if sub_old.empty: continue
                # look for first within tolerance
                hit = sub_old[(sub_old["start_ms"].sub(s0).abs() <= tol) &
                              (sub_old["end_ms"].sub(e0).abs() <= tol)]
                if not hit.empty and events.at[i, TAG_COL] is None:
                    r = hit.iloc[0]
                    events.at[i, TAG_COL] = _coerce_tag(r[TAG_COL])
                    events.at[i, TS_COL] = r[TS_COL] if pd.notna(r[TS_COL]) else None
                    filled += 1
    if filled:
        print(f"[preload] Filled {filled} existing annotations from per-block CSVs.")

    # ---------------- OpenCV/UI state identical to your original (with small diffs) ----------------
    def _frame_col(df: pd.DataFrame) -> Optional[str]:
        for c in ("eye_frame", "frame", "frame_idx", "video_frame"):
            if c in df.columns: return c
        return None

    def _lookup_block(animal: str, block_num: str):
        return _block_for_key(block_dict, animal, block_num)

    def _nearest_row(df: pd.DataFrame, ms: float) -> Optional[pd.Series]:
        arr = df["ms_axis"].values
        if arr.size == 0: return None
        idx = int(np.argmin(np.abs(arr - ms)))
        return df.iloc[idx]

    def _apply_flip(img: np.ndarray) -> np.ndarray:
        return cv2.flip(img, 0) if flip_mode == "vertical" else img

    # ---- video handles
    capL = capR = capA = None
    cur_animal = cur_block = None
    left_df = right_df = arena_df = None
    left_frame_col = right_frame_col = None
    arena_frame_col = None
    fpsL = fpsR = fpsA = 60.0
    Wl = Hl = Wr = Hr = Wa = Ha = 0
    arena_idx = 0

    def _release_caps():
        nonlocal capL, capR, capA
        for c in (capL, capR, capA):
            try:
                if c is not None: c.release()
            except Exception:
                pass
        capL = capR = capA = None

    def _resolve_arena_frame_col(df: pd.DataFrame) -> str:
        for c in ["Arena_frame", "arena_frame", "arena_frames", "arena_frame_idx", "frame", "frame_idx", "video_frame",
                  "arena_idx"]:
            if c in df.columns: return c
        raise RuntimeError("final_sync_df has no recognizable arena frame column.")

    def _ensure_arena_ms_axis(bs) -> Tuple[pd.DataFrame, str]:
        fs = bs.final_sync_df.copy()
        fcol = _resolve_arena_frame_col(fs)
        if "ms_axis" not in fs.columns:
            # try joining from eye ms_axis via OE_timestamp
            joined = False
            for eye_df in [getattr(bs, "left_eye_data", None), getattr(bs, "right_eye_data", None)]:
                if eye_df is not None and "OE_timestamp" in fs.columns and \
                        "OE_timestamp" in eye_df.columns and "ms_axis" in eye_df.columns:
                    tmp = eye_df[["OE_timestamp", "ms_axis"]].dropna().drop_duplicates(subset=["OE_timestamp"])
                    fs = fs.merge(tmp, on="OE_timestamp", how="left")
                    joined = True
                    break
            if not joined:
                fs["__need_video_time__"] = True
        keep = [c for c in ["ms_axis", fcol, "OE_timestamp", "__need_video_time__"] if c in fs.columns]
        out = fs[keep].dropna(subset=[fcol]).copy()
        out[fcol] = pd.to_numeric(out[fcol], errors="coerce").astype("Int64")
        out = out.dropna(subset=[fcol])
        return out, fcol

    def _open_for(animal: str, block_num: str):
        nonlocal cur_animal, cur_block, left_df, right_df, arena_df
        nonlocal left_frame_col, right_frame_col, arena_frame_col
        nonlocal capL, capR, capA, fpsL, fpsR, fpsA, Wl, Hl, Wr, Hr, Wa, Ha, arena_idx

        if animal == cur_animal and block_num == cur_block:
            return

        _release_caps()
        bs = _lookup_block(animal, block_num)
        cur_animal, cur_block = animal, block_num

        left_df = getattr(bs, "left_eye_data", None)
        right_df = getattr(bs, "right_eye_data", None)
        if left_df is None or right_df is None:
            raise RuntimeError(f"Missing eye data for {animal} B{block_num}")
        if "ms_axis" not in left_df.columns or "ms_axis" not in right_df.columns:
            raise RuntimeError(f"'ms_axis' missing for {animal} B{block_num}")
        left_frame_col = _frame_col(left_df)
        right_frame_col = _frame_col(right_df)

        # open videos
        lv = Path(bs.le_videos[0]);
        rv = Path(bs.re_videos[0])
        capL_local = cv2.VideoCapture(str(lv));
        capR_local = cv2.VideoCapture(str(rv))
        if not capL_local.isOpened(): raise RuntimeError(f"Cannot open {lv}")
        if not capR_local.isOpened(): raise RuntimeError(f"Cannot open {rv}")
        capL, capR = capL_local, capR_local
        Wl, Hl = int(capL.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capL.get(cv2.CAP_PROP_FRAME_HEIGHT))
        Wr, Hr = int(capR.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capR.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fpsL = capL.get(cv2.CAP_PROP_FPS) or 60.0
        fpsR = capR.get(cv2.CAP_PROP_FPS) or 60.0

        # Arena
        if show_arena and getattr(bs, "arena_videos", None):
            arena_df_local, arena_frame_col_local = _ensure_arena_ms_axis(bs)
            arena_df = arena_df_local
            arena_frame_col = arena_frame_col_local
            arena_idx = max(0, min(arena_idx, len(bs.arena_videos) - 1))
            av = Path(bs.arena_videos[arena_idx])
            capA_local = cv2.VideoCapture(str(av))
            if capA_local.isOpened():
                capA = capA_local
                Wa, Ha = int(capA.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capA.get(cv2.CAP_PROP_FRAME_HEIGHT))
                fpsA = capA.get(cv2.CAP_PROP_FPS) or 60.0
                if "__need_video_time__" in arena_df.columns:
                    arena_df["ms_axis"] = (arena_df[arena_frame_col].astype(float) / float(fpsA)) * 1000.0
                    arena_df = arena_df.drop(columns=["__need_video_time__"], errors="ignore")
            else:
                print(f"[arena] Cannot open: {av}")
        else:
            arena_df = None

    def _seek(cap, idx: int):
        cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(idx)))

    def _read(cap):
        ret, f = cap.read()
        return f if ret else None

    def _overlay_text(img, lines: List[str], origin=(10, 24), vstep=22, color=(255, 255, 255)):
        x, y = origin
        for ln in lines:
            cv2.putText(img, ln, (x, y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness, cv2.LINE_AA)
            y += vstep

    def _overlay_ellipse(img, df_eye: Optional[pd.DataFrame], frame_idx: Optional[int]):
        if df_eye is None or frame_idx is None: return
        col = _frame_col(df_eye)
        if col is None: return
        hit = df_eye[df_eye[col] == frame_idx]
        if hit.empty: return
        row = hit.iloc[0]
        cx, cy = row.get("center_x", np.nan), row.get("center_y", np.nan)
        w, h = row.get("width", np.nan), row.get("height", np.nan)
        phi = row.get("phi", np.nan)
        if not (pd.isna(cx) or pd.isna(cy) or pd.isna(w) or pd.isna(h)):
            cv2.ellipse(
                img,
                (int(round(cx)), int(round(cy))),
                (max(1, int(round(w))), max(1, int(round(h)))),
                float(0 if pd.isna(phi) else phi),
                0, 360, (0, 255, 0), thickness
            )

    # display geometry
    disp_Wl = disp_Hl = disp_Wr = disp_Hr = disp_Wa = disp_Ha = 0

    # controls
    ctrl_w, ctrl_h = 540, 420
    buttons = {
        "Play": ((10, 10), (260, 60)),
        "Pause": ((280, 10), (530, 60)),
        "Prev": ((10, 80), (260, 130)),
        "Next": ((280, 80), (530, 130)),
        "Step -1": ((10, 150), (260, 200)),
        "Step +1": ((280, 150), (530, 200)),
        "mark_bad": ((10, 220), (260, 270)),
        "mark_good": ((280, 220), (530, 270)),
        "Arena Switch": ((10, 290), (260, 340)),
        "export_annotated_df": ((280, 290), (530, 340)),
    }

    COLOR_BG = (60, 60, 60)
    COLOR_BORDER = (180, 180, 180)
    COLOR_TEXT = (220, 220, 220)
    COLOR_BAD = (0, 0, 255)
    COLOR_GOOD = (0, 255, 0)
    COLOR_EXPORT = (0, 165, 255)
    COLOR_INFO = (180, 255, 180)

    last_status = ""

    def _draw_controls(idx: int, current_arena_name: Optional[str]) -> np.ndarray:
        img = np.zeros((ctrl_h, ctrl_w, 3), dtype=np.uint8)
        state = events[TAG_COL].iloc[idx]
        state_str = "UNSET" if (state is None or (isinstance(state, float) and np.isnan(state))) else (
            "BAD" if bool(state) else "GOOD")
        header = f"Event {idx + 1}/{len(events)} | state={state_str}"
        cv2.putText(img, header, (10, ctrl_h - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_TEXT, 1, cv2.LINE_AA)

        for name, ((x1, y1), (x2, y2)) in buttons.items():
            fill = COLOR_BG
            if name == "mark_bad" and (state is True):  fill = COLOR_BAD
            if name == "mark_good" and (state is False): fill = COLOR_GOOD
            if name == "export_annotated_df":            fill = COLOR_EXPORT
            cv2.rectangle(img, (x1, y1), (x2, y2), fill, -1)
            cv2.rectangle(img, (x1, y1), (x2, y2), COLOR_BORDER, 2)
            text_size, _ = cv2.getTextSize(name, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
            tx = x1 + (x2 - x1 - text_size[0]) // 2;
            ty = y1 + (y2 - y1 + text_size[1]) // 2
            cv2.putText(img, name, (tx, ty), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)

        if last_status:
            cv2.putText(img, last_status, (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, COLOR_INFO, 1, cv2.LINE_AA)

        if current_arena_name:
            cv2.putText(img, f"Arena: {current_arena_name}", (10, 44), cv2.FONT_HERSHEY_SIMPLEX, 0.55, COLOR_INFO, 1,
                        cv2.LINE_AA)

        return img

    def _hit_button(x, y):
        for name, ((x1, y1), (x2, y2)) in buttons.items():
            if x1 <= x <= x2 and y1 <= y <= y2: return name
        return None

    def _export_per_block():
        """Split current in-memory annotations by (animal, block) and write per-block CSVs with merge."""
        nonlocal last_status
        # Build a lean df with final schema for writing:
        out = pd.DataFrame({
            "animal_call": events["_animal_std"].astype(str),
            "block": events["_block_std"].astype(str),
            "eye": events["_eye_std"].astype(str),
            "start_ms": events["_start_ms_std"].astype(float),
            "end_ms": events["_end_ms_std"].astype(float),
            TAG_COL: events[TAG_COL].map(_coerce_tag).astype("object"),
            TS_COL: events[TS_COL].astype('string')
        })
        # Partition and write
        written = []
        for (animal, block), sub in out.groupby(["animal_call", "block"], dropna=False):
            bs = _block_for_key(block_dict, animal, block)
            path = _write_block_annotations(bs, sub.copy(), overwrite=overwrite_existing)
            written.append(str(path))
        last_status = f"Exported to {len(written)} block file(s)."

    # window setup
    cv2.namedWindow("Controls", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Left Eye", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Right Eye", cv2.WINDOW_NORMAL)
    if show_arena:
        cv2.namedWindow("Arena", cv2.WINDOW_NORMAL)

    playing = False
    quit_flag = False
    cur_idx = 0
    cur_ms = None
    step_ms = 1000.0 / 60.0

    def _median_step_ms(df: pd.DataFrame) -> float:
        if df is None or df.empty: return 1000.0 / 60.0
        ms = df["ms_axis"].values
        if ms.size < 2: return 1000.0 / 60.0
        d = np.diff(ms);
        d = d[np.isfinite(d) & (d > 0)]
        return float(np.median(d)) if d.size else 1000.0 / 60.0

    def _current_arena_name(bs) -> Optional[str]:
        try:
            if show_arena and getattr(bs, "arena_videos", None):
                return Path(bs.arena_videos[arena_idx]).name
        except Exception:
            pass
        return None

    controls_img = _draw_controls(cur_idx, None)
    cv2.imshow("Controls", controls_img)

    def on_mouse_controls(event, x, y, flags, param):
        nonlocal playing, cur_idx, controls_img, cur_ms, last_status, arena_idx
        if event != cv2.EVENT_LBUTTONDOWN:
            return
        name = _hit_button(x, y)
        last_status = ""
        if name == "Play":
            playing = True
        elif name == "Pause":
            playing = False
        elif name == "Prev":
            playing = False;
            cur_idx = (cur_idx - 1) % len(events);
            cur_ms = None
        elif name == "Next":
            playing = False;
            cur_idx = (cur_idx + 1) % len(events);
            cur_ms = None
        elif name == "Step -1":
            playing = False;
            cur_ms = None if cur_ms is None else cur_ms - step_ms
        elif name == "Step +1":
            playing = False;
            cur_ms = None if cur_ms is None else cur_ms + step_ms
        elif name == "mark_bad":
            events.at[events.index[cur_idx], TAG_COL] = True
            events.at[events.index[cur_idx], TS_COL] = _now_stamp()
        elif name == "mark_good":
            events.at[events.index[cur_idx], TAG_COL] = False
            events.at[events.index[cur_idx], TS_COL] = _now_stamp()
        elif name == "Arena Switch":
            bs = _lookup_block(str(events.loc[cur_idx, "_animal_std"]), str(events.loc[cur_idx, "_block_std"]))
            if show_arena and getattr(bs, "arena_videos", None):
                arena_idx = (arena_idx + 1) % len(bs.arena_videos)
                if capA is not None: capA.release()
                av = Path(bs.arena_videos[arena_idx])
                capA_local = cv2.VideoCapture(str(av))
                if capA_local.isOpened():
                    globals()['capA'] = capA_local
                    last_status = f"[arena] Showing: {av.name}"
                else:
                    last_status = f"[arena] Cannot open: {av.name}"
        elif name == "export_annotated_df":
            _export_per_block()

        bs = _lookup_block(str(events.loc[cur_idx, "_animal_std"]), str(events.loc[cur_idx, "_block_std"]))
        controls_img = _draw_controls(cur_idx, _current_arena_name(bs))
        cv2.imshow("Controls", controls_img)

    cv2.setMouseCallback("Controls", on_mouse_controls)

    # keyboard shortcuts mirror your originals
    while True:
        k = cv2.waitKey(wait_ms) & 0xFF
        if k in (27, ord('q'), ord('Q')):
            quit_flag = True
        elif k == 32:
            playing = not playing
        elif k == ord('['):
            playing = False; cur_idx = (cur_idx - 1) % len(events); cur_ms = None
        elif k == ord(']'):
            playing = False; cur_idx = (cur_idx + 1) % len(events); cur_ms = None
        elif k == ord(','):
            playing = False; cur_ms = None if cur_ms is None else cur_ms - step_ms
        elif k == ord('.'):
            playing = False; cur_ms = None if cur_ms is None else cur_ms + step_ms
        elif k in (ord('b'), ord('B')):
            events.at[events.index[cur_idx], TAG_COL] = True;  events.at[events.index[cur_idx], TS_COL] = _now_stamp()
        elif k in (ord('g'), ord('G')):
            events.at[events.index[cur_idx], TAG_COL] = False; events.at[events.index[cur_idx], TS_COL] = _now_stamp()
        elif k in (ord('e'), ord('E')):
            _export_per_block()
        elif k in (ord('a'), ord('A')):
            on_mouse_controls(cv2.EVENT_LBUTTONDOWN, *buttons["Arena Switch"][0], None, None)

        if quit_flag:
            break

        row = events.iloc[cur_idx]
        animal = str(row["_animal_std"]);
        block_num = str(row["_block_std"])
        start_ms = float(row["_start_ms_std"]);
        end_ms = float(row["_end_ms_std"])

        _open_for(animal, block_num)

        # first-time sizing
        if disp_Wl == 0:
            disp_Wl, disp_Hl = int(Wl * window_scale), int(Hl * window_scale)
            disp_Wr, disp_Hr = int(Wr * window_scale), int(Hr * window_scale)
            cv2.resizeWindow("Left Eye", disp_Wl, disp_Hl)
            cv2.resizeWindow("Right Eye", disp_Wr, disp_Hr)
            if show_arena and capA is not None:
                disp_Wa, disp_Ha = int(Wa * window_scale), int(Ha * window_scale)
                cv2.resizeWindow("Arena", disp_Wa, disp_Ha)
            cv2.resizeWindow("Controls", ctrl_w, ctrl_h)

        # step from eye data
        step_ms = np.mean([_median_step_ms(left_df), _median_step_ms(right_df)])

        if cur_ms is None: cur_ms = start_ms
        cur_ms = min(max(cur_ms, start_ms), end_ms)

        rowL = _nearest_row(left_df, cur_ms)
        rowR = _nearest_row(right_df, cur_ms)

        # Left frame
        L_img = np.zeros((Hl, Wl, 3), dtype=np.uint8)
        if rowL is not None and left_frame_col is not None and pd.notna(rowL[left_frame_col]):
            L_idx = int(rowL[left_frame_col]);
            capL.set(cv2.CAP_PROP_POS_FRAMES, L_idx);
            fL = _read(capL)
            if fL is not None:
                _overlay_ellipse(fL,left_df,L_idx)
                img = _apply_flip(fL.copy());
                lines = [f"Left | {animal} B{block_num} | t={cur_ms:.1f}ms | frame={L_idx}"]
                for c in text_cols:
                    if c in left_df.columns:
                        v = rowL.get(c, np.nan)
                        if pd.notna(v):
                            try:
                                lines.append(f"{c}={float(v):.3f}")
                            except Exception:
                                pass
                _overlay_text(img, lines, origin=(10, 24))
                L_img = img
        else:
            _overlay_text(L_img, [f"Left | {animal} B{block_num}", f"t={cur_ms:.1f}ms", "no synchronized frame"],
                          origin=(10, 24))
            L_img = _apply_flip(L_img)

        # Right frame
        R_img = np.zeros((Hr, Wr, 3), dtype=np.uint8)
        if rowR is not None and right_frame_col is not None and pd.notna(rowR[right_frame_col]):
            R_idx = int(rowR[right_frame_col]);
            capR.set(cv2.CAP_PROP_POS_FRAMES, R_idx);
            fR = _read(capR)
            if fR is not None:
                _overlay_ellipse(fR, right_df, R_idx)
                img = _apply_flip(fR.copy());
                lines = [f"Right | {animal} B{block_num} | t={cur_ms:.1f}ms | frame={R_idx}"]
                for c in text_cols:
                    if c in right_df.columns:
                        v = rowR.get(c, np.nan)
                        if pd.notna(v):
                            try:
                                lines.append(f"{c}={float(v):.3f}")
                            except Exception:
                                pass
                _overlay_text(img, lines, origin=(10, 24))
                R_img = img
        else:
            _overlay_text(R_img, [f"Right | {animal} B{block_num}", f"t={cur_ms:.1f}ms", "no synchronized frame"],
                          origin=(10, 24))
            R_img = _apply_flip(R_img)

        # Arena frame
        if show_arena and capA is not None and arena_df is not None:
            arow = _nearest_row(arena_df, cur_ms)
            if arow is not None and pd.notna(arow[arena_frame_col]):
                A_idx = int(arow[arena_frame_col]);
                capA.set(cv2.CAP_PROP_POS_FRAMES, A_idx);
                fA = _read(capA)
                if fA is not None:
                    A_img = _apply_flip(fA.copy())
                    _overlay_text(A_img, [f"Arena | {animal} B{block_num} | t={cur_ms:.1f}ms | frame={A_idx}"],
                                  origin=(10, 24))
                    cv2.imshow("Arena", cv2.resize(A_img, (disp_Wa, disp_Ha)))
            else:
                blank = np.zeros((max(1, Ha), max(1, Wa), 3), dtype=np.uint8)
                _overlay_text(blank, [f"Arena | no synchronized frame"], origin=(10, 24))
                cv2.imshow("Arena", cv2.resize(blank, (disp_Wa, disp_Ha)))

        cv2.imshow("Left Eye", cv2.resize(L_img, (disp_Wl, disp_Hl)))
        cv2.imshow("Right Eye", cv2.resize(R_img, (disp_Wr, disp_Hr)))
        bs = _lookup_block(animal, block_num)
        controls_img = _draw_controls(cur_idx, Path(bs.arena_videos[arena_idx]).name if (
                    show_arena and getattr(bs, "arena_videos", None)) else None)
        cv2.imshow("Controls", controls_img)

        # playback
        if playing:
            cur_ms += step_ms
            if cur_ms > end_ms:
                playing = False
                cur_idx = (cur_idx + 1) % len(events)
                cur_ms = None

    # cleanup & return
    capL = None;
    capR = None;
    capA = None
    cv2.destroyAllWindows()

    return events


# ======================= tiny helper to build events_df from a query quickly =======================
def make_events_df(
        starts: Sequence[Union[int, float]],
        ends: Sequence[Union[int, float]],
        *,
        animal: Union[str, Sequence[str]],
        block: Union[Union[int, str], Sequence[Union[int, str]]],
        eye: Union[str, Sequence[str]],
        time_unit: str = "ms",
        sort_by_start: bool = True,
        drop_invalid: bool = True,
) -> pd.DataFrame:
    n = len(starts)
    if n != len(ends): raise ValueError("starts and ends must be the same length")

    def _broadcast(x, name):
        if isinstance(x, (list, tuple, np.ndarray, pd.Series)):
            if len(x) != n: raise ValueError(f"{name} length must match {n}")
            return list(x)
        return [x] * n

    animals = _broadcast(animal, "animal")
    blocks = _broadcast(block, "block")
    eyes = _broadcast(eye, "eye")

    starts = pd.to_numeric(pd.Series(starts), errors="coerce").astype(float)
    ends = pd.to_numeric(pd.Series(ends), errors="coerce").astype(float)
    if time_unit.lower().startswith("s"):
        starts *= 1000.0;
        ends *= 1000.0
    elif time_unit.lower() not in {"ms", "millisecond", "milliseconds"}:
        raise ValueError("time_unit must be 'ms' or 's'")

    df = pd.DataFrame({
        "animal": pd.Series(animals, dtype="string"),
        "block": pd.Series(blocks, dtype="string"),
        "eye": pd.Series(eyes, dtype="string"),
        "start_ms": starts,
        "end_ms": ends,
    }).dropna(subset=["start_ms", "end_ms"])
    if drop_invalid:
        df = df[df["start_ms"] < df["end_ms"]].copy()
    if sort_by_start and not df.empty:
        df = df.sort_values(["animal", "block", "eye", "start_ms", "end_ms"]).reset_index(drop=True)
    return df
def query_to_events_df_merge(
    eye_df: pd.DataFrame,
    query_str: str,
    *,
    animal: str,
    block: Union[int, str],
    eye: str,                      # 'L' or 'R' (case-insensitive)
    pre_pad_ms: float = 0.0,
    post_pad_ms: float = 0.0,
    min_run_len: int = 1,
    max_gap: int = 1,
    sort_by_start: bool = True,
    # --- merging controls ---
    existing_qdf: Optional[pd.DataFrame] = None,
    merge_tol_ms: float = 1.0,     # rows with |Δstart|<=tol & |Δend|<=tol are considered the same event
    binocular_label: str = "LR",   # label to use when an event exists for both eyes
) -> pd.DataFrame:
    """
    Build events from a query and merge into an existing qdf, collapsing binocular duplicates.

    Output columns match your workflow: ['animal','block','eye','start_ms','end_ms'].
    If a Left and Right event share timestamps (within merge_tol_ms), they are merged into one row
    with eye=binocular_label (default 'LR').
    """
    # --- helpers (reuse your previous logic) ---
    def _runs_from_index(int_index: np.ndarray, min_run_len: int = 1, max_gap: int = 1) -> List[Tuple[int,int]]:
        if int_index.size == 0: return []
        diffs = np.diff(int_index)
        boundaries = np.where(diffs > max_gap)[0]
        starts = np.r_[0, boundaries + 1]
        ends   = np.r_[boundaries, len(int_index) - 1]
        runs = [(int(int_index[s]), int(int_index[e])) for s, e in zip(starts, ends)]
        if min_run_len > 1:
            runs = [r for r in runs if (r[1] - r[0] + 1) >= min_run_len]
        return runs

    def _pad_ms_bounds(df: pd.DataFrame, start_ms: float, end_ms: float, pre_pad_ms: float, post_pad_ms: float) -> Tuple[float,float]:
        if "ms_axis" not in df.columns:
            return start_ms, end_ms
        ms = df["ms_axis"].values
        s = start_ms - float(pre_pad_ms)
        e = end_ms + float(post_pad_ms)
        if ms.size:
            lo, hi = (ms[0], ms[-1]) if ms[-1] >= ms[0] else (ms[-1], ms[0])
            s = min(max(s, lo), hi)
            e = min(max(e, lo), hi)
        return float(s), float(e)

    def _build_from_query() -> pd.DataFrame:
        # filter
        try:
            sub = eye_df.query(query_str)
        except Exception as e:
            raise ValueError(f"query failed: {e}")
        if sub.empty:
            return pd.DataFrame(columns=["animal","block","eye","start_ms","end_ms"])

        # contiguity on integer-like index (fallback to positional)
        if not np.issubdtype(sub.index.dtype, np.integer):
            pos_idx = eye_df.index.get_indexer(sub.index)
            valid = pos_idx >= 0
            int_idx = pos_idx[valid]
            sub = sub.iloc[np.where(valid)[0]]
        else:
            int_idx = sub.index.values
        int_idx = np.asarray(np.sort(int_idx), dtype=int)
        runs = _runs_from_index(int_idx, min_run_len=min_run_len, max_gap=max_gap)

        eye_short = "L" if str(eye).lower().startswith("l") else "R"
        rows = []
        for i0, i1 in runs:
            # robust loc/iloc fallback
            try:
                r0 = eye_df.loc[i0]; r1 = eye_df.loc[i1]
            except KeyError:
                r0 = eye_df.iloc[i0] if (0 <= i0 < len(eye_df)) else sub.iloc[0]
                r1 = eye_df.iloc[i1] if (0 <= i1 < len(eye_df)) else sub.iloc[-1]

            if "ms_axis" in eye_df.columns:
                s_ms = float(r0["ms_axis"]); e_ms = float(r1["ms_axis"])
            else:
                s_ms = float(i0); e_ms = float(i1)

            s_ms, e_ms = _pad_ms_bounds(eye_df, s_ms, e_ms, pre_pad_ms, post_pad_ms)
            if e_ms <= s_ms:
                continue

            rows.append({
                "animal": str(animal),
                "block":  str(block),
                "eye":    eye_short,
                "start_ms": s_ms,
                "end_ms":   e_ms,
            })

        out = pd.DataFrame(rows)
        if sort_by_start and not out.empty:
            out = out.sort_values(["animal","block","eye","start_ms"]).reset_index(drop=True)
        return out

    def _merge_eye_labels(labels: Sequence[str]) -> str:
        """Combine eyes across duplicates; {L,R} -> binocular_label; pass through singletons."""
        uniq = set(s.upper() for s in labels if isinstance(s, str) and len(s) > 0)
        if binocular_label.upper() in uniq:
            return binocular_label  # already merged
        if "L" in uniq and "R" in uniq:
            return binocular_label
        # single eye remains
        if "L" in uniq: return "L"
        if "R" in uniq: return "R"
        # fallback: join unique tokens
        return binocular_label if len(uniq) > 1 else (next(iter(uniq)) if uniq else binocular_label)

    def _merge_into_existing(new_df: pd.DataFrame, existing: pd.DataFrame) -> pd.DataFrame:
        if existing is None or existing.empty:
            return new_df.copy()

        # standardize columns
        need_cols = ["animal","block","eye","start_ms","end_ms"]
        for df in (existing, new_df):
            for c in need_cols:
                if c not in df.columns:
                    df[c] = np.nan
        ex = existing[need_cols].copy()
        nw = new_df[need_cols].copy()

        # numeric
        for c in ("start_ms","end_ms"):
            ex[c] = pd.to_numeric(ex[c], errors="coerce").astype(float)
            nw[c] = pd.to_numeric(nw[c], errors="coerce").astype(float)
        for c in ("animal","block","eye"):
            ex[c] = ex[c].astype(str)
            nw[c] = nw[c].astype(str)

        # build tolerance keys
        tol = max(1e-9, float(merge_tol_ms))
        ex["_k_start"] = np.rint(ex["start_ms"] / tol).astype("Int64")
        ex["_k_end"]   = np.rint(ex["end_ms"]   / tol).astype("Int64")
        nw["_k_start"] = np.rint(nw["start_ms"] / tol).astype("Int64")
        nw["_k_end"]   = np.rint(nw["end_ms"]   / tol).astype("Int64")

        combined = pd.concat([ex, nw], ignore_index=True)

        # group by (animal, block, approx start, approx end) and reduce
        def _reduce(group: pd.DataFrame) -> pd.Series:
            eyes = group["eye"].tolist()
            eye_merged = _merge_eye_labels(eyes)
            # choose representative start/end: median is robust if tiny jitter exists
            s = float(np.median(group["start_ms"].values))
            e = float(np.median(group["end_ms"].values))
            # keep canonical animal/block (string)
            a = str(group["animal"].iloc[0])
            b = str(group["block"].iloc[0])
            return pd.Series({"animal": a, "block": b, "eye": eye_merged, "start_ms": s, "end_ms": e})

        merged = (combined
                  .groupby(["animal","block","_k_start","_k_end"], dropna=False, sort=True)
                  .apply(_reduce)
                  .reset_index(drop=True))

        if sort_by_start and not merged.empty:
            merged = merged.sort_values(["animal","block","start_ms","end_ms","eye"]).reset_index(drop=True)

        return merged

    # --- build from query and merge ---
    built = _build_from_query()
    return _merge_into_existing(built, existing_qdf)

# NEW FROM HERE:

# === Outlier reporting, tagging, and Bokeh verification ===
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Sequence, Union
from dataclasses import dataclass
from itertools import chain
from bokeh.plotting import figure, show
from bokeh.layouts import gridplot, column
from bokeh.models import ColumnDataSource, BoxAnnotation, HoverTool
from bokeh.io import output_notebook

output_notebook()  # comment out if you prefer standalone HTML files only


# ------------------------------ configuration dataclass ------------------------------
@dataclass
class EyeColumns:
    ms: str = "ms_axis"
    phi: str = "k_phi"             # angular elevation (deg)
    theta: str = "k_theta"         # angular azimuth (deg)
    pupil: str = "pupil_diameter"  # units: your pipeline's pupil metric
    frame: Optional[str] = None    # optional, not required here


# ------------------------------ robust zscore ------------------------------
def _robust_zscore(x: pd.Series) -> pd.Series:
    """Median/MAD z-score (less sensitive to tails); NaN-safe."""
    x = pd.to_numeric(x, errors="coerce")
    med = np.nanmedian(x.values)
    mad = np.nanmedian(np.abs(x.values - med))
    scale = 1.4826 * (mad if mad > 0 else (np.nanstd(x.values) if np.nanstd(x.values) > 0 else 1.0))
    return (x - med) / scale


# ------------------------------ helpers: boolean runs -> (start_ms, end_ms) ------------------------------
def _boolean_runs_to_events(ms: np.ndarray, mask: np.ndarray, bridge_ms: float) -> List[Tuple[float, float]]:
    """
    Convert a boolean mask over ms-axis into merged [start_ms, end_ms] events, stitching gaps <= bridge_ms.
    """
    ms = ms.astype(float)
    mask = mask.astype(bool)
    if ms.size == 0 or not mask.any():
        return []

    # find initial contiguous runs where mask is True
    idx = np.flatnonzero(mask)
    # split on gaps > 1 in index (contiguity by sampling steps)
    split_points = np.where(np.diff(idx) > 1)[0]
    starts = np.r_[0, split_points + 1]
    ends = np.r_[split_points, len(idx) - 1]

    raw_events = [(ms[idx[s]], ms[idx[e]]) for s, e in zip(starts, ends)]
    if not raw_events:
        return []

    # stitch neighboring events if the gap between them ≤ bridge_ms
    merged = [raw_events[0]]
    for s, e in raw_events[1:]:
        prev_s, prev_e = merged[-1]
        gap = max(0.0, s - prev_e)
        if gap <= float(bridge_ms):
            merged[-1] = (prev_s, e)
        else:
            merged.append((s, e))
    return merged


def _overlap(a: Tuple[float, float], b: Tuple[float, float], tol_ms: float = 0.0) -> bool:
    return not (a[1] < b[0] - tol_ms or b[1] < a[0] - tol_ms)


def _merge_binocular(
    events_L: List[Tuple[float, float, str]],
    events_R: List[Tuple[float, float, str]],
    tol_ms: float = 1.0
) -> List[Tuple[float, float, str, str]]:
    """
    Merge L/R event lists (start,end,source) into possibly binocular LR events.
    Returns list of (start_ms, end_ms, eye_label, source_merged).
    Merging rule: if any L and R event overlap within tol_ms -> one LR event with unioned sources.
    Otherwise keep single-eye events.
    """
    out: List[Tuple[float, float, str, str]] = []
    used_R = np.zeros(len(events_R), dtype=bool)

    for sL, eL, srcL in events_L:
        merged_flag = False
        for j, (sR, eR, srcR) in enumerate(events_R):
            if used_R[j]:
                continue
            if _overlap((sL, eL), (sR, eR), tol_ms=tol_ms):
                s = float(min(sL, sR))
                e = float(max(eL, eR))
                src = f"L[{srcL}] + R[{srcR}]"
                out.append((s, e, "LR", src))
                used_R[j] = True
                merged_flag = True
                break
        if not merged_flag:
            out.append((float(sL), float(eL), "L", f"L[{srcL}]"))

    # append remaining R-only
    for (sR, eR, srcR), used in zip(events_R, used_R):
        if not used:
            out.append((float(sR), float(eR), "R", f"R[{srcR}]"))

    # stable sort by start
    out.sort(key=lambda r: (r[0], r[1]))
    return out


# ------------------------------ (1) threshold sweep report ------------------------------
def outlier_threshold_report(
    block,
    z_abs_list: Sequence[float],
    cols: EyeColumns = EyeColumns(),
    physiol_limits: Dict[str, Tuple[Optional[float], Optional[float]]] = None,
    eyes: Sequence[str] = ("L", "R")
) -> pd.DataFrame:
    """
    For each |z| in z_abs_list, report how many samples would be flagged per eye & signal.
    physiol_limits: dict with keys in {'phi','theta','pupil'} -> (min,max) hard bounds (None to disable a side).
    Returns tidy DataFrame for inspection.
    """
    physiol_limits = physiol_limits or {}
    rows = []

    eye_map = {
        "L": getattr(block, "left_eye_data"),
        "R": getattr(block, "right_eye_data"),
    }
    for eye in eyes:
        df = eye_map[eye]
        # ensure columns exist
        for c in (cols.ms, cols.phi, cols.theta, cols.pupil):
            if c not in df.columns:
                raise ValueError(f"Column '{c}' missing for eye {eye}")

        z_phi = _robust_zscore(df[cols.phi])
        z_theta = _robust_zscore(df[cols.theta])
        z_pupil = _robust_zscore(df[cols.pupil])

        # hard limits
        def lim_mask(signal: str) -> np.ndarray:
            v = df[getattr(cols, signal)]
            lo, hi = physiol_limits.get(signal, (None, None))
            lo_ok = np.full(len(v), True) if lo is None else (v >= float(lo))
            hi_ok = np.full(len(v), True) if hi is None else (v <= float(hi))
            return ~(lo_ok & hi_ok)  # True where violates limits

        hard_phi = lim_mask("phi")
        hard_theta = lim_mask("theta")
        hard_pupil = lim_mask("pupil")

        for zthr in z_abs_list:
            rel_phi = np.abs(z_phi.values) > float(zthr)
            rel_theta = np.abs(z_theta.values) > float(zthr)
            rel_pupil = np.abs(z_pupil.values) > float(zthr)
            rows.extend([
                {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "phi",
                 "criterion": f"|z|>{zthr}", "count": int(rel_phi.sum()), "percent": 100*rel_phi.mean()},
                {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "theta",
                 "criterion": f"|z|>{zthr}", "count": int(rel_theta.sum()), "percent": 100*rel_theta.mean()},
                {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "pupil",
                 "criterion": f"|z|>{zthr}", "count": int(rel_pupil.sum()), "percent": 100*rel_pupil.mean()},
            ])

        # add hard-limit rows once (criterion label 'limits')
        rows.extend([
            {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "phi",
             "criterion": "limits", "count": int(hard_phi.sum()), "percent": 100*hard_phi.mean()},
            {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "theta",
             "criterion": "limits", "count": int(hard_theta.sum()), "percent": 100*hard_theta.mean()},
            {"animal": block.animal_call, "block": str(block.block_num), "eye": eye, "signal": "pupil",
             "criterion": "limits", "count": int(hard_pupil.sum()), "percent": 100*hard_pupil.mean()},
        ])

    rep = pd.DataFrame(rows)
    return rep.sort_values(["eye", "signal", "criterion"]).reset_index(drop=True)


# ------------------------------ (2) tagging with bridging + sources ------------------------------
def tag_outliers_to_events(
    block,
    z_abs_threshold: float,
    cols: EyeColumns = EyeColumns(),
    physiol_limits: Dict[str, Tuple[Optional[float], Optional[float]]] = None,
    bridge_ms: float = 50.0,
    binocular_merge: bool = True,
    binocular_tol_ms: float = 1.0,
    min_duration_ms: float = 0.0,
    source_mode: str = "union",   # 'union' or 'max'
    *,
    phi_limits: Optional[Tuple[Optional[float], Optional[float]]] = None,
    theta_limits: Optional[Tuple[Optional[float], Optional[float]]] = None,
    pupil_limits: Optional[Tuple[Optional[float], Optional[float]]] = None,
) -> pd.DataFrame:
    """
    Creates an events dataframe:
      ['animal','block','eye','start_ms','end_ms','source']

    Flags samples by OR of:
      • (|z| > z_abs_threshold) for {phi, theta, pupil}
      • OR outside hard limits for any provided limits
          - precedence: explicit *limits arguments override entries in physiol_limits

    Then:
      • Builds events per eye from the boolean mask, stitching gaps <= bridge_ms.
      • Optionally merges L/R overlaps into a single 'LR' event (within binocular_tol_ms).
      • 'source' records which signals triggered within each event.

    Parameters added:
      phi_limits, theta_limits, pupil_limits:
         Tuples of (min, max). Use None to disable a side, e.g. (-40, None).
         If provided, they override the corresponding key in physiol_limits.
    """
    # ---- resolve effective hard limits (explicit args override dict) ----
    physiol_limits = dict(physiol_limits or {})
    if phi_limits is not None:
        physiol_limits["phi"] = phi_limits
    if theta_limits is not None:
        physiol_limits["theta"] = theta_limits
    if pupil_limits is not None:
        physiol_limits["pupil"] = pupil_limits

    def build_eye_events(df: pd.DataFrame, eye_label: str) -> List[Tuple[float, float, str]]:
        # robust z
        z_phi = _robust_zscore(df[cols.phi])
        z_theta = _robust_zscore(df[cols.theta])
        z_pupil = _robust_zscore(df[cols.pupil])

        rel_phi = np.abs(z_phi.values) > float(z_abs_threshold)
        rel_theta = np.abs(z_theta.values) > float(z_abs_threshold)
        rel_pupil = np.abs(z_pupil.values) > float(z_abs_threshold)

        ms = pd.to_numeric(df[cols.ms], errors="coerce").values.astype(float)

        # hard-limit masks (True where value violates limits)
        def hard(signal: str) -> np.ndarray:
            v = pd.to_numeric(df[getattr(cols, signal)], errors="coerce").values
            lo, hi = physiol_limits.get(signal, (None, None))
            lo_bad = np.zeros_like(v, dtype=bool) if lo is None else (v < float(lo))
            hi_bad = np.zeros_like(v, dtype=bool) if hi is None else (v > float(hi))
            return lo_bad | hi_bad

        hard_phi = hard("phi")
        hard_theta = hard("theta")
        hard_pupil = hard("pupil")

        # union mask across all signals/criteria
        mask = rel_phi | rel_theta | rel_pupil | hard_phi | hard_theta | hard_pupil

        # derive events with bridging
        intervals = _boolean_runs_to_events(ms, mask, bridge_ms=bridge_ms)

        # attribute sources inside each interval
        out: List[Tuple[float, float, str]] = []
        for s, e in intervals:
            in_evt = (ms >= s) & (ms <= e)

            causes = []
            if rel_phi[in_evt].any():    causes.append("phi_z")
            if rel_theta[in_evt].any():  causes.append("theta_z")
            if rel_pupil[in_evt].any():  causes.append("pupil_z")
            if hard_phi[in_evt].any():   causes.append("phi_limit")
            if hard_theta[in_evt].any(): causes.append("theta_limit")
            if hard_pupil[in_evt].any(): causes.append("pupil_limit")
            if not causes:
                causes = ["unknown"]

            if float(e - s) >= float(min_duration_ms):
                if source_mode == "max":
                    # pick dominant trigger by sample count
                    counts = {
                        "phi_z": int(rel_phi[in_evt].sum()),
                        "theta_z": int(rel_theta[in_evt].sum()),
                        "pupil_z": int(rel_pupil[in_evt].sum()),
                        "phi_limit": int(hard_phi[in_evt].sum()),
                        "theta_limit": int(hard_theta[in_evt].sum()),
                        "pupil_limit": int(hard_pupil[in_evt].sum()),
                    }
                    # keep only keys present in causes to avoid zero-only winners
                    counts = {k: v for k, v in counts.items() if k in causes}
                    top = max(counts, key=counts.get) if counts else "unknown"
                    src = top
                else:
                    src = "+".join(sorted(set(causes)))
                out.append((float(s), float(e), src))
        return out

    # sanity checks for required columns
    le_df = getattr(block, "left_eye_data")
    re_df = getattr(block, "right_eye_data")
    for eye_df in (le_df, re_df):
        for c in (cols.ms, cols.phi, cols.theta, cols.pupil):
            if c not in eye_df.columns:
                raise ValueError(f"Required column '{c}' missing in eye dataframe.")

    events_L = build_eye_events(le_df, "L")
    events_R = build_eye_events(re_df, "R")

    # binocular merge (overlap -> LR)
    if binocular_merge:
        merged = _merge_binocular(events_L, events_R, tol_ms=binocular_tol_ms)
    else:
        merged = ([(s, e, "L", f"L[{src}]") for (s, e, src) in events_L] +
                  [(s, e, "R", f"R[{src}]") for (s, e, src) in events_R])
        merged.sort(key=lambda r: (r[0], r[1]))

    out_df = pd.DataFrame({
        "animal": str(block.animal_call),
        "block": str(block.block_num),
        "eye": [lab for (_, _, lab, _) in merged],
        "start_ms": [s for (s, _, _, _) in merged],
        "end_ms": [e for (_, e, _, _) in merged],
        "source": [src for (_, _, _, src) in merged],
    })
    return out_df[["animal", "block", "eye", "start_ms", "end_ms", "source"]]


# ------------------------------ (3) Bokeh verification plotter ------------------------------
def bokeh_verify_outliers(
    block,
    events_df: pd.DataFrame,
    cols: EyeColumns = EyeColumns(),
    title: str = "Eye signals with tagged intervals",
    export_html: Optional[str] = None  # path to write a standalone HTML (optional)
):
    """
    Creates an interactive Bokeh view:
      - Left/Right eye: phi(t), theta(t), pupil(t) on shared x-range (ms).
      - Tagged intervals shown as shaded spans; L, R, and LR use different colors.
      - Pan/zoom linked across all plots.
    """
    # pull data
    L = getattr(block, "left_eye_data")
    R = getattr(block, "right_eye_data")
    for df in (L, R):
        for c in (cols.ms, cols.phi, cols.theta, cols.pupil):
            if c not in df.columns:
                raise ValueError(f"Column '{c}' missing for verification: {c}")

    srcL = ColumnDataSource(dict(
        ms=L[cols.ms].astype(float),
        phi=pd.to_numeric(L[cols.phi], errors="coerce"),
        theta=pd.to_numeric(L[cols.theta], errors="coerce"),
        pupil=pd.to_numeric(L[cols.pupil], errors="coerce"),
    ))
    srcR = ColumnDataSource(dict(
        ms=R[cols.ms].astype(float),
        phi=pd.to_numeric(R[cols.phi], errors="coerce"),
        theta=pd.to_numeric(R[cols.theta], errors="coerce"),
        pupil=pd.to_numeric(R[cols.pupil], errors="coerce"),
    ))

    # figures (shared x_range)
    pL_phi = figure(title=f"{title} – Left φ", width=1200, height=180, tools="pan,wheel_zoom,box_zoom,reset,save",
                    x_axis_label="time (ms)")
    pL_theta = figure(title="Left θ", width=1200, height=180, x_range=pL_phi.x_range,
                      tools="pan,wheel_zoom,box_zoom,reset,save")
    pL_pupil = figure(title="Left pupil", width=1200, height=180, x_range=pL_phi.x_range,
                      tools="pan,wheel_zoom,box_zoom,reset,save", x_axis_label="time (ms)")

    pR_phi = figure(title="Right φ", width=1200, height=180, x_range=pL_phi.x_range,
                    tools="pan,wheel_zoom,box_zoom,reset,save")
    pR_theta = figure(title="Right θ", width=1200, height=180, x_range=pL_phi.x_range,
                      tools="pan,wheel_zoom,box_zoom,reset,save")
    pR_pupil = figure(title="Right pupil", width=1200, height=180, x_range=pL_phi.x_range,
                      tools="pan,wheel_zoom,box_zoom,reset,save", x_axis_label="time (ms)")

    for p in (pL_phi, pL_theta, pL_pupil, pR_phi, pR_theta, pR_pupil):
        p.add_tools(HoverTool(tooltips=[("t (ms)", "@ms")], mode="vline"))

    # draw lines
    pL_phi.line("ms", "phi", source=srcL)
    pL_theta.line("ms", "theta", source=srcL)
    pL_pupil.line("ms", "pupil", source=srcL)

    pR_phi.line("ms", "phi", source=srcR)
    pR_theta.line("ms", "theta", source=srcR)
    pR_pupil.line("ms", "pupil", source=srcR)

    # add shaded tags
    def add_spans(figs, start_ms, end_ms, eye):
        if eye == "L":
            ba = BoxAnnotation(left=start_ms, right=end_ms, fill_alpha=0.18, fill_color="red")
            for p in figs[:3]:  # left panels
                p.add_layout(ba)
        elif eye == "R":
            ba = BoxAnnotation(left=start_ms, right=end_ms, fill_alpha=0.18, fill_color="blue")
            for p in figs[3:]:  # right panels
                p.add_layout(ba)
        else:  # LR
            ba = BoxAnnotation(left=start_ms, right=end_ms, fill_alpha=0.12, fill_color="purple")
            for p in figs:
                p.add_layout(ba)

    figs = [pL_phi, pL_theta, pL_pupil, pR_phi, pR_theta, pR_pupil]
    for _, row in events_df.iterrows():
        add_spans(figs, float(row["start_ms"]), float(row["end_ms"]), str(row["eye"]).upper())

    layout = gridplot([[pL_phi], [pL_theta], [pL_pupil], [pR_phi], [pR_theta], [pR_pupil]], toolbar_location="above")
    if export_html:
        from bokeh.io import output_file
        output_file(export_html, title=title)
    show(layout)


In [None]:
# === Example usage with a single BlockSync object ===

# 0) choose your block and column mapping if different names are used
cols = EyeColumns(
    ms="ms_axis",
    phi="k_phi",
    theta="k_theta",
    pupil="pupil_diameter",
)

# 1) sweep report to help pick a threshold
z_list = [2.0, 2.5, 3.0, 3.5, 4.0,4.5]  # absolute robust z-scores to preview
phys_limits = {
    # Use None to disable a side; set your physiology-based bounds here (deg / diameter unit)
    "phi":   (-30.0, 30.0),
    "theta": (-50.0, 25.0),
    "pupil": (1.5, 2.5),
}
report_df = outlier_threshold_report(block, z_abs_list=z_list, cols=cols, physiol_limits=phys_limits)
display(report_df.head(20))

In [None]:
# 2) tag events with your chosen parameters
events_df = tag_outliers_to_events(
    block,
    z_abs_threshold=4.5,     # pick based on the report
    cols=cols,
    physiol_limits=phys_limits,
    bridge_ms=50.0,          # contiguous if gaps <= 50 ms
    binocular_merge=True,    # collapse overlapping L/R into LR
    binocular_tol_ms=1.0,    # overlap tolerance when merging
    min_duration_ms=0.0,     # drop micro events if you want
    source_mode="union",phi_limits=(-40,40),theta_limits=(-45,45)     # 'union' (collect all causes) or 'max' (dominant cause)
)

In [None]:
# This prints in your requested format (+ a 'source' column):
print(events_df.head())

In [None]:
print(len(events_df))

In [None]:
# 3) visual verification (opens interactive Bokeh; shaded spans show L/R/LR tags)

bokeh_verify_outliers(block, events_df, cols=cols, title=f"{block.animal_call} B{block.block_num} verification",export_html=block.analysis_path / 'outlier_tags_verifier.html')

In [None]:
# 4) (optional) launch your existing manual reviewer for final GOOD/BAD triage
# from your provided functions (must be in the same kernel/session):
reviewed_df = review_events_multi_with_arena_v2(block_dict, events_df)
# reviewed_df will include tri-state 'manual_outlier_detected' + timestamps and write per-block CSV on export.


In [None]:
reviewed_df

In [None]:
# THIS IS WHERE WE MAKE CLEAN EYE_DFs
import numpy as np
import pandas as pd
from typing import Optional, Tuple, Dict

# assumes EyeColumns dataclass, TAG_COL, load_block_annotations() exist in scope

def apply_manual_outlier_cleanup(
    block,
    annotations_df: Optional[pd.DataFrame] = None,
    *,
    cols: EyeColumns = EyeColumns(),
    bad_col: str = TAG_COL,            # "manual_outlier_detected"
    value_cols: Tuple[str, str, str] = ("k_theta", "k_phi", "pupil_diameter"),
    add_mask_columns: bool = True,     # add boolean mask columns to the *clean* dfs for audit
) -> Dict[str, pd.DataFrame]:
    """
    Create cleaned eye-data copies on the block by setting 'bad' intervals to NaN in value_cols.
    'Bad' is defined by the GUI review's per-interval tag bad_col == True.

    Inputs
    ------
    block : BlockSync-like object with:
        - animal_call, block_num
        - left_eye_data, right_eye_data (DataFrames with cols.ms present)
    annotations_df : optional DataFrame of events with columns:
        ['animal_call','block','eye','start_ms','end_ms', bad_col]
        If None, loads from this block's per-block CSV via load_block_annotations(block).
    cols : EyeColumns
        Names for ms, phi, theta, pupil in the eye DataFrames (defaults match your pipeline).
    bad_col : str
        Column name in annotations_df whose True values denote BAD intervals to wipe.
    value_cols : tuple
        The columns that will be set to NaN inside BAD intervals.
    add_mask_columns : bool
        If True, adds a boolean 'clean_badmask' column to each clean df for traceability.

    Side effects
    -----------
    Sets on `block`:
        - block.left_eye_data_clean  (pd.DataFrame)
        - block.right_eye_data_clean (pd.DataFrame)
      If block has `left_eye_df`/`right_eye_df`, also sets:
        - block.left_eye_df_clean / block.right_eye_df_clean

    Returns
    -------
    dict with keys: {'left_eye_data_clean','right_eye_data_clean'}
    """
    # --- resolve annotations ---
    if annotations_df is None:
        annotations_df = load_block_annotations(block)

    if annotations_df is None or annotations_df.empty:
        # nothing to wipe; just copy originals
        L_clean = getattr(block, "left_eye_data").copy()
        R_clean = getattr(block, "right_eye_data").copy()
        if add_mask_columns:
            L_clean["clean_badmask"] = False
            R_clean["clean_badmask"] = False
        setattr(block, "left_eye_data_clean", L_clean)
        setattr(block, "right_eye_data_clean", R_clean)
        # optional compatibility mirror
        if hasattr(block, "left_eye_df"):
            setattr(block, "left_eye_df_clean", L_clean.copy())
        if hasattr(block, "right_eye_df"):
            setattr(block, "right_eye_df_clean", R_clean.copy())
        return {"left_eye_data_clean": L_clean, "right_eye_data_clean": R_clean}

    # --- filter to this block only ---
    a_str = str(block.animal_call)
    b_str = str(block.block_num)
    ann = annotations_df.copy()

    # normalize expected columns
    needed = {"animal","block","eye","start_ms","end_ms", bad_col}
    missing = [c for c in needed if c not in ann.columns]
    if missing:
        raise ValueError(f"annotations_df missing required columns: {missing}")

    ann["animal"] = ann["animal"].astype(str)
    ann["block"] = ann["block"].astype(str)
    ann = ann[(ann["animal"] == a_str) & (ann["block"] == b_str)]

    if ann.empty:
        # no intervals for this block
        L_clean = getattr(block, "left_eye_data").copy()
        R_clean = getattr(block, "right_eye_data").copy()
        if add_mask_columns:
            L_clean["clean_badmask"] = False
            R_clean["clean_badmask"] = False
        setattr(block, "left_eye_data_clean", L_clean)
        setattr(block, "right_eye_data_clean", R_clean)
        if hasattr(block, "left_eye_df"):
            setattr(block, "left_eye_df_clean", L_clean.copy())
        if hasattr(block, "right_eye_df"):
            setattr(block, "right_eye_df_clean", R_clean.copy())
        return {"left_eye_data_clean": L_clean, "right_eye_data_clean": R_clean}

    # Keep only rows explicitly tagged BAD == True
    def _coerce_bool(v):
        if isinstance(v, bool): return v
        s = str(v).strip().lower()
        if s in {"true","t","1","yes","y"}: return True
        if s in {"false","f","0","no","n","nan","none",""}: return False
        return False

    ann["_is_bad"] = ann[bad_col].map(_coerce_bool)
    ann_bad = ann[ann["_is_bad"]].copy()
    if ann_bad.empty:
        # nothing to mask
        L_clean = getattr(block, "left_eye_data").copy()
        R_clean = getattr(block, "right_eye_data").copy()
        if add_mask_columns:
            L_clean["clean_badmask"] = False
            R_clean["clean_badmask"] = False
        setattr(block, "left_eye_data_clean", L_clean)
        setattr(block, "right_eye_data_clean", R_clean)
        if hasattr(block, "left_eye_df"):
            setattr(block, "left_eye_df_clean", L_clean.copy())
        if hasattr(block, "right_eye_df"):
            setattr(block, "right_eye_df_clean", R_clean.copy())
        return {"left_eye_data_clean": L_clean, "right_eye_data_clean": R_clean}

    # --- pull eye data, sanity checks ---
    L = getattr(block, "left_eye_data")
    R = getattr(block, "right_eye_data")
    for df_name, df in (("left_eye_data", L), ("right_eye_data", R)):
        if cols.ms not in df.columns:
            raise ValueError(f"{df_name} missing time column '{cols.ms}'")
        for vc in value_cols:
            if vc not in df.columns:
                raise ValueError(f"{df_name} missing value column '{vc}'")

    # --- build masks from intervals ---
    msL = pd.to_numeric(L[cols.ms], errors="coerce").values.astype(float)
    msR = pd.to_numeric(R[cols.ms], errors="coerce").values.astype(float)
    maskL = np.zeros(msL.shape, dtype=bool)
    maskR = np.zeros(msR.shape, dtype=bool)

    # Accept eye in {'L','R','LR'} (case-insensitive); treat 'LR' as both eyes
    for _, row in ann_bad.iterrows():
        s = float(row["start_ms"])
        e = float(row["end_ms"])
        eye = str(row["eye"]).upper().strip() if pd.notna(row["eye"]) else "LR"
        if e < s:
            s, e = e, s  # swap if misordered

        if eye in ("L", "LR"):
            maskL |= (msL >= s) & (msL <= e)
        if eye in ("R", "LR"):
            maskR |= (msR >= s) & (msR <= e)

    # --- create cleaned copies and set NaNs in requested columns ---
    L_clean = L.copy()
    R_clean = R.copy()
    for vc in value_cols:
        L_clean.loc[maskL, vc] = np.nan
        R_clean.loc[maskR, vc] = np.nan

    if add_mask_columns:
        L_clean["clean_badmask"] = maskL
        R_clean["clean_badmask"] = maskR

    # --- attach back to block (primary names) ---
    setattr(block, "left_eye_data_clean", L_clean)
    setattr(block, "right_eye_data_clean", R_clean)

    # --- optional compatibility: if the project also uses *_eye_df names, mirror the cleans ---
    if hasattr(block, "left_eye_df"):
        setattr(block, "left_eye_df_clean", L_clean.copy())
    if hasattr(block, "right_eye_df"):
        setattr(block, "right_eye_df_clean", R_clean.copy())

    return {"left_eye_data_clean": L_clean, "right_eye_data_clean": R_clean}


In [None]:
apply_manual_outlier_cleanup(block,reviewed_df)