### Analyze + Plot Eye Saccades
The goal of this script is to determine wether the animal's saccades
are aligned with the location of the stimulus. In other words, if the
animal is looking at the salient part of the screen. The script:

  - Requires a path to a folder on the lab's server (X:\)
  - Loads right- and left-eye Bonsai CSV outputs, IMU & stimulus TTL files.
  - Computes eye-centric gaze, detects saccades (optionally removing blinks).
  - Creates three kinds of figures and writes them to  <folder_path>/Results/ :
      ─ “ALL”  : quiver of every saccade  + polar & linear angle histograms
      ─ LR/UD  : same layout but restricted to each stimulus direction
  - Global X/Y limits are derived from the full session; every quiver plot
    shares them so comparisons are direct.
  - Figure filenames:  <session>_<Eye>_<condition>.png
      e.g.  Tsh001_2025-06-11T13_02_29_Right_ALL_Interleaved.png


### How saccades are detected in this script

1. **Eye-centred coordinates**  
   Subtract the mid-point of the two eye-corner markers so gaze is relative to the eye, not the camera.

2. **Denoise (`medfilt_vec`)**  
   A 3-point median filter removes single-frame tracking jitter.

3. **Pixel → degree conversion**  
   Divide by the calibration factor ( ≈ 3.76 px = 1 °) to work in visual degrees.

4. **Instantaneous velocity (`frame_velocity`)**  
   Take the frame-to-frame difference in x and y; combine into a scalar speed (deg / frame).

5. **Speed threshold**  
   If speed ≥ `1.5 deg/frame`, mark that frame as a saccade (`saccade_indices`).

6. **Blink removal (`blink_frames`, optional)**  
   When VD-axis files are present, large jumps in eyelid separation flag blinks; those frames are deleted from `saccade_indices`.

The cleaned list **`saccade_indices`** feeds every quiver plot, PCA arrow, and angle histogram in the notebook.




Tested with Python 3.10 (Conda env “EyeHeadCoupling”).

  Author:  <Ratnadeep Pal @SarvestaniLab> , <Madineh Sedigh-Sarvestani @SarvestaniLab> 

## Import the required libraries


In [113]:

import sys
import os
import numpy as np
import pandas as pd
from scipy.fft import fft
from scipy.signal import ShortTimeFFT,butter,hilbert,sosfiltfilt,medfilt
from scipy.signal.windows import gaussian
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib.ticker import ScalarFormatter
import tkinter as tk
from tkinter import filedialog
from sklearn.decomposition import PCA
import os
from io import StringIO
from matplotlib import gridspec
from pathlib import Path
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
import re
from datetime import datetime


%matplotlib qt


## Define parameters

In [None]:
'''Some helper functions to select folders and files using a GUI.
These are needed to define some parameters'''
# Prompt the user to select a folder
def select_folder():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    directory = filedialog.askdirectory()  # Open the file selection dialog
    return directory

# Prompt the user to open a file 
def select_file():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    file_path = filedialog.askopenfilename()  # Open the file selection dialog
    return file_path



cal = 3.76  # Calibration factor for the pixels to degrees
ttl_freq = 60  # TTL frequency in Hz

# Parameters for nlink and saccade detection
blink_detection = 1
blink_thresh= 10
saccade_thresh= 3
saccade_win=0.7    # Window size for saccade detection in seconds

#folder_path = select_folder() #this won't work if you're running jupyter lab in browser, so hard coding


###################################### Paris

#First day when we started doing interleaved stim.
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-09T13_09_02\\" # interleaved

#Second day where she licked a lot!
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T12_50_45\\" #no stim
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T13_02_29\\" # interleaved
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-11T13_15_39\\" #  interleaved

#Not much licking on this day
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-17T12_50_28\\"#interleaved
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-13T13_25_04\\" #no-stim session

#Motivated on this day, but first day where juice was only given for saccades
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-16T12_13_51\\" #interleaved

#second day juice was given for saccades only but she was stressed 
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-17T12_50_28\\" #interleaved

#third day juce was given for saccades only, but she was not thirsty
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-18T12_19_39\\" #interleaved

#fourth day, but she didn't pay attention the whole time
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-19T12_17_52\\" #interleaved


folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-06-30T15_20_56\\" #just L/R


#this is after ratnadeep fixed a bunch of issues with the code! and data looks great!!
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-02T15_14_29\\" #just L/R


#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-07T14_49_12\\" #just L/R
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-07T15_12_22\\" #just L/R

#good session!
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-08T15_04_12\\" #just L/R

#good session!
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-09T15_29_43\\" #just L/R
folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-09T15_49_30\\" #L/R and U/D 



######################################## Bayleaf
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\Rat22_Bayleaf_server\Rat022_2025-06-10T16_23_02\\"   #no stim
#folder_path = r"X:\Experimental_Data\EyeHeadCoupling_RatTS_server\Rat22_Bayleaf_server\Rat022_2025-06-10T16_10_21\\"   #no stim


results_dir = Path(folder_path) / "Results\\"
results_dir.mkdir(exist_ok=True)

session_name = os.path.basename(folder_path.rstrip("/\\"))




