### Analyze + Plot Eye Saccades
The goal of this script is to determine wether the animal's saccades
are aligned with the location of the stimulus. In other words, if the
animal is looking at the salient part of the screen. The script:

  - Requires a path to a folder on the lab's server (X:\)
  - Loads right- and left-eye Bonsai CSV outputs, IMU & stimulus TTL files.
  - Computes eye-centric gaze, detects saccades (optionally removing blinks).
  - Creates three kinds of figures and writes them to  <folder_path>/Results/ :
      ─ “ALL”  : quiver of every saccade  + polar & linear angle histograms
      ─ LR/UD  : same layout but restricted to each stimulus direction
  - Global X/Y limits are derived from the full session; every quiver plot
    shares them so comparisons are direct.
  - Figure filenames:  <session>_<Eye>_<condition>.png
      e.g.  Tsh001_2025-06-11T13_02_29_Right_ALL_Interleaved.png


### How saccades are detected in this script

1. **Eye-centred coordinates**  
   Subtract the mid-point of the two eye-corner markers so gaze is relative to the eye, not the camera.

2. **Denoise (`medfilt_vec`)**  
   A 3-point median filter removes single-frame tracking jitter.

3. **Pixel → degree conversion**  
   Divide by the calibration factor ( ≈ 3.76 px = 1 °) to work in visual degrees.

4. **Instantaneous velocity (`frame_velocity`)**  
   Take the frame-to-frame difference in x and y; combine into a scalar speed (deg / frame).

5. **Speed threshold**  
   If speed ≥ `1.5 deg/frame`, mark that frame as a saccade (`saccade_indices`).

6. **Blink removal (`blink_frames`, optional)**  
   When VD-axis files are present, large jumps in eyelid separation flag blinks; those frames are deleted from `saccade_indices`.

The cleaned list **`saccade_indices`** feeds every quiver plot, PCA arrow, and angle histogram in the notebook.




Tested with Python 3.12 (Conda env “EyeHeadCoupling”).

  Author:  <Ratnadeep Pal @SarvestaniLab>   –  Last update: 2025-06-15

## Import the required libraries


In [None]:

import sys
import os
import numpy as np
import pandas as pd
from scipy.fft import fft
from scipy.signal import ShortTimeFFT,butter,hilbert,sosfiltfilt,medfilt
from scipy.signal.windows import gaussian
import scipy.stats as stats
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import tkinter as tk
from tkinter import filedialog
from sklearn.decomposition import PCA
import os
from io import StringIO
from matplotlib import gridspec
from pathlib import Path
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
import re
from datetime import datetime
from matplotlib.patches import FancyArrowPatch
from itertools import cycle
from matplotlib import cm
from matplotlib.collections import LineCollection
import matplotlib

%matplotlib qt


## Define parameters

In [None]:
cal = 3.76  # Calibration factor for the pixels to degrees
ttl_freq = 60  # TTL frequency in Hz

# Parameters for nlink and saccade detection
blink_detection = 1
blink_thresh= 10
saccade_thresh= 1.0
torsion_velocity_thresh = 1.5
saccade_win=0.7    # Window size for saccade detection in seconds

#folder_path = select_folder() #this won't work if you're running jupyter lab in browser, so hard coding


###################################### Paris

#First day when we started doing interleaved stim.
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-09T13_09_02\\" # interleaved

#Second day where she licked a lot!
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T12_50_45\\" #no stim
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T13_02_29\\" # interleaved
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T13_15_39\\" #  interleaved

#Not much licking on this day
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-13T13_07_55\\" #interleaved
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-13T13_25_04\\" #no-stim session

#Motivated on this day, but first day where juice was only given for saccades
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-16T12_13_51\\" #interleaved

#second day juice was given for saccades only but she was stressed 
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-17T12_50_28\\" #interleaved

#third day juce was given for saccades only, but she was not thirsty
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-18T12_19_39\\" #interleaved

#fourth day, but she didn't pay attention the whole time
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-19T12_17_52\\" #interleaved


folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-30T15_20_56\\" #just L/R


#this is after ratnadeep fixed a bunch of issues with the code! and data looks great!!
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-02T15_14_29\\" #just L/R


#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-07T14_49_12\\" #just L/R
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-07T15_12_22\\" #just L/R

#good session!
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-08T15_04_12\\" #just L/R

#good session! first day we tried U/D stim
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-09T15_29_43\\" #just L/R
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-09T15_49_30\\" #L/R and U/D 

#tough day, she was stressed and didn't pay attention
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-10T15_36_39\\" # U/D --bad
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-10T15_47_04\\" #L/R and U/D 
######################################## Bayleaf
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\Rat22_Bayleaf_server\Rat022_2025-06-10T16_23_02\\"   #no stim
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\Rat22_Bayleaf_server\Rat022_2025-06-10T16_10_21\\"   #no stim

# pretty bad performance this day, used wrong head bar
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-11T14_22_34\\" #L/R and U/D 
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-11T14_49_16\\" #L/R and U/D 


#came in on a sunday to do a session, figuring out she'd be motivated and calm
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-13T17_38_53\\" #L/R and U/D 
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-13T17_46_15\\" #L/R and U/D, moved u/down closer
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-13T17_52_47\\" #U/D only
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-13T18_19_05\\" #L/R only

#good long session today
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T14_58_27\\" #L/R only
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_04_46\\" #L/R only
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_11_17\\" #interleaved
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_16_32\\" #interleaved
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_26_38\\" #Up/Down
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_39_13\\" #Up / 2 levels
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_49_51\\" #Up / Down 2 levels
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-14T15_56_54\\" #Up / Down 1 levels

#good long session today, where we looked at torsion online
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-15T16_16_03\\" #interleaved
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-15T16_36_23\\" #U/D only
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-15T16_45_37\\" #L/R


################################################### #torsion training
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-17T15_32_42\\" #interleaved


folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-21T15_08_33\\" #Up/down
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-21T15_45_48\\" #Up/down
#testing monitor positions to see if down torsion gets better
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-22T15_02_57\\" #Up/down
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-22T15_25_56\\" #Up/down

##################################################### fixation training ####################
# #first day where we started punishing for not fixating during blue spot
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-30T15_29_31\\" #interleaved
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-30T15_39_26\\" #interleaved
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-30T15_47_15\\" #interleaved

# # #second day where we punish for not fixating during blue spot
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-31T15_23_59\\" #interleaved
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-31T15_56_37\\" #interleaved

#third day where we punish for not fixating during blue spot
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-01T15_15_48\\" #interleaved
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-06T16_40_40\\" #interleaved


###################################################################### anti-saccade training
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-11T16_40_24\\" 
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-12T15_04_14\\"
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-13T14_48_16\\" #interleaved
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-13T15_02_30\\" #interleaved
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-14T15_06_12\\" #interleaved
# folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-08-14T15_12_24\\" #interleaved



results_dir = Path(folder_path) / "Results\\"
results_dir.mkdir(exist_ok=True)

session_name = os.path.basename(folder_path.rstrip("/\\"))




## Define functions


In [None]:

def get_session_date_from_path(path):
    match = re.search(r"\d{4}-\d{2}-\d{2}", path)
    if match:
        return datetime.strptime(match.group(), "%Y-%m-%d")
    else:
        raise ValueError("No valid date (YYYY-MM-DD) found in path")
    
def determine_camera_side(path, cutoff_date_str="2025-06-30"):
    session_date = get_session_date_from_path(path)
    cutoff_date = datetime.strptime(cutoff_date_str, "%Y-%m-%d")
    return "L" if session_date >= cutoff_date else "R"


# Prompt the user to select a folder
def select_folder():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    directory = filedialog.askdirectory()  # Open the file selection dialog
    return directory

# Prompt the user to open a file 
def select_file():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    file_path = filedialog.askopenfilename()  # Open the file selection dialog
    return file_path

def choose_option(option1, option2, option3, option4):
    result = {}

    def select(choice):
        result['value'] = choice
        root.destroy()

    root = tk.Tk()
    root.title("Choose the type of visual stim")

    tk.Label(root, text="Please choose the type of visual stim:").pack(pady=10)
    tk.Button(root, text=option1, width=12, command=lambda: select(option1)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option2, width=12, command=lambda: select(option2)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option3, width=12, command=lambda: select(option3)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option4, width=12, command=lambda: select(option4)).pack(side='left', padx=10, pady=10)

    # Manual event loop, blocks until window is destroyed
    while not result.get('value'):
        root.update()

    return result['value']

# Prompt the user to choose the type of visual stim
#stim_type = choose_option("None","LR","UD","Interleaved")


# Function to remove parentheses characters from a line
def remove_parentheses_chars(line):
    # Remove only '(' and ')' characters
    return line.replace('(', '').replace(')', '').replace('True', '1').replace('False', '0')
def clean_csv(filename):
    with open(filename, 'r') as f:
        lines = [remove_parentheses_chars(line) for line in f]
    # Join lines and create a file-like object
        cleaned = StringIO(''.join(lines))
        return cleaned

# Butterworth filter to remove high frequency noise
def butter_noncausal(signal, fs, cutoff_freq=1, order=4):
    sos = butter(order, cutoff_freq/(fs/2), btype='low', output='sos')  # 50 Hz cutoff frequency
    return sosfiltfilt(sos, signal)   

def interpolate_nans(arr):
    nans = np.isnan(arr)
    x = np.arange(len(arr))
    arr[nans] = np.interp(x[nans], x[~nans], arr[~nans])
    return arr

def rotation_matrix(angle_rad):
    return np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
                     [np.sin(angle_rad), np.cos(angle_rad)]])


pca = PCA(n_components=2)

def vector_to_rgb(angle, absolute): ##Got it from https://stackoverflow.com/questions/19576495/color-matplotlib-quiver-field-according-to-magnitude-and-direction
    global max_abs

    # normalize angle
    angle = angle % (2 * np.pi)
    if angle < 0:
        angle += 2 * np.pi

    # return matplotlib.colors.hsv_to_rgb((angle / 2 / np.pi, 
    #                                      absolute / max_abs, 
    #                                      absolute / max_abs))
    return matplotlib.colors.hsv_to_rgb((angle / 2 / np.pi, 
                                         1, 
                                         1))


