Skip to content

Commit

Permalink
Added quickstart to the documentation. Improved how sign of low-rank
Browse files Browse the repository at this point in the history
filter components is computed. Fixed doc bug. Added better str()
handling to Sigmoid nonlinearity.
  • Loading branch information
bnaecker committed Nov 16, 2016
1 parent 3594831 commit de28753
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Pyret requires the following dependencies:

- ``scikit-image``

- ``scikit-learn``

- ``matplotlib``

Development
Expand Down
Binary file added docs/pyret-tutorial-figures/firing-rate.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/pyret-tutorial-figures/recovered-sta.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
147 changes: 144 additions & 3 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,155 @@ Quickstart

Overview
--------
Coming soon
``Pyret`` is a Python package that provides tools for analyzing stimulus-evoked
neurophysiology data. The project grew out of work in a retinal neurophsyiology
and computation lab (hence the name), but its functionality should be applicable
to any neuroscience work in which you wish to characterize how neurons behave
in response to an input.

``Pyret``'s functionality is broken into modules.

- ``stimulustools``: Functions for manipulating input stimuli.
- ``spiketools``: Tools to characterize spikes.
- ``filtertools``: Tools to estimate and characterize linear filters fitted to neural data.
- ``nonlinearities``: Classes for estimating static nonlinearities.
- ``visualizations``: Functions to visualize responses and fitted filters/nonlinearities.

Demo
----
Coming soon

Let's explore how ``pyret`` might be used in a very common analysis pipeline. First, we'll
import the relevant modules.

>>> import pyret
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> import h5py

For this demo, we'll be using data from a retinal ganglion cell (RGC), whose spike times were
recorded using a multi-electrode array. (Data courtesy of Lane McIntosh.) We'll load the
stimulus used in the experiment, as well as the spike times for the cell.

>>> data_file = h5py.File('tutorial-data.h5', 'r')
>>> spikes = data_file['spike-times'] # Spike times for one cell
>>> stimulus = data_file['stimulus']
>>> stimulus -= stimulus.mean()
>>> stimulus /= stimulus.std()
>>> time = np.arange(stimulus.shape[0]) * data_file['stimulus'].attrs.get('frame-rate')

The stimulus is a spatio-temporal gaussian white noise checkboard, with shape ``(time, nx, ny)``.
Each spatial position is drawn independently from a normal distribution on each
temporal frame.

To begin, let's look at the spiking behavior of the RGC. We'll create a peri-stimulus
time histogram, by binning the spike times and smoothing a bit. This is an estimate of the
firing rate of the RGC over time.

>>> binned = pyret.spiketools.binspikes(spikes, time)
>>> rate = pyret.spiketools.estfr(binned, time)
>>> plt.plot(time[:500], rate[:500])
>>> plt.xlabel('Time (s)')
>>> plt.ylabel('Firing rate (Hz)')

.. image:: /pyret-tutorial-figures/firing-rate.png
:height: 500px
:width: 500px
:alt: Estimated RGC firing rate over time

One widely-used and informative description of the cell is it's receptive field. This
is a linear approximation to the function of the cell, and captures the average visual
feature to which it responds. Because our data consists of spike times, we'll compute
the *spike-triggered average* (STA) for the cell.

>>> filter_length_seconds = 0.5 # 500 ms filter
>>> filter_length = int(filter_length_second / data_file['stimulus'].attrs.get('frame-rate'))
>>> sta, tax = pyret.filtertools.sta(time, stimulus, spikes, filter_length)
>>> fig, axes = pyret.visualizations.plotsta(tax[::-1], sta)
>>> axes[0].set_title('Recovered spatial filter (STA)')
>>> axes[1].set_title('Recovered temporal filter (STA)')
>>> axes[1].set_xlabel('Time before spike (s)')
>>> axes[1].set_ylabel('Filter response')

.. image:: /pyret-tutorial-figures/recovered-sta.png
:height: 500px
:width: 500px
:alt: Spatial and temporal RGC filters recovered via STA

While the STA gives a lot of information, it is not the whole story. Real RGCs are definitely
*not* linear. One common way to correct for this fact is to fit a single, time-invariant
(static), point-wise nonlinearity to the data. This is a mapping between the linear response
to the real spiking data; in otherwords, it captures the difference between how the cell
*would response if it were linear* and how the cell actually responds.

The first step in computing a nonlinearity is to compute how the recovered linear
filter responds to the input stimulus. This is done via convolution of the linear filter
with the stimulus.

>>> pred = pyret.filtertools.linear_prediction(sta, stimulus)
>>> stimulus.shape
(30011, 20, 20)
>>> pred.shape
(29962,)

The linear prediction is shorter than the full stimulus, because it only takes the
portion of the convolution in which the stimulus and filter fully overlap
(the ``valid`` keyword argument to ``np.convolve``).

We can get a sense for how poor our linear prediction is, simply by plotting the
predicted versus the actual response at each time point.

>>> plt.plot(pred, rate[filter_length - 1 :], linestyle='none', marker='o', mew=1, mec='w')
>>> plt.xlabel('Linearly predicted output')
>>> plt.ylabel('True output (Hz)')

.. image:: /pyret-tutorial-figures/pred-vs-true-no-fit.png
:height: 500px
:width: 500px
:alt: Predicted vs true firing rates for one RGC

