<a href="https://colab.research.google.com/github/dtabuena/EphysLib/blob/main/Analyzers/rheobase_analyzer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
''' Requires:
    Simple_ABF_tools.ipynb
    Basic_Ephys.ipynb
'''

def rheobase_analyzer(file_name,
                        spike_args =  {'spike_thresh':20, 'high_dv_thresh': 25,'low_dv_thresh': -50,'window_ms': 2},
                        to_plot=False,
                        verbose=False,
                        single_spike=True,figopt={'type':'jpg','dpi':300}):

    ''' File Analyzer for Rheobase etc  '''

    results = {} # default return

    abf = abf_or_name(file_name)


    # Rheobase Measure:
    if len(abf.sweepList)<2:
        return results
    else:
        is_base, is_stim = protocol_baseline_and_stim(abf)
        stim_currents, spike_counts, spike_rates, V_before_AP,V_before_stim,_ ,_= spikes_per_stim_LEGACY(abf, spike_args,to_plot=to_plot)
        single_spikes = spike_counts==1
        zero_spikes = spike_counts==0
        if single_spike:
            none_to_one = np.full(single_spikes.shape, False)
            none_to_one[1:] = np.logical_and(single_spikes[1:], zero_spikes[:-1])
            first_spike_stim = np.where(none_to_one)[0]
        else:
            some_spikes = spike_counts>0
            none_to_some = np.full(single_spikes.shape, False)
            none_to_some[1:] = np.logical_and(some_spikes[1:], zero_spikes[:-1])
            first_spike_stim = np.where(none_to_some)[0]


    if first_spike_stim.size == 0:
        return results
    else:
        if first_spike_stim.size >1:
            first_spike_stim = np.min(first_spike_stim)
        results['Rheobase'] = stim_currents[first_spike_stim][0]
        results['Vhold_spike'] = V_before_stim[first_spike_stim][0]
        results['AP_thresh'] = V_before_AP[first_spike_stim][0]

    if first_spike_stim.size>0:
        # try:
        abf = abf_or_name(file_name)
        abf.setSweep(first_spike_stim[0])
        ap_params = single_ap_stats(abf,spike_args,window_ms=[-3, 9.5],to_plot=to_plot,verbose=verbose)
        results.update(ap_params)
        # except:
        #     print('ap_stats_failed: ', file_name)


    if to_plot:
        rheo_fig, ax = plt.subplots(1,1)

        try:    os.makedirs('Saved_Figs/Rheobase/')
        except:     None
        for s in abf.sweepList:
            abf.setSweep(s)
            ax.plot(abf.sweepX,abf.sweepY,label = str(stim_currents[s]) + ' pA')
        ax.legend(loc='center right') #,bbox_to_anchor=(1,0.5)
        plt.show()
        plt.tight_layout()
        rheo_fig.savefig( 'Saved_Figs/Rheobase/Rheobase' + '_' + abf.abfID +'.'+figopt['type'],dpi=figopt['dpi'])


    return results




