<a href="https://colab.research.google.com/github/dtabuena/EphysLib/blob/main/Firing_Rate_Gain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
'Functions for Basic analysis of firing rate gain parameters'
'requires: numpy (np), pandas (pd), pyabf (pyabf) , matplotlib'
from matplotlib import pyplot as plt


def analyze_gain_abf(local_abf_file_name,
                     spike_args =  {'spike_thresh':10, 'high_dv_thresh': 25,'low_dv_thresh': -5,'window_ms': 2},
                     R2_thresh = 0.8,
                     to_plot = 0):
    
    'Analyze Single ABF of increasing current injections for firing rate gain'
    'to_plot scales from 0:2, no plot, plot just the final fitting, plot every sweep for spike detection'

    abf = pyabf.ABF( local_abf_file_name )
    if len(abf.sweepList)<5: #print( 'not enough sweeps')
        return np.nan, np.nan, np.nan, np.nan, np.nan , np.nan 
    is_base, is_stim = protocol_baseline_and_stim(abf)
    stim_currents, spike_counts, spike_rates,_,v_before_stim, fire_dur = spikes_per_stim(abf,spike_args, mode='count', to_plot=to_plot)
    if sum(spike_counts)==0:    #if no spikes return none
        return np.nan, np.nan, np.nan, np.nan, np.nan , np.nan 
    if_fit = fit_firing_gain( stim_currents, spike_counts, spike_rates ,to_plot=to_plot>0)
    gain_slope = if_fit['slope']
    R2 = if_fit['R2']

    # 'all or none fail'
    value_list = [gain_slope, R2, stim_currents, spike_counts, v_before_stim]    
    is_nan_val = [np.isnan(v) for v in value_list]
    is_nan_val = [v if len_one(v)==1 else any(v) for v in is_nan_val ]

    if is_nan_val.count(True) > 0:
        val_names = ['gain_slope', 'R2', 'stim_currents', 'spike_counts', 'v_before_stim']
        bad_vals = [val_names[i] for i in range(len(is_nan_val)) if is_nan_val[i]]
        print('\n')
        print('one fail all on ', local_abf_file_name)
        print('     ', bad_vals, ' = nan')
        
        gain_slope = np.nan
        R2 = np.nan
        stim_currents = np.nan
        spike_counts = np.nan
        v_before_stim = np.nan
        fire_dur = np.nan


    return gain_slope, R2, stim_currents, spike_counts, v_before_stim, fire_dur

def len_one(x):
    try:
        return len(x)
    except:
        return 1

def spikes_per_stim(abf,spike_args,thresh=20,mode='count', to_plot=0):
    'Loops through sweeps of and abf to find spikes'
    # init
    stim_currents = []
    spike_rates = []
    spike_counts = []
    v_before_spike1 = []
    v_before_stim = []
    fire_dur = []
    # get sweep info
    is_base, is_stim = protocol_baseline_and_stim(abf)

    # get spike per sweep
    for s in abf.sweepList:
        abf.setSweep(s)
        dVds, over_thresh, inds, mean_spike_rate = find_spike_in_trace(abf.sweepY,abf.sampleRate,spike_args,is_stim=is_stim,mode='count',to_plot=to_plot)
        rel_firing_duration = check_inactivation( abf.sweepX, abf.sweepY, is_stim, abf.sampleRate, dVds, inds, mean_spike_rate, to_plot=0 )
        # plot id'd spikes
        if to_plot>1:
            fig, axs = plt.subplots(1)
            axs.scatter(abf.sweepX[inds],abf.sweepY[inds],color='red',zorder=2)
            axs.plot(abf.sweepX ,abf.sweepY,zorder=1)
            plt.show()
        # calc multi sweep params
        stim_level = np.median(abf.sweepC[is_stim])
        stim_currents.append(stim_level)
        spike_rates.append(mean_spike_rate)
        spike_counts.append(len(inds))
        is_prestim = np.equal(np.cumsum( np.diff(is_base,prepend=1)),0)
        v_before_stim.append( np.mean(abf.sweepY[is_prestim] ))
        fire_dur.append(rel_firing_duration)

        if len(inds)>0:
            v_before_spike1.append(abf.sweepY[inds[0]])
        else:
            v_before_spike1.append(np.nan)

    return np.array(stim_currents), np.array(spike_counts), np.array(spike_rates), np.array(v_before_spike1), np.array(v_before_stim) , np.array(fire_dur)


