Skip to content

Commit

Permalink
Merge pull request #143 from kushalkolar/large-images
Browse files Browse the repository at this point in the history
HeatmapGraphic, supports dims larger than 8192
  • Loading branch information
kushalkolar committed Mar 5, 2023
2 parents 70b4908 + e464925 commit fd417d2
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 19 deletions.
4 changes: 2 additions & 2 deletions fastplotlib/graphics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .histogram import HistogramGraphic
from .line import LineGraphic
from .scatter import ScatterGraphic
from .image import ImageGraphic
from .heatmap import HeatmapGraphic
from .image import ImageGraphic, HeatmapGraphic
# from .heatmap import HeatmapGraphic
from .text import TextGraphic
from .line_collection import LineCollection, LineStack

Expand Down
4 changes: 2 additions & 2 deletions fastplotlib/graphics/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature
from ._data import PointsDataFeature, ImageDataFeature
from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature
from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature
from ._present import PresentFeature
from ._thickness import ThicknessFeature
from ._base import GraphicFeature, GraphicFeatureIndexable
30 changes: 29 additions & 1 deletion fastplotlib/graphics/features/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,33 @@
from pygfx import Buffer


supported_dtypes = [
np.uint8,
np.uint16,
np.uint32,
np.int8,
np.int16,
np.int32,
np.float16,
np.float32
]


def to_gpu_supported_dtype(array):
if isinstance(array, np.ndarray):
if array.dtype not in supported_dtypes:
if np.issubdtype(array.dtype, np.integer):
warn(f"converting {array.dtype} array to int32")
return array.astype(np.int32)
elif np.issubdtype(array.dtype, np.floating):
warn(f"converting {array.dtype} array to float32")
return array.astype(np.float32, copy=False)
else:
raise TypeError("Unsupported type, supported array types must be int or float dtypes")

return array


class FeatureEvent:
"""
type: <feature_name>, example: "colors"
Expand Down Expand Up @@ -43,7 +70,7 @@ def __init__(self, parent, data: Any, collection_index: int = None):
"""
self._parent = parent
if isinstance(data, np.ndarray):
data = data.astype(np.float32)
data = to_gpu_supported_dtype(data)

self._data = data

Expand Down Expand Up @@ -227,3 +254,4 @@ def _update_range_indices(self, key):
self._buffer.update_range(ix, size=1)
else:
raise TypeError("must pass int or slice to update range")

13 changes: 13 additions & 0 deletions fastplotlib/graphics/features/_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,16 @@ def _feature_changed(self, key, new_data):
event_data = FeatureEvent(type="cmap", pick_info=pick_info)

self._call_event_handlers(event_data)


class HeatmapCmapFeature(ImageCmapFeature):
"""
Colormap for HeatmapGraphic
"""

def _set(self, cmap_name: str):
self._parent._material.map.texture.data[:] = make_colors(256, cmap_name)
self._parent._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1))
self.name = cmap_name

self._feature_changed(key=None, new_data=self.name)
57 changes: 47 additions & 10 deletions fastplotlib/graphics/features/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
import numpy as np
from pygfx import Buffer, Texture

from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent


def to_float32(array):
if isinstance(array, np.ndarray):
return array.astype(np.float32, copy=False)

return array
from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype


class PointsDataFeature(GraphicFeatureIndexable):
Expand Down Expand Up @@ -102,7 +95,7 @@ def __init__(self, parent, data: Any):
"``[x_dim, y_dim]`` or ``[x_dim, y_dim, rgb]``"
)

data = to_float32(data)
data = to_gpu_supported_dtype(data)
super(ImageDataFeature, self).__init__(parent, data)

@property
Expand All @@ -114,7 +107,7 @@ def __getitem__(self, item):

def __setitem__(self, key, value):
# make sure float32
value = to_float32(value)
value = to_gpu_supported_dtype(value)

self._buffer.data[key] = value
self._update_range(key)
Expand Down Expand Up @@ -145,3 +138,47 @@ def _feature_changed(self, key, new_data):
event_data = FeatureEvent(type="data", pick_info=pick_info)

self._call_event_handlers(event_data)


class HeatmapDataFeature(ImageDataFeature):
@property
def _buffer(self) -> List[Texture]:
return [img.geometry.grid.texture for img in self._parent.world_object.children]

def __getitem__(self, item):
return self._data[item]

def __setitem__(self, key, value):
# make sure supported type, not float64 etc.
value = to_gpu_supported_dtype(value)

self._data[key] = value
self._update_range(key)

# avoid creating dicts constantly if there are no events to handle
if len(self._event_handlers) > 0:
self._feature_changed(key, value)

def _update_range(self, key):
for buffer in self._buffer:
buffer.update_range((0, 0, 0), size=buffer.size)

def _feature_changed(self, key, new_data):
if key is not None:
key = cleanup_slice(key, self._upper_bound)
if isinstance(key, int):
indices = [key]
elif isinstance(key, slice):
indices = range(key.start, key.stop, key.step)
elif key is None:
indices = None

pick_info = {
"index": indices,
"world_object": self._parent.world_object,
"new_data": new_data
}

event_data = FeatureEvent(type="data", pick_info=pick_info)

self._call_event_handlers(event_data)
176 changes: 175 additions & 1 deletion fastplotlib/graphics/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import *
from math import ceil
from itertools import product

import pygfx
from pygfx.utils import unpack_bitfield

