In [1]:

import datetime
import numpy as np
import cv2
from itertools import cycle
import pickle
import pathlib
import math
import tqdm
import scipy.io
from matplotlib import pyplot as plt
import scipy.io
import h5py
import re
from lxml import etree as ET
import scipy.signal as sig
import pandas as pd
from scipy.stats import kde
from BlockSync_current import BlockSync
import UtilityFunctions_newOE as uf
from scipy import signal
import bokeh
import seaborn as sns
from matplotlib import rcParams

rcParams['pdf.fonttype'] = 42  # Ensure fonts are embedded and editable
rcParams['ps.fonttype'] = 42  # Ensure compatibility with vector outputs
%matplotlib inline

def bokeh_plotter(data_list, x_axis_list=None, label_list=None,
                  plot_name='default',
                  x_axis_label='X', y_axis_label='Y',
                  peaks=None, peaks_list=False, export_path=False):
    """Generates an interactive Bokeh plot for the given data vector.
    Args:
        data_list (list or array): The data to be plotted.
        label_list (list of str): The labels of the data vectors
        plot_name (str, optional): The title of the plot. Defaults to 'default'.
        x_axis (str, optional): The label for the x-axis. Defaults to 'X'.
        y_axis (str, optional): The label for the y-axis. Defaults to 'Y'.
        peaks (list or array, optional): Indices of peaks to highlight on the plot. Defaults to None.
        export_path (False or str): when set to str, will output the resulting html fig
    """
    color_cycle = cycle(bokeh.palettes.Category10_10)
    fig = bokeh.plotting.figure(title=f'bokeh explorer: {plot_name}',
                                x_axis_label=x_axis_label,
                                y_axis_label=y_axis_label,
                                plot_width=1500,
                                plot_height=700)

    for i, data_vector in enumerate(data_list):

        color = next(color_cycle)

        if x_axis_list is None:
            x_axis = range(len(data_vector))
        elif len(x_axis_list) == len(data_list):
            print('x_axis manually set')
            x_axis = x_axis_list[i]
        else:
            raise Exception(
                'problem with x_axis_list input - should be either None, or a list with the same length as data_list')
        if label_list is None:
            fig.line(x_axis, data_vector, line_color=color, legend_label=f"Line {i + 1}")
        elif len(label_list) == len(data_list):
            fig.line(range(len(data_vector)), data_vector, line_color=color, legend_label=f"{label_list[i]}")
        if peaks is not None and peaks_list is True:
            fig.circle(peaks[i], data_vector[peaks[i]], size=10, color=color)

    if peaks is not None and peaks_list is False:
        fig.circle(peaks, data_vector[peaks], size=10, color='red')

    if export_path is not False:
        print(f'exporting to {export_path}')
        bokeh.io.output.output_file(filename=str(export_path / f'{plot_name}.html'), title=f'{plot_name}')
    bokeh.plotting.show(fig)


def load_eye_data_2d_w_rotation_matrix(block):
    """
    This function checks if the eye dataframes and rotation dict object exist, then imports them
    :param block: The current blocksync class with verifiec re/le dfs
    :return: None
    """
    try:
        block.left_eye_data = pd.read_csv(block.analysis_path / 'left_eye_data.csv', index_col=0, engine='python')
        block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data.csv', index_col=0, engine='python')
    except FileNotFoundError:
        print('eye_data files not found, run the pipeline!')
        return

    try:
        with open(block.analysis_path / 'rotate_eye_data_params.pkl', 'rb') as file:
            rotation_dict = pickle.load(file)
            block.left_rotation_matrix = rotation_dict['left_rotation_matrix']
            block.right_rotation_matrix = rotation_dict['right_rotation_matrix']
            block.left_rotation_angle = rotation_dict['left_rotation_angle']
            block.right_rotation_angle = rotation_dict['right_rotation_angle']
    except FileNotFoundError:
        print('No rotation matrix file, create it')


