Skip to content

Commit

Permalink
Merge pull request #78 from baccuslab/html
Browse files Browse the repository at this point in the history
Renames some plotting functions, better handling of low-rank STA components, and more. Bumps version to v0.5.0.
  • Loading branch information
bnaecker committed Nov 18, 2016
2 parents edfd31f + 0d505ae commit 59a0405
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Changelog

A list of new features, improvements, and bug-fixes in each release.

.. include:: releases/v0.5.rst
.. include:: releases/v0.4.rst
.. include:: releases/v0.3.rst
.. include:: releases/v0.2.rst
2 changes: 1 addition & 1 deletion docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ the *spike-triggered average* (STA) for the cell.
>>> filter_length_seconds = 0.5 # 500 ms filter
>>> filter_length = int(filter_length_second / data_file['stimulus'].attrs.get('frame-rate'))
>>> sta, tax = pyret.filtertools.sta(time, stimulus, spikes, filter_length)
>>> fig, axes = pyret.visualizations.plotsta(tax[::-1], sta)
>>> fig, axes = pyret.visualizations.plot_sta(tax[::-1], sta)
>>> axes[0].set_title('Recovered spatial filter (STA)')
>>> axes[1].set_title('Recovered temporal filter (STA)')
>>> axes[1].set_xlabel('Time before spike (s)')
Expand Down
31 changes: 31 additions & 0 deletions docs/releases/v0.5.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
==================
v0.5 (17 Nov 2016)
==================

New features
------------
- Better handling of low-rank STA component signs in ``filtertools.lowranksta``.
- Functionality for embedding STA animations into HTML, via ``visualizations.anim_to_html()``.
- New classes for estimating nonlinearities: ``Binterp``, ``Sigmoid`` and
``GaussianProcess``. These follow the ``scikit-learn`` interface, meaning
they have ``fit()`` and ``predict()`` methods, which return ``self``.

API changes
-----------
- Renamed ``filtertools.getsta`` -> ``filtertools.sta``
- Renamed ``filtertools.getste`` -> ``filtertools.ste``
- Renamed ``filtertools.getstc`` -> ``filtertools.stc``
- Renamed ``visualizations.rasterandpsth`` -> ``visualizations.raster_and_psth``
- Renamed ``visualizations.plotcells`` -> ``visualizations.plot_cells``
- Renamed ``visualizations.plotsta`` -> ``visualizations.plot_sta``
- Renamed ``visualizations.playrates`` -> ``visualizations.play_rates``
- Renamed ``visualizations.playsta`` -> ``visualizations.play_sta``
- ``spiketools.binspikes`` and ``spiketools.estfr`` no longer return the time axis. Only the
binned spikes and firing rate are returned, respectively.
- Removed ``containers`` module.
- ``filtertools.rolling_window`` has been moved to the ``stimulustools`` module,
and is renamed ``slicestim``. ``rolling_window`` is an alias for ``slicestim``,
for the time being, which raises a warning about future deprecation.
- Renamed ``stimulustools.stimcov`` -> ``stimulustools.cov``.
- Renamed ``stimulustools.upsample_stim`` -> ``stimulustools.upsample``.
- Renamed ``stimulustools.downsample_stim`` -> ``stimulustools.downsample``.
38 changes: 32 additions & 6 deletions pyret/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from matplotlib.patches import Ellipse

__all__ = ['raster', 'psth', 'raster_and_psth', 'spatial', 'temporal',
'plotsta', 'playsta', 'ellipse', 'plotcells', 'playrates']
'plot_sta', 'play_sta', 'ellipse', 'plot_cells', 'play_rates']


@plotwrapper
Expand Down Expand Up @@ -204,7 +204,7 @@ def raster_and_psth(spikes, trial_length=None, binsize=0.01, **kwargs):
tick.set_color('k')


def playsta(sta, repeat=True, frametime=100, cmap='seismic_r', clim=None):
def play_sta(sta, repeat=True, frametime=100, cmap='seismic_r', clim=None):
"""
Plays a spatiotemporal spike-triggered average as a movie.
Expand Down Expand Up @@ -357,7 +357,7 @@ def temporal(time, filt, **kwargs):
kwargs['ax'].plot(time, temporal_filter, linestyle='-', linewidth=2, color='LightCoral')


def plotsta(time, sta):
def plot_sta(time, sta):
"""
Plot a linear filter.
Expand Down Expand Up @@ -496,7 +496,7 @@ def ellipse(filt, sigma=2.0, alpha=0.8, fc='none', ec='black', lw=3, **kwargs):


@plotwrapper
def plotcells(cells, **kwargs):
def plot_cells(cells, **kwargs):
"""
Plot the spatial receptive fields for multiple cells.
Expand Down Expand Up @@ -530,11 +530,11 @@ def plotcells(cells, **kwargs):
fig, ax = ellipse(sp, fc=color, ec=color, lw=2, alpha=0.3, ax=ax)


def playrates(rates, patches, num_levels=255, time=None, repeat=True, frametime=100):
def play_rates(rates, patches, num_levels=255, time=None, repeat=True, frametime=100):
"""
Plays a movie representation of the firing rate of a list of cells, by
coloring a list of patches with a color proportional to the firing rate. This
is useful, for example, in conjunction with ``plotcells``, to color the
is useful, for example, in conjunction with ``plot_cells``, to color the
ellipses fitted to a set of receptive fields proportional to the firing rate.
Parameters
Expand Down Expand Up @@ -581,3 +581,29 @@ def animate(t):
np.arange(T), interval=frametime, repeat=repeat)
return anim

