In [None]:
import napari
from napari.qt import thread_worker
import numpy as np
import dask.array as da
from dask import delayed

from PyQt5.QtCore import Qt

from qtpy.QtWidgets import (
    QWidget, 
    QSizePolicy, 
    QLabel, 
    QGridLayout, 
    QPushButton,
    QProgressBar,
    QSpinBox,
)

delayed_load = delayed(np.load)

In [None]:
# N_z, N_y, N_x = np.load('neuron6.npy').shape  # Another way to get the image size?

N_z = 225
N_y = 2400
N_x = 825

print(N_z, N_y, N_x)

In [None]:
path_coordinates = np.load('../DemoData/demo_locations2.npy')[:, ::-1]

path_coordinates.shape

In [None]:
img = da.from_delayed(
    delayed_load('neuron6.npy'),
    shape=(N_z, N_y, N_x),
    dtype=float
).rechunk((100, 200, 200))

seg = da.from_delayed(
    delayed_load('pred6.npy') < 8.0,
    shape=(N_z, N_y, N_x),
    dtype=bool
).rechunk((100, 200, 200))

In [None]:
def get_image_chunk(img: da.array, center_loc, chunk_shape) -> da.Array:
    center_loc_array = np.asarray(center_loc).astype(int)
    cz, cy, cx = center_loc_array
    depth, width, length = chunk_shape

    max_z, max_y, max_x = img.shape
    
    start_z = cz - depth // 2
    stop_z = cz + depth // 2
    start_y = cy - width // 2
    stop_y = cy + width // 2
    start_x = cx - length // 2
    stop_x = cx + length // 2

    # Take care of img borders
    start_z = max(start_z, 0)
    start_y = max(start_y, 0)
    start_x = max(start_x, 0)

    stop_z = min(stop_z, max_z)
    stop_y = min(stop_y, max_y)
    stop_x = min(stop_x, max_x)

    img_chunk = img[start_z:stop_z, start_y:stop_y, start_x:stop_x]

    return img_chunk


def get_visible_nodes(img: da.array, center_loc, chunk_shape, path_coordinates) -> da.Array:
    center_loc_array = np.asarray(center_loc).astype(int)
    cz, cy, cx = center_loc_array
    depth, width, length = chunk_shape

    max_z, max_y, max_x = img.shape
    
    start_z = cz - depth // 2
    stop_z = cz + depth // 2
    start_y = cy - width // 2
    stop_y = cy + width // 2
    start_x = cx - length // 2
    stop_x = cx + length // 2

    # Take care of img borders
    start_z = max(start_z, 0)
    start_y = max(start_y, 0)
    start_x = max(start_x, 0)

    stop_z = min(stop_z, max_z)
    stop_y = min(stop_y, max_y)
    stop_x = min(stop_x, max_x)

    path_coordinates_array = np.asarray(path_coordinates).astype(int)
    visible_nodes_filter = (path_coordinates_array[:, 0] >= start_z) & \
        (path_coordinates_array[:, 0] < stop_z) & \
        (path_coordinates_array[:, 1] >= start_y) & \
        (path_coordinates_array[:, 1] < stop_y) & \
        (path_coordinates_array[:, 2] >= start_x) & \
        (path_coordinates_array[:, 2] < stop_x)
            
    visible_nodes = path_coordinates_array[visible_nodes_filter]
    visible_nodes_relative_loc = visible_nodes - center_loc + np.asarray(chunk_shape) // 2

    return visible_nodes_relative_loc


def get_bbox_location(img: da.array, center_loc, chunk_shape):
    center_loc_array = np.asarray(center_loc).astype(int)
    cz, cy, cx = center_loc_array
    depth, width, length = chunk_shape

    max_z, max_y, max_x = img.shape
    
    start_z = cz - depth // 2
    stop_z = cz + depth // 2
    start_y = cy - width // 2
    stop_y = cy + width // 2
    start_x = cx - length // 2
    stop_x = cx + length // 2

    # Take care of img borders
    start_z = max(start_z, 0)
    start_y = max(start_y, 0)
    start_x = max(start_x, 0)

    stop_z = min(stop_z, max_z)
    stop_y = min(stop_y, max_y)
    stop_x = min(stop_x, max_x)

    return np.array([
        [start_y, start_x],
        [stop_y, stop_x]
    ])