def plot_angle_distribution(angle, ax_polar, num_bins=18):
    """
    Plots a normalized polar histogram of angles.

    Parameters:
        angle (np.ndarray): array of saccade angles in radians
        ax_polar (matplotlib.axes._subplots.PolarAxesSubplot): the polar subplot to draw on
        num_bins (int): number of histogram bins
    """
    angle_2pi = np.where(angle < 0, angle + 2 * np.pi, angle)
    counts, bin_edges = np.histogram(angle_2pi, bins=num_bins, range=(0, 2 * np.pi))
    counts = counts / np.size(angle_2pi)  # Normalize
    width = np.diff(bin_edges)

    bars = ax_polar.bar(bin_edges[:-1], counts, width=width, align='edge', color='b', alpha=0.5, edgecolor='k')
    ax_polar.set_title("Normalized angle distribution")
    ax_polar.set_yticklabels([])

def plot_linear_histogram(angles, ax, num_bins=18):
    ang_deg = np.degrees(angles)
    ang_deg = np.mod(ang_deg, 360)
    counts, bins = np.histogram(ang_deg, bins=num_bins, range=(0, 360))
    counts = counts / ang_deg.size
    ax.bar(bins[:-1], counts, width=np.diff(bins), color="b", alpha=0.5, edgecolor="k")
    ax.set_xlabel("Angle (deg)")
    ax.set_ylabel("Normalised count")
    ax.set_title("Linear angle histogram")    