## Define functions


In [115]:

def get_session_date_from_path(path):
    match = re.search(r"\d{4}-\d{2}-\d{2}", path)
    if match:
        return datetime.strptime(match.group(), "%Y-%m-%d")
    else:
        raise ValueError("No valid date (YYYY-MM-DD) found in path")
    
def determine_camera_side(path, cutoff_date_str="2025-06-30"):
    session_date = get_session_date_from_path(path)
    cutoff_date = datetime.strptime(cutoff_date_str, "%Y-%m-%d")
    return "L" if session_date >= cutoff_date else "R"


# Prompt the user to select a folder
def select_folder():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    directory = filedialog.askdirectory()  # Open the file selection dialog
    return directory


def choose_option(option1, option2, option3, option4):
    result = {}

    def select(choice):
        result['value'] = choice
        root.destroy()

    root = tk.Tk()
    root.title("Choose the type of visual stim")

    tk.Label(root, text="Please choose the type of visual stim:").pack(pady=10)
    tk.Button(root, text=option1, width=12, command=lambda: select(option1)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option2, width=12, command=lambda: select(option2)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option3, width=12, command=lambda: select(option3)).pack(side='left', padx=10, pady=10)
    tk.Button(root, text=option4, width=12, command=lambda: select(option4)).pack(side='left', padx=10, pady=10)

    # Manual event loop, blocks until window is destroyed
    while not result.get('value'):
        root.update()

    return result['value']

# Prompt the user to choose the type of visual stim
#stim_type = choose_option("None","LR","UD","Interleaved")


# Function to remove parentheses characters from a line
def remove_parentheses_chars(line):
    # Remove only '(' and ')' characters
    return line.replace('(', '').replace(')', '')
def clean_csv(filename):
    with open(filename, 'r') as f:
        lines = [remove_parentheses_chars(line) for line in f]
    # Join lines and create a file-like object
        cleaned = StringIO(''.join(lines))
        return cleaned

# Butterworth filter to remove high frequency noise
def butter_noncausal(signal, fs, cutoff_freq=1, order=4):
    sos = butter(order, cutoff_freq/(fs/2), btype='low', output='sos')  # 50 Hz cutoff frequency
    return sosfiltfilt(sos, signal)   

def interpolate_nans(arr):
    nans = np.isnan(arr)
    x = np.arange(len(arr))
    arr[nans] = np.interp(x[nans], x[~nans], arr[~nans])
    return arr

def rotation_matrix(angle_rad):
    return np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
                     [np.sin(angle_rad), np.cos(angle_rad)]])


pca = PCA(n_components=2)

def vector_to_rgb(angle, absolute): ##Got it from https://stackoverflow.com/questions/19576495/color-matplotlib-quiver-field-according-to-magnitude-and-direction
    global max_abs

    # normalize angle
    angle = angle % (2 * np.pi)
    if angle < 0:
        angle += 2 * np.pi

    # return matplotlib.colors.hsv_to_rgb((angle / 2 / np.pi, 
    #                                      absolute / max_abs, 
    #                                      absolute / max_abs))
    return matplotlib.colors.hsv_to_rgb((angle / 2 / np.pi, 
                                         1, 
                                         1))
def plot_angle_distribution(angle, ax_polar, num_bins=18):
    """
    Plots a normalized polar histogram of angles.

    Parameters:
        angle (np.ndarray): array of saccade angles in radians
        ax_polar (matplotlib.axes._subplots.PolarAxesSubplot): the polar subplot to draw on
        num_bins (int): number of histogram bins
    """
    angle_2pi = np.where(angle < 0, angle + 2 * np.pi, angle)
    counts, bin_edges = np.histogram(angle_2pi, bins=num_bins, range=(0, 2 * np.pi))
    counts = counts / np.size(angle_2pi)  # Normalize
    width = np.diff(bin_edges)

    bars = ax_polar.bar(bin_edges[:-1], counts, width=width, align='edge', color='b', alpha=0.5, edgecolor='k')
    ax_polar.set_title("Normalized angle distribution")
    ax_polar.set_yticklabels([])

def plot_linear_histogram(angles, ax, num_bins=18):
    ang_deg = np.degrees(angles)
    ang_deg = np.mod(ang_deg, 360)
    counts, bins = np.histogram(ang_deg, bins=num_bins, range=(0, 360))
    counts = counts / ang_deg.size
    ax.bar(bins[:-1], counts, width=np.diff(bins), color="b", alpha=0.5, edgecolor="k")
    ax.set_xlabel("Angle (deg)")
    ax.set_ylabel("Normalised count")
    ax.set_title("Linear angle histogram")    

