<a href="https://colab.research.google.com/github/dtabuena/Patch_Ephys/blob/main/Ephys_wrapper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def ephys_wrapper_local(dataset,VC_prot,IC_prot,strat_cols=['Cell_Type'],verbose=False, spike_args={'spike_thresh':20, 'high_dv_thresh': 50,'low_dv_thresh': -30,'window_ms': 2},manual_exclusions=[],age_bin_dict=None):
    '''wrapper for single dataset pipeline'''
    results = {}

    '''Unpack'''
    data_name = dataset['data_name']
    data_source = dataset['data_source']
    file_naming_scheme = dataset['file_naming_scheme']

    ''' Gather and Catalog Source Data'''
    abf_recordings_df, protocol_set = catalogue_recs(dataset['data_source'],
                                                 dataset['file_naming_scheme'])

    results['abf_recordings_df'] = abf_recordings_df
    results['protocol_set'] = protocol_set


    abf_recordings_df, _ = purge_wrong_clamp(abf_recordings_df,VC_prot,IC_prot)

    results['abf_recordings_df'] = abf_recordings_df

    csv_name = cell_prot_lut(abf_recordings_df,protocol_set,csv_name=data_name+'_Recording_LookUp')
    results['prot_lut'] = csv_name

    '''Set Internal Analysis Params'''
    func_dict, arg_dict = init_func_arg_dicts()

    '''Analyze Dataset'''
    abf_recordings_df, problem_recs = analysis_iterator(abf_recordings_df,func_dict,arg_dict,verbose=verbose)
    # clear_output(wait=True)
    print('problem_recs')
    _=[print('     '+r) for r in problem_recs]
    results['problem_recs'] = problem_recs
    results['abf_recordings_df'] = abf_recordings_df


    '''Sort Cells'''
    cell_df = cell_sorting(abf_recordings_df)
    results['cell_df'] = cell_df

    '''Consolidate to Cells'''
    list_types = ['Recording_name','protocol','abf_timestamp', 'channelList']
    any_types = [] + dataset['file_naming_scheme']
    cell_df_con = cell_consolidation_v2(cell_df,list_types,any_types)
    results['cell_df_con'] = cell_df_con

    '''Simplify IV Data'''
    cols_to_simplify = ['IV_Early', 'IV_Steady_State']
    cell_df_nd = simplify_dicts(cell_df_con,cols_to_simplify)
    results['cell_df_nd'] = cell_df_nd

    '''Make Excell Friendly'''
    keys_and_data_cols={'Stim_Levels_(pA)': ['Stim_Levels_(pA)', 'Spike_Counts' ],
                    'IV_Early_(V_stim)': ['IV_Early_(V_stim)', 'IV_Early_(I_peak)', 'IV_Steady_State_(I_mean)']}
    cell_df_csv = csv_frinedly(cell_df_nd,keys_and_data_cols)
    results['cell_df_csv'] = cell_df_csv


    ''' Convert to Current Density'''
    size_col = 'Cmq_160.0'
    current_col_list = ['IV_Early_(I_peak)_', 'IV_Steady_State_(I_mean)_']
    cell_df_csv = current_density_correction(cell_df_csv, size_col, current_col_list)
    results['cell_df_csv'] = cell_df_csv

    '''Abridge DataFrame'''
    abrg_exclusions = ['Recording_name',
                    'protocol', 'abf_timestamp', 'channelList',  'Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0',  'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',
                    'Gain_R2', 'Stim_Levels_(pA)', 'Spike_Counts',  'Gain_Vh',  'Vhold_spike',
                        'Rin_Rsqr',  'Ramp_AP_thresh', 'Ramp_Vh', 'Ramp_Rheobase',
                    'v_half','is_compensated','sum_delta'
                    'IV_Early_(range)', 'IV_Early_(I_peak)', 'IV_Early_(I_mean)', 'IV_Early_(V_stim)', 'IV_Steady_State_(range)',
                    'IV_Steady_State_(I_peak)', 'IV_Steady_State_(I_mean)', 'IV_Steady_State_(V_stim)', ]

    abrg_keep = [c for c in cell_df_csv.columns if c not in abrg_exclusions]
    cell_df_csv_abrg = cell_df_csv[abrg_keep]
    results['cell_df_csv_abrg'] = cell_df_csv_abrg


    """
    Add Age Bins
    """
    if age_bin_dict is None:
        age_bin_dict = {6:'<=6mo',
                        7:'7-9mo',
                        8:'7-9mo',
                        9:'7-9mo',
                        17:'17-19mo',
                        18:'17-19mo',
                        19:'17-19mo',}
    cell_df_csv_abrg = convert_age_month_bins(cell_df_csv_abrg,age_bin_dict,age_key='age')

    '''Stratify Cells By Type'''
    # strat_df_dict = stratify_rec(cell_df_csv_abrg,strat_cols)
    # strat_df_dict,_ = flatten_dict(strat_df_dict,{})
    strat_df_dict = stratify_rec_v2(cell_df_csv_abrg,strat_cols)
    write_strat_dfs_local(strat_df_dict, dataset['data_name']+'_results_stratified')
    results['strat_df_dict'] = strat_df_dict
    return results

