In [22]:
import os
import numpy as np
import glob as glob
import pandas as pd
from pathlib import Path
from scipy.io import loadmat
from natsort import os_sorted
import matplotlib.pyplot as plt

# Functions

In [23]:
# Extract folders for each dataset
def extract_folders(anipose_folders_path, exception):
    # extract folders with 3d position data

    print("------------------Step1 : Extracting CSV file paths------------------")

    ballpos_folders_path = Path(str(anipose_folders_path) + "\\BallPos")
    if os.path.exists(ballpos_folders_path):
        print("INFO : BallPos folder found in project")
    else:
        raise TypeError("No BallPos subfoldder found in project")
    
    ballvel_folders_path = Path(str(anipose_folders_path) + "\\BallVel")
    if os.path.exists(ballvel_folders_path):
        print("INFO : BallVel folder found")
    else:
        raise TypeError("No BallPos subfoldder found in project")

    pos_folders = []
    for i in anipose_folders_path.glob('*\\pose-3d\\*.csv'):
        if len(exception) != 0:
            if str(i).split("\\")[-3] in exception:
                continue
            else:    
                pos_folders.append(i)
        else:
            pos_folders.append(i)

    pos_folders = os_sorted(pos_folders)
    
    # extract folders with angle data
    angle_folders = []
    for i in anipose_folders_path.glob('*\\angles\\*.csv'):
        if len(exception) != 0:
            if str(i).split("\\")[-3] in exception:
                continue
            else:    
                angle_folders.append(i)
        else:
            angle_folders.append(i)

    angle_folders = os_sorted(angle_folders)
    
    # extract folders with ball position data
    ballpos_folders = []
    for i in ballpos_folders_path.glob('*.csv'):
        if len(exception) != 0:
            if str(i).split("\\")[-1][:len(str(i).split("\\")[-1]) - 11] in exception:
                continue
            else:    
                ballpos_folders.append(i)
        else:
            ballpos_folders.append(i)

    ballpos_folders = os_sorted(ballpos_folders)
    
    # extract folders with ball velocity data
    ballvel_folders = []
    for i in ballvel_folders_path.glob('*.csv'):
        if len(exception) != 0:
            if str(i).split("\\")[-1][:len(str(i).split("\\")[-1]) - 11] in exception:
                continue
            else:    
                ballvel_folders.append(i)
        else:
            ballvel_folders.append(i)

    ballvel_folders = os_sorted(ballvel_folders)
    
    
    return pos_folders, angle_folders, ballpos_folders, ballvel_folders

