In [37]:
%load_ext autoreload
%autoreload 2
import glob
import os
import re
import subprocess
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from scipy.signal import find_peaks, savgol_filter
from scipy.stats import circmean
from sklearn.preprocessing import StandardScaler
from tqdm.contrib.concurrent import process_map
from tqdm.notebook import tqdm
import pathlib
from extract_stimulus_heading_for_camera import process_braidz_file, create_interpolation_function

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
SCREEN2HEADING_DATA = """
screen,heading
0,2.3513283485530456
80,1.2179812647799937
160,0.5031545295746856
240,-0.3078141744904855
320,-0.8746949393526915
400,-1.5019022477483523
480,-2.185375561680841
560,-3.0123437340031307
640,2.3513283485530456
"""
get_heading_func = create_interpolation_function()

In [29]:
# Setup paths
braidz_path = "/gpfs/soma_fs/nfc/nfc3008/Experiments/"
braidz_files = glob.glob(os.path.join(braidz_path, "*.braidz"))

slp_path = "/gpfs/soma_fs/home/buchsbaum/sleap_projects/highspeed/predictions/"
slp_folders = glob.glob(os.path.join(slp_path, "*",))

output_path = "/gpfs/soma_fs/home/buchsbaum/src/sleap_video_analysis/output"

In [42]:
# Loop over slp files and find the corresponding braidz files
for slp_folder in tqdm(slp_folders):
    filename = pathlib.Path(slp_folder).stem
    date = filename.split("_")[0]

    braidz_file_to_search = os.path.join(braidz_path, filename + ".braidz")

    # check if braidz file exists in `braidz_files` list
    braidz_file = [f for f in braidz_files if filename in f]
    braidz_file = braidz_file[0] if braidz_file else None
    if braidz_file is None:
        print(f"No braidz file found for {filename}, trying to search for another file with date {date}")

        # if braidz file not found, search for braidz files with the same date but a different time
        braidz_file = [f for f in braidz_files if date in f]
        braidz_file = braidz_file[0] if braidz_file else None

        if braidz_file is None:
            print(f"No braidz file found for {filename}, skipping")
            continue

    # make sure `braidz_file` exists
    if not os.path.exists(braidz_file):
        print(f"Braidz file {braidz_file} does not exist, skipping")
        continue

    # extract stimulus information from braidz file
    process_braidz_file(braidz_file, "data", get_heading_func)

  0%|          | 0/36 [00:00<?, ?it/s]

Interpolated /gpfs/soma_fs/nfc/nfc3008/Experiments/20250421_174810.braidz to data/20250421_174810.csv
No stim or opto data found in /gpfs/soma_fs/nfc/nfc3008/Experiments/20250120_173635.braidz
No stim or opto data found in /gpfs/soma_fs/nfc/nfc3008/Experiments/20250125_150808.braidz
Interpolated /gpfs/soma_fs/nfc/nfc3008/Experiments/20250430_155345.braidz to data/20250430_155345.csv
Interpolated /gpfs/soma_fs/nfc/nfc3008/Experiments/20241116_154109.braidz to data/20241116_154109.csv
Interpolated /gpfs/soma_fs/nfc/nfc3008/Experiments/20241112_124059.braidz to data/20241112_124059.csv
No braidz file found for 20250504_191950, trying to search for another file with date 20250504
No braidz file found for 20250504_191950, skipping
No braidz file found for 20241202_153526, trying to search for another file with date 20241202
No braidz file found for 20241202_153526, skipping
No stim or opto data found in /gpfs/soma_fs/nfc/nfc3008/Experiments/20250122_161817.braidz
Interpolated /gpfs/soma_fs/

# Helper functions

In [None]:
def calculate_heading_difference(a1, a2):
    # Calculate the angular difference considering the circular nature
    diff = a1 - a2

    # Normalize to [-π, π] range
    return np.arctan2(np.sin(diff), np.cos(diff))


def sg_smooth(array, window_length=51, polyorder=3, **kwargs):
    return savgol_filter(
        array, window_length=window_length, polyorder=polyorder, **kwargs
    )


def unwrap_with_nan(array):
    array[~np.isnan(array)] = np.unwrap(array[~np.isnan(array)])
    return array