In [None]:
def csv_frinedly(cell_df,keys_and_data_cols,remove_source = True):
    cell_df_csv = cell_df.copy()

    to_add = pd.DataFrame({cell_df.index.name : cell_df.index}).set_index(cell_df.index.name)

    for k in keys_and_data_cols.keys():
        for data_col in keys_and_data_cols[k]:
            for cell in cell_df_csv.index:
                label_value_list = cell_df_csv.loc[cell,k]
                data_value_list = cell_df_csv.loc[cell,data_col]
                if label_value_list is None: continue
                label_value_len = len( label_value_list)
                for i in range(label_value_len):
                    val = int(cell_df_csv.loc[cell,k][i])
                    str_val = str(val)
                    str_val = format(val,"=+04.0f")
                    new_col_name = data_col + '_' + str_val
                    if new_col_name not in cell_df_csv.columns: cell_df_csv[new_col_name] = None
                    to_add.at[cell,new_col_name] = data_value_list[i]

    to_add = to_add.reindex(sorted(to_add.columns), axis=1)
    cell_df_csv.update(to_add)
    return cell_df_csv

In [2]:
def analysis_iterator(abf_recordings_df,func_dict,arg_dict,verbose=True):
    problem_recs = []

    for file_name in tqdm(abf_recordings_df.index):
        abf = pyabf.ABF(file_name)
        prot_name = abf.protocol
        if verbose: print('\n','     ',file_name)
        if verbose: print('     ',prot_name)

        # check for keyed protocol
        if prot_name not in func_dict.keys():
            # print('unknown protocol(func): ',  prot_name)
            continue
        if prot_name not in arg_dict.keys():
            # print('unknown protocol(args): ',  prot_name)
            continue

        if not command_match(abf):
            continue


        analyzer_func = func_dict[prot_name]  # get analyzer from dict
        args_for_analyzer =  [abf] + arg_dict[prot_name] # get args for analyzer from dict
        # try:
        results = analyzer_func(*args_for_analyzer) # run analyzer
        # except:
        #     print('\n','error on: ' ,file_name)
        #     print('analysis failed')
        try:
            for k in results.keys():
                # New Col?
                cols = abf_recordings_df.columns
                if k not in cols:
                    abf_recordings_df = init_col_object(abf_recordings_df,k)
                abf_recordings_df.at[file_name,k] = results[k]
        except:
            print('\n','error on: ' ,file_name)
            print('recording results failed')
            problem_recs.append(file_name)

    return abf_recordings_df, problem_recs

def init_col_object(df,name):
        df[name] = None
        df[name] = df[name].astype(object)
        return df

IndentationError: expected an indented block after 'if' statement on line 18 (ipython-input-2-2888019125.py, line 20)

In [None]:
def cell_sorting(abf_recordings_df):

    unique_cells = list(set(abf_recordings_df['cell_id']))
    unique_cells.sort()
    transfer_cols = [c for c in abf_recordings_df.columns if 'cell_id' not in c]
    cell_df = pd.DataFrame(index=list(unique_cells),columns = transfer_cols)


    for cell in cell_df.index:
        match = [cell in r for r in abf_recordings_df['cell_id']]
        for col in transfer_cols:
            match_values = list(abf_recordings_df[match][col])
            # print('col', col)
            # print(match_values)

            cell_df.at[cell,col] = match_values
    return cell_df

In [None]:
def mean_loose(x):
    if len(x)==0:
        x_mean = []
    if len(x)==1:
        x_mean = x[0]
    else:
        x_mean = np.nanmean(x,0)
    return x_mean

In [None]:
def cell_consolidation(cell_df,list_types,any_types,average_types = True):
    cell_df_con = cell_df.copy()
    explicit_cols = ['IV_Early','IV_Steady_State','Stim_Levels_(pA)','Spike_Counts']

    if average_types:
        average_types = [c for c in cell_df_con.columns if c not in any_types and c not in list_types and c not in explicit_cols]


    for cell in cell_df_con.index:
        for col in list_types:
            'do nothing, keep the list'
        for col in any_types:
            'they are all the same take the first'
            cell_df_con.at[cell,col] = cell_df_con.at[cell,col][0]

        for col in average_types:
            multi_vals = cell_df_con.loc[cell,col]
            try:
                multi_vals = [v for v in multi_vals if v is not None]
                single_val = mean_loose(multi_vals)
                cell_df_con.at[cell,col] = single_val
                # print(single_val)
            except: 'Just keep going None'


    # explicitly defined consolidations
    for col in ['IV_Early', 'IV_Steady_State']:
        # assert col in cell_df_con.index, f"Column to consoidate not found: {col}"
        for cell in cell_df_con.index:
            try:
                multi_vals = cell_df_con.loc[cell,col]
                multi_vals = consolidate_iv_recs(multi_vals)
            except:
                if np.isnan(multi_vals): multi_vals = None
                else: multi_vals = 'ERROR'

            if not isinstance(multi_vals, list): multi_vals=[multi_vals]
            cell_df_con.at[cell,col] = multi_vals

    for cell in cell_df_con.index:
        multi_val_pair = (cell_df_con.loc[cell,'Stim_Levels_(pA)'], cell_df_con.loc[cell,'Spike_Counts'])
        multi_val_pair = consolidate_gain_recs(multi_val_pair)

        new_stim = multi_val_pair[0]
        new_firing = multi_val_pair[1]
        if len(new_stim)>0:
            if isinstance(new_stim[0],list):
                new_stim = new_stim[0]
        if len(new_firing)>0:
            if isinstance(new_firing[0],list):
                new_firing = new_firing[0]

        cell_df_con.at[cell,'Stim_Levels_(pA)'] = new_stim
        cell_df_con.at[cell,'Spike_Counts'] = new_firing

    # cell_df_con = cell_df_con.reindex(sorted(cell_df_con.columns), axis=1)
    return cell_df_con


