From 6dc1b027fcb3481e7e88458a819d2f7c8eda5036 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 22 Dec 2023 10:56:31 +0100 Subject: [PATCH] initial label mesh viewer --- .../_gui/_qt/multiple_viewer_widget.py | 251 +++++++++++++++++- 1 file changed, 249 insertions(+), 2 deletions(-) diff --git a/src/morphometrics/_gui/_qt/multiple_viewer_widget.py b/src/morphometrics/_gui/_qt/multiple_viewer_widget.py index 14b0899..c7047f5 100644 --- a/src/morphometrics/_gui/_qt/multiple_viewer_widget.py +++ b/src/morphometrics/_gui/_qt/multiple_viewer_widget.py @@ -2,14 +2,19 @@ from typing import List, Optional, Tuple import napari +import numpy as np +from magicgui import magicgui from napari import Viewer from napari.components.viewer_model import ViewerModel -from napari.layers import Image, Labels, Layer +from napari.layers import Image, Labels, Layer, Points, Surface from napari.qt import QtViewer from napari.utils.events.event import WarningEmitter +from napari_threedee.utils.napari_utils import get_dims_displayed from packaging.version import parse as parse_version from qtpy.QtCore import Qt -from qtpy.QtWidgets import QSplitter, QWidget +from qtpy.QtWidgets import QSplitter, QVBoxLayout, QWidget + +from ...utils.surface_utils import binary_mask_to_surface NAPARI_GE_4_16 = parse_version(napari.__version__) > parse_version("0.4.16") @@ -315,3 +320,245 @@ def _property_sync(self, name, event): ) finally: self._block = False + + +class MeshLabelViewerWidget(QSplitter): + """Add a viewer of a mesh of a selected label in 3D.""" + + def __init__(self, viewer: napari.Viewer, side_widget: Optional[QWidget] = None): + super().__init__() + self.main_viewer = viewer + self.surface_layer = None + self.slice_layer = None + + # make the viewer for the mesh + self.mesh_viewer_model = ViewerModel(title="model1") + + # connect the viewer sync events + self._connect_main_viewer_events() + self._connect_ortho_viewer_events() + + # make the qt viewers + qt_viewer, viewer_splitter = self._setup_ortho_view_qt( + self.mesh_viewer_model, viewer + ) + self.mesh_qt_viewer = qt_viewer + self.mesh_viewer_model.dims.ndisplay = 3 + + # make the label selection widget + self.selection_widget = LabelSelectionWidget(mesh_widget=self) + + # add the widgets + self.addWidget(self.selection_widget) + self.addWidget(viewer_splitter) + if side_widget is not None: + self.addWidget(side_widget) + + def _setup_ortho_view_qt( + self, viewer_model: List[ViewerModel], main_viewer: Viewer + ) -> Tuple[QtViewerWrapper, QSplitter]: + # create the QtViewer objects + qt_viewer = QtViewerWrapper(main_viewer, viewer_model) + + # create and populate the QSplitter for the mesh QtViewer + viewer_splitter = QSplitter() + viewer_splitter.setOrientation(Qt.Vertical) + viewer_splitter.addWidget(qt_viewer) + viewer_splitter.setContentsMargins(0, 0, 0, 0) + + return qt_viewer, viewer_splitter + + def _connect_main_viewer_events(self): + """Connect the update functions to the main viewer events. + + These events sync the ortho viewers with changes in the main viewer. + """ + self.main_viewer.dims.events.current_step.connect(self._point_update) + + def _connect_ortho_viewer_events(self): + """Connect the update functions to the orthoviewer events. + + These events sync the main viewer with changes in the ortho viewer. + """ + + def _point_update(self, event): + """Callback from when the dims point is changed.""" + + def update_surface( + self, vertices: np.ndarray, faces: np.ndarray, values: np.ndarray + ): + mesh_data = (vertices, faces, values) + if self.surface_layer is None: + surface_layer = Surface(mesh_data) + self.surface_layer = surface_layer + + # set up the lighting + self.mesh_viewer_model.layers.insert(0, self.surface_layer) + self.mesh_visual = self.mesh_qt_viewer.layer_to_visual[self.surface_layer] + self.mesh_viewer_model.camera.events.angles.connect(self._on_camera_change) + + # setup the points layer + self.main_points_layer = self.main_viewer.add_points(np.empty((1, 3))) + self.mesh_points_layer = Points(np.empty((1, 3))) + self.mesh_viewer_model.layers.insert(1, self.mesh_points_layer) + + # connect the click event and ensure the surface layer is selected + self.surface_layer.mouse_drag_callbacks.append(self._on_mesh_clicK) + self.mesh_viewer_model.layers.selection = {self.surface_layer} + else: + self.surface_layer.data = mesh_data + + def update_segment_bounding_box(self, bounding_box: np.ndarray): + self.segment_bounding_box = bounding_box + + slice_mesh = self._make_slice_mesh() + if self.slice_layer is None: + self.slice_layer = Surface(slice_mesh, opacity=0.7) + self.mesh_viewer_model.layers.insert(1, self.slice_layer) + self.mesh_viewer_model.layers.selection = {self.surface_layer} + self.main_viewer.dims.events.point.connect(self._on_dims_change) + + else: + self.slice_layer.data = slice_mesh + + def _make_slice_mesh(self): + """Make the mesh for the display the slice currently being viewed in the main viewer""" + slice_index = self.main_viewer.dims.point[0] + slice_mesh_vertices = np.array( + [ + [ + slice_index, + self.segment_bounding_box[0, 1], + self.segment_bounding_box[0, 2], + ], + [ + slice_index, + self.segment_bounding_box[0, 1], + self.segment_bounding_box[1, 2], + ], + [ + slice_index, + self.segment_bounding_box[1, 1], + self.segment_bounding_box[1, 2], + ], + [ + slice_index, + self.segment_bounding_box[1, 1], + self.segment_bounding_box[0, 2], + ], + ] + ) + slice_mesh_faces = np.array([[0, 1, 2], [0, 2, 3]]) + slice_mesh_values = np.ones((4,)) + return slice_mesh_vertices, slice_mesh_faces, slice_mesh_values + + def _on_camera_change(self, event=None): + if self.surface_layer is None: + return + + # get the view direction in layer coordinates + view_direction = np.asarray(self.mesh_viewer_model.camera.view_direction) + dims_displayed = get_dims_displayed(self.surface_layer) + layer_view_direction = np.asarray( + self.surface_layer._world_to_data_ray(view_direction) + )[dims_displayed] + + # update the node + self.mesh_visual.node.shading_filter.light_dir = -1 * layer_view_direction[::-1] + + def _on_mesh_clicK(self, layer, event): + """Mouse callback for when clicking in on the mesh in the viewer.""" + _, triangle_index = layer.get_value( + event.position, + view_direction=event.view_direction, + dims_displayed=event.dims_displayed, + world=True, + ) + + if triangle_index is None: + # if the click did not intersect the mesh, don't do anything + return + + candidate_vertices = layer.data[1][triangle_index] + candidate_points = layer.data[0][candidate_vertices] + point = np.mean(candidate_points, axis=0) + + self.main_points_layer.data = point + self.mesh_points_layer.data = point + + def _on_dims_change(self, event): + mesh_data = self._make_slice_mesh() + self.slice_layer.data = mesh_data + + +class LabelSelectionWidget(QWidget): + """Widget for selecting labels to view as a mesh.""" + + def __init__( + self, + mesh_widget: MeshLabelViewerWidget, + labels_layer: Optional[napari.layers.Labels] = None, + ) -> None: + super().__init__() + + # store the widget and layer + self.mesh_widget = mesh_widget + self.labels_layer = labels_layer + + # create the widget to select the labels layer and label index + self._label_selection_widget = magicgui( + self._set_labels_layer, + labels_layer={"choices": self._get_valid_labels_layers}, + call_button="update segment mesh", + ) + + # add widgets to layout + self.setLayout(QVBoxLayout()) + self.layout().addWidget(self._label_selection_widget.native) + + def _set_labels_layer( + self, labels_layer: napari.layers.Labels, label_index: int = 1 + ): + + self._labels_layer = labels_layer + vertices, faces, vertex_values = self._make_segment_mesh( + label_image=labels_layer.data, label_index=label_index + ) + self.mesh_widget.update_surface(vertices, faces, vertex_values) + self.mesh_widget.update_segment_bounding_box( + self._compute_bounding_box( + label_image=labels_layer.data, label_index=label_index + ) + ) + + def _make_segment_mesh( + self, label_image: np.ndarray, label_index: int + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Make a mesh from a selected label""" + label_mask = label_image == label_index + mesh = binary_mask_to_surface(label_mask, n_mesh_smoothing_iterations=0) + + vertices = mesh.vertices + faces = mesh.faces + vertex_values = np.ones((vertices.shape[0],)) + + return vertices, faces, vertex_values + + def _get_valid_labels_layers(self, combo_box) -> List[napari.layers.Labels]: + return [ + layer + for layer in self.mesh_widget.main_viewer.layers + if isinstance(layer, napari.layers.Labels) + ] + + def _compute_bounding_box( + self, label_image: np.ndarray, label_index: int + ) -> np.ndarray: + """Compute the bounding box around the selected label.""" + label_mask = label_image == label_index + + segment_coordinates = np.column_stack(np.where(label_mask)) + min_coordinates = np.min(segment_coordinates, axis=0) + max_coordinates = np.max(segment_coordinates, axis=0) + + return np.stack([min_coordinates, max_coordinates])