In [1]:
import warnings
warnings.filterwarnings("ignore")

import fastplotlib as fpl
import h5py
from tqdm.notebook import tqdm
import os
import numpy as np
import pygfx as gfx
from ipywidgets import IntSlider, HBox, VBox, Layout, Play, jslink, ToggleButtons, Button, IntRangeSlider
from skimage.exposure import rescale_intensity

pygfx version from git (0.4.1) and __version__ (0.5.0) don't match.
No windowing system present. Using surfaceless platform
No config found!
No config found!
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.
Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


In [2]:
green_scale = 2  # adjust the channel brightness
red_scale = 0.5

In [3]:
# data = np.memmap("/MCAM_data/mark/zebrafish_20220721_set3_4D_reconstructions_20230502_195635_830/memmap.mmap", dtype=np.uint8, mode="r", shape=(120, 400, 400, 400, 3), order="C")

In [4]:
import zarr

In [5]:
data = zarr.open("/MCAM_data/mark/zebrafish_20220721_set3_4D_reconstructions_20230502_195635_830/data.zarr", mode="r")#, shape=(121, 400, 400, 400, 3), chunks=(None, 40, 40, 1), dtype="uint8")

In [6]:
from pygfx.materials import VolumeRayMaterial as _VolumeRayMaterial
import pygfx as gfx
from pygfx.renderers.wgpu import register_wgpu_render_function
from pygfx.renderers.wgpu.shaders.volumeshader import VolumeRayShader

class MarkVolmeRayMaterial(gfx.VolumeRayMaterial):
    uniform_type = dict(
        gfx.VolumeRayMaterial.uniform_type,
        image_min="4xf4",
        image_max="4xf4",
        gamma="4xf4",
    )
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.image_min = (0., 0., 0., 0.)
        self.image_max = (1., 1., 1., 1.)
        self.gamma = (1., 1., 1., 1.)

    @property
    def image_min(self):
        v1, v2, v3, v4 = self.uniform_buffer.data["image_min"]
        return float(v1), float(v2), float(v3), float(v4)

    @image_min.setter
    def image_min(self, image_min):
        if image_min is None:
            image_min = 0., 0., 0., 0.
        image_min = float(image_min[0]), float(image_min[1]), float(image_min[2]), float(image_min[3])
        self.uniform_buffer.data["image_min"] = image_min
        self.uniform_buffer.update_range()

    @property
    def image_max(self):
        v1, v2, v3, v4 = self.uniform_buffer.data["image_max"]
        return float(v1), float(v2), float(v3), float(v4)

    @image_max.setter
    def image_max(self, image_max):
        if image_max is None:
            image_max = 1., 1., 1., 1.
        image_max = float(image_max[0]), float(image_max[1]), float(image_max[2]), float(image_max[3])
        self.uniform_buffer.data["image_max"] = image_max
        self.uniform_buffer.update_range()

    @property
    def gamma(self):
        v1, v2, v3, v4 = self.uniform_buffer.data["gamma"]
        return float(v1), float(v2), float(v3), float(v4)

    @gamma.setter
    def gamma(self, gamma):
        if gamma is None:
            gamma = 1., 1., 1., 1.
        gamma = float(gamma[0]), float(gamma[1]), float(gamma[2]), float(gamma[3])
        self.uniform_buffer.data["gamma"] = gamma
        self.uniform_buffer.update_range()

@register_wgpu_render_function(gfx.Volume, MarkVolmeRayMaterial)
class MarkVolumeRayShader(VolumeRayShader):
    def get_code(self):
        original = super().get_code()

        # Kinda hack around PyGFX's shader. This is pretty fragile
        # This does the adjustment AFTER the data has been sampled.
        # You can also do the modification at sampling time
        modified_code_at_display = original.replace(
"        let out_color = vec4<f32>(physical_color, opacity);",
"""
        let original_color = vec4<f32>(physical_color, opacity);
        let contrast_adjusted = (original_color - u_material.image_min) / (u_material.image_max - u_material.image_min);
        let gamma_adjusted = pow(contrast_adjusted, 1.0 / u_material.gamma);
        let out_color = gamma_adjusted;
"""
        )

        modified_code_at_sampling = original.replace(
            " = sample_vol(",
            " = sample_vol_corrected("
        )

        modified_code_at_sampling = """
fn sample_vol_corrected(texcoord: vec3<f32>, sizef: vec3<f32>) -> vec4<f32> {
    let original_color = sample_vol(texcoord, sizef);
    let contrast_adjusted = (original_color - u_material.image_min) / (u_material.image_max - u_material.image_min);
    let gamma_adjusted = pow(contrast_adjusted, 1.0 / u_material.gamma);
    return gamma_adjusted.rgra;
}

""" + modified_code_at_sampling
        
        return modified_code_at_sampling
        