In [None]:
def mean_loose(x):
    if len(x)==0:
        x_mean = []
    if len(x)==1:
        x_mean = x[0]
    else:
        x_mean = np.nanmean(x,0)
    return x_mean

In [None]:
def cell_consolidation_v2(cell_df,list_types,any_types,average_types = True):
    cell_df_con = cell_df.copy()
    explicit_cols = ['IV_Early','IV_Steady_State','Stim_Levels_(pA)','Spike_Counts']
    mem_fit_columns = ['Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0', 'Cm_pc_10.0', 'Ra_160.0', 'Rm_160.0', 'tau_160.0',
                    'Cmq_160.0', 'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',]

    if average_types:
        average_types = [c for c in cell_df_con.columns if c not in any_types and c not in list_types and c not in explicit_cols and c not in mem_fit_columns]

    for cell in cell_df_con.index:
        cell_df_con.loc[cell] = consolidate_membrane_fit(cell_df_con.loc[cell])

    for cell in cell_df_con.index:
        for col in list_types:
            'do nothing, keep the list'
        for col in any_types:
            'they are all the same take the first'
            cell_df_con.at[cell,col] = cell_df_con.at[cell,col][0]

        for col in average_types:
            multi_vals = cell_df_con.loc[cell,col]
            try:
                multi_vals = [v for v in multi_vals if v is not None]
                single_val = mean_loose(multi_vals)
                cell_df_con.at[cell,col] = single_val
                # print(single_val)
            except: 'Just keep going None'


    # explicitly defined consolidations
    for col in ['IV_Early', 'IV_Steady_State']:
        # assert col in cell_df_con.index, f"Column to consoidate not found: {col}"
        for cell in cell_df_con.index:
            try:
                multi_vals = cell_df_con.loc[cell,col]
                multi_vals = consolidate_iv_recs(multi_vals)
            except:
                if np.isnan(multi_vals): multi_vals = None
                else: multi_vals = 'ERROR'

            if not isinstance(multi_vals, list): multi_vals=[multi_vals]
            cell_df_con.at[cell,col] = multi_vals

    for cell in cell_df_con.index:
        multi_val_pair = (cell_df_con.loc[cell,'Stim_Levels_(pA)'], cell_df_con.loc[cell,'Spike_Counts'])
        multi_val_pair = consolidate_gain_recs(multi_val_pair)

        new_stim = multi_val_pair[0]
        new_firing = multi_val_pair[1]
        if len(new_stim)>0:
            if isinstance(new_stim[0],list):
                new_stim = new_stim[0]
        if len(new_firing)>0:
            if isinstance(new_firing[0],list):
                new_firing = new_firing[0]

        cell_df_con.at[cell,'Stim_Levels_(pA)'] = new_stim
        cell_df_con.at[cell,'Spike_Counts'] = new_firing

    # cell_df_con = cell_df_con.reindex(sorted(cell_df_con.columns), axis=1)
    return cell_df_con

def cell_consolidation_v2(cell_df,list_types,any_types,average_types = True):
    cell_df_con = cell_df.copy()
    explicit_cols = ['IV_Early','IV_Steady_State','Stim_Levels_(pA)','Spike_Counts']
    mem_fit_columns = ['Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0', 'Cm_pc_10.0', 'Ra_160.0', 'Rm_160.0', 'tau_160.0',
                    'Cmq_160.0', 'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',]

    if average_types:
        average_types = [c for c in cell_df_con.columns if c not in any_types and c not in list_types and c not in explicit_cols and c not in mem_fit_columns]

    for cell in cell_df_con.index:
        cell_df_con.loc[cell] = consolidate_membrane_fit(cell_df_con.loc[cell])

    for cell in cell_df_con.index:
        for col in list_types:
            'do nothing, keep the list'
        for col in any_types:
            'they are all the same take the first'
            cell_df_con.at[cell,col] = cell_df_con.at[cell,col][0]

        for col in average_types:
            multi_vals = cell_df_con.loc[cell,col]
            try:
                multi_vals = [v for v in multi_vals if v is not None]
                single_val = np.nanmean(multi_vals,0)
                cell_df_con.at[cell,col] = single_val
                # print(single_val)
            except: 'Just keep going None'


    # explicitly defined consolidations
    for col in ['IV_Early', 'IV_Steady_State']:
        # assert col in cell_df_con.index, f"Column to consoidate not found: {col}"
        for cell in cell_df_con.index:
            try:
                multi_vals = cell_df_con.loc[cell,col]
                multi_vals = consolidate_iv_recs(multi_vals)
            except:
                if np.isnan(multi_vals): multi_vals = None
                else: multi_vals = 'ERROR'

            if not isinstance(multi_vals, list): multi_vals=[multi_vals]
            cell_df_con.at[cell,col] = multi_vals

    for cell in cell_df_con.index:
        multi_val_pair = (cell_df_con.loc[cell,'Stim_Levels_(pA)'], cell_df_con.loc[cell,'Spike_Counts'])
        multi_val_pair = consolidate_gain_recs(multi_val_pair)

        new_stim = multi_val_pair[0]
        new_firing = multi_val_pair[1]
        if len(new_stim)>0:
            if isinstance(new_stim[0],list):
                new_stim = new_stim[0]
        if len(new_firing)>0:
            if isinstance(new_firing[0],list):
                new_firing = new_firing[0]

        cell_df_con.at[cell,'Stim_Levels_(pA)'] = new_stim
        cell_df_con.at[cell,'Spike_Counts'] = new_firing

    # cell_df_con = cell_df_con.reindex(sorted(cell_df_con.columns), axis=1)
    return cell_df_con