def detect_saccades(
    marker1_x, marker1_y, marker2_x, marker2_y,
    gaze_x, gaze_y,
    eye_frames,
    calibration_factor,
    blink_velocity_threshold, 
    saccade_threshold,
    blink_detection=0,
    vd_axis_lx=None, vd_axis_ly=None, vd_axis_rx=None, vd_axis_ry=None,
    torsion_angle=None,
    saccade_threshold_torsion=None,  
):
    ################# Analyze eye saccades function #################
    """
    #  Compute eye position in eye-centered coordinates.
    # Filter and convert raw gaze to degrees using calibration factor.
    # Compute eye position velocity to detect saccades.
    # Identify saccades that exceed a velocity threshold.
    # (Optional) Remove saccades likely caused by blinks.
    """
    

    # 1. eye-centred coordinates  →  degrees
    eye_origin = np.column_stack(((marker1_x + marker2_x) / 2.0,
                                (marker1_y + marker2_y) / 2.0))
    eye_camera = np.column_stack((gaze_x - eye_origin[:, 0],
                                gaze_y - eye_origin[:, 1])).astype(np.float64, copy=False)
    
    eye_angle = np.arctan2(marker2_y - marker1_y, marker2_x - marker1_x)

    # tiny denoise
    eye_camera[:, 0] = medfilt(eye_camera[:, 0], kernel_size=3)
    eye_camera[:, 1] = medfilt(eye_camera[:, 1], kernel_size=3)

    #read in 1 or 2 calibration factors
    cal = np.asarray(calibration_factor, dtype=np.float64)

    if cal.ndim == 0:                # scalar px/deg (same for X,Y)
        fx = fy = float(cal)
    elif cal.shape == (2,):          # [fx, fy] per-axis px/deg
        fx, fy = float(cal[0]), float(cal[1])

    eye_camera[:, 0] /= fx
    eye_camera[:, 1] /= fy

    
    # 2. instantaneous velocity  →  speed
    dx = np.ediff1d(eye_camera[:, 0], to_begin=0)
    dy = np.ediff1d(eye_camera[:, 1], to_begin=0)
    xy_speed = np.sqrt(dx**2 + dy**2)

    xy_mask = xy_speed >= saccade_threshold

    # 3. eye position velocity  based on torsion
    if torsion_angle is not None:
        torsion_angle = interpolate_nans(torsion_angle)
        dtheta = np.ediff1d(torsion_angle, to_begin=0)
        torsion_speed = np.abs(dtheta)
    else:
        torsion_speed = np.zeros_like(xy_speed)

    # 4. Thresholding based on speed masks
    torsion_mask = torsion_speed >= (saccade_threshold_torsion or np.inf)
    
    #5. Detect saccades based on xy_speed and torsion_speed

    saccade_indices_xy = np.where(xy_mask)[0] # ← row indices (0…7158)
    saccade_frames_xy = eye_frames[saccade_indices_xy] # ← absolute Bonsai frames

    saccade_indices_theta = np.where(torsion_mask)[0] # ← row indices (0…7158)
    saccade_frames_theta = eye_frames[saccade_indices_theta] # ← absolute Bonsai frames

    # 6. Package eye positions and velocity into output (can be 3D if torsion is included)
    if torsion_angle is not None:
        eye_pos = np.column_stack([eye_camera, torsion_angle])
        eye_vel = np.column_stack([dx, dy, dtheta])
    else:
        eye_pos = eye_camera
        eye_vel = np.column_stack([dx, dy])





    # Plot saccade and threshold to make sure it's detected
    fig, (ax, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

    frames = np.arange(len(xy_speed))

    # ─── Plot XY (translational) saccades ───
    ax.plot(frames, xy_speed, linewidth=0.8, label='Speed (°/frame)')
    ax.scatter(saccade_indices_xy, xy_speed[saccade_indices_xy],
            color='tab:red', s=12, label='Saccade idx')
    ax.axhline(saccade_threshold, color='tab:orange',
            linestyle='--', label=f'Threshold = {saccade_threshold}')
    ax.set_ylabel('Speed (° / frame)')
    ax.set_title('Instantaneous XY speed with detected saccade frames')
    ax.legend()
    ax.grid(alpha=.3)

    # ─── Plot torsional saccades ───
    ax2.plot(frames, torsion_speed, linewidth=0.8, label='Torsion Speed (°/frame)')
    ax2.scatter(saccade_indices_theta, torsion_speed[saccade_indices_theta],
                color='tab:purple', s=12, label='Torsion idx')
    ax2.axhline(saccade_threshold_torsion, color='tab:purple',
                linestyle='--', label=f'Threshold = {saccade_threshold_torsion}')
    ax2.set_xlabel('Frame number')
    ax2.set_ylabel('Torsion Speed (° / frame)')
    ax2.set_title('Instantaneous torsion speed with detected torsional saccades')
    ax2.legend()
    ax2.grid(alpha=.3)

    plt.tight_layout()
    plt.show()


    # save alongside other figures
    prob_fname = f"{session_name}_saccades.png"
    fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')


    # 3. optional blink removal
    if blink_detection:
        vd_axis_left = np.vstack([vd_axis_lx, vd_axis_ly]).T
        vd_axis_right = np.vstack([vd_axis_rx, vd_axis_ry]).T
        vd_axis_d = np.linalg.norm(vd_axis_right - vd_axis_left, axis=1)
        vd_axis_vel = np.gradient(vd_axis_d)
        blink_indices = np.where(np.abs(vd_axis_vel) > blink_velocity_threshold)
        saccade_indices_xy = saccade_indices_xy[~np.isin(saccade_indices_xy, blink_indices)]

    return {
        "eye_pos":          eye_pos, #in degrees
        "eye_vel":          eye_vel, #in degrees/second
        "saccade_indices_xy":  saccade_indices_xy,
        "saccade_frames_xy":   saccade_frames_xy,
        "saccade_indices_theta":  saccade_indices_theta,
        "saccade_frames_theta":   saccade_frames_theta,
        }



def organize_stims(
    go_frame, 
    go_dir_x = None,
    go_dir_y = None):
    
    has_lr = go_dir_x is not None and np.any(go_dir_x != 0)
    has_ud = go_dir_y is not None and np.any(go_dir_y != 0)

    direction_sets = {}

    if has_lr:
        direction_sets["Left"]  = go_dir_x < 0
        direction_sets["Right"] = go_dir_x > 0
    if has_ud:
        direction_sets["Down"] = go_dir_y < 0
        direction_sets["Up"]   = go_dir_y > 0
    if not direction_sets:
        direction_sets["All"] = np.full(len(go_frame), True)

    stim_frames = {lab: go_frame[mask] for lab, mask in direction_sets.items()}

    # Return the inferred stim type too
    if has_lr and has_ud:
        stim_type = "Interleaved"
    elif has_lr:
        stim_type = "LR"
    elif has_ud:
        stim_type = "UD"
    else:
        stim_type = "None"

    return stim_frames, stim_type



def sort_plot_saccades(
    #sorts saccades by stimulus onset
    saccades,
    saccade_window,  # seconds before/after each stimulus
    session_path,
    stim_type='None',
    eye_name='Eye',

):

    eye_pos = saccades["eye_pos"]
    eye_pos_diff = saccades["eye_vel"]
    saccade_indices_xy = saccades["saccade_indices_xy"]
    saccade_frames_xy= saccades["saccade_frames_xy"]
    saccade_indices_theta= saccades["saccade_indices_theta"]
    saccade_frames_theta = saccades["saccade_frames_theta"]
    stim_frames = saccades["stim_frames"]
    session_name = os.path.basename(session_path.rstrip("/\\"))

    # ───────── global axis limits (all saccades) ─────────

    if saccade_indices_theta is not None and len(saccade_indices_theta) > 0:
        saccade_indices_theta = np.array(saccade_indices_theta, dtype=int)
        t_all = eye_pos[saccade_indices_theta, 2]
    else:
        t_all = None

    if eye_pos_diff.shape[1] == 3:
        dx, dy, dtheta = eye_pos_diff[:, 0], eye_pos_diff[:, 1], eye_pos_diff[:, 2]
        x_all, y_all = eye_pos[saccade_indices_xy, 0], eye_pos[saccade_indices_xy, 1]
        # Extract torsion angles and convert to degrees
        t_all = eye_pos[saccade_indices_theta,2] if saccade_indices_theta is not None else None
        torsion_present = True
    else:
        dx, dy = eye_pos_diff[:, 0], eye_pos_diff[:, 1]
        x_all, y_all = eye_pos[saccade_indices_xy, 0], eye_pos[saccade_indices_xy, 1]
        t_all = None
        dtheta = None
        saccade_indices_theta = None
        saccade_frames_theta = None
        torsion_present = False


    pad   = 0.10
    rngX  = x_all.max() - x_all.min()
    rngY  = y_all.max() - y_all.min()
    X_LIM = (x_all.min() - pad*rngX, x_all.max() + pad*rngX)
    Y_LIM = (y_all.min() - pad*rngY, y_all.max() + pad*rngY)
    max_abs = np.max(np.hypot(dx[saccade_indices_xy], dy[saccade_indices_xy]))
    
    #calculate angles for all translational saccades
    angle_all = np.arctan2(dy[saccade_indices_xy],
                            dx[saccade_indices_xy])
    n_all = len(saccade_indices_xy)

    # ───────── master figure (ALL saccades) ─────────
    fig = plt.figure(figsize=(11, 6))
    gs  = gridspec.GridSpec(2, 2, width_ratios=[3, 2])
    ax_quiver = fig.add_subplot(gs[:, 0])
    ax_polar  = fig.add_subplot(gs[0, 1], polar=True)
    ax_linear = fig.add_subplot(gs[1, 1])

    ax_quiver.set_xlim(*X_LIM); ax_quiver.set_ylim(*Y_LIM)
    ax_quiver.set_xlabel('X (°)'); ax_quiver.set_ylabel('Y (°)')
    ax_quiver.set_title(
        f"{session_name}\n"
        f"All translational saccades ({n_all}) — {eye_name}  (stim: {stim_type})\n"
        f"saccade_thresh = {saccade_thresh}, saccade_win = {saccade_win}s\n"
        f"blink_thresh = {blink_thresh}, blink_detection = {blink_detection}s\n"
        )

    cols = np.array([vector_to_rgb(a, max_abs) for a in angle_all])
    ax_quiver.quiver(x_all, y_all,
                        dx[saccade_indices_xy],
                        dy[saccade_indices_xy],
                        angles='xy', scale_units='xy', scale=1,
                        color=cols, alpha=.5)

    # PCA arrows (unchanged)
    # pca.fit(eye_pos_diff[saccade_indices] /
    #         np.linalg.norm(eye_pos_diff[saccade_indices], axis=1, keepdims=True))
    # for i, (vec, var) in enumerate(zip(pca.components_, pca.explained_variance_ratio_)):
    #     ax_quiver.arrow(np.mean(x_all), np.mean(y_all),
    #                     *(vec * 10 * np.sqrt(var)),
    #                     color=['k', 'b'][i], width=0.1,
    #                     label=f'PC{i+1} ({var:.2f} var)')
    # ax_quiver.legend()

    plot_angle_distribution(angle_all, ax_polar)
    plot_linear_histogram(angle_all, ax_linear)
    plt.tight_layout()

    # save master figure
    all_fname = f"{session_name}_{eye_name}_ALL_{stim_type}.png"
    fig.savefig(results_dir / all_fname, dpi=300, bbox_inches='tight')


    # Determine the overall frame range [0, last_frame]
    last_frame = int(saccade_frames_xy.max())
    clipped_any = False
    plot_window = np.arange(0,saccade_window,1)

    # ───────── one figure per stimulus label (skip "All") ─────────
    for label, frames in stim_frames.items():
        if label == "All":
            continue

        # gather 1st saccades within ±plot_window around each stim

        idx_buf = []  # buffer to collect saccade indices for this label

        # sort saccade frames to ensure they are in order
        sorted_pairs_xy = sorted(zip(saccade_frames_xy, saccade_indices_xy))

        for f in frames:

            lower_bound = max(f + plot_window[0], 0)
            upper_bound = min(f + plot_window[-1], saccade_frames_xy.max())

            for sf, idx in sorted_pairs_xy:
                if sf < lower_bound:
                    continue
                elif sf <= upper_bound:
                    idx_buf.append(idx)   # first valid saccade
                    break                 # only take the first one
                else:
                    break                 # skip to next stim

        idx_use = np.array(idx_buf, dtype=int)
        if idx_use.size == 0:
            continue

        ang = np.arctan2(dy[idx_use],
                            dx[idx_use])
        n_cond = len(idx_use)

        fig = plt.figure(figsize=(9, 5))
        gs  = gridspec.GridSpec(3, 2, width_ratios=[3, 2])
        ax_q = fig.add_subplot(gs[:, 0])
        ax_p = fig.add_subplot(gs[0, 1], polar=True)
        ax_l = fig.add_subplot(gs[1, 1])
        ax_t = fig.add_subplot(gs[2, 1]) if torsion_present else None

        ax_q.set_xlim(*X_LIM); ax_q.set_ylim(*Y_LIM)
        ax_q.set_xlabel('X (°)'); ax_q.set_ylabel('Y (°)')
        ax_q.set_title(f"{session_name}\n{eye_name} — {label} (n={n_cond})")

        cols = np.array([vector_to_rgb(a, max_abs) for a in ang])
        ax_q.quiver(eye_pos[idx_use, 0], eye_pos[idx_use, 1],
                    dx[idx_use], dy[idx_use],
                    angles='xy', scale_units='xy', scale=1,
                    color=cols, alpha=.5)

        plot_angle_distribution(ang, ax_p)
        plot_linear_histogram(ang, ax_l)

        if torsion_present:
            # Plot histogram of dtheta only for torsional saccades within window
            idx_buf_torsion = []

            # sort torsion saccade frames
            sorted_pairs_theta = sorted(zip(saccade_frames_theta, saccade_indices_theta))

            for f in frames:
                lower_bound = max(f + plot_window[0], 0)
                upper_bound = min(f + plot_window[-1], saccade_frames_theta.max())

                for sf, idx in sorted_pairs_theta:
                    if sf < lower_bound:
                        continue
                    elif sf <= upper_bound:
                        idx_buf_torsion.append(idx)
                        # break  # first torsional saccade
                    else:
                        break

            idx_torsion_use = np.array(idx_buf_torsion, dtype=int)
            if idx_torsion_use.size > 0:
                dtheta_torsion = dtheta[idx_torsion_use]
                ax_t.hist(dtheta_torsion, bins=20, color='purple', alpha=0.5, edgecolor='k')
                ax_t.set_title("Torsion angle distribution")
                ax_t.set_xlabel("deg/frame")
                ax_t.set_ylabel("Count")
                ax_t.set_xlim(-15, 15)

                # Add curved arrows for each torsional saccade
                for i in idx_torsion_use:
                    x0, y0 = eye_pos[i, 0], eye_pos[i, 1]
                    rotation_magnitude = np.abs(dtheta[i])
                    #print(f"Rotation magnitude for index {i}: {rotation_magnitude}")
                    curvature = -0.3 * np.sign(dtheta[i])  # direction of rotatio
                    arrow = FancyArrowPatch(
                        posA=(x0 - 0.7, y0-1),
                        posB=(x0 + 0.7, y0),
                        connectionstyle=f"arc3,rad={curvature}",
                        color='purple',
                        arrowstyle='->',
                        mutation_scale=10 + 2 * rotation_magnitude,  # scale by magnitude
                        linewidth=1.0,
                        alpha=0.8
                    )
                    ax_q.add_patch(arrow)

        fig.tight_layout()
        fname = f"{session_name}_{eye_name}_{label.replace('/','-')}.png"
        fig.savefig(results_dir / fname, dpi=300, bbox_inches='tight')


def plot_eye_fixations_between_cue_and_go_by_trial(
    eye_frame, eye_pos, eye_timestamp,
    cue_frame, cue_time, go_frame, go_time,
    max_interval_s=1.0,
    color_all='0.85', s_all=2, alpha_all=0.25,
    s_subset=5,  alpha_subset=0.9,
    cmap_name='tab20',
    results_dir=None, session_name=None, eye_name='Eye'
):
    # ---- coerce to arrays
    eye_ts     = np.asarray(eye_timestamp).ravel()
    eye_x      = np.asarray(eye_pos[:, 0]).ravel()
    eye_y      = np.asarray(eye_pos[:, 1]).ravel()
    cue_frame  = np.asarray(cue_frame).astype(int).ravel()
    cue_time   = np.asarray(cue_time).astype(float).ravel()
    go_frame   = np.asarray(go_frame).astype(int).ravel()
    go_time    = np.asarray(go_time).astype(float).ravel()

    # ---- sort cues & gos by time (carry frames alongside)
    ci = np.argsort(cue_time); cue_time, cue_frame = cue_time[ci], cue_frame[ci]
    gi = np.argsort(go_time);  go_time,  go_frame  = go_time[gi],  go_frame[gi]

    # ---- dedupe by time gaps (keeps first in each contiguous run)
    cue_time_on, cue_frame_on = cue_time, cue_frame
    go_time_on, go_frame_on = go_time, go_frame

    # ---- one-to-one time-based pairing: for each cue, take the NEXT go
    pairs_ct, pairs_gt, pairs_cf, pairs_gf, pairs_dt = [], [], [], [], []
    gptr = 0
    for ct, cf in zip(cue_time_on, cue_frame_on):
        while gptr < len(go_time_on) and go_time_on[gptr] < ct:
            gptr += 1
        if gptr >= len(go_time_on):
            break
        dt = float(go_time_on[gptr] - ct)
        pairs_ct.append(ct);  pairs_gt.append(go_time_on[gptr])
        pairs_cf.append(cf);  pairs_gf.append(int(go_frame_on[gptr]))
        pairs_dt.append(dt)
        gptr += 1  # consume this GO so it’s one-to-one

    pairs_ct = np.asarray(pairs_ct); pairs_gt = np.asarray(pairs_gt)
    pairs_cf = np.asarray(pairs_cf, dtype=int); pairs_gf = np.asarray(pairs_gf, dtype=int)
    pairs_dt = np.asarray(pairs_dt, dtype=float)

    # ---- filter by Δt window (seconds)
    valid_trials = (pairs_dt >= 0) & (pairs_dt < max_interval_s)

    # ---- plotting (use time to find eye samples; safer than frames)
    cmap = cm.get_cmap(cmap_name)
    base_colors = [cmap(i) for i in np.linspace(0, 1, 20)]
    color_cycle = cycle(base_colors)

    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(eye_x, eye_y, s=s_all, c=color_all, alpha=alpha_all, label='All eye centers')

    legend_handles = []
    trial_num = 0
    for ct, gt, ok, dt in zip(pairs_ct, pairs_gt, valid_trials, pairs_dt):
        if not ok:
            continue
        a = np.searchsorted(eye_ts, min(ct, gt), side='left')
        b = np.searchsorted(eye_ts, max(ct, gt), side='right')
        if b <= a:
            continue
        col = next(color_cycle)
        h = ax.scatter(eye_x[a:b], eye_y[a:b], s=s_subset, c=[col], alpha=alpha_subset,
                       label=f'Trial {trial_num} (Δt={dt:.2f}s)')
        legend_handles.append(h)
        trial_num += 1

    ax.set_aspect('equal')
    ax.set_xlabel('Eye center X (deg)')
    ax.set_ylabel('Eye center Y (deg)')
    ax.set_title(f'Eye centers: all vs. time-paired cue→go windows (<{max_interval_s:.1f}s)')

    # Trim legend clutter
    if len(legend_handles) > 10:
        ax.legend([ax.collections[0], *legend_handles[:10]],
                  ['All eye centers', *[lh.get_label() for lh in legend_handles[:10]]],
                  frameon=False, loc='best')
    else:
        ax.legend(frameon=False, loc='best')

    # optional save
    if results_dir is not None:
        results_dir = Path(results_dir); results_dir.mkdir(exist_ok=True, parents=True)
        fname = f"{session_name or 'session'}_{(eye_name or 'Eye').replace(' ', '')}_cue_go_timepaired.png"
        fig.savefig(results_dir / fname, dpi=300, bbox_inches='tight')

    # ---- diagnostics so you can sanity-check counts & thresholds
    print(f"Raw cues: {cue_time.size} → deduped: {cue_time_on.size}")
    print(f"Raw gos : {go_time.size}  → deduped: {go_time_on.size}")
    if pairs_dt.size:
        print(f"Paired trials: {pairs_dt.size} | dt min/median/max = "
              f"{np.nanmin(pairs_dt):.3f} / {np.nanmedian(pairs_dt):.3f} / {np.nanmax(pairs_dt):.3f} s")
        print(f"Passing (<{max_interval_s:.2f}s): {valid_trials.sum()} trials")

    return (pairs_cf, pairs_gf, pairs_ct, pairs_gt, pairs_dt, valid_trials,fig,ax)

def quantify_fixation_stability_vs_random(
    eye_timestamp, eye_pos,
    pairs_ct, pairs_gt, valid_trials,
    plot=True,
    rng_seed=0
):
    """
    Compare eye stability during fixation windows (cue->go for valid trials)
    to random, equal-duration windows drawn from the rest of the session.

    Returns a dict with arrays of per-window metrics and high-level means.
    Metrics per window:
      - mean_step_disp_px  : mean Euclidean step size (px)
      - mean_speed_px_s    : mean |velocity| (px/s)
      - net_drift_px       : |last - first| (px)
    """

    # --- coerce & sort eye samples by time
    ts = np.asarray(eye_timestamp, dtype=float).ravel()
    x  = np.asarray(eye_pos[:, 0]).ravel()
    y  = np.asarray(eye_pos[:, 1]).ravel()
    if not np.all(np.diff(ts) >= 0):
        order = np.argsort(ts)
        ts, x, y = ts[order], x[order], y[order]

    # --- build fixation windows from valid cue->go pairs
    ct = np.asarray(pairs_ct, dtype=float).ravel()
    gt = np.asarray(pairs_gt, dtype=float).ravel()
    ok = np.asarray(valid_trials, dtype=bool).ravel()

    fix_windows = [(c, g) for c, g, v in zip(ct, gt, ok) if v and (g > c)]
    if len(fix_windows) == 0:
        print("No valid fixation windows. Nothing to compute.")
        return None

    # Merge/clean fixation windows (ensure sorted, non-overlapping)
    fix_windows = sorted(fix_windows, key=lambda w: w[0])
    merged = []
    for s, e in fix_windows:
        if not merged or s > merged[-1][1]:
            merged.append([s, e])
        else:
            merged[-1][1] = max(merged[-1][1], e)  # merge overlaps
    fix_windows = [(s, e) for s, e in merged]

    # --- helper: compute metrics inside [t0, t1]
    def window_metrics(t0, t1):
        a = np.searchsorted(ts, t0, side='left')
        b = np.searchsorted(ts, t1, side='right')
        if b - a < 2:
            return np.nan, np.nan, np.nan
        dx = np.diff(x[a:b])
        dy = np.diff(y[a:b])
        dt = np.diff(ts[a:b])

        # valid finite steps with positive dt
        m = np.isfinite(dx) & np.isfinite(dy) & np.isfinite(dt) & (dt > 0)
        if not np.any(m):
            return np.nan, np.nan, np.nan

        step_disp = np.hypot(dx[m], dy[m])              # pixels
        speed     = step_disp / dt[m]                    # px/s
        drift     = np.hypot(x[b-1] - x[a], y[b-1] - y[a])  # pixels

        return float(step_disp.mean()), float(speed.mean()), float(drift)

    # --- compute fixation metrics per *original* (unmerged) window
    # (We’ll compare one random window per fixation with the same duration)
    orig_fix_windows = [(c, g) for c, g, v in zip(ct, gt, ok) if v and (g > c)]
    fix_len = np.array([g - c for c, g in orig_fix_windows], dtype=float)

    fix_mean_step = np.empty(len(orig_fix_windows))
    fix_mean_speed = np.empty(len(orig_fix_windows))
    fix_drift = np.empty(len(orig_fix_windows))
    for i, (c, g) in enumerate(orig_fix_windows):
        fix_mean_step[i], fix_mean_speed[i], fix_drift[i] = window_metrics(c, g)

    # --- build allowed (non-fixation) intervals across the whole session
    session_start, session_end = float(ts[0]), float(ts[-1])
    # complement of merged fixation windows
    allowed = []
    cursor = session_start
    for s, e in fix_windows:
        if s > cursor:
            allowed.append((cursor, s))
        cursor = max(cursor, e)
    if cursor < session_end:
        allowed.append((cursor, session_end))

    # convenience: function to draw a random start for a given duration
    rng = np.random.default_rng(rng_seed)
    def sample_random_window(duration):
        # find allowed intervals that can fit this duration
        candidates = [(a, b) for (a, b) in allowed if (b - a) >= duration]
        if not candidates:
            return None  # cannot fit (rare)
        a, b = candidates[rng.integers(0, len(candidates))]
        start = float(a) + rng.random() * float((b - a) - duration)
        return (start, start + duration)

    # --- draw one random window per fixation (equal duration) and compute metrics
    rnd_mean_step = np.empty(len(orig_fix_windows))
    rnd_mean_speed = np.empty(len(orig_fix_windows))
    rnd_drift = np.empty(len(orig_fix_windows))

    for i, L in enumerate(fix_len):
        rw = sample_random_window(L)
        if rw is None:
            rnd_mean_step[i] = rnd_mean_speed[i] = rnd_drift[i] = np.nan
        else:
            rnd_mean_step[i], rnd_mean_speed[i], rnd_drift[i] = window_metrics(*rw)

    # --- summarize (ignore NaNs)
    def nice_stats(arr):
        arr = np.asarray(arr, dtype=float)
        m = np.isfinite(arr)
        if not m.any():
            return np.nan, np.nan, 0
        vals = arr[m]
        return float(vals.mean()), float(vals.std(ddof=1) / np.sqrt(vals.size)), int(vals.size)

    ms_fix,  se_fix,  n_fix  = nice_stats(fix_mean_step)
    ms_rnd,  se_rnd,  n_rnd  = nice_stats(rnd_mean_step)
    sp_fix,  se_spf,  _      = nice_stats(fix_mean_speed)
    sp_rnd,  se_spr,  _      = nice_stats(rnd_mean_speed)
    dr_fix,  se_drf,  _      = nice_stats(fix_drift)
    dr_rnd,  se_drr,  _      = nice_stats(rnd_drift)

    print("=== Stability summary (mean ± s.e.m.) ===")
    print(f"Mean step displacement (px):  fix {ms_fix:.3f} ± {se_fix:.3f}   vs   rand {ms_rnd:.3f} ± {se_rnd:.3f}  (n={n_fix} pairs)")
    print(f"Mean speed (px/s):            fix {sp_fix:.3f} ± {se_spf:.3f}   vs   rand {sp_rnd:.3f} ± {se_spr:.3f}")
    print(f"Net drift (px):               fix {dr_fix:.3f} ± {se_drf:.3f}   vs   rand {dr_rnd:.3f} ± {se_drr:.3f}")

    # --- optional quick plot
    if plot:
        fig, axes = plt.subplots(1, 3, figsize=(12, 4), constrained_layout=True)
        pairs = [
            ("Mean step (deg)", fix_mean_step, rnd_mean_step),
            ("Mean speed (deg/s)", fix_mean_speed, rnd_mean_speed),
            ("Net drift (deg)", fix_drift, rnd_drift),
        ]
        for ax, (title, a, b) in zip(axes, pairs):
            m = np.isfinite(a) & np.isfinite(b)
            ax.scatter(a[m], b[m], s=10, alpha=0.6)
            # y=x reference line
            lo = np.nanmin(np.concatenate([a[m], b[m]]))
            hi = np.nanmax(np.concatenate([a[m], b[m]]))
            if np.isfinite(lo) and np.isfinite(hi) and hi > lo:
                ax.plot([lo, hi], [lo, hi], linestyle='--', linewidth=1, alpha=0.5)
                ax.set_xlim(lo, hi); ax.set_ylim(lo, hi)
            ax.set_xlabel("Fixation"); ax.set_ylabel("Random")
            ax.set_title(title)
            ax.set_aspect('equal', adjustable='box')
        fig.suptitle("Fixation vs. random windows (paired, equal duration)")
    else:
        fig = None

    return {
        "fix_mean_step_px": fix_mean_step,
        "rnd_mean_step_px": rnd_mean_step,
        "fix_mean_speed_px_s": fix_mean_speed,
        "rnd_mean_speed_px_s": rnd_mean_speed,
        "fix_net_drift_px": fix_drift,
        "rnd_net_drift_px": rnd_drift,
        "summary": {
            "mean_step_fix_mean±sem": (ms_fix, se_fix, n_fix),
            "mean_step_rand_mean±sem": (ms_rnd, se_rnd, n_rnd),
            "mean_speed_fix_mean±sem": (sp_fix, se_spf),
            "mean_speed_rand_mean±sem": (sp_rnd, se_spr),
            "net_drift_fix_mean±sem": (dr_fix, se_drf),
            "net_drift_rand_mean±sem": (dr_rnd, se_drr),
        },
        "figure": fig,
    }





## Extract relevant files from filepath

In [None]:
# determine which camera is used based on the folder name and cut off date
camera_side = determine_camera_side(folder_path)
eye_name = f"{camera_side} Eye"  # e.g. "Left camera" or "Right camera"
print(f"Using camera side: {camera_side}")



print(f"Scanning folder: {folder_path}")
print(f"Found {len(os.listdir(folder_path))} files")

# Scan the folder for specific files
for f in os.listdir(folder_path):
    f_lower = f.lower()
    full_path = os.path.join(folder_path, f)

    if 'imu' in f_lower:
        IMU_file = full_path
    if 'camera' in f_lower:
        camera_file = full_path
    if 'go' in f_lower:
        go_file = full_path
    if f"ellipse_center_XY_{camera_side}".lower() in f_lower:
        ellipse_center_XY_file = full_path
    if f"origin_of_eyecoordinate_{camera_side}".lower() in f_lower:
        origin_of_eye_coordinate_file = full_path
    if f"vdaxis_{camera_side}".lower() in f_lower:
        vdaxis_file = full_path
    if f"torsion_{camera_side}".lower() in f_lower:
        torsion_file = full_path
    if 'endoftrial' in f_lower:
        end_of_trial_file = full_path
    if 'cue' in f_lower:
        cue_file = full_path




## Extract stim and eye position data 

In [None]:
## Read the camera data and map between camera TTL (for saccades) and Bonsai TTLs (for frames)
camera_data = np.genfromtxt(camera_file, delimiter=',', skip_header=1, dtype=np.float64)
[bonsai_frame, bonsai_time] = camera_data[:, 0], camera_data[:, 1]
bonsai_frame = bonsai_frame.astype(int)  # Convert bonsai_frame to integer type


### Read the go file for the start of stim in the trial 
new_go_data_format=0
go_data = np.genfromtxt(clean_csv(go_file), delimiter=',', skip_header=1, dtype=np.float64)
if go_data.shape[1]>3:
    new_go_data_format = 1
    [go_frame, go_time, go_direction_x,go_direction_y] = go_data[:, 0], go_data[:, 1], go_data[:, 2], go_data[:,3]
else:
    [go_frame, go_time, go_direction] = go_data[:, 0], go_data[:, 1], go_data[:, 2]
go_frame = go_frame.astype(int)  # Convert go_frame to integer type

### Read the ellipse center XY file  
ellipse_center_XY_data = np.genfromtxt(clean_csv(ellipse_center_XY_file), delimiter=',', skip_header=1, dtype=np.float64)
[eye_frame,eye_timestamp,eye_x,eye_y] = ellipse_center_XY_data[:, 0], ellipse_center_XY_data[:, 1], ellipse_center_XY_data[:, 2], ellipse_center_XY_data[:, 3]
eye_frame = eye_frame.astype(int)  # Convert eye_frame to integer type
eye_x = interpolate_nans(eye_x)  # Interpolate NaN values in eye_x
eye_y = -1*interpolate_nans(eye_y)  # Interpolate NaN values in eye_y

### Read the origin of eye coordinate file
origin_of_eye_coordinate_data = np.genfromtxt(clean_csv(origin_of_eye_coordinate_file), delimiter=',', skip_header=1, dtype=np.float64)
[origin_frame,o_ts,l_x,l_y,r_x,r_y] = origin_of_eye_coordinate_data[:, 0], origin_of_eye_coordinate_data[:, 1], origin_of_eye_coordinate_data[:, 2], origin_of_eye_coordinate_data[:, 3], origin_of_eye_coordinate_data[:, 4], origin_of_eye_coordinate_data[:, 5]
origin_frame = origin_frame.astype(int)  # Convert origin_frame_r to integer type
l_x = interpolate_nans(l_x)  # Interpolate NaN values in l_rx
r_x = interpolate_nans(r_x)  # Interpolate NaN values in r_rx 
l_y = interpolate_nans(l_y)  # Interpolate NaN values in l_ry
r_y = interpolate_nans(r_y)  # Interpolate NaN values in r_ry


## Read the torsion data - this is used for torsion detection
torsion_data = np.genfromtxt(clean_csv(torsion_file), delimiter=',', skip_header=1, dtype=np.float64)
[torsion_frame, torsion_ts, torsion] = torsion_data[:, 0], torsion_data[:, 1], torsion_data[:, 2]
torsion_frame = torsion_frame.astype(int)   # Convert torsion_frame to integer type
# Interpolate NaN values        
torsion = interpolate_nans(torsion)


## Read the vertical (VD) axis data - this is used for blink detection
vdaxis_data = np.genfromtxt(clean_csv(vdaxis_file),delimiter=',',skip_header=1,dtype=np.float64)
[vd_frame,vd_ts,vd_lx,vd_ly,vd_rx,vd_ry] = vdaxis_data[:,0],vdaxis_data[:,1],vdaxis_data[:,2],vdaxis_data[:,3],vdaxis_data[:,4],vdaxis_data[:,5]
vd_frame = vd_frame.astype(int)
# Interpolate NaN values
vd_lx = interpolate_nans(vd_lx)
vd_ly = interpolate_nans(vd_ly)
vd_rx = interpolate_nans(vd_rx)
vd_ry = interpolate_nans(vd_ry)

        
### Read the IMU data for the accelerometer and gyroscope
imu_data = np.genfromtxt(IMU_file, delimiter=',', skip_header=1, dtype=np.float64)
[imu_time,a_x,a_y,a_z,g_x,g_y,g_z,m_x,m_y,m_z] = imu_data[:, 0], imu_data[:, 1], imu_data[:, 2], imu_data[:, 3], imu_data[:, 4], imu_data[:, 5], imu_data[:, 6], imu_data[:, 7], imu_data[:, 8], imu_data[:, 9]
imu_time = imu_time.astype(np.float64)  # Ensure imu_time is in float64 format
# Interpolate NaN values in IMU data
a_x = interpolate_nans(a_x)
a_y = interpolate_nans(a_y)
a_z = interpolate_nans(a_z)
g_x = interpolate_nans(g_x)
g_y = interpolate_nans(g_y)
g_z = interpolate_nans(g_z)
m_x = interpolate_nans(m_x)
m_y = interpolate_nans(m_y)
m_z = interpolate_nans(m_z)


### Read the endoftrial file- This file tells us when the trial ends, stim direction, eye movement direction, torsion angle, and whether the trial was successful
try:
    end_of_trial_data = np.genfromtxt(clean_csv(end_of_trial_file), delimiter=',', skip_header=1, dtype=np.float64)
    [end_of_trial_frame, end_of_trial_ts, trial_stim_direction, trial_eye_movement_direction, trial_torsion_angle, trial_success] = end_of_trial_data[:, 0], end_of_trial_data[:, 1], end_of_trial_data[:, 2], end_of_trial_data[:, 3], end_of_trial_data[:, 4], end_of_trial_data[:, 5]
    end_of_trial_frame = end_of_trial_frame.astype(int)  # Convert end_of_trial_frame to integer type
    # Interpolate NaN values in trial_torsion_angle
    trial_torsion_angle = interpolate_nans(trial_torsion_angle)
    trial_eye_movement_direction = interpolate_nans(trial_eye_movement_direction)
except IndexError:
    end_of_trial_data = np.genfromtxt(clean_csv(end_of_trial_file), delimiter=',', skip_header=1, dtype=np.float64)
    [end_of_trial_frame, end_of_trial_ts, trial_stim_direction, trial_eye_movement_direction, trial_success] = end_of_trial_data[:, 0], end_of_trial_data[:, 1], end_of_trial_data[:, 2], end_of_trial_data[:, 3], end_of_trial_data[:, 4]
    end_of_trial_frame = end_of_trial_frame.astype(int)  # Convert end_of_trial_frame to integer type

    trial_eye_movement_direction = interpolate_nans(trial_eye_movement_direction)
except ValueError:
    print("No end of trial data found. Skipping this step.")
for t in range(len(trial_success)):
    if trial_success[t] == 0:
        if trial_eye_movement_direction[t] != -1:
            trial_success[t] = -1   # Incorrect trial

# --- Read the cue file (every-frame logging) and keep only trial onsets ---
cue_data = np.genfromtxt(clean_csv(cue_file), delimiter=',', skip_header=1, dtype=np.float64)

cue_frame_raw     = cue_data[:, 0].astype(int)
cue_time_raw      = cue_data[:, 1].astype(float)
cue_direction_raw = cue_data[:, 2]  # keep dtype as-is (often int/float)

# Sort by time to be safe (carry frames/directions along)
order = np.argsort(cue_time_raw)
cue_time_raw      = cue_time_raw[order]
cue_frame_raw     = cue_frame_raw[order]
cue_direction_raw = cue_direction_raw[order]

# Define what counts as a 'new trial' gap between consecutive cue rows
TRIAL_GAP_S = 1.5  # <-- adjust to 2–3 if your inter-trial gap is longer

# Keep only the FIRST row after each large time jump (trial onset)
onset_idx = np.r_[0, np.where(np.diff(cue_time_raw) > TRIAL_GAP_S)[0] + 1]
cue_frame     = cue_frame_raw[onset_idx]
cue_time      = cue_time_raw[onset_idx]
cue_direction = cue_direction_raw[onset_idx]

print(f"Detected {cue_frame.size} cue onsets from {cue_frame_raw.size} cue rows (gap > {TRIAL_GAP_S}s).")

# --- Align lengths with GO events (1 line per trial) ---
if len(cue_frame) != len(go_frame):
    n = min(len(cue_frame), len(go_frame))
    if len(cue_frame) > len(go_frame):
        print(f"Warning: {len(cue_frame)} cue onsets but {len(go_frame)} GO rows; truncating cues to {n}.")
        cue_frame, cue_time, cue_direction = cue_frame[:n], cue_time[:n], cue_direction[:n]
    else:
        print(f"Warning: {len(cue_frame)} cue onsets but {len(go_frame)} GO rows; truncating GO to {n}.")
        go_frame, go_time = go_frame[:n], go_time[:n]

## Sanity check: How far apart are the stimuli?  


In [None]:

d_frames = np.diff(go_frame)    # successive differences (frames)
d_sec = d_frames / ttl_freq

fig =plt.figure(figsize=(8,3))
plt.plot(d_sec, marker='o')
plt.xlabel('Stimulus index')
plt.ylabel('Δtime (s) to next stim')
plt.title('Seconds between successive Go Stims')
plt.grid(alpha=.3)
plt.tight_layout()
plt.show()

# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_Stim_Interval.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')

## Analyze and plot saccade statistics

In [None]:

# --------------------------------------------------------------
# ONE analyse-&-plot call for the Eye saccades
saccades = detect_saccades(
    l_x, l_y, r_x, r_y,
    eye_x, eye_y,
    eye_frame,
    cal,
    blink_detection = blink_detection,
    vd_axis_lx = vd_lx, vd_axis_ly = vd_ly,
    vd_axis_rx =vd_rx, vd_axis_ry = vd_ry,
    saccade_threshold       = saccade_thresh,
    blink_velocity_threshold= blink_thresh,
    torsion_angle = torsion,
    saccade_threshold_torsion = torsion_velocity_thresh,
)

saccades["stim_frames"], stim_type = organize_stims(
    go_frame,
    go_dir_x = go_direction_x,
    go_dir_y = go_direction_y,
)

sort_plot_saccades(
    saccades,
    saccade_window= saccade_win*ttl_freq,
    session_path = folder_path,
    stim_type    = stim_type,
    eye_name     = eye_name,

)


## Look at probability of saccades as a function of stimulus onset

In [None]:
go_frame        = go_frame = np.array(go_frame).flatten()
saccade_frames  = saccades["saccade_frames_xy"]
ttl_freq        = 60
t_window_s      = 2
bin_ms          = 50

total_duration_frames = None

# ─── Setup ─────────────────────────────────────────────
frame_win   = int(t_window_s * ttl_freq)
bin_frames  = int((bin_ms / 1000) * ttl_freq)
bins        = np.arange(-frame_win, frame_win + 1, bin_frames)
t_sec       = (bins[:-1] + bin_frames / 2) / ttl_freq

# ─── Estimate baseline saccade probability ─────────────
if total_duration_frames is None:
    total_duration_frames = saccade_frames.max()

total_bins = total_duration_frames / bin_frames
baseline_rate = len(saccade_frames) / total_bins  # saccades per bin

# ─── Helper: peri-stimulus histogram ──────────────────
def peristim_change(stim_frames):
    rel_times = []
    for f0 in stim_frames:
        nearby = saccade_frames[
            (saccade_frames >= f0 - frame_win) &
            (saccade_frames <= f0 + frame_win)
        ]
        rel_times.extend(nearby - f0)
    rel_times = np.array(rel_times)

    counts, _ = np.histogram(rel_times, bins=bins)
    prob_bin = counts / len(stim_frames)  # saccades per bin per stim
    change = 100 * (prob_bin - baseline_rate) / baseline_rate
    return change

# ─── Group by direction ────────────────────────────────
go_dir_x = np.nan_to_num(go_direction_x)
go_dir_y = np.nan_to_num(go_direction_y)

stim_all = go_frame
stim_L   = go_frame[go_dir_x < 0]
stim_R   = go_frame[go_dir_x > 0]
stim_D   = go_frame[go_dir_y < 0]
stim_U   = go_frame[go_dir_y > 0]

# ─── Compute changes ───────────────────────────────────
ch_all = peristim_change(stim_all)
ch_L   = peristim_change(stim_L)
ch_R   = peristim_change(stim_R)
ch_D   = peristim_change(stim_D)
ch_U   = peristim_change(stim_U)

# ─── Plot ──────────────────────────────────────────────
# Plotting
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 4), sharex=True)

