In [1]:
# Cell 1: Imports and Setup
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import fly_analysis as fa
from scipy.signal import savgol_filter
import multiprocessing as mp
from tqdm.notebook import tqdm
from pycircstat2.descriptive import circ_median


In [2]:
# Cell 2: Define the get_and_plot_data function
def get_and_plot_data(data_row):
    files = data_row['filename'].split(',')
    fly_code = data_row['fly_code']

    try:
        data = fa.braidz.read_multiple_braidz(files, root_folder='/home/buchsbaum/mnt/md0/Experiments/')
    except:
        return None

    df = data["df"]
    if len(data["opto"]) > 0:
        opto = data["opto"]
    elif len(data["stim"]) > 0:
        opto = data["stim"]
    else:
        return None

    angvel = []
    linvel = []
    heading_difference = []

    for _, opto_row in opto.iterrows():
        obj_id = opto_row['obj_id']
        exp_num = opto_row['exp_num']
        frame = opto_row['frame']

        grp = df[(df['obj_id'] == obj_id) & (df['exp_num'] == exp_num)]
        
        if len(grp) < 150:
            continue
        
        try:
            opto_idx = np.where(grp.frame.values == frame)[0][0]
        except IndexError:
            continue

        if opto_idx-50 < 0 or opto_idx+100 >= len(grp):
            continue

        x = savgol_filter(grp.x.values, 21, 3)
        y = savgol_filter(grp.y.values, 21, 3)
        z = savgol_filter(grp.z.values, 21, 3)
        xvel = savgol_filter(grp.xvel.values, 21, 3)
        yvel = savgol_filter(grp.yvel.values, 21, 3)
        zvel = savgol_filter(grp.zvel.values, 21, 3)

        theta = np.arctan2(yvel, xvel)
        theta_u = np.unwrap(theta)
        angular_velocity = np.rad2deg(np.gradient(theta_u, 0.01))
        angvel.append(angular_velocity[opto_idx-50:opto_idx+100])

        linear_velocity = np.sqrt(xvel**2 + yvel**2)
        linvel.append(linear_velocity[opto_idx-50:opto_idx+100])

        heading_before = np.arctan2(y[opto_idx]-y[opto_idx-10], x[opto_idx]-x[opto_idx-10])
        heading_after = np.arctan2(y[opto_idx+50]-y[opto_idx], x[opto_idx+50]-x[opto_idx])

        heading_difference.append(fa.helpers.angdiff(heading_before, heading_after))

    if len(angvel) == 0:
        return None

    angvel = np.array(angvel)
    angvel_mean = np.mean(angvel, axis=0)
    angvel_std = np.std(angvel, axis=0)

    linvel = np.array(linvel)
    linvel_mean = np.mean(linvel, axis=0)
    linvel_std = np.std(linvel, axis=0)

    heading_difference = np.array(heading_difference)
    sample_size = len(heading_difference)

    # Function to calculate effect size (Cohen's d)
    def cohens_d(x):
        return np.mean(x) / np.std(x, ddof=1)

    if sample_size < 5:
        # For very small samples, just report descriptive statistics
        mean = stats.circmean(heading_difference, low=-np.pi, high=np.pi, nan_policy='omit')
        median = circ_median(heading_difference)
        std = stats.circstd(heading_difference, low=-np.pi, high=np.pi, nan_policy='omit')
        test_result = f"Sample size too small (n={sample_size})\nMean: {mean:.4f}, Median: {median:.4f}, Std: {std:.4f}"
    elif sample_size < 20:
        # For small samples (5-19), use non-parametric test
        statistic, p_value = stats.wilcoxon(heading_difference, nan_policy='omit')
        effect_size = cohens_d(heading_difference)
        test_result = f"Wilcoxon signed-rank test: p={p_value:.4f}\nEffect size (Cohen's d): {effect_size:.4f}"
    else:
        # For larger samples, check normality and use appropriate test
        _, normality_p = stats.shapiro(heading_difference)

        if normality_p > 0.05:  # Data is likely normally distributed
            t_statistic, p_value = stats.ttest_1samp(heading_difference, 0, nan_policy='omit')
            effect_size = cohens_d(heading_difference)
            test_result = f"One-sample t-test: p={p_value:.4f}\nEffect size (Cohen's d): {effect_size:.4f}"
        else:
            statistic, p_value = stats.wilcoxon(heading_difference, nan_policy='omit')
            effect_size = cohens_d(heading_difference)
            test_result = f"Wilcoxon signed-rank test: p={p_value:.4f}\nEffect size (Cohen's d): {effect_size:.4f}"

    X = np.arange(-500, 1000, 10)
    fig, axs = plt.subplots(ncols=4, nrows=1, figsize=(16, 4))
    axs[0].plot(X, angvel_mean)
    axs[0].fill_between(X, angvel_mean-angvel_std, angvel_mean+angvel_std, alpha=0.5)
    axs[0].set_xlabel('Time (ms)')
    axs[0].set_title('Angular velocity')

    angvel_abs = np.abs(angvel)
    angvel_abs_mean = np.mean(angvel_abs, axis=0)
    angvel_abs_std = np.std(angvel_abs, axis=0)
    axs[1].plot(X, angvel_abs_mean)
    axs[1].fill_between(X, angvel_abs_mean-angvel_abs_std, angvel_abs_mean+angvel_abs_std, alpha=0.5)
    axs[1].set_xlabel('Time (ms)')
    axs[1].set_title('abs Angular velocity')

    axs[2].plot(X, linvel_mean)
    axs[2].fill_between(X, linvel_mean-linvel_std, linvel_mean+linvel_std, alpha=0.5)
    axs[2].set_xlabel('Time (ms)')
    axs[2].set_title('Linear velocity')

    axs[3].hist(heading_difference, bins=min(36, sample_size), density=True)
    axs[3].set_xlabel('Heading difference (rad)')
    axs[3].set_xlim(-np.pi, np.pi)
    axs[3].axvline(0, color='black', linestyle='--')

    mean_heading = stats.circmean(heading_difference, low=-np.pi, high=np.pi, nan_policy='omit')
    axs[3].axvline(mean_heading, color='red', linestyle='--')

    axs[3].set_title(f'Heading difference\n{test_result}')
    axs[3].legend(['0', 'mean'])

    plt.suptitle(f'{fly_code} n={sample_size}')
    plt.tight_layout()
    plt.savefig(f'/home/buchsbaum/src/fly_analysis/notebooks/Figures/all_flies/{fly_code}_velocity.png', dpi=150)
    plt.close()

    return fly_code  # Return fly_code to confirm successful processing

