-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add skeleton model * add tests for node/edge attributes * fix import
- Loading branch information
1 parent
1505ee1
commit 96d97e3
Showing
6 changed files
with
388 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Callable, Tuple | ||
|
||
import networkx as nx | ||
import numpy as np | ||
import pytest | ||
from morphosamplers.spline import Spline3D | ||
|
||
from morphometrics.skeleton.constants import ( | ||
EDGE_COORDINATES_KEY, | ||
EDGE_SPLINE_KEY, | ||
NODE_COORDINATE_KEY, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def make_bare_skeleton_graph() -> Callable[[], Tuple[nx.Graph, np.ndarray]]: | ||
"""Make a skeleton graph with no properties.""" | ||
|
||
def factory_function(): | ||
node_coordinates = np.array( | ||
[ | ||
[10, 25, 25], | ||
[20, 25, 25], | ||
[40, 35, 25], | ||
[40, 15, 25], | ||
[20, 30, 30], | ||
[20, 45, 45], | ||
] | ||
) | ||
edges = [(0, 1), (1, 2), (1, 3), (4, 5)] | ||
skeleton_graph = nx.Graph(edges) | ||
return skeleton_graph, node_coordinates | ||
|
||
return factory_function | ||
|
||
|
||
@pytest.fixture | ||
def make_valid_skeleton_graph(make_bare_skeleton_graph) -> Callable[[], nx.Graph]: | ||
"""Make a skeleton graph with all required properties.""" | ||
|
||
def factory_function() -> nx.Graph: | ||
skeleton_graph, node_coordinates = make_bare_skeleton_graph() | ||
|
||
# add the node coordinates | ||
node_attributes = {} | ||
for node_index in skeleton_graph.nodes(data=False): | ||
node_attributes[node_index] = { | ||
NODE_COORDINATE_KEY: node_coordinates[node_index] | ||
} | ||
nx.set_node_attributes(skeleton_graph, node_attributes) | ||
|
||
# add the edge properties | ||
edge_attributes = {} | ||
for start_node, end_node in skeleton_graph.edges(data=False): | ||
start_point = node_coordinates[start_node] | ||
end_point = node_coordinates[end_node] | ||
line_length = np.linalg.norm(end_point - start_point) | ||
n_skeleton_points = int(line_length) // 2 | ||
edge_coordinates = np.linspace( | ||
start_point, end_point, n_skeleton_points | ||
).astype(int) | ||
edge_spline = Spline3D( | ||
points=edge_coordinates, | ||
) | ||
edge_attributes[(start_node, end_node)] = { | ||
EDGE_COORDINATES_KEY: node_coordinates[[start_node, end_node]], | ||
EDGE_SPLINE_KEY: edge_spline, | ||
} | ||
nx.set_edge_attributes(skeleton_graph, edge_attributes) | ||
return skeleton_graph | ||
|
||
return factory_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import networkx as nx | ||
import numpy as np | ||
|
||
from morphometrics.skeleton.constants import NODE_COORDINATE_KEY | ||
from morphometrics.skeleton.model import Skeleton3D | ||
|
||
|
||
def test_skeleton_model_instantiation(make_valid_skeleton_graph): | ||
"""Test that the Skeleton3D class can be instantiated.""" | ||
graph = make_valid_skeleton_graph() | ||
skeleton = Skeleton3D(graph=graph) | ||
|
||
# the graph should be a copy of the original graph | ||
assert skeleton.graph is not graph | ||
|
||
# the graph should be identical to the original graph | ||
assert skeleton.graph.edges(data=False) == graph.edges(data=False) | ||
|
||
|
||
def test_skeleton_model_parse(make_valid_skeleton_graph): | ||
"""Test that the Skeleton3D class can be instantiated with the parser.""" | ||
|
||
# get the skeleton graph and node coordinates | ||
skeleton = make_valid_skeleton_graph() | ||
|
||
# update an edge attribute | ||
nx.set_edge_attributes(skeleton, {(0, 1): {"validated": False}}) | ||
|
||
# add an edge attribute that should be deleted by the parser | ||
# because it isn't passed as an edge_attribute | ||
bad_attribute_key = "bad_attribute" | ||
bad_attribute_edge = (0, 1) | ||
nx.set_edge_attributes(skeleton, {bad_attribute_edge: {bad_attribute_key: False}}) | ||
|
||
# add a node attribute | ||
nx.set_node_attributes(skeleton, {0: {"good_node": True}}) | ||
|
||
# add a node attribute that should be deleted by the parser | ||
# because it isn't passed as a node_attribute | ||
bad_attribute_node = 0 | ||
nx.set_node_attributes(skeleton, {bad_attribute_node: {bad_attribute_key: True}}) | ||
|
||
# parse the skeleton | ||
scale = (1, 2, 1) | ||
parsed_skeleton = Skeleton3D.parse( | ||
graph=skeleton, | ||
edge_attributes={"validated": True}, | ||
node_attributes={"good_node": False}, | ||
edge_coordinates_key="edge_coordinates", | ||
node_coordinate_key="node_coordinate", | ||
scale=scale, | ||
) | ||
|
||
# verify that the graph is a copy | ||
assert parsed_skeleton.graph is not skeleton | ||
|
||
# check that the original edge attributes were not overwritten | ||
assert parsed_skeleton.graph.edges[(0, 1)]["validated"] is False | ||
|
||
# check the default value was given to the edge attribute with missing values | ||
assert parsed_skeleton.graph.edges[(1, 2)]["validated"] is True | ||
|
||
# check the edge attribute not passed in edge_attributes is removed | ||
assert bad_attribute_key not in parsed_skeleton.graph.edges[bad_attribute_edge] | ||
|
||
# check that the original node attributes were not overwritten | ||
assert parsed_skeleton.nodes()[0]["good_node"] is True | ||
|
||
# check that the rest of the nodes got the default value | ||
assert parsed_skeleton.nodes()[1]["good_node"] is False | ||
|
||
# check that the node attribute not in node_attributes is removed | ||
assert bad_attribute_key not in parsed_skeleton.nodes()[bad_attribute_node] | ||
|
||
# check that the scale was applied to the node coordinates | ||
original_node_coordinates = np.stack( | ||
[node_data[NODE_COORDINATE_KEY] for _, node_data in skeleton.nodes(data=True)] | ||
) | ||
parsed_node_coordinates = np.stack( | ||
[ | ||
node_data[NODE_COORDINATE_KEY] | ||
for _, node_data in parsed_skeleton.nodes(data=True) | ||
] | ||
) | ||
scaled_original_coordinates = original_node_coordinates * scale | ||
np.testing.assert_allclose(parsed_node_coordinates, scaled_original_coordinates) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
NODE_COORDINATE_KEY = "node_coordinate" | ||
EDGE_COORDINATES_KEY = "edge_coordinates" | ||
EDGE_SPLINE_KEY = "edge_spline" | ||
NODE_INDEX_KEY = "node_index" | ||
|
||
EDGE_FEATURES_START_NODE_KEY = "start_node" | ||
EDGE_FEATURES_END_NODE_KEY = "end_node" | ||
EDGE_FEATURES_HIGHLIGHT_KEY = "highlight" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
from copy import deepcopy | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from morphosamplers.sampler import ( | ||
generate_2d_grid, | ||
place_sampling_grids, | ||
sample_volume_at_coordinates, | ||
) | ||
from morphosamplers.spline import Spline3D | ||
|
||
from morphometrics.skeleton.constants import ( | ||
EDGE_COORDINATES_KEY, | ||
EDGE_SPLINE_KEY, | ||
NODE_COORDINATE_KEY, | ||
) | ||
|
||
|
||
class Skeleton3D: | ||
def __init__(self, graph: nx.Graph): | ||
self.graph = deepcopy(graph) | ||
|
||
def nodes(self, data: bool = True): | ||
"""Passthrough for nx.Graph.nodes""" | ||
return self.graph.nodes(data=data) | ||
|
||
def edges(self, data: bool = True): | ||
"""Passthrough for nx.Graph.edges""" | ||
return self.graph.edges(data=data) | ||
|
||
@property | ||
def node_coordinates(self) -> np.ndarray: | ||
"""Coordinates of the nodes. | ||
Index matched to nx.Graph.nodes() | ||
""" | ||
node_data = self.nodes(data=True) | ||
coordinates = [data[NODE_COORDINATE_KEY] for _, data in node_data] | ||
return np.stack(coordinates) | ||
|
||
def sample_points_on_edge( | ||
self, start_node: int, end_node: int, u: List[float], derivative_order: int = 0 | ||
): | ||
spline = self.graph[start_node][end_node][EDGE_SPLINE_KEY] | ||
return spline.sample(u=u, derivative_order=derivative_order) | ||
|
||
def sample_slices_on_edge( | ||
self, | ||
image: np.ndarray, | ||
image_voxel_size: Tuple[float, float, float], | ||
start_node: int, | ||
end_node: int, | ||
slice_pixel_size: float, | ||
slice_width: int, | ||
slice_spacing: float, | ||
interpolation_order: int = 1, | ||
) -> np.ndarray: | ||
# get the spline object | ||
spline = self.graph[start_node][end_node][EDGE_SPLINE_KEY] | ||
|
||
# get the positions along the spline | ||
positions = spline.sample(separation=slice_spacing) | ||
orientations = spline.sample_orientations(separation=slice_spacing) | ||
|
||
# get the sampling coordinates | ||
sampling_shape = (slice_width, slice_width) | ||
grid = generate_2d_grid( | ||
grid_shape=sampling_shape, grid_spacing=(slice_pixel_size, slice_pixel_size) | ||
) | ||
sampling_coords = place_sampling_grids(grid, positions, orientations) | ||
|
||
# convert the sampling coordinates into the image indices | ||
sampling_coords = sampling_coords / np.array(image_voxel_size) | ||
|
||
return sample_volume_at_coordinates( | ||
image, sampling_coords, interpolation_order=interpolation_order | ||
) | ||
|
||
def sample_image_around_node( | ||
self, | ||
node_index: int, | ||
image: np.ndarray, | ||
image_voxel_size: Tuple[float, float, float], | ||
bounding_box_shape: Union[float, Tuple[float, float, float]] = 10, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
"""Extract an axis-aligned bounding box from an image around a node. | ||
Parameters | ||
---------- | ||
node_index : int | ||
The index of the node to sample around. | ||
image : np.ndarray | ||
The image to sample from. | ||
image_voxel_size : Tuple[float, float, float] | ||
Size of the image voxel in each axis. Should convert to the same | ||
scale as the skeleton graph. | ||
bounding_box_shape : Union[float, Tuple[float, float, float]] | ||
The shape of the bounding box to extract. Size should be specified | ||
in the coordinate system of the skeleton. If a single float is provided, | ||
a cube with edge-length bounding_box_shape will be extracted. Otherwise, | ||
provide a tuple with one element for each axis. | ||
Returns | ||
------- | ||
sub_volume : np.ndarray | ||
The extracted bounding box. | ||
bounding_box : np.ndarray | ||
(2, 3) array with the coordinates of the | ||
upper left and lower right hand corners of the bounding box. | ||
""" | ||
|
||
# get the node coordinates | ||
node_coordinate = self.graph.nodes(data=NODE_COORDINATE_KEY)[node_index] | ||
|
||
# convert node coordinate to | ||
graph_to_image_factor = 1 / np.array(image_voxel_size) | ||
node_coordinate_image = node_coordinate * graph_to_image_factor | ||
|
||
# convert the bounding box to image coordinates | ||
if isinstance(bounding_box_shape, int) or isinstance(bounding_box_shape, float): | ||
bounding_box_shape = ( | ||
bounding_box_shape, | ||
bounding_box_shape, | ||
bounding_box_shape, | ||
) | ||
grid_shape = np.asarray(bounding_box_shape) * graph_to_image_factor | ||
bounding_box_min = np.clip( | ||
node_coordinate_image - (grid_shape / 2), a_min=[0, 0, 0], a_max=image.shape | ||
) | ||
bounding_box_max = np.clip( | ||
node_coordinate_image + (grid_shape / 2), a_min=[0, 0, 0], a_max=image.shape | ||
) | ||
bounding_box = np.stack([bounding_box_min, bounding_box_max]).astype(int) | ||
|
||
# sample the image | ||
sub_volume = image[ | ||
bounding_box[0, 0] : bounding_box[1, 0], | ||
bounding_box[0, 1] : bounding_box[1, 1], | ||
bounding_box[0, 2] : bounding_box[1, 2], | ||
] | ||
|
||
return np.asarray(sub_volume), bounding_box | ||
|
||
def shortest_path(self, start_node: int, end_node: int) -> Optional[List[int]]: | ||
return nx.shortest_path(self.graph, source=start_node, target=end_node) | ||
|
||
@classmethod | ||
def parse( | ||
cls, | ||
graph: nx.Graph, | ||
edge_attributes: Optional[Dict[str, Any]] = None, | ||
node_attributes: Optional[Dict[str, Any]] = None, | ||
edge_coordinates_key: str = EDGE_COORDINATES_KEY, | ||
node_coordinate_key: str = NODE_COORDINATE_KEY, | ||
scale: Tuple[float, float, float] = (1, 1, 1), | ||
): | ||
# make a copy of the graph so we don't clobber the original attributes | ||
graph = deepcopy(graph) | ||
|
||
scale = np.asarray(scale) | ||
if edge_attributes is None: | ||
edge_attributes = {} | ||
if node_attributes is None: | ||
node_attributes = {} | ||
|
||
# parse the edge attributes | ||
parsed_edge_attributes = {} | ||
for start_index, end_index, attributes in graph.edges(data=True): | ||
# remove attribute not specified | ||
keys_to_delete = [ | ||
key | ||
for key in attributes | ||
if ((key not in edge_attributes) and (key != edge_coordinates_key)) | ||
] | ||
for key in keys_to_delete: | ||
del attributes[key] | ||
|
||
for expected_key, default_value in edge_attributes.items(): | ||
# add expected keys that are missing | ||
if expected_key not in attributes: | ||
attributes.update({expected_key: default_value}) | ||
|
||
# make the edge spline | ||
coordinates = np.asarray(attributes[edge_coordinates_key]) * scale | ||
spline = Spline3D(points=coordinates) | ||
parsed_edge_attributes.update( | ||
{ | ||
(start_index, end_index): { | ||
EDGE_COORDINATES_KEY: coordinates, | ||
EDGE_SPLINE_KEY: spline, | ||
} | ||
} | ||
) | ||
nx.set_edge_attributes(graph, parsed_edge_attributes) | ||
|
||
# parse the node attributes | ||
parsed_node_attributes = {} | ||
for node_index, attributes in graph.nodes(data=True): | ||
# remove attribute not specified | ||
keys_to_delete = [ | ||
key | ||
for key in attributes | ||
if ((key not in node_attributes) and (key != node_coordinate_key)) | ||
] | ||
for key in keys_to_delete: | ||
del attributes[key] | ||
|
||
for expected_key, default_value in node_attributes.items(): | ||
# add expected keys that are missing | ||
if expected_key not in attributes: | ||
attributes.update({expected_key: default_value}) | ||
# add the node coordinates | ||
coordinate = np.asarray(attributes[node_coordinate_key]) | ||
coordinate = coordinate * scale | ||
|
||
parsed_node_attributes.update( | ||
{node_index: {NODE_COORDINATE_KEY: coordinate}} | ||
) | ||
nx.set_node_attributes(graph, parsed_node_attributes) | ||
|
||
return cls(graph=graph) |