<a href="https://colab.research.google.com/github/dtabuena/EphysLib/blob/main/Pipeline_Wrapper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def ephys_wrapper(dataset,VC_prot,IC_prot,strat_cols=['Cell_Type'],verbose=False, spike_args=True):
    '''wrapper for single dataset pipeline'''
    results = {}
    # try:
    '''Unpack'''
    data_name = dataset['data_name']
    data_source = dataset['data_source']
    file_naming_scheme = dataset['file_naming_scheme']



    ''' Gather and Catalog Source Data'''
    file_loc = get_drobox_folder(data_source, 'my_ephys_data_' + data_name)
    # clear_output(wait=False)   
    abf_recordings_df, protocol_set = catalogue_recs(file_loc,file_naming_scheme)
    results['abf_recordings_df'] = abf_recordings_df
    results['protocol_set'] = protocol_set

    abf_recordings_df, _ = purge_wrong_clamp(abf_recordings_df,VC_prot,IC_prot)
    results['abf_recordings_df'] = abf_recordings_df

    _ = cell_prot_lut(abf_recordings_df,protocol_set,csv_name=data_name+'_Recording_LookUp')


    '''Set Internal Analysis Params'''
    if spike_args:
        spike_args =  {'spike_thresh':20, 'high_dv_thresh': 50,'low_dv_thresh': -30,'window_ms': 2}

    func_dict = {}
    arg_dict = {}

    func_dict['VC - 3min GapFree']= rmp_analyzer
    arg_dict['VC - 3min GapFree'] = [True] # [to_plot?]

    func_dict['IC - Rheobase']= rheobase_analyzer
    arg_dict['IC - Rheobase'] = [spike_args, True, False, False]  # [spike_args, to_plot, verbose, force_singlespike]

    func_dict['IC - Gain - D20pA']= gain_analyzer
    arg_dict['IC - Gain - D20pA']= [spike_args, 1]  # [spike_args, to_plot [0:2],]
    func_dict['IC - Gain - D50pA']= func_dict['IC - Gain - D20pA'] 
    arg_dict['IC - Gain - D50pA']= arg_dict['IC - Gain - D20pA']

    func_dict['VC - MemTest-10ms-160ms']= membrane_analyzer
    arg_dict['VC - MemTest-10ms-160ms']= [True, False, ['Ra', 'Rm', 'Cm', 'tau',	'Cmq',	'Cmf',	'Cmqf', 'Cm_pc']]  # [to_plot, verbose]

    func_dict['IC - Latentcy 800pA-1s']= latencey_analyzer 
    arg_dict['IC - Latentcy 800pA-1s']= [spike_args, True]  # [spike_args, to_plot]

    func_dict['IC - R input']= input_resistance_analyzer 
    arg_dict['IC - R input']= [[-30, 10] ,True]  # [dVm_limits, to_plot]

    func_dict['VC - Multi IV - 150ms'] = IV_analyzer_v2
    arg_dict['VC - Multi IV - 150ms']= [{'IV_Early':(16.5, 30),'IV_Steady_State':(100,120)} ,[False, True]]  # [measure_windows, to_plot]


    '''Analyze Dataset'''
    abf_recordings_df, problem_recs = analysis_iterator(abf_recordings_df,func_dict,arg_dict,verbose=verbose)
    # clear_output(wait=True)
    print('problem_recs')
    _=[print('     '+r) for r in problem_recs]
    results['problem_recs'] = problem_recs
    results['abf_recordings_df'] = abf_recordings_df

    '''Download Analysis figs'''
    zip_name = '/content/' + data_name + '_Saved_Figs.zip'
    !zip -r $zip_name /content/Saved_Figs 
    # clear_output()
    files.download(data_name + '_Saved_Figs.zip')

    '''Sort Cells'''
    cell_df = cell_sorting(abf_recordings_df)
    results['cell_df'] = cell_df

    '''Consolidate to Cells'''
    list_types = ['Recording_name','protocol','abf_timestamp', 'channelList']
    any_types = [] + dataset['file_naming_scheme']
    cell_df_con = cell_consolidation(cell_df,list_types,any_types)
    results['cell_df_con'] = cell_df_con

    '''Simplify IV Data'''
    cols_to_simplify = ['IV_Early', 'IV_Steady_State']
    cell_df_nd = simplify_dicts(cell_df_con,cols_to_simplify)     
    results['cell_df_nd'] = cell_df_nd

    '''Make Excell Friendly'''
    keys_and_data_cols={'Stim_Levels_(pA)': ['Stim_Levels_(pA)', 'Spike_Counts' ],
                    'IV_Early_(V_stim)': ['IV_Early_(V_stim)', 'IV_Early_(I_peak)', 'IV_Steady_State_(I_mean)']}
    cell_df_csv = csv_frinedly(cell_df_nd,keys_and_data_cols)
    results['cell_df_csv'] = cell_df_csv


    ''' Convert to Current Density'''
    size_col = 'Cmq_160.0'
    current_col_list = ['IV_Early_(I_peak)_', 'IV_Steady_State_(I_mean)_']
    cell_df_csv = current_density_correction(cell_df_csv, size_col, current_col_list)
    results['cell_df_csv'] = cell_df_csv

    '''Abridge DataFrame'''
    abrg_exclusions = ['Recording_name', 
                    'protocol', 'abf_timestamp', 'channelList',  'Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0',  'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',
                    'Gain_R2', 'Stim_Levels_(pA)', 'Spike_Counts', 'Firing_Duration_%', 'Gain_Vh',  'Vhold_spike',
                        'Rin_Rsqr',  'Ramp_AP_thresh', 'Ramp_Vh', 'Ramp_Rheobase', 
                    'ap_thresh_us', 'v_half','is_compensated','sum_delta'
                    'IV_Early_(range)', 'IV_Early_(I_peak)', 'IV_Early_(I_mean)', 'IV_Early_(V_stim)', 'IV_Steady_State_(range)',
                    'IV_Steady_State_(I_peak)', 'IV_Steady_State_(I_mean)', 'IV_Steady_State_(V_stim)', ]

    abrg_keep = [c for c in cell_df_csv.columns if c not in abrg_exclusions]
    cell_df_csv_abrg = cell_df_csv[abrg_keep]
    results['cell_df_csv_abrg'] = cell_df_csv_abrg

    '''Stratify Cells By Type'''
    strat_df_dict = stratify_rec(cell_df_csv_abrg,strat_cols)
    strat_df_dict,_ = flatten_dict(strat_df_dict,{})
    write_strat_dfs(strat_df_dict, dataset['data_name']+'_results_stratified')
    results['strat_df_dict'] = strat_df_dict
    # except: None
    return results