In [7]:
current_volume = np.transpose(data[0], (1, 2, 0, 3))

current_volume[..., 1] = (current_volume[..., 1].astype(np.uint16) * green_scale).clip(0, 255).astype(np.uint8)
current_volume[..., 0] = (current_volume[..., 0] * red_scale).astype(np.uint8)

In [8]:
tex = gfx.Texture(current_volume, dim=3)

In [9]:
vol = gfx.Volume(
    gfx.Geometry(grid=tex),
    MarkVolmeRayMaterial(clim=(0, 255)),
)

vol.material.image_max = (1., 1., 1., 0.1)

In [10]:
vol_dims = data.shape[1:-1]

In [11]:
# selectors to move through cross_sections
xy_selector = fpl.RectangleSelector(
    selection=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    limits=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    resizable=False,
    edge_color="b",
    fill_color=(0, 0, 0, 0),
)
xz_selector = fpl.RectangleSelector(
    selection=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    limits=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    resizable=False,
    edge_color="g",
    fill_color=(0, 0, 0, 0),
)
yz_selector = fpl.RectangleSelector(
    selection=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    limits=(-5, vol_dims[0] + 20, -5, vol_dims[0] + 20), 
    resizable=False,
    edge_color="r",
    fill_color=(0, 0, 0, 0),
)

for s in [xy_selector, xz_selector, yz_selector]:
    for edge in s._edges:
        edge.material.thickness = 2
    
    for vertex in s.vertices:
        vertex.visible = False

In [12]:
yz_selector.rotate(-np.pi/2, "y")
xz_selector.rotate(np.pi/2, "x")

In [13]:
fig = fpl.Figure(
    cameras="3d", 
    controller_types="orbit", 
    names=[["Volumetric Reconstruction"]],
    size=(650, 650),
)

RFBOutputContext()

Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.


In [14]:
_ = fig[0, 0]._fpl_graphics_scene.add(vol)

In [15]:
fig[0, 0].add_graphic(xy_selector)
fig[0, 0].add_graphic(xz_selector)
_ = fig[0, 0].add_graphic(yz_selector)

In [16]:
slice_fig = fpl.Figure(
    shape=(2, 2), 
    size=(650, 650),
    controller_ids="sync",
    names=[["xy", "xz"], ["yz", ""]]
)

RFBOutputContext()

In [17]:
# initial cross-section 
i = 0

current_volume[..., -1] = current_volume[..., 0]

slice_fig["xy"].add_image(current_volume[:, :, i], vmin=0, vmax=255)
slice_fig["xz"].add_image(current_volume[:, i, :], vmin=0, vmax=255)
_ = slice_fig["yz"].add_image(current_volume[i, :, :], vmin=0, vmax=255)

In [18]:
layout = Layout(width="860px")

time = IntSlider(min=0, max=120, description="time:", layout=layout)

play = Play(value=0, min=0, max=120, step=1, interval=500)
jslink((play, "value"), (time, "value"))

xy = IntSlider(min=0, max=vol_dims[2]-1, description="xy:", layout=layout)
xz = IntSlider(min=0, max=vol_dims[1]-1, description="xz:", layout=layout)
yz = IntSlider(min=0, max=vol_dims[0]-1, description="yz:", layout=layout)

In [19]:
clim_red = IntRangeSlider(value=(0, 255), min=0, max=255, description="clim red:")
clim_green = IntRangeSlider(value=(0, 255), min=0, max=255, description="clim green:")