def create_saccade_events_df(eye_data_df, speed_threshold, bokeh_verify_threshold=False, magnitude_calib=1,
                             speed_profile=True):
    """
    Detects saccade events in eye tracking data and computes relevant metrics.

    Parameters:
    - eye_data_df (pd.DataFrame): Input DataFrame containing eye tracking data.
    - speed_threshold (float): Threshold for saccade detection based on speed.

    Returns:
    - df (pd.DataFrame): Modified input DataFrame with added columns for speed and saccade detection.
    - saccade_events_df (pd.DataFrame): DataFrame containing information about detected saccade events.

    Steps:
    1. Calculate speed components ('speed_x', 'speed_y') based on differences in 'center_x' and 'center_y'.
    2. Compute the magnitude of the velocity vector ('speed_r').
    3. Create a binary column ('is_saccade') indicating saccade events based on the speed threshold.
    4. Determine saccade onset and offset indices and timestamps.
    5. Create a DataFrame ('saccade_events_df') with columns:
        - 'saccade_start_ind': Indices of saccade onset.
        - 'saccade_start_timestamp': Timestamps corresponding to saccade onset.
        - 'saccade_end_ind': Indices of saccade offset.
        - 'saccade_end_timestamp': Timestamps corresponding to saccade offset.
        - 'length': Duration of each saccade event.
    6. Calculate distance traveled and angles for each saccade event.
    7. Append additional columns to 'saccade_events_df':
        - 'magnitude': Magnitude of the distance traveled during each saccade.
        - 'angle': Angle of the saccade vector in degrees.
        - 'initial_x', 'initial_y': Initial coordinates of the saccade.
        - 'end_x', 'end_y': End coordinates of the saccade.

    Note: The original 'eye_data_df' is not modified; modified data is returned as 'df'.
    """
    df = eye_data_df
    df['speed_x'] = df['center_x'].diff()  # Difference between consecutive 'center_x' values
    df['speed_y'] = df['center_y'].diff()  # Difference between consecutive 'center_y' values

    # Step 2: Calculate magnitude of the velocity vector (R vector speed)
    df['speed_r'] = (df['speed_x'] ** 2 + df['speed_y'] ** 2) ** 0.5

    # Create a column for saccade detection
    df['is_saccade'] = df['speed_r'] > speed_threshold

    # create a saccade_on_off indicator where 1 is rising edge and -1 is falling edge by subtracting a shifted binary mask
    saccade_on_off = df.is_saccade.astype(int) - df.is_saccade.shift(periods=1, fill_value=False).astype(int)
    saccade_on_inds = np.where(saccade_on_off == 1)[
                          0] - 1  # notice the manual shift here, chosen to include the first (sometimes slower) eye frame, just before saccade threshold crossing
    saccade_on_ms = df['ms_axis'].iloc[saccade_on_inds]
    saccade_on_timestamps = df['OE_timestamp'].iloc[saccade_on_inds]
    saccade_off_inds = np.where(saccade_on_off == -1)[0]
    saccade_off_timestamps = df['OE_timestamp'].iloc[saccade_off_inds]
    saccade_off_ms = df['ms_axis'].iloc[saccade_off_inds]

    saccade_dict = {'saccade_start_ind': saccade_on_inds,
                    'saccade_start_timestamp': saccade_on_timestamps.values,
                    'saccade_end_ind': saccade_off_inds,
                    'saccade_end_timestamp': saccade_off_timestamps.values,
                    'saccade_on_ms': saccade_on_ms.values,
                    'saccade_off_ms': saccade_off_ms.values}

    saccade_events_df = pd.DataFrame.from_dict(saccade_dict)
    saccade_events_df['length'] = saccade_events_df['saccade_end_ind'] - saccade_events_df['saccade_start_ind']
    # Drop columns used for intermediate steps
    df = df.drop(['is_saccade'], axis=1)

    distances = []
    angles = []
    speed_list = []
    diameter_list = []
    for index, row in tqdm.tqdm(saccade_events_df.iterrows()):
        saccade_samples = df.loc[(df['OE_timestamp'] >= row['saccade_start_timestamp']) &
                                 (df['OE_timestamp'] <= row['saccade_end_timestamp'])]
        distance_traveled = saccade_samples['speed_r'].sum()
        if speed_profile:
            saccade_speed_profile = saccade_samples['speed_r'].values
            speed_list.append(saccade_speed_profile)
        saccade_diameter_profile = saccade_samples['pupil_diameter'].values
        diameter_list.append(saccade_diameter_profile)
        # Calculate angle from initial position to endpoint
        initial_position = saccade_samples.iloc[0][['center_x', 'center_y']]
        endpoint = saccade_samples.iloc[-1][['center_x', 'center_y']]
        overall_angle = np.arctan2(endpoint['center_y'] - initial_position['center_y'],
                                   endpoint['center_x'] - initial_position['center_x'])

        angles.append(overall_angle)
        distances.append(distance_traveled)

    saccade_events_df['magnitude_raw'] = np.array(distances)
    saccade_events_df['magnitude'] = np.array(distances) * magnitude_calib
    saccade_events_df['angle'] = np.where(np.isnan(angles), angles, np.rad2deg(
        angles) % 360)  # Convert radians to degrees and ensure result is in [0, 360)
    start_ts = saccade_events_df['saccade_start_timestamp'].values
    end_ts = saccade_events_df['saccade_end_timestamp'].values
    saccade_start_df = df[df['OE_timestamp'].isin(start_ts)]
    saccade_end_df = df[df['OE_timestamp'].isin(end_ts)]
    start_x_coord = saccade_start_df['center_x']
    start_y_coord = saccade_start_df['center_y']
    end_x_coord = saccade_end_df['center_x']
    end_y_coord = saccade_end_df['center_y']
    saccade_events_df['initial_x'] = start_x_coord.values
    saccade_events_df['initial_y'] = start_y_coord.values
    saccade_events_df['end_x'] = end_x_coord.values
    saccade_events_df['end_y'] = end_y_coord.values
    saccade_events_df['calib_dx'] = (saccade_events_df['end_x'].values - saccade_events_df[
        'initial_x'].values) * magnitude_calib
    saccade_events_df['calib_dy'] = (saccade_events_df['end_y'].values - saccade_events_df[
        'initial_y'].values) * magnitude_calib
    if speed_profile:
        saccade_events_df['speed_profile'] = speed_list
    saccade_events_df['diameter_profile'] = diameter_list
    if bokeh_verify_threshold:
        bokeh_plotter(data_list=[df.speed_r], label_list=['Pupil Velocity'], peaks=saccade_on_inds)

    return df, saccade_events_df