class NeuronSkeletonWalker(QWidget):
    def __init__(self, img, seg, path_coordinates, napari_viewer, minimap_viewer) -> None:
        super().__init__()

        self.path_coordinates = path_coordinates
        self.num_locs = len(path_coordinates)
        self.current_idx = 0
        self.img = img
        self.seg = seg

        self.viewer = napari_viewer
        self.viewer.text_overlay.visible = True

        self.minimap_viewer = minimap_viewer

        # Key bindings
        self.viewer.bind_key('Left', self.move_forward)
        self.viewer.bind_key('Right', self.move_backward)

        ### QT Layout
        grid_layout = QGridLayout()
        grid_layout.setAlignment(Qt.AlignTop)
        self.setLayout(grid_layout)

        # Step forward / backward
        self.forward_btn = QPushButton("Step forward", self)
        self.forward_btn.clicked.connect(self.move_forward)
        grid_layout.addWidget(self.forward_btn, 0, 0)

        self.backward_btn = QPushButton("Step backward", self)
        self.backward_btn.clicked.connect(self.move_backward)
        grid_layout.addWidget(self.backward_btn, 0, 1)

        # Start / Stop button
        self.play_btn = QPushButton("Start", self)
        self.play_btn.clicked.connect(self.toggle_play)
        grid_layout.addWidget(self.play_btn, 1, 0, 1, 2)
        self.running = False

        # Chunk size in X / Y / Z
        grid_layout.addWidget(QLabel("Z"), 2, 0)
        self.z_chunk_spinbox = QSpinBox()
        self.z_chunk_spinbox.setMinimum(1)
        self.z_chunk_spinbox.setMaximum(2000)
        self.z_chunk_spinbox.setValue(20)
        grid_layout.addWidget(self.z_chunk_spinbox, 2, 1)

        grid_layout.addWidget(QLabel("Y"), 3, 0)
        self.y_chunk_spinbox = QSpinBox()
        self.y_chunk_spinbox.setMinimum(1)
        self.y_chunk_spinbox.setMaximum(2000)
        self.y_chunk_spinbox.setValue(100)
        grid_layout.addWidget(self.y_chunk_spinbox, 3, 1)

        grid_layout.addWidget(QLabel("X"), 4, 0)
        self.x_chunk_spinbox = QSpinBox()
        self.x_chunk_spinbox.setMinimum(1)
        self.x_chunk_spinbox.setMaximum(2000)
        self.x_chunk_spinbox.setValue(100)
        grid_layout.addWidget(self.x_chunk_spinbox, 4, 1)

        # Progress bar
        self.pbar = QProgressBar(self, minimum=0, maximum=1)
        self.pbar.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        grid_layout.addWidget(self.pbar, 7, 0, 1, 2)

        # Update the view when the values change in the spinboxes
        self.z_chunk_spinbox.valueChanged.connect(self._update_view)
        self.y_chunk_spinbox.valueChanged.connect(self._update_view)
        self.x_chunk_spinbox.valueChanged.connect(self._update_view)

        # Image layer
        self.image_layer = self.viewer.add_image(
            self.current_image_chunk(),
            multiscale=False,
            contrast_limits = [0, 1]
        )

        # Labels layer (hide it by default)
        self.labels_layer = self.viewer.add_labels(
            self.current_labels_chunk(),
            visible=False
        )

        # Points layer
        self.points_layer = self.viewer.add_points(
            self.current_visible_nodes(),
            face_color='red',
            size=1,
        )

        # Shapes layer (path)
        self.shapes_layer = self.viewer.add_shapes(
            self.current_visible_nodes(),
            shape_type='path',
            edge_color='red',
            edge_width=0.2
        )

        # Shapes layer (path) in the minimap viewer
        self.minimap_path_layer = self.minimap_viewer.add_shapes(
            self.current_visited_locs(),
            shape_type='path',
            edge_color='red',
            edge_width=5
        )

        # Shapes layer (bounding box) in the minimap viewer
        self.minimap_shapes_layer = self.minimap_viewer.add_shapes(
            self.current_bbox(),
            shape_type='rectangle',
            edge_color='red',
            edge_width=5,
            face_color='transparent',
            name="Current location"
        )

        self._update_view()

    @property
    def chunk_shape(self):
        cz = self.z_chunk_spinbox.value()
        cy = self.y_chunk_spinbox.value()
        cx = self.x_chunk_spinbox.value()
        return (cz, cy, cx)
    
    def _update_view(self):
        self._update_image()
        self._update_labels()
        self._update_shapes()
        self._update_points()
        self._update_minimap_path()
        self._update_minimap_bbox()
        self._update_overlay()

    def _update_image(self):
        self.image_layer.data = self.current_image_chunk()

    def _update_labels(self):
        self.labels_layer.data = self.current_labels_chunk()

    def _update_shapes(self):
        self.shapes_layer.data = self.current_visible_nodes()
    
    def _update_points(self):
        self.points_layer.data = self.current_visible_nodes()

    def _update_minimap_path(self):
        self.minimap_path_layer.data = self.current_visited_locs()

    def _update_minimap_bbox(self):
        self.minimap_shapes_layer.data = self.current_bbox()
    
    def _update_overlay(self):
        self.viewer.text_overlay.text = f"idx={self.current_idx+1} / {self.num_locs}"

    def current_visited_locs(self):
        return self.path_coordinates[:(max(self.current_idx, 2)), 1:]

    def current_bbox(self):
        return get_bbox_location(
            self.img,
            center_loc=self.path_coordinates[self.current_idx],
            chunk_shape=self.chunk_shape
        )
    
    def current_visible_nodes(self):
        return get_visible_nodes(
            self.img,
            center_loc=self.path_coordinates[self.current_idx],
            chunk_shape=self.chunk_shape,
            path_coordinates=self.path_coordinates
        )

    def current_image_chunk(self) -> da.array:
        return get_image_chunk(
            self.img,
            center_loc=self.path_coordinates[self.current_idx],
            chunk_shape=self.chunk_shape,
        )
    
    def current_labels_chunk(self) -> da.array:
        return get_image_chunk(
            self.seg,
            center_loc=self.path_coordinates[self.current_idx],
            chunk_shape=self.chunk_shape,
        )
    
    def move_forward(self, *args, **kwargs):
        if self.current_idx + 1 <= self.num_locs-1:
            self.current_idx += 1
        else:
            return

        self._update_view()

    def move_backward(self, *args, **kwargs):
        if self.current_idx - 1 >= 0:
            self.current_idx -= 1
        else:
            return
        
        self._update_view()

    @thread_worker
    def run_animation(self):
        while (self.running is True) & (self.current_idx+1 < self.num_locs):
            self.move_forward()

    def toggle_play(self):
        self.running = not self.running
        if self.running:
            self.play_btn.setText('Stop')
            self.pbar.setMaximum(0)  # Start the progress bar
            worker = self.run_animation()
            worker.returned.connect(self.thread_worker_returned)
            worker.start()
        else:
            print(f"{self.running=}")

    def thread_worker_returned(self, return_value=None):
        self.running = False
        self.play_btn.setText('Start')
        self.pbar.setMaximum(1) # Stop the progress bar

In [None]:
minimap_viewer = napari.view_image(da.max(img, axis=0).compute(), contrast_limits=[0, 1], multiscale=False)

viewer = napari.Viewer(ndisplay=3)

skeleton_walker = NeuronSkeletonWalker(img, seg, path_coordinates, viewer, minimap_viewer)

viewer.window.add_dock_widget(skeleton_walker, name="Neuron walker");

In [None]:
### Todos

# 4D dataset lazily loaded with a slider?
# Camera along neuron's local orientation?