In [6]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.signal import savgol_filter, find_peaks
from scipy.stats import circmean
import fly_analysis as fa
from tqdm import tqdm
import pickle
from datetime import datetime
import os
date_time = datetime.now().strftime("%Y%m%d_%H%M%S")
print(date_time)

20241028_171404


In [7]:
def sg_smooth(df):
    """
    Applies Savitzky-Golay filter to smooth specified columns in a DataFrame.

    Parameters:
    df (pandas.DataFrame): The input DataFrame containing the columns to be smoothed.

    Returns:
    pandas.DataFrame: The DataFrame with smoothed columns.

    The function applies the Savitzky-Golay filter with a window length of 21 and a polynomial order of 3
    to the columns 'x', 'y', 'z', 'xvel', 'yvel', and 'zvel' in the input DataFrame.
    """
    columns = ["x", "y", "z", "xvel", "yvel", "zvel"]
    df[columns] = df[columns].apply(lambda x: savgol_filter(x, 21, 3))
    return df


def get_angular_velocity(xvel, yvel, dt=0.01):
    """
    Calculate the angular velocity from x and y velocity components.

    Parameters:
    xvel (array-like): The x-component of the velocity.
    yvel (array-like): The y-component of the velocity.
    dt (float, optional): The time step between velocity measurements. Default is 0.01.

    Returns:
    numpy.ndarray: The angular velocity calculated from the input velocities.
    """
    theta = np.arctan2(yvel, xvel)
    theta_unwrap = np.unwrap(theta)
    return np.gradient(theta_unwrap, dt)


def get_linear_velocity(xvel, yvel):
    """
    Calculate the linear velocity given the x and y components of velocity.

    Parameters:
    xvel (float): The velocity component in the x direction.
    yvel (float): The velocity component in the y direction.

    Returns:
    float: The linear velocity calculated using the Pythagorean theorem.
    """
    return np.sqrt(xvel**2 + yvel**2)


def get_saccades(angular_velocity, linear_velocity, **kwargs):
    """
    Detects saccades in the given angular velocity data.

    Parameters:
    angular_velocity (array-like): The angular velocity data to analyze.
    **kwargs: Additional keyword arguments for peak detection.
        - height (float, optional): Minimum height of peaks. Default is np.deg2rad(300).
        - distance (int, optional): Minimum distance between peaks. Default is 20.

    Returns:
    array: Indices of the detected saccades.
    """
    saccades = find_peaks(
        np.abs(angular_velocity),
        height=kwargs.get("height", np.deg2rad(300)),
        distance=kwargs.get("distance", 20),
    )[0]

    saccades = saccades[linear_velocity[saccades] > kwargs.get("min_speed", 0.1)]
    return saccades


def plot_mean_and_std(arr, ax=None, take_abs=True, units="frames", convert_to_degrees = False, peak_align = False, label=None):
    """
    Plots the mean and standard deviation of the given array.

    Parameters:
    arr (numpy.ndarray): The input array to compute mean and standard deviation.
    ax (matplotlib.axes.Axes, optional): The axes on which to plot. If None, a new figure and axes are created.
    take_abs (bool, optional): If True, takes the absolute value of the array before computing mean and std. Default is True.
    label (str, optional): The label for the plot. Default is None.

    Returns:
    matplotlib.axes.Axes: The axes with the plot.
    """
    if ax is None:
        _, ax = plt.subplots()

    data = arr.copy()
    
    if peak_align:
        take_abs = False
        # Find peaks only within specified window
        start_idx, end_idx = 50, 80
        peak_indices = start_idx + np.argmax(np.abs(data[:, start_idx:end_idx]), axis=1)
        
        # Flip based on peak sign
        for i, peak_idx in enumerate(peak_indices):
            if data[i, peak_idx] < 0:
                data[i] = -data[i]
    
    if take_abs:
        data = np.abs(data)
        
    mean = np.nanmean(data, axis=0)
    std = np.nanstd(data, axis=0)
    
    if convert_to_degrees:
        mean = np.rad2deg(mean)
        std = np.rad2deg(std)
    
    x = np.linspace(-0.5, 1.0, 150)

    ax.plot(x, mean, label=label)
    ax.fill_between(x, mean - std, mean + std, alpha=0.2)
    
    if convert_to_degrees:
        ax.set_ylabel("Angular Velocity (deg/s)")
    else:
        ax.set_ylabel("Angular Velocity (rad/s)")
    
    return ax


