<a href="https://colab.research.google.com/github/dtabuena/Patch_Ephys/blob/main/Basic_Ephys.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def spikes_per_stim(abf,spike_args,mode='count', to_plot=0):
    '''Loops through sweeps of an abf to find spikes'''
    # init
    stim_currents = []
    spike_rates = []
    spike_counts = []
    v_before_spike1 = []
    v_before_stim = []
    fire_dur = []
    inds_list=[]
    # get sweep info
    is_base, is_stim = protocol_baseline_and_stim(abf)

    # get spike per sweep
    for s in abf.sweepList:
        abf.setSweep(s)
        dVds, over_thresh, inds, mean_spike_rate = find_spike_in_trace(abf.sweepY,abf.sampleRate,spike_args,is_stim=is_stim,mode='count',to_plot=to_plot)
        rel_firing_duration = check_inactivation( abf.sweepX, abf.sweepY, is_stim, abf.sampleRate, dVds, inds, mean_spike_rate, to_plot=0 )
        # plot id'd spikes
        if to_plot>1:
            fig, axs = plt.subplots(1)
            axs.scatter(abf.sweepX[inds],abf.sweepY[inds],color='red',zorder=2)
            axs.plot(abf.sweepX ,abf.sweepY,zorder=1)
            plt.show()
        # calc multi sweep params
        stim_level = np.median(abf.sweepC[is_stim])
        stim_currents.append(stim_level)
        spike_rates.append(mean_spike_rate)
        spike_counts.append(len(inds))
        is_prestim = np.equal(np.cumsum( np.diff(is_base,prepend=1)),0)
        v_before_stim.append( np.mean(abf.sweepY[is_prestim] ))
        fire_dur.append(rel_firing_duration)
        inds_list.append(inds)
        if len(inds)>0:
            v_before_spike1.append(abf.sweepY[inds[0]])
        else:
            v_before_spike1.append(np.nan)


    pulse_dur = len(is_stim[is_stim])/abf.sampleRate
    try:
        max_fire = np.max(spike_counts)
        ind_max = np.where(spike_counts==max_fire)[0][0]
        fire_dur_max = fire_dur[ind_max]
    except: fire_dur_max = None


    time_offest = abf.sweepX[is_stim][0]
    spike_times = [abf.sweepX[il]-time_offest for il in inds_list]
    isi_rates = mean_inst_firing_rate(spike_times)

    results_dict={}
    results_dict['stim_currents'] = np.array(stim_currents)
    results_dict['spike_counts'] = np.array(spike_counts)
    results_dict['spike_rates'] = np.array(spike_rates)
    results_dict['v_before_spike1'] = np.array(v_before_spike1)
    results_dict['v_before_stim'] = np.array(v_before_stim)
    results_dict['fire_dur'] = np.array(fire_dur_max)
    results_dict['isi_rates'] = np.array(isi_rates)
    results_dict['spike_times'] = spike_times ## DO NOT MAKE ARRAY ?

    return results_dict

In [None]:
def find_spike_in_trace(trace,rate,spike_args,refract=0.005,is_stim = None ,mode='count',sanity_check=True,to_plot=0):
    '''
    Takes in a voltage trace from current clamp mode and uses derivative (dVds) to find action potentials.
    Returns the dVds trace, boolean array indicating if dVds>threshold, inicies where dV crossed threshold,
    and the mean firing rate given # spikes in trace of given length. Optional ways to count are:
    isi (1/interspike interval) or count (spike count per second). Default is count.


    Spike args is a dict containing thesholds for spike dtection including dvds rising threshold and falling thresholds.
    spike_thresh: the minimum dVdT for determining AP occurnance. Also used as the 'start' of AP
    high_dv_thresh: An AP must also cross this higher threshold to be considered. eg: after crossing 20mv/ms the AP must continue to rise to >40mV/ms,
    this is used to filter out incomplete APs that may be undergoing depolarization/inactivation block.
    low_dv_thresh & window_ms: strong stimuli may triger the spike_thresh purely based on charging the membrane.
    Thus to be considered a true AP, there must also be a falling phase to the wave from that occurs shortly after, eg.  min(dVdS)<-20mV/ms within 2ms of waveform start
    '''


    high_dv_thresh = spike_args['high_dv_thresh']
    low_dv_thresh = spike_args['low_dv_thresh']
    spike_thresh = spike_args['spike_thresh']
    window_ms = spike_args['window_ms']

    if any(is_stim == None):
        is_stim = [True for i in trace]
    dVds = np.diff(trace, prepend=trace[0])*rate/1000
    over_thresh = dVds>spike_thresh
    over_thresh[np.logical_not(is_stim)] = False
    refract_window = int(np.round((refract*rate)))
    inds = [t for t in np.arange(refract_window,len(over_thresh)) if all([over_thresh[t], all(over_thresh[t-refract_window:t]==False)])]
    if sanity_check:
        old_inds = inds
        inds = []
        for i in old_inds:
            samp_window = window_ms/1000 * rate
            ind_range = np.arange(i-samp_window,i+samp_window).astype(int)
            nearby_dVds = dVds[ind_range]
            if False: print(i,'max', np.max(nearby_dVds))
            if False: print(i,'min', np.min(nearby_dVds))
            if np.max(nearby_dVds)>high_dv_thresh and np.min(nearby_dVds) < low_dv_thresh:
                inds.append(i)
                if False: print(inds)
    if to_plot>2:
        fig1, axs1 = plt.subplots(1,figsize = [9,2])
        axs1.plot(np.arange(len(dVds))/rate,dVds,zorder=1)
        axs1.scatter((np.arange(len(dVds))/rate)[inds],dVds[inds],color='red',zorder=2)
        plt.show()
    if len(inds)<1:
        mean_spike_rate = 0
    else:
        if mode=='isi':
            mean_spike_rate = np.mean(rate/np.diff(inds))
        elif mode=='count':
            mean_spike_rate = len(inds)/(np.sum(is_stim)/rate)
        else:
            print('invalid mode. using default (count)')
    return dVds, over_thresh, inds, mean_spike_rate

