Skip to content

Commit

Permalink
Added methods to filtertools module for computing linear predictions
Browse files Browse the repository at this point in the history
from a stimulus and filter, and for performing basic reverse correlation
for computing filters from continuous responses. Added basic tests for
them as well.

Reverse correlation currently only supports 1D stimuli, and does not
correct for any correlations in the stimulus.
  • Loading branch information
bnaecker committed Oct 10, 2016
1 parent b0828ca commit 5608de8
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ _build/
*.egg-info/
tags
dist/
*.swp
117 changes: 114 additions & 3 deletions pyret/filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import numpy as np
import scipy
from scipy.signal import fftconvolve
from numpy.linalg import LinAlgError
from skimage.measure import label, regionprops, find_contours
from functools import reduce

from pyret.stimulustools import slicestim

__all__ = ['getste', 'getsta', 'getstc', 'lowranksta', 'decompose',
'filterpeak', 'smooth', 'cutout', 'rolling_window', 'resample',
'get_ellipse', 'get_contours', 'get_regionprops',
Expand Down Expand Up @@ -219,7 +222,7 @@ def lowranksta(f_orig, k=10):

# get out the temporal filter at the RF center
peakidx = filterpeak(f)[1]
tsta = f[:, peakidx[1], peakidx[0]].reshape(-1, 1)
tsta = f[:, peakidx[-1::]].reshape(-1, 1)
tsta -= np.mean(tsta)

# project onto the temporal filters and keep the sign
Expand Down Expand Up @@ -269,8 +272,8 @@ def filterpeak(sta):
idx : int
Linear index of the maximal point
sidx : int
Spatial index of the maximal point
sidx : 1- or 2-element tuple
Spatial index of the maximal point. For a 1D spatiotemporal
tidx : int
Temporal index of the maximal point
Expand Down Expand Up @@ -570,6 +573,114 @@ def rfsize(spatial_filter, dx, dy=None, pvalue=0.6827):
return widths[0] * dx, widths[1] * dy


def linear_prediction(filt, stim):
"""
Compute the predicted linear response of a receptive field to a stimulus.
Parameters
----------
filt : array_like
The linear filter whose response is to be computed. The array should
have shape ``(t, ...)``, where ``t`` is the number of time points in the
filter and the ellipsis indicates any remaining spatial dimenions.
The number of dimensions and the sizes of the spatial dimensions
must match that of ``stim``.
stim : array_like
The stimulus to which the predicted response is computed. The array
should have shape (T,...), where ``T`` is the number of time points
in the stimulus and the ellipsis indicates any remaining spatial
dimensions. The number of dimensions and the sizes of the spatial
dimenions must match that of ``filt``.
Returns
-------
pred : array_like
The predicted linear response. The shape is (T,) where T is the
number of time points in the input stimulus array.
Raises
------
ValueError : If the number of dimensions of ``stim`` and ``filt`` do not
match, or if the spatial dimensions differ.
"""

if (filt.ndim != stim.ndim) or (filt.shape[1:] != stim.shape[1:]):
raise ValueError("The filter and stimulus must have the same " +
"number of dimensions and match in size along spatial dimensions")

slices = slicestim(stim, filt.shape[0])
dim_start = ord('i')
indices = ''.join(map(chr, range(dim_start, dim_start + slices.ndim)))
subscripts = '{0},{1}{2}->{3}'.format(indices, indices[0],
indices[2:], indices[1])
return np.einsum(subscripts, slices, filt)


def revco(response, stimulus, filter_length, norm=False):
"""
Compute the reverse-correlation between a stimulus and a response.
This returns the best-fitting linear filter which predicts the given
response from the stimulus. It is analogous to the spike-triggered
average for continuous variables. ``response`` is most often a membrane
potential.
Parameters
----------
response : array_like
A continuous output response correlated with the stimulus. Must
be one-dimensional.
stimulus : array_like
A input stimulus correlated with the ``response``. Must be
one-dimensional.
filter_length : int
The length of the returned filter, in samples of the ``stimulus`` and
``response`` arrays.
norm : bool [optional]
If True, normalize the computed filter to a unit vector. Defaults
to False.
Returns
-------
filt : array_like
An array of shape ``(filter_length,)`` containing the best-fitting
linear filter which predicts the response from the stimulus.
Raises
------
ValueError : If the ``stimulus`` and ``response`` arrays are of different
shapes.
Notes
-----
The ``response`` and ``stimulus`` arrays must share the same sampling
rate. As the stimulus often has a lower sampling rate, one can use
``stimulustools.upsamplestim`` to upsample it.
"""

if response.ndim > 1 or stimulus.ndim > 1:
raise ValueError("The `response` and `stimulus` must be 1-dimensional")
if response.shape != stimulus.shape:
raise ValueError("The `response` and `stimulus` must have the same shape")

filt = fftconvolve(response, stimulus[::-1], mode='full')
mid = int(filt.size / 2) + np.mod(filt.size, 2) # Account for odd-sized arrays
filt = filt[mid - filter_length : mid]
return filt / np.linalg.norm(filt) if norm else filt

def _gaussian_function(data, x0, y0, a, b, c):
"""
A 2D gaussian function (used for fitting an ellipse to RFs)
Expand Down
24 changes: 21 additions & 3 deletions pyret/nonlinearities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,28 @@


class Nonlinearity:
def plot(self, span=(-5, 5), n=100):
"""Creates a 1D plot of the nonlinearity"""
def plot(self, span=(-5, 5), n=100, **kwargs):
"""Creates a 1D plot of the nonlinearity
Parameters
----------
span : 2-element array_like
The span of the x-axis to plot.
n : integer_like
The number of points to plot.
kwargs : mapping
Keyword arguments passed directly to `matplotlib.pyplot.plot()`
Returns
-------
line : matplotlib.lines.Line2D
The line object that represents this nonlinearity.
"""
x = np.linspace(span[0], span[1], n)
plt.plot(x, self.predict(x))
return plt.plot(x, self.predict(x), **kwargs)

def fit(self, x, y):
"""Fits the parameters of the nonlinearity
Expand Down
63 changes: 63 additions & 0 deletions tests/test_filtertools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""test_filtertools.py
Test code for pyret's filtertools module.
(C) 2016 The Baccus Lab.
"""

import numpy as np
import pytest

from pyret import filtertools as flt
from pyret.stimulustools import slicestim

def test_linear_prediction_one_dim():
"""Test method for computing linear prediction from a
filter to a one-dimensional stimulus.
"""
filt = np.random.randn(100,)
stim = np.random.randn(1000,)
pred = flt.linear_prediction(filt, stim)

sl = slicestim(stim, filt.shape[0])
assert np.allclose(filt.reshape(1, -1).dot(sl), pred)

def test_linear_prediction_multi_dim():
"""Test method for computing linear prediction from a
filter to a multi-dimensional stimulus.
"""
for ndim in range(2, 4):
filt = np.random.randn(100, *((10,) * ndim))
stim = np.random.randn(1000, *((10,) * ndim))
pred = flt.linear_prediction(filt, stim)

sl = slicestim(stim, filt.shape[0])
tmp = np.zeros(sl.shape[1])
filt_reshape = filt.reshape(1, -1)
for i in range(tmp.size):
tmp[i] = filt_reshape.dot(sl[:, i, :].reshape(-1, 1))

assert np.allclose(tmp, pred)

def test_linear_prediction_raises():
"""Test raising ValueErrors with incorrect inputs"""
with pytest.raises(ValueError):
flt.linear_prediction(np.random.randn(10,), np.random.randn(10,2))
flt.linear_prediction(np.random.randn(10, 2), np.random.randn(10, 3))

def test_revco():
"""Test computation of a linear filter by reverse correlation"""
# Create fake filter
filter_length = 100
x = np.linspace(0, 2 * np.pi, filter_length)
true = np.exp(-1. * x) * np.sin(x)
true /= np.linalg.norm(true)

# Compute linear response
stim_length = 10000
stimulus = np.random.randn(stim_length,)
response = np.convolve(stimulus, true, mode='full')[-stimulus.size:]

# Reverse correlation
filt = flt.revco(response, stimulus, filter_length, norm=True)
tol = 0.1
assert np.allclose(true, filt, atol=tol)

0 comments on commit 5608de8

Please sign in to comment.