# Day 3: Using Pandas Data Frames to analyze single cell electrophysiology data


<img src="http://www.zocalopublicsquare.org/wp-content/uploads/2016/12/Mathews-on-US-China.jpg" width="300" height="300" />

Today we will analyze patch-clamp data from a single PV+ neuron in a mouse cortical slice. The experiment was performed in current clamp. The experimenter injected 20 different square pulses of current and recorded the voltage response of the neuron to each pulse. Using the data from this experiment, we will create and F-I (frequency-current) curve.

## Outline of this notebook

[3.0 Import statements](#3.0-Import-statements)

[3.1 More on for-loops and if-else statements](#3.1-for-loops-and-if-else-statements) 
* A review on for-loops and if/else statements, and some new tricks for for-loops

[3.2 Load single cell electrophysiology data from csv file](#3.2-Load-single-cell-electrophysiology-data-from-.csv-file)
* Inspect data using pandas
* Visualize data with matplotlib

[3.3 Analyze a single sweep of electrophysiology data](#3.3-Analyze-a-single-sweep-of-electrophysiology-data)
* Calculate the firing rate of the neuron during a single sweep of data

[3.4 Create an F-I (frequency - current) curve](#3.4-Create-an-F-I-curve-for-a-single-cell)
* Write a function to calculate firing frequency during one sweep (3.4)
* Employ function in a for-loop to perform the operation for all sweeps

[3.5 Bonus exercises](#3.5-Bonus-exercises)
* Calculate the input resistance of a cell
* Extract spike cutouts
    * Create a phase plot for a spike (rate of change of voltage vs. voltage)
    
[3.6 Appendix](#Appendix)
* Additional information about loading binary file formats

# 3.0 Import statements

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scipy.signal as ss

In [None]:
## Note: the following line is only needed inside the 
## Jupyter notebook, it is not a Python statement
%matplotlib inline

# 3.1 for-loops and if-else statements
Yesterday you learned how to use a for loop to iterate over the elements of a list (refer to section 2.7). Today we'll learn a couple additional tricks you can use when constructing for-loops.

Often we only want to perform an operation on the current element when a specific condition is met. For this, we use **if-else statements**. 

Though it's nice to loop over the elements in a list, sometimes we also would like the index where that element occurs. For this we can use the function **`enumerate()`**.

Let's create some dummy data in a Pandas `DataFrame` to illustrate these points:

In [None]:
# Let's use the numpy's rand() function to create an array of random floats with 
# shape (100,10)
data = np.random.rand(100,10)   # creates a 100 x 10 numpy array of random numbers between 0 and 1

# then create column names for the DataFrame
cols = ['one','two','three','four','five','six','seven','eight','nine','ten']    

# and convert our data into a DataFrame and cacluate indices
df = pd.DataFrame(columns = cols, index = range(0,data.shape[0]), data = data)
df.head(5)

### Exercise 1 - Review from day 2
> * Refer to section 2.7 from yesterday to write a for loop that loops over all the column names in `df` and prints them out.

In [None]:
# list of all column names in the df above
cols = df.columns 

# Your loop goes here
for col in cols:
    print(col)

Sometimes, we want to have access not only to the column names ('one', 'two' etc.), but also to the correspondings indexes for these elements (0, 1, 2 etc.). We can achieve this by using a built-in function [enumerate()](http://book.pythontips.com/en/latest/enumerate.html).

One example of why this might be useful is if you only want to perform an operation on the columns of **odd** indexes. With only the column name, we can't do this. Using a combination of the enumerate and if-else statements (refer to section 1.14 of Day 1) we can do this:

In [None]:
for i, col in enumerate(cols):
    if (i+1)%2 == 0:    # notice that i starts counting from 0, so we must add 1
        pass            # "pass" tells python to skip this if-else statement, finish the code block,
                        # and move on to the next iteration in the loop
    else:
        print("index: "+str(i))
        print("column name: "+col)
    

### Exercise 2

>* Use what you just learned about for-loops to write a loop that prints out only the names of the even numbered columns.

In [None]:
# Your code goes here
for i, col in enumerate(cols):
    if (i+1)%2!=0:
        pass
    else:
        print(col)

Yesterday you also learned a little bit about how to write your own functions. Below is an illustration of how functions, loops, and if-else statements can be combined in a useful way. It's a lot, but we will use these concepts later so make sure you have a good grasp on the code below.

In [None]:
def print_even_or_odd_column(dataframe, cols='even'):   # note that the default value of columns is even
    
    columns=dataframe.columns
    
    # code to be executed if you tell the function to print even columns
    if cols=='even':
        for i, col in enumerate(columns):
            if (i+1)%2==0:
                print(col)
            else:
                pass
            
    # code to be executed if you tell the function to print off columns        
    elif cols=='odd':
        for i, col in enumerate(columns):
            if (i+1)%2!=0:
                print(col)
            else:
                pass


# ==== Call the function ====
print_even_or_odd_column(df,cols='even')   # can set the last argument as: cols='odd' or cols='even'

# 3.2 Load single cell electrophysiology data from .csv file

**Set file path to day-2 of the python_neurobootcamp (concepts from Day 2 - section 2.1)**

In [None]:
#print "Current Working Directory (cwd)" to check we're in the right place and don't need absolute path
os.getcwd()

In [None]:
path = 'csv_data/'

#### The data:

As those of you who do patch clamp know, e-phys data doesn't come nicely packaged into a simple spreadhsheet format. Therefore, we have written a function that converts axon binary files (.abf) into spreadheet format (.csv) for the purposes of this course. 

If you are interested in exactly how this process works, look into the file: `binary_file_loading_tools.py` in the `day-3` directory or see Appendix section at the bottom of this notebook. It's a little bit rough, but will help give you an idea of how you can convert binary files to spreadsheets. After completing this course, you should be able to modify this file on your own to suit your needs!

A reminder about csv files: CSV = Comma-Separated Value file that stores tabular data (numbers and text) in plain text. You can save an excel spreadsheet in .csv and then import it in Python. Let's look at an example. (open up csv file from the day-3/csv_data folder in excel)

**Load spreadsheet data into pandas**

Note the use of `index_col=0`. Try to get rid of this and see what happens when you load based on default settings. For more information: https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html

In [None]:
meta = pd.read_csv("csv_data/meta_data_PV_3_10_03_2014.csv", index_col=0)
data = pd.read_csv("csv_data/data_PV_3_10_03_2014.csv",index_col=0)

In [None]:
print(meta.shape)
meta.head()

Notice that in our meta data, we have information about the sampling rate (fs), cell type, date of recording, units for ch1 data, and units for ch2 data for each sweep in the binary file that we loaded. This information is particulalrly useful for analyses in which, for example, we want to group anlayses by cell type. Today we won't use it a whole lot, but it's a common way this type of information is stored.

In [None]:
print(data.shape)
data.head()

In [None]:
data.index.name='Time'
data.head()

Though the channels aren't labeled, we can determine what each represents by the their units (pico-amps or milli-volts in this case) which we can get above (in meta data).

Notice that data contains the time series for each channel on each sweep. The indexes here represent the time points that each value was acquired at.

### Visualization using matplotlib and seaborn

We learned a bit about seaborn yesterday. Matplotlib is another widely used plotting library in Python. Everyone has their own preference for which they like to use. Here we will go through examples of each, though we will focus on matplotlib. You can decide for yourself which you prefer. 

Before we get into analysis, it's always a good idea to inspect your raw data to make sure you know what's there, and that it was loaded properly.

First let's use Pandas built-in `filter()` function to assign all ch1 traces to a `DataFrame` called `voltage_traces` and all ch2 traces to a `DataFrame` called `current_traces`. For basics of filtering, refer back to section 2.9.

Here we use regular expressions to filter the data. These can be very confusing at first, but don't worry. Regular expression are very powerful and worth spending some time getting used to. For more information: https://docs.python.org/2/library/re.html

In [None]:
# Based on meta data, we can tell that ch1 refers to voltage and ch2 refers to current
voltage_traces = data.filter(regex="ch1.", axis=1) # to filter on rows instead of columns, use axis=0
current_traces = data.filter(regex="ch2.",axis=1)

In [None]:
voltage_traces.head()

### Plot using matplotlib

In [None]:
fig, ax = plt.subplots(2,1,sharex=True)   # Create a figure and an axis. Specify the layout: (2 rows, 1 column)
                                          # there are now two axes for your plots (ax[0] and ax[1])

# plot the voltages traces on the first subplot
ax[0].plot(voltage_traces) 

# Add a figure legend (this is easy because pandas and matplotlib play nice together)
ax[0].legend(voltage_traces,loc='upper right')  
# If data wasn't in a pandas frame, could also do something more explicit, like pass it a list of names:
#  ax[0].legend(['ch1_sweep1', 'ch2_sweep2'...], loc='upper right')

# set the title of the first subplot
ax[0].set_title('Voltage traces')

# set the ylabel of the first subplot (using the meta data to get units)
ax[0].set_ylabel(meta['ch1_units'][0])


# Create the same plot in your second set of axes for current_traces
ax[1].plot(current_traces)
ax[1].legend(current_traces,loc='upper right')
ax[1].set_title('Current traces')
ax[1].set_xlabel('Time (s)')
ax[1].set_ylabel(meta['ch2_units'][0])

plt.tight_layout()  # trick to make formatting look a little nicer

### Plot using Pandas
Pandas offers wrapper functions around matplotlib that make the above code a bit more concise. We'll use this (plotting with Pandas) quite a bit in this notebook. 

Just note that what it's really doing is written out explicity above. See [here](https://pandas.pydata.org/pandas-docs/stable/visualization.html) for a little more explanation.

In [None]:
fig, ax = plt.subplots(2,1,sharex=True)

# Call plot. plot is a dataframe method that in turn calls the matplotlib function plot
voltage_traces.plot(title='Voltage Traces', ax=ax[0])
current_traces.plot(title='Current Traces', ax=ax[1])

## Position the legends
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

## Set X and Y axis labels
ax[0].set(ylabel=meta['ch1_units'][0])
ax[1].set(ylabel=meta['ch2_units'][0], xlabel='Time (s)')
plt.tight_layout()

### Plot using seaborn
Note that there are some slight tweaks you need to make to the data in this case in order to use seaborn, including downsampling the data (otherwise it takes A LONG time to run). Mostly for this reason, we'll stick with using matplotlib for the rest of today. Just be aware that you can also use seaborn if you prefer.

Reformat (melt) the data into "long form" from "wide from". This is a step that will be necessary with any data in wideform that you want to plot in this way - using tsplot.

In [None]:
## Reformat the dataframe to work with a Seaborn time-series plot
voltage_ts = voltage_traces.iloc[:,:5].copy()
voltage_ts['Time'] = voltage_ts.index
voltage_ts = pd.melt(voltage_ts, id_vars='Time', var_name='Sweep', value_name='mV')
voltage_ts['Dummy'] = 0

Downsample the data so that this runs faster. This is quirk for this type of data that is sampled at a very high sampling rate making seaborn run rather slowly.

In [None]:
## Downsample the data (take every 50th row)
voltage_ts = voltage_ts.iloc[::50,:]
voltage_ts.head()

Perform the same operations on the current trace data.

In [None]:
## Do the same for the current traces
## Reformat the dataframe to work with a Seaborn time-series plot
current_ts = current_traces.iloc[:,:5].copy()
current_ts['Time'] = current_ts.index
current_ts = pd.melt(current_ts, id_vars='Time', var_name='Sweep', value_name='pA')
current_ts['Dummy'] = 0
## Downsample the data (take every 50th row)
current_ts = current_ts.iloc[::50,:]

Now, plot the data using the seaborn function `tsplot` [time-series plot](http://seaborn.pydata.org/generated/seaborn.tsplot.html).

Note in particular the argument "unit". We must pass it something otherwise seaborn will run into issues when it tries to calculate confidence intervals for this data. Therefore we just pass a "Dummy" unit to make seaborn happy since we don't really care about the confidence interval in this case.

In [None]:
# This will take a minute to load
fig, ax = plt.subplots(2,1,sharex=True)
sns.tsplot(data=voltage_ts, time='Time', condition='Sweep', unit='Dummy', value='mV', err_style=None, ax=ax[0])
sns.tsplot(data=current_ts, time='Time', condition='Sweep', unit='Dummy', value='pA', err_style=None, ax=ax[1])
# Position the legends
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right');

These are all pretty messy and don't tell us a ton about the data. Can we make it better?

### Exercise 3

* Try plotting only the last sweep of both current and voltage so that we can see what's going on more easily.
* Use the code from above as a starting point for how to create your own plot (you should be able to get most of the
 way there with copy and paste). Ask a TA or instructor if you need help.

**Step 1)** Find the column id for the last sweep (there are multiple ways to do this).

Hint: As you might remember from Day-1, we can use "-1" to grab the last index of something. This is equivalent to using "end" in Matlab.

Hint: we can return a list of data frame columns by using our_data_frame.columns

In [None]:
lastSweep_voltage = voltage_traces.columns[-1]     
lastSweep_current = current_traces.columns[-1]

print(lastSweep_voltage)
print(lastSweep_current)

**Step 2)** Now, use these  column id's to select only the columns containing data for the last sweep and assigne these to new variables called: last_voltage and last_current

In [None]:
last_voltage = voltage_traces[lastSweep_voltage]
last_current = current_traces[lastSweep_current]

In [None]:
# Plot with Pandas
# Create a figure with two subplots (like we did above)

fig, ax = plt.subplots(2,1,sharex=True)

# plot the last voltage trace and last current trace
last_voltage.plot(title='Voltage Trace', ax=ax[0])
last_current.plot(title='Current Trace', ax=ax[1])

## Position the legends
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

## Set X and Y axis labels
ax[0].set(ylabel=meta['ch1_units'][0])
ax[1].set(ylabel=meta['ch2_units'][0], xlabel='Time (s)')
plt.tight_layout();

That looks better, but can we zoom in on some of the spikes in the voltage traces?

### Exercise 4
* Plot only time points from 0.02 to 0.10 seconds:

** Step 1) ** Define the list of indexes (times) that are greater than 0.02 and less than 0.1. Create a mask (array of True and False values) where the condition is true. Use the mask to select only the appropriate indexes (time values).

In [None]:
# Make a mask
mask = (0.02 < data.index) & (data.index < 0.1)
print(mask)
# Use mask to select the appropriate indexes (time values) ***remember these are stored in the index of our data frame
indexes = data.index[mask]  
print(indexes)

** Step 2) ** Use the indexes you just defined to select only the rows of last_voltage and last_current where time > 0.02 and time < 0.1. Save these results into two new data frames called: **last_volt_short** and **last_curr_short**

In [None]:
last_volt_short=last_voltage[indexes]
last_curr_short=last_current[indexes]

** Step 3) ** Plot the result with pandas 
* Note: I've filled in the plotting commands to save time, which means this next cell will only work if you correctly defined ** last_volt_short ** and **last_curr_short** in the cell above

In [None]:
## Plot with Pandas
## Create a figure and plot voltage and current traces
fig, ax = plt.subplots(2,1,sharex=True)
last_volt_short.plot(title='Voltage Traces', ax=ax[0])
last_curr_short.plot(title='Current Traces', ax=ax[1])
## Position the legends
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')
## Set X and Y axis labels
ax[0].set(ylabel=meta['ch1_units'][0])
ax[1].set(ylabel=meta['ch2_units'][0], xlabel='Time (s)')
plt.tight_layout();

# 3.3 Analyze a single sweep of electrophysiology data

A common way to show a neuron's response to input is to make a frequency-intensity curve (F-I curve).  The frequency of firing is plotted on the y-axis and the amplitude of the current injection that caused the spikes is plotted on the x-axis.  

From looking at these F-I curves you can identify if the cell is transiently firing, whether the spike rate reaches a maximum at some high current step, the gain of firing (the slope of the line, in Hz/pA), etc.  To construct a F-I curve we need to first find the number of spikes per unit time in each sweep and the current injection that drove those spikes. Then plot it.

First, we will perform the analysis for just one sweep of data. Then, if there is time, we will define a function to help automate the process. We will run the function on all sweeps using a for-loop, and plot the results in a F-I curve.  

Becasue we have already pulled out the last voltage trace, let's go ahead and analyze this time-series first.

The analysis will consist of a few key steps:
* Step 1: Locate the time points of all spikes
* Step 2: Find the time window and magnitude of current injection
* Step 3: Calculate the firing rate

### Step 1: Locate the time points of all spikes
* There are a couple of ways to do this. We will do it by locating the relative maxima of the voltage trace during the current injection
* We will also set a threshold above which to detect maxima (to avoid counting local maxima in the baseline (noise) as a spikes)

In [None]:
threshold = -20  
last_voltage_thresholded = last_voltage[last_voltage > threshold]  

Let's take a look at what happened when we thresholded the voltage trace:

In [None]:
## Plot with Pandas

## Note: X-axis is different for each subplot (what times were cut out by thesholding)
fig, ax = plt.subplots(2,1)
last_voltage_thresholded.plot(title='Thresholded Trace', ax=ax[0])
last_voltage.plot(title='Raw Trace', ax=ax[1])

## Position the legends
ax[0].legend(loc='upper right')
ax[1].legend(loc='upper right')

## Set X and Y axis labels
ax[0].set(ylabel=meta['ch1_units'][0], xlabel='Time (s)')
ax[1].set(ylabel=meta['ch1_units'][0], xlabel='Time (s)')
plt.tight_layout();

Using a a threshold of -20, we have selected all the voltage points above -20 mV. In other words, these are all points during which an action potential was happening.

*Scipy*, another useful Python package, has a function that finds the indices where relative extrema occur: `argrelextrema`
This function takes two required arguments: data and comparator (for details see [documentation](https://docs.scipy.org/doc/scipy-0.19.0/reference/generated/scipy.signal.argrelextrema.html)).

* Note: `last_voltage_thresholded` is a Pandas DataFrame. `argrelextrema` expects data to be an Array, not a DataFrame.
* Note: `argrelextrema` also takes an optional argument: "order" that specifies the amount of points on each size to use for the comparison.

In [None]:
indexes_of_maxima = ss.argrelextrema(data=last_voltage_thresholded.values, comparator= np.greater, order=1)
np.greater(4,3)  # np.greater is a function that compares two input arguments

The Scipy function returns indexes, *i.e.* 1,2 3... BUT we want the time points (which are stored in the index of our DataFrame)

In [None]:
spike_times = last_voltage_thresholded.index[indexes_of_maxima] 

It's possible that the voltage trace is a little noisy, even during a spike. This might cause a spike to be double counted by the Scipy function, which is only locating relative/local maxima. To fix these cases, we will impose a constraint based on the refractory period of neuron. To do this, we'll first define a constant refractory period limit:

In [None]:
refractory_period_low_limit = 0.002  # 2ms is a lower limit on the refracotry period for a neuron

Now, let's calculate the interspike intervals for all the maxima we located (using numpy's ediff1d function) and exclude any spikes that violate the refractory period.
* np.ediff1d: Calculates the difference between all adjancent elements in a 1d array: https://docs.scipy.org/doc/numpy/reference/generated/numpy.ediff1d.html
* np.arghwere: Locates the indexes where the given condition is met: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.argwhere.html

In [None]:
interspike_interval = np.ediff1d(spike_times)

Remove the refractory period violations

In [None]:
VIOLATION_indexes = np.argwhere(interspike_interval<refractory_period_low_limit)
spike_times = spike_times.delete(VIOLATION_indexes)
print(spike_times)

We now have an array of time points where we think spikes occured. Let's plot them to see if we were successful

In [None]:
## Plot with Pandas
fig, ax = plt.subplots(1,1)
ax = last_voltage.plot(title='Raw Trace')
ax.set(ylabel=meta['ch1_units'][0], xlabel='Time (s)')

# plot a point at each time point where a spike was detected
ax.plot(spike_times, last_voltage[spike_times], ".", markersize=20); 

This is looking pretty good. 

Now, let's zoom in on some spikes to make sure we really are counting each individual spike.

In [None]:
## Plot with Pandas
fig, ax = plt.subplots(1,1)
ax = last_volt_short.plot(title='Raw Trace')
ax.set(ylabel=meta['ch1_units'][0], xlabel='Time (s)')
ax.plot(last_volt_short[spike_times], ".", markersize=9); # plot a point at each time point where a spike was detected

Great, looks like we are capturing all the spikes, not double counting, and not counting any noise as spikes. 

One final check we can perform is just to look at all the raw ISI (interspike interval) values to make sure there wasn't a change in frequency over the course of the stimulation (high firing at the beginnning, slower at the end). To do this, let's define a variable called ISI_check using numpy's ediff1d funciton:

In [None]:
ISI_check = np.ediff1d(spike_times)

In [None]:
plt.figure()
plt.hist(ISI_check*1000,bins=50) # convert to ms by multiplying by 1000
plt.ylabel('Count')
plt.xlabel('Interspike interval (ms)');

### Exercise 5
* Go back and re-run everything from "Step 1" but this time, DO NOT run the cells that corrected for refractory period violations (the cell that begins with: VIOLATION_indexes...) This should illustrate why this final histogram can be a good check on your analysis.

### Step 2: Find out the time window of current injection
This is pretty straightforward for this experiment because as we saw earlier, it's just square pulse injections. Therefore, we can just find the time points where `current == max(current).

In [None]:
current_magnitude = max(last_current.values)  
current_inj_times = last_current.index[last_current==current_magnitude] # time points when current is being injected   
current_inj_length = max(current_inj_times) - min(current_inj_times) # duration in seconds
print("Current injection: "+str(current_magnitude)+" pA")
print("Current injection times:")
print(current_inj_times)
print('Current injection length: '+str(current_inj_length)+" s")

Before we calculate the firing frequecny, there is one last check we must perform. It is possile that some spontaneous spikes occured outside the current injection window (it didn't happen in this case, but it could and that could change our results). To handle this, we remove all spikes times outside of the current injection window using a boolean mask (list of `True` and `False`) which we can then use to only grab indixes where there is a one:

In [None]:
boolean_array = (max(current_inj_times)>= spike_times) & (spike_times >= min(current_inj_times)) # boolean array (0's and 1's)
print(boolean_array)

Now, use the boolean mask to get only the spike times inside the appropriate time window

In [None]:
spike_times = spike_times[boolean_array]   # only saving spike times where boolean_array=True e.g. during current injection

### Step 3: Calculate the firing rate 
There are couple of different ways to calculate firing frequency. They are both useful and tell you different things.  
The most straighforward way is to simply count the number of spikes and divide by the amout of time (in seconds) that the current was being injected.

In [None]:
n_spikes = len(spike_times)                 # count the number of spikes
spike_freq = n_spikes/current_inj_length    # divide the number of spikes by the curent injection time

print("spiking frequency = " + str(round(spike_freq,2))+" Hz")

The second way to go about this is to calculate the average interspike interval. This is particularly useful if the neuron does not continue to fire or adapts during the current injection.

In [None]:
ISI = np.ediff1d(spike_times)     # interspike interval (ISI) (as we did for the histogram)
meanISI = np.mean(ISI)            # calculate the mean ISI
spike_freq = 1/meanISI            # ISI is in units of seconds, do: 1s/meanISI to convert to mean firing frequency

print("spiking frequency = " +str(round(spike_freq,2))+"Hz")

For this case, both methods give us basially the same result because as we saw above, this neuron fires pretty regularly during the entire stimulation. This isn't necesarily the case, so it's good to check. 

####  Summary of single sweep analysis 
Although what we really want is the F-I curve for the cell, analyzing one sweep at a time, like we did here, is often a good idea. This way, you can make sure that the code you write is really analyzing the data in the way you intend. Plus, at this point we have basically written all the code we need to analyze any sweep of data. In order to perform the analysis for the whole experiment, we just need to wrap all the code we wrote into a function that can be call iterativley from a for-loop!

# 3.4 Create an F-I curve for a single cell

To do this, we will perform the above analysis many times. Once for each current injection/sweep. 

Because we are doing the same thing many times, rather than write all of the above code 20 times (once for each sweep), we'll write a for-loop that performs the same operation many times

When you are performing an operation many times, it is often nice to write function to perform this operation. We will write a function that returns the firing frequency for any sweep.

**Let's write a function that returns the firing rate for a single sweep of data:**

Note: this looks like a lot of code, but, everything inside this function is stuff we already wrote above!

In [None]:
def get_firing_rate(v, threshold, tstart, tend, method="ISI"):
    '''
    ================ This is called a doc string ====================
    ========== It tells you how to use the function =================
    ======= It is what is printed when you call "help" ==============
    Arguments:
        v (data frame): array of voltage values during one sweep with the index being the time series
        threshold (float): voltage cut-off to count spikes
        tstart (float): time current injection begins
        tend (float): time current injection ends
        method (string, optional): method for calculating the firing rate
                ISI: use interspike interval
                AVG: use average over whole current injection window
    Output:
        firing_rate (float): firing rate during the period defined by tstart and tend
    '''
    refractory_limit=0.002 # set refractory period limit (to avoid double counting spikes)
    current_duration = tend-tstart
    
    v_thresh = v[v>threshold]
    
    spike_indexes = ss.argrelextrema(v_thresh.values, np.greater,order=1)  # indices of spikes (0,1,2...)
    spike_times = v_thresh.index[spike_indexes]  # convert to time (in seconds)   
    
    # get rid of refractory violations
    interspike_interval = np.ediff1d(spike_times)
    VIOLATION_indexes = np.argwhere(interspike_interval<refractory_limit)
    spike_times = spike_times.delete(VIOLATION_indexes)    # delete any spike times that occured too soon after a previous spike
    
    # get rid of spikes outside the current injection window
    tf = (tend >= spike_times) & (spike_times >= tstart) # boolean array (0's and 1's)
    spike_times = spike_times[tf]   # only saving spike times where tf==1 e.g. during current injection
    
    if method=="ISI":
        # calculate the firing frequency using interspike interval
        # in case there were no spikes, set firing rate to 0 Hz
        if spike_times.size==0:
            firing_rate=0  # there were no spikes, so firing rate is 0
        else:
            ISI = np.mean(np.ediff1d(spike_times))
            firing_rate = 1/ISI
    elif method=="AVG":
        firing_rate = len(spike_times)/current_duration
    
    return firing_rate

There is alot going on in that function. If you are confused on any part of it, please ask the instructor or TAs for assitance. 
* You should understand how the function works, and how to modify it if you need/want to
* You should understand the difference between required and optional function arguments
* You should understand the if-else statements at the end of the function definition

Note: check out what happens when you call "help" on the function we just wrote:

In [None]:
help(get_firing_rate)

**Write a for-loop to calculate the firing rate for every sweep of data**

First, let's intitalize the constants that we will use throughout the analysis
    * threshold
    * current_duration

In [None]:
# set threshold above which to detect spikes
threshold = -20   

# Find current duration (same for all sweeps so we can just do this once)
current_magnitude = max(last_current.values)  
current_inj_times = last_current.index[last_current==current_magnitude] # time points when current is being injected  
I_start = min(current_inj_times)
I_end = max(current_inj_times)
current_duration = max(current_inj_times) - min(current_inj_times) # duration in seconds   

Loop over all the sweeps, find the current magnitude and firing rate, print them out

In [None]:
sweeps = meta.index # define the list of sweeps to loop over

In [None]:
for i, sweep in enumerate(sweeps):
    
    # Call our function that we wrote above (get_firing_rate)    
    firing_rate = get_firing_rate(data['ch1_'+sweep], threshold, I_start, I_end, method="ISI")  
    
    # get current injection magnitude
    # Note how we deal with negative current injections using if-else statements
    I_mag = max(data['ch2_'+sweep])
    
    if I_mag == 0 and min(data['ch2_'+sweep])==0:      # case where current = 0
        I_mag = 0
    elif I_mag == 0 and min(data['ch2_'+sweep]) != 0:  # case where current < 0
        I_mag = min(data['ch2_'+sweep])
    
    # print the result
    print(sweep+":  "+"I: "+str(I_mag)+" "+meta['ch2_units'][i]+ ",   Firing rate: "+str(firing_rate)+" Hz")

There's a lot of stuff going on here as well. Some key takeaways:
* Make sure you understand the for loop, what it's doing, and why it's useful
* Make sure that the if-else statements inside the for loop are making sense
* If you are confused on these points, ask the instructor or a TA

Now that we understand how the code above is working, let's run it again but this time actually save the results instead of just printing them out. To save our results, let's first create a new data frame called "results" to store everything.

In [None]:
results = pd.DataFrame(columns=["Current Injection", "Firing Rate"],index = meta.index)
results.head()

We've now created an empty data frame to store the results for all data sweeps ("NaN" means "empty" or "undefined"). Using a for loop, let's again calculate the current injection and firing rate for each sweep and fill up the empty results DataFrame.

In [None]:
# copy and paste the code from above, but save results this time
for i, sweep in enumerate(sweeps):
    firing_rate = get_firing_rate(data['ch1_'+sweep],threshold,I_start,I_end,method="ISI")
    
    I_mag = max(data['ch2_'+sweep])
    if I_mag == 0 and min(data['ch2_'+sweep])==0:
        I_mag = 0
    elif I_mag == 0 and min(data['ch2_'+sweep]) != 0:
        I_mag = min(data['ch2_'+sweep])
       
    results['Current Injection'][sweep]=I_mag
    results['Firing Rate'][sweep]=firing_rate

Now, we have stored all our results into a new data frame.

In [None]:
results

And we can create our F-I plot for this neuron

In [None]:
## Plot with Pandas
plt.figure()
ax = results.plot('Current Injection', 'Firing Rate', title='F-I Curve', marker=".", color="r", lw=2, markersize=5, alpha=0.6)
ax.set(xlabel='Current '+meta['ch2_units'][0], ylabel='Firing rate (Hz)');

Cool, we now have an F-I curve for this PV neuron. 

### Final Excercise: do the same analysis but calculate frequency as n_spikes/time instead of using the ISI as we just did
> * Add the result to the DataFrame called `results` and name this new column `"Firing rate (avg)"`
**Hint**: one way to add a column to a DataFrame is to make a series and then add it to the original DataFrame
    * newcolumn = pd.Series(index=, name=)
    * pd.concat([results,newcolumn],axis=1)
* Once you've accomplished all this, plot the results alongside the trace above

[More on how to create  new column using Pandas series](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.Series.html)

In [None]:
# Create a new, empty column with the correct number of rows using pd.Series()
newColumn = pd.Series(index=sweeps, name="Firing Rate (avg)")

Now loop over all the data and fill in the new column.

** Remember: ** Useful variables (like threshold, I_start and I_end) are already defined above so we can use them again here!

In [None]:
# Write a loop to calculate firing rate and save it into the the new column
for i, sweep in enumerate(sweeps):
    
    firing_rate = get_firing_rate(data['ch1_'+sweep],threshold,I_start,I_end, method="AVG")
    
    newColumn[sweep]=firing_rate


Concatenate `newColumn` with results
Check out [this](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.concat.html) for help. Or recall how we concatenated data frames during yesterday's lesson.

In [None]:
# concatenate data frame
results_big = pd.concat([results,newColumn],axis=1)
results_big


Plot the results for both methods on the same figure. You can steal (copy and paste) a lot of code from above to do this. You want to create a single plot with two lines on it. 

In [None]:
# Use examples of plotting from above
plt.figure()
plt.plot(results['Current Injection'].values, results['Firing Rate'].values, ".-", color="r", lw=2,
         markersize=5, alpha=0.6)
plt.plot(results['Current Injection'].values, results['Firing Rate (avg)'].values, ".-", color="b", lw=2,
         markersize=5, alpha=0.6)
plt.xlabel('Current '+meta['ch2_units'][0])
plt.ylabel('Firing rate (Hz)')
plt.legend(['ISI', 'avg'])
plt.title('F-I curve')
plt.show()


Basically, we get the same result. However, we do see a clear difference for the first positive current step, when the cell did not continue to fire over the course of the current injection but only fired a burst of spikes at the beginning of stimulation

# 3.5 Bonus exercises

### 1. Calculate the input resistance of the cell (Easier)

Recall Ohm's law: voltage = current * resistance (V=IR). We know the current and we know the voltage so we can solve for R, the cell's input resistance. Using ohms law and the hints below, solve for this cell's input resistance.

* Hint 1: Input resistance is a passive property of a neuron, therefore we'll have to measure it when voltage activated channels are not opening/closing. Remember from our analysis above, there were a bunch of negative current steps that elicited no action potentials.

* Hint 2: Even though the negative current steps didn't cause action potentials to fire there could still have been some channels opened, changing the resistance of the cell (for example, Ih currents). Therefore, it's probably best to use voltage measurements when the cell was in a steady state (later on in the current injection).

* Hint 3: Ohms law is a linear relationship. In other words, if we plot voltage vs. current we should see a straight line (if no channels were opening and closing). The slope of this line is R. Plot voltage vs. current for all negative current injections and meausre the slope at the point where the line is most linear.

In [None]:
# ==============================   Solution    ========================================

negative_current_sweeps = results.index[[i for i in results['Current Injection']<0]]
tstart=0.4
tend=0.5
steady_state_voltage = []
current=[]
for sweep in negative_current_sweeps:
    steady_state_voltage.append(np.mean(data['ch1_'+sweep][tstart:tend]))
    current.append(results['Current Injection'][sweep])

plt.figure()
plt.plot(current, steady_state_voltage,'.-')
plt.ylabel('mV')
plt.xlabel('pA')


# remove last two current injections because seems to get non-linear there
# and convert to Volts and Amps
R = ((steady_state_voltage[-2]-steady_state_voltage[0])/1000)/((current[-2] - current[0])/10e12)
R_mohms = round(R/10e6,2)
plt.title('Input resistance = ' + str(R_mohms)+ ' Mohms')


### 2. Create a phase-plot for this cell (More difficult)
A phase plot is a plot of voltage vs. the derivative of the voltage (rate of change) during an action potential. Here's an example:
<img src = 'phase_plot.png'>

Your goal is to create a phase plot for the action potentials fired during the 500pA current injection. There a few steps involved to do this:
* Step 1: Generate an array of "spike cutouts" during this current step. In other words, make a 2D numpy array where the rows are time points and the columns are individual action potentials. Do this by taking 1ms of data before an AP peak and 2ms of data after a peak. Hint: remember how we found the time points of AP peaks earlier...

* Step 2: Generate an equivalent numpy array except in this array, each column is the derivative of the action potential in the array from step 1

* Step 3: Take the mean of all columns in both arrays and then plot action potential mean on the x-axis vs. the derivative mean on the y-axis

In [None]:
# ============================== Solution =======================================

# get the voltage values for the 500pA sweep (last sweep)

voltages = voltage_traces[voltage_traces.columns[-1]]

# find spike times
threshold=-20
voltages_thresh = voltages[voltages>-20]
indexes_of_maxima = ss.argrelextrema(data=voltages_thresh.values, comparator=np.greater, order=1)
spike_times = voltages_thresh.index[indexes_of_maxima]

# get rid of refractory violations
interspike_interval = np.ediff1d(spike_times)
VIOLATION_indexes = np.argwhere(interspike_interval<0.002)
spike_times = spike_times.delete(VIOLATION_indexes)

# get spike cutouts
n_spikes = len(spike_times)
len_spike_window = len(voltages[0:0.003])

# make an empty arrays to hold spike cutouts and derivatives
spike_cutouts = np.empty((n_spikes,len_spike_window-1))
derivative = np.empty((n_spikes, len_spike_window-2))

# fill arrays
fs = meta['fs'][0]
for i in range(0, n_spikes):
    spike_cutouts[i,:] = voltages[(spike_times[i]-0.001):(spike_times[i]+0.002)].squeeze()
    derivative[i,:] = np.ediff1d(spike_cutouts[i,:])/(1/fs)

spike = np.mean(spike_cutouts,axis=0)[:-1]
deriv = np.mean(derivative,axis=0)

plt.figure()
time = np.linspace(0,0.003,len(spike))
plt.subplot(221)
plt.title('Action potential')
plt.plot(time,spike)
plt.xlabel('Time (s)')
plt.ylabel('mV')

plt.subplot(223)
plt.title('Derivative')
plt.xlabel('Time (s)')
plt.ylabel('mV/s')
plt.plot(time,deriv)

plt.subplot(122)
plt.title('Phase plot')
plt.ylabel('dV/dt (mV/s)')
plt.xlabel('mV')
plt.plot(spike, deriv,'.-')
plt.tight_layout()

plt.show()

# 3.6 Appendix

### Loading binary files (.abf, .dat, etc.) using the python package, neo

*neo* is an open source python package, like numpy or pandas. It was written specifically to help load binary file formats from electrophysiology data into python. We have written a function that uses this package to load .abf files into a .csv data format (spreadsheet). If you are interested in learning more about this, ask one of your TA's or intsructors, or just go look at the file binary_file_loading_tools.py that is located in the day 3 folder. Note, in order to run the loading function, you will need neo installed on your machine. To do this, type: "pip install neo" into a terminal window on your machine. For more information on neo, see: http://neuralensemble.org/neo/
 