# Top: all stimuli
ax1.plot(t_sec, ch_all, label="All stimuli", color="black", lw=2)
ax1.axvline(0, color='k', ls='--', alpha=.5)
ax1.axhline(0, color='gray', ls=':', lw=1)
ax1.set_title('Stim-locked saccade probability (relative to baseline saccade rate)')
ax1.grid(alpha=.3)
ax1.set_ylim(-100,200)
ax1.legend()

# Bottom: each direction
if ch_L is not None:
    ax2.plot(t_sec, ch_L, label=f"Left (n={len(stim_L)})", color='green')
if ch_R is not None:
    ax2.plot(t_sec, ch_R, label=f"Right (n={len(stim_R)})", color='pink')
if ch_D is not None:
    ax2.plot(t_sec, ch_D, label=f"Down (n={len(stim_D)})", color='blue')
if ch_U is not None:
    ax2.plot(t_sec, ch_U, label=f"Up (n={len(stim_U)})", color='red')

ax2.axvline(0, color='k', ls='--', alpha=.5)
ax2.axhline(0, color='gray', ls=':', lw=1)
ax2.set_xlabel('Time from stimulus onset (s)')
ax2.set_title('Modulation by direction')
ax2.set_ylim(-100,200)
ax2.grid(alpha=.3)
#ax2.legend()

