In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import pandas as pd

from src2 import (
    feature_extract as feat
)

# Read and clean data

In [None]:
## Read the parquiet files with the raw data for each individual genotype
## If the parquet file contains columns for any genotype other than the one it was pre-processed for, extract only the correct genotype from the table

BPN = pd.read_parquet(Path(r'C:\2_P9Project_allData\1_DataStructure\ball_predictions\BPN\df_ballpredictn.parquet', engine = 'pyarrow'))
BPN = BPN.loc[BPN['genotype'] == 'BPN']

P9LT = pd.read_parquet(Path(r"C:\2_P9Project_allData\1_DataStructure\ball_predictions\P9LT\df_ballpredictn.parquet", engine = 'pyarrow'))
P9LT = P9LT.loc[P9LT['genotype'] == 'P9LT']

P9RT = pd.read_parquet(Path(r"C:\2_P9Project_allData\1_DataStructure\ball_predictions\P9RT\df_ballpredictn.parquet", engine = 'pyarrow'))
P9RT = P9RT.loc[P9RT['genotype'] == 'P9RT']

In [None]:
## Removing badly tracked flies 
P9LT_remove = [6, 8]
P9LT = P9LT[~P9LT.flynum.isin(P9LT_remove)]

P9RT_remove = [6, 11, 13]
P9RT = P9RT[~P9RT.flynum.isin(P9RT_remove)]

BPN_remove = [10]
BPN = BPN[~BPN.flynum.isin(BPN_remove)]

# Generate Tables

In [None]:
## bin size defined in frames; 1 frame = 5ms. 
## bin size should be chosen such that at least two full step cycles are included in each window.
## If too small a window is chosen, there will be bins where not a single step was completed and this will result in 'Nan' values for parameters
bin_size = 90
P9LT_90 = feat.smoothed_table(P9LT, bin_size, sigma = 5)

In [None]:
raise TypeError("End of Section")

# Troubleshoot (only for development)

In [None]:
P9LT_test = P9LT.loc[(P9LT['flynum'] == 1) & (P9LT['tnum'] == 2)]
P9RT_test = P9RT.loc[(P9RT['flynum'] == 1) & (P9RT['tnum'] == 1)]

In [None]:
plot_trajectory([P9LT.iloc[45+400:45+800, :].reset_index(drop=True), 
                 P9LT_1.iloc[400-45:800-45].reset_index(drop=True),
                 P9LT_5.iloc[400-45:800-45].reset_index(drop=True),
                 P9LT_10.iloc[400-45:800-45].reset_index(drop=True)])

In [None]:
def plot_trajectory(df_list, arrows=True, rotate_first=True, cmap='spring', scatter_c = 'k'):
    
    plt.figure(figsize=(10,6))
    fig, ax = plt.subplots(len(df_list), 1, sharex = True, sharey = True)
    fig.set_figheight(len(df_list) * 5)
    fig.set_figwidth(10)

    idx = 0
    for df in df_list:
        try:
            df = df.loc[:, ['x_vel', 'y_vel', 'z_vel']]
        except:
            df = df.loc[:, ['mean_x_vel', 'mean_y_vel', 'mean_z_vel']]

        x, y, d = 0, 0, 0
        xs, ys, ds = [], [], []
        for xi, yi, di in df.values:

            if rotate_first:
                d += di

            x += np.cos(d) * xi + np.cos(d - np.pi/2) * -yi
            y += np.sin(d) * xi + np.sin(d - np.pi/2) * -yi
            
            if not rotate_first:
                d += di

            xs.append(x)
            ys.append(y)
            ds.append(d)
            
        x = np.array(xs)
        y = np.array(ys)
        d = np.array(ds)
        # c = df.loc[:, 'mean_z_vel']
        c = np.arange(0,len(df),1)

        vmax = np.max(np.abs(c))
        # norm = Normalize(vmin=-vmax, vmax=vmax)
        norm = Normalize(vmin=0, vmax=vmax)

        # fig, ax = plt.subplots()
        if arrows:
            ax[idx].quiver(x, y, np.cos(d), np.sin(d), c, cmap=cmap, norm=norm)
        else:
            ax[idx].scatter(x, y, s=1, color = scatter_c)
        ax[idx].set_xlabel('x')
        ax[idx].set_ylabel('y')

        fig.tight_layout()
        idx+=1
        coords = pd.DataFrame([x,y,d]).T
        coords.columns = ['x', 'y', 'd']