def detect_saccades(
    marker1_x, marker1_y, marker2_x, marker2_y,
    gaze_x, gaze_y,
    eye_frames,
    calibration_factor,
    blink_velocity_threshold, 
    saccade_threshold,
    blink_detection=0,
    vd_axis_lx=None, vd_axis_ly=None, vd_axis_rx=None, vd_axis_ry=None,
):
    ################# Analyze eye saccades function #################
    """
    #  Compute eye position in eye-centered coordinates.
    # Filter and convert raw gaze to degrees using calibration factor.
    # Compute speed of gaze movement to detect saccades.
    # Identify saccades that exceed a velocity threshold.
    # (Optional) Remove saccades likely caused by blinks.
    # Classify stimulus onsets (go_frame) by their direction.
    # Return relevant arrays for further analysis/plotting.
    """
    

    # 1. eye-centred coordinates  →  degrees
    eye_origin = np.column_stack((
        (marker1_x + marker2_x) / 2,
        (marker1_y + marker2_y) / 2
    ))
    eye_angle = np.arctan2(marker2_y - marker1_y, marker2_x - marker1_x)

    eye_camera = np.column_stack((gaze_x - eye_origin[:, 0], gaze_y - eye_origin[:, 1]))
    eye_camera[:, 0] = medfilt(eye_camera[:, 0], kernel_size=3)
    eye_camera[:, 1] = medfilt(eye_camera[:, 1], kernel_size=3)
    eye_camera = eye_camera / calibration_factor

    # 2. instantaneous velocity  →  speed
    eye_vel = np.zeros_like(eye_camera)
    eye_vel[:, 0] = np.ediff1d(eye_camera[:, 0], to_begin=0) ##TODO
    eye_vel[:, 1] = np.ediff1d(eye_camera[:, 1], to_begin=0)
    speed   = np.linalg.norm(eye_vel, axis=1)

    mask = speed >= saccade_threshold
    saccade_indices = np.where(mask)[0] # ← row indices (0…7158)
    saccade_frames = eye_frames[saccade_indices] # ← absolute Bonsai frames

    fig, ax = plt.subplots(figsize=(12, 4))

    frames = np.arange(len(speed))

    # 1) speed trace
    ax.plot(frames, speed, linewidth=0.8, label='Speed (°/frame)')

    # 2) highlight saccade frames
    ax.scatter(saccade_indices, speed[saccade_indices],
            color='tab:red', s=12, label='Saccade idx')

    # 3) threshold line
    ax.axhline(saccade_threshold, color='tab:orange',
            linestyle='--', label=f'Threshold = {saccade_threshold}')

    ax.set_xlabel('Frame number')
    ax.set_ylabel('Speed (° / frame)')
    ax.set_title('Instantaneous eye speed with detected saccade frames')
    ax.legend()
    ax.grid(alpha=.3)
    plt.tight_layout()
    plt.show()

    # save alongside other figures
    prob_fname = f"{session_name}_saccades.png"
    fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')



    # 3. optional blink removal
    if blink_detection:
        vd_axis_left = np.vstack([vd_axis_lx, vd_axis_ly]).T
        vd_axis_right = np.vstack([vd_axis_rx, vd_axis_ry]).T
        vd_axis_d = np.linalg.norm(vd_axis_right - vd_axis_left, axis=1)
        vd_axis_vel = np.gradient(vd_axis_d)
        blink_indices = np.where((np.abs(vd_axis_vel) > blink_velocity_threshold) | (vd_axis_d <= np.mean(vd_axis_d)-5*np.std(vd_axis_d)))[0] ## If the Ventral Dorsal axis distance falls 
         # below the mean minus 5 standard deviations and rate of change of the distance is more than threshold, it is considered a blink.
        saccade_indices = saccade_indices[~np.isin(saccade_indices, blink_indices)]

    return {
        "eye_camera":       eye_camera,
        "eye_vel":          eye_vel,
        "saccade_indices":  saccade_indices,
        "saccade_frames":   saccade_frames,
        }



def organize_stims(
    go_frame, 
    go_dir_x = None,
    go_dir_y = None):
    
    has_lr = go_dir_x is not None and np.any(go_dir_x != 0)
    has_ud = go_dir_y is not None and np.any(go_dir_y != 0)

    direction_sets = {}

    if has_lr:
        direction_sets["Left"]  = go_dir_x < 0
        direction_sets["Right"] = go_dir_x > 0
    if has_ud:
        direction_sets["Down"] = go_dir_y < 0
        direction_sets["Up"]   = go_dir_y > 0
    if not direction_sets:
        direction_sets["All"] = np.full(len(go_frame), True)

    stim_frames = {lab: go_frame[mask] for lab, mask in direction_sets.items()}

    # Return the inferred stim type too
    if has_lr and has_ud:
        stim_type = "Interleaved"
    elif has_lr:
        stim_type = "LR"
    elif has_ud:
        stim_type = "UD"
    else:
        stim_type = "None"

    return stim_frames, stim_type



