Skip to content

Commit

Permalink
The hdf5_plot_spike_raster function now takes 2 sorting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Helveg committed Jan 9, 2021
1 parent 78d9608 commit 06aa8df
Showing 1 changed file with 48 additions and 39 deletions.
87 changes: 48 additions & 39 deletions bsb/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,17 +628,30 @@ def hdf5_plot_spike_raster(
input_region=None,
show=True,
cutoff=0,
sorted_labels=None,
sorted_ids=None,
cell_type_sort=None,
cell_sort=None,
):
"""
Create a spike raster plot from an HDF5 group of spike recorders.
sorted_labels can be specified to plot population rasters ordered from bottom to top as in the given list.
:param input_region: Specifies an interval ``[min, max]`` on the x axis to highlight
as active input to the simulation.
:type input_region: 2-element list-like
:param show: Immediately plot the result
:type show: bool
:param cutoff: Amount of ms initial simulation to ignore.
:type cutoff: float
:param cell_type_sort: A function to sort the cell types. Must take 2 dictionaries
as arguments, being the raster plot's x values per label and y values per label.
Must return an array labels matching those of the x and y values to order them.
:type cell_type_sort: function-like
:param cell_sort: A function that takes the cell type label and set of ids and returns
a map to sort them.
:type cell_sort: function-like
"""
x_labelled = {}
y_labelled = {}
colors = {}
ids = {}
for cell_id, dataset in spike_recorders.items():
attrs = dict(dataset.attrs)
if len(dataset.shape) == 1 or dataset.shape[1] == 1:
Expand All @@ -656,9 +669,6 @@ def hdf5_plot_spike_raster(
y_labelled[label] = []
if not label in colors:
colors[label] = attrs.get("color", "black")
if not label in ids:
ids[label] = 0
ids[label] += 1
# Add the spike timings on the X axis.
x_labelled[label].extend(times)
# Set the cell id for the Y axis of each added spike timing.
Expand All @@ -669,40 +679,39 @@ def hdf5_plot_spike_raster(
xaxis=dict(title_text="Time (ms)"), yaxis=dict(title_text="Cell (ID)")
)
)
if sorted_labels is None:
sort_by_size = lambda d: {
k: v for k, v in sorted(d.items(), key=lambda i: len(i[1]))
}
sorted_labels = sort_by_size(x_labelled).keys()
if cell_type_sort is None:
# Sorts the cell type dictionary by cell type size
cell_type_sort = lambda x, y: [
k for k, v in sorted(y.items(), key=lambda kv: len(kv[1]))
]
# This lambda maps each unique y value to a sorted index starting from 0
# We define this here so that it can be used as fallback mechanism later
_cell_sort = lambda l, sy: dict(zip(sy, np.argsort(sy)))
if cell_sort is None:
# If no cell sorter is given we use the fallback sorter as default sorter.
cell_sort = _cell_sort

sorted_labels = cell_type_sort(x_labelled, y_labelled)
start_id = 0

for label in sorted_labels:
x = x_labelled[label]
y = y_labelled[label]
if sorted_labels is None:
y = [yi + start_id for yi in y]
start_id += ids[label]
else:
if len(y) > 0:
sy = set(y)
if sorted_ids is None or (label not in sorted_ids.keys()):
# Create a map between the scattered y and ordered y
a = dict(zip(list(set(y)), range(start_id, start_id + len(sy))))
else:
# Create a map between the given sorted_labels and the ordered y
a = dict(
zip(
sorted_ids[label],
range(start_id, start_id + len(sorted_ids[label])),
)
)
len_diff = len(sy) - len(sorted_ids)
if len_diff > 0:
warn(
f"Sorted '{label}' array do not contain all cell ids, {len_diff} {label} omitted from raster."
)
y = [a[l] for l in y]
start_id += len(set(y)) + ids[label]
x = np.array(x_labelled[label])
y = np.array(y_labelled[label])
if len(y) > 0:
uy = np.unique(y)
# Ask the cell sorter to give a map for the unique y values. If it returns
# something Falsy (such as None) we use the default cell sorter.
id_map = cell_sort(label, uy) or _cell_sort(label, uy)
len_diff = len(uy) - len(id_map)
if len_diff > 0:
warn(
f"Sorted '{label}' array do not contain all cell ids, {len_diff} {label} omitted from raster."
)
y_mask = np.isin(y, id_map.keys())
y = y[y_mask]
x = x[x_mask]
# Build a new numpy array using the `id_map` dictionary lookup
y = np.vectorize(id_map.__getitem__)(y) + start_id
start_id += len(uy)
plot_spike_raster(
x,
y,
Expand Down

0 comments on commit 06aa8df

Please sign in to comment.