Skip to content

Commit

Permalink
Multi canvas labeler (#50)
Browse files Browse the repository at this point in the history
* slow version working

* add multi labeler

* guard against validated labels not set
  • Loading branch information
kevinyamauchi committed Mar 27, 2023
1 parent 9083bd7 commit 871fb23
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace
exclude: ^.napari-hub/*
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
Expand Down
10 changes: 8 additions & 2 deletions src/morphometrics/_gui/_qt/multiple_viewer_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,16 @@ def _order_update(self):
return

order[-3:] = order[-2], order[-3], order[-1]
self.ortho_viewer_models[0].dims.order = order
self.ortho_viewer_models[1].dims.order = order
order = list(self.viewer.dims.order)
order[-3:] = order[-1], order[-2], order[-3]
self.ortho_viewer_models[1].dims.order = order
self.ortho_viewer_models[0].dims.order = order

# order[-3:] = order[-2], order[-3], order[-1]
# self.ortho_viewer_models[0].dims.order = order
# order = list(self.viewer.dims.order)
# order[-3:] = order[-1], order[-2], order[-3]
# self.ortho_viewer_models[1].dims.order = order

def _layer_added(self, event):
"""add layer to additional viewers and connect all required events.
Expand Down
143 changes: 122 additions & 21 deletions src/morphometrics/_gui/label_curator/label_cleaning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List

import napari
import numpy as np
from napari.utils.colormaps import color_dict_to_colormap
from napari.utils.events import EmitterGroup, Event, EventedSet

if TYPE_CHECKING:
Expand Down Expand Up @@ -36,29 +37,49 @@ def enabled(self, enabled: bool) -> None:
self.events.enabled()

def _on_enable(self):

# check if the validated labels layer is set
validated_labels_layer_set = self._curator_model._ortho_viewers is not None

# add the lables layer events
self._selected_labels.events.changed.connect(self._on_selection_changed)
self._curator_model.labels_layer.mouse_drag_callbacks.append(
self._on_click_selection
)

# add the validated labels layer events
self._selected_validated_labels.events.changed.connect(
self._on_selection_changed
)
self._curator_model.validated_labels_layer.mouse_drag_callbacks.append(
self._on_click_selection
)
if self._curator_model.validated_labels_layer is not None:
self._selected_validated_labels.events.changed.connect(
self._on_selection_changed
)
self._curator_model.validated_labels_layer.mouse_drag_callbacks.append(
self._on_click_selection
)

labels_layer_name = self._curator_model.labels_layer.name
if validated_labels_layer_set:
validated_labels_layer_name = (
self._curator_model.validated_labels_layer.name
)
if validated_labels_layer_set:
for viewer in self._curator_model._ortho_viewers:
viewer.layers[labels_layer_name].mouse_drag_callbacks.append(
self._on_click_selection
)
if validated_labels_layer_set:
viewer.layers[
validated_labels_layer_name
].mouse_drag_callbacks.append(self._on_click_selection)
# set the labels layer coloring mode
self._curator_model.labels_layer.color_mode = "direct"
self._curator_model.labels_layer.color = (
self._curator_model._colormap_manager._default_colormap
)

self._curator_model.validated_labels_layer.color_mode = "direct"
self._curator_model.validated_labels_layer.color = (
self._curator_model._colormap_manager._default_colormap
if validated_labels_layer_set:
self._curator_model.validated_labels_layer.color_mode = "direct"

# set the colormaps
self._set_all_labels_colormamaps(
labels_colormap=self._curator_model._colormap_manager._default_colormap,
validated_labels_colormap=self._curator_model._colormap_manager._default_colormap,
)

def _on_disable(self):
Expand All @@ -76,9 +97,9 @@ def _on_click_selection(self, layer: napari.layers.Labels, event: Event):
world=True,
)

if layer is self._curator_model.labels_layer:
if layer.name == self._curator_model.labels_layer.name:
selection_set = self._selected_labels
elif layer is self._curator_model.validated_labels_layer:
elif layer.name == self._curator_model.validated_labels_layer.name:
selection_set = self._selected_validated_labels
else:
return
Expand All @@ -100,18 +121,98 @@ def _on_selection_changed(self, event: Event):
if not self.enabled:
# don't do anything if not curating
return

selected_labels = event.source
# get the indices of the validated labels
self._curator_model.labels_layer.color = (
new_colormap = (
self._curator_model._colormap_manager.create_highlighted_colormap(
list(self._selected_labels)
list(selected_labels)
)
)

self._curator_model.validated_labels_layer.color = (
self._curator_model._colormap_manager.create_highlighted_colormap(
list(self._selected_validated_labels)
if selected_labels is self._selected_labels:
self._set_labels_layer_colormap(
labels_colormap=new_colormap,
)
)
elif selected_labels is self._selected_validated_labels:
self._set_validated_labels_layer_colormap(labels_colormap=new_colormap)
else:
raise ValueError("unknown selction model")

def _set_all_labels_colormamaps(
self,
labels_colormap: Dict[int, np.ndarray],
validated_labels_colormap: Dict[int, np.ndarray],
) -> None:
"""Set the label colormaps for all viewers"""
self._set_labels_layer_colormap(labels_colormap)
self._set_validated_labels_layer_colormap(validated_labels_colormap)

def _set_labels_layer_colormap(
self, labels_colormap: Dict[int, np.ndarray]
) -> None:
self._curator_model.labels_layer.color = labels_colormap

if self._curator_model._ortho_viewers is None:
# we can return early if there aren't orthoviewers
return

# update the orthoviewers
labels_layer_name = self._curator_model.labels_layer.name

for viewer in self._curator_model._ortho_viewers:
viewer.layers[labels_layer_name].color = labels_colormap

# # set the main viewers
# self._fast_set_labels_colormap(layer=self._curator_model.labels_layer, colormap=labels_colormap)
#
# if self._curator_model._ortho_viewers is None:
# # we can return early if there aren't orthoviewers
# return
#
# # update the orthoviewers
# labels_layer_name = self._curator_model.labels_layer.name
#
# for viewer in self._curator_model._ortho_viewers:
# self._fast_set_labels_colormap(
# layer=viewer.layers[labels_layer_name],
# colormap=labels_colormap
# )

def _set_validated_labels_layer_colormap(
self, labels_colormap: Dict[int, np.ndarray]
) -> None:
if self._curator_model.validated_labels_layer is not None:
self._curator_model.validated_labels_layer.color = labels_colormap

if self._curator_model._ortho_viewers is None:
# we can return early if there aren't orthoviewers
return

# update the orthoviewers
validated_labels_layer_name = self._curator_model.validated_labels_layer.name
for viewer in self._curator_model._ortho_viewers:
viewer.layers[validated_labels_layer_name].color = labels_colormap

def _fast_set_labels_colormap(self, layer, colormap):
if layer._background_label not in colormap:
colormap[layer._background_label] = np.array([0, 0, 0, 0])
if None not in colormap:
colormap[None] = np.array([0, 0, 0, 1])

# colors = {
# label: transform_color(color_str)[0]
# for label, color_str in colormap.items()
# }

layer._color = colormap
# set the colormap
custom_colormap, label_color_index = color_dict_to_colormap(layer.color)
#
# # layer._colormap = custom_colormap
# # layer._label_color_index = label_color_index
# #
# # layer._selected_color = layer.get_color(layer.selected_label)

def merge_selected_labels(self):
"""Merge the selected label values.
Expand Down
6 changes: 4 additions & 2 deletions src/morphometrics/_gui/label_curator/label_curator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from itertools import cycle
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import glasbey
import napari
Expand Down Expand Up @@ -48,7 +48,7 @@ def _initialize_colormaps(
colors_default[:, 0:3] = colors_default[:, 0:3] * 0.7

colors_default_cylcer = cycle(colors_default)
colormap_default = {i: next(colors_default_cylcer) for i in range(5000)}
colormap_default = {i: next(colors_default_cylcer) for i in range(2000)}
colormap_default[0] = np.array([0, 0, 0, 0])
colormap_default[None] = np.array([0, 0, 0, 1])

Expand All @@ -68,11 +68,13 @@ class LabelCurator:
def __init__(
self,
viewer: napari.Viewer,
ortho_viewers: Optional[List[napari.Viewer]] = None,
labels_layer: Optional[Labels] = None,
validated_labels_layer: Optional[Labels] = None,
mode: Union[CurationMode, str] = "paint",
):
self._viewer = viewer
self._ortho_viewers = ortho_viewers
self._mode = CurationMode(mode)
self._labels_layer = None
self._validated_labels_layer = None
Expand Down
35 changes: 32 additions & 3 deletions src/morphometrics/_gui/label_curator/qt_label_curator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QWidget,
)

from morphometrics._gui._qt.multiple_viewer_widget import MultipleViewerWidget
from morphometrics._gui.label_curator.label_cleaning import LabelCleaningModel
from morphometrics._gui.label_curator.label_curator import CurationMode, LabelCurator
from morphometrics.label.image_utils import expand_selected_labels_using_crop
Expand Down Expand Up @@ -215,10 +216,15 @@ def _on_paint_enabled_changed(self, event: Event) -> None:


class QtLabelingWidget(QWidget):
def __init__(self, viewer: napari.Viewer, parent: Optional[QWidget] = None):
def __init__(
self,
main_viewer: napari.Viewer,
ortho_viewers: Optional[List[napari.Viewer]] = None,
parent: Optional[QWidget] = None,
):
super().__init__(parent=parent)
self._viewer = viewer
self._model = LabelCurator(viewer=viewer)
self._viewer = main_viewer
self._model = LabelCurator(viewer=main_viewer, ortho_viewers=ortho_viewers)

# make the label selection widget
self._label_selection_widget = magicgui(
Expand All @@ -228,6 +234,12 @@ def __init__(self, viewer: napari.Viewer, parent: Optional[QWidget] = None):
auto_call=True,
call_button=None,
)
self._viewer.layers.events.inserted.connect(
self._label_selection_widget.reset_choices
)
self._viewer.layers.events.removed.connect(
self._label_selection_widget.reset_choices
)

# get the curation mode widget
self.curation_mode_widget = QtLabelingModeWidget(curator_model=self._model)
Expand All @@ -254,3 +266,20 @@ def _get_valid_labels_layers(self, combo_box) -> List[napari.layers.Labels]:
for layer in self._viewer.layers
if isinstance(layer, napari.layers.Labels)
]


class QtMultiCanvasLabelingWidget(MultipleViewerWidget):
def __init__(self, viewer: napari.Viewer):
super().__init__(viewer)
# self.flood_fill_widget = FloodFillWidget(
# main_viewer=viewer, ortho_viewers=self.ortho_viewer_models
# )
# self.addWidget(self.flood_fill_widget)
self.labeling_widget = QtLabelingWidget(
main_viewer=viewer, ortho_viewers=self.ortho_viewer_models
)
self.addWidget(self.labeling_widget)

self.viewer.axes.visible = True
for model in self.ortho_viewer_models:
model.axes.visible = True

0 comments on commit 871fb23

Please sign in to comment.