Skip to content

Commit

Permalink
Merge pull request #47 from brian-team/morphology_plot_improvements
Browse files Browse the repository at this point in the history
Morphology plot improvements
  • Loading branch information
mstimberg committed Dec 2, 2020
2 parents 18a1338 + 59d4964 commit 7dc3a7d
Show file tree
Hide file tree
Showing 7 changed files with 6,493 additions and 37 deletions.
169 changes: 133 additions & 36 deletions brian2tools/plotting/morphology.py
@@ -1,15 +1,20 @@
'''
Module to plot Brian `~brian2.spatialneuron.morphology.Morphology` objects.
'''
from typing import Mapping

import numpy as np

from matplotlib.colors import colorConverter, Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.patches import Circle, Polygon
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from brian2 import Unit, have_same_dimensions
from brian2.spatialneuron.spatialneuron import FlatMorphology
from brian2.units.stdunits import um
from brian2.units.fundamentalunits import fail_for_dimension_mismatch
from brian2.units.fundamentalunits import fail_for_dimension_mismatch, DIMENSIONLESS
from brian2.spatialneuron.morphology import Soma

__all__ = ['plot_morphology', 'plot_dendrogram']
Expand All @@ -20,7 +25,7 @@ def _plot_morphology2D(morpho, axes, colors,
voltage_colormap,
show_diameter=False, show_compartments=True,
color_counter=0):
if values:
if values is not None:
# Determine colors based on compartment values
normed_values = value_norm(values[morpho.indices[:]])
colors = voltage_colormap(normed_values)
Expand Down Expand Up @@ -72,10 +77,22 @@ def _plot_morphology2D(morpho, axes, colors,
colors=colors, color_counter=color_counter+1)