def detect_tracking_gaps(df, min_tracked_frames=10, min_gap_size=20):
    """
    Detects if there are multiple tracking sections separated by NaN gaps in the data.

    Parameters:
        df (pd.DataFrame): The dataframe with tracking data (complete_df)
        min_tracked_frames (int): Minimum consecutive frames to consider a valid tracking section
        min_gap_size (int): Minimum size of NaN gap to alert about

    Returns:
        bool: True if multiple tracking sections with gaps are detected
    """
    # Create a mask for rows where all tracking points are valid
    valid_mask = (
        ~pd.isna(df["head.x"])
        & ~pd.isna(df["head.y"])
        & ~pd.isna(df["abdomen.x"])
        & ~pd.isna(df["abdomen.y"])
    )

    # Convert mask to integers (1 for valid, 0 for NaN)
    valid_series = valid_mask.astype(int)

    # Detect changes in the mask (0->1 or 1->0)
    # This creates a series where 1 indicates the start or end of a tracking section
    changes = valid_series.diff().abs()

    # Get indices where changes occur
    change_indices = np.where(changes == 1)[0]

    # If less than 2 changes, there's only one section or no valid sections
    if len(change_indices) < 2:
        return False

    # Calculate segments
    segments = []

    # If the first frame is valid, the first change is the end of a segment
    start_idx = 0 if valid_series.iloc[0] == 1 else change_indices[0]

    for i in range(1 if valid_series.iloc[0] == 1 else 2, len(change_indices), 2):
        if i >= len(change_indices):
            # If we have an odd number of changes and started with a valid segment
            end_idx = len(valid_series) - 1
        else:
            end_idx = change_indices[i] - 1

        # Only include segments that are long enough
        segment_length = end_idx - start_idx + 1
        if segment_length >= min_tracked_frames:
            segments.append((start_idx, end_idx, segment_length))

        # Set up for next segment if there are more changes
        if i + 1 < len(change_indices):
            start_idx = change_indices[i + 1]

    # If we have only one valid segment, no need to alert
    if len(segments) <= 1:
        return False

    # Check gaps between segments
    for i in range(len(segments) - 1):
        current_end = segments[i][1]
        next_start = segments[i + 1][0]
        gap_size = next_start - current_end - 1

        if gap_size >= min_gap_size:
            # print(f"ALERT: Multiple tracking sections detected!")
            # print(f"  Section 1: Frames {segments[i][0]}-{segments[i][1]} ({segments[i][2]} frames)")
            # print(f"  Gap: {gap_size} frames with NaNs")
            # print(f"  Section 2: Frames {segments[i+1][0]}-{segments[i+1][1]} ({segments[i+1][2]} frames)")
            return True

    return False


def savgol_filter_with_nans(y, window_length, polyorder, **kwargs):
    """
    Apply savgol_filter to an array that contains NaNs.
    The filter is only applied to contiguous segments of non-NaN values.

    Parameters:
    -----------
    y : array_like
        The data to be filtered
    window_length : int
        The length of the filter window (must be odd)
    polyorder : int
        The order of the polynomial used to fit the samples
    **kwargs : dict
        Additional arguments to pass to savgol_filter

    Returns:
    --------
    y_filtered : ndarray
        The filtered data with NaNs preserved in their original locations
    """
    # Create a copy of the input array to avoid modifying the original
    y_filtered = np.copy(y)

    # Find indices of non-NaN values
    valid_indices = ~np.isnan(y)

    if not np.any(valid_indices):
        return y_filtered  # Return original if all values are NaN

    # Find contiguous segments of valid data
    diff_indices = np.diff(np.concatenate(([0], valid_indices.astype(int), [0])))
    start_indices = np.where(diff_indices == 1)[0]
    end_indices = np.where(diff_indices == -1)[0]
    segments = zip(start_indices, end_indices)

    # Apply savgol_filter to each segment separately
    for start, end in segments:
        # Only apply filter if the segment is long enough
        if end - start >= window_length:
            y_filtered[start:end] = savgol_filter(
                y[start:end], window_length, polyorder, **kwargs
            )
        # Leave shorter segments unfiltered

    return y_filtered

