In [1]:
from flyanalysis import braidz
import pandas as pd
import numpy as np
import os

In [8]:
from flyanalysis import trajectory
from typing import Optional


def extract_stimulus_centered_data(
    df: pd.DataFrame,
    csv: pd.DataFrame,
    n_before: int = 50,
    n_after: int = 100,
    columns: list = ["angular_velocity", "linear_velocity", "position"],
    padding: Optional[int] = None,
):
    data_dict = {}
    for col in columns:
        data_dict[col] = []

    def get_segment_with_padding(data, before_overflow, after_overflow, is_2d=False):
        if padding:
            if is_2d:
                return np.pad(
                    data,
                    ((before_overflow, after_overflow), (0, 0)),
                    constant_values=np.nan,
                )
            else:
                return np.pad(
                    data, (before_overflow, after_overflow), constant_values=np.nan
                )
        return data

    for idx, row in csv.iterrows():
        # extract identifier and frame number
        obj_id = int(row["obj_id"])
        frame = int(row["frame"])

        # filter dataframe based on identifier
        grp = df[df.obj_id == obj_id]

        # skip if length is less than 150
        if len(grp) < 150:
            continue

        # find index of stimulus in main df
        try:
            stim_idx = np.where(grp.frame == frame)[0][0]
        except IndexError:
            continue

        # set indices and check boundaries
        idx_before = stim_idx - n_before
        idx_after = stim_idx + n_after

        # Calculate overflow and adjust indices
        before_overflow = max(0, -idx_before)
        after_overflow = max(0, idx_after - len(grp) + 1)

        # Check if overflow exceeds padding
        if (padding is not None) and (
            before_overflow > padding or after_overflow > padding
        ):
            continue

        # Adjust indices to within valid range
        valid_idx_before = max(0, idx_before)
        valid_idx_after = min(len(grp), idx_after + 1)

        # Get data and apply padding if necessary
        for col in columns:
            if col == "angular_velocity":
                angvel = trajectory.get_angular_velocity(grp)
                segment = angvel[valid_idx_before:valid_idx_after]
                padded_segment = get_segment_with_padding(
                    segment, before_overflow, after_overflow
                )
                data_dict[col].append(padded_segment)

            elif col == "linear_velocity":
                linvel = trajectory.get_linear_velocity(grp)
                segment = linvel[valid_idx_before:valid_idx_after]
                padded_segment = get_segment_with_padding(
                    segment, before_overflow, after_overflow
                )
                data_dict[col].append(padded_segment)

            elif col == "position":
                position_segment = (
                    grp[["x", "y", "z"]].iloc[valid_idx_before:valid_idx_after].values
                )
                padded_segment = get_segment_with_padding(
                    position_segment, before_overflow, after_overflow, is_2d=True
                )
                data_dict[col].append(padded_segment)

            else:
                print(f"Column {col} not found")

    return data_dict

In [3]:
root_folder = "/home/buchsbaum/mnt/md0/Experiments/"
wtcs_file = "20230206_141606.braidz"
filepath = os.path.join(root_folder, wtcs_file)

In [4]:
df, csvs = braidz.read_braidz(filepath)

Reading /home/buchsbaum/mnt/md0/Experiments/20230206_141606.braidz using pyarrow


In [10]:
wtcs_data = extract_stimulus_centered_data(df, csvs["stim"], max_overflow=10)