In [None]:
import os
import numpy as np
import glob as glob
import pandas as pd
from pathlib import Path
from scipy.io import loadmat
import matplotlib.pyplot as plt
import math
import seaborn as sns
import scipy 
from matplotlib.ticker import MultipleLocator
import sklearn
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from scipy.ndimage import gaussian_filter1d

# Import Data

In [None]:
All_stopdata = pd.read_parquet(r'C:\1_Stop_project_allData\Ball_fitting_V2\data\df_preproc.parquet', engine='pyarrow')

In [None]:
## Define genotypes
Stop1 = All_stopdata.loc[All_stopdata['genotype'].str.contains('Stop')]
FG = All_stopdata.loc[All_stopdata['genotype'].str.contains('FG')]
BB = All_stopdata.loc[All_stopdata['genotype'].str.contains('BB')]
ES = All_stopdata.loc[All_stopdata['genotype'].str.contains('ES')]

In [None]:
def plot_all_stopping_bouts (df, title):
    """Plots ball velocity bounded line with Fe-Ti joint angle for all trials in the input df

    Parameters
    ----------
    df : DataFrame
        Raw dataframe with flynum, tnum and L1C_flex columns 
    title : string
        'ES_66' or 'FG_66' or 'BB_66' or 'BRK_66'. Used to specify the genotype of the df so that it is matched to the appropriate ball data file
    """
    if title == 'ES_66':
        Mean = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\Mean_ES_66.csv')
        SEM = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\SEM_ES_66.csv')
    elif title == 'FG_66':
        Mean = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\Mean_FG_66.csv')
        SEM = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\SEM_FG_66.csv')
    elif title == 'BB_66':
        Mean = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\Mean_BB_66.csv')
        SEM = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\SEM_BB_66.csv')
    elif title == 'BRK_66':
        Mean = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\Mean_Stop1_66.csv')
        SEM = pd.read_csv(r'Z:\STOP PROJECT SHARED FOLDER\Fig2\PreliminaryData_invalid\07222023\9_Ball_velocity\SEM_Stop1_66.csv')
    
  
    idx = pd.DataFrame()
    for N in df['flynum'].unique().tolist():
        for t in df.groupby('flynum').get_group(N)['tnum'].unique().tolist():
            idx = pd.concat([idx, pd.DataFrame([N,t]).T])
    idx.columns = ['flynum', 'tnum']
    idx  =idx.reset_index(drop=True)

    HM_data_L1 = pd.DataFrame(np.array(df['L1C_flex']).reshape((int(len(df)/1400)), 1400))   
    HM_data_L1_toplot = pd.concat([idx, HM_data_L1 ], axis = 1)

    # HM_data_L2 = pd.DataFrame(np.array(df['L2C_flex']).reshape((int(len(df)/1400)), 1400))   
    # HM_data_L2_toplot = pd.concat([idx, HM_data_L2 ], axis = 1)

    # HM_data_L3 = pd.DataFrame(np.array(df['L3C_flex']).reshape((int(len(df)/1400)), 1400))   
    # HM_data_L3_toplot = pd.concat([idx, HM_data_L3 ], axis = 1)

    # HM_data_R1 = pd.DataFrame(np.array(df['R1C_flex']).reshape((int(len(df)/1400)), 1400))   
    # HM_data_R1_toplot = pd.concat([idx, HM_data_R1 ], axis = 1)
   
    # HM_data_R2 = pd.DataFrame(np.array(df['R2C_flex']).reshape((int(len(df)/1400)), 1400))   
    # HM_data_R2_toplot = pd.concat([idx, HM_data_R2 ], axis = 1)
  
    # HM_data_R3 = pd.DataFrame(np.array(df['R3C_flex']).reshape((int(len(df)/1400)), 1400))   
    # HM_data_R3_toplot = pd.concat([idx, HM_data_R3 ], axis = 1)
 
    plt.rcParams['figure.figsize'] = (10,5)
    fig, ax = plt.subplots(2,1, sharex = True, sharey = False)
    vmin = 10
    vmax = 150
    cmap = 'coolwarm'

    sns.heatmap(ax = ax[1], data = HM_data_L1_toplot.iloc[:,2:1002], vmin = vmin, vmax = vmax, cmap = cmap, cbar=False)
    # sns.heatmap(ax = ax[2], data = HM_data_L2_toplot.iloc[:,2:1002], vmin = vmin, vmax = 130, cmap = cmap, cbar=False)
    # sns.heatmap(ax = ax[3], data = HM_data_L3_toplot.iloc[:,2:1002], vmin = vmin, vmax = vmax, cmap = cmap, cbar=False)
    # sns.heatmap(ax = ax[4], data = HM_data_R1_toplot.iloc[:,2:1002], vmin = vmin, vmax = vmax, cmap = cmap, cbar=False)
    # sns.heatmap(ax = ax[5], data = HM_data_R2_toplot.iloc[:,2:1002], vmin = vmin, vmax = 130, cmap = cmap, cbar=False)
    # sns.heatmap(ax = ax[6], data = HM_data_R3_toplot.iloc[:,2:1002], vmin = vmin, vmax = vmax, cmap = cmap, cbar=False)
    


    ax[1].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[2].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[3].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[4].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[5].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[6].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')


    ax[1].xaxis.set_minor_locator(MultipleLocator(100))
    y = Mean.iloc[:1000,0]
    error = SEM.iloc[:1000,0]
    ax[0].plot(y, color = 'r', label = 'Stop1')
    ax[0].fill_between(np.arange(0,1000,1), y-error, y+error, alpha = 0.2, color = 'r')
    ax[0].axvline(x = 400, color = 'k', alpha = 0.7, linestyle = '--')
    # ax[0].axvline(x = 1000, color = 'k', alpha = 0.7, linestyle = '--')
    ax[0].set_ylim([-4, 12])

    
    plt.xticks(np.arange(0,1200, 200), [-2, -1, 0, 1, 2, 3],rotation=0)
    plt.tight_layout()


In [None]:
plot_all_stopping_bouts (ES , 'ES_66')