# Prosaccade Feedback Session Analysis

This notebook analyzes data from a saccade task where:
- The animal's eye movements map to a green dot on the monitor
- A blue target dot appears at some location
- The animal is rewarded when the green dot (eye position) touches the blue dot (target)
- After a delay, a new trial starts

The notebook produces:
1. Trajectory plots showing eye position paths relative to target position
2. Time-to-target analysis showing trial durations
3. Various statistical analyses and comparisons

## 1. Setup and Imports

In [None]:
from __future__ import annotations

import sys
from pathlib import Path
from typing import Tuple, Optional
from collections import defaultdict
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.animation as animation

# Enable inline plotting
%matplotlib inline

# For interactive plots (uncomment if using interactive viewers)
# %matplotlib widget

# Add the Python folder to path
notebook_dir = Path.cwd()
python_dir = notebook_dir.parent
if str(python_dir) not in sys.path:
    sys.path.insert(0, str(python_dir))

# Import utilities
from eyehead.io import clean_csv

print(f"Notebook directory: {notebook_dir}")
print(f"Python directory: {python_dir}")
print("Setup complete!")

## 2. Configuration

Set your data folder path and parameters here:

In [None]:
# =============================================================================
# CONFIGURATION - MODIFY THESE VALUES
# =============================================================================

# Path to your data folder containing the CSV files
FOLDER_PATH = Path("/path/to/your/data/folder")

# Animal ID
ANIMAL_ID = "Tsh001"

# Results directory (set to None to use FOLDER_PATH/results)
RESULTS_DIR = None

# Trial duration filters for analysis
TRIAL_MIN_DURATION = 0.01  # seconds
TRIAL_MAX_DURATION = 15.0  # seconds

# Whether to include failed trials in analysis
INCLUDE_FAILED_TRIALS = False

# =============================================================================

# Set up results directory
if RESULTS_DIR is None:
    RESULTS_DIR = FOLDER_PATH / "results"
else:
    RESULTS_DIR = Path(RESULTS_DIR)

# Try to extract date from folder name
date_match = re.search(r'\d{4}-\d{2}-\d{2}', str(FOLDER_PATH))
DATE_STR = date_match.group() if date_match else ""

print(f"Data folder: {FOLDER_PATH}")
print(f"Results directory: {RESULTS_DIR}")
print(f"Animal ID: {ANIMAL_ID}")
print(f"Session date: {DATE_STR}")

## 3. Data Loading Functions

