In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.signal import savgol_filter, find_peaks
import fly_analysis as fa

In [22]:
def filter_trajecs(
    df,
    min_length=300,  # 50 frames = 0.5 seconds
    max_z_speed=0.3,  # shouldn't be flying up/down a lot
    min_speed=0.05,  # should be flying (not wallking)
):
    selected_obj_ids = []

    for obj_id, gdf in df.groupby("obj_id"):
        if len(gdf) > min_length:
            speed = np.linalg.norm(np.vstack([gdf.zvel, gdf.xvel, gdf.yvel]), axis=0)

            if gdf.z.min() > 0.05:  # not on floor
                if gdf.z.max() < 0.5:  # not on ceiling
                    if np.max(np.abs(gdf.zvel)) < max_z_speed:
                        if np.min(speed[5:]) > min_speed:  # first few frames are off
                            selected_obj_ids.append(obj_id)

    return df.loc[df["obj_id"].isin(selected_obj_ids)]

def angdiff(theta1, theta2):
    return np.arctan2(np.sin(theta1 - theta2), np.cos(theta1 - theta2))

def get_saccade_amplitude(xvel, yvel, idx):
    heading_before = np.arctan2(yvel[idx - 10], xvel[idx - 10])
    heading_after = np.arctan2(yvel[idx + 10], xvel[idx + 10])
    heading_difference = angdiff(heading_before, heading_after)
    return heading_difference

def split_on_jumps(df, column='frame', k=1, n=300):
    # Calculate the differences between consecutive values
    diffs = df[column].diff()
    
    # Find where the differences are greater than k
    split_indices = np.where(diffs > k)[0]
    
    # Split the dataframe
    if len(split_indices) > 0:
        result = np.split(df, split_indices)
    
        # Remove empty dataframes (if any)
        result = [r for r in result if len(r) > n and not r.empty]
    else:
        result = [df]
    
    return result

def detect_saccades(angvel, **kwargs):
    thresh = np.deg2rad(kwargs.get("threshold", 500))
    dist = kwargs.get("distance", 10)
    # positive_peaks, _ = find_peaks(angvel, height=thresh, distance=dist)
    # negative_peaks, _ = find_peaks(-angvel, height=thresh, distance=dist)
    peaks,_ = find_peaks(np.abs(angvel), height=thresh, distance=dist)
    return peaks

def process_fly_trajectories(
    df,
    x_bounds=(-0.25, 0.25),
    y_bounds=(-0.25, 0.25),
    z_bounds=(0.00, 0.35),
    min_trajectory_length=300,
):
    
    angvels = []
    linvels = []
    accels = []
    amplitudes = []
    isi = []
    for obj_id, grp in df.groupby("unique_obj_id"):

        if len(grp) < min_trajectory_length:
            continue
        
        # make a clone of grp
        grp_new = grp.copy()

        # find indices in range
        x, y, z = grp.x.to_numpy(), grp.y.to_numpy(), grp.z.to_numpy()
        indices_in_range = (x >= x_bounds[0]) & (x <= x_bounds[1]) & (y >= y_bounds[0]) & (y <= y_bounds[1]) & (z >= z_bounds[0]) & (z <= z_bounds[1])
        grp_new = grp_new.iloc[indices_in_range]

        # split dataframes
        result = split_on_jumps(grp_new, column="frame", k=1)
        x = savgol_filter(result[0]["x"], 21, 3)
        y = savgol_filter(result[0]["y"], 21, 3)
        z = savgol_filter(result[0]["z"], 21, 3)
        xvel = savgol_filter(result[0]["xvel"], 21, 3)
        yvel = savgol_filter(result[0]["yvel"], 21, 3)

        # angular velocity
        theta = np.arctan2(yvel, xvel)
        theta_u = np.unwrap(theta)
        angular_velocity = np.gradient(theta_u, 0.01)

        # linear velocity
        linear_velocity = np.sqrt(xvel**2+yvel**2)

        # linear acceleration
        linear_acceleration = np.gradient(linear_velocity, 0.01)

        # saccade indices
        saccades = detect_saccades(angular_velocity, threshold=500, distance=10)
        
        # loop over saccade indices
        for sac in saccades:
            if sac-25 < 0 or sac+25 > len(angular_velocity):
                continue

            angvels.append(angular_velocity[sac-25:sac+25])
            linvels.append(linear_velocity[sac-25:sac+25])
            accels.append(linear_acceleration[sac-25:sac+25])
            amplitudes.append(get_saccade_amplitude(xvel, yvel, sac))
        isi.append(np.diff(saccades))

    return np.asarray(angvels), np.asarray(linvels), np.asarray(accels), np.asarray(amplitudes), isi

In [3]:
def plot_mean_and_std(data):
    npdata = np.asarray(data)
    fig = plt.figure()
    mean = np.nanmean(npdata, axis=0)
    std = np.nanmean(npdata, axis=0)
    plt.plot(mean)
    plt.fill_between(np.arange(len(mean)), mean-std, mean+std, alpha=0.5)
    plt.show()