def cell_consolidation_v2(cell_df,list_types,any_types,average_types = True):
    cell_df_con = cell_df.copy()
    explicit_cols = ['IV_Early','IV_Steady_State','Stim_Levels_(pA)','Spike_Counts']
    mem_fit_columns = ['Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0', 'Cm_pc_10.0', 'Ra_160.0', 'Rm_160.0', 'tau_160.0',
                    'Cmq_160.0', 'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',]

    if average_types:
        average_types = [c for c in cell_df_con.columns if c not in any_types and c not in list_types and c not in explicit_cols and c not in mem_fit_columns]

    for cell in cell_df_con.index:
        cell_df_con.loc[cell] = consolidate_membrane_fit(cell_df_con.loc[cell])

    for cell in cell_df_con.index:
        for col in list_types:
            'do nothing, keep the list'
        for col in any_types:
            'they are all the same take the first'
            cell_df_con.at[cell,col] = cell_df_con.at[cell,col][0]

        for col in average_types:
            multi_vals = cell_df_con.loc[cell,col]
            try:
                multi_vals = [v for v in multi_vals if v is not None]
                single_val = np.nanmean(multi_vals,0)
                cell_df_con.at[cell,col] = single_val
                # print(single_val)
            except: 'Just keep going None'


    # explicitly defined consolidations
    for col in ['IV_Early', 'IV_Steady_State']:
        # assert col in cell_df_con.index, f"Column to consoidate not found: {col}"
        for cell in cell_df_con.index:
            try:
                multi_vals = cell_df_con.loc[cell,col]
                multi_vals = consolidate_iv_recs(multi_vals)
            except:
                if np.isnan(multi_vals): multi_vals = None
                else: multi_vals = 'ERROR'

            if not isinstance(multi_vals, list): multi_vals=[multi_vals]
            cell_df_con.at[cell,col] = multi_vals

    for cell in cell_df_con.index:
        multi_val_pair = (cell_df_con.loc[cell,'Stim_Levels_(pA)'], cell_df_con.loc[cell,'Spike_Counts'])
        multi_val_pair = consolidate_gain_recs(multi_val_pair)

        new_stim = multi_val_pair[0]
        new_firing = multi_val_pair[1]
        if len(new_stim)>0:
            if isinstance(new_stim[0],list):
                new_stim = new_stim[0]
        if len(new_firing)>0:
            if isinstance(new_firing[0],list):
                new_firing = new_firing[0]

        cell_df_con.at[cell,'Stim_Levels_(pA)'] = new_stim
        cell_df_con.at[cell,'Spike_Counts'] = new_firing

    # cell_df_con = cell_df_con.reindex(sorted(cell_df_con.columns), axis=1)
    return cell_df_con

def consolidate_membrane_fit(cell):
    cell_con=cell.copy()
    columns=['Ra_10.0',
            'Rm_10.0',
            'tau_10.0',
            'Cmq_10.0',
            'Cmf_10.0',
            'Cmqf_10.0',
            'Cm_pc_10.0',
            'Ra_160.0',
            'Rm_160.0',
            'tau_160.0',
            'Cmq_160.0',
            'Cmf_160.0',
            'Cmqf_160.0',
            'Cm_pc_160.0',]

    cm = cell_con['Cmq_160.0']
    ra = cell_con['Ra_160.0']
    # rm = cell['Rm_160.0']


    ra = [np.inf if n is None else n for n in ra]
    # print('ra',ra)
    ra_rank = np.argsort(np.argsort(-1*np.array(ra)))
    # print('ra_rank',ra_rank)

    cm = [-1 if n is None else n for n in cm]
    # print('cm',cm)
    cm_rank = np.argsort(np.argsort(cm)) #
    # print('cm_rank',cm_rank)

    rank = cm_rank + ra_rank

    best_index = np.argsort(-rank)[0]
    # print('best_index',best_index)
    for c in columns:
        cell_con.at[c] = cell_con.at[c][best_index]

    return cell_con

In [None]:
def consolidate_iv_recs(multi_vals):
    multi_vals = [v for v in multi_vals if v is not None]
    v_stim = [  mv['V_stim'] for mv in  multi_vals ]
    peak_vals = [  mv['I_peak'] for mv in  multi_vals ]
    if len(v_stim)>1:
        rec_lengths = [len(v) for v in v_stim]
        long_enough = np.where(np.array(rec_lengths) > 5)[0][0]
        multi_vals = multi_vals[long_enough]
        # print(multi_vals)
    return multi_vals

In [None]:

def simplify_dicts(cell_df,cols_to_simplify,remove_source = True):
    cell_df_new = cell_df.copy()
    for col in cols_to_simplify:
        for cell in cell_df_new.index:
            list_of_dicts = cell_df_new.loc[cell,col]
            list_of_dicts = [d for d in list_of_dicts if d is not None]
            if len(list_of_dicts) == 0: continue
            # print(list_of_dicts)
            list_of_keys = list(list_of_dicts[0].keys())
            for k in list_of_keys:
                vals_of_key = []
                for i in range(len(list_of_dicts)):
                    vals_of_key.append(  list_of_dicts[i][k] )
                if len(vals_of_key) == 1: vals_of_key = vals_of_key[0]
                new_col = col + '_(' + str(k) +')'
                if new_col not in cell_df_new.columns:
                    cell_df_new[new_col] = None
                    cell_df_new[new_col] = cell_df_new[new_col].astype(object)
                cell_df_new.at[cell,new_col] = vals_of_key
        cell_df_new.drop(labels=col, axis = 1,inplace = True)
    return cell_df_new

