In [1]:
from typing import List
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import os
import gzip
from tqdm import tqdm
import seaborn as sns
import natsort
from scipy.signal import savgol_filter
import hashlib


In [2]:
def read_files(folder_path):
    all_files = glob.glob(os.path.join(folder_path, "*.csv.gz"))
    all_files = natsort.natsorted(all_files)  # Natural sort files

    fish_data = {}

    for file_path in all_files:
        filename = os.path.basename(file_path)
        fish_id, trial_id = filename.split('_')[0], filename.split('_')[1]

        # Only process the first 3 trials for each fish
        if fish_id not in fish_data:
            fish_data[fish_id] = {}
        if trial_id not in fish_data[fish_id] and len(fish_data[fish_id]) < 3:
            fish_data[fish_id][trial_id] = file_path

    # Debug: Print out the files being processed for each fish
    for fish_id, trials in fish_data.items():
        print(f"Fish ID: {fish_id}, Trials: {list(trials.keys())}")

    return fish_data

In [3]:
def calculate_time_stamp(df):
    df['time_stamp'] = (df['realtime'] - df['start_time'][0]) / 100
    return df


In [4]:
def cartesian_to_spherical(x, y, z):
    r = np.sqrt(x**2 + y**2 + z**2)
    theta = np.arctan2(y, x)  # Azimuth
    phi = np.arccos(z / r)    # Inclination
    return r, theta, phi

def spherical_to_cartesian(r, theta, phi):
    x = r * np.sin(phi) * np.cos(theta)
    y = r * np.sin(phi) * np.sin(theta)
    z = r * np.cos(phi)
    return x, y, z

In [5]:
def filter_large_jumps(df, max_stepsize=0.02, window_size=10):
    # Calculate the step size for each frame
    steps = np.sqrt(df['fishx'].diff()**2 + df['fishy'].diff()**2 + df['fishz'].diff()**2)
    
    # Identify large steps
    large_steps = steps > max_stepsize
    large_step_indices = large_steps[large_steps].index.values

    # print first 5 large steps

    print(f"Found {len(large_step_indices)} large steps")

    # Filter out the large steps
    for index in large_step_indices:
        df.loc[index-window_size:index+window_size, ['fishz', 'fishy', 'fishx']] = np.nan

    return df

In [6]:
# print length of df


In [7]:
def filter_spherical(df, err=0.005, z_offset=0.11):
    # Convert to spherical coordinates
    azimuth, elevation, R = cartesian_to_spherical(df['fishx'], df['fishy'], df['fishz'] - z_offset)

    # Apply spherical filters
    df.loc[R > 0.2 + err, ['fishz', 'fishy', 'fishx']] = np.nan
    df.loc[R < 0.11 - err, ['fishz', 'fishy', 'fishx']] = np.nan

    return df

In [8]:
def calculate_angles_at_peaks_efficient(df, fHz):
    # Find peaks
    peaks, _ = find_velocity_peaks(df, fHz)

    # Calculate angles at peaks
    peak_angles = [calculate_angle(df, peaks[i-1], peaks[i]) for i in range(1, len(peaks))]

    df['peak_angles'] = np.nan

    # Assign values to the rows corresponding to the peak indices
    df.loc[peaks[1:], 'peak_angles'] = peak_angles  # Skip the first peak as there's no angle for it

    # print("Unique peak_angles: ", df['peak_angles'].unique())

    # Handle timestamps, interbout durations, and turn bias
    df['time_stamp_peak'] = np.nan
    df.loc[peaks, 'time_stamp_peak'] = df.loc[peaks, 'time_stamp']

    df['interbout_duration'] = np.nan
    df.loc[peaks[1:], 'interbout_duration'] = df.loc[peaks[1:], 'time_stamp'] - df.loc[peaks[:-1], 'time_stamp'].values

    angle_differences = np.diff(peak_angles)
    turn_bias = np.sign(angle_differences)
    turn_bias = np.append(turn_bias, np.nan)  # Append NaN to match lengths
    df.loc[peaks[1:], 'turn_bias'] = turn_bias


    # # print unique df.time_stamp_peak and df.interbout_duration
    # print("Unique time_stamp_peak: ", df['time_stamp_peak'].unique())
    # print("Unique interbout_duration: ", df['interbout_duration'].unique())

    return df, peaks, peak_angles