def sort_plot_saccades(
        #sorts saccades by stimulus onset
    eye_camera, eye_camera_diff, saccade_indices, saccade_frames,
    stim_frames,                # dict: label → array of frames
    saccade_window,  # seconds before/after each stimulus
    session_path,
    stim_type='None',
    eye_name='Eye',
):

    session_name = os.path.basename(session_path.rstrip("/\\"))

    # ───────── global axis limits (all saccades) ─────────
    x_all = eye_camera[saccade_indices, 0]
    y_all = eye_camera[saccade_indices, 1]
    pad   = 0.10
    rngX  = x_all.max() - x_all.min()
    rngY  = y_all.max() - y_all.min()
    X_LIM = (x_all.min() - pad*rngX, x_all.max() + pad*rngX)
    Y_LIM = (y_all.min() - pad*rngY, y_all.max() + pad*rngY)

    max_abs = np.max(np.abs(eye_camera_diff))

    
    #calculate angles for all saccades
    angle_all = np.arctan2(eye_camera_diff[saccade_indices, 1],
                            eye_camera_diff[saccade_indices, 0])
    n_all = len(saccade_indices)

    # ───────── master figure (ALL saccades) ─────────
    fig = plt.figure(figsize=(11, 6))
    gs  = gridspec.GridSpec(2, 2, width_ratios=[3, 2])
    ax_quiver = fig.add_subplot(gs[:, 0])
    ax_polar  = fig.add_subplot(gs[0, 1], polar=True)
    ax_linear = fig.add_subplot(gs[1, 1])

    ax_quiver.set_xlim(*X_LIM); ax_quiver.set_ylim(*Y_LIM)
    ax_quiver.set_xlabel('X (°)'); ax_quiver.set_ylabel('Y (°)')
    ax_quiver.set_title(
        f"{session_name}\n"
        f"All saccades ({n_all}) — {eye_name}  (stim: {stim_type})\n"
        f"saccade_thresh = {saccade_thresh}, saccade_win = {saccade_win}s\n"
        f"blink_thresh = {blink_thresh}, blink_detection = {blink_detection}s\n"
        )

    cols = np.array([vector_to_rgb(a, max_abs) for a in angle_all])
    ax_quiver.quiver(x_all, y_all,
                        eye_camera_diff[saccade_indices, 0],
                        eye_camera_diff[saccade_indices, 1],
                        angles='xy', scale_units='xy', scale=1,
                        color=cols, alpha=.5)

    # PCA arrows (unchanged)
    pca.fit(eye_camera_diff[saccade_indices] /
            np.linalg.norm(eye_camera_diff[saccade_indices], axis=1, keepdims=True))
    for i, (vec, var) in enumerate(zip(pca.components_, pca.explained_variance_ratio_)):
        ax_quiver.arrow(np.mean(x_all), np.mean(y_all),
                        *(vec * 10 * np.sqrt(var)),
                        color=['k', 'b'][i], width=0.1,
                        label=f'PC{i+1} ({var:.2f} var)')
    ax_quiver.legend()

    plot_angle_distribution(angle_all, ax_polar)
    plot_linear_histogram(angle_all, ax_linear)
    plt.tight_layout()

    # save master figure
    all_fname = f"{session_name}_{eye_name}_ALL_{stim_type}.png"
    fig.savefig(results_dir / all_fname, dpi=300, bbox_inches='tight')


    # Determine the overall frame range [0, last_frame]
    last_frame = int(saccade_frames.max())
    clipped_any = False
    plot_window = np.arange(0,saccade_window,1)

    # ───────── one figure per stimulus label (skip "All") ─────────
    for label, frames in stim_frames.items():
        if label == "All":
            continue

        # gather 1st saccades within ±plot_window around each stim

        idx_buf = []  # buffer to collect saccade indices for this label

        # sort saccade frames to ensure they are in order
        sorted_pairs = sorted(zip(saccade_frames, saccade_indices))

        for f in frames:

            lower_bound = max(f + plot_window[0], 0)
            upper_bound = min(f + plot_window[-1], saccade_frames.max())

            for sf, idx in sorted_pairs:
                if sf < lower_bound:
                    continue
                elif sf <= upper_bound:
                    idx_buf.append(idx)   # first valid saccade
                    break                 # only take the first one
                else:
                    break                 # skip to next stim

        idx_use = np.array(idx_buf, dtype=int)
        if idx_use.size == 0:
            continue

        ang = np.arctan2(eye_camera_diff[idx_use, 1],
                            eye_camera_diff[idx_use, 0])
        n_cond = len(idx_use)

        fig = plt.figure(figsize=(6, 3))
        gs  = gridspec.GridSpec(2, 2, width_ratios=[3, 2])
        ax_q = fig.add_subplot(gs[:, 0])
        ax_p = fig.add_subplot(gs[0, 1], polar=True)
        ax_l = fig.add_subplot(gs[1, 1])

        ax_q.set_xlim(*X_LIM); ax_q.set_ylim(*Y_LIM)
        ax_q.set_xlabel('X (°)'); ax_q.set_ylabel('Y (°)')
        ax_q.set_title(f"{session_name}\n{eye_name} — {label} (n={n_cond})")

        cols = np.array([vector_to_rgb(a, max_abs) for a in ang])
        ax_q.quiver(eye_camera[idx_use, 0], eye_camera[idx_use, 1],
                    eye_camera_diff[idx_use, 0], eye_camera_diff[idx_use, 1],
                    angles='xy', scale_units='xy', scale=1,
                    color=cols, alpha=.5)

        plot_angle_distribution(ang, ax_p)
        plot_linear_histogram(ang, ax_l)

        fig.tight_layout()
        fname = f"{session_name}_{eye_name}_{label.replace('/','-')}.png"
        fig.savefig(results_dir / fname, dpi=300, bbox_inches='tight')