In [None]:
def process_data(stim_csvs_folder, pre_range=[0, 400], post_range=[400, 750]):
    """
    This function accepts a folder with all the csv files that contain the stimulus data as
    extracted from the braid recording.
    Then, for each file, it finds the correct folder with the converted slp files, and
    inside that folder finds the correct file that matches each row (obj_id + frame) in the stim csv file.

    It then loads the data from that file, and calculates the heading difference between pre and post
    stimulus data, as well as the heading difference between pre stimulus and the stimulus heading.
    The results are returned as a pandas DataFrame.

    Parameters:
        stim_csvs_folder (str): The folder containing the stimulus CSV files.
        pre_range (list): The range of frames to consider for pre-stimulus data.
        post_range (list): The range of frames to consider for post-stimulus data.

    Returns:
        pd.DataFrame: A DataFrame containing the processed data with heading differences.
    """
    # Create an empty list to collect all the data
    all_data = []

    # get all csv files in the stim_csvs_folder (these are the stim files)
    stim_csvs = sorted(glob.glob(os.path.join(stim_csvs_folder, "*.csv")))

    # define pattern recognition for filenames
    pattern = r"obj_id_(\d+)_frame_(\d+)"

    # loop over all files
    for stim_csv in stim_csvs:
        print(f"==== Processing {stim_csv} ====")
        stim_df = pd.read_csv(stim_csv)  # read the csv

        # now get the correct folder for the stim file
        slp2csv_folder = os.path.join(
            stim_csvs_folder,
            os.path.join(
                os.path.basename(os.path.normpath(stim_csv)).replace(".csv", "")
            ),
        )

        # and get all the files from that folder
        slp2csv_files = sorted(glob.glob(os.path.join(slp2csv_folder, "*.csv")))

        # loop over the rows of each stim file
        for idx, row in stim_df.iterrows():
            # extract data for each stim row
            stim_obj_id = int(row["obj_id"])
            stim_frame = int(row["frame"])
            stim_heading = float(row["stim_heading"])

            # Find the matching csv file
            matching_file = None
            for file in slp2csv_files:
                match = re.search(pattern, file)
                if match:
                    file_obj_id = int(match.group(1))
                    file_frame = int(match.group(2))

                    if file_obj_id == stim_obj_id and file_frame == stim_frame:
                        matching_file = file
                        break

            # if no matching file was found, skip
            if matching_file is None:
                continue

            # Load the matching file
            data_df = pd.read_csv(matching_file)

            # Check if the original data has too few tracked frames
            if len(data_df) < 51:
                # print(f"Skipping file with insufficient data: {matching_file}")
                continue

            # Create an empty DataFrame with the same structure as data_df
            complete_df = pd.DataFrame(columns=data_df.columns)

            # Set dtypes to match the original dataframe
            for col in data_df.columns:
                complete_df[col] = complete_df[col].astype(data_df[col].dtype)

            # Fill the frame_idx column with all possible frames (0-749)
            complete_df["frame_idx"] = np.array(range(750))

            # Set the index to frame_idx for easier merging
            complete_df = complete_df.set_index("frame_idx")
            data_df_indexed = data_df.set_index("frame_idx")

            # Update the complete_df with values from the original data_df
            complete_df.update(data_df_indexed)

            # Reset index to make frame_idx a column again
            complete_df = complete_df.reset_index()

            # Example usage in your code:
            if detect_tracking_gaps(
                complete_df, min_tracked_frames=10, min_gap_size=20
            ):
                has_tracking_gaps = True

            # Now interpolate to fill the gaps in tracking data
            data_df_interp = complete_df.interpolate(
                method="linear", limit_direction="both", limit=25
            )

            # extract all data and apply smoothing
            frames = data_df_interp["frame_idx"].to_numpy()
            head_x = savgol_filter_with_nans(
                data_df_interp["head.x"].to_numpy(), window_length=51, polyorder=3
            )
            head_y = savgol_filter_with_nans(
                data_df_interp["head.y"].to_numpy(), window_length=51, polyorder=3
            )
            abdomen_x = savgol_filter_with_nans(
                data_df_interp["abdomen.x"].to_numpy(), window_length=51, polyorder=3
            )
            abdomen_y = savgol_filter_with_nans(
                data_df_interp["abdomen.y"].to_numpy(), window_length=51, polyorder=3
            )

            # head_x = sg_smooth(data_df_interp["head.x"].to_numpy())
            # head_y = sg_smooth(data_df_interp["head.y"].to_numpy())
            # abdomen_x = sg_smooth(data_df_interp["abdomen.x"].to_numpy())
            # abdomen_y = sg_smooth(data_df_interp["abdomen.y"].to_numpy())

            # extract all frames in pre and post stimulus ranges
            pre_indices = np.where((frames >= pre_range[0]) & (frames < pre_range[1]))[
                0
            ]
            post_indices = np.where(
                (frames >= post_range[0]) & (frames < post_range[1])
            )[0]

            # Count frames with valid (non-NaN) tracking data in pre-range
            pre_valid_mask = (
                ~np.isnan(head_x[pre_indices])
                & ~np.isnan(head_y[pre_indices])
                & ~np.isnan(abdomen_x[pre_indices])
                & ~np.isnan(abdomen_y[pre_indices])
            )
            pre_valid_count = np.sum(pre_valid_mask)

            # Count frames with valid (non-NaN) tracking data in post-range
            post_valid_mask = (
                ~np.isnan(head_x[post_indices])
                & ~np.isnan(head_y[post_indices])
                & ~np.isnan(abdomen_x[post_indices])
                & ~np.isnan(abdomen_y[post_indices])
            )
            post_valid_count = np.sum(post_valid_mask)

            # Skip if not enough valid frames in these ranges
            if pre_valid_count < 10 or post_valid_count < 10:
                continue

            # Keep only valid indices
            pre_indices = pre_indices[pre_valid_mask]
            post_indices = post_indices[post_valid_mask]

            # extract pre-stimulus coordinates
            head_x_pre = head_x[pre_indices]
            head_y_pre = head_y[pre_indices]
            abdomen_x_pre = abdomen_x[pre_indices]
            abdomen_y_pre = abdomen_y[pre_indices]

            # extract post-stimulus coordinates
            head_x_post = head_x[post_indices]
            head_y_post = head_y[post_indices]
            abdomen_x_post = abdomen_x[post_indices]
            abdomen_y_post = abdomen_y[post_indices]

            # calculate heading for each frame (angle of vector from abdomen to head)
            pre_heading = np.arctan2(
                head_y_pre - abdomen_y_pre, head_x_pre - abdomen_x_pre
            )
            post_heading = np.arctan2(
                head_y_post - abdomen_y_post, head_x_post - abdomen_x_post
            )

            # calculate circular mean of headings (accounts for circular nature of angle data)
            pre_heading_mean = circmean(pre_heading, high=np.pi, low=-np.pi)
            post_heading_mean = circmean(post_heading, high=np.pi, low=-np.pi)

            # calculate heading differences
            try:
                # Calculate the difference between post-stimulus and pre-stimulus headings
                prepost_heading_difference = calculate_heading_difference(
                    post_heading_mean, pre_heading_mean
                )

                # Calculate the difference between stimulus heading and pre-stimulus heading
                prestim_heading_difference = calculate_heading_difference(
                    stim_heading, pre_heading_mean
                )

                # Calculate the difference between stimulus heading and post-stimulus heading
                poststim_heading_difference = calculate_heading_difference(
                    stim_heading, post_heading_mean
                )

                # Determine if the fly turned toward or away from the stimulus
                # Get sign of stimulus position relative to fly
                stimulus_direction = np.sign(prestim_heading_difference)

                # Get sign of the turn
                turn_direction = np.sign(prepost_heading_difference)

                # If stimulus_direction and turn_direction have opposite signs,
                # the fly turned away from the stimulus
                turned_away = stimulus_direction * turn_direction < 0

                # Create a copy of the row data and add the new calculations
                row_data = row.to_dict()  # Convert the row to a dictionary
                row_data["prepost_heading_difference"] = prepost_heading_difference
                row_data["prestim_heading_difference"] = prestim_heading_difference
                row_data["poststim_heading_difference"] = poststim_heading_difference
                row_data["pre_heading"] = pre_heading_mean
                row_data["post_heading"] = post_heading_mean
                row_data["turned_away"] = turned_away
                row_data["turn_direction"] = "Away" if turned_away else "Toward"

                # Append to the all_data list
                all_data.append(row_data)

            except ValueError as e:
                print(f"Error calculating heading difference: {e}")
                continue

    # Create a DataFrame from all the collected data
    result_df = pd.DataFrame(all_data)

    # Now result_df contains all the data from all files with the heading differences
    print(f"Combined DataFrame has {len(result_df)} rows")
    return result_df