def anim_to_html(anim):
"""
Convert an animation into an embedable HTML element.
This converts the animation objects returned by ``play_sta()`` and
``play_rates()`` into an HTML tag that can be embedded, for example
in a Jupyter notebook.
Paramters
---------
anim : matplotlib.animation.Animation
The animation object to embed.
Returns
-------
html : IPython.display.HTML
An HTML object with the encoded video. This can be directly embedded
into an IPython notebook.
Raises
------
An ImportError is raised if the IPython modules required to convert the
animation are not installed.
"""
from IPython.display import HTML
return HTML(anim.to_html5_video())
42 changes: 30 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,38 @@
author='Benjamin Naecker, Niru Maheshwaranathan',
author_email='bnaecker@stanford.edu',
url='https://github.com/baccuslab/pyret',
requires=['scipy', 'skimage', 'numpy', 'matplotlib'],
requires=[
'numpy',
'scipy',
'matplotlib'
'sklearn',
'skimage',
],
long_description='''
The pyret package contains tools for analyzing neural
data. In particular, it contains methods for manipulating
spike trains (such as binning and smoothing), computing
spike-triggered averages and ensembles, computing nonlinearities,
as well as a suite of visualization tools.
''',
The pyret package contains tools for analyzing neural
data. In particular, it contains methods for manipulating
spike trains (such as binning and smoothing), computing
spike-triggered averages and ensembles, computing nonlinearities,
as well as a suite of visualization tools.
''',
classifiers=[
'Intended Audience :: Science/Research',
'Operating System :: MacOS :: MacOS X',
'Topic :: Scientific/Engineering :: Information Analysis'
],
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Scientific/Engineering :: Visualization',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3 :: Only'
],
packages=find_packages(),
install_requires=['numpy', 'scipy', 'matplotlib', 'scikit-image'],
install_requires=[
'numpy',
'scipy',
'matplotlib',
'scikit-image',
'scikit-learn'
],
license='MIT',
extras_require={
'html' : ['jupyter>=1.0']
}
)
35 changes: 25 additions & 10 deletions tests/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_spatiotemporal_filter():

# Test plotting temporal component
filename = os.path.join(IMG_DIR, 'test-temporal-filter.png')
viz.plotsta(time, t)
viz.plot_sta(time, t)
plt.savefig(filename)
assert not compare_images(
os.path.join(IMG_DIR, 'baseline-temporal-from-spatiotemporal-filter.png'),
Expand All @@ -64,7 +64,7 @@ def test_spatiotemporal_filter():

# Test plotting spatial component
filename = os.path.join(IMG_DIR, 'test-temporal-filter.png')
viz.plotsta(time, s)
viz.plot_sta(time, s)
plt.savefig(filename)
assert not compare_images(
os.path.join(IMG_DIR, 'baseline-spatial-from-spatiotemporal-filter.png'),
Expand All @@ -74,7 +74,7 @@ def test_spatiotemporal_filter():

# Test plotting both spatial/temporal components
filename = os.path.join(IMG_DIR, 'test-full-spatiotemporal-filter.png')
viz.plotsta(time, sta)
viz.plot_sta(time, sta)
plt.savefig(filename)
assert not compare_images(
os.path.join(IMG_DIR, 'baseline-full-spatiotemporal-filter.png'), filename, 1)
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_raster_and_psth():
plt.close('all')


def test_playsta():
def test_play_sta():
"""Test playing an STA as a movie.
Matplotlib doesn't yet have a way to compare movies, and the formats
Expand All @@ -129,7 +129,7 @@ def test_playsta():
"""
nx, ny, nt = 10, 10, 50
sta = utils.create_spatiotemporal_filter(nx, ny, nt)[-1]
anim = viz.playsta(sta)
anim = viz.play_sta(sta)
filename = os.path.join(IMG_DIR, 'test-sta-movie.png')
frame = 10
anim._func(frame)
Expand All @@ -153,7 +153,7 @@ def test_ellipse():
plt.close('all')


def test_plotcells():
def test_plot_cells():
"""Test plotting ellipses for multiple cells on the same axes."""
nx, ny, nt = 10, 10, 50
stas = []
Expand All @@ -162,16 +162,16 @@ def test_plotcells():
stas.append(utils.create_spatiotemporal_filter(nx, ny, nt)[-1])

filename = os.path.join(IMG_DIR, 'test-plotcells.png')
np.random.seed(0) # plotcells() uses random colors for each cell
viz.plotcells(stas)
np.random.seed(0) # plot_cells() uses random colors for each cell
viz.plot_cells(stas)
plt.savefig(filename)
assert not compare_images(
os.path.join(IMG_DIR, 'baseline-plotcells.png'), filename, 1)
os.remove(filename)
plt.close('all')


def test_playrates():
def test_play_rates():
"""Test playing firing rates for cells as a movie."""
nx, ny, nt = 10, 10, 50
sta = utils.create_spatiotemporal_filter(nx, ny, nt)[-1]
Expand All @@ -183,7 +183,7 @@ def test_playrates():
# Plot cell
fig, axes = viz.ellipse(sta)
patch = plt.findobj(axes, Ellipse)[0]
anim = viz.playrates(rate, patch)
anim = viz.play_rates(rate, patch)
filename = os.path.join(IMG_DIR, 'test-rates-movie.png')
frame = 10
anim._func(frame)
Expand All @@ -193,3 +193,18 @@ def test_playrates():
filename, 1)
os.remove(filename)
plt.close('all')


def test_anim_to_html():
"""Test converting an animation to HTML."""
try:
from IPython.display import HTML
except ImportError:
pytest.skip('Cannot convert movie to HTML without IPython.')

nx, ny, nt = 10, 10, 50
sta = utils.create_spatiotemporal_filter(nx, ny, nt)[-1]
anim = viz.play_sta(sta)
html = viz.anim_to_html(viz.play_sta(sta))
assert isinstance(html, HTML)

0 comments on commit 59a0405

Please sign in to comment.