Skip to content

Commit

Permalink
initial label mesh viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinyamauchi committed Dec 22, 2023
1 parent 9e8ec13 commit 6dc1b02
Showing 1 changed file with 249 additions and 2 deletions.
251 changes: 249 additions & 2 deletions src/morphometrics/_gui/_qt/multiple_viewer_widget.py
Expand Up @@ -2,14 +2,19 @@
from typing import List, Optional, Tuple

import napari
import numpy as np
from magicgui import magicgui
from napari import Viewer
from napari.components.viewer_model import ViewerModel
from napari.layers import Image, Labels, Layer
from napari.layers import Image, Labels, Layer, Points, Surface
from napari.qt import QtViewer
from napari.utils.events.event import WarningEmitter
from napari_threedee.utils.napari_utils import get_dims_displayed
from packaging.version import parse as parse_version
from qtpy.QtCore import Qt
from qtpy.QtWidgets import QSplitter, QWidget
from qtpy.QtWidgets import QSplitter, QVBoxLayout, QWidget

from ...utils.surface_utils import binary_mask_to_surface

NAPARI_GE_4_16 = parse_version(napari.__version__) > parse_version("0.4.16")

Expand Down Expand Up @@ -315,3 +320,245 @@ def _property_sync(self, name, event):
)
finally:
self._block = False


class MeshLabelViewerWidget(QSplitter):
"""Add a viewer of a mesh of a selected label in 3D."""

def __init__(self, viewer: napari.Viewer, side_widget: Optional[QWidget] = None):
super().__init__()
self.main_viewer = viewer
self.surface_layer = None
self.slice_layer = None

# make the viewer for the mesh
self.mesh_viewer_model = ViewerModel(title="model1")

# connect the viewer sync events
self._connect_main_viewer_events()
self._connect_ortho_viewer_events()

# make the qt viewers
qt_viewer, viewer_splitter = self._setup_ortho_view_qt(
self.mesh_viewer_model, viewer
)
self.mesh_qt_viewer = qt_viewer
self.mesh_viewer_model.dims.ndisplay = 3

# make the label selection widget
self.selection_widget = LabelSelectionWidget(mesh_widget=self)

# add the widgets
self.addWidget(self.selection_widget)
self.addWidget(viewer_splitter)
if side_widget is not None:
self.addWidget(side_widget)

def _setup_ortho_view_qt(
self, viewer_model: List[ViewerModel], main_viewer: Viewer
) -> Tuple[QtViewerWrapper, QSplitter]:
# create the QtViewer objects
qt_viewer = QtViewerWrapper(main_viewer, viewer_model)

# create and populate the QSplitter for the mesh QtViewer
viewer_splitter = QSplitter()
viewer_splitter.setOrientation(Qt.Vertical)
viewer_splitter.addWidget(qt_viewer)
viewer_splitter.setContentsMargins(0, 0, 0, 0)

return qt_viewer, viewer_splitter

def _connect_main_viewer_events(self):
"""Connect the update functions to the main viewer events.
These events sync the ortho viewers with changes in the main viewer.
"""
self.main_viewer.dims.events.current_step.connect(self._point_update)

def _connect_ortho_viewer_events(self):
"""Connect the update functions to the orthoviewer events.
These events sync the main viewer with changes in the ortho viewer.
"""

def _point_update(self, event):
"""Callback from when the dims point is changed."""

def update_surface(
self, vertices: np.ndarray, faces: np.ndarray, values: np.ndarray
):
mesh_data = (vertices, faces, values)
if self.surface_layer is None:
surface_layer = Surface(mesh_data)
self.surface_layer = surface_layer

# set up the lighting
self.mesh_viewer_model.layers.insert(0, self.surface_layer)
self.mesh_visual = self.mesh_qt_viewer.layer_to_visual[self.surface_layer]
self.mesh_viewer_model.camera.events.angles.connect(self._on_camera_change)

# setup the points layer
self.main_points_layer = self.main_viewer.add_points(np.empty((1, 3)))
self.mesh_points_layer = Points(np.empty((1, 3)))
self.mesh_viewer_model.layers.insert(1, self.mesh_points_layer)

# connect the click event and ensure the surface layer is selected
self.surface_layer.mouse_drag_callbacks.append(self._on_mesh_clicK)
self.mesh_viewer_model.layers.selection = {self.surface_layer}
else:
self.surface_layer.data = mesh_data

def update_segment_bounding_box(self, bounding_box: np.ndarray):
self.segment_bounding_box = bounding_box