In [None]:
cs_results_df = process_data(
    "/gpfs/soma_fs/home/buchsbaum/src/sleap_video_analysis/output/canton-s/"
)
native_results_df = process_data(
    "/gpfs/soma_fs/home/buchsbaum/src/sleap_video_analysis/output/native/"
)

In [None]:
# bin data by `prestim_heading_difference`
# where -45 to +45 is `front`, -45 to -135 is `left`, +45 to +135 is `right`, rest is 'back'
def bin_heading_difference(row):
    if (
        row["prestim_heading_difference"] >= -np.pi / 4
        and row["prestim_heading_difference"] <= np.pi / 4
    ):
        return "front"
    elif (
        row["prestim_heading_difference"] > np.pi / 4
        and row["prestim_heading_difference"] <= 3 * np.pi / 4
    ):
        return "right"
    elif (
        row["prestim_heading_difference"] < -np.pi / 4
        and row["prestim_heading_difference"] >= -3 * np.pi / 4
    ):
        return "left"
    else:
        return "back"


# Apply the function to create a new column
cs_results_df["prestim_heading_bin"] = cs_results_df.apply(
    bin_heading_difference, axis=1
)
native_results_df["prestim_heading_bin"] = native_results_df.apply(
    bin_heading_difference, axis=1
)