In [None]:
def single_ap_stats(abf,spike_args,window_ms=[-3, 6.5],rise_fraction=0.90,to_plot=True,verbose=False,up_sample = True):
    '''Takes in a voltage trace from current clamp mode at/near rheobase and calulates various AP parameters'''

    "Hard Code Threshold"
    spike_args['spike_thresh'] = 20

    # window_ms=[-1, 3.5]
    x_trace = abf.sweepX
    y_trace = abf.sweepY
    sample_rate = abf.sampleRate

    fahp_max_delay = window_ms[1]

    if up_sample:
        factor = 4
        x_new = np.linspace(x_trace[0],x_trace[-1], num=len(x_trace)*factor )
        interp_func = scipy.interpolate.interp1d(x_trace, y_trace, kind='quadratic')
        y_trace = interp_func(x_new)
        x_trace = x_new
        sample_rate = sample_rate*factor


    is_stim = np.array([True for i in range(len(x_trace))])

    args = find_spike_in_trace(y_trace,sample_rate,spike_args,is_stim = is_stim)
    inds = args[2]

    window_ind = np.arange(window_ms[0]/1000*sample_rate,window_ms[1]/1000*sample_rate)
    first_spike = inds[0]
    capture_indicies = np.array(window_ind+first_spike,dtype='int')

    spike_trace_x = x_trace[capture_indicies]
    spike_trace_y = y_trace[capture_indicies]


    spike_trace_dvds = np.diff(spike_trace_y,prepend=spike_trace_y[0])*sample_rate/1000


    v_max = np.max(spike_trace_y)
    v_max_ind = np.where(spike_trace_y==v_max)[0]
    fahp_wind = np.array(np.arange(0,fahp_max_delay/1000*sample_rate),dtype='int')+v_max_ind

    fahp_wind = fahp_wind[fahp_windspike_args['spike_thresh'])[0][0]

    # v_baseline -> ap_thresh_us
    ap_thresh_us= spike_trace_y[ap_start_ind]
    print('ap_thresh_us', spike_trace_y[ap_start_ind])
    v_half = np.mean([v_max,ap_thresh_us])

    over_half_ind = np.where( spike_trace_y>=v_half )[0]
    half_start = over_half_ind[0]
    half_stop = over_half_ind[-1]
    ap50_width_ms = (spike_trace_x[half_stop] - spike_trace_x[half_start])*1000

    fractional_peak = ap_thresh_us+rise_fraction*(v_max-ap_thresh_us)
    fractional_base = ap_thresh_us+(1-rise_fraction)*(v_max-ap_thresh_us)
    rising_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds>0)
    falling_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds<0)

    rise_time_ms = len(spike_trace_x[rising_bool])/sample_rate*1000
    fall_time_ms = len(spike_trace_x[falling_bool])/sample_rate*1000

    fast_after_hyperpol = fast_after_hyperpol - ap_thresh_us

    ap_amplitutude = v_max-ap_thresh_us

    ap_params = {'ap_amplitutude':ap_amplitutude,
                 'fast_after_hyperpol':fast_after_hyperpol,
                 'AP_thresh_US':ap_thresh_us,
                 'v_half':v_half,
                 'ap50_width_ms':ap50_width_ms,
                 'rise_time_ms':rise_time_ms,
                 'fall_time_ms':fall_time_ms,
                 'dv_max':dv_max,
                 'dv_min':dv_min,}


    if to_plot:
        fig, axs = plt.subplots(2,1,figsize=(4,3))
        for ax in axs:
            ax.plot(spike_trace_x,spike_trace_y,'k.-')
            ax.scatter(spike_trace_x[ap_start_ind],spike_trace_y[ap_start_ind],color='red')
            ax.scatter(spike_trace_x[fast_after_hyperpol_ind],spike_trace_y[fast_after_hyperpol_ind],color='magenta')
            ax.plot(spike_trace_x[rising_bool],spike_trace_y[rising_bool],color='red' )
            ax.plot(spike_trace_x[falling_bool],spike_trace_y[falling_bool],color='cyan' )
            # yright = ax.twinx()
            # dy = np.diff(spike_trace_y,prepend=spike_trace_y[0])/(np.diff(spike_trace_x[0:2])[0]*1000) # np.diff(spike_trace_y,append=spike_trace_y[-1]))/2
            # yright.plot(spike_trace_x,dy,'orange',marker='o')
        # yright.set_ylim(-20,60)
        axs[1].set_ylim(ap_thresh_us+fast_after_hyperpol-5,ap_thresh_us+5,)
        try:    os.makedirs('Saved_Figs/AP_Params/')
        except:     None
        plt.savefig( 'Saved_Figs/AP_Params/AP_Params_' + abf.abfID +'.png',dpi=600)
        plt.show()
    if verbose:
        print(ap_params)
        # print('spike_trace_x[half_start]',spike_trace_x[half_start])
        # print('spike_trace_x[half_stop]',spike_trace_x[half_stop])

    return ap_params