In [None]:
def consolidate_gain_recs(multi_val_pair):
    min_stims = 5
    mv_stim = multi_val_pair[0]
    mv_fire = multi_val_pair[1]
    mv_stim = [v.tolist() for v in mv_stim if v is not None]
    mv_fire = [v.tolist() for v in mv_fire if v is not None]
    results = (mv_stim, mv_fire)


    if len(mv_stim)>1:
        rec_lengths = [len(v) for v in mv_stim]
        mv_stim = [v for v in mv_stim if len(v) >=min_stims]
        mv_fire = [v for v in mv_fire if len(v) >=min_stims]

    results = (mv_stim, mv_fire)

    if len(mv_stim)>1:
        stim_set = list(set( [vv for v in mv_stim for vv in v] ))# flat_list = [item for sublist in regular_list for item in sublist]
        stim_set.sort()
        new_vals_dict = {}
        for s in stim_set:
            matching_response =[]
            matching_stim = []
            for i in range(len(mv_stim)):
                for j in range(len(mv_stim[i])):
                    if mv_stim[i][j] == s:
                        matching_stim.append(mv_stim[i][j])
                        matching_response.append(mv_fire[i][j])
            new_vals_dict[s] =  matching_response
        new_stim_list = []
        new_response_list = []
        for k in new_vals_dict:
            new_vals_dict[k] = np.mean(new_vals_dict[k])
            new_stim_list.append(k)
            new_response_list.append(new_vals_dict[k])


        results = (new_stim_list, new_response_list)

    return results

In [None]:
def current_density_correction(cell_df,size_col,current_col_list,remove_old=True):
    cell_df_cd = cell_df.copy()
    ccl_exp = []
    for ccl in current_col_list:
        ccl_exp = ccl_exp + [c for c in cell_df.columns if ccl in c]
    current_col_list = ccl_exp
    for cell in cell_df.index:
        size = cell_df.loc[cell,size_col]
        for col in current_col_list:
            try:
                new_col = col +'_pApF'
                cell_df_cd.at[ cell,new_col] = cell_df_cd.at[ cell,col] / size
            except:
                cell_df_cd.at[ cell,new_col] = None

    cell_df_cd = cell_df_cd[[ c for c in cell_df_cd.columns if c not in current_col_list ]]

    return cell_df_cd

In [None]:
def stratify_cells(cell_df,strat_col,xl_file_name='stratified_data.xlsx'):
    types = list(set(cell_df[strat_col]))

    new_dfs = {}
    # options = {}
    # options['strings_to_formulas'] = False
    # options['strings_to_urls'] = False
    writer = pd.ExcelWriter(xl_file_name) # , options=options
    for t in types:
        is_type = cell_df[strat_col] == t
        new_dfs[t] = cell_df[is_type]
        new_dfs[t].to_excel(writer, sheet_name=str(t))
        # new_dfs[t].to_csv(str(t) + '_cell_df_csv.csv')
        # files.download(str(t) + '_cell_df_csv.csv')
    # writer.save()
    writer.close()
    files.download(xl_file_name)
    return new_dfs

In [None]:
def stratify_rec_v2(cell_df,strat_cols):
    """ Takes in cell_df and splits it in to multiple dataframes stored in a
    dict, where the key is the combonation of features definining an analysis
    group and the value is a df containing the cells beloning to that group"""
    # strat_cols = ['Cell_Type', 'Rec_date', 'Marker']
    strat_keys = list()
    for ind in cell_df.index:
        strat_keys.append( '_'.join([cell_df.loc[ind,col] for col in strat_cols] ))
    cell_df['strat_keys']=strat_keys
    key_set = list(set(cell_df['strat_keys']))
    strat_dict = {k:cell_df[cell_df['strat_keys']==k] for k in key_set}
    return strat_dict

In [None]:
def flatten_dict(my_dict,flat_dict = {} ):
    for k in my_dict.keys():
        if isinstance(my_dict[k], dict):
            sub_dict, sub_keys = flatten_dict(my_dict[k],flat_dict)
            for sk in sub_keys:
                flat_dict['_'+sk] = sub_dict[sk]
        else:
            flat_dict = my_dict
    return flat_dict, list(flat_dict.keys())

In [None]:
def write_strat_dfs(strat_dfs, xl_file_name='stratified_data.xlsx'):
    if '.xlsx' not in xl_file_name: xl_file_name = xl_file_name+'.xlsx'
    options = {}
    # options['strings_to_formulas'] = False
    # options['strings_to_urls'] = False
    writer = pd.ExcelWriter(xl_file_name) # , options=options
    for k in strat_dfs.keys():
        cur_df = strat_dfs[k]
        # cur_df = strat_dfs[k].T
        cur_df.to_excel(writer, sheet_name=str(k))
    # writer.save()
    writer.close()
    files.download(xl_file_name)
    return None

def write_strat_dfs_local(strat_dfs, xl_file_name='stratified_data.xlsx'):
    if not xl_file_name.endswith('.xlsx'):
        xl_file_name += '.xlsx'

    # Check if the file exists and, if so, modify the file name
    counter = 1
    new_file_name = xl_file_name
    while os.path.exists(new_file_name):
        new_file_name = xl_file_name.replace('.xlsx', f'_C{counter}.xlsx')
        counter += 1

    # Save DataFrames to an Excel file with a unique name
    with pd.ExcelWriter(new_file_name, engine='xlsxwriter') as writer:
        for k, cur_df in strat_dfs.items():
            cur_df.to_excel(writer, sheet_name=str(k))

    return new_file_name  # Return the name of the created file for reference