In [None]:
def csv_frinedly(cell_df,keys_and_data_cols,remove_source = True):
    cell_df_csv = cell_df.copy()
    for k in keys_and_data_cols.keys():
        for data_col in keys_and_data_cols[k]:
            for cell in cell_df_csv.index:
                label_value_list = cell_df_csv.loc[cell,k]
                data_value_list = cell_df_csv.loc[cell,data_col]
                if label_value_list is None: continue
                label_value_len = len( label_value_list)
                for i in range(label_value_len):
                    new_col_name = data_col + '_' + str( cell_df_csv.loc[cell,k][i])
                    if new_col_name not in cell_df_csv.columns: cell_df_csv[new_col_name] = None
                    cell_df_csv.at[cell,new_col_name] = data_value_list[i]

    return cell_df_csv

In [None]:

def analysis_iterator(abf_recordings_df,func_dict,arg_dict,verbose=False):
    problem_recs = []
    

    for file_name in tqdm(abf_recordings_df.index):
        
        abf = abf_or_name(file_name)
        prot_name = abf.protocol
        if verbose: print('\n','     ',file_name)
        if verbose: print('     ',prot_name)

        # check for keyed protocol
        if prot_name not in func_dict.keys():
            # print('unknown protocol(func): ',  prot_name)
            continue
        if prot_name not in arg_dict.keys():
            # print('unknown protocol(args): ',  prot_name)
            continue


        try:
            results = analyzer_func(*args_for_analyzer) # run analyzer
        except:
            print('\n','error on: ' ,file_name)
            print('analysis failed')
        try:
            for k in results.keys():
                # New Col?
                cols = abf_recordings_df.columns
                if k not in cols:
                    abf_recordings_df = init_col_object(abf_recordings_df,k)
                abf_recordings_df.at[file_name,k] = results[k]
        except: 
            print('\n','error on: ' ,file_name)
            print('recording results failed')
            problem_recs.append(file_name)

    return abf_recordings_df, problem_recs

def init_col_object(df,name): 
        df[name] = None
        df[name] = df[name].astype(object)
        return df

In [None]:
def cell_sorting(abf_recordings_df):

    unique_cells = list(set(abf_recordings_df['cell_id']))
    unique_cells.sort()
    transfer_cols = [c for c in abf_recordings_df.columns if 'cell_id' not in c]
    cell_df = pd.DataFrame(index=list(unique_cells),columns = transfer_cols)
    

    for cell in cell_df.index:
        match = [cell in r for r in abf_recordings_df['cell_id']]
        for col in transfer_cols:
            match_values = list(abf_recordings_df[match][col])
            # print('col', col)
            # print(match_values)

            cell_df.at[cell,col] = match_values
    return cell_df