from ._base import Graphic, Interaction, PreviouslyModifiedData
from .features import ImageCmapFeature, ImageDataFeature
from .features import ImageCmapFeature, ImageDataFeature, HeatmapDataFeature, HeatmapCmapFeature
from ..utils import quick_min_max


Expand Down Expand Up @@ -119,5 +122,176 @@ def _reset_feature(self, feature: str):
pass


class _ImageTile(pygfx.Image):
"""
Similar to pygfx.Image, only difference is that it contains a few properties to keep track of
row chunk index, column chunk index
"""
def _wgpu_get_pick_info(self, pick_value):
tex = self.geometry.grid
if hasattr(tex, "texture"):
tex = tex.texture # tex was a view
# This should match with the shader
values = unpack_bitfield(pick_value, wobject_id=20, x=22, y=22)
x = values["x"] / 4194304 * tex.size[0] - 0.5
y = values["y"] / 4194304 * tex.size[1] - 0.5
ix, iy = int(x + 0.5), int(y + 0.5)
return {
"index": (ix, iy),
"pixel_coord": (x - ix, y - iy),
"row_chunk_index": self.row_chunk_index,
"col_chunk_index": self.col_chunk_index
}

@property
def row_chunk_index(self) -> int:
return self._row_chunk_index

@row_chunk_index.setter
def row_chunk_index(self, index: int):
self._row_chunk_index = index

@property
def col_chunk_index(self) -> int:
return self._col_chunk_index

@col_chunk_index.setter
def col_chunk_index(self, index: int):
self._col_chunk_index = index


class HeatmapGraphic(Graphic, Interaction):
feature_events = (
"data",
"cmap",
)

def __init__(
self,
data: Any,
vmin: int = None,
vmax: int = None,
cmap: str = 'plasma',
filter: str = "nearest",
chunk_size: int = 8192,
*args,
**kwargs
):
"""
Create an Image Graphic
Parameters
----------
data: array-like
array-like, usually numpy.ndarray, must support ``memoryview()``
Tensorflow Tensors also work **probably**, but not thoroughly tested
| shape must be ``[x_dim, y_dim]``
vmin: int, optional
minimum value for color scaling, calculated from data if not provided
vmax: int, optional
maximum value for color scaling, calculated from data if not provided
cmap: str, optional, default "plasma"
colormap to use to display the data
filter: str, optional, default "nearest"
interpolation filter, one of "nearest" or "linear"
chunk_size: int, default 8192, max 8192
chunk size for each tile used to make up the heatmap texture
args:
additional arguments passed to Graphic
kwargs:
additional keyword arguments passed to Graphic
Examples
--------
.. code-block:: python
from fastplotlib import Plot
# create a `Plot` instance
plot = Plot()
# make some random 2D image data
data = np.random.rand(512, 512)
# plot the image data
plot.add_image(data=data)
# show the plot
plot.show()
"""

super().__init__(*args, **kwargs)

if chunk_size > 8192:
raise ValueError("Maximum chunk size is 8192")

self.data = HeatmapDataFeature(self, data)

row_chunks = range(ceil(data.shape[0] / chunk_size))
col_chunks = range(ceil(data.shape[1] / chunk_size))

chunks = list(product(row_chunks, col_chunks))
# chunks is the index position of each chunk

start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks]
stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs]

self._world_object = pygfx.Group()

if (vmin is None) or (vmax is None):
vmin, vmax = quick_min_max(data)

self.cmap = HeatmapCmapFeature(self, cmap)
self._material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap())

for start, stop, chunk in zip(start_ixs, stop_ixs, chunks):
row_start, col_start = start
row_stop, col_stop = stop

# x and y positions of the Tile in world space coordinates
y_pos, x_pos = row_start, col_start

tex_view = pygfx.Texture(data[row_start:row_stop, col_start:col_stop], dim=2).get_view(filter=filter)
geometry = pygfx.Geometry(grid=tex_view)
# material = pygfx.ImageBasicMaterial(clim=(0, 1), map=self.cmap())

img = _ImageTile(geometry, self._material)

# row and column chunk index for this Tile
img.row_chunk_index = chunk[0]
img.col_chunk_index = chunk[1]

img.position.set_x(x_pos)
img.position.set_y(y_pos)

self.world_object.add(img)

@property
def vmin(self) -> float:
"""Minimum contrast limit."""
return self._material.clim[0]

@vmin.setter
def vmin(self, value: float):
"""Minimum contrast limit."""
self._material.clim = (
value,
self._material.clim[1]
)

@property
def vmax(self) -> float:
"""Maximum contrast limit."""
return self._material.clim[1]

@vmax.setter
def vmax(self, value: float):
"""Maximum contrast limit."""
self._material.clim = (
self._material.clim[0],
value
)

def _set_feature(self, feature: str, new_data: Any, indices: Any):
pass

def _reset_feature(self, feature: str):
pass
3 changes: 0 additions & 3 deletions fastplotlib/layouts/_subplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,6 @@ def add_graphic(self, graphic, center: bool = True):
graphic.world_object.position.z = len(self._graphics)
super(Subplot, self).add_graphic(graphic, center)

if isinstance(graphic, graphics.HeatmapGraphic):
self.controller.scale.y = copysign(self.controller.scale.y, -1)

def set_axes_visibility(self, visible: bool):
"""Toggles axes visibility."""
if visible:
Expand Down

0 comments on commit fd417d2

Please sign in to comment.