# create a multi-animal block_collection:

def create_block_collections(animals, block_lists, experiment_path, bad_blocks=None):
    """
    Create block collections and a block dictionary from multiple animals and their respective block lists.

    Parameters:
    - animals: list of str, names of the animals.
    - block_lists: list of lists of int, block numbers corresponding to each animal.
    - experiment_path: pathlib.Path, path to the experiment directory.
    - bad_blocks: list of int, blocks to exclude. Default is an empty list.

    Returns:
    - block_collection: list of BlockSync objects for all specified blocks.
    - block_dict: dictionary where keys are block numbers as strings and values are BlockSync objects.
    """
    import UtilityFunctions_newOE as uf

    if bad_blocks is None:
        bad_blocks = []

    block_collection = []
    block_dict = {}

    for animal, blocks in zip(animals, block_lists):
        # Generate blocks for the current animal
        current_blocks = uf.block_generator(
            block_numbers=blocks,
            experiment_path=experiment_path,
            animal=animal,
            bad_blocks=bad_blocks
        )
        # Add to collection and dictionary
        block_collection.extend(current_blocks)
        for b in current_blocks:
            block_dict[f"{animal}_block_{b.block_num}"] = b

    return block_collection, block_dict


animals = ['PV_126']
block_lists = [[7]]
experiment_path = pathlib.Path(r"Z:\Nimrod\experiments")
bad_blocks = [0]  # Example of bad blocks