In [None]:
def restrat(strat_df_dict,alt_strat_groups  ):
    alt_strat_dict = {}
    for group in alt_strat_groups:
        new_df = pd.DataFrame()
        sorted_keys = sorted(list(strat_df_dict.keys()))
        for k in sorted_keys:
            v = strat_df_dict[k]
        # for k,v in strat_df_dict.items():
            if isinstance(group, list):
                # print(group,'LIST')
                ''' iter the list'''
                group_name = group[0]
                for sub_group in group:
                    new_df = add_col( new_df,sub_group,k,v)
            else:
                group_name = group
                new_df = add_col( new_df,group,k,v)
        alt_strat_dict[group_name] = new_df
    return alt_strat_dict


def add_col( new_df,g,k,v):
        new_col_name = g+'_'+k
        clean_ser = v[g].reset_index().drop(labels='index',axis=1,inplace=False)
        len_diff =  len(clean_ser)-len(new_df)
        if len_diff>0:
            blank_df = pd.DataFrame( index=range(len_diff),columns=new_df.columns)
            new_df = pd.concat([new_df, blank_df],ignore_index=True)
        new_df[new_col_name]=clean_ser
        return new_df



In [None]:
def restratify_results(results_dict,labels,alt_strat_groups):
    def add_col( new_df,g,k,v):
        new_col_name = g+'_'+k
        clean_ser = v[g].reset_index().drop(labels='index',axis=1,inplace=False)
        len_diff =  len(clean_ser)-len(new_df)
        if len_diff>0:
            blank_df = pd.DataFrame( index=range(len_diff),columns=new_df.columns)
            new_df = pd.concat([new_df, blank_df],ignore_index=True)
        new_df[new_col_name]=clean_ser
        return new_df

    for k,v in results_dict.items():
        v2 = v.assign(Strat_ID=[k]*len(v.index))
        results_dict[k]=v2

    alt_strat_dict = {}
    for group in alt_strat_groups:
        new_df = pd.DataFrame()
        for k,v in results_dict.items():
            if isinstance(group, list):
                # print(group,'LIST')
                ''' iter the list'''
                group_name = group[0]
                for sub_group in group:
                    new_df = add_col( new_df,sub_group,k,v)
            else:
                group_name = group
                new_df = add_col( new_df,group,k,v)
        col_list = new_df.columns.tolist()
        new_df = new_df[col_list]
        alt_strat_dict[group_name] = new_df
    return alt_strat_dict







In [None]:
def write_strat_dfs(strat_dfs, xl_file_name='stratified_data.xlsx'):
    if '.xlsx' not in xl_file_name: xl_file_name = xl_file_name+'.xlsx'
    # options = {}
    # options['strings_to_formulas'] = False
    # options['strings_to_urls'] = False
    writer = pd.ExcelWriter(xl_file_name) # , options=options
    for k in strat_dfs.keys():
        cur_df = strat_dfs[k]

        k=k.replace('/','_per_')
        k=k[:31]
        print(str(k))
        cur_df.to_excel(writer, sheet_name=str(k))
    # writer.save()
    writer.close()
    files.download(xl_file_name)
    return None


In [None]:
def stratify_response_curve(strat_df_dict,resp_curve_list,strat_list):
    import re
    response_curve_data = {}
    for curve in resp_curve_list:
        new_df_list = [pd.DataFrame() for l in strat_list]

        for k,v in strat_df_dict.items():
            for row in v.index:
                cols = [c for c in v.columns if curve in c]
                row_val = v.loc[row,cols]
                strat_ind = [i for i in range(len(strat_list)) if strat_list[i] in v.loc[row,'Strat_ID']][0]
                new_df_list[strat_ind] = pd.concat( [new_df_list[strat_ind],row_val] ,axis=1 )
                response_curve_data[strat_list[strat_ind] +'_'+curve] = new_df_list[strat_ind]

    for k,v in response_curve_data.items():
        rows = v.index
        # r_ints = [ float(re.findall("[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", r)[0]) for r in rows  ]
        r_ints = [float(re.findall(r"[-+]?[.]?\d+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", r)[0]) for r in rows]
        order = np.argsort(r_ints)
        v = v.reindex(rows[order])
        response_curve_data[k] = v

    return response_curve_data

In [None]:
##### reorg Files
def strat_abfs_by_prot(new_dir, vm_local_dir,lut_df_loc):
    lut_df = pd.read_csv(lut_df_loc)
    try: shutil.rmtree('/content/'+new_dir)
    except: None
    os.mkdir('/content/'+new_dir)
    for subdir, dirs, fils in os.walk(vm_local_dir):
        for file in fils:
            if 'abf' in file:
                try:
                    full_path = (os.path.join(subdir, file))
                    abf_file = abf_or_name(full_path)
                    prot_name = abf_file.protocol
                    new_prot_dir = new_dir+'/'+prot_name
                    try: os.mkdir('/content/'+new_prot_dir)
                    except: None
                    shutil.copyfile(full_path, new_prot_dir+'/'+file)
                except: print('Failed', file)
    shutil.make_archive(new_dir, 'zip', new_dir)
    from google.colab import files
    files.download(new_dir+'.zip')