In [9]:
def calculate_angle(df, prev_peak, current_peak):
    v1 = df.loc[prev_peak, ['dx', 'dy']]
    v2 = df.loc[current_peak, ['dx', 'dy']]

    angle_rad = np.arctan2(v2[1] - v1[1], v2[0] - v1[0])  # Difference in dy and dx
    angle_degrees = np.degrees(angle_rad)
    # center angle between -180 and 180
    angle_degrees = (angle_degrees + 180) % 360 - 180
    # save angle_degrees to df
    df.loc[current_peak, 'angle_degrees'] = angle_degrees

    return angle_degrees

In [10]:
def calculate_cumulative_angles(df, peaks):
    # Calculate angles at peaks (as per your Method 1)
    peak_angles = calculate_angles_at_peaks_efficient(df, peaks)
    
    # Calculate cumulative angles
    cumulative_angles = np.cumsum(peak_angles)

    return cumulative_angles

In [11]:
def find_velocity_peaks(df, fHz):
    height_min, height_max = 0.015, 0.4
    distance = round(fHz / 10)
    width = round(fHz / 100)
    prominence = 0.03

    peaks, properties = find_peaks(df['velocity'], height=(height_min, height_max), 
                                   distance=distance, width=width, prominence=prominence)
    
    # print first 5 peaks

    print(peaks[:5])

    # # plot the peaks and the velocity
    # plt.figure(figsize=(20, 10))
    # plt.plot(df['velocity'], label='velocity')
    # plt.plot(peaks, df['velocity'][peaks], "x")
    # plt.title('Velocity and peaks')
    # plt.legend()
    # plt.show()

    print(f"Found {len(peaks)} peaks")
    return peaks, properties


In [12]:
def process_file(file_path, fHz=100):
    # Load the data
    desired_cols = ['realtime', 'fishx', 'fishy', 'fishz', 'start_time']
    df = pd.read_csv(file_path, compression='gzip', usecols=desired_cols)#, nrows=2000)
    
    # Apply filters
    df = filter_large_jumps(df)

    # df = filter_spherical(df)

    # Smooth the data using savgol_filter
    df['smooth_fishx'] = savgol_filter(df['fishx'], 11, 1)
    df['smooth_fishy'] = savgol_filter(df['fishy'], 11, 1)
    df['smooth_fishz'] = savgol_filter(df['fishz'], 11, 1)

    # Calculate the differences on the smoothed data
    df['dx'] = df['smooth_fishx'].diff().fillna(0)
    df['dy'] = df['smooth_fishy'].diff().fillna(0)
    df['dz'] = df['smooth_fishz'].diff().fillna(0)
    
    df['velocity'] = np.sqrt(df['dx']**2 + df['dy']**2 + df['dz']**2) / (1/fHz)

    #  add time_stamp column
    df = calculate_time_stamp(df)

    #  Call calculate_angles_at_peaks_efficient to get peaks and related calculations
    df, peaks, peak_angles = calculate_angles_at_peaks_efficient(df, fHz)

    # Center angles from -180 to 180 degrees and unwrap them
    wrapped_angles = [(angle + 180) % 360 - 180 for angle in peak_angles]
    unwrapped_angles = np.unwrap(np.radians(wrapped_angles))

    # Calculate cumulative angles
    cumulative_angles = np.cumsum(np.degrees(unwrapped_angles))

    # print(df.tail())

    # Other calculations such as interbout duration, turn bias, etc., are included in calculate_angles_at_peaks_efficient

    return df, peaks, peak_angles, wrapped_angles, unwrapped_angles, cumulative_angles