slice_mesh = self._make_slice_mesh()
if self.slice_layer is None:
self.slice_layer = Surface(slice_mesh, opacity=0.7)
self.mesh_viewer_model.layers.insert(1, self.slice_layer)
self.mesh_viewer_model.layers.selection = {self.surface_layer}
self.main_viewer.dims.events.point.connect(self._on_dims_change)

else:
self.slice_layer.data = slice_mesh

def _make_slice_mesh(self):
"""Make the mesh for the display the slice currently being viewed in the main viewer"""
slice_index = self.main_viewer.dims.point[0]
slice_mesh_vertices = np.array(
[
[
slice_index,
self.segment_bounding_box[0, 1],
self.segment_bounding_box[0, 2],
],
[
slice_index,
self.segment_bounding_box[0, 1],
self.segment_bounding_box[1, 2],
],
[
slice_index,
self.segment_bounding_box[1, 1],
self.segment_bounding_box[1, 2],
],
[
slice_index,
self.segment_bounding_box[1, 1],
self.segment_bounding_box[0, 2],
],
]
)
slice_mesh_faces = np.array([[0, 1, 2], [0, 2, 3]])
slice_mesh_values = np.ones((4,))
return slice_mesh_vertices, slice_mesh_faces, slice_mesh_values

def _on_camera_change(self, event=None):
if self.surface_layer is None:
return

# get the view direction in layer coordinates
view_direction = np.asarray(self.mesh_viewer_model.camera.view_direction)
dims_displayed = get_dims_displayed(self.surface_layer)
layer_view_direction = np.asarray(
self.surface_layer._world_to_data_ray(view_direction)
)[dims_displayed]

# update the node
self.mesh_visual.node.shading_filter.light_dir = -1 * layer_view_direction[::-1]

def _on_mesh_clicK(self, layer, event):
"""Mouse callback for when clicking in on the mesh in the viewer."""
_, triangle_index = layer.get_value(
event.position,
view_direction=event.view_direction,
dims_displayed=event.dims_displayed,
world=True,
)

if triangle_index is None:
# if the click did not intersect the mesh, don't do anything
return

candidate_vertices = layer.data[1][triangle_index]
candidate_points = layer.data[0][candidate_vertices]
point = np.mean(candidate_points, axis=0)

self.main_points_layer.data = point
self.mesh_points_layer.data = point

def _on_dims_change(self, event):
mesh_data = self._make_slice_mesh()
self.slice_layer.data = mesh_data


class LabelSelectionWidget(QWidget):
"""Widget for selecting labels to view as a mesh."""

def __init__(
self,
mesh_widget: MeshLabelViewerWidget,
labels_layer: Optional[napari.layers.Labels] = None,
) -> None:
super().__init__()

# store the widget and layer
self.mesh_widget = mesh_widget
self.labels_layer = labels_layer

# create the widget to select the labels layer and label index
self._label_selection_widget = magicgui(
self._set_labels_layer,
labels_layer={"choices": self._get_valid_labels_layers},
call_button="update segment mesh",
)

# add widgets to layout
self.setLayout(QVBoxLayout())
self.layout().addWidget(self._label_selection_widget.native)

def _set_labels_layer(
self, labels_layer: napari.layers.Labels, label_index: int = 1
):

self._labels_layer = labels_layer
vertices, faces, vertex_values = self._make_segment_mesh(
label_image=labels_layer.data, label_index=label_index
)
self.mesh_widget.update_surface(vertices, faces, vertex_values)
self.mesh_widget.update_segment_bounding_box(
self._compute_bounding_box(
label_image=labels_layer.data, label_index=label_index
)
)

def _make_segment_mesh(
self, label_image: np.ndarray, label_index: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Make a mesh from a selected label"""
label_mask = label_image == label_index
mesh = binary_mask_to_surface(label_mask, n_mesh_smoothing_iterations=0)

vertices = mesh.vertices
faces = mesh.faces
vertex_values = np.ones((vertices.shape[0],))

return vertices, faces, vertex_values

def _get_valid_labels_layers(self, combo_box) -> List[napari.layers.Labels]:
return [
layer
for layer in self.mesh_widget.main_viewer.layers
if isinstance(layer, napari.layers.Labels)
]

def _compute_bounding_box(
self, label_image: np.ndarray, label_index: int
) -> np.ndarray:
"""Compute the bounding box around the selected label."""
label_mask = label_image == label_index

segment_coordinates = np.column_stack(np.where(label_mask))
min_coordinates = np.min(segment_coordinates, axis=0)
max_coordinates = np.max(segment_coordinates, axis=0)

return np.stack([min_coordinates, max_coordinates])

0 comments on commit 6dc1b02

Please sign in to comment.