In [20]:
def update_data(ev):
    new_time = ev["new"]

    global current_volume

    current_volume[:] = np.transpose(data[new_time], (1, 2, 0, 3))

    # prevent overflow
    current_volume[..., 1] = (current_volume[..., 1].astype(np.uint16) * green_scale).clip(0, 255).astype(np.uint8)
    current_volume[..., 0] = (current_volume[..., 0] * red_scale).astype(np.uint8)
    
    # set blue value to match red
    current_volume[..., -1] = current_volume[..., 0]
    
    vol.geometry.grid.data[:] = current_volume
    vol.geometry.grid.update_full()

    # update slices
    xy_data = current_volume[xy.value].copy()
    xz_data = current_volume[:, xz.value].copy()
    yz_data = current_volume[:, :, yz.value].copy()

    limits_red = clim_red.value
    limits_green = clim_green.value

    xy_data[..., 0] = rescale_intensity(xy_data[..., 0], limits_red)
    xy_data[..., -1] = xy_data[..., 0]
    xy_data[..., 1] = rescale_intensity(xy_data[..., 1], limits_green)

    xz_data[..., 0] = rescale_intensity(xz_data[..., 0], limits_red)
    xz_data[..., -1] = xz_data[..., 0]
    xz_data[..., 1] = rescale_intensity(xz_data[..., 1], limits_green)

    yz_data[..., 0] = rescale_intensity(yz_data[..., 0], limits_red)
    yz_data[..., -1] = yz_data[..., 0]
    yz_data[..., 1] = rescale_intensity(yz_data[..., 1], limits_green)
    
    slice_fig["xy"].graphics[0].data = xy_data
    slice_fig["xz"].graphics[0].data = xz_data
    slice_fig["yz"].graphics[0].data = yz_data

    # update text with new time
    fig[0,0].docks["bottom"].graphics[0].text = f"time: {new_time}"

In [21]:
def clim_red_changed(change):
    limits = change["new"]
    green_limits = clim_green.value
    vol.material.image_min = (limits[0] / 255, limits[0] / 255, limits[0] / 255, 0)
    vol.material.image_max = (limits[1] / 255, green_limits[1] / 255, limits[1] / 255, 0.5)


    xy_data = current_volume[xy.value]
    xz_data = current_volume[:, xz.value]
    yz_data = current_volume[:, :, yz.value]

    slice_fig["xy"].graphics[0].data[..., 0] = rescale_intensity(xy_data[..., 0], limits)
    slice_fig["xy"].graphics[0].data[..., -1] = rescale_intensity(xy_data[..., -1], limits)
    
    slice_fig["xz"].graphics[0].data[..., 0] = rescale_intensity(xz_data[..., 0], limits)
    slice_fig["xz"].graphics[0].data[..., -1] = rescale_intensity(xz_data[..., -1], limits)
    
    slice_fig["yz"].graphics[0].data[..., 0] = rescale_intensity(yz_data[..., 0], limits)
    slice_fig["yz"].graphics[0].data[..., -1] = rescale_intensity(yz_data[..., -1], limits)


def clim_green_changed(change):
    limits = change["new"]
    red_limits = clim_red.value
    vol.material.image_min = (red_limits[0] / 255, limits[0] / 255, red_limits[0] / 255, 0)
    vol.material.image_max = (red_limits[1] / 255, limits[1] / 255, red_limits[1] / 255, 0.5)

    xy_data = current_volume[xy.value]
    xz_data = current_volume[:, xz.value]
    yz_data = current_volume[:, :, yz.value]

    slice_fig["xy"].graphics[0].data[..., 1] = rescale_intensity(xy_data[..., 1], limits)
    
    slice_fig["xz"].graphics[0].data[..., 1] = rescale_intensity(xz_data[..., 1], limits)
    
    slice_fig["yz"].graphics[0].data[..., 1] = rescale_intensity(yz_data[..., 1], limits)

clim_red.observe(clim_red_changed, "value")
clim_green.observe(clim_green_changed, "value")

In [22]:
def update_xy(ev):
    new = ev["new"]
    global current_volume
    
    xy_data = current_volume[new]

    xy_selector.offset = (0, 0, new)

    slice_fig["xy"].graphics[0].data = xy_data