# Shared y-axis label
fig.text(0.01, 0.5, '% Change ',
         va='center', rotation='vertical', fontsize=11)

fig.tight_layout()


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_StimLockedProb.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')

## Sanity check: Plot an overlay of Go-stims and Saccades 

In [None]:
# ============================================================
# Timeline plot: saccades (grey) + colour-coded stimuli
# ============================================================
# -----------------------------------------------------------------
# Assumes these objects already exist in the workspace
# -----------------------------------------------------------------
# right["saccade_frames"]   – Bonsai frame IDs of all saccades
# go_frame                  – Bonsai frame IDs of stimuli
# go_dir_x, go_dir_y        – direction codes (can be None)
# ttl_freq                  – camera TTL rate (Hz)
# results_dir               – Path(folder_path) / "Results"
# -----------------------------------------------------------------

# 1) prep data ----------------------------------------------------
saccade_frames = np.asarray(saccades["saccade_frames_xy"], dtype=int)

# fallback arrays if dir arrays are missing
gx = go_dir_x if (go_dir_x is not None) else np.zeros_like(go_frame)
gy = go_dir_y if (go_dir_y is not None) else np.zeros_like(go_frame)


# ── 1-bis. count how many of each stimulus -----------------------------
eps = 1e-6                      # tolerance for “zero”
is_left   = (np.abs(gy) < eps) & (gx < -eps)
is_right  = (np.abs(gy) < eps) & (gx >  eps)
is_down   = (np.abs(gx) < eps) & (gy < -eps)
is_up     = (np.abs(gx) < eps) & (gy >  eps)