def _plot_morphology3D(morpho, figure, colors, show_diameters=True,
def _plot_morphology3D(morpho, figure, colors, values, value_norm,
value_colormap,
show_diameters=True,
show_compartments=False):
import mayavi.mlab as mayavi
colors = np.vstack(colorConverter.to_rgba(c) for c in colors)
if values is not None:
# calculate color for the soma
vmin, vmax = value_norm
if vmin is None:
vmin = min(values)
if vmax is None:
vmax = max(values)
normed_value = (values[0] - vmin)/(vmax - vmin)
colors = np.vstack(value_colormap([normed_value]))
else:
colors = np.vstack([colorConverter.to_rgba(c) for c in colors])
flat_morpho = FlatMorphology(morpho)
if isinstance(morpho, Soma):
start_idx = 1
Expand Down Expand Up @@ -111,10 +128,14 @@ def _plot_morphology3D(morpho, figure, colors, show_diameters=True,
points[::2, :] = start_points
points[1::2, :] = end_points
# Create the points at start and end of the compartments
if values is not None:
scatter_values = values[start_idx:].repeat(2)
else:
scatter_values = flat_morpho.depth[start_idx:].repeat(2)
src = mayavi.pipeline.scalar_scatter(points[:, 0],
points[:, 1],
points[:, 2],
flat_morpho.depth[start_idx:].repeat(2),
scatter_values,
scale_factor=1)
# Create the lines between compartments
connections = []
Expand All @@ -141,19 +162,27 @@ def _plot_morphology3D(morpho, figure, colors, show_diameters=True,
else:
tubes = mayavi.pipeline.tube(lines, tube_radius=1)
max_depth = max(flat_morpho.depth)
surf = mayavi.pipeline.surface(tubes, colormap='prism', line_width=1,
opacity=0.5,
vmin=0, vmax=max(flat_morpho.depth))
surf.module_manager.scalar_lut_manager.lut.number_of_colors = max_depth + start_idx
cmap = np.int_(np.round(255*colors[np.arange(max_depth + start_idx)%len(colors), :]))
if values is not None:
surf = mayavi.pipeline.surface(tubes, colormap='prism', line_width=1,
opacity=0.5, vmin=vmin, vmax=vmax)
surf.module_manager.scalar_lut_manager.lut.number_of_colors = 256
cmap = np.array(np.vstack(value_colormap(np.linspace(0., 1., num=256, endpoint=True)))*255.,
dtype=np.uint8)
else:
surf = mayavi.pipeline.surface(tubes, colormap='prism', line_width=1,
opacity=0.5,
vmin=0, vmax=max(flat_morpho.depth))
surf.module_manager.scalar_lut_manager.lut.number_of_colors = max_depth + start_idx
cmap = np.int_(np.round(255*colors[np.arange(max_depth + start_idx)%len(colors), :]))
surf.module_manager.scalar_lut_manager.lut.table = cmap
src.update()
return surf


def plot_morphology(morphology, plot_3d=None, show_compartments=False,
show_diameter=False, colors=('darkblue', 'darkred'),
values=None, value_norm=(None, None), value_colormap='hot',
axes=None):
value_colorbar=True, value_unit=None, axes=None):
'''
Plot a given `~brian2.spatialneuron.morphology.Morphology` in 2D or 3D.
Expand Down Expand Up @@ -189,6 +218,18 @@ def plot_morphology(morphology, plot_3d=None, show_compartments=False,
value_colormap : str or matplotlib.colors.Colormap, optional
Desired colormap for plots. Either the name of a standard colormap
or a `.matplotlib.colors.Colormap` instance. Defaults to ``'hot'``.
Note that this uses ``matplotlib`` color maps even for 3D plots with
Mayavi.
value_colorbar : bool or dict, optional
Whether to add a colorbar for the ``values``. Defaults to ``True``,
but will be ignored if no ``values`` are provided. Can also be a
dictionary with the keyword arguments for matplotlib's
`~.matplotlib.figure.Figure.colorbar` method (2D plot), or for
Mayavi's `~.mayavi.mlab.scalarbar` method (3D plot).
value_unit : `Unit`, optional
A `Unit` to rescale the values for display in the colorbar. Does not
have any visible effect if no colorbar is used. If not specified, will
try to determine the "best unit" to itself.
axes : `~matplotlib.axes.Axes` or `~mayavi.core.api.Scene`, optional
A matplotlib `~matplotlib.axes.Axes` (for 2D plots) or mayavi
`~mayavi.core.api.Scene` ( for 3D plots) instance, where the plot will
Expand All @@ -204,11 +245,59 @@ def plot_morphology(morphology, plot_3d=None, show_compartments=False,
# Avoid circular import issues
from brian2tools.plotting.base import (_setup_axes_matplotlib,
_setup_axes_mayavi)

if plot_3d is None:
# Decide whether to use 2d or 3d plotting based on the coordinates
flat_morphology = FlatMorphology(morphology)
plot_3d = any(np.abs(flat_morphology.z) > 1e-12)

if values is not None:
if hasattr(values, 'name'):
value_varname = values.name
else:
value_varname = 'values'
if value_unit is not None:
if not isinstance(value_unit, Unit):
raise TypeError(f'\'value_unit\' has to be a unit but is'
f'\'{type(value_unit)}\'.')
fail_for_dimension_mismatch(value_unit, values,
'The \'value_unit\' arguments needs '
'to have the same dimensions as '
'the \'values\'.')
else:
if have_same_dimensions(values, DIMENSIONLESS):
value_unit = 1.
else:
value_unit = values[:].get_best_unit()
orig_values = values
values = values/value_unit
if isinstance(value_norm, tuple):
if not len(value_norm) == 2:
raise TypeError('Need a (vmin, vmax) tuple for the value '
'normalization, but got a tuple of length '
f'{len(value_norm)}.')
vmin, vmax = value_norm
if vmin is not None:
err_msg = ('The minimum value in \'value_norm\' needs to '
'have the same units as \'values\'.')
fail_for_dimension_mismatch(vmin, orig_values,
error_message=err_msg)
vmin /= value_unit
if vmax is not None:
err_msg = ('The maximum value in \'value_norm\' needs to '
'have the same units as \'values\'.')
fail_for_dimension_mismatch(vmax, orig_values,
error_message=err_msg)
vmax /= value_unit
if plot_3d:
value_norm = (vmin, vmax)
else:
value_norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
value_norm.autoscale_None(values)
elif plot_3d:
raise TypeError('3d plots only support normalizations given by '
'a (min, max) tuple.')
value_colormap = plt.get_cmap(value_colormap)

if plot_3d:
try:
Expand All @@ -217,42 +306,50 @@ def plot_morphology(morphology, plot_3d=None, show_compartments=False,
raise ImportError('3D plotting needs the mayavi library')
axes = _setup_axes_mayavi(axes)
axes.scene.disable_render = True
_plot_morphology3D(morphology, axes, colors=colors,
show_diameters=show_diameter,
show_compartments=show_compartments)
surf = _plot_morphology3D(morphology, axes, colors=colors,
values=values, value_norm=value_norm,
value_colormap=value_colormap,
show_diameters=show_diameter,
show_compartments=show_compartments)
if values is not None and value_colorbar:
if not isinstance(value_colorbar, Mapping):
value_colorbar = {}
if not have_same_dimensions(value_unit, DIMENSIONLESS):
unit_str = f' ({value_unit!s})'
else:
unit_str = ''
if value_varname:
value_colorbar['title'] = f'{value_varname}{unit_str}'
cb = mayavi.scalarbar(surf, **value_colorbar)
# Make text dark gray
cb.title_text_property.color = (0.1, 0.1, 0.1)
cb.label_text_property.color = (0.1, 0.1, 0.1)
axes.scene.disable_render = False
else:
axes = _setup_axes_matplotlib(axes)

if values is not None:
if isinstance(value_norm, tuple):
if not len(value_norm) == 2:
raise TypeError('Need a (vmin, vmax) tuple for the value '
'normalization, but got a tuple of length '
f'{len(value_norm)}.')
vmin, vmax = value_norm
if vmin is not None:
err_msg = ('The minimum value in \'value_norm\' needs to '
'have the same units as \'values\'.')
fail_for_dimension_mismatch(vmin, values,
error_message=err_msg)
if vmax is not None:
err_msg = ('The maximum value in \'value_norm\' needs to '
'have the same units as \'values\'.')
fail_for_dimension_mismatch(vmax, values,
error_message=err_msg)
value_norm = Normalize(vmin=vmin, vmax=vmax, clip=True)
value_norm.autoscale_None(values)
value_colormap = plt.get_cmap(value_colormap)

_plot_morphology2D(morphology, axes, colors,
values, value_norm, value_colormap,
show_compartments=show_compartments,
show_diameter=show_diameter)
axes.set_xlabel('x (um)')
axes.set_ylabel('y (um)')
axes.set_aspect('equal')

if values is not None and value_colorbar:
divider = make_axes_locatable(axes)
cax = divider.append_axes("right", size="5%", pad=0.1)
mappable = ScalarMappable(norm=value_norm, cmap=value_colormap)
mappable.set_array([])
fig = axes.get_figure()
if not isinstance(value_colorbar, Mapping):
value_colorbar = {}
if not have_same_dimensions(value_unit, DIMENSIONLESS):
unit_str = f' ({value_unit!s})'
else:
unit_str = ''
if value_varname:
value_colorbar['label'] = f'{value_varname}{unit_str}'
fig.colorbar(mappable, cax=cax, **value_colorbar)
return axes


Expand Down
61 changes: 60 additions & 1 deletion brian2tools/tests/test_plotting.py
Expand Up @@ -5,11 +5,12 @@
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import pytest

# We avoid "from brian2 import *", as this would also import Brian's test
# function which will then be collected by py.test
from brian2 import (NeuronGroup, SpikeMonitor, PopulationRateMonitor,
StateMonitor, Synapses, run, set_device)
StateMonitor, Synapses, run, set_device, SpatialNeuron, DimensionMismatchError, meter)
from brian2 import Cylinder, Soma, Section
from brian2 import ms, mV, um

Expand Down Expand Up @@ -123,6 +124,64 @@ def test_plot_morphology():
ax = plot_morphology(morpho)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()
ax = plot_morphology(morpho, show_diameter=True)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()
ax = plot_morphology(morpho, show_compartments=True)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()
ax = plot_morphology(morpho, show_diameter=True,
show_compartments=True)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

def test_plot_morphology_values():
set_device('runtime')
# Only testing 2D plotting for now
morpho = Soma(diameter=30*um)
morpho.axon = Cylinder(diameter=10*um, n=10, length=100*um)
morpho.dend = Section(diameter=np.linspace(10, 1, 11)*um, n=10,
length=np.ones(10)*5*um)
morpho = morpho.generate_coordinates()

neuron = SpatialNeuron(morpho, 'Im = 0*amp/meter**2 : amp/meter**2')

# Just checking whether the plotting does not fail with an error and that
# it retuns an Axis object as promised
ax = plot_morphology(morpho, values=neuron.distance, plot_3d=False)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

ax = plot_morphology(morpho, values=neuron.distance,
show_diameter=True, plot_3d=False)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

ax = plot_morphology(morpho, values=neuron.distance,
show_compartments=True, plot_3d=False)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

ax = plot_morphology(morpho, values=neuron.distance,
show_diameter=True,
show_compartments=True, plot_3d=False)
assert isinstance(ax, matplotlib.axes.Axes)
plt.close()

# Check a few wrong usages
with pytest.raises(DimensionMismatchError):
plot_morphology(morpho, values=neuron.distance, value_unit=mV, plot_3d=False)
with pytest.raises(DimensionMismatchError):
plot_morphology(morpho, values=neuron.distance, value_norm=(-65*mV, None),
plot_3d=False)
with pytest.raises(DimensionMismatchError):
plot_morphology(morpho, values=neuron.distance, value_norm=(None, -60*mV),
plot_3d=False)
with pytest.raises(TypeError):
plot_morphology(morpho, values=neuron.distance, value_norm=(0*meter,
1*meter,
2*meter),
plot_3d=False)


if __name__ == '__main__':
Expand Down
24 changes: 24 additions & 0 deletions dev/doc_tools/plotting_examples/morphology_plots.py
Expand Up @@ -19,6 +19,30 @@
plot_morphology(morpho, plot_3d=False)
savefig(os.path.join(fig_dir, 'plot_morphology_2d.svg'))
close()

# Value plots
neuron = SpatialNeuron(morpho, 'Im = 0*amp/meter**2 : amp/meter**2')

figure()
plot_morphology(neuron.morphology, values=neuron.distance,
plot_3d=False)
savefig(os.path.join(fig_dir, 'plot_morphology_values_2d.svg'))
close()

figure()
plot_morphology(neuron.morphology, values=neuron.distance,
value_norm=(50*um, 200*um),
value_colormap='viridis', value_unit=mm,
value_colorbar={'label': 'distance from soma in mm',
'extend': 'both'},
plot_3d=False)
savefig(os.path.join(fig_dir, 'plot_morphology_values_2d_custom.svg'))
close()

plot_morphology(morpho, plot_3d=True, show_compartments=True,
show_diameter=True, colors=('darkblue',))
mayavi.show()

neuron = SpatialNeuron(morpho, 'Im = 0*amp/meter**2 : amp/meter**2')
plot_morphology(morpho, values=neuron.distance, plot_3d=True)
mayavi.show()

0 comments on commit 7dc3a7d

Please sign in to comment.