In [23]:
def update_xz(ev):
    new = ev["new"]

    global current_volume

    xz_selector.offset=(0, new, 0)

    xz_data = current_volume[:, xz.value]
    
    slice_fig["xz"].graphics[0].data = xz_data

In [24]:
def update_yz(ev):
    new = ev["new"]

    global current_volume

    yz_data = current_volume[:, :, new]

    yz_selector.offset = (new, 0, 0)

    slice_fig["yz"].graphics[0].data = yz_data

In [25]:
time.observe(update_data, "value")

In [26]:
xy.observe(update_xy, "value")
xz.observe(update_xz, "value")
yz.observe(update_yz, "value")

In [27]:
# add axes helper for orientation
axes_helper = gfx.AxesHelper(size=50, thickness=5)
_ = fig[0, 0]._fpl_graphics_scene.add(axes_helper)

In [28]:
# add text to axes helper
fig[0,0].add_text("x", offset=(20, -5, 0))
fig[0,0].add_text("y", offset=(-5, 20, 0))
_ = fig[0,0].add_text("z", offset=(0, -5, 20))

In [29]:
# remove axes 
fig[0,0].axes.visible = False

In [30]:
for sb in slice_fig:
    sb.toolbar = False
    sb.axes.visible = False

In [31]:
fig[0, 0].toolbar = False

In [32]:
# add time title to bottom of plot
t_graphic = fpl.TextGraphic(text=f"time: {time.value}", anchor="middle-center")
fig[0,0].docks["bottom"].add_graphic(t_graphic)

fig[0,0].docks["bottom"].size = 25

In [33]:
for s in [xy, xz, yz]:
    s.value = 200

In [34]:
controller_button = ToggleButtons(options=["orbit", "fly"], description="controller:")

def update_controller(change):
    new = change["new"]

    fig[0, 0].controller = new

controller_button.observe(update_controller, "value")

camera_state = {
    'position': np.array([-382.16958611,  314.87869604,  -53.04635314]),
    'rotation': np.array([-0.06085387, -0.84882078, -0.10019799,  0.5155196 ]),
    'scale': np.array([1., 1., 1.]),
    'reference_up': np.array([0., 1., 0.]),
    'fov': 50.0,
    'width': 477.7115315482451,
    'height': 477.7115315482451,
    'zoom': 0.75,
    'maintain_aspect': True,
    'depth_range': None
}

reset_camera_button = Button(description="reset view")

def reset_camera_and_slices(*args):
    fov.value = 50
    fig[0, 0].camera.set_state(camera_state)

    for subplot in slice_fig:
        subplot.camera.maintain_aspect = True
        subplot.auto_scale(maintain_aspect=True)

    for s in [xy, xz, yz]:
        s.value = 200

reset_camera_button.on_click(reset_camera_and_slices)

In [35]:
fov = IntSlider(min=1, max=70, value=50, description="fov, degrees:")

def update_fov(change):
    val = change["new"]
    fig[0, 0].controller.update_fov(val - fig[0, 0].camera.fov, animate=False)

fov.observe(update_fov, "value")

In [36]:
HBox([fig.show(), slice_fig.show()])

HBox(children=(JupyterWgpuCanvas(css_height='650px', css_width='650px'), JupyterWgpuCanvas(css_height='650px',…

In [48]:
fig.renderer.blend_mode = "weighted_plus"

In [56]:
vol.material.image_min, vol.material.image_max

((0.0, 0.0, 0.0, 0.0), (1.0, 1.0, 1.0, 0.5))

In [78]:
vol.material.image_min = (0.0, 0.0, 0.0, 1.0)

Powered by:

[![fastplotlib](./logo_small.png)](https://github.com/fastplotlib/fastplotlib)

In [54]:
VBox([play, time, xy, xz, yz, clim_red, clim_green, fov, controller_button, reset_camera_button])

VBox(children=(Play(value=0, interval=500, max=120), IntSlider(value=0, description='time:', layout=Layout(wid…

In [38]:
play.interval = 500

In [39]:
fig[0, 0].camera.set_state(camera_state)