In [None]:
raw = plot_trajectory(P9LT_test[400:600])

In [None]:
temp_x = scipy.ndimage.gaussian_filter(P9LT_test['x_vel'][400:600], sigma = 5)
temp_y = scipy.ndimage.gaussian_filter(P9LT_test['y_vel'][400:600], sigma = 5)
temp_z = scipy.ndimage.gaussian_filter(P9LT_test['z_vel'][400:600], sigma = 5)
test_arr = pd.concat([
    pd.DataFrame(temp_x),
    pd.DataFrame(temp_y),
    pd.DataFrame(temp_x)
], axis = 1)
test_arr.columns = ['x_vel', 'y_vel', 'z_vel']
smoothed_5= plot_trajectory(test_arr)

In [None]:
plt.plot(raw['x'])
plt.plot(smoothed_5['x'])

In [None]:
plt.plot(raw['d'])
plt.plot(smoothed_5['d'])

In [None]:
plt.plot(raw['y'])
plt.plot(smoothed_5['y'])

In [None]:
plt.plot(P9LT_test['z_vel'].reset_index(drop=True))
s_list = [1,5,10,15,20]
for s in s_list:
    plt.plot(scipy.ndimage.gaussian_filter(P9LT_test['z_vel'], sigma = s), label = s)
plt.legend()

In [None]:
temp_x = scipy.ndimage.gaussian_filter(P9LT_test['x_vel'], sigma = 5)
temp_y = scipy.ndimage.gaussian_filter(P9LT_test['y_vel'], sigma = 5)
temp_z = scipy.ndimage.gaussian_filter(P9LT_test['z_vel'], sigma = 5)
test_arr = pd.concat([
    pd.DataFrame(temp_x),
    pd.DataFrame(temp_y),
    pd.DataFrame(temp_x)
], axis = 1)
test_arr.columns = ['x_vel', 'y_vel', 'z_vel']
plot_trajectory(test_arr)

In [None]:
temp_x = scipy.ndimage.gaussian_filter(P9LT_test['x_vel'], sigma = 0.5)
temp_y = scipy.ndimage.gaussian_filter(P9LT_test['y_vel'], sigma = 0.5)
temp_z = scipy.ndimage.gaussian_filter(P9LT_test['z_vel'], sigma = 0.5)
test_arr = pd.concat([
    pd.DataFrame(temp_x),
    pd.DataFrame(temp_y),
    pd.DataFrame(temp_x)
], axis = 1)
test_arr.columns = ['x_vel', 'y_vel', 'z_vel']
plot_trajectory(test_arr)

In [None]:
plot_trajectory(P9RT_test)

In [None]:
plt.plot(P9RT_test['R2C_flex'][400:1000])
plt.plot(P9RT_test['R-M_stepcycle'][400:1000]* 160) 
plt.plot(P9RT_test['z_vel'][400:1000] * 15)

In [None]:
idx = 87
win = 90
data_stim = P9RT_test.iloc[400+idx:400+win+idx, :]
plt.plot(data_stim['R2C_flex'])
plt.plot(data_stim['R-M_stepcycle']* 160) 

In [None]:
idx_ts = data_stim.index[0]
idx_ts

In [None]:
temp = feat.get_TD_LO(data_stim, "R2", idx_ts)
temp

In [None]:
feat.align_arr(feat.get_TD_LO(data_stim, "R2", idx_ts))

In [None]:
temp_param = feat.get_temp_params_2(feat.align_arr(feat.get_TD_LO(data_stim, "R2", idx_ts)))
temp_param

In [None]:
L1_stepdata = feat.get_stance_dist2(
    feat.get_temp_params_2(feat.align_arr(feat.get_TD_LO(data_stim, "R2", idx_ts))),
    P9LT_test,
    "L1"
)
L1_stepdata

In [None]:
L1_angdata = feat.get_ang_params(data_stim, "R2")
L1_angdata

In [None]:
temp_win_data = pd.concat(
    [
        L1_stepdata,
        L1_angdata,

    ],
    axis=1,
).mean()

In [None]:
mean_df.columns

In [None]:
## find indices
temp_df = pd.DataFrame(mean_df['L1_step_period'].isna())
temp_df[temp_df['L1_step_period'] == True].index.tolist()

In [None]:
data_stim['x_vel']

In [None]:
plt.plot(data_stim['x_vel'].reset_index(drop=True))
plt.plot(scipy.ndimage.gaussian_filter(data_stim['x_vel'], sigma = 5))