In [None]:
def final_qc(strat_df_dict,file_naming_scheme,qc_RR=.3,qc_AP_amp=50,qc_Rmp=-50,qc_Ra=50,manual_exclusions=[],exclusion_overide=[]):
    fail_dict = {}
    filtered_dict = strat_df_dict.copy()
    for k in filtered_dict.keys():
        data_df = filtered_dict[k].copy()


        for r in data_df.index:
            if r in exclusion_overide:
                fail_dict[r] = f"Exlusion Override"
                continue

            if r in manual_exclusions:
                fail_dict[r] = f"Manual Exclusion"
                for c in data_df.columns:
                    if c not in file_naming_scheme:
                        data_df.at[r,c] = np.nan


            Ra = data_df.loc[r,'Ra_160.0']
            Rm = data_df.loc[r,'Rm_160.0']
            RR = Ra/Rm
            # print(RR,Ra,Rm)
            if RR > qc_RR:
                fail_dict[r] = f"Fail - RR = {RR} > {qc_RR}"
                for c in data_df.columns:
                    if c not in file_naming_scheme:
                        data_df.at[r,c] = np.nan

            if Ra > qc_Ra:
                fail_dict[r] = f"Fail - Ra = {Ra} > {qc_Ra}"
                for c in data_df.columns:
                    if c not in file_naming_scheme:
                        data_df.at[r,c] = np.nan

            if data_df.loc[r,'ap_amplitutude'] < qc_AP_amp:
                fail_dict[r] = f"Fail - AP = {data_df.loc[r,'ap_amplitutude']} mV < {qc_AP_amp}mV"
                for c in data_df.columns:
                    if c not in file_naming_scheme:
                        data_df.at[r,c] = np.nan

            if data_df.loc[r,'Rmp_mV'] > qc_Rmp:
                fail_dict[r] = f"Fail - Rmp = {data_df.loc[r,'Rmp_mV']} mV > {qc_Rmp} mV"
                for c in data_df.columns:
                    if c not in file_naming_scheme:
                        data_df.at[r,c] = np.nan
        filtered_dict[k] = data_df
    for k,v in fail_dict.items():
        print(k , v)
    return filtered_dict, fail_dict


In [None]:
##### alternate Stratification and QC params

def analysis_consolidation(result_dict,
                           manual_exclusions = [''],
                           exclusion_overide = [''],
                           single_val_strat_groups=['ap_amplitutude','Rmp_mV','Ra_160.0','Rm_160.0','Cm_pc_10.0','Cmq_160.0','Ra_160.0',
                                                    'Rm_160.0','Gain_(HzpA)','max_adapt%','adapt_thresh_90','Rheobase','AP_thresh_US',
                                                    'fast_after_hyperpol','Spike_latency_(ms)','Input_Resistance_MO'],
                           resp_curve_list = ['IV_Early_(V_stim)','IV_Early_(I_peak)','IV_Steady_State_(I_mean)','Stim_Levels_(pA)','Spike_Counts'],
                           file_naming_scheme = ['Rec_date', 'Genotype', 'Sex', 'Age', 'Slice_Num', 'Cell_num', 'Cell_Type'],
                           QC_param = {}
                           ):
    print(QC_param)
    QC_param_keys = QC_param.keys()
    if 'qc_Rmp' not in QC_param_keys: QC_param['qc_Rmp']=-45
    if 'qc_AP_amp' not in QC_param_keys: QC_param['qc_AP_amp']=40
    if 'qc_RR' not in QC_param_keys: QC_param['qc_RR']=.35
    if 'qc_Ra' not in QC_param_keys: QC_param['qc_Ra']=65
    print(QC_param)

    strat_df_dict = result_dict['strat_df_dict'].copy()
    filtered_dict, fail_dict = final_qc(strat_df_dict,file_naming_scheme, QC_param['qc_RR'], QC_param['qc_Rmp'], QC_param['qc_AP_amp'],  QC_param['qc_Ra'],manual_exclusions=manual_exclusions,exclusion_overide=exclusion_overide)
    alt_strat_dict = restratify_results(filtered_dict,file_naming_scheme,single_val_strat_groups)
    response_curve_data = stratify_response_curve(filtered_dict,resp_curve_list,strat_list=[''])
    alt_strat_dict.update(response_curve_data)
    write_strat_dfs(alt_strat_dict, dataset['data_name']+'_results_stratified_alternate')

    return {'filtered_dict':filtered_dict,
            'fail_dict':fail_dict,
            'alt_strat_dict':alt_strat_dict,
            'response_curve_data':response_curve_data,
            'single_val_strat_groups':single_val_strat_groups,
            'resp_curve_list':resp_curve_list}


In [None]:

### Make Some Bar Plots
def summary_plots(alt_strat_dict,single_val_strat_groups):

    for k,v in alt_strat_dict.items():
        if k in single_val_strat_groups:
            categs = list(v.columns)
            fig,ax=plt.subplots(1,1, figsize = [ .3 + .5*len(categs), 2])
            for ci in range(len(categs)):
                c_vals = v[categs[ci]]
                c_vals = [vi for vi in c_vals if np.isfinite(vi)]
                c_mean = np.nanmean(c_vals)
                c_sem = np.nanstd(c_vals)/np.sqrt(len(c_vals))
                ax.bar(ci,c_mean,yerr=c_sem)
                ax.scatter([ci]*len(c_vals),c_vals,color='k',marker='o')
            ax.set_xticks(range(len(categs)))
            categs_strats = [str(cat).split("__")[1] for cat in categs]
            ax.set_xticklabels(categs_strats,rotation =45)
            ax.set_ylabel( str(categs[0]).split("__")[0]  )
    return None

In [None]:
def day_to_bin(day,age_bin_dict):

    month = int(day) / 365 * 12
    month_flr = int(month)
    return age_bin_dict[month_flr],month_flr, month