In [None]:
def plot_histogram_for_all_bins(results_df):
    for bin in results_df["prestim_heading_bin"].unique():
        fig, ax = plt.subplots()
        bin_df = results_df[results_df["prestim_heading_bin"] == bin]
        sns.histplot(
            data=bin_df,
            x="prepost_heading_difference",
            bins=np.linspace(-np.pi, np.pi, 30),
            common_bins=True,
            stat="density",
            label=bin,
            ax=ax,
            common_norm=True,
            kde=True,
        )

        # add line for median and mean
        mean = circmean(bin_df["prepost_heading_difference"], high=np.pi, low=-np.pi)

        ax.axvline(mean, color="g", linestyle="--", label="Mean")
        ax.set_title(f"Heading Difference Histogram - {bin}")
        plt.tight_layout()
        plt.show()


plot_histogram_for_all_bins(cs_results_df)
plot_histogram_for_all_bins(native_results_df)

# Clustering

In [None]:
def get_data_for_clustering(stim_csvs_folder, window=0):
    # Create a list to store all feature matrices
    all_features = []
    # Create a list to store metadata
    all_metadata = []

    # get all csv files in the stim_csvs_folder (these are the stim files)
    stim_csvs = sorted(glob.glob(os.path.join(stim_csvs_folder, "*.csv")))

    # define pattern recognition for filenames
    pattern = r"obj_id_(\d+)_frame_(\d+)"

    # Inside your nested loops, after extracting features:
    for i, stim_csv in enumerate(stim_csvs):
        print(f"Processing file {stim_csv} ({i} out of {len(stim_csvs)})")
        # print(f"==== Processing {stim_csv} ====")
        stim_df = pd.read_csv(stim_csv)  # read the csv

        # now get the correct folder for the stim file
        slp2csv_folder = os.path.join(
            stim_csvs_folder,
            os.path.join(
                os.path.basename(os.path.normpath(stim_csv)).replace(".csv", "")
            ),
        )

        # and get all the files from that folder
        slp2csv_files = sorted(glob.glob(os.path.join(slp2csv_folder, "*.csv")))

        for j, (idx, row) in enumerate(stim_df.iterrows()):
            # extract data for each stim row
            stim_obj_id = int(row["obj_id"])
            stim_frame = int(row["frame"])
            stim_heading = float(row["stim_heading"])

            # Find the matching csv file
            matching_file = None
            for file in slp2csv_files:
                match = re.search(pattern, file)
                if match:
                    file_obj_id = int(match.group(1))
                    file_frame = int(match.group(2))

                    if file_obj_id == stim_obj_id and file_frame == stim_frame:
                        matching_file = file
                        break

            # if no matching file was found, skip
            if matching_file is None:
                continue

            # Load the matching file
            data_df = pd.read_csv(matching_file)

            # Create an empty DataFrame with the same structure as data_df
            complete_df = pd.DataFrame(columns=data_df.columns)

            # Set dtypes to match the original dataframe
            for col in data_df.columns:
                complete_df[col] = complete_df[col].astype(data_df[col].dtype)

            # Fill the frame_idx column with all possible frames (0-749)
            complete_df["frame_idx"] = list(range(750))

            # Set the index to frame_idx for easier merging
            complete_df = complete_df.set_index("frame_idx")
            data_df_indexed = data_df.set_index("frame_idx")

            # Update the complete_df with values from the original data_df
            complete_df.update(data_df_indexed)

            # Reset index to make frame_idx a column again
            complete_df = complete_df.reset_index()

            # Example usage in your code:

            # Now interpolate to fill the gaps in tracking data
            data_df_interp = complete_df.interpolate(
                method="linear", limit_direction="both", limit=25
            )

            # extract all data and apply smoothing
            frames = data_df_interp["frame_idx"].to_numpy()
            head_x = savgol_filter_with_nans(
                data_df_interp["head.x"].to_numpy(), window_length=51, polyorder=3
            )
            head_y = savgol_filter_with_nans(
                data_df_interp["head.y"].to_numpy(), window_length=51, polyorder=3
            )
            abdomen_x = savgol_filter_with_nans(
                data_df_interp["abdomen.x"].to_numpy(), window_length=51, polyorder=3
            )
            abdomen_y = savgol_filter_with_nans(
                data_df_interp["abdomen.y"].to_numpy(), window_length=51, polyorder=3
            )

            # Calculate different features
            heading = np.arctan2(head_y - abdomen_y, head_x - abdomen_x)
            heading_unwrap = unwrap_with_nan(heading)
            heading_change = savgol_filter(
                np.gradient(heading_unwrap, 1 / 500), window_length=21, polyorder=3
            )
            distance_between_head_and_abdomen = np.sqrt(
                (head_x - abdomen_x) ** 2 + (head_y - abdomen_y) ** 2
            )

            # calculate centroid based on head and abodomen coordinates
            centroid_x = (head_x + abdomen_x) / 2
            centroid_y = (head_y + abdomen_y) / 2
            centroid_velocity_x = savgol_filter(
                np.gradient(centroid_x, 1 / 500), window_length=21, polyorder=3
            )
            centroid_velocity_y = savgol_filter(
                np.gradient(centroid_y, 1 / 500), window_length=21, polyorder=3
            )
            centroid_velocity = np.sqrt(centroid_velocity_x**2 + centroid_velocity_y**2)
            centroid_acceleration = savgol_filter(
                np.gradient(centroid_velocity, 1 / 500), window_length=21, polyorder=3
            )

            theta = np.arctan2(centroid_velocity_y, centroid_velocity_x)
            theta_unwrap = unwrap_with_nan(theta)
            angular_velocity = savgol_filter(
                np.gradient(theta_unwrap, 1 / 500), window_length=21, polyorder=3
            )

            # find peaks in the angular velocity
            positive_peaks, _ = find_peaks(
                angular_velocity, height=np.deg2rad(1000), distance=50
            )
            negative_peaks, _ = find_peaks(
                -angular_velocity, height=np.deg2rad(1000), distance=50
            )
            all_peaks = np.sort(np.concatenate((positive_peaks, negative_peaks)))

            # find if there are any peaks in range
            response_peak = [peak for peak in all_peaks if 350 < peak < 450]

            if len(response_peak) > 0:
                response_peak = response_peak[0]

                if window == 0:
                    indices = response_peak
                else:
                    indices = range(response_peak - window, response_peak + window)

                heading = heading[indices]

                if np.any(np.isnan(heading)):
                    continue

                heading_change = heading_change[indices]
                distance_between_head_and_abdomen = distance_between_head_and_abdomen[
                    indices
                ]
                centroid_velocity = centroid_velocity[indices]
                centroid_acceleration = centroid_acceleration[indices]
                angular_velocity = angular_velocity[indices]

                frames = frames[indices]
                x = centroid_x[indices]
                y = centroid_y[indices]

                temp_dict = {
                    "file": np.array([i] * len(indices)),
                    "time": frames,
                    "x": x,
                    "y": y,
                    "heading": heading,
                    "heading_change": heading_change,
                    "distance_between_head_and_abdomen": distance_between_head_and_abdomen,
                    "centroid_velocity": centroid_velocity,
                    "centroid_acceleration": centroid_acceleration,
                    "angular_velocity": angular_velocity,
                }

                temp_df = pd.DataFrame(temp_dict)
                all_features.append(temp_df)
    return pd.concat(all_features)