In [None]:
temp_win_data

### Generate Tables

In [None]:
## bin size defined in frames; 1 frame = 5ms. 
## bin size should be chosen such that at least two full step cycles are included in each window.
## If too small a window is chosen, there will be bins where not a single step was completed and this will result in 'Nan' values for parameters
bin_size = 90
P9LT_test_60 = feat.smoothed_table(P9LT_test, bin_size)


In [None]:
P9RT_test_60 = feat.smoothed_table(P9RT_test, bin_size)

In [None]:
P9LT_test_5 = feat.smoothed_table(P9LT_test, 5)

In [None]:
P9LT_test_10 = feat.smoothed_table(P9LT_test, 10)
P9LT_test_20 = feat.smoothed_table(P9LT_test, 20)
P9LT_test_30 = feat.smoothed_table(P9LT_test, 30)
P9LT_test_40 = feat.smoothed_table(P9LT_test, 40)
P9LT_test_50 = feat.smoothed_table(P9LT_test, 50)
P9LT_test_60 = feat.smoothed_table(P9LT_test, 60)
P9LT_test_70 = feat.smoothed_table(P9LT_test, 70)
P9LT_test_80 = feat.smoothed_table(P9LT_test, 80)
P9LT_test_90 = feat.smoothed_table(P9LT_test, 90)
P9LT_test_100 = feat.smoothed_table(P9LT_test, 100)

In [None]:
P9LT_test_120 = feat.smoothed_table(P9LT_test, 120)
P9LT_test_150 = feat.smoothed_table(P9LT_test, 150)
P9LT_test_170 = feat.smoothed_table(P9LT_test, 170)
P9LT_test_200 = feat.smoothed_table(P9LT_test, 200)

In [None]:
P9RT_test_10 = feat.smoothed_table(P9RT_test, 10)
P9RT_test_20 = feat.smoothed_table(P9RT_test, 20)
P9RT_test_30 = feat.smoothed_table(P9RT_test, 30)
P9RT_test_40 = feat.smoothed_table(P9RT_test, 40)
P9RT_test_50 = feat.smoothed_table(P9RT_test, 50)
P9RT_test_60 = feat.smoothed_table(P9RT_test, 60)
P9RT_test_70 = feat.smoothed_table(P9RT_test, 70)
P9RT_test_80 = feat.smoothed_table(P9RT_test, 80)
P9RT_test_90 = feat.smoothed_table(P9RT_test, 90)
P9RT_test_100 = feat.smoothed_table(P9RT_test, 100)
P9RT_test_120 = feat.smoothed_table(P9RT_test, 120)
P9RT_test_150 = feat.smoothed_table(P9RT_test, 150)
P9RT_test_170 = feat.smoothed_table(P9RT_test, 170)
P9RT_test_200 = feat.smoothed_table(P9RT_test, 200)

In [None]:
P9RT_test_5 = feat.smoothed_table(P9RT_test, 5)

In [None]:
traj_list = [P9LT_test_5, P9LT_test_10, P9LT_test_20, P9LT_test_30, P9LT_test_40, P9LT_test_50, 
             P9LT_test_60, P9LT_test_70, P9LT_test_80, P9LT_test_90, P9LT_test_100, P9LT_test_120, P9LT_test_150, P9LT_test_170, P9LT_test_200]

name_list = ['P9LT_test_5', 'P9LT_test_10', 'P9LT_test_20', 'P9LT_test_30', 'P9LT_test_40', 'P9LT_test_50', 
             'P9LT_test_60', 'P9LT_test_70', 'P9LT_test_80', 'P9LT_test_90', 'P9LT_test_100', 'P9LT_test_120', 'P9LT_test_150', 'P9LT_test_170', 'P9LT_test_200']
# plot_trajectory(traj_list)
plt.figure(figsize=(15,6))
idx = 0
for traj, name in zip(traj_list, name_list):
    if idx <8:
        plt.plot(traj['mean_z_vel'], label = name)
    else:
        plt.plot(traj['mean_z_vel'], label = name, ls = '--')
    idx+=1

plt.legend()

In [None]:
traj_list = [P9RT_test_5, P9RT_test_10, P9RT_test_20, P9RT_test_30, P9RT_test_40, P9RT_test_50, 
             P9RT_test_60, P9RT_test_70, P9RT_test_80, P9RT_test_90, P9RT_test_100, P9RT_test_120, P9RT_test_150, P9RT_test_170, P9RT_test_200]