In [24]:
def clean_all_dicts (pos3d_dict, angle_dict, ballpos_dict, ballvel_dict):
    for fly in pos3d_dict.keys():
        len_pose3d = len(pos3d_dict[fly])
        len_angle = len(angle_dict[fly])
        len_ballpos = len(ballpos_dict[fly])
        len_ballvel = len(ballvel_dict[fly])

        if len_pose3d  == len_ballpos:
            print("INFO: All dataframes of equal lengths with complete trials for {flynum}: (pose3d, {xx}) , (angle, {yy}), (ballpos, {zz}), (ballvel, {aa})".format(flynum = fly, xx = len(pos3d_dict[fly]), yy= len(angle_dict[fly]), zz =len(ballpos_dict[fly]) , aa = len(ballvel_dict[fly])))
            continue
        elif len_pose3d > len_ballpos:
            print("INFO: Extra frames found in pose_3d for {flynum}. Reducing frame number to match ball pos".format(flynum = fly))
            ## reduce frame number of pose3d and angle 
            new_pose3d = pos3d_dict[fly].iloc[:len_ballpos,:]
            new_angle = angle_dict[fly].iloc[:len_ballpos,:]

            if len(new_pose3d)%1400 == 0:
                pos3d_dict.update({fly:new_pose3d})
                angle_dict.update({fly:new_angle})
                print( "INFO: All dataframes updated for {flynum}: (pose3d, {xx}) , (angle, {yy}), (ballpos, {zz}), (ballvel, {aa})".format(flynum = fly, xx = len(pos3d_dict[fly]), yy= len(angle_dict[fly]), zz =len(ballpos_dict[fly]) , aa = len(ballvel_dict[fly])))

            else:
                
                diff = len(new_pose3d)%1400 ## frames to remove from all dfs for fly
                print("INFO: Incomplete data structure for {flynum}. Removing last {num} frames".format(flynum = fly, num = diff))

                updated_pose3d = new_pose3d.iloc[:-diff, :]
                updated_angle = new_angle[fly].iloc[:-diff, :]
                updated_ballpos = ballpos_dict[fly].iloc[:-diff, :]
                updated_ballvel = ballvel_dict[fly].iloc[:-diff, :]

                pos3d_dict.update({fly:updated_pose3d})
                angle_dict.update({fly:updated_angle})
                ballpos_dict.update({fly:updated_ballpos})
                ballvel_dict.update({fly:updated_ballvel})

                print( "INFO: All dataframes updated for {flynum}: (pose3d, {xx}) , (angle, {yy}), (ballpos, {zz}), (ballvel, {aa})".format(flynum = fly, xx = len(pos3d_dict[fly]), yy= len(angle_dict[fly]), zz =len(ballpos_dict[fly]) , aa = len(ballvel_dict[fly])))
        else: ## len_pose3d < len_ballpos
            print("INFO: Extra frames found in ballpos for {flynum}. Reducing frame number to match pose 3d".format(flynum = fly))
            ## reduce frame number of ballpos and ballvel 
            new_ballpos = ballpos_dict[fly].iloc[:len_pose3d,:]
            new_ballvel = ballvel_dict[fly].iloc[:len_pose3d,:]
            
            if len(new_ballpos)%1400 == 0:
                ballpos_dict.update({fly:new_ballpos})
                ballvel_dict.update({fly:new_ballvel})
                print( "INFO: All dataframes updated for {flynum}: (pose3d, {xx}) , (angle, {yy}), (ballpos, {zz}), (ballvel, {aa})".format(flynum = fly, xx = len(pos3d_dict[fly]), yy= len(angle_dict[fly]), zz =len(ballpos_dict[fly]) , aa = len(ballvel_dict[fly])))
            else:
                diff = len_pose3d%1400 ## frames to remove from all dfs for fly
                print("INFO: Incomplete data structure for {flynum}. Removing last {num} frames".format(flynum = fly, num = diff))

                updated_pose3d = pos3d_dict[fly].iloc[:-diff, :]
                updated_angle = angle_dict[fly].iloc[:-diff, :]
                updated_ballpos = new_ballpos.iloc[:-diff, :]
                updated_ballvel = new_ballvel.iloc[:-diff, :]

                pos3d_dict.update({fly:updated_pose3d})
                angle_dict.update({fly:updated_angle})
                ballpos_dict.update({fly:updated_ballpos})
                ballvel_dict.update({fly:updated_ballvel})

                print( "INFO: All dataframes updated for {flynum}: (pose3d, {xx}) , (angle, {yy}), (ballpos, {zz}), (ballvel, {aa})".format(flynum = fly, xx = len(pos3d_dict[fly]), yy= len(angle_dict[fly]), zz =len(ballpos_dict[fly]) , aa = len(ballvel_dict[fly])))
    return pos3d_dict, angle_dict, ballpos_dict, ballvel_dict

In [25]:
def dict_to_df (dict):
    df = pd.DataFrame()
    for fly in dict.keys():
        data = dict[fly]
        df = pd.concat([df, data], ignore_index= True).reset_index(drop=True)
    return df