In [3]:
# Cell 3: Define the parallel processing function
def process_data_parallel(df):
    # Create a pool of workers
    num_cores = min(4, mp.cpu_count())
    pool = mp.Pool(num_cores)

    # Use tqdm to create a progress bar
    results = list(tqdm(pool.imap(get_and_plot_data, [row for _, row in df.iterrows()]), total=len(df)))

    # Close the pool
    pool.close()
    pool.join()

    # Filter out None results (from failed processing) and return successful fly_codes
    return [r for r in results if r is not None]

In [4]:
# Cell 4: Load and process the data
# Read the Excel file
excel_file = "/home/buchsbaum/experiment_list.ods"
df = pd.read_excel(excel_file)

# Filter the DataFrame
axflp = df[df['fly_code'].str.contains('axflp')]
axsparc = df[df['fly_code'].str.contains('AXSPARC')]

In [5]:
# Cell 5: Process AXFLP data
print("Processing AXFLP data...")
axflp_results = process_data_parallel(axflp)
print(f"Processed {len(axflp_results)} AXFLP flies successfully")

Processing AXFLP data...
Reading /home/buchsbaum/mnt/md0/Experiments/20220126_111954.braidz using pyarrow

  0%|          | 0/64 [00:00<?, ?it/s]

Reading /home/buchsbaum/mnt/md0/Experiments/20220127_115455.braidz using pyarrow

Reading /home/buchsbaum/mnt/md0/Experiments/20220131_111636.braidz using pyarrowReading /home/buchsbaum/mnt/md0/Experiments/20220128_121657.braidz using pyarrow

Reading /home/buchsbaum/mnt/md0/Experiments/20220131_174134.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220201_153340.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220202_114157.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220201_173551.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220202_173839.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220207_124256.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220208_154942.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220208_173007.braidz using pyarrow
Reading /home/buchsbaum/mnt/md0/Experiments/20220202_150231.braidz using pyarrow
Reading /home/buchsbaum/mnt

KeyboardInterrupt: 

In [None]:

# Cell 6: Process AXSPARC data
print("Processing AXSPARC data...")
axsparc_results = process_data_parallel(axsparc)
print(f"Processed {len(axsparc_results)} AXSPARC flies successfully")

print("All processing complete!")

In [None]:
# Cell 7: Display results summary
print("Summary of processed flies:")
print(f"AXFLP: {len(axflp_results)} out of {len(axflp)} processed successfully")
print(f"AXSPARC: {len(axsparc_results)} out of {len(axsparc)} processed successfully")

In [None]:

# Optional: Display list of successfully processed fly codes
print("\nSuccessfully processed AXFLP fly codes:")
print(axflp_results)
print("\nSuccessfully processed AXSPARC fly codes:")
print(axsparc_results)