In [1]:
import itk
import itkwidgets
import ipywidgets as widgets

import numpy as np

In [2]:
class ImageViewer:
    
    def __init__(self, image, image_axis=None, non_spatial_axis=None):
        self.image = image
        self.image_axis = image_axis
        self.non_spatial_axis = non_spatial_axis
        
        if not self.image_axis:
            self.image_axis = ('T', 'C', 'X', 'Y')
            
        if not self.non_spatial_axis:
            self.non_spatial_axis = ['T', 'C']
            
        assert len(image.shape) == len(self.image_axis), "Image axis and image shape must hav the same length."

        # Log widget
        self.out = widgets.Output()
        self.log_widget = widgets.Accordion(children=[self.out])
        self.log_widget.set_title(0, 'Log')
    
        # Init positions
        self.axis_positions = np.zeros((len(self.non_spatial_axis), ), dtype='int')
        
        # Init positions label
        self.position_label = widgets.Label(value="")
        self._update_position_label()
        
        # Init image viewer
        self.image_widget = None
        self._update_image()
        
        # Init sliders
        sliders = []
        for axis in self.non_spatial_axis:
            idx = self.image_axis.index(axis)
            n = self.image.shape[idx]

            slider = widgets.IntSlider(value=0, min=0, max=n - 1, step=1, description=f'Axis {axis}', continuous_update=True)
            slider.name = axis
            slider.observe(self._position_sliders_change, names='value')
            
            play_widget = widgets.Play(value=0, min=0, max=n - 1, step=1, continuous_update=True)
            widgets.link((play_widget, 'value'), (slider, 'value'))
            sliders.append(widgets.HBox([play_widget, slider]))
 
        # Create layout
        self.slider_widgets = widgets.VBox(sliders)
        self.control_widget = widgets.VBox([self.position_label, self.slider_widgets])
        
        self.widget = widgets.VBox([self.control_widget, self.image_widget, self.log_widget])

    def log(self, message):
        with self.out:
            print(str(message))
        
    def _position_sliders_change(self, change):
        axis = change['owner'].name
        idx = self.image_axis.index(axis)
        self.axis_positions[idx] = change['new']
        self._update_position_label()
        self._update_image()

    def _update_position_label(self):
        label_string = "Axis Position: "
        label_string += ' | '.join([f'{axis}: {self.axis_positions[i]}' for i, axis in enumerate(self.non_spatial_axis)])
        self.position_label.value = label_string
        
    def _update_image(self):
        indexes = [self.axis_positions[i] for i in range(len(self.non_spatial_axis))]
        indexes = [slice(idx, idx+1) for idx in indexes]
        self.current_image = self.image[indexes]
        self.current_image = np.squeeze(self.current_image)
        
        if not self.image_widget:
            self.image_widget = itkwidgets.view(self.current_image, ui_collapsed=False, annotations=True,
                                                interpolation=False, cmap='Viridis (matplotlib)', mode='v',
                                                shadow=True, slicing_planes=False, gradient_opacity=0.22)
        else:
            self.image_widget.image = self.current_image

In [3]:
# Create image with random noise
image = np.random.random((50, 3, 512, 512))

# Define axis
image_axis = ('T', 'C', 'X', 'Y')
non_spatial_axis = ['T', 'C']

v = ImageViewer(image, image_axis=image_axis, non_spatial_axis=non_spatial_axis)
v.widget

Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.


VBox(children=(VBox(children=(Label(value='Axis Position: T: 0 | C: 0'), VBox(children=(HBox(children=(Play(va…