In [None]:
def load_feedback_data(folder_path: Path, animal_id: str = "Tsh001") -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Load the three CSV files for saccade feedback analysis.

    Parameters
    ----------
    folder_path : Path
        Path to the folder containing the CSV files
    animal_id : str
        Animal identifier prefix for the files (default: "Tsh001")

    Returns
    -------
    tuple of (end_of_trial_df, eye_position_df, target_position_df)
        DataFrames containing the loaded and cleaned data
    """
    csv_files = list(folder_path.glob("*.csv"))

    endoftrial_file = None
    vstim_go_file = None
    vstim_cue_file = None

    for f in csv_files:
        fname = f.name.lower()
        if "endoftrial" in fname:
            endoftrial_file = f
        elif "vstim_go" in fname:
            vstim_go_file = f
        elif "vstim_cue" in fname:
            vstim_cue_file = f

    if endoftrial_file is None:
        raise FileNotFoundError(f"Could not find endoftrial file in {folder_path}")
    if vstim_go_file is None:
        raise FileNotFoundError(f"Could not find vstim_go file in {folder_path}")
    if vstim_cue_file is None:
        raise FileNotFoundError(f"Could not find vstim_cue file in {folder_path}")

    # Load end of trial data
    print(f"\nLoading {endoftrial_file.name}...")
    cleaned = clean_csv(str(endoftrial_file))
    eot_arr = np.genfromtxt(cleaned, delimiter=",", skip_header=1, dtype=float)

    if eot_arr.ndim == 1:
        eot_arr = eot_arr.reshape(1, -1)

    n_cols = eot_arr.shape[1]
    print(f"  Detected {n_cols} columns in endoftrial file")

    if n_cols >= 3:
        if n_cols == 7:
            eot_df = pd.DataFrame(eot_arr, columns=['frame', 'timestamp', 'trial_success', 'trial_number', 'green_x', 'green_y', 'diameter'])
        elif n_cols == 6:
            eot_df = pd.DataFrame(eot_arr, columns=['frame', 'timestamp', 'trial_success', 'green_x', 'green_y', 'diameter'])
            eot_df['diameter'] = 0.2
        else:
            eot_df = pd.DataFrame({
                'frame': eot_arr[:, 0],
                'timestamp': eot_arr[:, 1],
                'trial_success': eot_arr[:, 2]
            })
    else:
        raise ValueError(f"Unexpected number of columns: {n_cols}. Expected at least 3.")

    eot_df['frame'] = eot_df['frame'].astype(int)
    eot_df['trial_success'] = eot_df['trial_success'].astype(int)
    if 'trial_number' in eot_df.columns:
        eot_df['trial_number'] = eot_df['trial_number'].astype(int)

    print(f"  Loaded {len(eot_df)} end-of-trial events")
    if 'trial_success' in eot_df.columns:
        n_success = (eot_df['trial_success'] == 1).sum()
        n_failed = (eot_df['trial_success'] == 0).sum()
        print(f"  Trial success indicators: {n_success} successful, {n_failed} failed")

    # Load eye position data
    print(f"\nLoading {vstim_go_file.name}...")
    cleaned = clean_csv(str(vstim_go_file))
    eye_arr = np.genfromtxt(cleaned, delimiter=",", skip_header=1, dtype=float)

    if eye_arr.ndim == 1:
        eye_arr = eye_arr.reshape(1, -1)

    n_cols = eye_arr.shape[1]
    print(f"  Detected {n_cols} columns in vstim_go file")

    if n_cols < 4:
        raise ValueError(f"Too few columns: {n_cols}. Expected at least 4")

    eye_df = pd.DataFrame({
        'frame': eye_arr[:, 0],
        'timestamp': eye_arr[:, 1],
        'green_x': eye_arr[:, -3],
        'green_y': eye_arr[:, -2],
        'diameter': eye_arr[:, -4] if n_cols >= 4 else 0.2,
    })
    eye_df['frame'] = eye_df['frame'].astype(int)
    print(f"  Loaded {len(eye_df)} eye position samples")

    # Load target position data
    print(f"\nLoading {vstim_cue_file.name}...")
    cleaned = clean_csv(str(vstim_cue_file))
    target_arr = np.genfromtxt(cleaned, delimiter=",", skip_header=1, dtype=float)

    if target_arr.ndim == 1:
        target_arr = target_arr.reshape(1, -1)

    n_cols = target_arr.shape[1]
    print(f"  Detected {n_cols} columns in vstim_cue file")

    if n_cols == 6:
        target_df = pd.DataFrame(target_arr, columns=['frame', 'timestamp', 'target_x', 'target_y', 'diameter', 'visible'])
        print(f"  Target visibility column detected")
    elif n_cols == 5:
        target_df = pd.DataFrame(target_arr, columns=['frame', 'timestamp', 'target_x', 'target_y', 'diameter'])
        target_df['visible'] = 1
    else:
        raise ValueError(f"Unexpected number of columns: {n_cols}. Expected 4, 5, or 6.")

    target_df['frame'] = target_df['frame'].astype(int)
    target_df['visible'] = target_df['visible'].astype(int)
    target_df['diameter'] = target_df['diameter'].astype(float)

    # Remove duplicates
    duplicates = target_df.duplicated(subset=['frame'], keep='first')
    n_duplicates = duplicates.sum()
    if n_duplicates > 0:
        print(f"  Warning: Removed {n_duplicates} duplicate entries")
        target_df = target_df[~duplicates].reset_index(drop=True)

    print(f"  Loaded {len(target_df)} target position samples")

    print(f"\nData loaded successfully!")
    print(f"  Frame range: {eye_df['frame'].min()} to {eye_df['frame'].max()}")
    print(f"  Timestamp range: {eot_df['timestamp'].min():.2f} to {eot_df['timestamp'].max():.2f}")

    return eot_df, eye_df, target_df

In [None]:
def identify_and_filter_failed_trials(target_df: pd.DataFrame, eot_df: pd.DataFrame,
                                      exclude_failed: bool = True) -> Tuple[pd.DataFrame, list, list]:
    """Identify failed trials using the trial_success column in end_of_trial data."""
    if len(target_df) == 0 or len(eot_df) == 0:
        print("\nWarning: Empty target or end-of-trial data")
        return target_df, [], []

    if 'trial_success' not in eot_df.columns:
        print("\nWarning: trial_success column not found, assuming all trials successful")
        successful_indices = list(range(len(target_df)))
        failed_indices = []
    else:
        target_df = target_df.sort_values('timestamp').reset_index(drop=True)
        eot_df = eot_df.sort_values('timestamp').reset_index(drop=True)

        trial_success_flags = eot_df['trial_success'].values

        successful_indices = []
        failed_indices = []

        n_trials = min(len(target_df), len(eot_df))
        for idx in range(n_trials):
            if idx < len(trial_success_flags) and trial_success_flags[idx] == 1:
                successful_indices.append(idx)
            else:
                failed_indices.append(idx)

    n_total = len(target_df)
    n_success = len(successful_indices)
    n_failed = len(failed_indices)

    print(f"\n{'='*60}")
    print(f"Trial Summary:")
    print(f"  Total trials: {n_total}")
    print(f"  Successful trials: {n_success} ({100*n_success/n_total:.1f}%)")
    print(f"  Failed trials: {n_failed} ({100*n_failed/n_total:.1f}%)")
    print(f"{'='*60}\n")

    if exclude_failed and n_success > 0:
        filtered_target_df = target_df.iloc[successful_indices].reset_index(drop=True)
        return filtered_target_df, failed_indices, successful_indices
    else:
        return target_df, failed_indices, successful_indices

In [None]:
def extract_trial_trajectories(eot_df: pd.DataFrame, eye_df: pd.DataFrame,
                                target_df: pd.DataFrame,
                                successful_indices: Optional[list] = None) -> list[dict]:
    """Extract eye position trajectories for each trial."""
    trials = []
    n_trials = len(target_df)

    # Calculate ITI
    if n_trials > 1:
        time_diffs = np.diff(target_df['timestamp'].values)
        min_diff = np.min(time_diffs)
        ITI = np.floor(min_diff)
        print(f"\nCalculated ITI: {ITI:.0f} seconds")
    else:
        ITI = 0

    for i in range(n_trials):
        if 'original_trial_number' in target_df.columns:
            trial_num = int(target_df.iloc[i]['original_trial_number'])
        else:
            trial_num = i + 1

        target_x = target_df.iloc[i]['target_x']
        target_y = target_df.iloc[i]['target_y']
        target_diameter = target_df.iloc[i]['diameter']
        target_visible = target_df.iloc[i]['visible']
        start_frame = target_df.iloc[i]['frame']
        start_time = target_df.iloc[i]['timestamp']

        if i < len(eot_df):
            end_frame = int(eot_df.iloc[i]['frame'])
            end_time = eot_df.iloc[i]['timestamp']
        else:
            end_frame = start_frame + 1000
            end_time = start_time + ITI

        eye_mask = (eye_df['frame'] >= start_frame) & (eye_df['frame'] <= end_frame)
        eye_trajectory = eye_df[eye_mask].dropna(subset=['green_x', 'green_y', 'timestamp'])

        has_eye_data = len(eye_trajectory) > 0

        if not has_eye_data:
            trial_data = {
                'trial_number': trial_num,
                'start_frame': start_frame,
                'end_frame': end_frame,
                'start_time': start_time,
                'end_time': end_time,
                'duration': 0.0,
                'target_x': target_x,
                'target_y': target_y,
                'target_diameter': target_diameter,
                'target_visible': target_visible,
                'start_eye_x': np.nan,
                'start_eye_y': np.nan,
                'final_eye_x': np.nan,
                'final_eye_y': np.nan,
                'eye_x': np.array([]),
                'eye_y': np.array([]),
                'eye_times': np.array([]),
                'path_length': 0.0,
                'path_efficiency': 0.0,
                'initial_direction_error': np.nan,
                'trial_failed': successful_indices is not None and i not in successful_indices,
                'has_eye_data': False,
            }
        else:
            start_eye_x = eye_trajectory['green_x'].values[0]
            start_eye_y = eye_trajectory['green_y'].values[0]
            eye_times_raw = eye_trajectory['timestamp'].values
            eye_duration = eye_times_raw[-1] - eye_times_raw[0]

            last_within_trial_idx = eye_trajectory.index[-1]
            eye_df_position = eye_df.index.get_loc(last_within_trial_idx)

            if eye_df_position + 1 < len(eye_df):
                next_row = eye_df.iloc[eye_df_position]
                final_eye_x = next_row['green_x']
                final_eye_y = next_row['green_y']
            else:
                final_eye_x = eye_trajectory['green_x'].values[-1]
                final_eye_y = eye_trajectory['green_y'].values[-1]

            eye_x_full = np.append(eye_trajectory['green_x'].values, final_eye_x)
            eye_y_full = np.append(eye_trajectory['green_y'].values, final_eye_y)

            # Path metrics
            if len(eye_trajectory) > 1:
                dx = np.diff(eye_trajectory['green_x'].values)
                dy = np.diff(eye_trajectory['green_y'].values)
                segment_lengths = np.sqrt(dx**2 + dy**2)
                path_length = np.sum(segment_lengths)

                straight_line_distance = np.sqrt((target_x - start_eye_x)**2 + (target_y - start_eye_y)**2)
                path_efficiency = straight_line_distance / path_length if path_length > 0 else 0.0
            else:
                path_length = 0.0
                path_efficiency = 0.0

            trial_data = {
                'trial_number': trial_num,
                'start_frame': start_frame,
                'end_frame': end_frame,
                'start_time': start_time,
                'end_time': end_time,
                'duration': eye_duration,
                'target_x': target_x,
                'target_y': target_y,
                'target_diameter': target_diameter,
                'target_visible': target_visible,
                'start_eye_x': start_eye_x,
                'start_eye_y': start_eye_y,
                'final_eye_x': final_eye_x,
                'final_eye_y': final_eye_y,
                'eye_x': eye_x_full,
                'eye_y': eye_y_full,
                'eye_times': eye_times_raw,
                'eye_start_time': eye_times_raw[0],
                'path_length': path_length,
                'path_efficiency': path_efficiency,
                'initial_direction_error': np.nan,
                'trial_failed': successful_indices is not None and i not in successful_indices,
                'has_eye_data': True,
            }

        trials.append(trial_data)

    print(f"\nExtracted {len(trials)} trials")
    if len(trials) > 0:
        durations = [t['duration'] for t in trials if t['has_eye_data']]
        if durations:
            print(f"  Mean trial duration: {np.mean(durations):.2f}s")

    return trials

## 4. Plotting Functions

In [None]:
def plot_trajectories(trials: list[dict], results_dir: Optional[Path] = None,
                      animal_id: Optional[str] = None, session_date: str = "") -> plt.Figure:
    """Plot eye position trajectories and target positions."""
    fig, ax = plt.subplots(figsize=(12, 10))

    colors_list = ['#9b59b6', '#3498db', '#2ecc71']
    cmap = LinearSegmentedColormap.from_list('purple_green', colors_list)
    n_trials = len(trials)

    for i, trial in enumerate(trials):
        if not trial.get('has_eye_data', True):
            continue

        eye_x = trial['eye_x']
        eye_y = trial['eye_y']

        color = cmap(i / max(1, n_trials - 1))
        ax.plot(eye_x, eye_y, '-', color=color, alpha=0.6, linewidth=1.5)

        ax.plot(eye_x[0], eye_y[0], 'o', color=color, markersize=8, alpha=0.9,
                markeredgecolor='white', markeredgewidth=1)
        final_x = trial.get('final_eye_x', eye_x[-1])
        final_y = trial.get('final_eye_y', eye_y[-1])
        ax.plot(final_x, final_y, 's', color=color, markersize=8, alpha=0.9,
                markeredgecolor='white', markeredgewidth=1)

        target_x = trial['target_x']
        target_y = trial['target_y']
        target_radius = trial['target_diameter'] / 2.0
        target_visible = trial.get('target_visible', 1)

        linestyle = '-' if target_visible else '--'
        alpha_val = 1.0 if target_visible else 0.4

        target_circle = Circle((target_x, target_y), radius=target_radius, fill=False,
                              edgecolor='black', linewidth=2.5, linestyle=linestyle,
                              alpha=alpha_val)
        ax.add_patch(target_circle)
        ax.plot(target_x, target_y, 'ko', markersize=4)

    ax.set_xlabel('Horizontal Position (stimulus units)', fontsize=12)
    ax.set_ylabel('Vertical Position (stimulus units)', fontsize=12)

    title = 'Eye Position Trajectories to Target'
    if animal_id:
        title += f' - {animal_id}'
    if session_date:
        title += f' ({session_date})'
    ax.set_title(title, fontsize=14, fontweight='bold')

    ax.grid(True, alpha=0.3)
    ax.set_xlim(-1.7, 1.7)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal', adjustable='box')

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=1, vmax=n_trials))
    sm.set_array([])
    plt.colorbar(sm, ax=ax, label='Trial Number')

    plt.tight_layout()

    if results_dir:
        results_dir.mkdir(parents=True, exist_ok=True)
        prefix = f"{animal_id}_" if animal_id else ""
        filename = f"{prefix}saccade_feedback_trajectories.png"
        fig.savefig(results_dir / filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {results_dir / filename}")

    return fig

In [None]:
def plot_trajectories_by_direction(trials: list[dict], results_dir: Optional[Path] = None,
                                   animal_id: Optional[str] = None, session_date: str = "") -> plt.Figure:
    """Plot trajectories colored by target direction (left vs right)."""
    fig, ax = plt.subplots(figsize=(12, 10))

    left_trials = [t for t in trials if t['target_x'] < 0 and t.get('has_eye_data', True)]
    right_trials = [t for t in trials if t['target_x'] >= 0 and t.get('has_eye_data', True)]

    left_color = 'blue'
    right_color = 'red'

    for trial in left_trials:
        eye_x = trial['eye_x']
        eye_y = trial['eye_y']
        ax.plot(eye_x, eye_y, '-', color=left_color, alpha=0.5, linewidth=1.5)
        ax.plot(eye_x[0], eye_y[0], 'o', color=left_color, markersize=6, alpha=0.7)

    for trial in right_trials:
        eye_x = trial['eye_x']
        eye_y = trial['eye_y']
        ax.plot(eye_x, eye_y, '-', color=right_color, alpha=0.5, linewidth=1.5)
        ax.plot(eye_x[0], eye_y[0], 'o', color=right_color, markersize=6, alpha=0.7)

    targets_drawn = set()
    for trial in trials:
        target_x = trial['target_x']
        target_y = trial['target_y']
        target_radius = trial['target_diameter'] / 2.0
        key = (round(target_x, 2), round(target_y, 2))

        if key not in targets_drawn:
            target_circle = Circle((target_x, target_y), radius=target_radius,
                                  fill=False, edgecolor='black', linewidth=2.5)
            ax.add_patch(target_circle)
            ax.plot(target_x, target_y, 'ko', markersize=4)
            targets_drawn.add(key)

    ax.set_xlabel('Horizontal Position (stimulus units)', fontsize=12)
    ax.set_ylabel('Vertical Position (stimulus units)', fontsize=12)

    title = 'Eye Trajectories by Target Direction'
    if animal_id:
        title += f' - {animal_id}'
    ax.set_title(title, fontsize=14, fontweight='bold')

    ax.grid(True, alpha=0.3)
    ax.set_xlim(-1.7, 1.7)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal', adjustable='box')

    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='blue', label=f'Left targets (n={len(left_trials)})'),
        Line2D([0], [0], color='red', label=f'Right targets (n={len(right_trials)})'),
    ]
    ax.legend(handles=legend_elements, loc='upper right')

    plt.tight_layout()

    if results_dir:
        results_dir.mkdir(parents=True, exist_ok=True)
        prefix = f"{animal_id}_" if animal_id else ""
        filename = f"{prefix}saccade_feedback_trajectories_by_direction.png"
        fig.savefig(results_dir / filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {results_dir / filename}")

    return fig

In [None]:
def plot_density_heatmap(trials: list[dict], results_dir: Optional[Path] = None,
                         animal_id: Optional[str] = None, session_date: str = "") -> plt.Figure:
    """Plot 2D histogram heatmap showing density of eye positions."""
    fig, ax = plt.subplots(figsize=(12, 10))

    all_x = []
    all_y = []
    for trial in trials:
        if trial.get('has_eye_data', True):
            all_x.extend(trial['eye_x'])
            all_y.extend(trial['eye_y'])

    all_x = np.array(all_x)
    all_y = np.array(all_y)

    bins = 50
    h, xedges, yedges = np.histogram2d(all_x, all_y, bins=bins, range=[[-1.7, 1.7], [-1, 1]])

    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    im = ax.imshow(h.T, extent=extent, origin='lower', cmap='hot', aspect='auto', interpolation='bilinear')

    plt.colorbar(im, ax=ax, label='Number of Samples')

    for trial in trials:
        target_x = trial['target_x']
        target_y = trial['target_y']
        target_radius = trial['target_diameter'] / 2.0
        target_circle = Circle((target_x, target_y), radius=target_radius, fill=False,
                              edgecolor='cyan', linewidth=2, alpha=0.7)
        ax.add_patch(target_circle)

    ax.set_xlabel('Horizontal Position (stimulus units)', fontsize=12)
    ax.set_ylabel('Vertical Position (stimulus units)', fontsize=12)

    title = 'Eye Position Density Heatmap'
    if animal_id:
        title += f' - {animal_id}'
    ax.set_title(title, fontsize=14, fontweight='bold')

    ax.set_xlim(-1.7, 1.7)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal', adjustable='box')

    plt.tight_layout()

    if results_dir:
        results_dir.mkdir(parents=True, exist_ok=True)
        prefix = f"{animal_id}_" if animal_id else ""
        filename = f"{prefix}saccade_feedback_heatmap.png"
        fig.savefig(results_dir / filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {results_dir / filename}")

    return fig

In [None]:
def plot_time_to_target(trials: list[dict], results_dir: Optional[Path] = None,
                        animal_id: Optional[str] = None, session_date: str = "") -> plt.Figure:
    """Plot time from trial onset to trial end."""
    valid_trials = [t for t in trials if t.get('has_eye_data', True)]
    trial_numbers = [t['trial_number'] for t in valid_trials]
    durations = [t['duration'] for t in valid_trials]

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

    ax1.plot(trial_numbers, durations, 'o-', linewidth=2, markersize=8,
            color='steelblue', markerfacecolor='lightblue', markeredgecolor='steelblue')
    ax1.set_xlabel('Trial Number', fontsize=12)
    ax1.set_ylabel('Time to Target (seconds)', fontsize=12)

    title = 'Time to Reach Target Across Trials'
    if animal_id:
        title += f' - {animal_id}'
    ax1.set_title(title, fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)

    mean_duration = np.mean(durations)
    ax1.axhline(mean_duration, color='red', linestyle='--', linewidth=2,
               label=f'Mean: {mean_duration:.2f}s')
    ax1.legend(fontsize=10)

    ax2.hist(durations, bins=20, color='steelblue', alpha=0.7, edgecolor='black')
    ax2.set_xlabel('Time to Target (seconds)', fontsize=12)
    ax2.set_ylabel('Number of Trials', fontsize=12)
    ax2.set_title('Distribution of Trial Durations', fontsize=12, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='y')

    std_duration = np.std(durations)
    median_duration = np.median(durations)
    stats_text = f'Mean: {mean_duration:.2f}s\nMedian: {median_duration:.2f}s\nStd: {std_duration:.2f}s\nN: {len(durations)}'
    ax2.text(0.95, 0.95, stats_text, transform=ax2.transAxes,
            fontsize=10, verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()

    if results_dir:
        results_dir.mkdir(parents=True, exist_ok=True)
        prefix = f"{animal_id}_" if animal_id else ""
        filename = f"{prefix}saccade_feedback_time_to_target.png"
        fig.savefig(results_dir / filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {results_dir / filename}")

    return fig

In [None]:
def plot_final_positions_by_target(trials: list[dict], min_duration: float = 0.01, max_duration: float = 15.0,
                                   results_dir: Optional[Path] = None, animal_id: Optional[str] = None,
                                   session_date: str = "") -> plt.Figure:
    """Plot final cursor positions grouped by target position."""
    filtered_trials = [t for t in trials if t.get('has_eye_data', True) and
                       min_duration <= t['duration'] <= max_duration]

    print(f"\nFinal Position Analysis:")
    print(f"  Total trials: {len(trials)}")
    print(f"  Trials after filtering: {len(filtered_trials)}")

    if len(filtered_trials) == 0:
        print("  Warning: No trials left after filtering!")
        return None

    target_groups = defaultdict(list)
    for t in filtered_trials:
        final_x = t.get('final_eye_x', t['eye_x'][-1] if len(t['eye_x']) > 0 else np.nan)
        final_y = t.get('final_eye_y', t['eye_y'][-1] if len(t['eye_y']) > 0 else np.nan)

        target_key = (round(t['target_x'], 2), round(t['target_y'], 2))
        target_groups[target_key].append({
            'final_x': final_x,
            'final_y': final_y,
            'target_x': t['target_x'],
            'target_y': t['target_y'],
            'target_diameter': t['target_diameter']
        })

    sorted_groups = sorted(target_groups.keys(), key=lambda k: (k[0], k[1]))

    fig, ax = plt.subplots(figsize=(10, 10))
    colors = plt.cm.tab10(np.linspace(0, 1, len(sorted_groups)))

    for idx, target_key in enumerate(sorted_groups):
        tx, ty = target_key
        trials_data = target_groups[target_key]

        final_xs = [d['final_x'] for d in trials_data]
        final_ys = [d['final_y'] for d in trials_data]

        mean_x = np.mean(final_xs)
        mean_y = np.mean(final_ys)

        color = colors[idx]
        label = f"Target ({tx:+.1f}, {ty:+.1f}) (n={len(trials_data)})"

        ax.scatter(final_xs, final_ys, alpha=0.4, color=color, s=30, label=label)
        ax.scatter([mean_x], [mean_y], color=color, s=300, marker='*',
                  edgecolors='black', linewidths=2, zorder=10)

        target_radius = trials_data[0]['target_diameter'] / 2.0
        circle = Circle((tx, ty), radius=target_radius, fill=False,
                       edgecolor=color, linewidth=2.5, alpha=0.7)
        ax.add_patch(circle)
        ax.plot(tx, ty, 'o', color=color, markersize=5, markeredgecolor='black')

    ax.set_xlabel('Horizontal Position (stimulus units)', fontsize=14)
    ax.set_ylabel('Vertical Position (stimulus units)', fontsize=14)

    title = 'Final Cursor Positions by Target Type'
    if animal_id:
        title += f' - {animal_id}'
    title += f'\n(N={len(filtered_trials)})'
    ax.set_title(title, fontsize=14, fontweight='bold')

    ax.set_xlim(-1.7, 1.7)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=9, loc='best')

    plt.tight_layout()

    if results_dir:
        results_dir.mkdir(parents=True, exist_ok=True)
        prefix = f"{animal_id}_" if animal_id else ""
        filename = f"{prefix}final_positions_by_target.png"
        fig.savefig(results_dir / filename, dpi=150, bbox_inches='tight')
        print(f"Saved: {results_dir / filename}")

    return fig

## 5. Load Data

Run these cells to load your data:

In [None]:
# Load the three CSV files
eot_df, eye_df, target_df_all = load_feedback_data(FOLDER_PATH, ANIMAL_ID)

# Add original trial numbers
target_df_all['original_trial_number'] = range(1, len(target_df_all) + 1)

In [None]:
# Identify and filter failed trials
target_df_successful, failed_indices, successful_indices = identify_and_filter_failed_trials(
    target_df_all, eot_df, exclude_failed=True
)

In [None]:
# Extract trial trajectories
print("Extracting all trials...")
trials_all = extract_trial_trajectories(eot_df, eye_df, target_df_all,
                                        successful_indices=successful_indices)

# Separate successful trials
trials_successful = [t for t in trials_all if not t.get('trial_failed', False) and t.get('has_eye_data', True)]

# Select which trials to use for analysis
if INCLUDE_FAILED_TRIALS:
    trials_for_analysis = trials_all
    print(f"\nUsing ALL {len(trials_all)} trials for analysis")
else:
    trials_for_analysis = trials_successful
    print(f"\nUsing {len(trials_successful)} successful trials for analysis")

## 6. Visualize Data

Run these cells to generate plots:

In [None]:
# Plot all trajectories
fig_traj = plot_trajectories(trials_for_analysis, RESULTS_DIR, ANIMAL_ID, DATE_STR)
plt.show()

In [None]:
# Plot trajectories by direction
fig_dir = plot_trajectories_by_direction(trials_for_analysis, RESULTS_DIR, ANIMAL_ID, DATE_STR)
plt.show()

In [None]:
# Plot density heatmap
fig_heat = plot_density_heatmap(trials_for_analysis, RESULTS_DIR, ANIMAL_ID, DATE_STR)
plt.show()

In [None]:
# Plot time to target
fig_time = plot_time_to_target(trials_for_analysis, RESULTS_DIR, ANIMAL_ID, DATE_STR)
plt.show()

In [None]:
# Plot final positions by target
fig_final = plot_final_positions_by_target(trials_for_analysis, 
                                            min_duration=TRIAL_MIN_DURATION, 
                                            max_duration=TRIAL_MAX_DURATION,
                                            results_dir=RESULTS_DIR,
                                            animal_id=ANIMAL_ID,
                                            session_date=DATE_STR)
plt.show()

## 7. Session Summary

In [None]:
# Calculate summary statistics
valid_trials = [t for t in trials_for_analysis if t.get('has_eye_data', True)]
durations = [t['duration'] for t in valid_trials]
path_lengths = [t['path_length'] for t in valid_trials]
efficiencies = [t['path_efficiency'] for t in valid_trials]

print("="*60)
print("SESSION SUMMARY")
print("="*60)
print(f"Folder: {FOLDER_PATH}")
print(f"Animal: {ANIMAL_ID}")
print(f"Date: {DATE_STR}")
print(f"Valid trials: {len(valid_trials)}")
print(f"\nTime to Target:")
print(f"  Mean: {np.mean(durations):.2f} ± {np.std(durations):.2f} s")
print(f"  Median: {np.median(durations):.2f} s")
print(f"  Range: {np.min(durations):.2f} - {np.max(durations):.2f} s")
print(f"\nPath Length:")
print(f"  Mean: {np.mean(path_lengths):.3f} ± {np.std(path_lengths):.3f}")
print(f"  Median: {np.median(path_lengths):.3f}")
print(f"\nPath Efficiency (1.0 = perfectly direct):")
print(f"  Mean: {np.mean(efficiencies):.3f} ± {np.std(efficiencies):.3f}")
print(f"  Median: {np.median(efficiencies):.3f}")
print("="*60)

## 8. Interactive Exploration

Use this section to explore individual trials:

In [None]:
# View a single trial
def plot_single_trial(trial_idx: int):
    """Plot a single trial for detailed inspection."""
    if trial_idx >= len(trials_for_analysis):
        print(f"Trial index {trial_idx} out of range (0-{len(trials_for_analysis)-1})")
        return
    
    trial = trials_for_analysis[trial_idx]
    
    if not trial.get('has_eye_data', True):
        print(f"Trial {trial['trial_number']} has no eye data")
        return
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    eye_x = trial['eye_x']
    eye_y = trial['eye_y']
    
    # Plot trajectory
    ax.plot(eye_x, eye_y, 'b-', linewidth=2, alpha=0.7, label='Trajectory')
    ax.plot(eye_x[0], eye_y[0], 'go', markersize=15, label='Start')
    ax.plot(trial['final_eye_x'], trial['final_eye_y'], 'rs', markersize=15, label='End')
    
    # Draw target
    target_circle = Circle((trial['target_x'], trial['target_y']), 
                          radius=trial['target_diameter']/2.0,
                          fill=False, edgecolor='black', linewidth=3,
                          linestyle='-' if trial['target_visible'] else '--')
    ax.add_patch(target_circle)
    ax.plot(trial['target_x'], trial['target_y'], 'ko', markersize=8, label='Target')
    
    ax.set_xlim(-1.7, 1.7)
    ax.set_ylim(-1, 1)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right')
    
    title = f"Trial {trial['trial_number']} - Duration: {trial['duration']:.2f}s"
    ax.set_title(title, fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print trial details
    print(f"\nTrial {trial['trial_number']} Details:")
    print(f"  Duration: {trial['duration']:.3f}s")
    print(f"  Target: ({trial['target_x']:.2f}, {trial['target_y']:.2f})")
    print(f"  Target visible: {bool(trial['target_visible'])}")
    print(f"  Start: ({trial['start_eye_x']:.3f}, {trial['start_eye_y']:.3f})")
    print(f"  End: ({trial['final_eye_x']:.3f}, {trial['final_eye_y']:.3f})")
    print(f"  Path length: {trial['path_length']:.3f}")
    print(f"  Path efficiency: {trial['path_efficiency']:.3f}")

In [None]:
# Example: Plot trial 0
plot_single_trial(0)

In [None]:
# Create a summary DataFrame for all trials
trial_summary = pd.DataFrame([
    {
        'trial_number': t['trial_number'],
        'duration': t['duration'],
        'target_x': t['target_x'],
        'target_y': t['target_y'],
        'target_visible': t['target_visible'],
        'path_length': t['path_length'],
        'path_efficiency': t['path_efficiency'],
        'trial_failed': t.get('trial_failed', False),
        'has_eye_data': t.get('has_eye_data', True),
    }
    for t in trials_all
])

# Display summary
print("Trial Summary Table:")
display(trial_summary.head(20))

In [None]:
# Save trial summary to CSV
if RESULTS_DIR:
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    csv_path = RESULTS_DIR / f"{ANIMAL_ID}_trial_summary.csv"
    trial_summary.to_csv(csv_path, index=False)
    print(f"Saved trial summary to: {csv_path}")