In [26]:
def extract_data(pos_folders, angle_folders, ballpos_folders, ballvel_folders):
    """extract position, angle, ball velocity and ball position data

    Parameters
    ----------
    pos_folders : list
        anipose 3d pose csv file paths
    angle_folders : list
        anipose angle csv file paths
    ballpos_folders : list
        ball pos csv file paths
    ballvel_folders : list
        ball velocity csv file paths
    """

    print("------------------Step2 : Extracting data from CSV files to DataFrames------------------")

    indent = " "*6

    pos3d_dict = {}
    for data_csv in pos_folders:
        num = str(data_csv).split("\\")[-3][:]
        temp_csv = pd.read_csv(data_csv)
        temp_csv = temp_csv.iloc[:, :-13] # remove last 13 columns of anipose metadata
        pos3dclean_df = pd.DataFrame()
        for item in np.arange(0,len(temp_csv.columns),6):
            pos3dclean_df = pd.concat([pos3dclean_df, temp_csv.iloc[:,item:item+3]],axis=1) 
        pos3d_dict.update({str(num) :pos3dclean_df })

    
    angle_dict = {}

    for data_csv in angle_folders:
        num = str(data_csv).split("\\")[-3][:]
        temp_csv = pd.read_csv(data_csv).drop("fnum", axis = 1)
        angle_dict.update({str(num) :temp_csv })

    
    ballpos_dict = {}

    for data_csv in ballpos_folders:
        num = str(data_csv).split("\\")[-1][:-11]
        temp_csv = pd.read_csv(data_csv)
        temp_len = len(temp_csv)

        if temp_len%1400 !=0:
            L = temp_len
            missing = 1400 - L%1400 
            print("INFO: Missing frames found in {filepath}, missing count = {frames}".format(filepath = data_csv, frames = missing))
            print(indent +'Info: length before:', L)
            extra = pd.DataFrame([])
            for i in range(1400-L%1400):
                extra = pd.concat([extra,pd.DataFrame([np.nan, np.nan, np.nan]).T ], axis =0)
#             print('dim extra = ', len(extra.columns))   
            extra.columns = data_csv.columns.tolist()
            temp_csv = pd.concat([temp_csv, extra], axis = 0, ignore_index=True)
            temp_csv.columns = ["x_pos", "y_pos", "z_pos"]
            print(indent +'Info: length after:', len(temp_csv))
        
        ballpos_dict.update({str(num) :temp_csv })
            

    
    ballvel_dict = {}
    for data_csv in ballvel_folders:
        num = str(data_csv).split("\\")[-1][:-11]
        temp_csv = pd.read_csv(data_csv)
        temp_len = len(temp_csv)

        if temp_len%1400 !=0:
            L = temp_len
            missing = 1400 - L%1400 
            print("INFO: Missing frames found in {filepath}, missing count = {frames}".format(filepath = data_csv, frames = missing))
            print(indent +'Info: length before:', L)
            extra = pd.DataFrame([])
            for i in range(1400-L%1400):
                extra = pd.concat([extra,pd.DataFrame([np.nan, np.nan, np.nan]).T ], axis =0)
            extra.columns = data_csv.columns.tolist()
            temp_csv = pd.concat([temp_csv,extra], axis = 0, ignore_index=True)
            print(indent +'Info: length after:', len(temp_csv))
            temp_csv.columns = ["x_vel", "y_vel", "z_vel"]
        ballvel_dict.update({str(num) :temp_csv })
        


    print("INFO: Finished extracting dataframes with 3d pose, angles, ballpos and ballvel data")
    pos3d_dict, angle_dict, ballpos_dict, ballvel_dict = clean_all_dicts (pos3d_dict, angle_dict, ballpos_dict, ballvel_dict)
    

    return pos3d_dict, angle_dict, ballpos_dict, ballvel_dict

