In [None]:
import os

import cv2
import ipywidgets as widgets
import matplotlib.pyplot as plt
import nrrd
import numpy as np
from IPython.display import clear_output, display
from ipywidgets import HBox, Layout, VBox


class CTViewer:
    """
    A class to create an ipywidgets-based CT viewer for visualizing raw and processed
    CT data for machine learning projects.
    """

    def __init__(self, data_folder):
        """
        Initializes the CTViewer.

        Args:
            data_folder (str): The path to the folder containing subfolders of CT scans.
        """
        self.data_folder = data_folder
        self.scan_folders = [f.path for f in os.scandir(data_folder) if f.is_dir()]
        self.scan_names = [os.path.basename(f) for f in self.scan_folders]

        # UI Components
        self.scan_dropdown = widgets.Dropdown(options=self.scan_names, description='CT Scan:')
        self.nrrd_slice_slider = widgets.IntSlider(description='Slice:')
        self.png_slice_slider = widgets.IntSlider(description='Slice:')
        self.toggle_nrrd_mask_button = widgets.Button(description='Toggle NRRD Mask')
        self.toggle_png_mask_button = widgets.Button(description='Toggle PNG Mask')

        self.nrrd_mask_visible = True
        self.png_mask_visible = True

        # Observers
        self.scan_dropdown.observe(self._on_scan_change, names='value')
        self.nrrd_slice_slider.observe(self._on_nrrd_slice_change, names='value')
        self.png_slice_slider.observe(self._on_png_slice_change, names='value')
        self.toggle_nrrd_mask_button.on_click(self._on_toggle_nrrd_mask)
        self.toggle_png_mask_button.on_click(self._on_toggle_png_mask)

        # Data placeholders
        self.ct_nrrd = None
        self.seg_nrrd = None
        self.png_files = None
        self.mask_png_files = None
        self.contour_png_files = None

        # Initial data load
        self._load_scan_data(self.scan_dropdown.value)


    def _load_scan_data(self, scan_name):
        """Loads the data for the selected CT scan."""
        scan_path = os.path.join(self.data_folder, scan_name)

        # Load NRRD files
        self.ct_nrrd, _ = nrrd.read(os.path.join(scan_path, f"{scan_name}.nrrd"))
        self.seg_nrrd, _ = nrrd.read(os.path.join(scan_path, f"{scan_name}_seg.nrrd"))

        # Discover PNG files
        self.png_files = sorted([os.path.join(scan_path, f) for f in os.listdir(scan_path) if f.endswith('.png') and 'mask' not in f and 'contour' not in f])
        self.mask_png_files = sorted([os.path.join(scan_path, f) for f in os.listdir(scan_path) if 'mask' in f and 'contour' not in f])
        self.contour_png_files = sorted([os.path.join(scan_path, f) for f in os.listdir(scan_path) if 'contour' in f])

        # Update sliders
        self.nrrd_slice_slider.max = self.ct_nrrd.shape[2] - 1
        self.png_slice_slider.max = len(self.png_files) - 1
        self.nrrd_slice_slider.value = 0
        self.png_slice_slider.value = 0


    def _on_scan_change(self, change):
        """Handles the event when a new CT scan is selected."""
        self._load_scan_data(change.new)
        self._update_nrrd_plot()
        self._update_png_plot()
        self._update_contour_plot()

    def _on_nrrd_slice_change(self, change):
        """Handles the event when the NRRD slice slider is changed."""
        self._update_nrrd_plot()

    def _on_png_slice_change(self, change):
        """Handles the event when the PNG slice slider is changed."""
        self._update_png_plot()
        self._update_contour_plot()

    def _on_toggle_nrrd_mask(self, b):
        """Toggles the visibility of the NRRD segmentation mask."""
        self.nrrd_mask_visible = not self.nrrd_mask_visible
        self._update_nrrd_plot()

    def _on_toggle_png_mask(self, b):
        """Toggles the visibility of the PNG segmentation mask."""
        self.png_mask_visible = not self.png_mask_visible
        self._update_png_plot()

    def _update_nrrd_plot(self):
        """Updates the NRRD plot."""
        slice_idx = self.nrrd_slice_slider.value
        with self.nrrd_plot_output:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(5, 5))
            ax.imshow(self.ct_nrrd[:, :, slice_idx], cmap='gray')
            if self.nrrd_mask_visible:
                ax.imshow(self.seg_nrrd[:, :, slice_idx], cmap='Reds', alpha=0.5)
            ax.axis('off')
            plt.show()

    def _update_png_plot(self):
        """Updates the PNG plot."""
        slice_idx = self.png_slice_slider.value
        with self.png_plot_output:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(5, 5))
            img = cv2.imread(self.png_files[slice_idx], cv2.IMREAD_GRAYSCALE)
            ax.imshow(img, cmap='gray')
            if self.png_mask_visible and self.mask_png_files:
                mask = cv2.imread(self.mask_png_files[slice_idx], cv2.IMREAD_GRAYSCALE)
                ax.imshow(mask, cmap='Reds', alpha=0.5)
            ax.axis('off')
            plt.show()

    def _update_contour_plot(self):
        """Updates the contour plot."""
        slice_idx = self.png_slice_slider.value
        with self.contour_plot_output:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(5, 5))
            if self.contour_png_files:
                contour = cv2.imread(self.contour_png_files[slice_idx], cv2.IMREAD_GRAYSCALE)
                ax.imshow(contour, cmap='gray')
            ax.axis('off')
            plt.show()

    def display(self):
        """Renders the CT viewer UI."""
        self.nrrd_plot_output = widgets.Output()
        self.png_plot_output = widgets.Output()
        self.contour_plot_output = widgets.Output()

        panel1 = VBox([self.nrrd_slice_slider, self.toggle_nrrd_mask_button, self.nrrd_plot_output])
        panel2 = VBox([self.png_slice_slider, self.toggle_png_mask_button, self.png_plot_output])
        panel3 = VBox([self.contour_plot_output])

        self._update_nrrd_plot()
        self._update_png_plot()
        self._update_contour_plot()

        main_ui = VBox([self.scan_dropdown, HBox([panel1, panel2, panel3])])
        display(main_ui)

In [None]:
viewer_cv2 = CTViewer('data_cv2')
viewer_cv2.display()