Skip to content

Commit

Permalink
Added utils for testing (may move into main tree eventually), updated
Browse files Browse the repository at this point in the history
`filtertools.revco()` to handle multi-dimensional stimuli and filters.
  • Loading branch information
bnaecker committed Nov 1, 2016
1 parent 770d9e9 commit 503d6d0
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 31 deletions.
38 changes: 19 additions & 19 deletions pyret/filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
__all__ = ['getste', 'getsta', 'getstc', 'lowranksta', 'decompose',
'filterpeak', 'smooth', 'cutout', 'rolling_window', 'resample',
'get_ellipse', 'get_contours', 'get_regionprops',
'normalize_spatial']
'normalize_spatial', 'linear_prediction', 'revco']


def getste(time, stimulus, spikes, filter_length):
Expand Down Expand Up @@ -621,7 +621,7 @@ def linear_prediction(filt, stim):
return np.einsum(subscripts, slices, filt)


def revco(response, stimulus, filter_length, norm=False):
def revco(response, stimulus, filter_length):
"""
Compute the reverse-correlation between a stimulus and a response.
Expand All @@ -638,23 +638,20 @@ def revco(response, stimulus, filter_length, norm=False):
be one-dimensional.
stimulus : array_like
A input stimulus correlated with the ``response``. Must be
one-dimensional.
A input stimulus correlated with the ``response``. Must be of shape
(t, ...), where t is the time and ... indicates any spatial dimensions.
filter_length : int
The length of the returned filter, in samples of the ``stimulus`` and
``response`` arrays.
norm : bool [optional]
If True, normalize the computed filter to a unit vector. Defaults
to False.
Returns
-------
filt : array_like
An array of shape ``(filter_length,)`` containing the best-fitting
linear filter which predicts the response from the stimulus.
An array of shape ``(filter_length, ...)`` containing the best-fitting
linear filter which predicts the response from the stimulus. The ellipses
indicates spatial dimensions of the filter.
Raises
------
Expand All @@ -671,15 +668,18 @@ def revco(response, stimulus, filter_length, norm=False):
"""

if response.ndim > 1 or stimulus.ndim > 1:
raise ValueError("The `response` and `stimulus` must be 1-dimensional")
if response.shape != stimulus.shape:
raise ValueError("The `response` and `stimulus` must have the same shape")

filt = fftconvolve(response, stimulus[::-1], mode='full')
mid = int(filt.size / 2) + np.mod(filt.size, 2) # Account for odd-sized arrays
filt = filt[mid - filter_length : mid]
return filt / np.linalg.norm(filt) if norm else filt
if response.ndim > 1:
raise ValueError("The `response` must be 1-dimensional")
if response.size != (stimulus.shape[0] - filter_length):
raise ValueError(("`stimulus` must have {:#d} time points " +
"(`response.size` + `filter_length`").format(response.size + filter_length))

slices = slicestim(stimulus, filter_length)
dim_start = ord('i')
indices = ''.join(map(chr, range(dim_start, dim_start + slices.ndim)))
subscripts = '{0},{1}->{2}{3}'.format(indices, indices[1], indices[0], indices[2:])
recovered = np.einsum(subscripts, slices, response)
return recovered

def _gaussian_function(data, x0, y0, a, b, c):
"""
Expand Down
39 changes: 27 additions & 12 deletions tests/test_filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from pyret import filtertools as flt
from pyret.stimulustools import slicestim

import utils

def test_linear_prediction_one_dim():
"""Test method for computing linear prediction from a
def test_linear_prediction_1d():
"""Test method for computing linear prediction from a
filter to a one-dimensional stimulus.
"""
filt = np.random.randn(100,)
Expand All @@ -21,9 +22,8 @@ def test_linear_prediction_one_dim():
sl = slicestim(stim, filt.shape[0])
assert np.allclose(filt.reshape(1, -1).dot(sl), pred)


def test_linear_prediction_multi_dim():
"""Test method for computing linear prediction from a
def test_linear_prediction_nd():
"""Test method for computing linear prediction from a
filter to a multi-dimensional stimulus.
"""
for ndim in range(2, 4):
Expand All @@ -46,19 +46,34 @@ def test_linear_prediction_raises():
flt.linear_prediction(np.random.randn(10,), np.random.randn(10,2))
flt.linear_prediction(np.random.randn(10, 2), np.random.randn(10, 3))

def test_revco_1d():
"""Test computation of a 1D linear filter by reverse correlation"""
# Create fake filter, 100 time points
filter_length = 100
true = utils.create_temporal_filter(filter_length)

def test_revco():
"""Test computation of a linear filter by reverse correlation"""
# Compute linear response
stim_length = 10000
stimulus = np.random.randn(stim_length,)
response = flt.linear_prediction(true, stimulus)

# Reverse correlation
filt = flt.revco(response, stimulus, filter_length, norm=True)
tol = 0.1
assert np.allclose(true, filt, atol=tol)


def test_revco_nd():
"""Test computation of 3D linear filter by reverse correlation"""
# Create fake filter
filter_length = 100
x = np.linspace(0, 2 * np.pi, filter_length)
true = np.exp(-1. * x) * np.sin(x)
true /= np.linalg.norm(true)
nx, ny = 10, 10
true = utils.create_spatiotemporal_filter(nx, ny, filter_length)

# Compute linear response
stim_length = 10000
stimulus = np.random.randn(stim_length,)
response = np.convolve(stimulus, true, mode='full')[-stimulus.size:]
stimulus = np.random.randn(stim_length, nx, ny)
response = flt.linear_prediction(true, stimulus)

# Reverse correlation
filt = flt.revco(response, stimulus, filter_length, norm=True)
Expand Down
71 changes: 71 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""utils.py
Some general utilities used in various testing routines.
(C) 2016 The Baccus Lab
"""

import numpy as np

from pyret.filtertools import _gaussian_function as gaussian_function

def create_temporal_filter(n, norm=True):
"""Returns a fake temporal linear filter that superficially resembles
those seen in retinal ganglion cells.
Parameters
----------
n : int
Number of time points in the filter.
norm : bool [optional]
If True, normalize the filter to have unit 2-norm. Defaults to True.
Returns
-------
f : ndarray
The fake linear filter
"""
time_axis = np.linspace(0, 2 * np.pi, n)
filt = np.exp(-1. * time_axis) * np.sin(time_axis)
return filt / np.linalg.norm(filt) if norm else filt


def create_spatiotemporal_filter(nx, ny, nt, norm=True):
"""Returns a fake 3D spatiotemporal filter.
The filter is created as the outer product of a 2D gaussian with a fake
temporal filter as returned by `create_temporal_filter()`.
Parameters
----------
nx, ny : int
Number of points in the two spatial dimensions of the stimulus.
nt : int
Number of time points in the stimulus.
norm : bool [optional]
If True, normalize the filter to have unit 2-norm. Defaults to True.
Returns
-------
f : ndarray
The linear filter, shaped (nt, nx, ny)
"""
temporal_filter = create_temporal_filter(nt, norm)

grid = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
points = np.array([each.flatten() for each in grid])
gaussian = gaussian_function(points, int(ny / 2), int(nx / 2), 1, 0, 1).reshape(nx, ny)
if norm:
gaussian /= np.linalg.norm(gaussian)

# Outer product
filt = np.rollaxis(np.einsum('i,jk->jki', temporal_filter, gaussian), -1, 0)

return filt / np.linalg.norm(filt) if norm else filt


0 comments on commit 503d6d0

Please sign in to comment.