In [27]:
def extract_initcol(pos3d_dict,  SF_path = None):
    """extract information about fly number, trial number, frame number, and stimulation frequency if SF_path is defined

    Parameters
    ----------
    pos3d_dict : dict
        dict with flynumbers as keys with 3D pose data
    SF_path : Path, optional
        Path to CSV file with stim parameter, by default None

    Returns
    -------
    DataFrame
        DF with flynum, tnum, fnum, SF (optional)

    """

    print("------------------Step3 : Extracting Metadata columns------------------")
    
    indent = " "*6
    
    if SF_path != None:
        SF = pd.read_csv(SF_path)
        print("INFO : SF path found!")
    else:
        SF = []
        print("INFO : NO SF path found!!")

    
    trial_len = 1400
    Init_cols_all = pd.DataFrame()


    for fly in pos3d_dict.keys():
        data = pos3d_dict[fly]
        data_len = int(len(data))

        list = [x for x in fly if x != 'N']
        flynum = ''
        for i in list:
            flynum = flynum + i
        flynum = int(flynum)

        flynum_list = [flynum] * data_len ## flynum list created

        tot_trials = int(data_len / trial_len)
        tnum_range = [x for x in range(1,tot_trials+1)]
        tnum_list = []
        for i in tnum_range:
            temp_tnum = [i]*trial_len
            tnum_list = tnum_list + temp_tnum ## tnum list created
        
        fnum_list = [x for x in range(data_len)] ## fnum list created
        
        if len(flynum_list) == len(tnum_list) == len(fnum_list):
            print(indent + "INFO:flynum, tnum, fnum arrays created for {fly}".format(fly = fly))
        else:
            raise TypeError("INFO: flynum, tnum,fnum array lengths don't match for {fly}".format(fly = fly))
    
        
        init_cols = pd.DataFrame([flynum_list, tnum_list, fnum_list]).T
        init_cols.columns = ['flynum', 'tnum', 'fnum']
        init_cols.reset_index(drop=True)

        if len(SF) > 0:
            SF_fly = SF[fly][:tot_trials].tolist()
            SF_list= []
            for i in SF_fly:
                temp_sf = [i]*trial_len
                SF_list = SF_list + temp_sf   ## SF list created
            
            init_cols = pd.concat([init_cols, pd.DataFrame(SF_list)], axis = 1)
            init_cols.columns = ['flynum', 'tnum', 'fnum', 'SF']
            init_cols.reset_index(drop=True)
            
            Init_cols_all = pd.concat([Init_cols_all, init_cols], axis = 0).reset_index(drop=True)
        else:
            Init_cols_all = pd.concat([Init_cols_all, init_cols], axis = 0).reset_index(drop=True)

    return Init_cols_all

In [28]:
def combine_flydata(anipose_folders_path, exception, SF_path = None):
    

    pos_folders, angle_folders, ballpos_folders, ballvel_folders = extract_folders(anipose_folders_path, exception)
    pos3d_dict, angle_dict, ballpos_dict, ballvel_dict = extract_data(pos_folders, angle_folders, ballpos_folders, ballvel_folders)
    
    Init_cols_all= extract_initcol(pos3d_dict, SF_path)

    pos3d_df = dict_to_df (pos3d_dict)
    angle_df = dict_to_df (angle_dict)
    ballpos_df = dict_to_df (ballpos_dict)
    ballvel_df = dict_to_df (ballvel_dict)

    fly_data = pd.concat([Init_cols_all, pos3d_df, angle_df, ballpos_df, ballvel_df], axis=1)
    
    print("INFO: All Dataframes combined sucessfully")
    return fly_data
    

In [29]:
def generate_table(all_data):
    gen_df = pd.DataFrame()
    for genotype in all_data.keys():
        fly_data = all_data[genotype]
        gendata = {"Genotype": genotype, "no. of flies": [max(fly_data["flynum"])], "flydata": [fly_data]}
        temp_df = pd.DataFrame(data=gendata)
        gen_df = pd.concat([gen_df, temp_df], axis=0)
    return gen_df

In [30]:
def filter_prestim_vel (dataframe, threhsold):
    threshold = int(threhsold)
    filtered_df = pd.DataFrame()
    for n in dataframe['flynum'].unique().tolist():
        for t in dataframe.groupby('flynum').get_group(n)['tnum'].unique().tolist():
            data = dataframe.groupby('flynum').get_group(n).groupby('tnum').get_group(t)
            mean_vel = np.mean(data.loc[(data['fnum']%1400>320)&(data['fnum']%1400<400)]['x_vel'])
            if mean_vel >=threshold:
                filtered_df = pd.concat([filtered_df.reset_index(drop=True), data.reset_index(drop=True)], axis = 0)
    return filtered_df

# Generate Dataset

## Define genotypes

