In [6]:
from typing import Any

import cv2
import matplotlib.pyplot as plt
import nrrd
import numpy as np
from ipywidgets import Button, HBox, IntSlider, Output, Text, VBox


class ImageViewer:
    def __init__(self) -> None:
        self.img, self.mask = None, None
        self.fig, self.ax, self.ax_image, self.ax_mask = None, None, None, None
        self.output = Output()
        self.text_input = Text(placeholder="Enter path to nrrd or png file")
        self.button = Button(description="Submit")
        self.slider = IntSlider(
            orientation="horizontal",
            description="Slice",
            value=0,
            min=0,
            max=0,
            disabled=True,
        )
        self.slider.layout.width = "50%"

        self.button.on_click(self._load_image)
        self.slider.observe(self._update_slice, names="value")

    def display(self) -> None:
        app_layout = VBox(
            [
                HBox([self.text_input, self.button]),
                self.slider,
                self.output,
            ]
        )
        display(app_layout)

    def _load_image(self, b: Button) -> None:
        self.output.clear_output()

        filepath = self.text_input.value
        if not filepath:
            with self.output:
                print("Error: Please enter a file path.")
            return

        try:
            path: str = self.text_input.value
            if path.endswith(".nrrd"):
                img, _ = nrrd.read(path)
                mask, _ = nrrd.read(path.replace(".nrrd", ".seg.nrrd"))
            else:
                img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                img = img[:, :, np.newaxis]
                mask = cv2.imread(
                    path.replace("/images/", "/masks/"), cv2.IMREAD_GRAYSCALE
                )
                mask = mask[:, :, np.newaxis]

            self.img = img
            self.mask = mask

            self.slider.value = 0
            self.slider.max = self.img.shape[-1] - 1
            self.slider.disabled = False

            with self.output:
                self.fig, self.ax = plt.subplots(figsize=(7, 7))
                self.ax.axis("off")

                self.ax_image = self.ax.imshow(self.img[:, :, 0], cmap="gray")
                self.ax_mask = self.ax.imshow(
                    np.where(mask[:, :, 0], 1, np.nan), cmap="jet", alpha=0.5
                )
                plt.show(self.fig)
                self.ax.set_title(f"Slice 1 / {self.slider.max + 1}")

        except Exception as e:
            with self.output:
                print(f"An error occurred during image loading: {e}")
            self.slider.disabled = True
            self.img, self.mask = None, None
            self.fig, self.ax, self.ax_image, self.ax_mask = None, None, None, None

    def _display_slice(self, idx: int) -> None:
        if self.img is None or self.ax is None:
            return

        self.ax_image.set_data(self.img[:, :, idx])
        self.ax_mask.set_data(np.where(self.mask[:, :, 0], 1, np.nan))
        self.ax.set_title(f"Slice {idx + 1} / {self.slider.max + 1}")

    def _update_slice(self, change: Any):
        self._display_slice(change.new)

In [7]:
%matplotlib widget

# /Users/cameronjohnson/Downloads/14806362/Rider/R3/R3.nrrd
# /Users/cameronjohnson/Documents/repos/aaa-seg/data/images/D1_21.png
image_viewer = ImageViewer()
image_viewer.display()

VBox(children=(HBox(children=(Text(value='', placeholder='Enter path to nrrd or png file'), Button(description…