n_left, n_right = is_left.sum(),  is_right.sum()
n_down, n_up    = is_down.sum(),  is_up.sum()


# palette mapping
palette = {'L': 'green', 'R': 'pink', 'D': 'blue', 'U': 'red', 'NA': 'gray'}

# build colour list per stimulus
colors = []
for x, y in zip(gx, gy):
    if abs(y) > 1e-6:                       # Up / Down has priority
        colors.append(palette['U' if y > 0 else 'D'])
    elif abs(x) > 1e-6:                     # Left / Right
        colors.append(palette['R' if x > 0 else 'L'])
    else:
        colors.append(palette['NA'])

# sort frames & colours together
order      = np.argsort(go_frame)
t_stim     = go_frame[order] / ttl_freq
colors     = [colors[i] for i in order]

# convert saccade frames to seconds
t_sacc = np.sort(saccade_frames) / ttl_freq

# 2) plot ---------------------------------------------------------
fig, ax = plt.subplots(figsize=(12, 2.5))

# saccades: grey vertical ticks at y = 0
ax.vlines(t_sacc, -0.1, 0.1, colors='0.25', linewidth=1)

# stimuli: colour ticks at y = 1
ax.vlines(t_stim, 0.9, 1.1, colors=colors, linewidth=2)

# axes formatting
ax.set_yticks([0, 1])
ax.set_yticklabels(['Saccade', 'Stim'])
ax.set_xlabel('Time (s)')
ax.set_title('Timeline of saccades and stimuli')
ax.set_xlim(t_sacc.min() - 1, t_sacc.max() + 1)
ax.set_ylim(-0.5, 1.5)
ax.grid(axis='x', alpha=.3)

