diff --git a/examples/dev/direct-colormap-aliasing-test.py b/examples/dev/direct-colormap-aliasing-test.py index 0008adb150b..24197b37cf7 100644 --- a/examples/dev/direct-colormap-aliasing-test.py +++ b/examples/dev/direct-colormap-aliasing-test.py @@ -56,7 +56,7 @@ texel_pos_img = np.zeros((1, nb_steps, 4)) texel_pos_img[..., -1] = 1 # alpha for k in range(nb_steps): - grid_position = hash2d_get(k + 1, keys, values)[0] + grid_position = hash2d_get(k + 1, keys)[0] # divide by shape and set to RG values like in shader (tex coords) texel_pos_img[:, k, :2] = (np.array(grid_position) + 0.5) / tex_shape diff --git a/napari/_vispy/_tests/test_vispy_labels.py b/napari/_vispy/_tests/test_vispy_labels.py new file mode 100644 index 00000000000..10a46fa158a --- /dev/null +++ b/napari/_vispy/_tests/test_vispy_labels.py @@ -0,0 +1,121 @@ +from itertools import product +from unittest.mock import patch + +import numpy as np +import pytest + +from napari._vispy.layers.labels import ( + MAX_LOAD_FACTOR, + PRIME_NUM_TABLE, + build_textures_from_dict, + hash2d_get, + idx_to_2d, +) + + +@pytest.fixture(scope='module', autouse=True) +def mock_max_texture_size(): + """When running tests in this file, pretend max texture size is 2^16.""" + with patch('napari._vispy.layers.labels.MAX_TEXTURE_SIZE', 2**16): + yield + + +def test_idx_to_2d(): + assert idx_to_2d(0, (100, 100)) == (0, 0) + assert idx_to_2d(1, (100, 100)) == (0, 1) + assert idx_to_2d(101, (100, 100)) == (1, 1) + assert idx_to_2d(521, (100, 100)) == (5, 21) + assert idx_to_2d(100 * 100 + 521, (100, 100)) == (5, 21) + + +def test_build_textures_from_dict(): + keys, values, collision = build_textures_from_dict( + {1: (1, 1, 1, 1), 2: (2, 2, 2, 2)} + ) + assert not collision + assert keys.shape == (37, 37) + assert values.shape == (37, 37, 4) + assert keys[0, 1] == 1 + assert keys[0, 2] == 2 + assert np.array_equiv(values[0, 1], (1, 1, 1, 1)) + assert np.array_equiv(values[0, 2], (2, 2, 2, 2)) + + +def test_build_textures_from_dict_too_many_labels(monkeypatch): + with pytest.raises(MemoryError): + build_textures_from_dict( + {i: (i, i, i, i) for i in range(1001)}, shape=(10, 10) + ) + monkeypatch.setattr( + "napari._vispy.layers.labels.PRIME_NUM_TABLE", [[61], [127]] + ) + with pytest.raises(MemoryError): + build_textures_from_dict( + {i: (i, i, i, i) for i in range((251**2) // 2)}, + ) + + +def test_size_of_texture_square(): + count = int(127 * 127 * MAX_LOAD_FACTOR) - 1 + keys, values, *_ = build_textures_from_dict( + {i: (i, i, i, i) for i in range(count)} + ) + assert keys.shape == (127, 127) + assert values.shape == (127, 127, 4) + + +def test_size_of_texture_rectangle(): + count = int(128 * 128 * MAX_LOAD_FACTOR) + 5 + keys, values, *_ = build_textures_from_dict( + {i: (i, i, i, i) for i in range(count)} + ) + assert keys.shape == (251, 127) + assert values.shape == (251, 127, 4) + + +def test_build_textures_from_dict_collision(): + keys, values, collision = build_textures_from_dict( + {1: (1, 1, 1, 1), 26: (2, 2, 2, 2), 27: (3, 3, 3, 3)}, shape=(5, 5) + ) + assert collision + assert keys.shape == (5, 5) + assert keys[0, 1] == 1 + assert keys[0, 2] == 26 + assert keys[0, 3] == 27 + assert np.array_equiv(values[0, 1], (1, 1, 1, 1)) + assert np.array_equiv(values[0, 2], (2, 2, 2, 2)) + assert np.array_equiv(values[0, 3], (3, 3, 3, 3)) + + assert hash2d_get(1, keys) == (0, 1) + assert hash2d_get(26, keys) == (0, 2) + assert hash2d_get(27, keys) == (0, 3) + + +def test_collide_keys(): + base_keys = [x * y for x, y in product(PRIME_NUM_TABLE[0], repeat=2)] + colors = {0: (0, 0, 0, 0), 1: (1, 1, 1, 1)} + colors.update({i + 10: (1, 0, 0, 1) for i in base_keys}) + colors.update({2 * i + 10: (0, 1, 0, 1) for i in base_keys}) + keys, values, collision = build_textures_from_dict(colors) + assert not collision + assert keys.shape == (37, 61) + assert values.shape == (37, 61, 4) + + +def test_collide_keys2(): + base_keys = [x * y for x, y in product(PRIME_NUM_TABLE[0], repeat=2)] + [ + x * y for x, y in product(PRIME_NUM_TABLE[0], PRIME_NUM_TABLE[1]) + ] + colors = {0: (0, 0, 0, 0), 1: (1, 1, 1, 1)} + colors.update({i + 10: (1, 0, 0, 1) for i in base_keys}) + colors.update({2 * i + 10: (0, 1, 0, 1) for i in base_keys}) + + # enforce collision for collision table of size 31 + colors.update({31 * i + 10: (0, 0, 1, 1) for i in base_keys}) + # enforce collision for collision table of size 29 + colors.update({29 * i + 10: (0, 0, 1, 1) for i in base_keys}) + + keys, values, collision = build_textures_from_dict(colors) + assert collision + assert keys.shape == (37, 37) + assert values.shape == (37, 37, 4) diff --git a/napari/_vispy/layers/image.py b/napari/_vispy/layers/image.py index f4090ffbb26..6e703f79b5c 100644 --- a/napari/_vispy/layers/image.py +++ b/napari/_vispy/layers/image.py @@ -167,7 +167,7 @@ def _on_depiction_change(self): if isinstance(self.node, VolumeNode): self.node.raycasting_mode = str(self.layer.depiction) - def _on_colormap_change(self): + def _on_colormap_change(self, event=None): self.node.cmap = VispyColormap(*self.layer.colormap) def _update_mip_minip_cutoff(self): diff --git a/napari/_vispy/layers/labels.py b/napari/_vispy/layers/labels.py index e932c59d2f2..d1d2dcd1a64 100644 --- a/napari/_vispy/layers/labels.py +++ b/napari/_vispy/layers/labels.py @@ -1,3 +1,7 @@ +from itertools import product +from math import ceil, isnan, log2, sqrt +from typing import Dict, Optional, Tuple, Union + import numpy as np from vispy.color import Colormap as VispyColormap from vispy.gloo import Texture2D @@ -7,7 +11,36 @@ from vispy.visuals.shaders import Function, FunctionChain from napari._vispy.layers.image import ImageLayerNode, VispyImageLayer +from napari._vispy.utils.gl import get_max_texture_sizes from napari._vispy.visuals.volume import Volume as VolumeNode +from napari.utils._dtype import vispy_texture_dtype + +# We use table sizes that are prime numbers near powers of 2. +# For each power of 2, we keep three candidate sizes. This allows us to +# maximize the chances of finding a collision-free table for a given set of +# keys (which we typically know at build time). +PRIME_NUM_TABLE = [ + [37, 31, 29], + [61, 59, 53], + [127, 113, 109], + [251, 241, 239], + [509, 503, 499], + [1021, 1019, 1013], + [2039, 2029, 2027], + [4093, 4091, 4079], + [8191, 8179, 8171], + [16381, 16369, 16363], + [32749, 32719, 32717], + [65521, 65519, 65497], +] + +START_TWO_POWER = 5 + +MAX_LOAD_FACTOR = 0.25 + +MAX_TEXTURE_SIZE = None + +ColorTuple = Tuple[float, float, float, float] low_disc_lookup_shader = """ uniform sampler2D texture2D_LUT; @@ -39,6 +72,7 @@ uniform sampler2D texture2D_keys; uniform sampler2D texture2D_values; uniform vec2 LUT_shape; +uniform int color_count; vec4 sample_label_color(float t) { @@ -68,8 +102,14 @@ // we get a different value: // - if it's the empty key, exit; // - otherwise, it's a hash collision: continue searching - while ((abs(found - t) > 1e-8) && (abs(found - empty) > 1e-8)) { - t = t + 1; + float initial_t = t; + int count = 0; + while ((abs(found - initial_t) > 1e-8) && (abs(found - empty) > 1e-8)) { + count = count + 1; + t = initial_t + float(count); + if (count >= color_count) { + return vec4(0); + } // same as above vec2 pos = vec2( mod(int(t / LUT_shape.y), LUT_shape.x), @@ -120,60 +160,255 @@ def __init__( self, use_selection=False, selection=0.0, + collision=True, ): colors = ['w', 'w'] # dummy values, since we use our own machinery super().__init__(colors, controls=None, interpolation='zero') - self.glsl_map = direct_lookup_shader.replace( - '$use_selection', str(use_selection).lower() - ).replace('$selection', str(selection)) + self.glsl_map = ( + direct_lookup_shader.replace( + "$use_selection", str(use_selection).lower() + ) + .replace("$selection", str(selection)) + .replace("$collision", str(collision).lower()) + ) -def idx_to_2D(idx, shape): +def idx_to_2d(idx, shape): """ From a 1D index generate a 2D index that fits the given shape. The 2D index will wrap around line by line and back to the beginning. """ - return (idx // shape[1]) % shape[0], (idx % shape[1]) + return int((idx // shape[1]) % shape[0]), int(idx % shape[1]) -def hash2d_get(key, keys, values, empty_val=0): +def hash2d_get(key, keys, empty_val=0): """ Given a key, retrieve its location in the keys table. """ - pos = idx_to_2D(key, keys.shape) + pos = idx_to_2d(key, keys.shape) initial_key = key - while keys[pos] != key and keys[pos] != empty_val: + while keys[pos] != initial_key and keys[pos] != empty_val: if key - initial_key > keys.size: raise KeyError('label does not exist') key += 1 - pos = idx_to_2D(key, keys.shape) - return pos if keys[pos] == key else None + pos = idx_to_2d(key, keys.shape) + return pos if keys[pos] == initial_key else None -def hash2d_set(key, value, keys, values, empty_val=0): +def hash2d_set( + key: Union[float, np.floating], + value: ColorTuple, + keys: np.ndarray, + values: np.ndarray, + empty_val=0, +) -> bool: """ Set a value in the 2d hashmap, wrapping around to avoid collision. """ - if key is None: - return - pos = idx_to_2D(key, keys.shape) + if key is None or isnan(key): + return False + pos = idx_to_2d(key, keys.shape) initial_key = key + collision = False while keys[pos] != empty_val: + collision = True if key - initial_key > keys.size: raise OverflowError('too many labels') key += 1 - pos = idx_to_2D(key, keys.shape) - keys[pos] = key + pos = idx_to_2d(key, keys.shape) + keys[pos] = initial_key values[pos] = value + return collision + + +def _get_shape_from_keys( + keys: np.ndarray, first_dim_index: int, second_dim_index: int +) -> Optional[Tuple[int, int]]: + """Get the smallest hashmap size without collisions, if any. + + This function uses precomputed prime numbers from PRIME_NUM_TABLE. + + For each index, it gets a list of prime numbers close to + ``2**(index + START_TWO_POWER)`` (where ``START_TWO_POWER=5``), that is, + the smallest table is close to ``32 * 32``. + + The function then iterates over all combinations of prime numbers from the + lists and checks for a combination that has no collisions for the + given keys, returning that combination. + + If no combination can be found, returns None. -def build_textures_from_dict(color_dict, empty_val=0, shape=(1000, 1000)): - keys = np.full(shape, empty_val, dtype=np.float32) - values = np.zeros(shape + (4,), dtype=np.float32) + Although keys that collide for all table combinations are rare, they are + possible: see ``test_collide_keys`` and ``test_collide_keys2``. + + Parameters + ---------- + keys: np.ndarray + array of keys to be inserted into the hashmap, + used for collision detection + first_dim_index: int + index for first dimension of PRIME_NUM_TABLE + second_dim_index: int + index for second dimension of PRIME_NUM_TABLE + + Returns + shp : 2-tuple of int, optional + If a table shape can be found that has no collisions for the given + keys, return that shape. Otherwise, return None. + """ + for fst_size, snd_size in product( + PRIME_NUM_TABLE[first_dim_index], + PRIME_NUM_TABLE[second_dim_index], + ): + fst_crd = (keys // snd_size) % fst_size + snd_crd = keys % snd_size + + collision_set = set(zip(fst_crd, snd_crd)) + if len(collision_set) == len(set(keys)): + return fst_size, snd_size + return None + + +def _get_shape_from_dict( + color_dict: Dict[float, Tuple[float, float, float, float]] +) -> Tuple[int, int]: + """Compute the shape of a 2D hashmap based on the keys in `color_dict`. + + This function finds indices for the first and second dimensions of a + table in PRIME_NUM_TABLE based on a target load factor of 0.125-0.25, + then calls `_get_shape_from_keys` based on those indices. + + This is quite a low load-factor, but, ultimately, the hash table + textures are tiny compared to most datasets, so we choose these + factors to minimize the chance of collisions and trade a bit of GPU + memory for speed. + """ + keys = np.array([x for x in color_dict if x is not None], dtype=np.int64) + + size = len(keys) / MAX_LOAD_FACTOR + size_sqrt = sqrt(size) + size_log2 = log2(size_sqrt) + max_idx = len(PRIME_NUM_TABLE) - 1 + max_size = PRIME_NUM_TABLE[max_idx][0] ** 2 + fst_dim = min(max(int(ceil(size_log2)) - START_TWO_POWER, 0), max_idx) + snd_dim = min(max(int(round(size_log2, 0)) - START_TWO_POWER, 0), max_idx) + + if len(keys) > max_size: + raise MemoryError( + f'Too many labels: napari supports at most {max_size} labels, ' + f'got {len(keys)}.' + ) + + shp = _get_shape_from_keys(keys, fst_dim, snd_dim) + if shp is None and snd_dim < max_idx: + # if we still have room to grow, try the next size up to get a + # collision-free table + shp = _get_shape_from_keys(keys, fst_dim, snd_dim + 1) + if shp is None: + # at this point, if there's still collisions, we give up and return + # the largest possible table given these indices and the target load + # factor. + # (To see a set of keys that cause collision, + # and land on this branch, see test_collide_keys2.) + shp = PRIME_NUM_TABLE[fst_dim][0], PRIME_NUM_TABLE[snd_dim][0] + return shp + + +def get_shape_from_dict(color_dict): + global MAX_TEXTURE_SIZE + if MAX_TEXTURE_SIZE is None: + MAX_TEXTURE_SIZE = get_max_texture_sizes()[0] + + shape = _get_shape_from_dict(color_dict) + + if MAX_TEXTURE_SIZE is not None and ( + shape[0] > MAX_TEXTURE_SIZE or shape[1] > MAX_TEXTURE_SIZE + ): + raise MemoryError( + f'Too many labels. GPU does not support textures of this size.' + f' Requested size is {shape[0]}x{shape[1]}, but maximum supported' + f' size is {MAX_TEXTURE_SIZE}x{MAX_TEXTURE_SIZE}' + ) + return shape + + +def build_textures_from_dict( + color_dict: Dict[float, ColorTuple], + empty_val=0, + shape=None, + use_selection=False, + selection=0.0, +) -> Tuple[np.ndarray, np.ndarray, bool]: + """ + This function construct hash table for fast lookup of colors. + It uses pair of textures. + First texture is a table of keys, used to determine position, + second is a table of values. + + The procedure of selection table and collision table is + implemented in hash2d_get function. + + Parameters + ---------- + color_dict: Dict[float, Tuple[float, float, float, float]] + Dictionary from labels to colors + empty_val: float + Value to use for empty cells in the hash table + shape: Optional[Tuple[int, int]] + Shape of the hash table. + If None, it is calculated from the number of + labels using _get_shape_from_dict + use_selection: bool + If True, only the selected label is shown. + The generated colormap is single-color of size (1, 1) + selection: float + used only if use_selection is True. + Determines the selected label. + + Returns + ------- + keys: np.ndarray + Texture of keys for the hash table + values: np.ndarray + Texture of values for the hash table + collision: bool + True if there are collisions in the hash table + """ + if use_selection: + keys = np.full((1, 1), selection, dtype=vispy_texture_dtype) + values = np.zeros((1, 1, 4), dtype=vispy_texture_dtype) + values[0, 0] = color_dict[selection] + return keys, values, False + + if len(color_dict) > 2**31 - 2: + raise MemoryError( + f'Too many labels ({len(color_dict)}). Maximum supported number of labels is 2^31-2' + ) + + if shape is None: + shape = get_shape_from_dict(color_dict) + + if len(color_dict) > shape[0] * shape[1]: + raise MemoryError( + f'Too many labels ({len(color_dict)}). Maximum supported number of labels for the given shape is {shape[0] * shape[1]}' + ) + + keys = np.full(shape, empty_val, dtype=vispy_texture_dtype) + values = np.zeros(shape + (4,), dtype=vispy_texture_dtype) + visited = set() + collision = False for key, value in color_dict.items(): - hash2d_set(key, value, keys, values) - return keys, values + key_ = vispy_texture_dtype(key) + if key_ in visited: + # input int keys are unique but can map to the same float. + # if so, we ignore all but the first appearance. + continue + visited.add(key_) + collision |= hash2d_set(key_, value, keys, values) + + return keys, values, collision class VispyLabelsLayer(VispyImageLayer): @@ -185,13 +420,13 @@ def __init__(self, layer, node=None, texture_format='r32f') -> None: layer_node_class=LabelLayerNode, ) - self.layer.events.color_mode.connect(self._on_colormap_change) + # self.layer.events.color_mode.connect(self._on_colormap_change) self.layer.events.labels_update.connect(self._on_partial_labels_update) self.layer.events.selected_label.connect(self._on_colormap_change) self.layer.events.show_selected_label.connect(self._on_colormap_change) def _on_rendering_change(self): - # overriding the Image method so we can maintain the same old rendering name + # overriding the Image method, so we can maintain the same old rendering name if isinstance(self.node, VolumeNode): rendering = self.layer.rendering self.node.method = ( @@ -206,6 +441,12 @@ def _on_colormap_change(self, event=None): # self.layer.colormap is a labels_colormap, which is an evented model # from napari.utils.colormaps.Colormap (or similar). If we use it # in our constructor, we have access to the texture data we need + if ( + event is not None + and event.type == 'selected_label' + and not self.layer.show_selected_label + ): + return colormap = self.layer.colormap mode = self.layer.color_mode @@ -221,10 +462,16 @@ def _on_colormap_change(self, event=None): color_dict = ( self.layer.color ) # TODO: should probably account for non-given labels - key_texture, val_texture = build_textures_from_dict(color_dict) + key_texture, val_texture, collision = build_textures_from_dict( + color_dict, + use_selection=colormap.use_selection, + selection=colormap.selection, + ) + self.node.cmap = DirectLabelVispyColormap( use_selection=colormap.use_selection, selection=colormap.selection, + collision=collision, ) # note that textures have to be transposed here! self.node.shared_program['texture2D_keys'] = Texture2D( @@ -297,6 +544,6 @@ def _compute_bounds(self, axis, view): if self._data is None: return None elif axis > 1: # noqa: RET505 - return (0, 0) + return 0, 0 else: - return (0, self.size[axis]) + return 0, self.size[axis] diff --git a/napari/layers/labels/labels.py b/napari/layers/labels/labels.py index d18519af667..dc558be19dc 100644 --- a/napari/layers/labels/labels.py +++ b/napari/layers/labels/labels.py @@ -45,9 +45,10 @@ ) from napari.layers.utils.color_transformations import transform_color from napari.layers.utils.layer_utils import _FeatureTable -from napari.utils._dtype import normalize_dtype +from napari.utils._dtype import normalize_dtype, vispy_texture_dtype from napari.utils.colormaps import ( direct_colormap, + ensure_colormap, label_colormap, ) from napari.utils.events import Event @@ -756,12 +757,12 @@ def color_mode(self, color_mode: Union[str, LabelColorMode]): self._cached_labels = None # invalidates labels cache self._color_mode = color_mode if color_mode == LabelColorMode.AUTO: - super()._set_colormap(self._random_colormap) + self._colormap = ensure_colormap(self._random_colormap) else: - super()._set_colormap(self._direct_colormap) + self._colormap = ensure_colormap(self._direct_colormap) self._selected_color = self.get_color(self.selected_label) self.events.color_mode() - self.events.colormap() + self.events.colormap() # If remove this emitting, connect shader update to color_mode self.events.selected_label() self.refresh() @@ -774,6 +775,7 @@ def show_selected_label(self): def show_selected_label(self, show_selected): self._show_selected_label = show_selected self.colormap.use_selection = show_selected + self.colormap.selection = self.selected_label self.events.show_selected_label(show_selected_label=show_selected) self._cached_labels = None self.refresh() @@ -863,7 +865,7 @@ def _to_vispy_texture_dtype(data): float32 as it can represent all input values (though not losslessly, see https://github.com/napari/napari/issues/6084). """ - return data.astype(np.float32) + return vispy_texture_dtype(data) def _update_slice_response(self, response: _ImageSliceResponse) -> None: """Override to convert raw slice data to displayed label colors.""" @@ -881,7 +883,7 @@ def _partial_labels_refresh(self): # Keep only the dimensions that correspond to the current view updated_slice = tuple( - [self._updated_slice[index] for index in dims_displayed] + self._updated_slice[index] for index in dims_displayed ) offset = [axis_slice.start for axis_slice in updated_slice] diff --git a/napari/utils/_dtype.py b/napari/utils/_dtype.py index c77ae904b81..4f9d3516228 100644 --- a/napari/utils/_dtype.py +++ b/napari/utils/_dtype.py @@ -111,3 +111,6 @@ def get_dtype_limits(dtype_spec) -> Tuple[float, float]: else: raise TypeError(f'Unrecognized or non-numeric dtype: {dtype_spec}') return info.min, info.max + + +vispy_texture_dtype = np.float32