In [None]:
def protocol_baseline_and_stim(abf):
    'Return two boolean arrays, distiguishing holding I/V and electrical stimuli'
    # use command signal variance to determine stimulus periods
    commands = []
    for s in abf.sweepList:
        abf.setSweep(sweepNumber=s)
        commands.append(abf.sweepC)
    commands = np.stack(commands)

    std = np.std(commands, axis=0)
    is_base = std==0
    is_stim = np.logical_not(is_base)
    return is_base, is_stim

In [None]:
def command_match(abf,error_thresh = .05):
    abf.setSweep(0,channel=0)
    desired_command = abf.sweepC
    abf.setSweep(0,channel=1)
    observed_command = abf.sweepY

    if abs((np.var(desired_command)-np.var(desired_command))/np.var(desired_command)) > error_thresh:
        print('Not Correct Command')
        return False

    if abs((np.mean(desired_command)-np.mean(desired_command))/np.mean(desired_command)) > error_thresh:
        print('Not Correct Command')
        return False

    return True

In [None]:
# def single_ap_stats(abf,spike_args,window_ms=[-3, 6.5],rise_fraction=0.90,to_plot=True,verbose=False,up_sample = True):
#     '''Takes in a voltage trace from current clamp mode at/near rheobase and calulates various AP parameters'''

#     "Hard Code Threshold"
#     spike_args['spike_thresh'] = 20



#     # window_ms=[-1, 3.5]
#     x_trace = abf.sweepX
#     y_trace = abf.sweepY
#     sample_rate = abf.sampleRate

#     fahp_max_delay = window_ms[1]

#     if up_sample:
#         factor = 4
#         x_new = np.linspace(x_trace[0],x_trace[-1], num=len(x_trace)*factor )
#         interp_func = scipy.interpolate.interp1d(x_trace, y_trace, kind='quadratic')
#         y_trace = interp_func(x_new)
#         x_trace = x_new
#         sample_rate = sample_rate*factor


#     is_stim = np.array([True for i in range(len(x_trace))])

#     args = find_spike_in_trace(y_trace,sample_rate,spike_args,is_stim = is_stim)
#     inds = args[2]

#     window_ind = np.arange(window_ms[0]/1000*sample_rate,window_ms[1]/1000*sample_rate)
#     first_spike = inds[0]
#     capture_indicies = np.array(window_ind+first_spike,dtype='int')

#     spike_trace_x = x_trace[capture_indicies]
#     spike_trace_y = y_trace[capture_indicies]


#     spike_trace_dvds = np.diff(spike_trace_y,prepend=spike_trace_y[0])*sample_rate/1000


#     v_max = np.max(spike_trace_y)
#     v_max_ind = np.where(spike_trace_y==v_max)[0]
#     fahp_wind = np.array(np.arange(0,fahp_max_delay/1000*sample_rate),dtype='int')+v_max_ind

#     fahp_wind = fahp_wind[fahp_wind<len(spike_trace_x)]

#     fast_after_hyperpol = np.min(spike_trace_y[fahp_wind])
#     fast_after_hyperpol_ind= np.where(spike_trace_y==fast_after_hyperpol)[0]