## Extract relevant files from filepath

In [116]:
# determine which camera is used based on the folder name and cut off date
camera_side = determine_camera_side(folder_path)
eye_name = f"{camera_side} Eye"  # e.g. "Left camera" or "Right camera"
print(f"Using camera side: {camera_side}")



print(f"Scanning folder: {folder_path}")
print(f"Found {len(os.listdir(folder_path))} files")

# Scan the folder for specific files
for f in os.listdir(folder_path):
    f_lower = f.lower()
    full_path = os.path.join(folder_path, f)

    if 'imu' in f_lower:
        IMU_file = full_path
    if 'camera' in f_lower:
        camera_file = full_path
    if 'go' in f_lower:
        go_file = full_path
    if f"ellipse_center_XY_{camera_side}".lower() in f_lower:
        ellipse_center_XY_file = full_path
    if f"origin_of_eyecoordinate_{camera_side}".lower() in f_lower:
        origin_of_eye_coordinate_file = full_path
    if f"vdaxis_{camera_side}".lower() in f_lower:
        vdaxis_file = full_path



Using camera side: L
Scanning folder: X:\Experimental_Data\EyeHeadCoupling_RatTS_server\TSh01_Paris_server\Tsh001_2025-07-09T15_29_43\\
Found 25 files


## Extract stim and eye position data 

In [117]:
## Read the camera data and map between camera TTL (for saccades) and Bonsai TTLs (for frames)
camera_data = np.genfromtxt(camera_file, delimiter=',', skip_header=1, dtype=np.float64)
[bonsai_frame, bonsai_time] = camera_data[:, 0], camera_data[:, 1]
bonsai_frame = bonsai_frame.astype(int)  # Convert bonsai_frame to integer type


### Read the go file for the start of stim in the trial 
new_go_data_format=0
go_data = np.genfromtxt(clean_csv(go_file), delimiter=',', skip_header=1, dtype=np.float64)
if go_data.shape[1]>3:
    new_go_data_format = 1
    [go_frame, go_time, go_direction_x,go_direction_y] = go_data[:, 0], go_data[:, 1], go_data[:, 2], go_data[:,3]
else:
    [go_frame, go_time, go_direction] = go_data[:, 0], go_data[:, 1], go_data[:, 2]
go_frame = go_frame.astype(int)  # Convert go_frame to integer type

### Read the ellipse center XY file  
ellipse_center_XY_data = np.genfromtxt(clean_csv(ellipse_center_XY_file), delimiter=',', skip_header=1, dtype=np.float64)
[eye_frame,eye_timestamp,eye_x,eye_y] = ellipse_center_XY_data[:, 0], ellipse_center_XY_data[:, 1], ellipse_center_XY_data[:, 2], ellipse_center_XY_data[:, 3]
eye_frame = eye_frame.astype(int)  # Convert eye_frame to integer type
eye_x = interpolate_nans(eye_x)  # Interpolate NaN values in eye_x
eye_y = interpolate_nans(eye_y)  # Interpolate NaN values in eye_y

### Read the origin of eye coordinate file
origin_of_eye_coordinate_data = np.genfromtxt(clean_csv(origin_of_eye_coordinate_file), delimiter=',', skip_header=1, dtype=np.float64)
[origin_frame,o_ts,l_x,l_y,r_x,r_y] = origin_of_eye_coordinate_data[:, 0], origin_of_eye_coordinate_data[:, 1], origin_of_eye_coordinate_data[:, 2], origin_of_eye_coordinate_data[:, 3], origin_of_eye_coordinate_data[:, 4], origin_of_eye_coordinate_data[:, 5]
origin_frame = origin_frame.astype(int)  # Convert origin_frame_r to integer type
l_x = interpolate_nans(l_x)  # Interpolate NaN values in l_rx
r_x = interpolate_nans(r_x)  # Interpolate NaN values in r_rx 
l_y = interpolate_nans(l_y)  # Interpolate NaN values in l_ry
r_y = interpolate_nans(r_y)  # Interpolate NaN values in r_ry

## Read the vertical (VD) axis data - this is used for blink detection
vdaxis_data = np.genfromtxt(clean_csv(vdaxis_file),delimiter=',',skip_header=1,dtype=np.float64)
[vd_frame,vd_ts,vd_lx,vd_ly,vd_rx,vd_ry] = vdaxis_data[:,0],vdaxis_data[:,1],vdaxis_data[:,2],vdaxis_data[:,3],vdaxis_data[:,4],vdaxis_data[:,5]
vd_frame = vd_frame.astype(int)
# Interpolate NaN values
vd_lx = interpolate_nans(vd_lx)
vd_ly = interpolate_nans(vd_ly)
vd_rx = interpolate_nans(vd_rx)
vd_ry = interpolate_nans(vd_ry)

        
### Read the IMU data for the accelerometer and gyroscope
imu_data = np.genfromtxt(IMU_file, delimiter=',', skip_header=1, dtype=np.float64)
[imu_time,a_x,a_y,a_z,g_x,g_y,g_z,m_x,m_y,m_z] = imu_data[:, 0], imu_data[:, 1], imu_data[:, 2], imu_data[:, 3], imu_data[:, 4], imu_data[:, 5], imu_data[:, 6], imu_data[:, 7], imu_data[:, 8], imu_data[:, 9]
imu_time = imu_time.astype(np.float64)  # Ensure imu_time is in float64 format
# Interpolate NaN values in IMU data
a_x = interpolate_nans(a_x)
a_y = interpolate_nans(a_y)
a_z = interpolate_nans(a_z)
g_x = interpolate_nans(g_x)
g_y = interpolate_nans(g_y)
g_z = interpolate_nans(g_z)
m_x = interpolate_nans(m_x)
m_y = interpolate_nans(m_y)
m_z = interpolate_nans(m_z)