In [None]:
def cell_consolidation(cell_df,list_types,any_types,average_types = True):
    cell_df_con = cell_df.copy()
    explicit_cols = ['IV_Early','IV_Steady_State','Stim_Levels_(pA)','Spike_Counts']

    if average_types:
        average_types = [c for c in cell_df_con.columns if c not in any_types and c not in list_types and c not in explicit_cols]
        
        # print('average_types',average_types)



    for cell in cell_df_con.index:
        for col in list_types:
            'do nothing, keep the list'
        for col in any_types:
            'they are all the same take the first'
            cell_df_con.at[cell,col] = cell_df_con.at[cell,col][0]

        for col in average_types:
            multi_vals = cell_df_con.loc[cell,col]
            try:
                multi_vals = [v for v in multi_vals if v is not None]
                single_val = np.nanmean(multi_vals,0)
                cell_df_con.at[cell,col] = single_val
                # print(single_val)
            except: 'Just keep going None'
        

    # explicitly defined consolidations
    for col in ['IV_Early', 'IV_Steady_State']:
        # assert col in cell_df_con.index, f"Column to consoidate not found: {col}"
        for cell in cell_df_con.index:
            try:
                multi_vals = cell_df_con.loc[cell,col]
                multi_vals = consolidate_iv_recs(multi_vals)
            except:
                if np.isnan(multi_vals): multi_vals = None
                else: multi_vals = 'ERROR'       

            if not isinstance(multi_vals, list): multi_vals=[multi_vals]
            cell_df_con.at[cell,col] = multi_vals

    for cell in cell_df_con.index:
        multi_val_pair = (cell_df_con.loc[cell,'Stim_Levels_(pA)'], cell_df_con.loc[cell,'Spike_Counts'])
        multi_val_pair = consolidate_gain_recs(multi_val_pair)

        new_stim = multi_val_pair[0]
        new_firing = multi_val_pair[1]
        if len(new_stim)>0:
            if isinstance(new_stim[0],list):
                new_stim = new_stim[0]
        if len(new_firing)>0:
            if isinstance(new_firing[0],list):
                new_firing = new_firing[0]

        cell_df_con.at[cell,'Stim_Levels_(pA)'] = new_stim
        cell_df_con.at[cell,'Spike_Counts'] = new_firing

    # cell_df_con = cell_df_con.reindex(sorted(cell_df_con.columns), axis=1)
    return cell_df_con

In [None]:
def consolidate_iv_recs(multi_vals):
    multi_vals = [v for v in multi_vals if v is not None]   
    v_stim = [  mv['V_stim'] for mv in  multi_vals ]
    peak_vals = [  mv['I_peak'] for mv in  multi_vals ]
    if len(v_stim)>1:
        rec_lengths = [len(v) for v in v_stim]
        long_enough = np.where(np.array(rec_lengths) > 5)[0][0]
        multi_vals = multi_vals[long_enough]
        # print(multi_vals)  
    return multi_vals

In [None]:

def simplify_dicts(cell_df,cols_to_simplify,remove_source = True):
    cell_df_new = cell_df.copy()
    for col in cols_to_simplify:
        for cell in cell_df_new.index:
            list_of_dicts = cell_df_new.loc[cell,col]
            list_of_dicts = [d for d in list_of_dicts if d is not None]
            if len(list_of_dicts) == 0: continue
            # print(list_of_dicts)
            list_of_keys = list(list_of_dicts[0].keys())            
            for k in list_of_keys:
                vals_of_key = []
                for i in range(len(list_of_dicts)):
                    vals_of_key.append(  list_of_dicts[i][k] )
                if len(vals_of_key) == 1: vals_of_key = vals_of_key[0]
                new_col = col + '_(' + str(k) +')'
                if new_col not in cell_df_new.columns: 
                    cell_df_new[new_col] = None
                    cell_df_new[new_col] = cell_df_new[new_col].astype(object)
                cell_df_new.at[cell,new_col] = vals_of_key
        cell_df_new.drop(labels=col, axis = 1,inplace = True)
    return cell_df_new