In [13]:
def aggregate_trial_results(fish_data, fHz):
    aggregated_data = {}
    for fish_id, trials in fish_data.items():
        aggregated_data[fish_id] = {'turn_bias_per_trial': {}}
        for trial_id, file_path in trials.items():
            df, peaks, _, _, turn_bias = process_file(file_path, fHz)
            left_turns = np.sum(np.array(turn_bias) == -1)
            right_turns = np.sum(np.array(turn_bias) == 1)
            total_turns = len(turn_bias) - 1  # Subtracting 1 for the appended NaN
            left_turn_prop = left_turns / total_turns if total_turns > 0 else 0
            right_turn_prop = right_turns / total_turns if total_turns > 0 else 0
            aggregated_data[fish_id]['turn_bias_per_trial'][trial_id] = {
                'left_turn_prop': left_turn_prop,
                'right_turn_prop': right_turn_prop
                
            }

    return aggregated_data


In [14]:
def aggregate_fish_results(fish_data, fHz):
    aggregated_data = {}
    for fish_id, trials in fish_data.items():
        for trial_id, file_path in trials.items():
            df, peaks, peak_angles, wrapped_angles, unwrapped_angles, cumulative_angles = process_file(file_path, fHz)
            aggregated_data[file_path] = {  # Use file_path as the key
                'df': df,
                'peaks': peaks,
                'peak_angles': peak_angles,
                'wrapped_angles': wrapped_angles,
                'unwrapped_angles': unwrapped_angles,
                'cumulative_angles': cumulative_angles
                
        }
    return aggregated_data

In [15]:
def extract_fish_trial(full_path):
    """
    Extracts fish ID and trial number from the full file path.
    Assumes the filename is in a format like 'FishID_TrialNumber_otherinfo.csv.gz'.
    """
    # Extract the filename from the full path
    filename = os.path.basename(full_path)
    print(f"Debug: Extracted filename: {filename}")  # Debug print

    parts = filename.split('_')
    print(f"Debug: Split parts: {parts}")  # Debug print

    if len(parts) >= 2:
        # Extract Fish ID and Trial Number
        fish_id, trial = parts[0], parts[1]
        return fish_id, trial
    return 'Unknown', 'Unknown'

def save_dataframe(df, output_folder_path, file_name):
    """
    Saves the given dataframe.
    """
    full_path = os.path.join(output_folder_path, file_name)
    print(f"Debug: Saving dataframe to {full_path}")  # Debug print

    df.to_csv(full_path, index=False)
    print(f"Dataframe saved as '{full_path}'")

def save_dataframes_from_nested_dict(nested_dict, output_folder_path, file_name_suffix):
    """
    Iterates over a nested dictionary of dataframes and saves each dataframe.
    """
    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder_path):
        os.makedirs(output_folder_path)

    for outer_key, inner_dict in nested_dict.items():
        if isinstance(inner_dict, dict):
            for inner_key, df in inner_dict.items():
                if isinstance(df, pd.DataFrame):
                    fish_id, trial = extract_fish_trial(outer_key)
                    file_identifier = f'{fish_id}_{trial}' if fish_id != 'Unknown' and trial != 'Unknown' else 'Unknown_Unknown'
                    file_name = f'{file_identifier}_{inner_key}_{file_name_suffix}.csv'
                    save_dataframe(df, output_folder_path, file_name)
                else:
                    print(f"Skipping non-dataframe item for key: {inner_key} in nested dict of {outer_key}")
        else:
            print(f"Skipping non-dict item for key: {outer_key}")




In [16]:
fHz = 100  # or the appropriate value for your data
folder_path = "/home/kkumari/PhD/fish-data/long-term-free-swim/"
# Read files from folder and organize by fish and trial
fish_data = read_files(folder_path)
# Aggregate results across trials for each fish
# trial_results = aggregate_trial_results(fish_data, fHz)
# Aggregate results across different fishes
fish_results = aggregate_fish_results(fish_data, fHz)


