# Trajectory analysis per single track - extended
The notebook loads 2D trajectory data, applies Savitzky-Golay smoothing, computes instantaneous and windowed features (speed, angular velocity, radius of fitted arc, etc.), and classifies behavioral states (Swim, Reorient, Reverse) based on deviation angle variance. Outputs include figures and CSV files with all calculated metrics and state assignments.

Input:
- folder_main: Path to the main data folder containing the trajectory file
- npy_file: Name of the input NumPy file with trajectory data
- fps: Frames per second of the recording
- px_mm: Conversion factor from pixels to millimeters
- Smoothing and feature extraction parameters (e.g., window sizes, thresholds)

Output:
- figures (PNG, SVG) of trajectories, kinematic features, and state ethograms
- CSV files with time series of all calculated features and state labels

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import ListedColormap
from scipy.signal import savgol_filter
from scipy.optimize import least_squares
from scipy.stats import circvar
import csv
import pandas as pd
# %matplotlib widget
%matplotlib inline

import matplotlib as mpl
# Ensure text is saved as editable text in SVG
mpl.rcParams['svg.fonttype'] = 'none'
# Set global font to Arial
mpl.rcParams['font.family'] = 'Arial'

Define input:

In [None]:
# Path to main data folder:
folder_main = "W:\\Users\\Daphne\\Imaging_Daphne\\25-12-19_RPi_ptetwt_swimming_deciliated\\bgd_subs_4000\\processed_251219_etoh_allcilia_3_analysis\\"

# input/output paths and files:
npy_file = "filtered_trajectories_2.npy"
sample_nr = "_allcilia_3" # used to name output files
file_path = folder_main + npy_file
output_path =  folder_main
fps = 30

# track_data = pd.read_csv(file_path)
track_array = np.load(file_path, allow_pickle=True)
# track_data = pd.DataFrame(track_array[0], columns=['x', 'y'])
track_data = pd.DataFrame(track_array, columns=['x', 'y'])

#OPTIONAL: Limit the number of rows for testing purposes
# track_data = track_data[:1000]  # Limit to first 1000 rows for testing 

# Define the conversion to mm:
px_mm = 91.8 # 91.8 pixels = 1 mm

Quick check that the data is read correctly:

In [None]:
track_data

Visualize the track to compare later with the smoothed version:

In [None]:
# Extract x, y, and time coordinates from the data:
x = track_data['x'] / px_mm
y = track_data['y'] / px_mm
time = pd.Series(track_data.index) / 30  # 30 fps, assuming the index represents the frame number
time = (time - time.iloc[0])/60  # divsion by 60 to get the time in min
    
# For nicer display the following translates the data to be in a range of positive x and y
# 1. Finds the center of the current data:
center_x = (np.min(x) + np.max(x)) / 2
center_y = (np.min(y) + np.max(y)) / 2
# 2. Choose a target center:
target_center_x, target_center_y = 5,5
# 3. Calculate the translation:
translation_x = target_center_x - center_x
translation_y = target_center_y - center_y
# 4. Apply the translation:
translated_x = x + translation_x
translated_y = y + translation_y
    
