Skip to content

Commit

Permalink
Skeleton model (#56)
Browse files Browse the repository at this point in the history
* add skeleton model

* add tests for node/edge attributes

* fix import
  • Loading branch information
kevinyamauchi committed Jun 5, 2023
1 parent 1505ee1 commit 96d97e3
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 0 deletions.
Empty file.
Empty file.
72 changes: 72 additions & 0 deletions src/morphometrics/skeleton/_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Callable, Tuple

import networkx as nx
import numpy as np
import pytest
from morphosamplers.spline import Spline3D

from morphometrics.skeleton.constants import (
EDGE_COORDINATES_KEY,
EDGE_SPLINE_KEY,
NODE_COORDINATE_KEY,
)


@pytest.fixture
def make_bare_skeleton_graph() -> Callable[[], Tuple[nx.Graph, np.ndarray]]:
"""Make a skeleton graph with no properties."""

def factory_function():
node_coordinates = np.array(
[
[10, 25, 25],
[20, 25, 25],
[40, 35, 25],
[40, 15, 25],
[20, 30, 30],
[20, 45, 45],
]
)
edges = [(0, 1), (1, 2), (1, 3), (4, 5)]
skeleton_graph = nx.Graph(edges)
return skeleton_graph, node_coordinates

return factory_function


@pytest.fixture
def make_valid_skeleton_graph(make_bare_skeleton_graph) -> Callable[[], nx.Graph]:
"""Make a skeleton graph with all required properties."""

def factory_function() -> nx.Graph:
skeleton_graph, node_coordinates = make_bare_skeleton_graph()

# add the node coordinates
node_attributes = {}
for node_index in skeleton_graph.nodes(data=False):
node_attributes[node_index] = {
NODE_COORDINATE_KEY: node_coordinates[node_index]
}
nx.set_node_attributes(skeleton_graph, node_attributes)

# add the edge properties
edge_attributes = {}
for start_node, end_node in skeleton_graph.edges(data=False):
start_point = node_coordinates[start_node]
end_point = node_coordinates[end_node]
line_length = np.linalg.norm(end_point - start_point)
n_skeleton_points = int(line_length) // 2
edge_coordinates = np.linspace(
start_point, end_point, n_skeleton_points
).astype(int)
edge_spline = Spline3D(
points=edge_coordinates,
)
edge_attributes[(start_node, end_node)] = {
EDGE_COORDINATES_KEY: node_coordinates[[start_node, end_node]],
EDGE_SPLINE_KEY: edge_spline,
}
nx.set_edge_attributes(skeleton_graph, edge_attributes)
return skeleton_graph

return factory_function
86 changes: 86 additions & 0 deletions src/morphometrics/skeleton/_tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import networkx as nx
import numpy as np

from morphometrics.skeleton.constants import NODE_COORDINATE_KEY
from morphometrics.skeleton.model import Skeleton3D


def test_skeleton_model_instantiation(make_valid_skeleton_graph):
"""Test that the Skeleton3D class can be instantiated."""
graph = make_valid_skeleton_graph()
skeleton = Skeleton3D(graph=graph)

# the graph should be a copy of the original graph
assert skeleton.graph is not graph

# the graph should be identical to the original graph
assert skeleton.graph.edges(data=False) == graph.edges(data=False)


def test_skeleton_model_parse(make_valid_skeleton_graph):
"""Test that the Skeleton3D class can be instantiated with the parser."""

# get the skeleton graph and node coordinates
skeleton = make_valid_skeleton_graph()

# update an edge attribute
nx.set_edge_attributes(skeleton, {(0, 1): {"validated": False}})

# add an edge attribute that should be deleted by the parser
# because it isn't passed as an edge_attribute
bad_attribute_key = "bad_attribute"
bad_attribute_edge = (0, 1)
nx.set_edge_attributes(skeleton, {bad_attribute_edge: {bad_attribute_key: False}})

# add a node attribute
nx.set_node_attributes(skeleton, {0: {"good_node": True}})

# add a node attribute that should be deleted by the parser
# because it isn't passed as a node_attribute
bad_attribute_node = 0
nx.set_node_attributes(skeleton, {bad_attribute_node: {bad_attribute_key: True}})

# parse the skeleton
scale = (1, 2, 1)
parsed_skeleton = Skeleton3D.parse(
graph=skeleton,
edge_attributes={"validated": True},
node_attributes={"good_node": False},
edge_coordinates_key="edge_coordinates",
node_coordinate_key="node_coordinate",
scale=scale,
)

# verify that the graph is a copy
assert parsed_skeleton.graph is not skeleton

# check that the original edge attributes were not overwritten
assert parsed_skeleton.graph.edges[(0, 1)]["validated"] is False

# check the default value was given to the edge attribute with missing values
assert parsed_skeleton.graph.edges[(1, 2)]["validated"] is True

# check the edge attribute not passed in edge_attributes is removed
assert bad_attribute_key not in parsed_skeleton.graph.edges[bad_attribute_edge]

# check that the original node attributes were not overwritten
assert parsed_skeleton.nodes()[0]["good_node"] is True

# check that the rest of the nodes got the default value
assert parsed_skeleton.nodes()[1]["good_node"] is False

# check that the node attribute not in node_attributes is removed
assert bad_attribute_key not in parsed_skeleton.nodes()[bad_attribute_node]

# check that the scale was applied to the node coordinates
original_node_coordinates = np.stack(
[node_data[NODE_COORDINATE_KEY] for _, node_data in skeleton.nodes(data=True)]
)
parsed_node_coordinates = np.stack(
[
node_data[NODE_COORDINATE_KEY]
for _, node_data in parsed_skeleton.nodes(data=True)
]
)
scaled_original_coordinates = original_node_coordinates * scale
np.testing.assert_allclose(parsed_node_coordinates, scaled_original_coordinates)
8 changes: 8 additions & 0 deletions src/morphometrics/skeleton/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
NODE_COORDINATE_KEY = "node_coordinate"
EDGE_COORDINATES_KEY = "edge_coordinates"
EDGE_SPLINE_KEY = "edge_spline"
NODE_INDEX_KEY = "node_index"

EDGE_FEATURES_START_NODE_KEY = "start_node"
EDGE_FEATURES_END_NODE_KEY = "end_node"
EDGE_FEATURES_HIGHLIGHT_KEY = "highlight"
222 changes: 222 additions & 0 deletions src/morphometrics/skeleton/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
from morphosamplers.sampler import (
generate_2d_grid,
place_sampling_grids,
sample_volume_at_coordinates,
)
from morphosamplers.spline import Spline3D

from morphometrics.skeleton.constants import (
EDGE_COORDINATES_KEY,
EDGE_SPLINE_KEY,
NODE_COORDINATE_KEY,
)


class Skeleton3D:
def __init__(self, graph: nx.Graph):
self.graph = deepcopy(graph)

def nodes(self, data: bool = True):
"""Passthrough for nx.Graph.nodes"""
return self.graph.nodes(data=data)

def edges(self, data: bool = True):
"""Passthrough for nx.Graph.edges"""
return self.graph.edges(data=data)

@property
def node_coordinates(self) -> np.ndarray:
"""Coordinates of the nodes.
Index matched to nx.Graph.nodes()
"""
node_data = self.nodes(data=True)
coordinates = [data[NODE_COORDINATE_KEY] for _, data in node_data]
return np.stack(coordinates)

def sample_points_on_edge(
self, start_node: int, end_node: int, u: List[float], derivative_order: int = 0
):
spline = self.graph[start_node][end_node][EDGE_SPLINE_KEY]
return spline.sample(u=u, derivative_order=derivative_order)

def sample_slices_on_edge(
self,
image: np.ndarray,
image_voxel_size: Tuple[float, float, float],
start_node: int,
end_node: int,
slice_pixel_size: float,
slice_width: int,
slice_spacing: float,
interpolation_order: int = 1,
) -> np.ndarray:
# get the spline object
spline = self.graph[start_node][end_node][EDGE_SPLINE_KEY]

# get the positions along the spline
positions = spline.sample(separation=slice_spacing)
orientations = spline.sample_orientations(separation=slice_spacing)

# get the sampling coordinates
sampling_shape = (slice_width, slice_width)
grid = generate_2d_grid(
grid_shape=sampling_shape, grid_spacing=(slice_pixel_size, slice_pixel_size)
)
sampling_coords = place_sampling_grids(grid, positions, orientations)

# convert the sampling coordinates into the image indices
sampling_coords = sampling_coords / np.array(image_voxel_size)

return sample_volume_at_coordinates(
image, sampling_coords, interpolation_order=interpolation_order
)

def sample_image_around_node(
self,
node_index: int,
image: np.ndarray,
image_voxel_size: Tuple[float, float, float],
bounding_box_shape: Union[float, Tuple[float, float, float]] = 10,
) -> Tuple[np.ndarray, np.ndarray]:
"""Extract an axis-aligned bounding box from an image around a node.
Parameters
----------
node_index : int
The index of the node to sample around.
image : np.ndarray
The image to sample from.
image_voxel_size : Tuple[float, float, float]
Size of the image voxel in each axis. Should convert to the same
scale as the skeleton graph.
bounding_box_shape : Union[float, Tuple[float, float, float]]
The shape of the bounding box to extract. Size should be specified
in the coordinate system of the skeleton. If a single float is provided,
a cube with edge-length bounding_box_shape will be extracted. Otherwise,
provide a tuple with one element for each axis.
Returns
-------
sub_volume : np.ndarray
The extracted bounding box.
bounding_box : np.ndarray
(2, 3) array with the coordinates of the
upper left and lower right hand corners of the bounding box.
"""

# get the node coordinates
node_coordinate = self.graph.nodes(data=NODE_COORDINATE_KEY)[node_index]

# convert node coordinate to
graph_to_image_factor = 1 / np.array(image_voxel_size)
node_coordinate_image = node_coordinate * graph_to_image_factor

# convert the bounding box to image coordinates
if isinstance(bounding_box_shape, int) or isinstance(bounding_box_shape, float):
bounding_box_shape = (
bounding_box_shape,
bounding_box_shape,
bounding_box_shape,
)
grid_shape = np.asarray(bounding_box_shape) * graph_to_image_factor
bounding_box_min = np.clip(
node_coordinate_image - (grid_shape / 2), a_min=[0, 0, 0], a_max=image.shape
)
bounding_box_max = np.clip(
node_coordinate_image + (grid_shape / 2), a_min=[0, 0, 0], a_max=image.shape
)
bounding_box = np.stack([bounding_box_min, bounding_box_max]).astype(int)

# sample the image
sub_volume = image[
bounding_box[0, 0] : bounding_box[1, 0],
bounding_box[0, 1] : bounding_box[1, 1],
bounding_box[0, 2] : bounding_box[1, 2],
]

return np.asarray(sub_volume), bounding_box

def shortest_path(self, start_node: int, end_node: int) -> Optional[List[int]]:
return nx.shortest_path(self.graph, source=start_node, target=end_node)

@classmethod
def parse(
cls,
graph: nx.Graph,
edge_attributes: Optional[Dict[str, Any]] = None,
node_attributes: Optional[Dict[str, Any]] = None,
edge_coordinates_key: str = EDGE_COORDINATES_KEY,
node_coordinate_key: str = NODE_COORDINATE_KEY,
scale: Tuple[float, float, float] = (1, 1, 1),
):
# make a copy of the graph so we don't clobber the original attributes
graph = deepcopy(graph)

scale = np.asarray(scale)
if edge_attributes is None:
edge_attributes = {}
if node_attributes is None:
node_attributes = {}

# parse the edge attributes
parsed_edge_attributes = {}
for start_index, end_index, attributes in graph.edges(data=True):
# remove attribute not specified
keys_to_delete = [
key
for key in attributes
if ((key not in edge_attributes) and (key != edge_coordinates_key))
]
for key in keys_to_delete:
del attributes[key]

for expected_key, default_value in edge_attributes.items():
# add expected keys that are missing
if expected_key not in attributes:
attributes.update({expected_key: default_value})

# make the edge spline
coordinates = np.asarray(attributes[edge_coordinates_key]) * scale
spline = Spline3D(points=coordinates)
parsed_edge_attributes.update(
{
(start_index, end_index): {
EDGE_COORDINATES_KEY: coordinates,
EDGE_SPLINE_KEY: spline,
}
}
)
nx.set_edge_attributes(graph, parsed_edge_attributes)

# parse the node attributes
parsed_node_attributes = {}
for node_index, attributes in graph.nodes(data=True):
# remove attribute not specified
keys_to_delete = [
key
for key in attributes
if ((key not in node_attributes) and (key != node_coordinate_key))
]
for key in keys_to_delete:
del attributes[key]

for expected_key, default_value in node_attributes.items():
# add expected keys that are missing
if expected_key not in attributes:
attributes.update({expected_key: default_value})
# add the node coordinates
coordinate = np.asarray(attributes[node_coordinate_key])
coordinate = coordinate * scale

parsed_node_attributes.update(
{node_index: {NODE_COORDINATE_KEY: coordinate}}
)
nx.set_node_attributes(graph, parsed_node_attributes)

return cls(graph=graph)

0 comments on commit 96d97e3

Please sign in to comment.