def plot_all_traces(data):
    fig = plt.figure()
    for d in data:
        plt.plot(d)
    plt.show()

In [4]:
files = ["20240702_140401.braidz","20240712_153309.braidz","20240715_171206.braidz"]
root_folder = "/home/buchsbaum/mnt/DATA/Experiments/"
wtcs = fa.braidz.read_multiple_braidz(files, root_folder)

Reading /home/buchsbaum/mnt/DATA/Experiments/20240702_140401.braidz using pyarrow
Reading /home/buchsbaum/mnt/DATA/Experiments/20240712_153309.braidz using pyarrow
Reading /home/buchsbaum/mnt/DATA/Experiments/20240715_171206.braidz using pyarrow


In [40]:
import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
from typing import List, Tuple, Dict

# Constants
ANGULAR_VELOCITY_THRESHOLD = np.radians(500)  # 500 deg/s in rad/s
STIM_SACCADE_MIN_FRAMES = 25
STIM_SACCADE_MAX_FRAMES = 35
ARENA_CENTER_X = 0
ARENA_CENTER_Y = 0
ARENA_CENTER_Z = 0.175
ARENA_RADIUS = 0.15
ARENA_HEIGHT = 0.15
SAVGOL_WINDOW = 21
SAVGOL_ORDER = 3

def smooth_data(data: np.ndarray) -> np.ndarray:
    return savgol_filter(data, SAVGOL_WINDOW, SAVGOL_ORDER)