## Sanity check: How far apart are the stimuli?  


In [118]:

d_frames = np.diff(go_frame)    # successive differences (frames)
d_sec = d_frames / ttl_freq

fig =plt.figure(figsize=(8,3))
plt.plot(d_sec, marker='o')
plt.xlabel('Stimulus index')
plt.ylabel('Δtime (s) to next stim')
plt.title('Seconds between successive Go Stims')
plt.grid(alpha=.3)
plt.tight_layout()
plt.show()

# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_Stim_Interval.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')

## Analyze and plot saccade statistics

In [119]:

# --------------------------------------------------------------
# ONE analyse-&-plot call for the Eye saccades
saccades = detect_saccades(
    l_x, l_y, r_x, r_y,
    eye_x, eye_y,
    eye_frame,
    cal,
    blink_detection = blink_detection,
    vd_axis_lx = vd_lx, vd_axis_ly = vd_ly,
    vd_axis_rx =vd_rx, vd_axis_ry = vd_ry,
    saccade_threshold       = saccade_thresh,
    blink_velocity_threshold= blink_thresh
)

saccades["stim_frames"], stim_type = organize_stims(
    go_frame,
    go_dir_x = go_direction_x,
    go_dir_y = go_direction_y,
)
sort_plot_saccades(
    saccades["eye_camera"],
    saccades["eye_vel"],
    saccades["saccade_indices"],
    saccades["saccade_frames"],
    saccades["stim_frames"],
    saccade_window= saccade_win*ttl_freq,
    session_path = folder_path,
    stim_type    = stim_type,
    eye_name     = eye_name,
)


## Look at probability of saccades as a function of stimulus onset

In [120]:
go_frame        = go_frame = np.array(go_frame).flatten()
saccade_frames  = saccades["saccade_frames"]
ttl_freq        = 60
t_window_s      = 2
bin_ms          = 50

total_duration_frames = None

# ─── Setup ─────────────────────────────────────────────
frame_win   = int(t_window_s * ttl_freq)
bin_frames  = int((bin_ms / 1000) * ttl_freq)
bins        = np.arange(-frame_win, frame_win + 1, bin_frames)
t_sec       = (bins[:-1] + bin_frames / 2) / ttl_freq

# ─── Estimate baseline saccade probability ─────────────
if total_duration_frames is None:
    total_duration_frames = saccade_frames.max()

total_bins = total_duration_frames / bin_frames
baseline_rate = len(saccade_frames) / total_bins  # saccades per bin

# ─── Helper: peri-stimulus histogram ──────────────────
def peristim_change(stim_frames):
    rel_times = []
    for f0 in stim_frames:
        nearby = saccade_frames[
            (saccade_frames >= f0 - frame_win) &
            (saccade_frames <= f0 + frame_win)
        ]
        rel_times.extend(nearby - f0)
    rel_times = np.array(rel_times)

    counts, _ = np.histogram(rel_times, bins=bins)
    prob_bin = counts / len(stim_frames)  # saccades per bin per stim
    change = 100 * (prob_bin - baseline_rate) / baseline_rate
    return change

# ─── Group by direction ────────────────────────────────
go_dir_x = np.nan_to_num(go_direction_x)
go_dir_y = np.nan_to_num(go_direction_y)

stim_all = go_frame
stim_L   = go_frame[go_dir_x < 0]
stim_R   = go_frame[go_dir_x > 0]
stim_D   = go_frame[go_dir_y < 0]
stim_U   = go_frame[go_dir_y > 0]

# ─── Compute changes ───────────────────────────────────
ch_all = peristim_change(stim_all)
ch_L   = peristim_change(stim_L)
ch_R   = peristim_change(stim_R)
ch_D   = peristim_change(stim_D)
ch_U   = peristim_change(stim_U)

# ─── Plot ──────────────────────────────────────────────
# Plotting
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 4), sharex=True)

# Top: all stimuli
ax1.plot(t_sec, ch_all, label="All stimuli", color="black", lw=2)
ax1.axvline(0, color='k', ls='--', alpha=.5)
ax1.axhline(0, color='gray', ls=':', lw=1)
ax1.set_title('Stim-locked saccade probability (relative to baseline saccade rate)')
ax1.grid(alpha=.3)
ax1.set_ylim(-100,200)
ax1.legend()