In [None]:
cs_features = get_data_for_clustering(
    "/gpfs/soma_fs/home/buchsbaum/src/sleap_video_analysis/canton-s/", window=25
)
native_features = get_data_for_clustering(
    "/gpfs/soma_fs/home/buchsbaum/src/sleap_video_analysis/native/", window=25
)

In [None]:
cs_features.reset_index(drop=True, inplace=True)
native_features.reset_index(drop=True, inplace=True)

cs_features.interpolate(inplace=True)
native_features.interpolate(inplace=True)

# print number of nan values in each column
print("CS features:")
print(cs_features.isna().sum())
print("Native features:")
print(native_features.isna().sum())

In [None]:
# 1. First, check for and handle zero-variance features
def standardize_with_checks(df, columns_to_transform):
    # Create a copy to avoid modifying the original
    df_std = df.copy()

    # Check for zero variance columns
    for col in df.iloc[:, columns_to_transform].columns:
        if df[col].std() < 1e-10:
            print(f"Warning: Column {col} has near-zero variance")
            # Add tiny amount of noise to prevent zero variance
            df_std[col] += np.random.normal(0, 1e-5, size=len(df))

    # Apply standardization
    scaler = StandardScaler()
    df_std.iloc[:, columns_to_transform] = scaler.fit_transform(
        df_std.iloc[:, columns_to_transform]
    )

    return df_std, scaler