In [31]:
P9_LT_anipose_folders_path = Path(r'Z:\BallSystem_AniposeReconstructions\4_P9-LeftTurning\project') ## path to the anipose project folder
P9_LT_SF_path = Path(r"Z:\BallSystem_AniposeReconstructions\4_P9-LeftTurning\project\Metadata_freq.csv") # path to the csv file with the frequency of activation 
P9_LT_exception = [] # if any fly needs to be excluded (eg: N1), enter as 'N1'
P9_LT_data = combine_flydata(P9_LT_anipose_folders_path, P9_LT_exception, P9_LT_SF_path)

------------------Step1 : Extracting CSV file paths------------------
INFO : BallPos folder found in project
INFO : BallVel folder found
------------------Step2 : Extracting data from CSV files to DataFrames------------------
INFO: Finished extracting dataframes with 3d pose, angles, ballpos and ballvel data
INFO: All dataframes of equal lengths with complete trials for N1: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N2: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N3: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N4: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N5: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All 

In [32]:
P9_RT_anipose_folders_path = Path(r'Z:\BallSystem_AniposeReconstructions\5_P9-RightTurning\project') ## path to the anipose project folder
P9_RT_SF_path = Path(r"Z:\BallSystem_AniposeReconstructions\5_P9-RightTurning\project\Metadata_freq.csv") # path to the csv file with the frequency of activation 
P9_RT_exception = [] # if any fly needs to be excluded (eg: N1), enter as 'N1'
P9_RT_data = combine_flydata(P9_RT_anipose_folders_path, P9_RT_exception, P9_RT_SF_path)

------------------Step1 : Extracting CSV file paths------------------
INFO : BallPos folder found in project
INFO : BallVel folder found
------------------Step2 : Extracting data from CSV files to DataFrames------------------
INFO: Finished extracting dataframes with 3d pose, angles, ballpos and ballvel data
INFO: All dataframes of equal lengths with complete trials for N1: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N2: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N3: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N4: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N5: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All 

In [33]:
BPN_anipose_folders_path = Path(r'Z:\BallSystem_AniposeReconstructions\3_BPN-S1-Activation\newCalib_flybased\project') ## path to the anipose project folder
BPN_SF_path = Path(r"Z:\BallSystem_AniposeReconstructions\3_BPN-S1-Activation\newCalib_flybased\project\Metadata_freq.csv") # path to the csv file with the frequency of activation 
BPN_exception = [] # if any fly needs to be excluded (eg: N1), enter as 'N1'
BPN_data = combine_flydata(BPN_anipose_folders_path, BPN_exception, BPN_SF_path)

------------------Step1 : Extracting CSV file paths------------------
INFO : BallPos folder found in project
INFO : BallVel folder found
------------------Step2 : Extracting data from CSV files to DataFrames------------------
INFO: Finished extracting dataframes with 3d pose, angles, ballpos and ballvel data
INFO: All dataframes of equal lengths with complete trials for N1: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N2: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N3: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N4: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All dataframes of equal lengths with complete trials for N5: (pose3d, 14000) , (angle, 14000), (ballpos, 14000), (ballvel, 14000)
INFO: All 

## Create table with all data

In [34]:
all_data = {"P9LT": P9_LT_data,
            "P9RT": P9_RT_data,
            "BPN" : BPN_data}

gen_df = generate_table(all_data)
gen_df.reset_index(drop=True, inplace=True)
display(gen_df)

Unnamed: 0,Genotype,no. of flies,flydata
0,P9LT,12,flynum tnum fnum SF R-F-ThC_x R-...
1,P9RT,13,flynum tnum fnum SF R-F-ThC_x R...
2,BPN,10,flynum tnum fnum SF R-F-ThC_x R-...


Save DataStructure as HDF file

In [35]:
gen_df.to_hdf(Path(r'C:\2_P9Project_allData\1_DataStructure\P9LT_P9RT_BPNS1_V1_04252024.h5'), key='df', mode='w')

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['Genotype', 'flydata'], dtype='object')]

  gen_df.to_hdf(Path(r'C:\2_P9Project_allData\1_DataStructure\P9LT_P9RT_BPNS1_V1_04252024.h5'), key='df', mode='w')