In [None]:
def consolidate_gain_recs(multi_val_pair):
    min_stims = 5
    mv_stim = multi_val_pair[0]
    mv_fire = multi_val_pair[1]
    mv_stim = [v.tolist() for v in mv_stim if v is not None]
    mv_fire = [v.tolist() for v in mv_fire if v is not None]
    results = (mv_stim, mv_fire)


    if len(mv_stim)>1:
        rec_lengths = [len(v) for v in mv_stim]
        mv_stim = [v for v in mv_stim if len(v) >=min_stims]
        mv_fire = [v for v in mv_fire if len(v) >=min_stims]

    results = (mv_stim, mv_fire)
    
    if len(mv_stim)>1:
        stim_set = list(set( [vv for v in mv_stim for vv in v] ))# flat_list = [item for sublist in regular_list for item in sublist]
        stim_set.sort()
        new_vals_dict = {}
        for s in stim_set:
            matching_response =[]
            matching_stim = []
            for i in range(len(mv_stim)):
                for j in range(len(mv_stim[i])):
                    if mv_stim[i][j] == s:
                        matching_stim.append(mv_stim[i][j])
                        matching_response.append(mv_fire[i][j])
            new_vals_dict[s] =  matching_response
        new_stim_list = []
        new_response_list = []
        for k in new_vals_dict:
            new_vals_dict[k] = np.mean(new_vals_dict[k])
            new_stim_list.append(k)
            new_response_list.append(new_vals_dict[k])


        results = (new_stim_list, new_response_list)
        
    return results

In [None]:
def current_density_correction(cell_df,size_col,current_col_list,remove_old=True):
    cell_df_cd = cell_df.copy()
    ccl_exp = []
    for ccl in current_col_list:
        ccl_exp = ccl_exp + [c for c in cell_df.columns if ccl in c]
    current_col_list = ccl_exp
    for cell in cell_df.index:
        size = cell_df.loc[cell,size_col]
        for col in current_col_list:
            try:
                new_col = col +'_pApF'
                cell_df_cd.at[ cell,new_col] = cell_df_cd.at[ cell,col] / size
            except: 
                cell_df_cd.at[ cell,new_col] = None
    
    cell_df_cd = cell_df_cd[[ c for c in cell_df_cd.columns if c not in current_col_list ]]

    return cell_df_cd

In [None]:
def stratify_cells(cell_df,strat_col,xl_file_name='stratified_data.xlsx'):
    types = list(set(cell_df[strat_col]))

    new_dfs = {}
    options = {}
    options['strings_to_formulas'] = False
    options['strings_to_urls'] = False
    writer = pd.ExcelWriter(xl_file_name, options=options)
    for t in types:
        is_type = cell_df[strat_col] == t
        new_dfs[t] = cell_df[is_type]
        new_dfs[t].to_excel(writer, sheet_name=str(t))
        # new_dfs[t].to_csv(str(t) + '_cell_df_csv.csv')
        # files.download(str(t) + '_cell_df_csv.csv')
    writer.save()
    writer.close()
    files.download(xl_file_name)
    return new_dfs

In [None]:

def stratify_rec(cell_df,strat_col,prefix=''):
    strat_dfs = {}


    cur_col = strat_col[0]
    rem_col = [c for c in strat_col if c not in cur_col]

    if len(cur_col) == 0:
        return df

    types = list(set(cell_df[cur_col]))
    for t in types:
        is_type = cell_df[cur_col] == t
        df_name = prefix +'_'+t
        strat_dfs[df_name] = cell_df[is_type]
        new_to_strat = strat_dfs[df_name].copy()
        for r in rem_col:            
            strat_dfs[df_name] = stratify_rec(new_to_strat,rem_col,prefix=t)

    return strat_dfs

In [None]:
def flatten_dict(my_dict,flat_dict = {} ):
    for k in my_dict.keys():
        if isinstance(my_dict[k], dict):
            sub_dict, sub_keys = flatten_dict(my_dict[k],flat_dict)
            for sk in sub_keys:
                flat_dict[k+'_'+sk] = sub_dict[sk]
        else:
            flat_dict = my_dict
    return flat_dict, list(flat_dict.keys())

In [None]:
def write_strat_dfs(strat_dfs, xl_file_name='stratified_data.xlsx'):
    if '.xlsx' not in xl_file_name: xl_file_name = xl_file_name+'.xlsx'
    options = {}
    options['strings_to_formulas'] = False
    options['strings_to_urls'] = False
    writer = pd.ExcelWriter(xl_file_name, options=options)
    for k in strat_dfs.keys():
        cur_df = strat_dfs[k]
        # cur_df = strat_dfs[k].T
        cur_df.to_excel(writer, sheet_name=str(k))
    writer.save()
    writer.close()
    files.download(xl_file_name)
    return None