def plot_histogram(arr, ax=None, label=None):
    """
    Plots a histogram of the given array with a density estimate.

    Parameters:
    arr (array-like): The input data array to plot.
    ax (matplotlib.axes.Axes, optional): The axes on which to plot the histogram.
                                         If None, a new figure and axes are created.
    label (str, optional): The label for the histogram.

    Returns:
    matplotlib.axes.Axes: The axes with the plotted histogram.
    """
    if ax is None:
        _, ax = plt.subplots()

    sns.histplot(
        arr,
        ax=ax,
        bins=36,
        binrange=(-np.pi, np.pi),
        stat="density",
        kde=True,
    )
    return ax


def calculate_inverted_signed_angle(v1, v2):
    # Calculate the signed angle between two vectors
    angle = np.arctan2(np.cross(v1, v2), np.dot(v1, v2))

    # Invert the angle
    if angle >= 0:
        inverted_angle = np.pi - angle
    else:
        inverted_angle = -np.pi - angle

    return inverted_angle  # This will be in radians, in the range [-π, π]


def get_heading_difference_for_trajectory(xyz, midpoint=65, n_around=10):
    if np.shape(xyz)[1] == 3:
        xyz = xyz[:, :2]

    vector_before = xyz[midpoint - n_around] - xyz[midpoint]
    vector_after = xyz[midpoint + n_around] - xyz[midpoint]

    return calculate_inverted_signed_angle(vector_before, vector_after)


def extract_saccade_angvel(angular_velocity, idx):
    return angular_velocity[idx - 50 : idx + 100]


def extract_saccade_linvel(linear_velocity, idx):
    return linear_velocity[idx - 50 : idx + 100]


def flatten_nested(lst):
    result = []
    for item in lst:
        if isinstance(item, list):
            result.extend(flatten_nested(item))
        else:
            result.append(item)
    return result


def get_all_saccade_data(df, stim, **kwargs):
    spont_angvels = []
    stim_angvels = []
    spont_linvels = []
    stim_linvels = []
    spont_heading_diffs = []
    stim_heading_diffs = []

    heading_n_around = 10
    n_before = kwargs.get("n_before", 50)
    n_after = kwargs.get("n_after", 100)
    delay = kwargs.get("delay", 15)

    for _, row in tqdm(stim.iterrows(), total=len(stim)):
        obj_id = row["obj_id"]
        exp_num = row["exp_num"]
        frame = row["frame"]

        grp = df[(df["obj_id"] == obj_id) & (df["exp_num"] == exp_num)]

        if len(grp) < kwargs.get("min_grp_length", 150):
            continue

        try:
            stim_idx = np.where(grp["frame"] == frame)[0][0]
        except IndexError:
            continue

        if stim_idx - n_before < 0 or stim_idx + n_after > len(grp):
            continue

        grp = sg_smooth(grp)

        angular_velocity = get_angular_velocity(
            grp["xvel"].to_numpy(), grp["yvel"].to_numpy()
        )
        linear_velocity = get_linear_velocity(
            grp["xvel"].to_numpy(), grp["yvel"].to_numpy()
        )

        # get all saccades
        saccade_indices = get_saccades(angular_velocity, linear_velocity, **kwargs)

        for saccade in saccade_indices:
            if stim_idx < saccade < stim_idx + n_after:
                continue
            elif saccade - n_before - delay < 0 or saccade + n_before - delay >= len(
                grp
            ):
                continue
            else:
                angvel = angular_velocity[
                    saccade - n_before - delay : saccade + n_after - delay
                ]
                linvel = linear_velocity[
                    saccade - n_before - delay : saccade + n_after - delay
                ]
                heading_diff = get_heading_difference_for_trajectory(
                    grp[["x", "y"]].to_numpy(), saccade, 10
                )
                if len(angvel) == len(linvel) == n_before+n_after:
                    spont_angvels.append(angvel)
                    spont_linvels.append(linvel)
                    spont_heading_diffs.append(heading_diff)
                else:
                    continue

        # get only stim-related saccades
        index_slice = slice(stim_idx - n_before, stim_idx + n_after)
        stim_angvels.append(angular_velocity[index_slice])
        stim_linvels.append(linear_velocity[index_slice])
        
        xy_slice = grp[["x", "y"]].to_numpy()[index_slice]
        stim_heading_diffs.append(
            get_heading_difference_for_trajectory(xy_slice)
        )

    spont_angvels = np.row_stack(spont_angvels)
    spont_linvels = np.row_stack(spont_linvels)
    spont_heading_diffs = np.array(spont_heading_diffs)

    # convert other lists to numpy
    stim_angvels = np.array(stim_angvels)
    stim_linvels = np.array(stim_linvels)
    stim_heading_diffs = np.array(stim_heading_diffs)

    out_dict = {
        "spont_angvels": spont_angvels,
        "stim_angvels": stim_angvels,
        "spont_linvels": spont_linvels,
        "stim_linvels": stim_linvels,
        "spont_heading_diffs": spont_heading_diffs,
        "stim_heading_diffs": stim_heading_diffs,
    }
    return out_dict