# Bottom: each direction
if ch_L is not None:
    ax2.plot(t_sec, ch_L, label=f"Left (n={len(stim_L)})", color='green')
if ch_R is not None:
    ax2.plot(t_sec, ch_R, label=f"Right (n={len(stim_R)})", color='pink')
if ch_D is not None:
    ax2.plot(t_sec, ch_D, label=f"Down (n={len(stim_D)})", color='blue')
if ch_U is not None:
    ax2.plot(t_sec, ch_U, label=f"Up (n={len(stim_U)})", color='red')

ax2.axvline(0, color='k', ls='--', alpha=.5)
ax2.axhline(0, color='gray', ls=':', lw=1)
ax2.set_xlabel('Time from stimulus onset (s)')
ax2.set_title('Modulation by direction')
ax2.set_ylim(-100,200)
ax2.grid(alpha=.3)
#ax2.legend()

# Shared y-axis label
fig.text(0.01, 0.5, '% Change ',
         va='center', rotation='vertical', fontsize=11)

fig.tight_layout()


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_StimLockedProb.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')

  prob_bin = counts / len(stim_frames)  # saccades per bin per stim


## Sanity check: Plot an overlay of Go-stims and Saccades 

In [121]:
# ============================================================
# Timeline plot: saccades (grey) + colour-coded stimuli
# ============================================================
# -----------------------------------------------------------------
# Assumes these objects already exist in the workspace
# -----------------------------------------------------------------
# right["saccade_frames"]   – Bonsai frame IDs of all saccades
# go_frame                  – Bonsai frame IDs of stimuli
# go_dir_x, go_dir_y        – direction codes (can be None)
# ttl_freq                  – camera TTL rate (Hz)
# results_dir               – Path(folder_path) / "Results"
# -----------------------------------------------------------------

# 1) prep data ----------------------------------------------------
saccade_frames = np.asarray(saccades["saccade_frames"], dtype=int)

# fallback arrays if dir arrays are missing
gx = go_dir_x if (go_dir_x is not None) else np.zeros_like(go_frame)
gy = go_dir_y if (go_dir_y is not None) else np.zeros_like(go_frame)


# ── 1-bis. count how many of each stimulus -----------------------------
eps = 1e-6                      # tolerance for “zero”
is_left   = (np.abs(gy) < eps) & (gx < -eps)
is_right  = (np.abs(gy) < eps) & (gx >  eps)
is_down   = (np.abs(gx) < eps) & (gy < -eps)
is_up     = (np.abs(gx) < eps) & (gy >  eps)

n_left, n_right = is_left.sum(),  is_right.sum()
n_down, n_up    = is_down.sum(),  is_up.sum()


# palette mapping
palette = {'L': 'green', 'R': 'pink', 'D': 'blue', 'U': 'red', 'NA': 'gray'}

# build colour list per stimulus
colors = []
for x, y in zip(gx, gy):
    if abs(y) > 1e-6:                       # Up / Down has priority
        colors.append(palette['U' if y > 0 else 'D'])
    elif abs(x) > 1e-6:                     # Left / Right
        colors.append(palette['R' if x > 0 else 'L'])
    else:
        colors.append(palette['NA'])

# sort frames & colours together
order      = np.argsort(go_frame)
t_stim     = go_frame[order] / ttl_freq
colors     = [colors[i] for i in order]

# convert saccade frames to seconds
t_sacc = np.sort(saccade_frames) / ttl_freq

# 2) plot ---------------------------------------------------------
fig, ax = plt.subplots(figsize=(12, 2.5))

# saccades: grey vertical ticks at y = 0
ax.vlines(t_sacc, -0.1, 0.1, colors='0.25', linewidth=1)

# stimuli: colour ticks at y = 1
ax.vlines(t_stim, 0.9, 1.1, colors=colors, linewidth=2)

# axes formatting
ax.set_yticks([0, 1])
ax.set_yticklabels(['Saccade', 'Stim'])
ax.set_xlabel('Time (s)')
ax.set_title('Timeline of saccades and stimuli')
ax.set_xlim(t_sacc.min() - 1, t_sacc.max() + 1)
ax.set_ylim(-0.5, 1.5)
ax.grid(axis='x', alpha=.3)

# legend
handles = [
    mlines.Line2D([], [], color='0.25', marker='|', ls='', markersize=10,
                  label='Saccade'),
    mlines.Line2D([], [], color='green', marker='|', ls='', markersize=10,
                  label=f'Stim Left  (n={n_left})'),
    mlines.Line2D([], [], color='pink',  marker='|', ls='', markersize=10,
                  label=f'Stim Right (n={n_right})'),
    mlines.Line2D([], [], color='blue',  marker='|', ls='', markersize=10,
                  label=f'Stim Down  (n={n_down})'),
    mlines.Line2D([], [], color='red',   marker='|', ls='', markersize=10,
                  label=f'Stim Up    (n={n_up})')
]
ax.legend(handles=handles, loc='upper right', ncol=5, fontsize=9, framealpha=.9)

plt.tight_layout()


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_timeline_saccade_vs_stim.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')


## Plot how many stims produced at least one saccade