# 2. Modify your UMAP function to handle special metrics properly
def draw_umap_safe(
    data,
    features,
    n_neighbors=15,
    min_dist=0.1,
    n_components=2,
    metric="euclidean",
    title="",
):
    # Get the data matrix
    X = data.iloc[:, features].to_numpy()

    # Set up metric-specific parameters
    metric_params = {}

    if metric == "mahalanobis":
        # Calculate covariance matrix with regularization
        cov = np.cov(X, rowvar=False)
        # Add small regularization to ensure positive definiteness
        cov += np.eye(cov.shape[0]) * 1e-6
        try:
            inv_cov = np.linalg.inv(cov)
            metric_params = {"V": inv_cov}
        except np.linalg.LinAlgError:
            print(
                "Warning: Covariance matrix is singular, using stronger regularization"
            )
            cov += np.eye(cov.shape[0]) * 1e-4
            inv_cov = np.linalg.inv(cov)
            metric_params = {"V": inv_cov}

    elif metric == "seuclidean":
        # Calculate variance for each dimension, avoiding zeros
        variances = np.var(X, axis=0, ddof=1)
        # Ensure no zeros in variances
        variances = np.maximum(variances, 1e-8)
        metric_params = {"V": variances}

    # Create UMAP with proper parameters
    fit = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        n_components=n_components,
        metric=metric,
        metric_kwds=metric_params if metric_params else {},
        random_state=42,  # For reproducibility
    )

    u = fit.fit_transform(X)

    # Create visualization
    fig = plt.figure(figsize=(10, 8))
    if n_components == 1:
        ax = fig.add_subplot(111)
        ax.scatter(u[:, 0], range(len(u)))
    if n_components == 2:
        ax = fig.add_subplot(111)
        scatter = ax.scatter(u[:, 0], u[:, 1], c=data["file"], alpha=0.7)
        plt.colorbar(scatter, label="File")
    if n_components == 3:
        ax = fig.add_subplot(111, projection="3d")
        scatter = ax.scatter(u[:, 0], u[:, 1], u[:, 2], c=data["file"], alpha=0.7)
        plt.colorbar(scatter, label="File")

    plt.title(title, fontsize=18)
    plt.tight_layout()

    return fit, u


# 3. Use these functions in your main code
columns_to_transform = [4, 5, 6, 7, 8, 9]

# Standardize data with checks
cs_features_std, scaler = standardize_with_checks(cs_features, columns_to_transform)

# Now run your experiments with safer implementation
neighbours = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
min_dists = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25]
metrics = ["euclidean", "mahalanobis", "correlation", "cosine"]

pbar = tqdm(total=len(neighbours) * len(min_dists) * len(metrics))
for n in neighbours:
    for m in min_dists:
        for met in metrics:
            try:
                fit, u = draw_umap_safe(
                    cs_features_std,
                    features=columns_to_transform,
                    n_neighbors=n,
                    min_dist=m,
                    metric=met,
                    title=f"CS: n_neighbors={n}, min_dist={m}, metric={met}",
                )
                plt.show()
            except Exception as e:
                print(f"Error with {met} metric: {str(e)}")
            finally:
                pbar.update(1)

In [None]:
import matplotlib.pyplot as plt
import os
from tqdm import tqdm


# 1. First, check for and handle zero-variance features
def standardize_with_checks(df, columns_to_transform):
    # Create a copy to avoid modifying the original
    df_std = df.copy()

    # Check for zero variance columns
    for col in df.iloc[:, columns_to_transform].columns:
        if df[col].std() < 1e-10:
            print(f"Warning: Column {col} has near-zero variance")
            # Add tiny amount of noise to prevent zero variance
            df_std[col] += np.random.normal(0, 1e-5, size=len(df))

    # Apply standardization
    scaler = StandardScaler()
    df_std.iloc[:, columns_to_transform] = scaler.fit_transform(
        df_std.iloc[:, columns_to_transform]
    )

    return df_std, scaler


