diff --git a/examples/image_widget.ipynb b/examples/image_widget.ipynb new file mode 100644 index 000000000..8a88b3b67 --- /dev/null +++ b/examples/image_widget.ipynb @@ -0,0 +1,393 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "04f453ca-d0bc-411f-b2a6-d38294dd0a26", + "metadata": {}, + "outputs": [], + "source": [ + "from fastplotlib.widgets import ImageWidget\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "e933771b-f172-4fa9-b2f8-129723efb808", + "metadata": {}, + "source": [ + "# Single image sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ea87f9a6-437f-41f6-8739-c957fb04bdbf", + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.rand(500, 512, 512)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8b7a6066-ff69-4bee-bae6-160fb4038393", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6d575ba7671047ca88c36606344714fa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw = ImageWidget(\n", + " data=a, \n", + " slider_dims=[\"t\"],\n", + " vmin_vmax_sliders=True,\n", + " cmap=\"gnuplot2\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3d4cb44e-2c71-4bff-aeed-b2129f34d724", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8de187407b7746168c8d20a428d8712e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(JupyterWgpuCanvas(), IntSlider(value=0, description='dimension: t', max=499), FloatRangeSlider(…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9908103c-c35c-4f33-ada1-0fc357c3fd5e", + "metadata": {}, + "source": [ + "### Play with setting different window functions\n", + "\n", + "These can also be given as kwargs to `ImageWidget` during instantiation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f278b26a-1b71-4e76-9cc7-efaddbd7b122", + "metadata": {}, + "outputs": [], + "source": [ + "# must be in the form of {dim: (func, window_size)}\n", + "iw.window_funcs = {\"t\": (np.mean, 13)}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cb4d4b7c-919f-41c0-b1cc-b4496473d760", + "metadata": {}, + "outputs": [], + "source": [ + "# change the winow size\n", + "iw.window_funcs[\"t\"].window_size = 23" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2eea6432-4d38-4d42-ab75-f6aa1bab36f4", + "metadata": {}, + "outputs": [], + "source": [ + "# change the function\n", + "iw.window_funcs[\"t\"].func = np.max" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "afa2436f-2741-49d6-87f6-7a91a343fe0e", + "metadata": {}, + "outputs": [], + "source": [ + "# or set it again\n", + "iw.window_funcs = {\"t\": (np.min, 11)}" + ] + }, + { + "cell_type": "markdown", + "id": "aca22179-1b1f-4c51-97bf-ce2d7044e451", + "metadata": {}, + "source": [ + "# Gridplot of txy data" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "882162eb-c873-42df-a945-d5e05ad141c9", + "metadata": {}, + "outputs": [], + "source": [ + "dims = (100, 512, 512)\n", + "data = [np.random.rand(*dims) for i in range(4)]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bf9f92b6-38ad-4d78-b88c-a32d473b6462", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "005bcbc7755748cfaf0644e28beb3b0e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw = ImageWidget(\n", + " data=data, \n", + " slider_dims=[\"t\"], \n", + " # dims_order=\"txy\", # you can set this manually if dim order is not the usual\n", + " vmin_vmax_sliders=True,\n", + " names=[\"zero\", \"one\", \"two\", \"three\"],\n", + " window_funcs={\"t\": (np.mean, 5)},\n", + " cmap=\"gnuplot2\", \n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0721dc40-677e-431d-94c6-da59606199cb", + "metadata": {}, + "source": [ + "### pan-zoom controllers are all synced in a `ImageWidget`" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "403dde31-981a-46fb-b005-1bcef19c4f2c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b0a10be5d5b43b5a08f51a9d8f9b1dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(JupyterWgpuCanvas(), IntSlider(value=0, description='dimension: t', max=99), FloatRangeSlider(v…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw.show()" + ] + }, + { + "cell_type": "markdown", + "id": "82545214-13c4-475e-87da-962117085834", + "metadata": {}, + "source": [ + "### Index the subplots using the names given to `ImageWidget`" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b59d95e2-9092-4915-beef-01661d164781", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "two: Subplot @ 0x7f91486a7a00\n", + " parent: None\n", + " Graphics:\n", + "\tfastplotlib.ImageGraphic @ 0x7f914881ceb0" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "iw.plot[\"two\"]" + ] + }, + { + "cell_type": "markdown", + "id": "dc727d1a-681e-4cbf-bfb2-898ceb31cbe0", + "metadata": {}, + "source": [ + "### change window functions just like before" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a8f070db-da11-4062-95aa-f19b96351ee8", + "metadata": {}, + "outputs": [], + "source": [ + "iw.window_funcs[\"t\"].func = np.max" + ] + }, + { + "cell_type": "markdown", + "id": "3e89c10f-6e34-4d63-9805-88403d487432", + "metadata": {}, + "source": [ + "## Gridplot of volumetric data" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b1587410-a08e-484c-8795-195a413d6374", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a2e4d723405345e0a7bd7b005330d018", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dims = (256, 256, 5, 100)\n", + "data = [np.random.rand(*dims) for i in range(4)]\n", + "\n", + "iw = ImageWidget(\n", + " data=data, \n", + " slider_dims=[\"t\", \"z\"], \n", + " dims_order=\"xyzt\", # example of how you can set this for non-standard orders\n", + " vmin_vmax_sliders=True,\n", + " names=[\"zero\", \"one\", \"two\", \"three\"],\n", + " # window_funcs={\"t\": (np.mean, 5)}, # window functions can be slow when indexing multiple dims\n", + " cmap=\"gnuplot2\", \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "3ccea6c6-9580-4720-bce8-a5507cf867a3", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "78a4ed0f59734124a7f3ee23e373e64a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(JupyterWgpuCanvas(), IntSlider(value=0, description='dimension: t', max=99), IntSlider(value=0,…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw.show()" + ] + }, + { + "cell_type": "markdown", + "id": "2382809c-4c7d-4da4-9955-71d316dee46a", + "metadata": {}, + "source": [ + "### window functions, can be slow when you have \"t\" and \"z\"" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "fd4433a9-2add-417c-a618-5891371efae0", + "metadata": {}, + "outputs": [], + "source": [ + "iw.window_funcs = {\"t\": (np.mean, 11)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3090a7e2-558e-4975-82f4-6a67ae141900", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/imagewidget.ipynb b/examples/imagewidget.ipynb new file mode 100644 index 000000000..4f56cf473 --- /dev/null +++ b/examples/imagewidget.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "111354d5-36ee-4bd5-9376-aaece6eb5b4e", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8404aaae-ba87-426c-a3cf-f3968640b8e3", + "metadata": {}, + "outputs": [], + "source": [ + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2253ccaa-670c-4c6a-9e81-48963bd1d964", + "metadata": {}, + "outputs": [], + "source": [ + "from fastplotlib.widgets.image import ImageWidget\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9775a1f0-34c3-4583-8dcc-00707095150b", + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.rand(100, 5, 512, 512)\n", + "b = np.random.rand(100, 5, 512, 512)\n", + "c = np.random.rand(100, 5, 512, 512)\n", + "d = np.random.rand(100, 5, 512, 512)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f1ade87e-c5bf-4258-9e5a-89d5cd41f348", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3783d46da0c2448a82e7209ccf48b0c8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw = ImageWidget(a, slider_axes=[0, 1], cmap=\"gnuplot2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4ad670f9-53d7-4499-9d50-5ae2ef838f25", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "91eb46fd92e8431ea22b78bfd687d0b7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(JupyterWgpuCanvas(), IntSlider(value=0, description='Axis: t', max=99), IntSlider(value=0, desc…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8b910f06-2cf4-4363-8b58-71c32d6f9c64", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9194bd04719b4665bfc33e912474659b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw = ImageWidget([a, b, c, d], slider_axes=[\"t\", \"z\"], axes_order=\"tzxy\", cmap=\"gnuplot2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f1be6cd2-9263-4da0-9280-48444dd74c1f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c803b14974eb45c6b2a17be83faccc39", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(JupyterWgpuCanvas(), IntSlider(value=0, description='Axis: t', max=99), IntSlider(value=0, desc…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "iw.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f39df8c-5248-471e-a05d-5cf667da138e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/fastplotlib/utils/functions.py b/fastplotlib/utils/functions.py index 9cd7dc6a2..650cfb053 100644 --- a/fastplotlib/utils/functions.py +++ b/fastplotlib/utils/functions.py @@ -78,9 +78,15 @@ def map_labels_to_colors(labels: iter, cmap: str, **kwargs) -> list: def quick_min_max(data: np.ndarray) -> Tuple[float, float]: - # from pyqtgraph.ImageView + # adapted from pyqtgraph.ImageView # Estimate the min/max values of *data* by subsampling. # Returns [(min, max), ...] with one item per channel + + if hasattr(data, "min") and hasattr(data, "max"): + # if value is pre-computed + if isinstance(data.min, (float, int)) and isinstance(data.max, (float, int)): + return data.min, data.max + while data.size > 1e6: ax = np.argmax(data.shape) sl = [slice(None)] * data.ndim diff --git a/fastplotlib/widgets/__init__.py b/fastplotlib/widgets/__init__.py new file mode 100644 index 000000000..553e990bf --- /dev/null +++ b/fastplotlib/widgets/__init__.py @@ -0,0 +1 @@ +from .image import ImageWidget diff --git a/fastplotlib/widgets/image.py b/fastplotlib/widgets/image.py new file mode 100644 index 000000000..c79da680a --- /dev/null +++ b/fastplotlib/widgets/image.py @@ -0,0 +1,799 @@ +from ..plot import Plot +from ..layouts import GridPlot +from ..graphics import ImageGraphic +from ..utils import quick_min_max +from ipywidgets.widgets import IntSlider, VBox, HBox, Layout, FloatRangeSlider +import numpy as np +from typing import * +from warnings import warn +from functools import partial + + +DEFAULT_DIMS_ORDER = \ + { + 2: "xy", + 3: "txy", + 4: "tzxy", + # 5: "tczxy", # no 5 dim stuff for now + } + + +def _calc_gridshape(n): + sr = np.sqrt(n) + return ( + int(np.ceil(sr)), + int(np.round(sr)) + ) + + +def _is_arraylike(obj) -> bool: + """ + Checks if the object is array-like. + For now just checks if obj has `__getitem__()` + """ + for attr in [ + "__getitem__", + "shape", + "ndim" + ]: + if not hasattr(obj, attr): + return False + + return True + + +class _WindowFunctions: + def __init__(self, func: callable, window_size: int): + self.func = func + + self._window_size = 0 + self.window_size = window_size + + @property + def window_size(self) -> int: + return self._window_size + + @window_size.setter + def window_size(self, ws: int): + if ws is None: + self._window_size = None + return + + if not isinstance(ws, int): + raise TypeError("window size must be an int") + + if ws < 3: + warn( + f"Invalid 'window size' value for function: {self.func}, " + f"setting 'window size' = None for this function. " + f"Valid values are integers >= 3." + ) + self.window_size = None + return + + if ws % 2 == 0: + ws += 1 + + self._window_size = ws + + def __repr__(self): + return f"func: {self.func}, window_size: {self.window_size}" + + +class ImageWidget: + @property + def plot(self) -> Union[Plot, GridPlot]: + """ + The plotter used by the ImageWidget. Either a simple ``Plot`` or ``GridPlot``. + """ + return self._plot + + @property + def data(self) -> List[np.ndarray]: + """data currently displayed in the widget""" + return self._data + + @property + def ndim(self) -> int: + """number of dimensions in the image data displayed in the widget""" + return self._ndim + + @property + def dims_order(self) -> List[str]: + """dimension order of the data displayed in the widget""" + return self._dims_order + + @property + def sliders(self) -> Dict[str, IntSlider]: + """the slider instances used by the widget for indexing the desired dimensions""" + return self._sliders + + @property + def slider_dims(self) -> List[str]: + """the dimensions that the sliders index""" + return self._slider_dims + + @property + def current_index(self) -> Dict[str, int]: + return self._current_index + + @current_index.setter + def current_index(self, index: Dict[str, int]): + """ + Set the current index + + Parameters + ---------- + index: Dict[str, int] + | ``dict`` for indexing each dimension, provide a ``dict`` with indices for all dimensions used by sliders + or only a subset of dimensions used by the sliders. + | example: if you have sliders for dims "t" and "z", you can pass either ``{"t": 10}`` to index to position + 10 on dimension "t" or ``{"t": 5, "z": 20}`` to index to position 5 on dimension "t" and position 20 on + dimension "z" simultaneously. + """ + if not set(index.keys()).issubset(set(self._current_index.keys())): + raise KeyError( + f"All dimension keys for setting `current_index` must be present in the widget sliders. " + f"The dimensions currently used for sliders are: {list(self.current_index.keys())}" + ) + + for k, val in index.items(): + if not isinstance(val, int): + raise TypeError("Indices for all dimensions must be int") + if val < 0: + raise IndexError("negative indexing is not supported for ImageWidget") + if val > self._dims_max_bounds[k]: + raise IndexError(f"index {val} is out of bounds for dimension '{k}' " + f"which has a max bound of: {self._dims_max_bounds[k]}") + + self._current_index.update(index) + + # can make a callback_block decorator later + self.block_sliders = True + for k in index.keys(): + self.sliders[k].value = index[k] + self.block_sliders = False + + for i, (ig, data) in enumerate(zip(self.image_graphics, self.data)): + frame = self._process_indices(data, self._current_index) + frame = self._process_frame_apply(frame, i) + ig.update_data(frame) + + def __init__( + self, + data: Union[np.ndarray, List[np.ndarray]], + dims_order: Union[str, Dict[int, str]] = None, + slider_dims: Union[str, int, List[Union[str, int]]] = None, + window_funcs: Union[int, Dict[str, int]] = None, + frame_apply: Union[callable, Dict[int, callable]] = None, + vmin_vmax_sliders: bool = False, + grid_shape: Tuple[int, int] = None, + names: List[str] = None, + **kwargs + ): + """ + A high level for displaying n-dimensional image data in conjunction with automatically generated sliders for + navigating through 1-2 selected dimensions within the image data. + + Can display a single n-dimensional image array or a grid of n-dimensional images. + + Default dimension orders: + + ======= ========== + n_dims dims order + ======= ========== + 2 "xy" + 3 "txy" + 4 "tzxy" + ======= ========== + + Parameters + ---------- + data: Union[np.ndarray, List[np.ndarray] + array-like or a list of array-like + + dims_order: Optional[Union[str, Dict[np.ndarray, str]]] + | ``str`` or a dict mapping to indicate dimension order + | a single ``str`` if ``data`` is a single array, or a list of arrays with the same dimension order + | examples: ``"xyt"``, ``"tzxy"`` + | ``dict`` mapping of ``{array_index: axis_order}`` if specific arrays have a non-default axes order. + | "array_index" is the position of the corresponding array in the data list. + | examples: ``{array_index: "tzxy", another_array_index: "xytz"}`` + + slider_dims: Optional[Union[str, int, List[Union[str, int]]]] + | The dimensions for which to create a slider + | can be a single ``str`` such as **"t"**, **"z"** or a numerical ``int`` that indexes the desired dimension + | can also be a list of ``str`` or ``int`` if multiple sliders are desired for multiple dimensions + | examples: ``"t"``, ``["t", "z"]`` + + window_funcs: Dict[Union[int, str], int] + | average one or more dimensions using a given window + | if a slider exists for only one dimension this can be an ``int``. + | if multiple sliders exist, then it must be a `dict`` mapping in the form of: ``{dimension: window_size}`` + | dimension/axes can be specified using ``str`` such as "t", "z" etc. or ``int`` that indexes the dimension + | if window_size is not an odd number, adds 1 + | use ``None`` to disable averaging for a dimension, example: ``{"t": 5, "z": None}`` + + frame_apply: Union[callable, Dict[int, callable]] + | apply a function to slices of the array before displaying the frame + | pass a single function or a dict of functions to apply to each array individually + | examples: ``{array_index: to_grayscale}``, ``{0: to_grayscale, 2: threshold_img}`` + | "array_index" is the position of the corresponding array in the data list. + | if `window_funcs` is used, then this function is applied after `window_funcs` + | this function must be a callable that returns a 2D array + | example use case: converting an RGB frame from video to a 2D grayscale frame + + grid_shape: Optional[Tuple[int, int]] + manually provide the shape for a gridplot, otherwise a square gridplot is approximated. + + names: Optional[str] + gives names to the subplots + + kwargs: Any + passed to fastplotlib.graphics.Image + """ + self._names = None + + if isinstance(data, list): + # verify that it's a list of np.ndarray + if all([_is_arraylike(d) for d in data]): + if grid_shape is None: + grid_shape = _calc_gridshape(len(data)) + + # verify that user-specified grid shape is large enough for the number of image arrays passed + elif grid_shape[0] * grid_shape[1] < len(data): + grid_shape = _calc_gridshape(len(data)) + warn(f"Invalid `grid_shape` passed, setting grid shape to: {grid_shape}") + + _ndim = [d.ndim for d in data] + + # verify that all image arrays have same number of dimensions + # sliders get messy otherwise + if not len(set(_ndim)) == 1: + raise ValueError( + f"Number of dimensions of all data arrays must match, your ndims are: {_ndim}" + ) + + self._data: List[np.ndarray] = data + self._ndim = self.data[0].ndim # all ndim must be same + + if names is not None: + if not all([isinstance(n, str) for n in names]): + raise TypeError("optinal argument `names` must be a list of str") + + if len(names) != len(self.data): + raise ValueError( + "number of `names` for subplots must be same as the number of data arrays" + ) + self._names = names + + self._plot_type = "grid" + + else: + raise TypeError( + f"If passing a list to `data` all elements must be an " + f"array-like type representing an n-dimensional image" + ) + + elif _is_arraylike(data): + self._data = [data] + self._ndim = self.data[0].ndim + + self._plot_type = "single" + else: + raise TypeError( + f"`data` must be an array-like type representing an n-dimensional image " + f"or a list of array-like representing a grid of n-dimensional images" + ) + + # default dims order if not passed + # updated later if passed + self._dims_order: List[str] = [DEFAULT_DIMS_ORDER[self.ndim]] * len(self.data) + + if dims_order is not None: + if isinstance(dims_order, str): + dims_order = dims_order.lower() + if len(dims_order) != self.ndim: + raise ValueError( + f"number of dims '{len(dims_order)} passed to `dims_order` " + f"does not match ndim '{self.ndim}' of data" + ) + self._dims_order: List[str] = [dims_order] * len(self.data) + elif isinstance(dims_order, dict): + self._dims_order: List[str] = [DEFAULT_DIMS_ORDER[self.ndim]] * len(self.data) + + # dict of {array_ix: dims_order_str} + for data_ix in list(dims_order.keys()): + if not isinstance(data_ix, int): + raise TypeError("`dims_oder` dict keys must be ") + if len(dims_order[data_ix]) != self.ndim: + raise ValueError( + f"number of dims '{len(dims_order)} passed to `dims_order` " + f"does not match ndim '{self.ndim}' of data" + ) + _do = dims_order[data_ix].lower() + # make sure the same dims are present + if not set(_do) == set(DEFAULT_DIMS_ORDER[self.ndim]): + raise ValueError( + f"Invalid `dims_order` passed for one of your arrays, " + f"valid `dims_order` for given number of dimensions " + f"can only contain the following characters: " + f"{DEFAULT_DIMS_ORDER[self.ndim]}" + ) + try: + self.dims_order[data_ix] = _do + except Exception: + raise IndexError( + f"index {data_ix} out of bounds for `dims_order`, the bounds are 0 - {len(self.data)}" + ) + else: + raise TypeError(f"`dims_order` must be a or , you have passed a: <{type(dims_order)}>") + + if not len(self.dims_order[0]) == self.ndim: + raise ValueError( + f"Number of dims specified by `dims_order`: {len(self.dims_order[0])} does not" + f" match number of dimensions in the `data`: {self.ndim}" + ) + + ao = np.array([sorted(v) for v in self.dims_order]) + + if not np.all(ao == ao[0]): + raise ValueError( + f"`dims_order` for all arrays must contain the same combination of dimensions, your `dims_order` are: " + f"{self.dims_order}" + ) + + # by default slider is only made for "t" - time dimension + if slider_dims is None: + slider_dims = "t" + + # slider for only one of the dimensions + if isinstance(slider_dims, (int, str)): + # if numerical dimension is specified + if isinstance(slider_dims, int): + ao = np.array([v for v in self.dims_order]) + if not np.all(ao == ao[0]): + raise ValueError( + f"`dims_order` for all arrays must be identical if passing in a `slider_dims` argument. " + f"Pass in a argument if the `dims_order` are different for each array." + ) + self._slider_dims: List[str] = [self.dims_order[0][slider_dims]] + + # if dimension specified by str + elif isinstance(slider_dims, str): + if slider_dims not in self.dims_order[0]: + raise ValueError( + f"if `slider_dims` is a , it must be a character found in `dims_order`. " + f"Your `dims_order` characters are: {set(self.dims_order[0])}." + ) + self._slider_dims: List[str] = [slider_dims] + + # multiple sliders, one for each dimension + elif isinstance(slider_dims, list): + self._slider_dims: List[str] = list() + + # make sure window_funcs and frame_apply are dicts if multiple sliders are desired + if (not isinstance(window_funcs, dict)) and (window_funcs is not None): + raise TypeError( + f"`window_funcs` must be a if multiple `slider_dims` are provided. You must specify the " + f"window for each dimension." + ) + if (not isinstance(frame_apply, dict)) and (frame_apply is not None): + raise TypeError( + f"`frame_apply` must be a if multiple `slider_dims` are provided. You must specify a " + f"function for each dimension." + ) + + for sdm in slider_dims: + if isinstance(sdm, int): + ao = np.array([v for v in self.dims_order]) + if not np.all(ao == ao[0]): + raise ValueError( + f"`dims_order` for all arrays must be identical if passing in a `slider_dims` argument. " + f"Pass in a argument if the `dims_order` are different for each array." + ) + # parse int to a str + self.slider_dims.append(self.dims_order[0][sdm]) + + elif isinstance(sdm, str): + if sdm not in self.dims_order[0]: + raise ValueError( + f"if `slider_dims` is a , it must be a character found in `dims_order`. " + f"Your `dims_order` characters are: {set(self.dims_order[0])}." + ) + self.slider_dims.append(sdm) + + else: + raise TypeError( + "If passing a list for `slider_dims` each element must be either an or " + ) + + else: + raise TypeError(f"`slider_dims` must a , or , you have passed a: {type(slider_dims)}") + + self.frame_apply: Dict[int, callable] = dict() + + if frame_apply is not None: + if callable(frame_apply): + self.frame_apply = {0: frame_apply} + + elif isinstance(frame_apply, dict): + self.frame_apply: Dict[int, callable] = dict.fromkeys(list(range(len(self.data)))) + + # dict of {array: dims_order_str} + for data_ix in list(frame_apply.keys()): + if not isinstance(data_ix, int): + raise TypeError("`frame_apply` dict keys must be ") + try: + self.frame_apply[data_ix] = frame_apply[data_ix] + except Exception: + raise IndexError( + f"key index {data_ix} out of bounds for `frame_apply`, the bounds are 0 - {len(self.data)}" + ) + else: + raise TypeError( + f"`frame_apply` must be a callable or , " + f"you have passed a: <{type(frame_apply)}>") + + self._window_funcs = None + self.window_funcs = window_funcs + + self._sliders: Dict[str, IntSlider] = dict() + self._vertical_sliders = list() + self._horizontal_sliders = list() + + # current_index stores {dimension_index: slice_index} for every dimension + self._current_index: Dict[str, int] = {sax: 0 for sax in self.slider_dims} + + self.vmin_vmax_sliders: List[FloatRangeSlider] = list() + + # get max bound for all data arrays for all dimensions + self._dims_max_bounds: Dict[str, int] = {k: np.inf for k in self.slider_dims} + for _dim in list(self._dims_max_bounds.keys()): + for array, order in zip(self.data, self.dims_order): + self._dims_max_bounds[_dim] = min(self._dims_max_bounds[_dim], array.shape[order.index(_dim)]) + + if self._plot_type == "single": + self._plot: Plot = Plot() + + minmax = quick_min_max(self.data[0]) + + if vmin_vmax_sliders: + data_range = np.ptp(minmax) + data_range_30p = np.ptp(minmax) * 0.3 + + minmax_slider = FloatRangeSlider( + value=minmax, + min=minmax[0] - data_range_30p, + max=minmax[1] + data_range_30p, + step=data_range / 150, + description=f"min-max", + readout = True, + readout_format = '.3f', + ) + + minmax_slider.observe( + partial(self._vmin_vmax_slider_changed, 0), + names="value" + ) + + self.vmin_vmax_sliders.append(minmax_slider) + + if ("vmin" not in kwargs.keys()) or ("vmax" not in kwargs.keys()): + kwargs["vmin"], kwargs["vmax"] = minmax + + frame = self._process_indices(self.data[0], slice_indices=self._current_index) + + self.image_graphics: List[ImageGraphic] = [self.plot.image(data=frame, **kwargs)] + + elif self._plot_type == "grid": + self._plot: GridPlot = GridPlot(shape=grid_shape, controllers="sync") + + self.image_graphics = list() + for i, (d, subplot) in enumerate(zip(self.data, self.plot)): + minmax = quick_min_max(self.data[0]) + + if self._names is not None: + name = self._names[i] + name_slider = name + else: + name = None + name_slider = "" + + if vmin_vmax_sliders: + data_range = np.ptp(minmax) + data_range_30p = np.ptp(minmax) * 0.4 + + minmax_slider = FloatRangeSlider( + value=minmax, + min=minmax[0] - data_range_30p, + max=minmax[1] + data_range_30p, + step=data_range / 150, + description=f"mm ['{name_slider}']", + readout=True, + readout_format='.3f', + ) + + minmax_slider.observe( + partial(self._vmin_vmax_slider_changed, i), + names="value" + ) + + self.vmin_vmax_sliders.append(minmax_slider) + + if ("vmin" not in kwargs.keys()) or ("vmax" not in kwargs.keys()): + kwargs["vmin"], kwargs["vmax"] = minmax + + frame = self._process_indices(d, slice_indices=self._current_index) + ig = ImageGraphic(frame, **kwargs) + subplot.add_graphic(ig) + subplot.name = name + self.image_graphics.append(ig) + + self.plot.renderer.add_event_handler(self._set_slider_layout, "resize") + + for sdm in self.slider_dims: + if sdm == "z": + # TODO: once ipywidgets plays nicely with HBox and jupyter-rfb, use vertical + # orientation = "vertical" + orientation = "horizontal" + else: + orientation = "horizontal" + + slider = IntSlider( + min=0, + max=self._dims_max_bounds[sdm] - 1, + step=1, + value=0, + description=f"dimension: {sdm}", + orientation=orientation + ) + + slider.observe( + partial(self._slider_value_changed, sdm), + names="value" + ) + + self._sliders[sdm] = slider + if orientation == "horizontal": + self._horizontal_sliders.append(slider) + elif orientation == "vertical": + self._vertical_sliders.append(slider) + + # will change later + # prevent the slider callback if value is self.current_index is changed programmatically + self.block_sliders: bool = False + + # TODO: So just stack everything vertically for now + self.widget = VBox([ + self.plot.canvas, + *list(self._sliders.values()), + *self.vmin_vmax_sliders + ]) + + # TODO: there is currently an issue with ipywidgets or jupyter-rfb and HBox doesn't work with RFB canvas + # self.widget = None + # hbox = None + # if len(self.vertical_sliders) > 0: + # hbox = HBox(self.vertical_sliders) + # + # if len(self.horizontal_sliders) > 0: + # if hbox is not None: + # self.widget = VBox([ + # HBox([self.plot.canvas, hbox]), + # *self.horizontal_sliders, + # ]) + # + # else: + # self.widget = VBox([self.plot.canvas, *self.horizontal_sliders]) + + @property + def window_funcs(self) -> Dict[str, _WindowFunctions]: + return self._window_funcs + + @window_funcs.setter + def window_funcs(self, sa: Union[int, Dict[str, int]]): + if sa is None: + self._window_funcs = None + return + + # for a single dim + elif isinstance(sa, tuple): + if len(self.slider_dims) > 1: + raise TypeError( + "Must pass dict argument to window_funcs if using multiple sliders. See the docstring." + ) + if not callable(sa[0]) or not isinstance(sa[1], int): + raise TypeError( + "Tuple argument to `window_funcs` must be in the form of (func, window_size). See the docstring." + ) + + dim_str = self.slider_dims[0] + self._window_funcs = dict() + self._window_funcs[dim_str] = _WindowFunctions(*sa) + + # for multiple dims + elif isinstance(sa, dict): + if not all([isinstance(_sa, tuple) or (_sa is None) for _sa in sa.values()]): + raise TypeError( + "dict argument to `window_funcs` must be in the form of: " + "`{dimension: (func, window_size)}`. " + "See the docstring." + ) + for v in sa.values(): + if v is not None: + if not callable(v[0]) or not (isinstance(v[1], int) or v[1] is None): + raise TypeError( + "dict argument to `window_funcs` must be in the form of: " + "`{dimension: (func, window_size)}`. " + "See the docstring." + ) + + if not isinstance(self._window_funcs, dict): + self._window_funcs = dict() + + for k in list(sa.keys()): + if sa[k] is None: + self._window_funcs[k] = None + else: + self._window_funcs[k] = _WindowFunctions(*sa[k]) + + else: + raise TypeError( + f"`window_funcs` must be of type `int` if using a single slider or a dict if using multiple sliders. " + f"You have passed a {type(sa)}. See the docstring." + ) + + def _process_indices( + self, + array: np.ndarray, + slice_indices: dict[Union[int, str], int] + ) -> np.ndarray: + """ + Get the 2D array from the given slice indices. If not returning a 2D slice (such as due to window_funcs) + then `frame_apply` must take this output and return a 2D array + + Parameters + ---------- + array: np.ndarray + array-like to get a 2D slice from + + slice_indices: dict[int, int] + dict in form of {dimension_index: slice_index} + For example if an array has shape [1000, 30, 512, 512] corresponding to [t, z, x, y]: + To get the 100th timepoint and 3rd z-plane pass: + {"t": 100, "z": 3}, or {0: 100, 1: 3} + + Returns + ------- + np.ndarray + array-like, 2D slice + + """ + indexer = [slice(None)] * self.ndim + + numerical_dims = list() + for dim in list(slice_indices.keys()): + if isinstance(dim, str): + data_ix = None + for i in range(len(self.data)): + if self.data[i] is array: + data_ix = i + break + if data_ix is None: + raise ValueError( + f"Given `array` not found in `self.data`" + ) + # get axes order for that specific array + numerical_dim = self.dims_order[data_ix].index(dim) + else: + numerical_dim = dim + + indices_dim = slice_indices[dim] + + # takes care of averaging if it was specified + indices_dim = self._get_window_indices(data_ix, numerical_dim, indices_dim) + + # set the indices for this dimension + indexer[numerical_dim] = indices_dim + + numerical_dims.append(numerical_dim) + + # apply indexing to the array + # use window function is given for this dimension + if self.window_funcs is not None: + a = array + for i, dim in enumerate(sorted(numerical_dims)): + dim_str = self.dims_order[data_ix][dim] + dim = dim - i # since we loose a dimension every iteration + _indexer = [slice(None)] * (self.ndim - i) + _indexer[dim] = indexer[dim + i] + + # if the indexer is an int, this dim has no window func + if isinstance(_indexer[dim], int): + a = a[tuple(_indexer)] + else: + # if the indices are from `self._get_window_indices` + func = self.window_funcs[dim_str].func + window = a[tuple(_indexer)] + a = func(window, axis=dim) + # a = np.mean(a[tuple(_indexer)], axis=dim) + return a + else: + return array[tuple(indexer)] + + def _get_window_indices(self, data_ix, dim, indices_dim): + if self.window_funcs is None: + return indices_dim + + else: + ix = indices_dim + + dim_str = self.dims_order[data_ix][dim] + + # if no window stuff specified for this dim + if dim_str not in self.window_funcs.keys(): + return indices_dim + + # if window stuff is set to None for this dim + # example: {"t": None} + if self.window_funcs[dim_str] is None: + return indices_dim + + window_size = self.window_funcs[dim_str].window_size + + if (window_size == 0) or (window_size is None): + return indices_dim + + half_window = int((window_size - 1) / 2) # half-window size + # get the max bound for that dimension + max_bound = self._dims_max_bounds[dim_str] + indices_dim = range(max(0, ix - half_window), min(max_bound, ix + half_window)) + return indices_dim + + def _process_frame_apply(self, array, data_ix) -> np.ndarray: + if data_ix not in self.frame_apply.keys(): + return array + if self.frame_apply[data_ix] is not None: + return self.frame_apply[data_ix](array) + + def _slider_value_changed( + self, + dimension: str, + change: dict + ): + if self.block_sliders: + return + self.current_index = {dimension: change["new"]} + + def _vmin_vmax_slider_changed( + self, + data_ix: int, + change: dict + ): + self.image_graphics[data_ix].clim = change["new"] + + def _set_slider_layout(self, *args): + w, h = self.plot.renderer.logical_size + for hs in self._horizontal_sliders: + hs.layout = Layout(width=f"{w}px") + + for vs in self._vertical_sliders: + vs.layout = Layout(height=f"{h}px") + + for mm in self.vmin_vmax_sliders: + mm.layout = Layout(width=f"{w}px") + + def show(self): + """ + Show the widget + + Returns + ------- + VBox + ``ipywidgets.VBox`` stacking the plotter and sliders in a vertical layout + """ + # start render loop + self.plot.show() + + return self.widget