block_collection, block_dict = create_block_collections(
    animals=animals,
    block_lists=block_lists,
    experiment_path=experiment_path,
    bad_blocks=bad_blocks
)
ref_points = pd.read_csv(r'Z:\Nimrod\experiments\cross_animals_data\kerr_reference_all_animals.csv')
for block in block_collection:
    block.parse_open_ephys_events()
    block.get_eye_brightness_vectors()
    block.synchronize_block()
    block.create_eye_brightness_df(threshold_value=20)

    # if the code fails here, go to manual synchronization
    block.import_manual_sync_df()
    block.read_dlc_data()
    block.calibrate_pixel_size(10)
    #load_eye_data_2d_w_rotation_matrix(block)  #should be integrated again... later

    block.left_eye_data = pd.read_csv(block.analysis_path / 'left_eye_data_original.csv')
    block.right_eye_data = pd.read_csv(block.analysis_path / 'right_eye_data_original.csv')



instantiated block number 007 at Path: Z:\Nimrod\experiments\PV_126\2024_07_18\block_007, new OE version
Found the sample rate for block 007 in the xml file, it is 20000 Hz
created the .oe_rec attribute as an open ephys recording obj with get_data functionality
retrieving zertoh sample number for block 007
got it!
running parse_open_ephys_events...
block 007 has a parsed events file, reading...
Getting eye brightness values for block 007...
Found an existing file!
Eye brightness vectors generation complete.
blocksync_df loaded from analysis folder
eye_brightness_df loaded from analysis folder
eye dataframes loaded from analysis folder
got the calibration values from the analysis folder


In [3]:
# Synchronized playback based on block.final_sync_df  
import cv2
import numpy as np
import pandas as pd
import time