def calculate_velocities(xvel: np.ndarray, yvel: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    angular_velocity = np.arctan2(yvel, xvel)
    angular_velocity = np.unwrap(angular_velocity)
    angular_velocity = np.gradient(angular_velocity)
    linear_velocity = np.sqrt(xvel**2 + yvel**2)
    return angular_velocity, linear_velocity

def is_flying(x: float, y: float, z: float) -> bool:
    distance_from_center = np.sqrt((x - ARENA_CENTER_X)**2 + (y - ARENA_CENTER_Y)**2)
    height_from_center = abs(z - ARENA_CENTER_Z)
    return (distance_from_center <= ARENA_RADIUS) and (height_from_center <= ARENA_HEIGHT / 2)

def detect_saccades(angular_velocity: np.ndarray, linear_velocity: np.ndarray, 
                    positions: np.ndarray) -> List[Tuple[int, int]]:
    saccades = []
    in_saccade = False
    start_idx = 0
    
    for i, (av, pos) in enumerate(zip(angular_velocity, positions)):
        if not in_saccade and abs(av) > ANGULAR_VELOCITY_THRESHOLD and is_flying(*pos):
            in_saccade = True
            start_idx = i
        elif in_saccade and (abs(av) <= ANGULAR_VELOCITY_THRESHOLD or not is_flying(*pos)):
            in_saccade = False
            saccades.append((start_idx, i))
    
    if in_saccade:
        saccades.append((start_idx, len(angular_velocity) - 1))
    
    return saccades

def analyze_group(group: pd.DataFrame, stim_frames: List[int]) -> Dict[str, List]:
    xvel_smooth = smooth_data(group['xvel'].values)
    yvel_smooth = smooth_data(group['yvel'].values)
    
    angular_velocity, linear_velocity = calculate_velocities(xvel_smooth, yvel_smooth)
    positions = group[['x', 'y', 'z']].values
    
    saccades = detect_saccades(angular_velocity, linear_velocity, positions)
    
    results = {
        'angular_velocity': [],
        'linear_velocity': [],
        'is_stim_saccade': []
    }
    
    for start, end in saccades:
        results['angular_velocity'].append(angular_velocity[start:end])
        results['linear_velocity'].append(linear_velocity[start:end])
        
        is_stim = any(STIM_SACCADE_MIN_FRAMES <= start - stim_frame <= STIM_SACCADE_MAX_FRAMES 
                      for stim_frame in stim_frames)
        results['is_stim_saccade'].append(is_stim)
    
    return results

def analyze_saccades(df: pd.DataFrame, stim: pd.DataFrame) -> Dict[str, List]:
    results = {
        'angular_velocity': [],
        'linear_velocity': [],
        'is_stim_saccade': []
    }
    
    analyzed_groups = set()
    
    for _, stim_row in stim.iterrows():
        group_key = (stim_row['obj_id'], stim_row['exp_num'])
        
        if group_key in analyzed_groups:
            continue
        
        group = df[(df['obj_id'] == stim_row['obj_id']) & (df['exp_num'] == stim_row['exp_num'])]
        stim_frames = stim[(stim['obj_id'] == stim_row['obj_id']) & 
                           (stim['exp_num'] == stim_row['exp_num'])]['frame'].tolist()
        
        group_results = analyze_group(group, stim_frames)
        
        results['angular_velocity'].extend(group_results['angular_velocity'])
        results['linear_velocity'].extend(group_results['linear_velocity'])
        results['is_stim_saccade'].extend(group_results['is_stim_saccade'])
        
        analyzed_groups.add(group_key)
    
    return results

# Usage
# df and stim should be pandas DataFrames with the specified columns
results = analyze_saccades(wtcs["df"], wtcs["stim"])

In [42]:
import matplotlib.pyplot as plt
import numpy as np

def plot_saccade_comparison(results: Dict[str, List], time_window: float = 0.1):
    # Separate stim and non-stim saccades
    stim_angular = np.array([av for av, is_stim in zip(results['angular_velocity'], results['is_stim_saccade']) if is_stim])
    non_stim_angular = np.array([av for av, is_stim in zip(results['angular_velocity'], results['is_stim_saccade']) if not is_stim])
    stim_linear = np.array([lv for lv, is_stim in zip(results['linear_velocity'], results['is_stim_saccade']) if is_stim])
    non_stim_linear = np.array([lv for lv, is_stim in zip(results['linear_velocity'], results['is_stim_saccade']) if not is_stim])

    # Check if we have saccades to plot
    if len(stim_angular) == 0 and len(non_stim_angular) == 0:
        print("No saccades detected. Unable to create plot.")
        return

    # Find the maximum length of non-empty arrays
    max_len = max(
        max((len(arr) for arr in stim_angular), default=0),
        max((len(arr) for arr in non_stim_angular), default=0),
        max((len(arr) for arr in stim_linear), default=0),
        max((len(arr) for arr in non_stim_linear), default=0)
    )

    if max_len == 0:
        print("All saccade arrays are empty. Unable to create plot.")
        return

    def pad_to_length(arr, length):
        return np.pad(arr, (0, length - len(arr)), mode='constant', constant_values=np.nan)

    # Pad arrays only if they are not empty
    stim_angular = np.array([pad_to_length(arr, max_len) for arr in stim_angular]) if len(stim_angular) > 0 else np.array([])
    non_stim_angular = np.array([pad_to_length(arr, max_len) for arr in non_stim_angular]) if len(non_stim_angular) > 0 else np.array([])
    stim_linear = np.array([pad_to_length(arr, max_len) for arr in stim_linear]) if len(stim_linear) > 0 else np.array([])
    non_stim_linear = np.array([pad_to_length(arr, max_len) for arr in non_stim_linear]) if len(non_stim_linear) > 0 else np.array([])

    # Calculate mean and std, handling empty arrays
    def safe_mean_std(arr):
        if len(arr) > 0:
            return np.nanmean(arr, axis=0), np.nanstd(arr, axis=0)
        else:
            return np.array([]), np.array([])

    stim_angular_mean, stim_angular_std = safe_mean_std(stim_angular)
    non_stim_angular_mean, non_stim_angular_std = safe_mean_std(non_stim_angular)
    stim_linear_mean, stim_linear_std = safe_mean_std(stim_linear)
    non_stim_linear_mean, non_stim_linear_std = safe_mean_std(non_stim_linear)

    # Create time array
    time = np.linspace(-time_window, time_window, max_len)

    # Create subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
    
    # Plot angular velocity
    if len(stim_angular) > 0:
        ax1.plot(time, np.degrees(stim_angular_mean), label='Stimulus Saccades', color='blue')
        ax1.fill_between(time, np.degrees(stim_angular_mean - stim_angular_std), 
                         np.degrees(stim_angular_mean + stim_angular_std), alpha=0.3, color='blue')
    if len(non_stim_angular) > 0:
        ax1.plot(time, np.degrees(non_stim_angular_mean), label='Spontaneous Saccades', color='red')
        ax1.fill_between(time, np.degrees(non_stim_angular_mean - non_stim_angular_std), 
                         np.degrees(non_stim_angular_mean + non_stim_angular_std), alpha=0.3, color='red')
    ax1.set_title('Angular Velocity: Stimulus vs Spontaneous Saccades')
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('Angular Velocity (deg/s)')
    ax1.legend()
    ax1.grid(True)

    # Plot linear velocity
    if len(stim_linear) > 0:
        ax2.plot(time, stim_linear_mean, label='Stimulus Saccades', color='blue')
        ax2.fill_between(time, stim_linear_mean - stim_linear_std, 
                         stim_linear_mean + stim_linear_std, alpha=0.3, color='blue')
    if len(non_stim_linear) > 0:
        ax2.plot(time, non_stim_linear_mean, label='Spontaneous Saccades', color='red')
        ax2.fill_between(time, non_stim_linear_mean - non_stim_linear_std, 
                         non_stim_linear_mean + non_stim_linear_std, alpha=0.3, color='red')
    ax2.set_title('Linear Velocity: Stimulus vs Spontaneous Saccades')
    ax2.set_xlabel('Time (s)')
    ax2.set_ylabel('Linear Velocity (m/s)')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

# Usage
plot_saccade_comparison(results)

No saccades detected. Unable to create plot.
