In [None]:
from typing import List
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
from scipy.signal import find_peaks
import os
import gzip
from tqdm import tqdm
from typing import List, Tuple
from scipy.signal import find_peaks
import pickle

In [None]:

path = "/home/kkumari/PhD/fish-data/long-term-free-swim/"
# path = "C:/PhD/long_term_free_swim"

In [None]:
# cartesian to spherical coordinates
def cart2sph(x,y,z):
    azimuth = np.arctan2(y,x)
    elevation = np.arctan2(z, np.sqrt(x**2 + y**2))
    R = np.sqrt(x**2 + y**2 + z**2)
    return(azimuth, elevation, R)

In [None]:
# This function takes a list of file paths and a list of desired columns as input,
# and returns a dictionary that collates data from the files.
def load_and_collate_data(all_files: List[str], desired_cols: List[str]) -> dict:

    # Create an empty dictionary to store the collated data
    fish_data = {}

    # Iterate over each file path in the given list of files
    for file in tqdm(all_files[0:3], desc="Processing files"):
        # Open the file using gzip and read its contents using pandas
        with gzip.open(file, 'rb') as f:
            df = pd.read_csv(f, usecols=desired_cols)

        # Extract the fish ID from the file path
        fish_id = os.path.basename(file)[:2]

        # Check if the fish ID already exists in the fish_data dictionary
        if fish_id not in fish_data:
            # If the fish ID does not exist, create a new entry with an empty dataframe and list of files
            fish_data[fish_id] = {'df': [], 'files': []}

        # Append the dataframe and file path to the corresponding fish ID entry in the fish_data dictionary
        fish_data[fish_id]['df'].append(df)
        fish_data[fish_id]['files'].append(file)

    # Return the collated fish_data dictionary
    return fish_data

In [None]:
def calculate_angles(df_diff):
    dx = df_diff["fishx"].interpolate(method='bfill')
    dy = df_diff["fishy"].interpolate(method='bfill')
    dz = df_diff["fishz"].interpolate(method='bfill')

    angle_wrapped = np.arctan2(dy, dx)
    last = 0
    angles = []
    for phi in angle_wrapped:
        while phi < last - np.pi:
            phi += 2 * np.pi
        while phi > last + np.pi:
            phi -= 2 * np.pi
        last = phi
        angles.append(phi)

    return angles


def create_velocity_dataframe(df, fHz):
    df_diff = df.diff(periods=1, axis=0)

    dt = 1 / fHz
    velocity = np.sqrt(df_diff["dx"] ** 2 + df_diff["dy"] ** 2 + df_diff["dz"] ** 2) / dt

    velocity_dataframe = pd.DataFrame({
        'Time': df['realtime'],
        'Velocity': velocity,
        'X': df['fishx'],
        'Y': df['fishy'],
        'Z': df['fishz'],
        'Angle': calculate_angles(df_diff)
    })

    velocity_dataframe.interpolate(inplace=True)

    return velocity_dataframe