Fish ID: 01, Trials: ['T1', 'T2', 'T3']
Fish ID: 02, Trials: ['T1', 'T2', 'T3']
Fish ID: 03, Trials: ['T1', 'T2', 'T4']
Fish ID: 04, Trials: ['T1', 'T2', 'T3']
Fish ID: 05, Trials: ['T1', 'T2', 'T3']
Fish ID: 06, Trials: ['T1', 'T2', 'T3']
Fish ID: 07, Trials: ['T1', 'T2', 'T3']
Fish ID: 08, Trials: ['T1', 'T2', 'T3']
Fish ID: 09, Trials: ['T1', 'T2', 'T3']
Fish ID: 10, Trials: ['T1', 'T2', 'T3']
Fish ID: 11, Trials: ['T1', 'T2', 'T3']
Fish ID: 12, Trials: ['T1', 'T2', 'T3']
Fish ID: 13, Trials: ['T1', 'T2', 'T3']
Fish ID: 14, Trials: ['T1', 'T2', 'T3']
Fish ID: 15, Trials: ['T1', 'T2', 'T3']
Fish ID: 16, Trials: ['T1', 'T2', 'T3']
Found 447 large steps
[22 47 58 78 93]
Found 14911 peaks
Found 1068 large steps
[ 39  96 110 138 160]
Found 12431 peaks
Found 711 large steps
[ 19  46  86 106 135]
Found 15081 peaks
Found 193 large steps
[ 13  36  71 113 141]
Found 15657 peaks
Found 346 large steps
[ 16  50  70  99 126]
Found 15648 peaks
Found 197 large steps
[ 12  34  71 108 151]
Found 1596

In [17]:
print(type(fish_results))

<class 'dict'>


In [18]:
for key in list(fish_results.keys())[:5]:  # Adjust the number to print more or fewer keys
    print(f"Key: {key}, Type of value: {type(fish_results[key])}")


Key: /home/kkumari/PhD/fish-data/long-term-free-swim/01_T1_1b8cd8200e6211edb285003053fc6914_VR03.csv.gz, Type of value: <class 'dict'>
Key: /home/kkumari/PhD/fish-data/long-term-free-swim/01_T2_667418920e7311edb1af003053fc6914_VR03.csv.gz, Type of value: <class 'dict'>
Key: /home/kkumari/PhD/fish-data/long-term-free-swim/01_T3_cf108f460e8411edb1af003053fc6914_VR03.csv.gz, Type of value: <class 'dict'>
Key: /home/kkumari/PhD/fish-data/long-term-free-swim/02_T1_1ce42cc80e6211ed9279003053fc8758_VR04.csv.gz, Type of value: <class 'dict'>
Key: /home/kkumari/PhD/fish-data/long-term-free-swim/02_T2_683acd880e7311edb6c7003053fc8758_VR04.csv.gz, Type of value: <class 'dict'>


In [19]:
all(isinstance(value, pd.DataFrame) for value in fish_results.values())


False

In [20]:
print(f"Number of keys in fish_results: {len(fish_results)}")
print(f"Types of values: {set(type(value) for value in fish_results.values())}")


Number of keys in fish_results: 48
Types of values: {<class 'dict'>}


In [21]:
output_folder_path = "/home/kkumari/PhD/fish-data/processed-data-long-term-free-swim/"
file_name_suffix = 'processed-data'
save_dataframes_from_nested_dict(fish_results, output_folder_path, file_name_suffix)

Debug: Extracted filename: 01_T1_1b8cd8200e6211edb285003053fc6914_VR03.csv.gz
Debug: Split parts: ['01', 'T1', '1b8cd8200e6211edb285003053fc6914', 'VR03.csv.gz']
Debug: Saving dataframe to /home/kkumari/PhD/fish-data/processed-data-long-term-free-swim/01_T1_df_processed-data.csv
Dataframe saved as '/home/kkumari/PhD/fish-data/processed-data-long-term-free-swim/01_T1_df_processed-data.csv'
Skipping non-dataframe item for key: peaks in nested dict of /home/kkumari/PhD/fish-data/long-term-free-swim/01_T1_1b8cd8200e6211edb285003053fc6914_VR03.csv.gz
Skipping non-dataframe item for key: peak_angles in nested dict of /home/kkumari/PhD/fish-data/long-term-free-swim/01_T1_1b8cd8200e6211edb285003053fc6914_VR03.csv.gz
Skipping non-dataframe item for key: wrapped_angles in nested dict of /home/kkumari/PhD/fish-data/long-term-free-swim/01_T1_1b8cd8200e6211edb285003053fc6914_VR03.csv.gz
Skipping non-dataframe item for key: unwrapped_angles in nested dict of /home/kkumari/PhD/fish-data/long-term-fre