In [8]:
root_folder = "/home/buchsbaum/mnt/md0/Experiments/"
checkpoint_path = "/home/buchsbaum/src/fly_analysis/notebooks/checkpoints/"

# Plots

## Angular velocity, Linear velocity, and heading change per group

### DNp03

In [9]:
# Process j53xu68 files
print("Processing j53xu68 files")
j53xu68_files = ["20230321_162524.braidz", "20230519_130210.braidz"]
j53_data = fa.braidz.read_multiple_braidz(j53xu68_files, root_folder)

Processing j53xu68 files
Reading /home/buchsbaum/mnt/md0/Experiments/20230321_162524.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20230519_130210.braidz using pyarrow


In [10]:
pkl_file = os.path.join("/home/buchsbaum/src/fly_analysis/notebooks/checkpoints/", f"{date_time}_dnp03.pkl")

# if the file exists, skip the processing
if not os.path.exists(pkl_file):
    j53_results = get_all_saccade_data(j53_data["df"], j53_data["stim"])
    with open(pkl_file, "wb") as f:
        pickle.dump(j53_results, f)
else:
    with open(pkl_file, "rb") as f:
        j53_results = pickle.load(f)

100%|██████████| 531/531 [01:30<00:00,  5.85it/s]


In [73]:
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

# angvel
plot_mean_and_std(j53_results["spont_angvels"], label="Spont", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")
plot_mean_and_std(j53_results["stim_angvels"], label="Stim", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")

# linvel
plot_mean_and_std(j53_results["spont_linvels"], label="Spont", ax=axs[1], units="ms")
plot_mean_and_std(j53_results["stim_linvels"], label="Stim", ax=axs[1], units="ms")

# heading_difference
plot_histogram(j53_results["spont_heading_diffs"], label="Spont", ax=axs[2])
plot_histogram(j53_results["stim_heading_diffs"], label="Stim", ax=axs[2])

for ax in axs[:2]:
    ax.set_xlim(-0.5, 1.0)
    ax.set_xlabel("Time (s)")

axs[1].set_ylabel("Linear Velocity (m/s)")
axs[2].set_xlabel("Direction change (rad)")
axs[2].set_ylabel("Probability Density")

plt.suptitle("DNp03")
plt.tight_layout()

# save
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/DNp03.png", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/DNp03.svg", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/DNp03.pdf", dpi=300)

# close
plt.close()

### AX-Split

In [None]:
# Process g29xu68 files
print("Processing g29xu68 files")
g29xu68_files = ["20230512_144203.braidz", "20230203_145747.braidz"]
g29_data = fa.braidz.read_multiple_braidz(g29xu68_files, root_folder)

In [None]:
g29_results = get_all_saccade_data(g29_data["df"], g29_data["stim"])

In [74]:
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

# angvel
plot_mean_and_std(g29_results["spont_angvels"], label="Spont", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")
plot_mean_and_std(g29_results["stim_angvels"], label="Stim", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")

# linvel
plot_mean_and_std(g29_results["spont_linvels"], label="Spont", ax=axs[1], units="ms")
plot_mean_and_std(g29_results["stim_linvels"], label="Stim", ax=axs[1], units="ms")

# heading_difference
plot_histogram(g29_results["spont_heading_diffs"], label="Spont", ax=axs[2])
plot_histogram(g29_results["stim_heading_diffs"], label="Stim", ax=axs[2])

for ax in axs[:2]:
    ax.set_xlim(-0.5, 1.0)
    ax.set_xlabel("Time (s)")

axs[1].set_ylabel("Linear Velocity (m/s)")
axs[2].set_xlabel("Direction change (rad)")
axs[2].set_ylabel("Probability Density")

plt.suptitle("AX-Split")
plt.tight_layout()

# save
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/AX-Split.png", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/AX-Split.svg", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/AX-Split.pdf", dpi=300)

# close
plt.close()

### Empty-Split

In [None]:
# Process emptyxu68 files
print("Processing emptyxu68 files")
emptyxu68_files = [
    "20231020_150051.braidz",
    "20230530_115028.braidz",
]
empty_data = fa.braidz.read_multiple_braidz(emptyxu68_files, root_folder)

In [None]:
empty_split_results = get_all_saccade_data(empty_data["df"], empty_data["stim"])

In [75]:
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

# angvel
plot_mean_and_std(empty_split_results["spont_angvels"], label="Spont", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")
plot_mean_and_std(empty_split_results["stim_angvels"], label="Stim", ax=axs[0], take_abs=True, convert_to_degrees=True, units="ms")

# linvel
plot_mean_and_std(empty_split_results["spont_linvels"], label="Spont", ax=axs[1], units="ms")
plot_mean_and_std(empty_split_results["stim_linvels"], label="Stim", ax=axs[1], units="ms")

# heading_difference
plot_histogram(empty_split_results["spont_heading_diffs"], label="Spont", ax=axs[2])
plot_histogram(empty_split_results["stim_heading_diffs"], label="Stim", ax=axs[2])

for ax in axs[:2]:
    ax.set_xlim(-0.5, 1.0)
    ax.set_xlabel("Time (s)")

axs[1].set_ylabel("Linear Velocity (m/s)")
axs[2].set_xlabel("Direction change (rad)")
axs[2].set_ylabel("Probability Density")

plt.suptitle("Empty-Split")
plt.tight_layout()

# save
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/Empty_Split.png", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/Empty_Split.svg", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/Empty_Split.pdf", dpi=300)

# close
plt.close()

## Combined angular velocity plot

In [82]:
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

# stimulus-elicited only
# angvel
plot_mean_and_std(empty_split_results["stim_angvels"], label="Empty-Split", ax=axs[0], take_abs=True, convert_to_degrees=True,  peak_align=False)
plot_mean_and_std(g29_results["stim_angvels"], label="AX-Split", ax=axs[0], take_abs=True, convert_to_degrees=True,  peak_align=False)
plot_mean_and_std(j53_results["stim_angvels"], label="DNp03", ax=axs[0], take_abs=True, convert_to_degrees=True, peak_align=False)

# linvel
plot_mean_and_std(empty_split_results["stim_linvels"], label="Empty-Split", ax=axs[1], peak_align=False)
plot_mean_and_std(g29_results["stim_linvels"], label="AX-Split", ax=axs[1], peak_align=False)
plot_mean_and_std(j53_results["stim_linvels"], label="DNp03", ax=axs[1], peak_align=False)

# heading_difference
plot_histogram(empty_split_results["stim_heading_diffs"], label="Empty-Split", ax=axs[2])
plot_histogram(g29_results["stim_heading_diffs"], label="AX-Split", ax=axs[2])
plot_histogram(j53_results["stim_heading_diffs"], label="DNp03", ax=axs[2])


axs[0].axvspan(0, 0.3, color="gray", alpha=0.2)
axs[1].axvspan(0, 0.3, color="gray", alpha=0.2)

axs[0].set_xlim(-0.5, 1.0)
axs[1].set_xlim(-0.5, 1.0)
axs[2].set_xlim(-np.pi, np.pi)

axs[0].set_xlabel("Time (s)")
axs[1].set_xlabel("Time (s)")
axs[2].set_xlabel("Direction change (rad)")

axs[0].set_ylabel("Angular Velocity (deg/s)")
axs[1].set_ylabel("Linear Velocity (m/s)")

plt.suptitle("Stimulus-Elicited Saccades")
plt.tight_layout()

plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/combined_analysis.png", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/combined_analysis.svg", dpi=300)
plt.savefig("/home/buchsbaum/src/fly_analysis/notebooks/Figures/dnp03_paper/combined_analysis.pdf", dpi=300)
plt.close("all")


In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='polar')
#ax.hist(empty_split_results["stim_heading_diffs"])

plt.show()