def display_synchronized_frames(sync_df, arena_vid, left_vid, right_vid):
    """
    Displays synchronized frames from three video sources with options for manual control and rolling playback.

    Parameters:
        sync_df (pd.DataFrame): A dataframe with columns ['Arena_frame', 'L_eye_frame', 'R_eye_frame'].
        arena_vid (str): Path to the arena video file.
        left_vid (str): Path to the left eye video file.
        right_vid (str): Path to the right eye video file.
    """
    # Open video capture objects
    arena_cap = cv2.VideoCapture(arena_vid)
    left_cap = cv2.VideoCapture(left_vid)
    right_cap = cv2.VideoCapture(right_vid)

    if not (arena_cap.isOpened() and left_cap.isOpened() and right_cap.isOpened()):
        print("Error: Could not open one or more video files.")
        return

    # Function to resize frames while keeping aspect ratio
    def resize_frame(frame, scale, canvas):
        new_w, new_h = int(frame.shape[1] * scale), int(frame.shape[0] * scale)
        resized = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
        canvas.fill(0)
        canvas[(canvas.shape[0] - new_h) // 2:(canvas.shape[0] - new_h) // 2 + new_h,
        (canvas.shape[1] - new_w) // 2:(canvas.shape[1] - new_w) // 2 + new_w] = resized
        return canvas

    # Parameters for display
    window_size = (640, 480)  # Default window size for each video
    cv2.namedWindow("Arena View", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Left Eye View", cv2.WINDOW_NORMAL)
    cv2.namedWindow("Right Eye View", cv2.WINDOW_NORMAL)
    cv2.resizeWindow("Arena View", window_size[0], window_size[1])
    cv2.resizeWindow("Left Eye View", window_size[0], window_size[1])
    cv2.resizeWindow("Right Eye View", window_size[0], window_size[1])

    # Precompute scaling factors and allocate canvases
    scale_arena = min(window_size[0] / arena_cap.get(cv2.CAP_PROP_FRAME_WIDTH),
                      window_size[1] / arena_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    scale_eye = min(window_size[0] / left_cap.get(cv2.CAP_PROP_FRAME_WIDTH),
                    window_size[1] / left_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    arena_canvas = np.zeros((window_size[1], window_size[0], 3), dtype=np.uint8)
    left_canvas = np.zeros((window_size[1], window_size[0], 3), dtype=np.uint8)
    right_canvas = np.zeros((window_size[1], window_size[0], 3), dtype=np.uint8)

    current_idx = 0
    max_idx = len(sync_df)
    rolling_playback = False
    slo_mo = 0
    # Main loop for displaying frames
    while True:
        row = sync_df.iloc[current_idx]
        arena_frame, left_frame, right_frame = row['Arena_frame'], row['L_eye_frame'], row['R_eye_frame']

        # Retrieve frames (sequential read during rolling playback, seeking otherwise)
        if rolling_playback:
            if slo_mo % 2 == 0:
                time.sleep(1 / 60)  # Playback at 60 fps

            if current_idx < max_idx - 1:
                current_idx += 1
                arena_img = arena_cap.read()[1] if not pd.isna(arena_frame) else None
                left_img = left_cap.read()[1] if not pd.isna(left_frame) else None
                right_img = right_cap.read()[1] if not pd.isna(right_frame) else None
            else:
                rolling_playback = False  # Stop playback at the last frame
        else:
            arena_cap.set(cv2.CAP_PROP_POS_FRAMES, arena_frame if not pd.isna(arena_frame) else 0)
            left_cap.set(cv2.CAP_PROP_POS_FRAMES, left_frame if not pd.isna(left_frame) else 0)
            right_cap.set(cv2.CAP_PROP_POS_FRAMES, right_frame if not pd.isna(right_frame) else 0)
            arena_img = arena_cap.read()[1] if not pd.isna(arena_frame) else None
            left_img = left_cap.read()[1] if not pd.isna(left_frame) else None
            right_img = right_cap.read()[1] if not pd.isna(right_frame) else None

        # Prepare display images
        def prepare_display(image, scale, canvas, flip=False):
            if image is None:
                display = np.zeros(canvas.shape, dtype=np.uint8)
                cv2.putText(display, "No synchronized frame", (50, canvas.shape[0] // 2),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            else:
                if flip:
                    image = cv2.flip(cv2.flip(image, 0), 1)  # Flip vertically and horizontally
                display = resize_frame(image, scale, canvas)
            return display

        arena_display = prepare_display(arena_img, scale_arena, arena_canvas)
        left_display = prepare_display(left_img, scale_eye, left_canvas, flip=True)
        right_display = prepare_display(right_img, scale_eye, right_canvas, flip=True)

        # Show frames in separate windows
        cv2.imshow("Arena View", arena_display)
        cv2.imshow("Left Eye View", left_display)
        cv2.imshow("Right Eye View", right_display)

        # Clear the console and print updated frame info
        print(
            f"\rFrame Info: Arena: {arena_frame}, Left Eye: {left_frame}, Right Eye: {right_frame} (Index: {current_idx})",
            end='', flush=True)

        # Handle keyboard events
        if not rolling_playback:
            key = cv2.waitKey(0) & 0xFF
            if key == ord('q'):  # Quit
                break
            elif key == ord('n'):  # Step backward
                current_idx = max(0, current_idx - 1)
            elif key == ord('c'):  # Step backward by 30 frames
                current_idx = max(0, current_idx - 30)
            elif key == ord('m'):  # Step forward
                current_idx = min(max_idx - 1, current_idx + 1)
            elif key == ord('v'):  # Step forward by 30 frames
                current_idx = min(max_idx - 1, current_idx + 30)
            elif key == ord('f'):  # Step forward by 120 frames
                current_idx = min(max_idx - 1, current_idx + 120)
            elif key == ord('p'):  # Start rolling playback
                rolling_playback = True
        else:
            if cv2.waitKey(1) & 0xFF == ord('p'):  # Stop rolling playback
                rolling_playback = False
            elif cv2.waitKey(1) & 0xFF == ord('o'):
                slo_mo += 1

    # Release video captures and destroy windows
    arena_cap.release()
    left_cap.release()
    right_cap.release()
    cv2.destroyAllWindows()


arena_vid = r'Z:\Nimrod\experiments\PV_126\2024_07_18\block_007\arena_videos\videos\top_20240718T124930.mp4'
left_vid = block.le_videos[0]
right_vid = block.re_videos[0]
sync_df = block.final_sync_df
display_synchronized_frames(sync_df, arena_vid=arena_vid, left_vid=left_vid, right_vid=right_vid)

Frame Info: Arena: 2144.0, Left Eye: 2569.0, Right Eye: 2566.0 (Index: 2559)