Skip to content

Commit

Permalink
Extend root_rotate to allow rotation from a point of the subtree (#697)
Browse files Browse the repository at this point in the history
* Add yaml dependency node.

* Fix sphinx issue for types in documentation.

* Apply suggestions from code review

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>

* Apply suggestions from code review

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>

* Add test for ValueError.

---------

Co-authored-by: Robin De Schepper <robin.deschepper93@gmail.com>
  • Loading branch information
drodarie and Helveg committed Mar 17, 2023
1 parent a5879b6 commit d91532e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 11 deletions.
62 changes: 51 additions & 11 deletions bsb/morphologies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,11 @@ def rotate(self, rot, center=None):
Point rotation
:param rot: Scipy rotation
:type: :class:`scipy.spatial.transform.Rotation`
:type rot: scipy.spatial.transform.Rotation
:param center: rotation offset point.
:type center: numpy.ndarray
"""

if self._is_shared:
self._shared._points[:] = self._rotate(self._shared._points, rot, center)
else:
Expand All @@ -531,13 +534,37 @@ def _rotate(self, points, rot, center):
rotated_points = rot.apply(points)
return rotated_points

def root_rotate(self, rot):
def root_rotate(self, rot, downstream_of=0):
"""
Rotate the subtree emanating from each root around the start of that root
If downstream_of is provided, will rotate points starting from the index provided (only for
subtrees with a single root).
:param rot: Scipy rotation to apply to the subtree.
:type rot: scipy.spatial.transform.Rotation
:param downstream_of: index of the point in the subtree from which the rotation should be
applied. This feature works only when the subtree has only one root branch.
:returns: rotated Morphology
:rtype: bsb.morphologies.Morphology
"""
for b in self.roots:
group = SubTree([b])
group.rotate(rot, group.origin)

if downstream_of != 0:
if len(self.roots) > 1:
raise ValueError(
"Can't rotate with subbranch precision with multiple roots"
)
elif type(downstream_of) == int and 0 < downstream_of < len(
self.roots[0].points
):
b = self.roots[0]
group = SubTree([b])
upstream = np.copy(b.points[:downstream_of])
group.rotate(rot, b.points[downstream_of])
b.points[:downstream_of] = upstream
else:
for b in self.roots:
group = SubTree([b])
group.rotate(rot, group.origin)
return self

def translate(self, point):
Expand Down Expand Up @@ -989,6 +1016,19 @@ class Branch:
"""

def __init__(self, points, radii, labels=None, properties=None, children=None):
"""
:param points: Array of 3D coordinates defining the point of the branch
:type points: list | numpy.ndarray
:param radii: Array of radii associated to each point
:type radii: list | numpy.ndarray
:param labels: Array of labels to associate to each point
:type labels: EncodedLabels | List[str] | set | numpy.ndarray
:param properties: dictionary of metadata to store in the branch
:type properties: dict
:param children: list of child branches to attach to the branch
:type children: List[bsb.morphologies.Branch]
"""

self._points = _gutil.sanitize_ndarray(points, (-1, 3), float)
self._radii = _gutil.sanitize_ndarray(radii, (-1,), float)
_gutil.assert_samelen(self._points, self._radii)
Expand Down Expand Up @@ -1399,11 +1439,11 @@ def contains_labels(self, labels):

def get_points_labelled(self, labels):
"""
Filter out all points with a certain label
Filter out all points with certain labels
:param label: The label to check for.
:type label: str
:returns: All points with the label.
:param labels: The labels to check for.
:type labels: List[str] | numpy.ndarray[str]
:returns: All points with the labels.
:rtype: List[numpy.ndarray]
"""
return self.points[self.get_label_mask(labels)]
Expand All @@ -1412,8 +1452,8 @@ def get_label_mask(self, labels):
"""
Return a mask for the specified labels
:param label: The label to check for.
:type label: str
:param labels: The labels to check for.
:type labels: List[str] | numpy.ndarray[str]
:returns: A boolean mask that selects out the points that match the label.
:rtype: List[numpy.ndarray]
"""
Expand Down
15 changes: 15 additions & 0 deletions docs/morphologies/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,21 @@ respective root in the tree:
dendrites.root_rotate(r)
Additionally, you can :meth:`root-rotate <.morphologies.SubTree.root_rotate>` from a point of the
subtree instead of its root. In this case, points starting from the point selected will be rotated.

To do so, set the `downstream_of` parameter with the index of the point of your interest.

.. code-block:: python
# rotate all points after the second point in the subtree
# i.e.: points at index 0 and 1 will not be rotated.
dendrites.root_rotate(r, downstream_of=2)
.. note::

This feature can only be applied to subtrees with a single root

Gap closing
-----------

Expand Down
34 changes: 34 additions & 0 deletions tests/test_morphologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,40 @@ def test_chaining(self):
res = m.rotate(r).root_rotate(r).translate([0, 0, 0]).collapse().close_gaps()
self.assertEqual(m, res, "chaining calls should return self")

def test_root_rotate(self):
points = np.array([[0, 0, 0], [1, 1, 0], [0, 4, 0], [0, 6, 0], [2, 4, 8]])
radii = np.array([0, 1, 2, 2, 1])
m = Morphology([Branch(points, radii)])
rot = Rotation.from_euler("x", np.pi)
# rotate from root
rotated = m.copy().root_rotate(rot)
rot_points = np.copy(points)
rot_points[:, 1:] = -rot_points[:, 1:]
expected = Morphology([Branch(rot_points, radii)])
self.assertEqual(rotated, expected)

# rotate from second point
rotated = m.copy().root_rotate(rot, downstream_of=1)
rot_points = np.copy(points)
rot_points[1:, 1:] = 2 * rot_points[1, 1:] - rot_points[1:, 1:]
expected = Morphology([Branch(rot_points, radii)])
self.assertEqual(rotated, expected)

# Wrong point index -> no rotation
expected = Morphology([Branch(points, radii)])
rotated = m.copy().root_rotate(rot, downstream_of=5)
self.assertEqual(rotated, expected)
rotated = m.copy().root_rotate(rot, downstream_of=-1)
self.assertEqual(rotated, expected)
rotated = m.copy().root_rotate(
Rotation.from_euler("x", np.pi), downstream_of="bla"
)
self.assertEqual(rotated, expected)

# More than one root
m = Morphology([Branch(points, radii), Branch(points, radii)])
self.assertRaises(ValueError, m.root_rotate, rot=rot, downstream_of=1)

def test_simplification(self):
def branch_one():
return Branch(
Expand Down

0 comments on commit d91532e

Please sign in to comment.