<a href="https://colab.research.google.com/github/dtabuena/EphysLib/blob/main/Basic_Ephys.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def find_spike_in_trace(trace,rate,spike_args,refract=0.005,is_stim = None ,mode='count',sanity_check=True,to_plot=0):
    '''Takes in a voltage trace from current clamp mode and uses derivative (dVds) to find action potentials.
    Returns the dVds trace, boolean array indicating if dVds>threshold, inicies where dV crossed threshold,
    and the mean firing rate given # spikes in trace of given length. Optional ways to count are:
    isi (1/interspike interval) or count (spike count per second). Default is count '''

    high_dv_thresh = spike_args['high_dv_thresh']
    low_dv_thresh = spike_args['low_dv_thresh']
    spike_thresh = spike_args['spike_thresh']
    window_ms = spike_args['window_ms']

    if any(is_stim == None):
        is_stim = [True for i in trace]
    dVds = np.diff(trace, prepend=trace[0])*rate/1000
    over_thresh = dVds>spike_thresh
    over_thresh[np.logical_not(is_stim)] = False
    refract_window = int(np.round((refract*rate)))
    inds = [t for t in np.arange(refract_window,len(over_thresh)) if all([over_thresh[t], all(over_thresh[t-refract_window:t]==False)])]    
    if sanity_check:
        old_inds = inds
        inds = []
        for i in old_inds:
            samp_window = window_ms/1000 * rate
            ind_range = np.arange(i-samp_window,i+samp_window).astype(int)
            nearby_dVds = dVds[ind_range]
            if False: print(i,'max', np.max(nearby_dVds))
            if False: print(i,'min', np.min(nearby_dVds))
            if np.max(nearby_dVds)>high_dv_thresh and np.min(nearby_dVds) < low_dv_thresh:
                inds.append(i)
                if False: print(inds)
    if to_plot>2:
        fig1, axs1 = plt.subplots(1,figsize = [9,2])
        axs1.plot(np.arange(len(dVds))/rate,dVds,zorder=1)
        axs1.scatter((np.arange(len(dVds))/rate)[inds],dVds[inds],color='red',zorder=2)
        plt.show()
    if len(inds)<1:
        mean_spike_rate = 0
    else:
        if mode=='isi':
            mean_spike_rate = np.mean(rate/np.diff(inds))
        elif mode=='count':
            mean_spike_rate = len(inds)/(np.sum(is_stim)/rate)
        else:
            print('invalid mode. using default (count)')
    return dVds, over_thresh, inds, mean_spike_rate

In [None]:
def single_ap_stats(x_trace,y_trace,sample_rate,spike_args=True,window_ms=[-1, 3.5],rise_fraction=0.95,to_plot=True,verbose=False):
    '''Takes in a voltage trace from current clamp mode at/near rheobase and calulates various AP parameters'''
    
    'overwrite unspecified with dict'
    if spike_args:
        spike_args = {'spike_thresh':10,
                    'high_dv_thresh': 25,
                    'low_dv_thresh': -5,
                        'window_ms': 2}


    is_stim = np.array([True for i in range(len(x_trace))])

    args = find_spike_in_trace(y_trace,sample_rate,spike_args,is_stim = is_stim)
    inds = args[2]

    window_ind = np.arange(window_ms[0]/1000*sample_rate,window_ms[1]/1000*sample_rate)
    first_spike = inds[0]
    capture_indicies = np.array(window_ind+first_spike,dtype='int')
    spike_trace_x = x_trace[capture_indicies]
    spike_trace_y = y_trace[capture_indicies]
    spike_trace_dvds = np.diff(spike_trace_y,prepend=spike_trace_y[0])*sample_rate/1000


    v_max = np.max(spike_trace_y)
    v_max_ind = np.where(spike_trace_y==v_max)[0]
    fahp_wind = np.array(np.arange(0,2/1000*sample_rate),dtype='int')+v_max_ind


    fast_after_hyperpol = np.min(spike_trace_y[fahp_wind])
    fast_after_hyperpol_ind= np.where(spike_trace_y==fast_after_hyperpol)[0]

    dv_max = np.max(spike_trace_dvds)
    dv_min = np.min(spike_trace_dvds)

    ap_start_ind = np.where(spike_trace_dvds>spike_args['spike_thresh'])[0][0]
    
    v_baseline= np.median(  spike_trace_y[np.arange(ap_start_ind)] )
    v_half = np.mean([v_max,v_baseline])

    
    over_half_ind = np.where( spike_trace_y>=v_half )[0]
    half_start = over_half_ind[0]
    half_stop = over_half_ind[-1]
    ap50_width_ms = (spike_trace_x[half_stop] - spike_trace_x[half_start])*1000

    fractional_peak = v_baseline+rise_fraction*(v_max-v_baseline)
    fractional_base = v_baseline+(1-rise_fraction)*(v_max-v_baseline)
    rising_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds>0)
    falling_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds<0)

    rise_time_ms = len(spike_trace_x[rising_bool])/sample_rate*1000
    fall_time_ms = len(spike_trace_x[falling_bool])/sample_rate*1000

    ap_params = {'v_max':v_max,
                 'fast_after_hyperpol':fast_after_hyperpol,
                 'v_baseline':v_baseline,
                 'v_half':v_half,
                 'ap50_width_ms':ap50_width_ms,
                 'rise_time_ms':rise_time_ms,
                 'fall_time_ms':fall_time_ms,
                 'dv_max':dv_max,
                 'dv_min':dv_min,}


    if to_plot:
        fig, axs = plt.subplots(1,1)
        axs.plot(spike_trace_x,spike_trace_y,'k.-')
        axs.scatter(spike_trace_x[ap_start_ind],spike_trace_y[ap_start_ind],color='red')
        axs.scatter(spike_trace_x[fast_after_hyperpol_ind],spike_trace_y[fast_after_hyperpol_ind],color='magenta')
        axs.plot(spike_trace_x[rising_bool],spike_trace_y[rising_bool],color='red' )
        axs.plot(spike_trace_x[falling_bool],spike_trace_y[falling_bool],color='cyan' )
        # plt.plot(spike_trace_x[fahp_wind],spike_trace_y[fahp_wind],color='magenta')
        plt.show()
    if verbose:
        print(ap_params)
        # print('spike_trace_x[half_start]',spike_trace_x[half_start])
        # print('spike_trace_x[half_stop]',spike_trace_x[half_stop])
    
    return ap_params