Skip to content

Commit

Permalink
Fix lagging 3d view for big data in auto color mode (#6411)
Browse files Browse the repository at this point in the history
closes #6397

This array fixes fps performance issues in OpenGL introduced by #3308.
In that PR, the texture type was changed to float32 in order to directly
pass the labels values to the texture. It turns out that OpenGL
performance for float32 textures is much worse than for uint8 textures.

Here we change the code to use uint8 whenever the final number of colors
is less than 255 in automatic coloring mode, or uint16 if the number is
less than 65535.

This is achieved by transforming original data using a modulo-like
operation that avoids the background label landing on 0.

This PR introduces numba dependency, which might not be a long-term
solution. We may try to move this utility to some package that already
contains compiled code. We can revisit the decision if it causes issues
(such as a delay in supporting newer Python versions), and perhaps push
such a function to a compiled dependency such as scikit-image.

This PR also disables caching used for speedup painting until someone starts
painting. It is a significant speedup and reduces memory usage.

---------

Co-authored-by: Juan Nunez-Iglesias <jni@fastmail.com>
Co-authored-by: Matthias Bussonnier <bussonniermatthias@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Nov 14, 2023
1 parent e929eeb commit 41314da
Show file tree
Hide file tree
Showing 17 changed files with 448 additions and 113 deletions.
3 changes: 2 additions & 1 deletion napari/_qt/_tests/test_qt_viewer.py
Expand Up @@ -732,7 +732,8 @@ def test_label_colors_matching_widget(qtbot, make_napari_viewer):
shape = np.array(screenshot.shape[:2])
middle_pixel = screenshot[tuple(shape // 2)]

np.testing.assert_equal(color_box_color, middle_pixel)
assert np.allclose(color_box_color, middle_pixel, atol=1), label
# there is a difference of rounding between the QtColorBox and the screenshot


def test_axes_labels(make_napari_viewer):
Expand Down
2 changes: 1 addition & 1 deletion napari/_qt/layer_controls/qt_labels_controls.py
Expand Up @@ -511,7 +511,7 @@ class QtColorBox(QWidget):
Parameters
----------
layer : napari.layers.Layer
layer : napari.layers.Labels
An instance of a napari layer.
"""

Expand Down
2 changes: 1 addition & 1 deletion napari/_vispy/_tests/test_vispy_labels_layer.py
Expand Up @@ -68,7 +68,7 @@ def test_labels_fill_slice(make_napari_viewer, array_type, qtbot):
QCoreApplication.instance().processEvents()
layer.fill((1, 10, 10), 13, refresh=True)
visual = viewer.window._qt_viewer.layer_to_visual[layer]
assert np.sum(visual.node._data) == 13
assert np.sum(visual.node._data) == 14


@skip_local_popups
Expand Down
64 changes: 57 additions & 7 deletions napari/_vispy/layers/image.py
@@ -1,4 +1,7 @@
from __future__ import annotations

import warnings
from typing import Dict, Optional

import numpy as np
from vispy.color import Colormap as VispyColormap
Expand Down Expand Up @@ -37,15 +40,25 @@ def __init__(self, custom_node: Node = None, texture_format=None) -> None:
texture_format=texture_format,
)

def get_node(self, ndisplay: int) -> Node:
def get_node(
self, ndisplay: int, dtype: Optional[np.dtype] = None
) -> Node:
# Return custom node if we have one.
if self._custom_node is not None:
return self._custom_node

# Return Image or Volume node based on 2D or 3D.
if ndisplay == 2:
return self._image_node
return self._volume_node
res = self._image_node if ndisplay == 2 else self._volume_node
if (
res.texture_format != "auto"
and dtype is not None
and _VISPY_FORMAT_TO_DTYPE[res.texture_format] != dtype
):
# it is a bug to hit this error — it is here to catch bugs
# early when we are creating the wrong nodes or
# textures for our data
raise ValueError("dtype does not match texture_format")
return res


class VispyImageLayer(VispyBaseLayer):
Expand Down Expand Up @@ -103,10 +116,16 @@ def _on_display_change(self, data=None):
parent = self.node.parent
self.node.parent = None
ndisplay = self.layer._slice_input.ndisplay
self.node = self._layer_node.get_node(ndisplay)
self.node = self._layer_node.get_node(
ndisplay, getattr(data, "dtype", None)
)

if data is None:
data = np.zeros((1,) * ndisplay, dtype=np.float32)
texture_format = self.node.texture_format
data = np.zeros(
(1,) * ndisplay,
dtype=get_dtype_from_vispy_texture_format(texture_format),
)

if self.layer._empty:
self.node.visible = False
Expand Down Expand Up @@ -135,6 +154,9 @@ def _set_node_data(self, node, data):

data = fix_data_dtype(data)
ndisplay = self.layer._slice_input.ndisplay
node = self._layer_node.get_node(
ndisplay, getattr(data, "dtype", None)
)

if ndisplay == 3 and self.layer.ndim == 2:
data = np.expand_dims(data, axis=0)
Expand All @@ -147,7 +169,9 @@ def _set_node_data(self, node, data):

# Check if ndisplay has changed current node type needs updating
if (ndisplay == 3 and not isinstance(node, VolumeNode)) or (
ndisplay == 2 and not isinstance(node, ImageNode)
ndisplay == 2
and not isinstance(node, ImageNode)
or node != self.node
):
self._on_display_change(data)
else:
Expand Down Expand Up @@ -301,3 +325,29 @@ def downsample_texture(self, data, MAX_TEXTURE_SIZE):
slices = tuple(slice(None, None, ds) for ds in downsample)
data = data[slices]
return data


_VISPY_FORMAT_TO_DTYPE: Dict[Optional[str], np.dtype] = {
"r8": np.dtype(np.uint8),
"r16": np.dtype(np.uint16),
"r32f": np.dtype(np.float32),
None: np.dtype(np.float32),
}

_DTYPE_TO_VISPY_FORMAT = {v: k for k, v in _VISPY_FORMAT_TO_DTYPE.items()}


def get_dtype_from_vispy_texture_format(format_str: str) -> np.dtype:
"""Get the numpy dtype from a vispy texture format string.
Parameters
----------
format_str : str
The vispy texture format string.
Returns
-------
dtype : numpy.dtype
The numpy dtype corresponding to the vispy texture format string.
"""
return _VISPY_FORMAT_TO_DTYPE.get(format_str, np.dtype(np.float32))
81 changes: 42 additions & 39 deletions napari/_vispy/layers/labels.py
Expand Up @@ -6,14 +6,19 @@
from vispy.color import Colormap as VispyColormap
from vispy.gloo import Texture2D
from vispy.scene.node import Node
from vispy.scene.visuals import create_visual_node
from vispy.visuals.image import ImageVisual
from vispy.visuals.shaders import Function, FunctionChain

from napari._vispy.layers.image import ImageLayerNode, VispyImageLayer
from napari._vispy.layers.image import (
_DTYPE_TO_VISPY_FORMAT,
_VISPY_FORMAT_TO_DTYPE,
ImageLayerNode,
VispyImageLayer,
get_dtype_from_vispy_texture_format,
)
from napari._vispy.utils.gl import get_max_texture_sizes
from napari._vispy.visuals.labels import LabelNode
from napari._vispy.visuals.volume import Volume as VolumeNode
from napari.utils._dtype import vispy_texture_dtype
from napari.utils.colormaps.colormap import minimum_dtype_for_labels

if TYPE_CHECKING:
from napari.layers import Labels
Expand Down Expand Up @@ -56,9 +61,9 @@
uniform sampler2D texture2D_values;
vec4 sample_label_color(float t) {
if (t == $background_value) {
return vec4(0);
}
// VisPy automatically scales uint8 and uint16 to [0, 1].
// this line fixes returns values to their original range.
t = t * $scale;
if (($use_selection) && ($selection != t)) {
return vec4(0);
Expand Down Expand Up @@ -148,7 +153,7 @@ def __init__(
colors,
use_selection=False,
selection=0.0,
background_value=0.0,
scale=1.0,
):
super().__init__(
colors=["w", "w"], controls=None, interpolation='zero'
Expand All @@ -157,7 +162,7 @@ def __init__(
auto_lookup_shader.replace('$color_map_size', str(len(colors)))
.replace('$use_selection', str(use_selection).lower())
.replace('$selection', str(selection))
.replace('$background_value', str(background_value))
.replace('$scale', str(scale))
)


Expand Down Expand Up @@ -434,7 +439,7 @@ def build_textures_from_dict(
class VispyLabelsLayer(VispyImageLayer):
layer: 'Labels'

def __init__(self, layer, node=None, texture_format='r32f') -> None:
def __init__(self, layer, node=None, texture_format='r8') -> None:
super().__init__(
layer,
node=node,
Expand Down Expand Up @@ -473,11 +478,16 @@ def _on_colormap_change(self, event=None):
mode = self.layer.color_mode

if mode == 'auto':
dtype = minimum_dtype_for_labels(self.layer.num_colors + 1)
if issubclass(dtype.type, np.integer):
scale = np.iinfo(dtype).max
else: # float32 texture
scale = 1.0
self.node.cmap = LabelVispyColormap(
colors=colormap.colors,
use_selection=colormap.use_selection,
selection=colormap.selection,
background_value=colormap.background_value,
scale=scale,
)
self.node.shared_program['texture2D_values'] = Texture2D(
colormap.colors.reshape(
Expand Down Expand Up @@ -534,47 +544,40 @@ def _on_partial_labels_update(self, event):
self.node.update()


class LabelVisual(ImageVisual):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _build_color_transform(self):
fun = FunctionChain(
None,
[
Function(self._func_templates['red_to_luminance']),
Function(self.cmap.glsl_map),
],
)
return fun


class LabelLayerNode(ImageLayerNode):
def __init__(self, custom_node: Node = None, texture_format=None):
self._custom_node = custom_node
self._setup_nodes(texture_format)

def _setup_nodes(self, texture_format):
self._image_node = LabelNode(
None
if (texture_format is None or texture_format == 'auto')
else np.array([[0.0]], dtype=np.float32),
else np.array(
[[0.0]],
dtype=get_dtype_from_vispy_texture_format(texture_format),
),
method='auto',
texture_format=texture_format,
)

self._volume_node = VolumeNode(
np.zeros((1, 1, 1), dtype=np.float32),
np.zeros(
(1, 1, 1),
dtype=get_dtype_from_vispy_texture_format(texture_format),
),
clim=[0, 2**23 - 1],
texture_format=texture_format,
)

def get_node(self, ndisplay: int, dtype=None) -> Node:
res = self._image_node if ndisplay == 2 else self._volume_node

BaseLabel = create_visual_node(LabelVisual)


class LabelNode(BaseLabel): # type: ignore [valid-type,misc]
def _compute_bounds(self, axis, view):
if self._data is None:
return None
elif axis > 1: # noqa: RET505
return 0, 0
else:
return 0, self.size[axis]
if (
res.texture_format != "auto"
and dtype is not None
and _VISPY_FORMAT_TO_DTYPE[res.texture_format] != dtype
):
self._setup_nodes(_DTYPE_TO_VISPY_FORMAT[dtype])
return self.get_node(ndisplay, dtype)
return res
2 changes: 1 addition & 1 deletion napari/_vispy/utils/visual.py
Expand Up @@ -183,7 +183,7 @@ def get_view_direction_in_scene_coordinates(
d2 = p1 - p0

# in 3D world coordinates
d3 = d2[0:3]
d3 = d2[:3]
d4 = d3 / np.linalg.norm(d3)

# data are ordered xyz on vispy Volume
Expand Down
4 changes: 3 additions & 1 deletion napari/_vispy/visuals/image.py
@@ -1,8 +1,10 @@
from vispy.scene.visuals import Image as BaseImage

from napari._vispy.visuals.util import TextureMixin


# If data is not present, we need bounds to be None (see napari#3517)
class Image(BaseImage):
class Image(TextureMixin, BaseImage):
def _compute_bounds(self, axis, view):
if self._data is None:
return None
Expand Down
40 changes: 40 additions & 0 deletions napari/_vispy/visuals/labels.py
@@ -0,0 +1,40 @@
from typing import TYPE_CHECKING, Optional, Tuple

from vispy.scene.visuals import create_visual_node
from vispy.visuals.image import ImageVisual
from vispy.visuals.shaders import Function, FunctionChain

from napari._vispy.visuals.util import TextureMixin

if TYPE_CHECKING:
from vispy.visuals.visual import VisualView


class LabelVisual(TextureMixin, ImageVisual):
"""Visual subclass displaying a 2D array of labels."""

def _build_color_transform(self) -> FunctionChain:
"""Build the color transform function chain."""
funcs = [
Function(self._func_templates['red_to_luminance']),
Function(self.cmap.glsl_map),
]

return FunctionChain(
funcs=funcs,
)


BaseLabel = create_visual_node(LabelVisual)


class LabelNode(BaseLabel): # type: ignore [valid-type,misc]
def _compute_bounds(
self, axis: int, view: 'VisualView'
) -> Optional[Tuple[float, float]]:
if self._data is None:
return None
elif axis > 1: # noqa: RET505
return 0, 0
else:
return 0, self.size[axis]
30 changes: 30 additions & 0 deletions napari/_vispy/visuals/util.py
@@ -0,0 +1,30 @@
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from vispy.visuals.visual import Visual
else:

class Visual:
pass


class TextureMixin(Visual):
"""Store texture format passed to VisPy classes.
We need to refer back to the texture format, but VisPy
stores it in a private attribute — ``node._texture.internalformat``.
This mixin is added to our Node subclasses to avoid having to
access private VisPy attributes.
"""

def __init__(self, *args, texture_format: Optional[str], **kwargs) -> None: # type: ignore [no-untyped-def]
super().__init__(*args, texture_format=texture_format, **kwargs)
# classes using this mixin may be frozen dataclasses.
# we save the texture format between unfreeze/freeze.
self.unfreeze()
self._texture_format = texture_format
self.freeze()

@property
def texture_format(self) -> Optional[str]:
return self._texture_format
4 changes: 3 additions & 1 deletion napari/_vispy/visuals/volume.py
@@ -1,5 +1,7 @@
from vispy.scene.visuals import Volume as BaseVolume

from napari._vispy.visuals.util import TextureMixin

FUNCTION_DEFINITIONS = """
// the tolerance for testing equality of floats with floatEqual and floatNotEqual
const float equality_tolerance = 1e-8;
Expand Down Expand Up @@ -200,7 +202,7 @@
rendering_methods['translucent_categorical'] = TRANSLUCENT_CATEGORICAL_SNIPPETS


class Volume(BaseVolume):
class Volume(TextureMixin, BaseVolume):
# add the new rendering method to the snippets dict
_shaders = shaders
_rendering_methods = rendering_methods

0 comments on commit 41314da

Please sign in to comment.