In [122]:
# ============================================================
#  Probability of ≥1 saccade within 0.5 s of each stimulus


# ----- parameters -------------------------------------------
win     = saccade_win                  # seconds after onset
w_frames = int(win * ttl_freq)  # convert to frames

# direction masks
gx = go_direction_x if go_direction_x is not None else np.zeros_like(go_frame)
gy = go_direction_y if go_direction_y is not None else np.zeros_like(go_frame)

dir_info = {
    'Left' :  (gx < -1e-6,  'green'),
    'Right':  (gx >  1e-6,  'pink'),
    'Down' :  (gy < -1e-6,  'blue'),
    'Up'   :  (gy >  1e-6,  'red')
}

labels, probs, colors = [], [], []

for label, (mask, col) in dir_info.items():
    stim_frames = go_frame[mask]
    n_stim      = len(stim_frames)
    if n_stim == 0:
        continue                                 # skip if this direction absent
    # check each stimulus: does ANY saccade happen within +win seconds?
    has_sacc = [( (saccade_frames >= f) & (saccade_frames <= f + w_frames) ).any()
                for f in stim_frames]
    prob = np.mean(has_sacc)                    # fraction of stimuli with ≥1 sac
    labels.append(label)
    probs.append(prob)
    colors.append(col)
    print(f"{label:5s}: {prob*100:5.1f}%  ({sum(has_sacc)}/{n_stim} stimuli)")

# ----- bar chart --------------------------------------------
fig, ax = plt.subplots(figsize=(6,4))
ax.bar(labels, probs, color=colors, edgecolor='k')
ax.set_ylim(0, 1)
ax.set_ylabel(f"P(saccade within {win}s)")
ax.set_title(f"Probability of a saccade in first {win} s after stimulus")
ax.grid(axis='y', alpha=.3)
plt.tight_layout()

# ----- save (optional) --------------------------------------


# optional: save alongside other figures
prob_fname = f"{session_name}_{eye_name}_saccade_prob_within_{win*1000:.0f}ms.png"
fig.savefig(results_dir / prob_fname, dpi=300, bbox_inches='tight')



Left :  43.9%  (58/132 stimuli)
Right:  43.4%  (53/122 stimuli)


## Plot a dynamic figure that shows each stim and the corresponding eye movements until the next stim

In [123]:
# import matplotlib.pyplot as plt
# import numpy as np
# from matplotlib.animation import FFMpegWriter
# import matplotlib as mpl

# # Point to ffmpeg if needed
# mpl.rcParams['animation.ffmpeg_path'] = r'X:\Software and Code\ffmpeg\bin\ffmpeg.exe'

# # Data
# eye_camera = right["eye_camera"] - right["eye_camera"].mean(axis=0)
# eye_camera = np.array(eye_camera)
# go_time = np.array(go_frame) / 60  # seconds
# eye_time = np.arange(len(eye_camera)) / 60
# gx = go_dir_x
# gy = go_dir_y

# # Display and video params
# SCREEN_XLIM = (-25, 25)
# SCREEN_YLIM = (-18, 18)
# scale_factor = 3
# stim_offset = 20
# fps = 10
# filename = "eye_trace_by_stim.mp4"
# video_path = results_dir / filename

# # Writer setup
# metadata = dict(title='Scaled eye trace, 0-0.5 seconds after stim onset', artist='Matplotlib')
# writer = FFMpegWriter(fps=fps, metadata=metadata)

# fig, ax = plt.subplots(figsize=(8, 6))

# with writer.saving(fig, str(video_path), dpi=200):  # make sure to convert to string
#     for i in range(len(go_time) - 1):   # iterate over each stimulus
#         start_t = go_time[i]
#         end_t = go_time[i] + 0.5
#         mask = (eye_time >= start_t) & (eye_time < end_t)
#         trace = eye_camera[mask] * scale_factor

#         stim_x = gx[i] * stim_offset
#         stim_y = gy[i] * stim_offset

#         ax.clear()
#         ax.set_title(f"Scaled eye trace, 0-0.5 seconds after stim onset. Stimulus {i+1}  (Time {start_t:.2f}s)")
#         ax.set_xlim(SCREEN_XLIM)
#         ax.set_ylim(SCREEN_YLIM)
#         ax.set_xlabel("Horizontal (deg)")
#         ax.set_ylabel("Vertical (deg)")
#         ax.grid(True)
#         ax.plot(stim_x, stim_y, 'o', color='blue', markersize=10, label='Stimulus')
#         ax.legend(loc='upper right')
#         writer.grab_frame()  # capture initial frame

#         # animate trace dynamically
#         for j in range(1, len(trace)):
#             ax.plot(trace[j-1:j+1, 0], trace[j-1:j+1, 1], 'k-')
#             if j == 1:
#                 ax.plot(trace[0, 0], trace[0, 1], 'go', markersize=6, label='Start')
#             elif j == len(trace) - 1:
#                 ax.plot(trace[j, 0], trace[j, 1], 'ro', markersize=6, label='End')
#             writer.grab_frame()

# plt.close()
# print(f"Saved video → {video_path}")