# legend
handles = [
    mlines.Line2D([], [], color='0.25', marker='|', ls='', markersize=10,
                  label='Saccade'),
    mlines.Line2D([], [], color='green', marker='|', ls='', markersize=10,
                  label=f'Stim Left  (n={n_left})'),
    mlines.Line2D([], [], color='pink',  marker='|', ls='', markersize=10,
                  label=f'Stim Right (n={n_right})'),
    mlines.Line2D([], [], color='blue',  marker='|', ls='', markersize=10,
                  label=f'Stim Down  (n={n_down})'),
    mlines.Line2D([], [], color='red',   marker='|', ls='', markersize=10,
                  label=f'Stim Up    (n={n_up})')
]
ax.legend(handles=handles, loc='upper right', ncol=5, fontsize=9, framealpha=.9)

plt.tight_layout()


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_timeline_saccade_vs_stim.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')


## Plot how many stims produced at least one saccade

In [None]:
# ============================================================
#  Probability of ≥1 saccade within 0.5 s of each stimulus


# ----- parameters -------------------------------------------
win     = saccade_win                  # seconds after onset
w_frames = int(win * ttl_freq)  # convert to frames

# direction masks
gx = go_direction_x if go_direction_x is not None else np.zeros_like(go_frame)
gy = go_direction_y if go_direction_y is not None else np.zeros_like(go_frame)

dir_info = {
    'Left' :  (gx < -1e-6,  'green'),
    'Right':  (gx >  1e-6,  'pink'),
    'Down' :  (gy < -1e-6,  'blue'),
    'Up'   :  (gy >  1e-6,  'red')
}

labels, probs, colors = [], [], []

for label, (mask, col) in dir_info.items():
    stim_frames = go_frame[mask]
    n_stim      = len(stim_frames)
    if n_stim == 0:
        continue                                 # skip if this direction absent
    # check each stimulus: does ANY saccade happen within +win seconds?
    has_sacc = [( (saccade_frames >= f) & (saccade_frames <= f + w_frames) ).any()
                for f in stim_frames]
    prob = np.mean(has_sacc)                    # fraction of stimuli with ≥1 sac
    labels.append(label)
    probs.append(prob)
    colors.append(col)
    print(f"{label:5s}: {prob*100:5.1f}%  ({sum(has_sacc)}/{n_stim} stimuli)")

# ----- bar chart --------------------------------------------
fig, ax = plt.subplots(figsize=(6,4))
ax.bar(labels, probs, color=colors, edgecolor='k')
ax.set_ylim(0, 1)
ax.set_ylabel(f"P(saccade within {win}s)")
ax.set_title(f"Probability of a saccade in first {win} s after stimulus")
ax.grid(axis='y', alpha=.3)
plt.tight_layout()

# ----- save (optional) --------------------------------------


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_saccade_prob_within_{win*1000:.0f}ms.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')



## Fixation: Plot the fixation eye positions and the distribution.

Get the fixation position from the eye positions just before the stimulus onset (Go Cue)


In [None]:
## For each go frame, look at the last available frame before the go frame in the eye position traces.
eye_position_during_fixation=[]
eye_position_during_fixation_success = []
## Basic sanity check regarding trials statistics collected from different sources
if (len(end_of_trial_frame) != len(go_frame)):
    print(f"Warning: Number of end of trial frames ({len(end_of_trial_frame)}) does not match number of go frames ({len(go_frame)}).")

for i, f in enumerate(go_frame[:len(trial_success)]):
    #last_frame_before_go = eye_frame[eye_frame < f][-31:-1]  # last 30 eye frame before stim
    #eye_pos = np.mean(saccades["eye_pos"][np.where(eye_frame == last_frame_before_go)[0]], axis=0)  # average eye position at that frame
    eye_pos = np.mean(saccades["eye_pos"][np.where(eye_frame <f)[0][-7:-1]], axis=0)  # average eye position at the last 30 frames before the go frame
    if len(eye_pos) == 0:
        print(f"Warning: No eye position data found for go frame {f}. Skipping this frame.")
        continue
    eye_position_during_fixation.append(eye_pos)
    if trial_success[i] == 1:
        eye_position_during_fixation_success.append(eye_pos)
eye_position_during_fixation = np.array(eye_position_during_fixation)
eye_position_during_fixation_success = np.array(eye_position_during_fixation_success)
# Ratio of the total spread of the eye positions during fixation to the total spread of all eye positions
eye_pos_all = saccades["eye_pos"]  
spread_fixation = np.std(eye_position_during_fixation, axis=0)
spread_all = np.std(eye_pos_all, axis=0)
ratio_spread = spread_fixation / spread_all
print(f"Ratio of spread during fixation to all eye positions: {ratio_spread}")


# Plot all the eye positions during the session first and then the eye positions during fixation
fig = plt.figure(figsize=(8, 6))
plt.scatter(saccades["eye_pos"][:, 0], saccades["eye_pos"][:, 1], color='red', alpha=0.1, label='All Eye Positions')
plt.scatter(eye_position_during_fixation[:, 0], eye_position_during_fixation[:, 1], color='blue', alpha=0.4, label='Eye Positions During Fixation')
plt.scatter(eye_position_during_fixation_success[:, 0], eye_position_during_fixation_success[:, 1], color='green', alpha=0.5, label='Eye Positions During Fixation (Successful Trials)')
plt.xlabel('X Position (deg)')
plt.ylabel('Y Position (deg)')
plt.title('Eye Positions in the orbit during the whole session and during fixation')
plt.legend()
plt.grid()
plt.show()



## Fixation: Compare eye movements during fixation period to randomly selected, matched duration, other times

In [None]:
rng = np.random.default_rng(123)  # set a seed for reproducibility
cue_frame_jit = (cue_frame + rng.integers(0, 101, size=cue_frame.shape)).astype(int)
cue_time_jit  = cue_time + rng.uniform(0.0, 5.0, size=cue_time.shape)

(pairs_cf, pairs_gf, pairs_ct, pairs_gt, pairs_dt, valid_trials,fig,ax)= plot_eye_fixations_between_cue_and_go_by_trial(
    eye_frame=eye_frame, eye_pos=saccades["eye_pos"], eye_timestamp=eye_timestamp,
    cue_frame=cue_frame, cue_time=cue_time,
    #cue_frame=cue_frame_jit, cue_time=cue_time_jit,    
    go_frame=go_frame,  go_time=go_time,
    max_interval_s=1,
    results_dir=results_dir, session_name=session_name, eye_name=eye_name
)

# Now quantify stability vs random (and show a small paired scatter summary)
stats = quantify_fixation_stability_vs_random(
    eye_timestamp=eye_timestamp,
    eye_pos=saccades["eye_pos"],
    pairs_ct=pairs_ct, pairs_gt=pairs_gt,
    valid_trials=valid_trials,
    plot=True,     # set False if you only want numbers
    rng_seed=0
)

In [None]:
## This section is only for the fixation experiments. Might throw an error if the data is not from a fixation experiment.
fixation_experiment = False
#----------------------------------------------------------------------------------------------------------
## Divide the successful fixation trials based on how long it took the animal to reach fixation point. We calculate this from the difference between the cue frame and go_frame.
if fixation_experiment:
    fixation_time_theshold = 0.75  # seconds
    total_fixation_time_per_trial = go_time - cue_time
    #plt.hist(total_fixation_time_per_trial, bins=20, color='gray', alpha=0.7)
    mask_short_fixation = total_fixation_time_per_trial <= fixation_time_theshold
    mask_long_fixation = total_fixation_time_per_trial > fixation_time_theshold
    short_fixation_frames = []
    for i, f in enumerate(cue_frame[mask_short_fixation]):
        short_fixation_frames.append(eye_frame[np.where(eye_frame>=f)[0][0]:np.where(eye_frame<=go_frame[mask_short_fixation][i])[0][-1]])
    short_fixation_frames = np.array(short_fixation_frames, dtype=object)  # array of arrays
    short_fixation_eye_positions = []
    for frames in short_fixation_frames:
        positions = []
        for fr in frames:
            pos = saccades["eye_pos"][np.where(eye_frame == fr)[0]]
            if len(pos) > 0:
                positions.append(pos[0])
        positions = np.array(positions)
    # positions = positions - np.mean(positions, axis=0)  # center the positions around the mean
        short_fixation_eye_positions.append(positions)
    short_fixation_eye_positions = np.array(short_fixation_eye_positions, dtype=object)  # array of arrays
    long_fixation_frames = []
    for i, f in enumerate(cue_frame[mask_long_fixation]):
        long_fixation_frames.append(eye_frame[np.where(eye_frame>=f)[0][0]:np.where(eye_frame<=go_frame[mask_long_fixation][i])[0][-1]])
    long_fixation_frames = np.array(long_fixation_frames, dtype=object)  # array of arrays
    long_fixation_eye_positions = []    
    for frames in long_fixation_frames:
        positions = []
        for fr in frames:
            pos = saccades["eye_pos"][np.where(eye_frame == fr)[0]] 
            if len(pos) > 0:
                positions.append(pos[0])
        positions = np.array(positions)
    # positions = positions - np.mean(positions, axis=0)  # center the positions around the mean
        long_fixation_eye_positions.append(np.array(positions))
    long_fixation_eye_positions = np.array(long_fixation_eye_positions, dtype=object)  # array of arrays
    # Plot the eye positions during short and long fixation trials
    #fig = plt.figure(figsize=(12, 5))
    short_fixation_all_positions = np.vstack(short_fixation_eye_positions)
    long_fixation_all_positions = np.vstack(long_fixation_eye_positions)

    ## Let's look at some randomly selected windows from the whole sessions for control
    control_window_eye_positions = []
    control_window_size = int(fixation_time_theshold * ttl_freq)  # in frames
    control_window_start = np.random.choice



    #----------------------------------------------------------------------------------------------------------
    ## This commented section is for side-by-side plots
    #----------------------------------------------------------------------------------------------------------
    # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    # ax1.scatter(short_fixation_all_positions[:, 0], short_fixation_all_positions[:, 1], color='blue', alpha=0.3)
    # ax1.set_title(f'Eye Positions During Short Fixation Trials (<= {fixation_time_theshold}s)')
    # ax1.set_xlabel('X Position (deg)')
    # ax1.set_ylabel('Y Position (deg)')  
    # ax1.grid()
    # ax2.scatter(long_fixation_all_positions[:, 0], long_fixation_all_positions[:, 1], color='orange', alpha=0.3)
    # ax2.set_title(f'Eye Positions During Long Fixation Trials (> {fixation_time_theshold}s)')
    # ax2.set_xlabel('X Position (deg)')
    # ax2.set_ylabel('Y Position (deg)')  
    # ax2.grid()
    # plt.show()
    #----------------------------------------------------------------------------------------------------------
    # This section is for overlayed plots
    #----------------------------------------------------------------------------------------------------------
    fig = plt.figure(figsize=(8, 6))
    plt.scatter(saccades["eye_pos"][:, 0], saccades["eye_pos"][:, 1], color='red', alpha=0.1, label='All Eye Positions')
    plt.scatter(long_fixation_all_positions[:, 0], long_fixation_all_positions[:, 1], color='orange', alpha=0.3, label=f'Long Fixation (> {fixation_time_theshold}s)')
    plt.scatter(short_fixation_all_positions[:, 0], short_fixation_all_positions[:, 1], color='blue', alpha=0.3, label=f'Short Fixation (<= {fixation_time_theshold}s)')
    plt.xlabel('X Position (deg)')
    plt.ylabel('Y Position (deg)')
    plt.title('Eye Positions During Short and Long Fixation Trials')
    plt.legend()
    plt.grid()
    plt.show()