# 2. Define a function to compute UMAP embedding in parallel (no plotting)
def compute_umap(n, m, met, data, features, n_components=2):
    try:
        # Get the data matrix
        X = data.iloc[:, features].to_numpy()

        # Set up metric-specific parameters
        metric_params = {}

        if met == "mahalanobis":
            # Calculate covariance matrix with regularization
            cov = np.cov(X, rowvar=False)
            # Add small regularization to ensure positive definiteness
            cov += np.eye(cov.shape[0]) * 1e-6
            try:
                inv_cov = np.linalg.inv(cov)
                metric_params = {"V": inv_cov}
            except np.linalg.LinAlgError:
                print(
                    "Warning: Covariance matrix is singular, using stronger regularization"
                )
                cov += np.eye(cov.shape[0]) * 1e-4
                inv_cov = np.linalg.inv(cov)
                metric_params = {"V": inv_cov}

        elif met == "seuclidean":
            # Calculate variance for each dimension, avoiding zeros
            variances = np.var(X, axis=0, ddof=1)
            # Ensure no zeros in variances
            variances = np.maximum(variances, 1e-8)
            metric_params = {"V": variances}

        # Create UMAP with proper parameters
        fit = umap.UMAP(
            n_neighbors=n,
            min_dist=m,
            n_components=n_components,
            metric=met,
            metric_kwds=metric_params if metric_params else {},
            random_state=42,  # For reproducibility
        )

        u = fit.fit_transform(X)

        # Save results to file to avoid memory issues
        result_filename = f"umap_temp/umap_n{n}_m{m}_{met}.npz"
        np.savez(result_filename, embedding=u)

        return True, (n, m, met)

    except Exception as e:
        print(f"Error with n={n}, min_dist={m}, metric={met}: {str(e)}")
        return False, (n, m, met)


# 3. Function to visualize a single UMAP result
def visualize_umap(n, m, met, data, features, save_dir="umap_plots", show_plot=False):
    try:
        # Load saved embedding
        result_filename = f"umap_temp/umap_n{n}_m{m}_{met}.npz"
        if not os.path.exists(result_filename):
            return False

        loaded = np.load(result_filename)
        embedding = loaded["embedding"]

        fig = plt.figure(figsize=(10, 8))

        # Determine number of components from embedding shape
        n_components = embedding.shape[1]

        # Plot based on the number of components
        if n_components == 1:
            plt.scatter(embedding[:, 0], range(len(embedding)))
        elif n_components == 2:
            scatter = plt.scatter(
                embedding[:, 0], embedding[:, 1], c=data["file"], alpha=0.7
            )
            plt.colorbar(scatter, label="File")
        elif n_components == 3:
            ax = fig.add_subplot(111, projection="3d")
            scatter = ax.scatter(
                embedding[:, 0],
                embedding[:, 1],
                embedding[:, 2],
                c=data["file"],
                alpha=0.7,
            )
            plt.colorbar(scatter, label="File")

        plt.title(f"CS: n_neighbors={n}, min_dist={m}, metric={met}", fontsize=18)
        plt.tight_layout()

        # Save the figure
        filename = f"{save_dir}/umap_n{n}_m{m}_{met}.png"
        plt.savefig(filename)

        # Show the plot if requested
        if show_plot:
            plt.show()
        else:
            plt.close(fig)

        return True

    except Exception as e:
        print(f"Error visualizing n={n}, min_dist={m}, metric={met}: {str(e)}")
        return False


# Standardize data with checks
columns_to_transform = [4, 5, 6, 7, 8, 9]
cs_features_std, scaler = standardize_with_checks(cs_features, columns_to_transform)

# Define parameter ranges
neighbours = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
min_dists = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25]
metrics = ["euclidean", "mahalanobis", "correlation", "cosine"]

# Create directories for saving plots and temporary results
os.makedirs("umap_plots", exist_ok=True)
os.makedirs("umap_temp", exist_ok=True)

# Generate all parameter combinations
param_combinations = [
    (n, m, met) for n in neighbours for m in min_dists for met in metrics
]
print(
    f"Running UMAP with {len(param_combinations)} parameter combinations in parallel..."
)

# Process in batches to avoid memory issues
batch_size = 10  # Adjust based on your system's capabilities
total_batches = (len(param_combinations) + batch_size - 1) // batch_size

all_successful = []

for batch_idx in range(total_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(param_combinations))
    current_batch = param_combinations[start_idx:end_idx]

    print(
        f"Processing batch {batch_idx + 1}/{total_batches} ({len(current_batch)} combinations)"
    )

    # Run UMAP computations in parallel
    results = Parallel(n_jobs=-1, backend="loky", verbose=10)(
        delayed(compute_umap)(n, m, met, cs_features_std, columns_to_transform)
        for n, m, met in current_batch
    )

    # Track successful results
    successful_in_batch = [params for success, params in results if success]
    all_successful.extend(successful_in_batch)

    # Allow system to recover before next batch
    time.sleep(1)

# Process results and create visualizations sequentially
print("Creating visualizations...")
for n, m, met in tqdm(all_successful):
    visualize_umap(n, m, met, cs_features_std, columns_to_transform)

# Print summary
print(
    f"Successfully created {len(all_successful)} UMAP plots out of {len(param_combinations)} combinations."
)
print("Plots saved to the 'umap_plots' directory.")

# Optional: clean up temporary files
# import shutil
# shutil.rmtree("umap_temp")
# print("Temporary files removed.")