# Plots the trajectory for visualization:
plt.figure()
plt.scatter(translated_x, translated_y, c=time, cmap="magma", s=1) # Plot the data points with color representing time
plt.colorbar(label='time (min)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
# plt.savefig(output_path + "trajectory_no_correction.png", dpi=300, bbox_inches='tight')
plt.show()

Make convenient arrays to work with later:

In [None]:
# Combines the x and y coorinates in an array:
translation_x_array = translated_x.to_numpy()
translation_y_array = translated_y.to_numpy()
translations = np.column_stack((translation_x_array, translation_y_array))

# Prepares the time axis in seconds for the plots:
time_steps = (np.diff(time.to_numpy()).reshape(-1, 1)) * 60   # this is should just be the step between frames in seconds
time_axis_sec = np.cumsum(time_steps)                         # this should be the timestamps starting from 0 in seconds
time_axis_sec_with_0 = np.insert(time_axis_sec, 0, 0)

Below we calculate the distance travelled, used for visual checks:

In [None]:
# Calculate the accumulative distance travelled:
dx = np.diff(translations[:,0])  
dy = np.diff(translations[:,1])  
distance = np.zeros_like(dx)
distance = np.sqrt(dx**2 + dy**2)  
cumulative_distance = np.cumsum(distance)

## 1. Noise estimation and trajectory smoothing:

#### We need to remove noise (e.g. from tracking) but preserve the features and transition in the data. 

(Roughly reorientations take about 1/2 sec to complete. Smoothing window should not be larger than that. Note that we sample with 30 fps, therefore the window size should be about 9-15 frames.)
Avoid moving average because it will not preserve features. A filter that includes polynomial fitting is best e.g. Savitzky-Golay.

**Savitzky-Golay filter**: preserves features like peaks or rapid changes in the trajectory. It works by fitting successive polynomial functions to small, overlapping windows of the data. The degree of smoothing can be tuned using window_length and polyorder.
Note that since it preserves features the window can be slightly bigger than the dynamics we anticipate.


In [None]:
# Target window size for the trajectory smoothing:
target_window_smoothing = 1/2  # in sec!   If we want in terms of points (should be odd number),
window_smoothing_SG = int(np.ceil(target_window_smoothing * fps))
window_smoothing_SG_odd = window_smoothing_SG if window_smoothing_SG % 2 != 0 else window_smoothing_SG + 1
print("chosen window size: ", int(window_smoothing_SG_odd))

Below we plot the x and y coordinates and the cumulative distance before and after the smoothing to check if we indeed preserve features while reducing noise or the window is to big/small

In [None]:
# Track smoothing with Savitzky-Golay
t = time_axis_sec_with_0
noisy_signal_x = translations[:,0]   
noisy_signal_y = translations[:,1]

# 1. Apply Savitzky-Golay filter
window_length = window_smoothing_SG_odd  # It has to be an odd number! Set to 7 for 0.5 sec
polyorder = 2       # Usually 2 or 3 
smoothed_x = savgol_filter(noisy_signal_x, window_length, polyorder)
smoothed_y = savgol_filter(noisy_signal_y, window_length, polyorder)
smoothed_translations = np.column_stack((smoothed_x, smoothed_y))

# 2. Plot result
fontsize_1 = 8
plt.figure(figsize = (6,1),dpi=300)
plt.ylabel("x (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.plot(t, smoothed_x, label="Smoothed Data")
plt.title("Savitzky-Golay filtered x - window " + str(window_length) + ", polyorder " + str(polyorder), fontsize = fontsize_1)
plt.grid()
# plt.savefig(output_path + "x_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

plt.figure(figsize = (6,1),dpi=300)
plt.ylabel("x (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.plot(t, noisy_signal_x, label="Noisy data")
plt.title("Original x", fontsize = fontsize_1)
plt.grid()
# plt.savefig(output_path + "x_in_time_raw.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

plt.figure(figsize = (6,1),dpi=300)
plt.ylabel("y (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.plot(t, smoothed_y, label="Smoothed Data")
plt.title("Savitzky-Golay filtered y - window " + str(window_length) + ", polyorder " + str(polyorder), fontsize = fontsize_1)
plt.grid()
# plt.savefig(output_path + "y_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

plt.figure(figsize = (6,1),dpi=300)
plt.ylabel("y (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.plot(t, noisy_signal_y, label="Noisy data")
plt.title("Original y", fontsize = fontsize_1)
plt.grid()
# plt.savefig(output_path + "y_in_time_raw.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
# Calculate the accumulative distance travelled:
dx_SG = np.diff(smoothed_translations[:,0])  
dy_SG = np.diff(smoothed_translations[:,1])  
distance_SG = np.zeros_like(dx_SG)
distance_SG = np.sqrt(dx_SG**2 + dy_SG**2)  
cumulative_distance_SG = np.cumsum(distance_SG)

plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec, cumulative_distance)
plt.title("Cumulative distance travelled", fontsize = fontsize_1)
plt.ylabel("d (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
# plt.savefig(output_path + "cumul_distance_in_time_no_smoothing.png", dpi=300, bbox_inches='tight', pad_inches=0.1)

plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec, cumulative_distance_SG)
plt.title("cumulative distance travelled-Savitzky-Golay", fontsize = fontsize_1)
plt.ylabel("d (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
# plt.savefig(output_path + "cumul_distance_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)

cmap = mpl.cm.Greys(np.linspace(0,1,250))
cmap = mpl.colors.ListedColormap(cmap[100:,:-1])

# Plots the smoothed trajectory for visualization:
plt.figure()
# Plot the data points with color representing time
plt.scatter(smoothed_x, smoothed_y, c=time, cmap=cmap, s=2)
# plt.plot(translated_x, translated_y)
plt.colorbar(label='time (min)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
# plt.xlim(8,11)
# plt.ylim(8,11)
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight')
plt.savefig(output_path + "trajectory_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".svg", dpi=300, bbox_inches='tight')
plt.show()

## Extracting linear + angular velocity, speed and reorientation

#### Main trajectory parameters are calculated below. Noting that abrupt peaks appear due to single points, we can apply a mean average with a small window to reduce the noise.

Instantaneous velocity and its amplitude (speed):

In [None]:
### Calculate instantaneous linear velocity and speed:
translations_to_analyze = smoothed_translations  # if you want to check with the raw trajectory set: translations_to_analyze =translations
velocities = []
velocities = np.diff(translations_to_analyze, axis = 0) / time_steps
speed_cell = np.linalg.norm(velocities, ord=2,  axis=1) 


# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec,speed_cell, color='0.2', linewidth=0.5) #,alpha=0.7)
# plt.plot(time_axis_sec, speed_cell_smooth, label='Smoothed Speed', color='orange', linewidth=2)
plt.title("Speed in time", fontsize = fontsize_1)
plt.ylabel("v (mm /sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(False)
plt.savefig(output_path + "instantaneous_speed_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "instantaneous_speed_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".svg", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(speed_cell, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('instantaneous speed  (mm/sec)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
# plt.savefig(output_path + "instantaneous_speed_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:-1,0], translations_to_analyze[:-1,1], c= speed_cell, cmap="magma", s=10)
plt.colorbar(label='v (mm/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
# plt.savefig(output_path + "trajectory_speed_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Relative angle between consequetive points and its absolute value if we don't care about directionality:

In [None]:
### Calculate instantaneous angles and reorientation:
dot_product = np.sum(velocities[:-1, :] * velocities[1:, :], axis=1)
norms_prev = np.linalg.norm(velocities[:-1, :], axis=1)
norms_next = np.linalg.norm(velocities[1:, :], axis=1)
denominator = np.clip(norms_prev * norms_next, a_min=1e-18, a_max=None)
cosine_of_consequetive_angle = dot_product / denominator

# # Angle from 0 to π
angles = np.arccos(np.clip(cosine_of_consequetive_angle, -1.0, 1.0))
# Correction to have it from -π to +π:
cross_product_z = (velocities[:-1, 0] * velocities[1:, 1] - 
                   velocities[:-1, 1] * velocities[1:, 0])
# Adjust angle based on cross product sign
signed_angles_in_rad = np.where(cross_product_z >= 0, angles, -angles)
signed_angles_in_deg = np.degrees(signed_angles_in_rad)
absolute_signed_angles_in_rad = np.abs(signed_angles_in_rad)
absolute_signed_angles_in_deg = np.abs(signed_angles_in_deg)

# Plot results:
# 1.a. Δθ: (θ_i) - θ_(i-1) , negative when Reverseing towards right, positive when Reverseing towards left (cell frame of reference)
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec[:-1],signed_angles_in_deg)
plt.title("Instantaneous reorientation in time", fontsize = fontsize_1)
plt.ylabel("\u03B8$_{\mathrm{i}}$ - \u03B8$_\mathrm{i-1}$", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "inst_reorientation_with_sign_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 1.b. Corresponsing histogram:
fontsize_2 = 5
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(signed_angles_in_deg, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('instantaneous orientation (degrees)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "inst_reorientation_with_sign_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 1.c. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[1:-1,0], translations_to_analyze[1:-1,1], c= signed_angles_in_deg, cmap="magma", s=10)
plt.colorbar(label='\u03B8$_{\mathrm{i}}$ - \u03B8$_\mathrm{i-1}$ (degrees)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_inst_reorient_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. absolute Δθ: amplitude of consequetive Reverses
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec[:-1], absolute_signed_angles_in_deg)
plt.title("Absolute instantaneous reorientation in time", fontsize = fontsize_1)
plt.ylabel("|\u03B8$_{\mathrm{i}}$ - \u03B8$_\mathrm{i-1}|$", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "inst_reorientation_absolute_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2.b. Corresponsing histogram:
fontsize_2 = 5
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(absolute_signed_angles_in_deg, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('absolute instantaneous reorientation (degrees)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "inst_reorientation_with_sign_absolute_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2.c. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[1:-1,0], translations_to_analyze[1:-1,1], c= absolute_signed_angles_in_deg, cmap="magma", s=10)
plt.colorbar(label='|\u03B8$_{\mathrm{i}}$ - \u03B8$_\mathrm{i-1}|$')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_inst_reorient_abs_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

The cumulative angle if we want to see more global trends:

In [None]:
### Cumulative reorientation:
cumulative_angles_in_rad = np.cumsum(signed_angles_in_rad)
cumulative_angles_in_deg = np.cumsum(signed_angles_in_deg)

# 1. Cumulative angles in time
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec[:-1], cumulative_angles_in_deg)
plt.title("Cumulative reorientation in time", fontsize = fontsize_1)
plt.ylabel("\u03B8", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "cumul_reorientation_with_sign_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram of cumulative angles
fontsize_2 = 5
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(cumulative_angles_in_deg, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('cumulative reorientation (degrees)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "cumul_reorientation_with_sign_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Instantaneous angular velocity:

In [None]:
### Calculate angular velocity:

ang_velocity = signed_angles_in_rad /  time_steps.flatten()[1:]

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec[:-1], ang_velocity)
plt.title("Angular velocity in time", fontsize = fontsize_1)
plt.ylabel("\u03A9 (rad/sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "instantaneous_ang_vel_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(ang_velocity, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('\u03A9 (rad/sec)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "instantaneous_ang_vel_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[1:-1,0], translations_to_analyze[1:-1,1], c= ang_velocity, cmap="magma", s=10)
plt.colorbar(label='\u03A9 (rad/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_ang_vel_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

## Extracting additional metrics: RFA, MAID, DAV, velocity variance

These metrics are calculated for a given window in time.

In [None]:
def radius_of_fitter_circle(points):
    # Define the function to minimize
    def residuals(params, x, y):
        h, k, r = params
        return ((x - h)**2 + (y - k)**2) - r**2

    # Initial guess for h, k, r
    h0, k0 = np.mean(points[:, [0,1]], axis=0)
    r0 = np.sqrt(np.mean((points[:,0] - h0)**2 + (points[:,1] - k0)**2))

    # Least squares optimization
    result = least_squares(residuals, x0=[h0, k0, r0], args=(points[:,0], points[:,1]))

    # Extract the radius
    _, _, radius = result.x
    return radius

def compute_features_for_window_size(trajectory, velocities, speed_cell, signed_angles_in_rad, window_size):
        current_trajectory = trajectory # .smoothed_trajectory[1:-1, :]
        radius_of_fitted_arc = []
        average_incremental_discplacements = []
        variance_in_directionality = []
        variance_in_speed =[]
        
        for i in range(current_trajectory.shape[0]):
            # Step 1: Determine loop boundaries
            start = max(0,  i - window_size//2)
            end = min(current_trajectory.shape[0], i + window_size//2)

            # Correct the endpoints if they're outside the [0, b] interval while keeping n iterations
            if start == 0:
                end = min(current_trajectory.shape[0], start + window_size) 
            elif end == current_trajectory.shape[0]:
                start = max(0, end - window_size)

#             xy = current_trajectory[start:end,[1,2]]
            xy = current_trajectory[start:end, :]  # Use all columns (x, y)
            radius_of_fitted_arc.append(radius_of_fitter_circle(points=xy)) # Radius of fitted circle
            average_incremental_discplacements.append(np.sum(velocities[start:end])/window_size) 
            variance_in_directionality.append(circvar(signed_angles_in_rad[start:end], high=np.pi, low=-np.pi))
            variance_in_speed.append(np.var(speed_cell[start:end]))
            
        return pd.DataFrame({
                            f'{window_size}_RFA' : np.asanyarray(radius_of_fitted_arc),
                            f'{window_size}_MAID' : np.asanyarray(average_incremental_discplacements), 
                            f'{window_size}_DAV' : np.asanyarray(variance_in_directionality),
                            f'{window_size}_VS' : np.asanyarray(variance_in_speed),
                            })

Window size should not exceed the expected dynamics timescales too much

In [None]:
target_window_features = 1/2  # in sec
window_features = int(np.ceil(target_window_features * fps))
window_features_odd = window_features if window_features % 2 != 0 else window_features + 1
track_features = compute_features_for_window_size(translations_to_analyze, velocities, speed_cell, signed_angles_in_rad,  window_features_odd)

rfa  = track_features[str(window_features_odd) + '_RFA']
maid = track_features[str(window_features_odd) + '_MAID']
dav  = track_features[str(window_features_odd) + '_DAV']
velocity_variance  = track_features[str(window_features_odd) + '_VS']


Plotting results for RFA:

In [None]:
# RFA
rfa_array = rfa.to_numpy()
rfa_array = rfa_array.reshape(1, -1)[0]

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec_with_0, rfa_array)
plt.title("Radius of Fitted Arc (RFA)", fontsize = fontsize_1)
plt.ylabel("RFA (mm)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "rfa_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(rfa_array, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('RFA (mm)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "rfa_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:,0], translations_to_analyze[:,1], c= rfa_array, cmap="magma", s=10)
plt.colorbar(label='RFA (mm)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_rfa_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Plotting results for MAID:

In [None]:
# MAID
maid_array = maid.to_numpy()
maid_array = np.abs(maid_array.reshape(1, -1)[0])

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec_with_0, maid_array)
plt.title("Moving Average of Incremental Distances (MAID)", fontsize = fontsize_1)
plt.ylabel("MAID (mm/sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "maid_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(maid_array, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('MAID (mm/sec)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "maid_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:,0], translations_to_analyze[:,1], c= maid_array, cmap="magma", s=10)
plt.colorbar(label='MAID (mm/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_maid_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Plotting results for DAV:

In [None]:
# DAV
dav_array = dav.to_numpy()
dav_array = dav_array.reshape(1, -1)[0]

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec_with_0, dav_array, color='0.2', linewidth=0.5)
plt.title("Deviation Angle Variance (DAV)", fontsize = fontsize_1)
plt.ylabel("DAV", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(False)
plt.savefig(output_path + "dav_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "dav_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".svg", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(dav_array, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('DAV', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "dav_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:,0], translations_to_analyze[:,1], c= dav_array, cmap="magma", s=10)
plt.colorbar(label='DAV')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_dav_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Save all relevant data calculated above.

In [None]:
def save_output_csv(output_file, velocity_x, velocity_y, speed,signed_angles_in_deg, translation_x, translation_y, maid, dav):
    """
    Save the size and center of the cell in an csv file.
    """
    # Check if the file exists
    file_exists = os.path.exists(output_file)

    new_signed_angles_in_deg = np.concatenate(([0], signed_angles_in_deg))
    # Ensure all arrays have the same length
    n = len(velocity_x)
    assert len(velocity_y) == n
    assert len(speed) == n
    assert len(translation_x) == n
    assert len(translation_y) == n
    assert len(time_axis_sec) == n
    
    
    with open(output_file, mode='a' if file_exists else 'w', newline='') as file:
        writer = csv.writer(file)
    
        # Write header if the file is newly created
        if not file_exists:
            writer.writerow(['velocity_x (mm/sec)', 'velocity_y (mm/sec)', ' speed (mm/sec)', 'signed_angles_in_deg', 'translation_x (mm)', 'translation_y (mm)','maid (mm/sec)', 'dav'])
    
        # Write data to CSV
        for i in range(n):
            writer.writerow([velocity_x[i], velocity_y[i], speed[i], new_signed_angles_in_deg[i], translation_x[i], translation_y[i], maid[i], dav[i]])

    return

# output_path = main_path +"results\\" +"_basic_measurements.csv"
output_path_csv = os.path.join(output_path ,"basic_measurements_SG" + str(window_length) + "_" +  str(polyorder) + ".csv")
# save_output_csv(output_path_csv, velocities[:,0], velocities[:,1], speed_cell, signed_angles_in_deg, translations_to_analyze[:-1,0], translations_to_analyze[:-1,1], rfa_array, maid_array, dav_array, time_axis_sec)
save_output_csv(output_path_csv, velocities[:,0], velocities[:,1], speed_cell, signed_angles_in_deg, translations_to_analyze[:-1,0], translations_to_analyze[:-1,1], maid_array, dav_array)

## Identifying states

First apply a mean filter with very small window on the linear and angular velocity and then recalculate some metrics using that to decide on the classification criteria. Then set thresholds for the states of interest (for Paramecium the DAV is used) and identify the borders of each segment of the classified track. Using those indices the corresponding segments are separated and can be studied in isolation.

In [None]:
### Calculate instantaneous linear velocity and speed:
window_size_mov_av = 3  # keep it very small
speed_series = pd.Series(speed_cell)
speed_cell_smooth = speed_series.rolling(window=window_size_mov_av, center=True, min_periods=1).mean()

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec,speed_cell, alpha=0.6)
plt.plot(time_axis_sec, speed_cell_smooth, label='Smoothed Speed', color='orange', linewidth=2)
plt.title("Speed in time", fontsize = fontsize_1)
plt.ylabel("v (mm /sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "instantaneous_speed_mov_av_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(speed_cell_smooth, bins=30, density=True, alpha=0.5, color='orange', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('instantaneous speed (mm/sec)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "instantaneous_speed_mov_av_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:-1,0], translations_to_analyze[:-1,1], c= speed_cell_smooth, cmap="magma", s=10)
plt.colorbar(label='v (mm/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_speed_mov_av_smooth_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
### Calculate angular velocity:
# To smooth for noise/abrupt peaks:
window_size_mov_av = 3  # keep it very small
ang_velocity_series = pd.Series(ang_velocity)
ang_velocity_cell_smooth = ang_velocity_series.rolling(window=window_size_mov_av, center=True, min_periods=1).mean()

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec[:-1], ang_velocity, alpha=0.6)
plt.plot(time_axis_sec[:-1], ang_velocity_cell_smooth, label='Smoothed Angular velocity', color='orange', linewidth=2)
plt.title("Angular velocity in time", fontsize = fontsize_1)
plt.ylabel("\u03A9 (rad/sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
# plt.savefig(output_path + "instantaneous_ang_vel_mov_av_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(ang_velocity_cell_smooth, bins=30, density=True, alpha=0.5, color='orange', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('\u03A9 (rad/sec)', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
# plt.savefig(output_path + "instantaneous_ang_vel_mov_av_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[1:-1,0], translations_to_analyze[1:-1,1], c= ang_velocity_cell_smooth, cmap="magma", s=10)
plt.colorbar(label='\u03A9 (rad/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
# plt.savefig(output_path + "trajectory_ang_vel_mov_av_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
target_window_features = 0.5  # in sec
window_features = int(np.ceil(target_window_features * fps))
window_features_odd = window_features if window_features % 2 != 0 else window_features + 1
track_features = compute_features_for_window_size(translations_to_analyze, velocities, speed_cell_smooth, signed_angles_in_rad,  window_features_odd)

rfa  = track_features[str(window_features_odd) + '_RFA']
maid = track_features[str(window_features_odd) + '_MAID']
dav  = track_features[str(window_features_odd) + '_DAV']
velocity_variance  = track_features[str(window_features_odd) + '_VS']

In [None]:
# DAV
dav_array = dav.to_numpy()
dav_array = dav_array.reshape(1, -1)[0]

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec_with_0, dav_array)
plt.title("Deviation Angle Variance (DAV)", fontsize = fontsize_1)
plt.ylabel("DAV", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "dav_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(dav_array, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('DAV', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "dav_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:,0], translations_to_analyze[:,1], c= dav_array, cmap="magma", s=10)
plt.colorbar(label='DAV')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_dav_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
## Velocity variance

# 1. Time evolution:
plt.figure(figsize = (6,1),dpi=300)
plt.plot(time_axis_sec_with_0, velocity_variance)
plt.title("Variance of speed", fontsize = fontsize_1)
plt.ylabel("VS", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(True)
plt.savefig(output_path + "var_speed_in_time_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 2. Histogram:
fontsize_2 = 8
plt.figure(figsize=(2, 1),dpi=300)
plt.hist(velocity_variance, bins=30, density=True, alpha=0.5, color='teal', edgecolor='black', label='data')
plt.ylim()
plt.xlabel('VS', fontsize = fontsize_2)
plt.ylabel('density', fontsize = fontsize_2)
plt.xticks(fontsize = fontsize_2)
plt.yticks(fontsize = fontsize_2)
plt.grid(False)
plt.savefig(output_path + "var_speed_histogram_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# 3. Color coding the track:
plt.figure()
# Plot the data points with color representing time
plt.scatter(translations_to_analyze[:,0], translations_to_analyze[:,1], c= velocity_variance, cmap="magma", s=10)
plt.colorbar(label='VS')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(output_path + "trajectory_var_speed_smoothing_SG" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

Definition of thresholds

In [None]:
# Swim → Reorient → Swim
# or
# Swim → Reverse → Swim
# but never Reorient → Reverse or Reverse → Reorient directly

# --- Thresholds ---
threshold_dav1 = 0.01
threshold_dav2 = 0.04
# threshold_vv = 0.005 # if you want to add the velocity variance
merge_gap = 10  # max gap (in frames) to merge close bouts

# --- Optional smoothing of dav ---
# dav_smooth = gaussian_filter1d(dav, sigma=2)
# OR use raw dav:
dav_smooth = dav

# Plot smoothed DAV
plt.figure(figsize=(6, 2), dpi=300)
plt.plot(time_axis_sec_with_0, dav_smooth, label='Smoothed DAV', color='orange', linewidth=1)
plt.title("Deviation Angle Variance (DAV)", fontsize=fontsize_1)
plt.ylabel("DAV", fontsize=fontsize_1)
plt.xlabel("time (sec)", fontsize=fontsize_1)
plt.xticks(fontsize=fontsize_1)
plt.yticks(fontsize=fontsize_1)
plt.grid(True)
plt.savefig(output_path + "dav_smooth_in_time.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# --- State classification ---
# 0 = Swim, 1 = Reorient, 2 = Reverse
cond_Reorient = (dav_smooth >= threshold_dav1) & (dav_smooth < threshold_dav2) #& (velocity_variance > threshold_vv)
cond_Reverse = (dav_smooth >= threshold_dav2) #& (velocity_variance > threshold_vv)
states = np.select([cond_Reorient, cond_Reverse], [1, 2], default=0)

# --- Helper functions ---
# When a Reorient (1) and Reverse (2) are directly next to each other, they should be merged into a single Reverse (2) state.
def merge_Reorient_Reverse_to_Reverse(states):
    """
    When Reorient (1) and Reverse (2) are adjacent, merge both into Reverse (2).
    """
    states = states.copy()
    i = 0
    while i < len(states) - 1:
        if (states[i] in [1, 2]) and (states[i + 1] in [1, 2]) and (states[i] != states[i + 1]):
            # Find start of the contiguous 1-2 or 2-1 segment
            start = i
            while start > 0 and states[start - 1] in [1, 2]:
                start -= 1

            # Find end of the contiguous 1-2 or 2-1 segment
            end = i + 1
            while end < len(states) - 1 and states[end + 1] in [1, 2]:
                end += 1

            # Convert all to Reverse (2)
            states[start:end + 1] = 2
            i = end + 1
        else:
            i += 1
    return states


def merge_short_gaps(states, target_state, gap_max_len=10):
    """
    Fill short Swim (0) gaps between bouts of the same target_state.
    """
    states = states.copy()
    i = 0
    while i < len(states):
        if states[i] == target_state:
            # Find end of current bout
            start = i
            while i < len(states) and states[i] == target_state:
                i += 1
            end = i

            # Look ahead for a short gap followed by same state
            gap_start = i
            while i < len(states) and states[i] == 0:
                i += 1
            gap_end = i

            if i < len(states) and states[i] == target_state:
                if (gap_end - gap_start) <= gap_max_len:
                    states[gap_start:gap_end] = target_state
        else:
            i += 1
    return states

def remove_short_states(states, min_len=6):
    """
    Convert short bouts of Reorient (1) or Reverse (2) into Swim (0) if their duration < min_len frames.
    """
    states = states.copy()
    i = 0
    while i < len(states):
        current_state = states[i]
        if current_state in [1, 2]:
            start = i
            while i < len(states) and states[i] == current_state:
                i += 1
            end = i
            if (end - start) < min_len:
                states[start:end] = 0
        else:
            i += 1
    return states

# --- Apply cleaning ---
states = merge_Reorient_Reverse_to_Reverse(states)
states = merge_short_gaps(states, target_state=1, gap_max_len=merge_gap)
states = merge_short_gaps(states, target_state=2, gap_max_len=merge_gap)
# Remove short Reorient/Reverse bouts (≤ 5 frames)
states = remove_short_states(states, min_len=6)


# --- Optional final smoothing (use with care) ---
# states = median_filter(states, size=5)

# --- State names ---
state_labels = np.array(["Swim", "Reorient", "Reverse"])
state_names = state_labels[states]

# --- Plotting ethogram ---
plt.figure(figsize=(10, 2), dpi=300)

for i, label in enumerate(state_labels):
    mask = states == i
    plt.plot(time_axis_sec_with_0[mask], states[mask], linestyle='None', marker='|',
             markersize=10, label=label)

plt.title("State's ethogram", fontsize=fontsize_1)
plt.ylabel("state", fontsize=fontsize_1)
plt.xlabel("time (sec)", fontsize=fontsize_1)
plt.xticks(fontsize=fontsize_1)
plt.yticks([0, 1, 2], ["Swim", "Reorient", "Complete Reverse/Reversal"], fontsize=fontsize_1)
plt.legend(loc='upper right')
plt.grid(True)
plt.savefig(output_path + "state_ethogram_colored.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# --- Export to CSV ---
df_states = pd.DataFrame({
    "time_sec": time_axis_sec_with_0,
    "dav": dav,
    "x":smoothed_x,
    "y":smoothed_y,
    "velocity": speed_cell_smooth,
    "velocity_variance": velocity_variance,
    "ang_velocity":ang_velocity_cell_smooth,
    "state": states,
    "state_label": state_names
})

output_path2 = "W:\\Users\\Daphne\\WT_RESULTS\\WT_swimming\\States\\"

df_states.to_csv(output_path + npy_file[:-4] + sample_nr + "_states_timeseries.csv", index=False)
print(f"CSV saved to: {output_path}")




In [None]:
# To look only at a small segment to check the states:
# reorient 5.1-6.1 sec
# reverse 103.9-105.5 sec
# swim 159.6-163.2 sec
# new swim 161-162 sec
# new reverse 104-105 sec
start_time = 161  # in seconds
end_time = 162  # in seconds

# --- Filter the data for zoomed range ---
mask = (time_axis_sec_with_0 >= start_time) & (time_axis_sec_with_0 <= end_time)

time_zoom     = time_axis_sec_with_0[mask]
dav_zoom      = dav_array[mask]
dav_smooth_zoom = dav_smooth[mask] if 'dav_smooth' in locals() else dav_zoom
states_zoom   = states[mask]
translations_to_analyze_zoom = translations_to_analyze[mask]
speed_cell_smooth_zoom = speed_cell_smooth[mask[:-1]]

time_axis_sec_zoom = time_axis_sec[mask[:-1]]
speed_cell_zoom = speed_cell[mask[:-1]]

# --- Plot smoothed DAV ---
# plt.figure(figsize=(10, 2.5), dpi=300)
plt.figure(figsize=(4, 4), dpi=300)
plt.plot(time_zoom, dav_zoom, label='Raw DAV', alpha=0.3, color='gray')
plt.plot(time_zoom, dav_smooth_zoom, label='Smoothed DAV', color='black', linewidth=1.5)
plt.title("DAV Zoomed", fontsize=12)
plt.xlabel("Time (sec)", fontsize=12)
plt.ylim(0, 0.17) 
plt.ylabel("DAV", fontsize=12)
plt.legend(fontsize=10)
plt.grid(False)
plt.tight_layout()
plt.savefig(output_path + "dav_smooth_zoomed.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "dav_smooth_zoomed.svg", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# speed plot zoomed
plt.figure(figsize = (4,4),dpi=300)
plt.plot(time_axis_sec_zoom,speed_cell_zoom, color='0.2', alpha=0.3)
plt.title("Speed in time", fontsize = fontsize_1)
plt.ylabel("v (mm /sec)", fontsize = fontsize_1)
plt.xlabel("time (sec)", fontsize = fontsize_1)
plt.ylim(0, 0.6)
plt.xticks(fontsize = fontsize_1)
plt.yticks(fontsize = fontsize_1)
plt.grid(False)
plt.savefig(output_path + "instantaneous_speed_in_time_zoom" + str(window_length) + "_" +  str(polyorder) + ".png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "instantaneous_speed_in_time_zoom" + str(window_length) + "_" +  str(polyorder) + ".svg", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# --- Plot State Ethogram ---
state_labels = ["Swim", "Reorient", "Reverse"]
state_colors = ["dodgerblue", "orange", "forestgreen"]

plt.figure(figsize=(10, 1.5), dpi=300)
for i, label in enumerate(state_labels):
    plt.plot(time_zoom[states_zoom == i], [i] * np.sum(states_zoom == i),
             '|', markersize=10, color=state_colors[i], label=label)

plt.yticks([0, 1, 2], state_labels, fontsize=10)
plt.xlabel("Time (sec)", fontsize=12)
plt.title("State Ethogram (Zoomed)", fontsize=12)
plt.grid(True, axis='x')
plt.legend(loc='upper right', fontsize=10)
plt.tight_layout()
plt.show()


# Plot the trajectory with states:
# Define the colormap: 0 = Swim (blue), 1 = Reorient (orange), 2 = Reverse (green)
three_state_cmap = ListedColormap([
    (31/255, 119/255, 180/255, 0.6),  # Swim (blue, semi-transparent)
    (255/255, 127/255, 14/255, 0.9),  # Reorient (orange)
    (44/255, 160/255, 44/255, 0.9)    # Reverse (green)
])

# Create the plot
plt.figure()
plt.scatter(translations_to_analyze_zoom[:, 0], translations_to_analyze_zoom[:, 1],
            c=states_zoom, cmap=three_state_cmap, s=2)

plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')

# Legend handles with matching colors
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(31/255, 119/255, 180/255, 0.5),
           markersize=5, label="Swim"),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(255/255, 127/255, 14/255, 0.9),
           markersize=5, label="Reorient"),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(44/255, 160/255, 44/255, 0.9),
           markersize=5, label="Reverse"),
]

plt.legend(handles=legend_elements, fontsize=12)
plt.savefig(output_path + "traj_states.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "traj_states.svg", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

# plot part of the trajectory with as colormap the speed
plt.figure()
plt.scatter(translations_to_analyze_zoom[:,0], translations_to_analyze_zoom[:,1],
                c=speed_cell_smooth_zoom, cmap="viridis", s=10, vmin=0, vmax=0.55)
plt.colorbar(label='v (mm/sec)')
plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
# plt.xlim(4.3, 4.5) 
# equal aspect ratio
plt.gca().set_aspect('equal', adjustable='box')
# plt.savefig(output_path + "traj_speed.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(output_path + "traj_speed_reverse.svg", dpi=300, pad_inches=0.1)
plt.show()

In [None]:
# If your signal is length 8994 and states is 8995, trim states:
# plt.plot(time_axis_sec_with_0[:-1], states[:-1], color='black', linewidth=0.8)

Below we define useful parameters based on the modes and plot the relevant metrics to assess the thresholds chosen:

In [None]:
state_colors = ["dodgerblue", "orange", "forestgreen"]

def split_by_state(signal, states):
    """
    Splits a signal into a list of arrays where only the selected state values are retained.
    Others are set to np.nan.
    """
    return [np.where(states == i, signal, np.nan) for i in range(len(state_labels))]


dav_by_state = split_by_state(dav_array, states)

plt.figure(figsize=(6, 1), dpi=300)
plt.plot(time_axis_sec_with_0, dav_array, alpha=0.3)
for i, label in enumerate(state_labels):
    plt.plot(time_axis_sec_with_0, dav_by_state[i], color=state_colors[i], linewidth=2, label=label)

plt.title("Deviation Angle Variance (DAV)", fontsize=fontsize_1)
plt.ylabel("DAV", fontsize=fontsize_1)
plt.xlabel("time (sec)", fontsize=fontsize_1)
plt.xticks(fontsize=fontsize_1)
plt.yticks(fontsize=fontsize_1)
plt.grid(True)
plt.legend(fontsize=fontsize_1)
plt.savefig(output_path + "dav_states.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()


vs_by_state = split_by_state(velocity_variance, states)

plt.figure(figsize=(6, 1), dpi=300)
plt.plot(time_axis_sec_with_0, velocity_variance, alpha=0.3)
for i, label in enumerate(state_labels):
    plt.plot(time_axis_sec_with_0, vs_by_state[i], color=state_colors[i], linewidth=2, label=label)

plt.title("Variance of speed", fontsize=fontsize_1)
plt.ylabel("VS", fontsize=fontsize_1)
plt.xlabel("time (sec)", fontsize=fontsize_1)
plt.xticks(fontsize=fontsize_1)
plt.yticks(fontsize=fontsize_1)
plt.grid(True)
plt.legend(fontsize=fontsize_1)
plt.savefig(output_path + "var_speed_states.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()


angvel_by_state = split_by_state(ang_velocity_cell_smooth, states[1:-1])

plt.figure(figsize=(6, 1), dpi=300)
plt.plot(time_axis_sec[1:], ang_velocity_cell_smooth, alpha=0.3)
for i, label in enumerate(state_labels):
    plt.plot(time_axis_sec[1:], angvel_by_state[i], color=state_colors[i], linewidth=2, label=label)

plt.title("Angular velocity in time", fontsize=fontsize_1)
plt.ylabel("\u03A9 (rad/sec)", fontsize=fontsize_1)
plt.xlabel("time (sec)", fontsize=fontsize_1)
plt.xticks(fontsize=fontsize_1)
plt.yticks(fontsize=fontsize_1)
plt.grid(True)
plt.legend(fontsize=fontsize_1)
plt.savefig(output_path + "ang_vel_states.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()

In [None]:
# Define the colormap: 0 = Swim (blue), 1 = Reorient (orange), 2 = Reverse (green)
three_state_cmap = ListedColormap([
    (31/255, 119/255, 180/255, 0.6),  # Swim (blue, semi-transparent)
    (255/255, 127/255, 14/255, 0.9),  # Reorient (orange)
    (44/255, 160/255, 44/255, 0.9)    # Reverse (green)
])

# Create the plot
plt.figure()
plt.scatter(translations_to_analyze[:, 0], translations_to_analyze[:, 1],
            c=states, cmap=three_state_cmap, s=10)

plt.xlabel('X (mm)')
plt.ylabel('Y (mm)')
plt.title('Smoothed Trajectory')
plt.grid(False)
plt.gca().set_aspect('equal', adjustable='box')

# Legend handles with matching colors
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(31/255, 119/255, 180/255, 0.5),
           markersize=5, label="Swim"),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(255/255, 127/255, 14/255, 0.9),
           markersize=5, label="Reorient"),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=(44/255, 160/255, 44/255, 0.9),
           markersize=5, label="Reverse"),
]

plt.legend(handles=legend_elements, fontsize=12)
plt.savefig(output_path + "traj_states.png", dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.show()