def find_spike_in_trace(trace,rate,spike_args,refract=0.005,is_stim = None ,mode='count',sanity_check=True,to_plot=0):
    'Takes in a voltage trace from current clamp mode and uses derivative (dVds) to find action potentials.'
    'Returns the dVds trace, boolean array indicating if dVds>threshold, inicies where dV crossed threshold,'
    'and the mean firing rate given # spikes in trace of given length. Optional ways to count are:'
    'isi (1/interspike interval) or count (spike count per second). Default is count'

    high_dv_thresh = spike_args['high_dv_thresh']
    low_dv_thresh = spike_args['low_dv_thresh']
    spike_thresh = spike_args['spike_thresh']
    window_ms = spike_args['window_ms']

    if any(is_stim == None):
        is_stim = [True for i in trace]
    dVds = np.diff(trace, prepend=trace[0])*rate/1000
    over_thresh = dVds>spike_thresh
    over_thresh[np.logical_not(is_stim)] = False
    refract_window = int(np.round((refract*rate)))
    inds = [t for t in np.arange(refract_window,len(over_thresh)) if all([over_thresh[t], all(over_thresh[t-refract_window:t]==False)])]    
    if sanity_check:
        old_inds = inds
        inds = []
        for i in old_inds:
            samp_window = window_ms/1000 * rate
            ind_range = np.arange(i-samp_window,i+samp_window).astype(int)
            nearby_dVds = dVds[ind_range]
            if False: print(i,'max', np.max(nearby_dVds))
            if False: print(i,'min', np.min(nearby_dVds))
            if np.max(nearby_dVds)>high_dv_thresh and np.min(nearby_dVds) < low_dv_thresh:
                inds.append(i)
                if False: print(inds)
    if to_plot>2:
        fig1, axs1 = plt.subplots(1,figsize = [9,2])
        axs1.plot(np.arange(len(dVds))/rate,dVds,zorder=1)
        axs1.scatter((np.arange(len(dVds))/rate)[inds],dVds[inds],color='red',zorder=2)
        plt.show()
    if len(inds)<1:
        mean_spike_rate = 0
    else:
        if mode=='isi':
            mean_spike_rate = np.mean(rate/np.diff(inds))
        elif mode=='count':
            mean_spike_rate = len(inds)/(np.sum(is_stim)/rate)
        else:
            print('invalid mode. using default (count)')
    return dVds, over_thresh, inds, mean_spike_rate

def fit_firing_gain(stim_currents, spike_counts, spike_rates, to_plot=False):
    'Gathers the firing rate of each stimuli and fits the linear portion of the curve to return the Gain in Hz/pA (the slope)'
    is_pos_slope = np.diff(spike_counts,prepend=0)>0
    is_pos_slope = movmean(np.diff(spike_counts,prepend=0),4)>0
    peak_ind = np.where(spike_counts==np.max(spike_counts))[0]
    if len(peak_ind)>1:
        peak_ind = np.min(peak_ind)
    
    before_peak = np.arange(len(spike_counts))<peak_ind
    is_nonzero = spike_counts>0
    use_for_fit = np.logical_and.reduce((is_pos_slope,is_nonzero,before_peak))

    if_fit = {}
    if_fit['stim_currents'] = stim_currents
    if_fit['spike_rates'] = spike_rates
    if 0 == np.sum(spike_rates):
        # print('no spikes detected')
        if_fit['slope'] = np.nan
        if_fit['intercept'] = np.nan
        if_fit['R2'] = 0
        return if_fit

    if np.sum(spike_rates>0)<3:
        # print('not enough spikes generated')
        if_fit['slope'] = np.nan
        if_fit['intercept'] = np.nan
        if_fit['R2'] = 0
        return if_fit

    if_fit['slope'], if_fit['intercept'] , r_value, p_value, std_err = stats.linregress(stim_currents[use_for_fit], spike_rates[use_for_fit])
    if_fit['R2'] = r_value**2

    if to_plot:
        fig, ax = plt.subplots(1, figsize=[3,3])
        ax.scatter( if_fit['stim_currents'] ,if_fit['spike_rates'] )
        ax.plot( if_fit['stim_currents'], if_fit['slope']* if_fit['stim_currents']+if_fit['intercept'])
        ax.scatter( if_fit['stim_currents'][use_for_fit] ,if_fit['spike_rates'][use_for_fit], color='r' )
        ax.scatter(if_fit['stim_currents'][peak_ind],if_fit['spike_rates'][peak_ind], color='m')
        ax.set_xlabel('current')
        ax.set_ylabel('Spike Rate (Hz)')
        (min,max) = ax.get_ylim()
        ax.text(0, max/2, 'R**2='+str(round(if_fit['R2'],2)),fontsize='large')
        plt.show()
    return if_fit

def check_inactivation( time, trace, is_stim, sample_rate, dVds, inds, mean_spike_rate, to_plot=0 ):
    time_ms = time*1000
    sum_isi = np.nan
    rel_firing_duration = np.nan
    if len(inds)>1:
        isi = np.diff(time_ms[inds])
        sum_isi = np.sum(isi)
        stim_time = time_ms[np.where(is_stim)[0][0]]
        first_time = time_ms[inds[0]]-stim_time
        firing_duration = first_time+sum_isi
        rel_firing_duration = firing_duration /(np.max(time[is_stim]*1000)-stim_time)
    return rel_firing_duration