It's clear that there is at least some nonlinear behavior in the cell. For one thing,
firing rates can never be negative, but our linear prediction definitely is.

``pyret`` contains several classes for fitting nonlinearities to data. The simplest is
the ``Binterp`` class (a portmanteau of "bin" and "interpolate"), which computes the
average true output in specified bins along the input axis. It uses variable-sized
bins, so that each bin has roughly the same number of data points.

>>> nbins = 50
>>> binterp = pyret.nonlinearities.Binterp(nbins)
>>> binterp.fit(pred, rate[filter_length - 1 :])
>>> nonlin_range = (pred.min(), pred.max())
>>> binterp.plot(nonlin_range, linewdith=5, label='Binterp') # Plot nonlinearity over the given range

.. image:: /pyret-tutorial-figures/pred-vs-true-with-binterp.png
:height: 500px
:width: 500px
:alt: Predicted vs true firing rates for one RGC

One can also fit sigmoidal nonlinearities, or a nonlinearity using a Gaussian process
(which has some nice advantages, and returns errorbars automatically). These are shown below.

We can now compare how well the full LN model captures the cell's response characteristics.

>>> predicted_rate = binterp.predict(pred)
>>> plt.figure()
>>> plt.plot(time[:500], rate[filter_length - 1 : filter_length - 1 + 500], linewidth=5, color=(0.75,) * 3, alpha=0.7, label='True rate')
>>> plt.plot(time[:500], predicted_rate[:500], linewidth=2, color=(0.75, 0.1, 0.1), label='LN predicted rate')
>>> plt.legend()
>>> plt.xlabel('Time (s)')
>>> plt.ylabel('Firing rate (Hz)')
>>> np.corrcoef(rate[filter_length - 1 :], predicted_rate)[0, 1]
0.70315310866999448


.. image:: /pyret-tutorial-figures/pred-vs-true-rates.png
:height: 500px
:width: 500px
:alt: True firing rate with LN model prediction for one RGC


Bugs
----

Please report any bugs you encounter through the github `issue tracker
Please report any bugs you encounter through the GitHub `issue tracker
<https://github.com/baccuslab/pyret/issues/new>`_.
Binary file added docs/tutorial-data.h5
Binary file not shown.
25 changes: 10 additions & 15 deletions pyret/filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,12 @@ def lowranksta(f_orig, k=10):
# Compute the rank-k filter
fk = (u[:, :k].dot(np.diag(s[:k]).dot(v[:k, :]))).reshape(f.shape)

# make sure the temporal kernels have the correct sign

# get out the temporal filter at the RF center
peakidx = filterpeak(f)[1]
tsta = f[:, peakidx[0], peakidx[1]].reshape(-1, 1)
tsta -= np.mean(tsta)

# project onto the temporal filters and keep the sign
signs = np.sign((u - np.mean(u, axis=0)).T.dot(tsta))

# flip signs according to this projection
v *= signs
u *= signs.T
# Ensure that the computed filter components have the correct sign.
# The mean-subtracted filter should have positive projection onto
# the low-rank filter.
sign = np.sign(fk.ravel().dot((f - np.mean(f)).ravel()))
u *= sign
v *= sign

# Return the rank-k approximate filter, and the SVD components
return fk, u, s, v
Expand Down Expand Up @@ -630,8 +623,10 @@ def linear_prediction(filt, stim):
Returns
-------
pred : array_like
The predicted linear response. The shape is (T,) where T is the
number of time points in the input stimulus array.
The predicted linear response. The shape is ``(T - t + 1,)`` where
``T`` is the number of time points in the stimulus, and ``t`` is
the number of time points in the filter. This is the valid portion
of the convolution between the stimulus and filter
Raises
------
Expand Down
5 changes: 3 additions & 2 deletions pyret/nonlinearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, baseline=0., peak=1., slope=1., threshold=0.):

def fit(self, x, y, **kwargs):
self.params, self.pcov = curve_fit(self._sigmoid, x, y, self.init_params, **kwargs)
self.set_params(**dict(zip(self.get_params().keys(), self.params)))
return self

@staticmethod
Expand Down Expand Up @@ -136,9 +137,9 @@ def predict(self, x):


class GaussianProcess(GaussianProcessRegressor, NonlinearityMixin):
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
self._fitted = False
super().__init__(*args, **kwargs)
super().__init__(**kwargs)

def fit(self, x, y):
super().fit(x.reshape(-1, 1), y)
Expand Down
8 changes: 4 additions & 4 deletions pyret/spiketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def estfr(bspk, time, sigma=0.01):
Parameters
----------
time : array_like
Array of time points corresponding to bins
bspk : array_like
Array of binned spike counts (e.g. from binspikes)
time : array_like
Array of time points corresponding to bins
sigma : float, optional
The width of the Gaussian filter, in seconds (Default: 0.01 seconds)
Expand All @@ -61,7 +61,7 @@ def estfr(bspk, time, sigma=0.01):
tau = np.arange(-5 * sigma, 5 * sigma, dt)
filt = np.exp(-0.5 * (tau / sigma) ** 2)
filt = filt / np.sum(filt)
size = np.round(filt.size / 2)
size = int(np.round(filt.size / 2))

# Filter binned spike times
return np.convolve(filt, bspk, mode='full')[size:size + time.size] / dt
Expand Down

0 comments on commit de28753

Please sign in to comment.