def convert_age_month_bins(cell_df,age_bin_dict,age_key='Age'):
    cell_df['Age_Bin']=""
    for cell in cell_df.index:
        age_day = cell_df.loc[cell,age_key]
        age_day=age_day.replace('P',"")*1
        age_bin,_,_ = day_to_bin(age_day,age_bin_dict)
        cell_df.at[cell,'Age_Bin']=age_bin
    return cell_df


In [None]:
def substitute_gain_rheobase(wrapper_results,rheo_gain_func_dict=None,rheo_gain_arg_dict=None,strat_cols=['Cell_Type'],age_bin_dict=None):
    results = wrapper_results

    rheo_df = wrapper_results['cell_df_csv_abrg']['Rheobase']
    missing_rheo = [index for index, value in rheo_df.items() if np.isnan(value)]
    print(missing_rheo)


    prot_lut_df = pd.read_csv(wrapper_results['prot_lut'],index_col=0)
    gain_files_to_use = prot_lut_df['IC - Gain - D10pA'][missing_rheo].values

    abf_recordings_df = wrapper_results['abf_recordings_df']
    abf_recordings_df_subset = abf_recordings_df[abf_recordings_df['Recording_name'].isin(gain_files_to_use)]

    if rheo_gain_func_dict is None or rheo_gain_arg_dict is None:
        spike_args_rheo={'spike_thresh':10, 'high_dv_thresh': 30,'low_dv_thresh': -10,'window_ms': 2}
        rheo_gain_func_dict={}
        rheo_gain_func_dict['IC - Gain - D10pA']= rheobase_analyzer
        rheo_gain_arg_dict={}
        rheo_gain_arg_dict['IC - Gain - D10pA']= [spike_args_rheo, True, False, False]

    abf_recordings_df_subset, _ = analysis_iterator(abf_recordings_df_subset,rheo_gain_func_dict,rheo_gain_arg_dict)

    abf_recordings_df.update(abf_recordings_df_subset)
    results['abf_recordings_df'] = abf_recordings_df

    '''Sort Cells'''
    cell_df = cell_sorting(abf_recordings_df)
    results['cell_df'] = cell_df

    '''Consolidate to Cells'''
    list_types = ['Recording_name','protocol','abf_timestamp', 'channelList']
    any_types = [] + dataset['file_naming_scheme']
    cell_df_con = cell_consolidation_v2(cell_df,list_types,any_types)
    results['cell_df_con'] = cell_df_con

    '''Simplify IV Data'''
    cols_to_simplify = ['IV_Early', 'IV_Steady_State']
    cell_df_nd = simplify_dicts(cell_df_con,cols_to_simplify)
    results['cell_df_nd'] = cell_df_nd

    '''Make Excell Friendly'''
    keys_and_data_cols={'Stim_Levels_(pA)': ['Stim_Levels_(pA)', 'Spike_Counts' ],
                    'IV_Early_(V_stim)': ['IV_Early_(V_stim)', 'IV_Early_(I_peak)', 'IV_Steady_State_(I_mean)']}
    cell_df_csv = csv_frinedly(cell_df_nd,keys_and_data_cols)
    results['cell_df_csv'] = cell_df_csv


    ''' Convert to Current Density'''
    size_col = 'Cmq_160.0'
    current_col_list = ['IV_Early_(I_peak)_', 'IV_Steady_State_(I_mean)_']
    cell_df_csv = current_density_correction(cell_df_csv, size_col, current_col_list)
    results['cell_df_csv'] = cell_df_csv

    '''Abridge DataFrame'''
    abrg_exclusions = ['Recording_name',
                    'protocol', 'abf_timestamp', 'channelList',  'Ra_10.0', 'Rm_10.0', 'tau_10.0', 'Cmq_10.0', 'Cmf_10.0',
                    'Cmqf_10.0',  'Cmf_160.0', 'Cmqf_160.0', 'Cm_pc_160.0',
                    'Gain_R2', 'Stim_Levels_(pA)', 'Spike_Counts',  'Gain_Vh',  'Vhold_spike',
                        'Rin_Rsqr',  'Ramp_AP_thresh', 'Ramp_Vh', 'Ramp_Rheobase',
                    'v_half','is_compensated','sum_delta'
                    'IV_Early_(range)', 'IV_Early_(I_peak)', 'IV_Early_(I_mean)', 'IV_Early_(V_stim)', 'IV_Steady_State_(range)',
                    'IV_Steady_State_(I_peak)', 'IV_Steady_State_(I_mean)', 'IV_Steady_State_(V_stim)', ]

    abrg_keep = [c for c in cell_df_csv.columns if c not in abrg_exclusions]
    cell_df_csv_abrg = cell_df_csv[abrg_keep]
    results['cell_df_csv_abrg'] = cell_df_csv_abrg


    """
    Add Age Bins
    """
    if age_bin_dict is None:
        age_bin_dict = {6:'<=6mo',
                        7:'7-9mo',
                        8:'7-9mo',
                        9:'7-9mo',
                        17:'17-19mo',
                        18:'17-19mo',
                        19:'17-19mo',}
    cell_df_csv_abrg = convert_age_month_bins(cell_df_csv_abrg,age_bin_dict,age_key='Age')

    '''Stratify Cells By Type'''
    # strat_df_dict = stratify_rec(cell_df_csv_abrg,strat_cols)
    # strat_df_dict,_ = flatten_dict(strat_df_dict,{})
    strat_df_dict = stratify_rec_v2(cell_df_csv_abrg,strat_cols)
    write_strat_dfs_local(strat_df_dict, dataset['data_name']+'_results_stratified')
    results['strat_df_dict'] = strat_df_dict
    return results