name_list = ['P9RT_test_5', 'P9RT_test_10', 'P9RT_test_20', 'P9RT_test_30', 'P9RT_test_40', 'P9RT_test_50', 
             'P9RT_test_60', 'P9RT_test_70', 'P9RT_test_80', 'P9RT_test_90', 'P9RT_test_100', 'P9RT_test_120', 'P9RT_test_150', 'P9RT_test_170', 'P9RT_test_200']
# plot_trajectory(traj_list)
plt.figure(figsize=(15,6))
idx = 0
for traj, name in zip(traj_list, name_list):
    if idx <8:
        plt.plot(traj['mean_z_vel'], label = name)
    else:
        plt.plot(traj['mean_z_vel'], label = name, ls = '--')
    idx+=1

plt.legend()

In [None]:
P9RT_test_60

In [None]:
print('L1: ', [P9LT_90['L1_step_period'].isna().sum(), 
P9LT_90['L1_swing_dur'].isna().sum(),
P9LT_90['L1_stance_dur'].isna().sum(), 
P9LT_90['L1_stance_dist_norm'].isna().sum()])

print('L2: ', [P9LT_90['L2_step_period'].isna().sum(), 
P9LT_90['L2_swing_dur'].isna().sum(),
P9LT_90['L2_stance_dur'].isna().sum(), 
P9LT_90['L2_stance_dist_norm'].isna().sum()])

print('L3: ', [P9LT_90['L3_step_period'].isna().sum(), 
P9LT_90['L3_swing_dur'].isna().sum(),
P9LT_90['L3_stance_dur'].isna().sum(), 
P9LT_90['L3_stance_dist_norm'].isna().sum()])

print('R1: ', [P9LT_90['R1_step_period'].isna().sum(), 
P9LT_90['R1_swing_dur'].isna().sum(),
P9LT_90['R1_stance_dur'].isna().sum(), 
P9LT_90['R1_stance_dist_norm'].isna().sum()])

print('R2: ', [P9LT_90['R2_step_period'].isna().sum(), 
P9LT_90['R2_swing_dur'].isna().sum(),
P9LT_90['R2_stance_dur'].isna().sum(), 
P9LT_90['R2_stance_dist_norm'].isna().sum()])

print('R3: ', [P9LT_90['R3_step_period'].isna().sum(), 
P9LT_90['R3_swing_dur'].isna().sum(),
P9LT_90['R3_stance_dur'].isna().sum(), 
P9LT_90['R3_stance_dist_norm'].isna().sum()])

In [None]:
print('L1: ', [P9RT_test_60['L1_step_period'].isna().sum(), 
P9RT_test_60['L1_swing_dur'].isna().sum(),
P9RT_test_60['L1_stance_dur'].isna().sum(), 
P9RT_test_60['L1_stance_dist_norm'].isna().sum()])

print('L2: ', [P9RT_test_60['L2_step_period'].isna().sum(), 
P9RT_test_60['L2_swing_dur'].isna().sum(),
P9RT_test_60['L2_stance_dur'].isna().sum(), 
P9RT_test_60['L2_stance_dist_norm'].isna().sum()])

print('L3: ', [P9RT_test_60['L3_step_period'].isna().sum(), 
P9RT_test_60['L3_swing_dur'].isna().sum(),
P9RT_test_60['L3_stance_dur'].isna().sum(), 
P9RT_test_60['L3_stance_dist_norm'].isna().sum()])

print('R1: ', [P9RT_test_60['R1_step_period'].isna().sum(), 
P9RT_test_60['R1_swing_dur'].isna().sum(),
P9RT_test_60['R1_stance_dur'].isna().sum(), 
P9RT_test_60['R1_stance_dist_norm'].isna().sum()])

print('R2: ', [P9RT_test_60['R2_step_period'].isna().sum(), 
P9RT_test_60['R2_swing_dur'].isna().sum(),
P9RT_test_60['R2_stance_dur'].isna().sum(), 
P9RT_test_60['R2_stance_dist_norm'].isna().sum()])

print('R3: ', [P9RT_test_60['R3_step_period'].isna().sum(), 
P9RT_test_60['R3_swing_dur'].isna().sum(),
P9RT_test_60['R3_stance_dur'].isna().sum(), 
P9RT_test_60['R3_stance_dist_norm'].isna().sum()])

In [None]:
df = pd.DataFrame(P9RT_test_60['R2_step_period'].isna())
df[df['R2_step_period'] == True].index.tolist()