#     dv_max = np.max(spike_trace_dvds)
#     dv_min = np.min(spike_trace_dvds)

#     ap_start_ind = np.where(spike_trace_dvds>spike_args['spike_thresh'])[0][0]

#     # v_baseline -> ap_thresh_us
#     ap_thresh_us= spike_trace_y[ap_start_ind]
#     print('ap_thresh_us', spike_trace_y[ap_start_ind])
#     v_half = np.mean([v_max,ap_thresh_us])

#     over_half_ind = np.where( spike_trace_y>=v_half )[0]
#     half_start = over_half_ind[0]
#     half_stop = over_half_ind[-1]
#     ap50_width_ms = (spike_trace_x[half_stop] - spike_trace_x[half_start])*1000

#     fractional_peak = ap_thresh_us+rise_fraction*(v_max-ap_thresh_us)
#     fractional_base = ap_thresh_us+(1-rise_fraction)*(v_max-ap_thresh_us)
#     rising_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds>0)
#     falling_bool = np.array(spike_trace_y>=fractional_base) * np.array(spike_trace_y<=fractional_peak) * np.array(spike_trace_dvds<0)

#     rise_time_ms = len(spike_trace_x[rising_bool])/sample_rate*1000
#     fall_time_ms = len(spike_trace_x[falling_bool])/sample_rate*1000

#     fast_after_hyperpol = fast_after_hyperpol - ap_thresh_us

#     ap_amplitutude = v_max-ap_thresh_us

#     ap_params = {'ap_amplitutude':ap_amplitutude,
#                  'fast_after_hyperpol':fast_after_hyperpol,
#                  'AP_thresh_US':ap_thresh_us,
#                  'v_half':v_half,
#                  'ap50_width_ms':ap50_width_ms,
#                  'rise_time_ms':rise_time_ms,
#                  'fall_time_ms':fall_time_ms,
#                  'dv_max':dv_max,
#                  'dv_min':dv_min,}


#     if to_plot:
#         fig, axs = plt.subplots(2,1,figsize=(8,6))
#         for ax in axs:
#             ax.plot(spike_trace_x,spike_trace_y,'k.-')
#             ax.scatter(spike_trace_x[ap_start_ind],spike_trace_y[ap_start_ind],color='red')
#             ax.scatter(spike_trace_x[fast_after_hyperpol_ind],spike_trace_y[fast_after_hyperpol_ind],color='magenta')
#             ax.plot(spike_trace_x[rising_bool],spike_trace_y[rising_bool],color='red' )
#             ax.plot(spike_trace_x[falling_bool],spike_trace_y[falling_bool],color='cyan' )
#             # yright = ax.twinx()
#             # dy = np.diff(spike_trace_y,prepend=spike_trace_y[0])/(np.diff(spike_trace_x[0:2])[0]*1000) # np.diff(spike_trace_y,append=spike_trace_y[-1]))/2
#             # yright.plot(spike_trace_x,dy,'orange',marker='o')
#         # yright.set_ylim(-20,60)
#         axs[1].set_ylim(ap_thresh_us+fast_after_hyperpol-5,ap_thresh_us+5,)
#         try:    os.makedirs('Saved_Figs/AP_Params/')
#         except:     None
#         plt.savefig( 'Saved_Figs/AP_Params/AP_Params_' + abf.abfID +'.png',dpi=600)
#         plt.show()
#     if verbose:
#         print(ap_params)
#         # print('spike_trace_x[half_start]',spike_trace_x[half_start])
#         # print('spike_trace_x[half_stop]',spike_trace_x[half_stop])

#     return ap_params

In [None]:
def mean_inst_firing_rate(spike_times):
    mean_inst_rates = [1/np.mean(np.diff(s)) for s in spike_times]
    mean_inst_rates = [np.round(r,1) for r in mean_inst_rates]
    mean_inst_rates = [0 if np.isnan(r) else r for r in mean_inst_rates ]
    return mean_inst_rates

In [None]:
def initial_inst_firing_rate(sweepX,inds_list,num_spikes=2,to_plot=False):
    num_spikes = np.max([num_spikes,2])
    rate_list = []
    for inds in inds_list:
        times = [sweepX[i] for i in inds]
        if len(times>=num_spikes):
            isi=np.mean(np.diff(times[:num_spikes+1]))
            rate = 1/isi
        else: rate = 0

        rate_list.append(rate)
    rate_list = np.array(rate_list)
    return rate_list

In [None]:
print('basic_ephys loaded succesfully')