## Performance: Plot the perforance of the animal throughout the session.

 Useful to see when the animal is getting distracted or tired. Also useful to see if the animal is learning the task.

In [None]:
## Plot the moving average of the trial success rate over a sliding window of 20 trials
window_size = 10
trial_success = np.array(trial_success, dtype=int)  # Ensure it's an integer array
moving_avg_success = np.convolve(trial_success, np.ones(window_size)/window_size, mode='valid')
fig = plt.figure(figsize=(10, 5))
plt.plot(moving_avg_success, color='blue', label='Moving Average Success Rate')
plt.xlabel('Trial Index')
plt.ylabel('Success Rate (Moving Average)')
plt.title(f'Moving Average of Trial Success Rate (Window Size: {window_size})')
plt.axhline(y=0.5, color='red', linestyle='--', label='50% Success Rate')
plt.legend()
plt.grid()
plt.tight_layout()

## Anti-saccade: Plot the performance of the animal in the antisaccade task.

 Correct, Incorrect, Missed trials. Magnitude and timing of saccades in correct and incorrect trials. Percentage of correct trials that had first saccade in the incorrect direction.

In [None]:
## First plot the percentage of the correct, incorrect, and missed trials
# end_of_trial_frame, end_of_trial_ts, trial_stim_direction, trial_eye_movement_direction, trial_torsion_angle, trial_success
num_trials = len(end_of_trial_frame)
num_correct = np.sum(trial_success == 1)
num_missed = np.sum(trial_success == 0)
num_incorrect = np.sum(trial_success == -1)
fig, ax = plt.subplots(figsize=(8, 6))
labels = ['Correct Trials', 'Missed Trials', 'Incorrect Trials']
sizes = [num_correct, num_missed, num_incorrect]
# Normalize sizes to percentages
# Percentages to be written on the bar chart
percentages = [f"{size/num_trials*100:.1f}%" for size in sizes]
colors = ['green', 'orange', 'red']
bars = ax.bar(labels, sizes, color=colors)
# Add percentage labels on top of the bars
for bar, percentage in zip(bars, percentages):
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2, yval + 0.5, percentage, ha='center', va='bottom')
ax.set_ylabel('Number of Trials')
ax.set_title('Trial Outcomes')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
# Save the figure
trial_outcome_fname = f"{session_name}_{eye_name}_Trial_Outcomes.png"   
fig.savefig(results_dir / trial_outcome_fname, dpi=300, bbox_inches='tight')


In [None]:
## Plot the comparison of saccade magnitude and direction between correct, incorrect, and missed trials
saccade_speeds_correct_all = []
saccade_speeds_correct_latency_all = []
saccade_speeds_incorrect_before_correct_all = []
saccade_latency_incorrect_before_correct_all = []
saccade_speeds_incorrect_all = []
saccade_speeds_incorrect_latency_all = []
saccade_speeds_correct_first_saccade_all = []
saccade_latency_correct_first_saccade_all = []
for i, f in enumerate(go_frame[:len(end_of_trial_frame)]):  # Loop through each go frame
    if trial_success[i] == 1: ## Correct trial
        # Find the number of saccade between go frame and end of the trial frame
        saccade_indices = np.where((saccades['saccade_frames_xy'] >=f) & (saccades['saccade_frames_xy'] <= end_of_trial_frame[i]))[0]
        if len(saccade_indices) > 0:
            saccade_speeds_correct = np.linalg.norm(saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices]], axis=1)
            saccade_directions_correct = np.rad2deg(np.arctan2(saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices], 1], saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices], 0]))
            saccade_speeds_correct_all.append(saccade_speeds_correct[-1])
            saccade_speeds_correct_latency_all.append((saccades['saccade_frames_xy'][saccade_indices[-1]] - f)/60.0)  # latency of the last saccade in the trial
            if len(saccade_indices) > 1:
                saccade_speeds_incorrect_before_correct_all.append(saccade_speeds_correct[0])  # first incorrect saccade before the correct one
                saccade_latency_incorrect_before_correct_all.append((saccades['saccade_frames_xy'][saccade_indices[0]] - f)/60.0)  # latency of the first saccade in the trial
                #print(f"Trial {i}: Correct trial with {len(saccade_indices)} saccades, first saccade speed: {saccade_speeds_correct[0]:.2f} deg/frame")
                #trial_success[i] = 2  # Mark this trial as succeincorrect before correct
            else:
                saccade_speeds_correct_first_saccade_all.append(saccade_speeds_correct[0])  # first saccade in the trial
                saccade_latency_correct_first_saccade_all.append((saccades['saccade_frames_xy'][saccade_indices[0]] - f)/60.0)  # latency of the first saccade in the trial
    elif trial_success[i] == -1: ## Incorrect trial
        # Find the number of saccade between go frame and end of the trial frame
        saccade_indices = np.where((saccades['saccade_frames_xy'] >= f) & (saccades['saccade_frames_xy'] <= end_of_trial_frame[i]))[0]
        if len(saccade_indices) > 0:
            saccade_speeds_incorrect = np.linalg.norm(saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices]], axis=1)
            saccade_directions_incorrect = np.rad2deg(np.arctan2(saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices], 1], saccades['eye_vel'][saccades['saccade_indices_xy'][saccade_indices], 0]))
            saccade_speeds_incorrect_all.append(saccade_speeds_incorrect[0])
            saccade_speeds_incorrect_latency_all.append((saccades['saccade_frames_xy'][saccade_indices[0]] - f)/60)  # latency of the first saccade in the trial
# saccade_speeds_correct_all = np.array(saccade_speeds_correct_all)
# saccade_speeds_correct_latency_all = np.array(saccade_speeds_correct_latency_all)
# saccade_speeds_incorrect_before_correct_all = np.array(saccade_speeds_incorrect_before_correct_all)
# saccade_speeds_incorrect_all = np.array(saccade_speeds_incorrect_all)
# saccade_speeds_incorrect_latency_all = np.array(saccade_speeds_incorrect_latency_all)
# saccade_speeds_correct_first_saccade_all = np.array(saccade_speeds_correct_first_saccade_all)
# saccade_latency_correct_first_saccade_all = np.array(saccade_latency_correct_first_saccade_all)
print(f"Number of correct trials with saccades: {len(saccade_speeds_correct_all)}, Number of incorrect trials with saccades: {len(saccade_speeds_incorrect_all)}")
print(f"Number of correct trials with both incorrect and correct saccades: {len(saccade_speeds_incorrect_before_correct_all)}")
print(f"Number of correct trials with only one correct saccade: {len(saccade_speeds_correct_first_saccade_all)}")
# Plot the box plot for saccade speeds for correct, incorrect, incorrect before correct, and correct first saccade trials
fig, ax = plt.subplots(figsize=(10, 6))
data = [saccade_speeds_correct_all, saccade_speeds_incorrect_all, saccade_speeds_incorrect_before_correct_all, saccade_speeds_correct_first_saccade_all]
labels = ['Correct Trials', 'Incorrect Trials', 'Incorrect Before Correct', 'Correct First Saccade']
ax.boxplot(data, labels=labels, patch_artist=True,
           boxprops=dict(facecolor='lightgreen', color='green'),    
              medianprops=dict(color='red'))    
ax.set_ylabel('Saccade Speed (deg/frame)')
ax.set_title('Comparison of Saccade Speeds by Trial Outcome')
#ax.grid(axis='y', alpha='0.3')
plt.tight_layout()
plt.show()
# Plot the box plot for saccade latencies for correct, incorrect, and incorrect before correct  and correct first saccade trials
fig, ax = plt.subplots(figsize=(10, 6))
data = [saccade_speeds_correct_latency_all, saccade_speeds_incorrect_latency_all, saccade_latency_incorrect_before_correct_all, saccade_latency_correct_first_saccade_all]
labels = ['Correct Trials', 'Incorrect Trials', 'Incorrect Before Correct', 'Correct First Saccade']
ax.boxplot(data, labels=labels, patch_artist=True,  
              boxprops=dict(facecolor='lightgreen', color='green'),
                medianprops=dict(color='red'))
ax.set_ylabel('Saccade Latency (s)')
ax.set_title('Comparison of Saccade Latencies by Trial Outcome')
#ax.grid(axis='y', alpha='0.3')
plt.tight_layout()
plt.show()

