Skip to content

Commit

Permalink
Plot morpho improvements (#176)
Browse files Browse the repository at this point in the history
* Added a function to create an empty network figure

* Added `use_last_soma_comp` to use either first or last soma comp as soma

* reorder traces function & x-axis for plot_traces

* Use add_trace instead of deprecated append_trace

* get_soma_trace pass kwargs

* fix `use_last_soma_comp == False` offset and added soma_opacity

* stopped drawing width==0 branches

* Added plot_traces default titles and improved subplot_fig/fig copy

* fixed black
  • Loading branch information
Helveg committed Nov 16, 2020
1 parent e28bef3 commit a48b28d
Showing 1 changed file with 60 additions and 28 deletions.
88 changes: 60 additions & 28 deletions bsb/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np, math, functools
from .morphologies import Compartment
from contextlib import contextmanager
import random
import random, types


class CellTrace:
Expand Down Expand Up @@ -64,6 +64,11 @@ def __len__(self):
def order(self):
self.cells = dict(sorted(self.cells.items(), key=lambda t: t[1].order or 0))

def reorder(self, order):
for o, key in zip(iter(order), self.cells.keys()):
self.cells[key].order = o
self.order()


def _figure(f):
"""
Expand Down Expand Up @@ -140,7 +145,7 @@ def wrapper_function(
set_range=set_range,
swapaxes=swapaxes,
soma_radius=soma_radius,
**kwargs
**kwargs,
)
if set_range:
rng = get_morphology_range(morphology, offset=offset, soma_radius=soma_radius)
Expand Down Expand Up @@ -233,6 +238,11 @@ def plot_network(
return fig


@_network_figure
def network_figure(fig=None, **kwargs):
return fig


@_network_figure
def plot_detailed_network(
network, fig=None, cubic=True, swapaxes=True, show=True, legend=True, ids=None
Expand Down Expand Up @@ -338,20 +348,23 @@ def plot_voxel_cloud(


def get_branch_trace(compartments, offset=[0.0, 0.0, 0.0], color="black", width=1.0):
x = [c.start[0] + offset[0] for c in compartments]
y = [c.start[1] + offset[1] for c in compartments]
z = [c.start[2] + offset[2] for c in compartments]
# Add branch endpoint
x.append(compartments[-1].end[0] + offset[0])
y.append(compartments[-1].end[1] + offset[1])
z.append(compartments[-1].end[2] + offset[2])
if width == 0:
x, y, z = [], [], []
else:
x = [c.start[0] + offset[0] for c in compartments]
y = [c.start[1] + offset[1] for c in compartments]
z = [c.start[2] + offset[2] for c in compartments]
# Add branch endpoint
x.append(compartments[-1].end[0] + offset[0])
y.append(compartments[-1].end[1] + offset[1])
z.append(compartments[-1].end[2] + offset[2])
return go.Scatter3d(
x=x, y=z, z=y, mode="lines", line=dict(width=width, color=color), showlegend=False
)


def get_soma_trace(
soma_radius, offset=[0.0, 0.0, 0.0], color="black", opacity=1, steps=5
soma_radius, offset=[0.0, 0.0, 0.0], color="black", opacity=1, steps=5, **kwargs
):
phi = np.linspace(0, 2 * np.pi, num=steps * 2)
theta = np.linspace(-np.pi / 2, np.pi / 2, num=steps)
Expand All @@ -368,6 +381,7 @@ def get_soma_trace(
opacity=opacity,
color=color,
alphahull=0,
**kwargs,
)


Expand Down Expand Up @@ -413,7 +427,9 @@ def plot_morphology(
color="black",
reduce_branches=False,
soma_radius=None,
soma_opacity=1.0,
segment_radius=1.0,
use_last_soma_comp=True,
):
compartments = np.array(morphology.compartments.copy())
dfs_list = all_depth_first_branches(morphology.get_compartment_network())
Expand All @@ -429,11 +445,15 @@ def plot_morphology(
if isinstance(color, dict) and "soma" not in color:
raise Exception("Please specify a color for the `soma`.")
soma_color = color["soma"] if isinstance(color, dict) else color
soma_comps = [c for c in compartments if "soma" in c.labels]
# Negative bool = -1/0 (True: -1, last soma comp, False: 0, first soma comp)
soma_comp = soma_comps[-use_last_soma_comp]
traces.append(
get_soma_trace(
soma_radius if soma_radius is not None else compartments[0].radius,
offset,
soma_radius if soma_radius is not None else soma_comp.radius,
offset + (soma_comp.end if use_last_soma_comp else soma_comp.start),
soma_color,
opacity=soma_opacity,
)
)
for trace in traces:
Expand Down Expand Up @@ -529,7 +549,7 @@ def plotly_block_faces(
j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
opacity=0.3,
**color_args
**color_args,
)


Expand Down Expand Up @@ -698,7 +718,7 @@ def hdf5_gdf_plot_spike_raster(spike_recorders, input_region=None, fig=None, sho
show=False,
color=colors[l],
input_region=input_region,
**kwargs
**kwargs,
)
if show:
fig.show()
Expand Down Expand Up @@ -737,7 +757,15 @@ def hdf5_gather_voltage_traces(handle, root, groups=None):
traces = CellTraceCollection()
for group in groups:
path = root + group
for name, dataset in handle[path].items():
# If an element of `groups` point to a single set, rather than a group
# catch the exception and construct a single element group from the single set
try:
iter = handle[path].items()
except AttributeError:
target = handle[path]
iter = ((group, target),)
path = root
for name, dataset in iter:
meta = {}
id = int(name.split(".")[0])
meta["id"] = id
Expand All @@ -751,28 +779,32 @@ def hdf5_gather_voltage_traces(handle, root, groups=None):

@_figure
@_input_highlight
def plot_traces(traces, fig=None, show=True, legend=True, mod=None, cutoff=0):
def plot_traces(traces, fig=None, show=True, legend=True, cutoff=0, x=None):
traces.order()
subplots_fig = make_subplots(
cols=1, rows=len(traces), subplot_titles=[trace.title for trace in traces]
cols=1,
rows=len(traces),
subplot_titles=[trace.title for trace in traces],
x_title="Time (ms)",
y_title="Membrane potential (mV)",
)
subplots_fig.update_layout(height=max(len(traces) * 130, 300))

if mod is not None:
mod(subplots_fig)
# Overwrite the layout and grid of the single plot that is handed to us
# to turn it into a subplots figure.
fig._grid_ref = subplots_fig._grid_ref
fig._layout = subplots_fig._layout
for k in dir(subplots_fig):
v = getattr(subplots_fig, k)
if isinstance(v, types.MethodType):
# Unbind subplots_fig methods and bind to fig.
v = v.__func__.__get__(fig)
fig.__dict__[k] = v
fig.update_layout(height=max(len(traces) * 130, 300))
legend_groups = set()
legends = traces.legends
for i, cell_traces in enumerate(traces):
for j, trace in enumerate(cell_traces):
showlegend = legends[j] not in legend_groups
trace.data = trace.data[cutoff:]
fig.append_trace(
data = trace.data[cutoff:]
fig.add_trace(
go.Scatter(
y=trace.data,
x=x,
y=data,
legendgroup=legends[j],
name=legends[j],
showlegend=showlegend,
Expand Down

0 comments on commit a48b28d

Please sign in to comment.