Skip to content

Commit

Permalink
Merge pull request #4 from brian-team/better_connection_plots
Browse files Browse the repository at this point in the history
Better connection plots
  • Loading branch information
mstimberg committed May 10, 2016
2 parents 6797867 + e1f5c99 commit 1e6d6e4
Show file tree
Hide file tree
Showing 15 changed files with 4,245 additions and 18,445 deletions.
15 changes: 14 additions & 1 deletion brian2tools/plotting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Base module for the plotting facilities.
'''
import matplotlib.pyplot as plt
import numpy as np

from brian2.spatialneuron.morphology import Morphology
from brian2.monitors import SpikeMonitor, StateMonitor, PopulationRateMonitor
Expand Down Expand Up @@ -95,7 +96,19 @@ def brian_plot(brian_obj,
'arguments, ignoring them.')
plot_dendrogram(brian_obj, axes=axes)
elif isinstance(brian_obj, Synapses):
plot_synapses(brian_obj.i, brian_obj.j, axes=axes)
if len(brian_obj) == 0:
raise TypeError('Synapses object does not have any synapses.')
min_sources, max_sources = np.min(brian_obj.i[:]), np.max(brian_obj.i[:])
min_targets, max_targets = np.min(brian_obj.j[:]), np.max(brian_obj.j[:])
source_range = max_sources - min_sources
target_range = max_targets - min_targets
if source_range < 1000 and target_range < 1000:
plot_type = 'image'
elif len(brian_obj) < 10000:
plot_type = 'scatter'
else:
plot_type = 'hexbin'
plot_synapses(brian_obj.i, brian_obj.j, plot_type=plot_type, axes=axes)
else:
raise NotImplementedError('Do not know how to plot object of type '
'%s' % type(brian_obj))
106 changes: 85 additions & 21 deletions brian2tools/plotting/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def plot_synapses(sources, targets, values=None, var_unit=None,
var_name=None, axes=None, **kwds):
var_name=None, plot_type='scatter', axes=None, **kwds):
'''
Parameters
----------
Expand All @@ -31,14 +31,25 @@ def plot_synapses(sources, targets, values=None, var_unit=None,
find a good scale automatically based on the ``values``.
var_name : str, optional
The name of the variable that is plotted. Used for the axis label.
plot_type : {``'scatter'``, ``'image'``, ``'hexbin'``}, optional
What type of plot to use. Can be ``'scatter'`` (the default) to draw
a scatter plot, ``'image'`` to display the connections as a matrix or
``'hexbin'`` to display a 2D histogram using matplotlib's
`~matplotlib.axes.Axes.hexbin` function.
For a large number of synapses, ``'scatter'`` will be very slow.
Similarly, an ``'image'`` plot will use a lot of memory for connections
between two large groups. For a small number of neurons and synapses,
``'hexbin'`` will be hard to interpret.
axes : `~matplotlib.axes.Axes`, optional
The `~matplotlib.axes.Axes` instance used for plotting. Defaults to
``None`` which means that a new `~matplotlib.axes.Axes` will be
created for the plot.
kwds : dict, optional
Any additional keywords command will be handed over to matplotlib's
`~matplotlib.axes.Axes.scatter` command. This can be used to set plot
properties such as the ``marker``.
Any additional keywords command will be handed over to the respective
matplotlib command (`~matplotlib.axes.Axes.scatter` if the
``plot_type`` is ``'scatter'``, `~matplotlib.axes.Axes.imshow` for
``'image'``, and `~matplotlib.axes.Axes.hexbin` for ``'hexbin'``).
This can be used to set plot properties such as the ``marker``.
Returns
-------
Expand All @@ -51,9 +62,15 @@ def plot_synapses(sources, targets, values=None, var_unit=None,
from brian2tools.plotting.base import _setup_axes_matplotlib
axes = _setup_axes_matplotlib(axes)

sources = np.asarray(sources)
targets = np.asarray(targets)
if not len(sources) == len(targets):
raise TypeError('Length of sources and targets does not match.')

if plot_type not in ['scatter', 'image', 'hexbin']:
raise ValueError("plot_type has to be either 'scatter', 'image', or "
"'hexbin' (was: %r)" % plot_type)

# Get some information out of the values if provided
if values is not None:
if len(values) != len(sources):
Expand All @@ -69,23 +86,51 @@ def plot_synapses(sources, targets, values=None, var_unit=None,
if var_unit is not None:
values = values / var_unit

connection_count = Counter(zip(sources, targets))
multiple_synapses = np.any(np.array(list(connection_count.values())) > 1)
if plot_type != 'hexbin':
# For "hexbin", we are binning multiple synapses anyway, so we don't
# have to make a difference for multiple synapses
connection_count = Counter(zip(sources, targets))
multiple_synapses = np.any(np.array(list(connection_count.values())) > 1)

edgecolor = kwds.pop('edgecolor', 'none')

if multiple_synapses:
if plot_type != 'hexbin' and multiple_synapses:
if values is not None:
raise NotImplementedError('Plotting variables with multiple '
'synapses per source-target pair is not '
'implemented yet.')
raise NotImplementedError("Plotting variables with multiple "
"synapses per source-target pair is only "
"implemented for 'hexbin' plots.")
unique_sources, unique_targets = zip(*connection_count.keys())
n_synapses = list(connection_count.values())
cmap = mpl.cm.get_cmap(kwds.pop('cmap', 'Accent'), max(n_synapses))
bounds = np.arange(max(n_synapses) + 1) + 0.5
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
axes.scatter(unique_sources, unique_targets, c=n_synapses,
edgecolor=edgecolor, cmap=cmap, **kwds)

if plot_type == 'scatter':
marker = kwds.pop('marker', ',')
cmap = mpl.cm.get_cmap(kwds.pop('cmap', None), max(n_synapses))
bounds = np.arange(max(n_synapses) + 1) + 0.5
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
axes.scatter(unique_sources, unique_targets, marker=marker,
c=n_synapses, edgecolor=edgecolor, cmap=cmap, norm=norm,
**kwds)
elif plot_type == 'image':
assert np.max(n_synapses) < 256
full_matrix = np.zeros((np.max(unique_sources) - np.min(unique_sources) + 1,
np.max(unique_targets) - np.min(unique_targets) + 1),
dtype=np.uint8)
full_matrix[unique_sources - np.min(unique_sources),
unique_targets - np.min(unique_targets)] = n_synapses
cmap = mpl.cm.get_cmap(kwds.pop('cmap', None),
max(n_synapses) + 1)
bounds = np.arange(max(n_synapses) + 2) - 0.5
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
origin = kwds.pop('origin', 'lower')
axes.imshow(full_matrix, origin=origin, cmap=cmap, norm=norm,
**kwds)
elif plot_type == 'hexbin':
cmap = mpl.cm.get_cmap(kwds.pop('cmap', None),
max(n_synapses) + 1)
bounds = np.arange(max(n_synapses) + 2) - 0.5
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
axes.hexbin(sources, targets, cmap=cmap, norm=norm, **kwds)

locatable_axes = make_axes_locatable(axes)
cax = locatable_axes.append_axes('right', size='5%', pad=0.05)
mpl.colorbar.ColorbarBase(cax, cmap=cmap,
Expand All @@ -94,14 +139,33 @@ def plot_synapses(sources, targets, values=None, var_unit=None,
spacing='proportional')
cax.set_ylabel('number of synapses')
else:
if values is None:
axes.scatter(sources, targets, edgecolor=edgecolor, **kwds)
else:
s = axes.scatter(sources, targets, c=values, edgecolor=edgecolor,
**kwds)
if plot_type == 'scatter':
marker = kwds.pop('marker', ',')
color = kwds.pop('color', values if values is not None else 'none')
plotted = axes.scatter(sources, targets, marker=marker, c=color,
edgecolor=edgecolor, **kwds)
elif plot_type == 'image':
full_matrix = np.zeros((np.max(sources) - np.min(sources) + 1,
np.max(targets) - np.min(targets) + 1))
if values is not None:
full_matrix[sources-np.min(sources), targets-np.min(targets)] = values
else:
full_matrix[sources - np.min(sources), targets - np.min(targets)] = 1
origin = kwds.pop('origin', 'lower')
interpolation = kwds.pop('interpolation', 'nearest')
if values is None:
vmin = kwds.pop('vmin', 0)
else:
vmin = kwds.pop('vmin', None)
plotted = axes.imshow(full_matrix, origin=origin, interpolation=interpolation,
vmin=vmin, **kwds)
elif plot_type == 'hexbin':
plotted = axes.hexbin(sources, targets, C=values, **kwds)

if values is not None or plot_type == 'hexbin':
locatable_axes = make_axes_locatable(axes)
cax = locatable_axes.append_axes('right', size='7.5%', pad=0.05)
plt.colorbar(s, cax=cax)
plt.colorbar(plotted, cax=cax)
if var_name is None:
if var_unit is not None:
cax.set_ylabel('in units of %s' % str(var_unit))
Expand Down
18 changes: 18 additions & 0 deletions brian2tools/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import matplotlib
matplotlib.use('Agg')

from brian2 import *
from brian2tools import *

Expand Down Expand Up @@ -50,12 +51,29 @@ def test_plot_synapses():
close()
plot_synapses(synapses.i, synapses.j)
close()
plot_synapses(synapses.i, synapses.j, plot_type='scatter')
close()
plot_synapses(synapses.i, synapses.j, plot_type='image')
close()
plot_synapses(synapses.i, synapses.j, plot_type='hexbin')
close()
plot_synapses(synapses.i, synapses.j, synapses.w)
close()
plot_synapses(synapses.i, synapses.j, synapses.w, plot_type='scatter')
close()
plot_synapses(synapses.i, synapses.j, synapses.w, plot_type='image')
close()
plot_synapses(synapses.i, synapses.j, synapses.w, plot_type='hexbin')
close()

synapses.connect('i > 5') # More than one synapse per connection
brian_plot(synapses)
close()
# It should be possible to plot synaptic variables for multiple connections
# with hexbin
plot_synapses(synapses.i, synapses.j, synapses.w, plot_type='hexbin')
close()


def test_plot_morphology():
# Only testing 2D plotting for now
Expand Down
22 changes: 16 additions & 6 deletions dev/doc_tools/plotting_examples/synapses_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from brian2tools import *
fig_dir = '../../../docs_sphinx/images'
brian_plot(synapses)
plt.savefig(os.path.join(fig_dir, 'brian_plot_synapses.svg'))
plt.savefig(os.path.join(fig_dir, 'brian_plot_synapses.png'))
close()

plot_synapses(synapses.i, synapses.j, color='gray', marker='s')
plot_synapses(synapses.i, synapses.j, plot_type='scatter', color='gray', marker='s')
plt.savefig(os.path.join(fig_dir, 'plot_synapses_connections.svg'))
close()

Expand All @@ -35,13 +35,23 @@
close()

ax = plot_synapses(synapses.i, synapses.j, synapses.w, var_name='synaptic weights',
marker='s', cmap='hot')
ax.set_axis_bgcolor('gray')
plot_type='image', cmap='hot')
ax.set_title('Recurrent connections')
synapses.connect(j='i+k for k in sample(-10, 10, p=0.5) if k != 0',
skip_if_invalid=True) # ignore values outside of the limits
plt.savefig(os.path.join(fig_dir, 'plot_synapses_weights_custom.svg'))
plt.savefig(os.path.join(fig_dir, 'plot_synapses_weights_custom.png'))
close()

brian_plot(synapses)
plt.savefig(os.path.join(fig_dir, 'brian_plot_multiple_synapses.svg'))
plt.savefig(os.path.join(fig_dir, 'brian_plot_multiple_synapses.png'))
close()

big_group = NeuronGroup(10000, '')
many_synapses = Synapses(big_group, big_group)
many_synapses.connect(j='i+k for k in range(-2000, 2000) if rand() < exp(-(k/1000.)**2)',
skip_if_invalid=True)
brian_plot(many_synapses)
plt.savefig(os.path.join(fig_dir, 'brian_plot_synapses_big.png'))
close()


3 changes: 2 additions & 1 deletion docs_sphinx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@
# Create api docs
def run_apidoc(_):
import sphinx.apidoc as apidoc
apidoc.main(argv=['sphinx-apidoc', '-f', '-e', '-M', '-o', './reference', '../brian2tools'])
apidoc.main(argv=['sphinx-apidoc', '-f', '-e', '-M', '-o', './reference',
'../brian2tools', '../brian2tools/tests'])

def setup(app):
app.connect('builder-inited', run_apidoc)
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1e6d6e4

Please sign in to comment.