def calculate_additional_variables_per_trial(fish_data: dict, fHz: int):
    dangles_dict = {}
    for fish_id, data in fish_data.items():
        data['angles'] = []
        data['avg_velocity'] = []
        data['angles_at_peaks'] = []
        data['peak_times'] = []
        data['dangles'] = []
        data['velocity_dataframe'] = []
        data['peaks'] = []
        data['velocity'] = []
        data['angles_at_peaks_normalized'] = []
        data['angles_at_peaks_unwrapped'] = []

        for df in data['df']:
            df["realtime"] = df["realtime"] - df["realtime"].iloc[0]

            df_diff = df.diff(periods=1, axis=0)
            steps = np.sqrt(df["fishx"] ** 2 + df["fishy"] ** 2 + df["fishz"] ** 2)
            df["steps"] = steps
            max_stepsize = 0.02
            large_steps = df['steps'] > max_stepsize
            w = 10
            selected_columns = ['fishz', 'fishy', 'fishx']
            large_step_indices = large_steps[large_steps].index.values
            for i in range(0, len(large_step_indices)):
                lsi = large_step_indices[i]
                df.loc[lsi - w:lsi + w, selected_columns] = np.nan

            err = 0.001
            df.loc[df['fishz'] < -(0.09 + err), selected_columns] = np.nan
            df.loc[df['fishz'] > 0 + err, selected_columns] = np.nan
            zoffset = 0.11
            azimuth, elevation, R = cart2sph(df['fishx'], df['fishy'], df['fishz'] - zoffset)
            err = 0.005
            df.loc[R > 0.2 + err, selected_columns] = np.nan
            df.loc[R < 0.11 - err, selected_columns] = np.nan

            df["dx"] = df_diff["fishx"].interpolate(method='bfill')
            df["dy"] = df_diff["fishy"].interpolate(method='bfill')
            df["dz"] = df_diff["fishz"].interpolate(method='bfill')

            angles = calculate_angles(df_diff)
            data['angles'].append(angles)

            dt = 1 / fHz
            velocity = np.sqrt(df["dx"] ** 2 + df["dy"] ** 2 + df["dz"] ** 2) / dt
            avg_velocity = velocity.median()
            data['avg_velocity'].append(avg_velocity)

            height = (0.1, 0.5)
            frames_btw_2bouts = round(fHz / 10)
            bout_width = round(fHz / 100)
            prominence = 0.05
            peaks, _ = find_peaks(velocity, height=height, distance=frames_btw_2bouts, width=bout_width,
                                  prominence=prominence)
            angles_at_peaks = [angles[i] for i in peaks]
            data['angles_at_peaks'].append(angles_at_peaks)

            peak_times = [df["realtime"].iloc[i] for i in peaks]
            data['peak_times'].append(peak_times)

            angles_at_peaks_normalized = np.mod(angles_at_peaks, 2 * np.pi) - np.pi
            angles_at_peaks_unwrapped = np.unwrap(angles_at_peaks_normalized)
            angles_at_peaks_diff = np.diff(angles_at_peaks_unwrapped)
            angles_at_peaks_diff = np.mod(angles_at_peaks_diff + np.pi, 2 * np.pi) - np.pi
            dangles = angles_at_peaks_diff
            data['dangles'].append(dangles)

            velocity_dataframe = create_velocity_dataframe(df, fHz)
            data['velocity_dataframe'].append(velocity_dataframe)

            dangles_dict[fish_id] = dangles

    return dangles_dict


In [None]:
def calculate_turning_angle_properties(dangles):
    # Calculate the number of clockwise and counterclockwise turns
    counterclockwise_turns = np.sum(dangles > 0)
    clockwise_turns = np.sum(dangles < 0)

    # Calculate probability of clockwise and counterclockwise turns
    probability_counterclockwise_turns = counterclockwise_turns / (counterclockwise_turns + clockwise_turns)
    probability_clockwise_turns = clockwise_turns / (counterclockwise_turns + clockwise_turns)

    # Get sequence of right and left turns as 1 and -1
    turns = np.sign(dangles)

    def streak_lengths(turns):
        if len(turns) == 0:
            return np.array([])  # return empty array for empty input

        streaks = []
        current_streak = 1  # start with a streak of 1

        for i in range(1, len(turns)):
            if turns[i] == turns[i - 1]:  # if current turn is same as previous
                current_streak += 1  # increment streak count
            else:  # if current turn is different
                streaks.append(current_streak)  # add the streak to the list
                current_streak = 1  # reset streak count

        streaks.append(current_streak)  # add the last streak
        return np.array(streaks)

    streaks = streak_lengths(turns)

    # Now let's bundle all this data into a dictionary and return it
    turning_properties = {
        'counterclockwise_turns': counterclockwise_turns,
        'clockwise_turns': clockwise_turns,
        'probability_counterclockwise_turns': probability_counterclockwise_turns,
        'probability_clockwise_turns': probability_clockwise_turns,
        'turns': turns,
        'streaks': streaks
    }

    return turning_properties


In [None]:
all_files = sorted(glob.glob(os.path.join(path, "*.csv.gz")))[:5]  # Modify the number of files to process
desired_cols = ['fishx', 'fishy', 'fishz', 'realtime']
fish_data = load_and_collate_data(all_files, desired_cols)
dangles_dict = calculate_additional_variables_per_trial(fish_data, fHz=100)

In [None]:
turning_properties_dict = calculate_turning_angle_properties(fish_data['dangles'])
