**Table of contents**<a id='toc0_'></a>    
- [Import packages](#toc1_1_)    
- [Set up working directory](#toc1_2_)    
- [Define functions for image processing](#toc1_3_)    
- [Process files](#toc1_4_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

## <a id='toc1_1_'></a>[Import packages](#toc0_)

In [None]:
import numpy as np
import czifile
from aicspylibczi import CziFile
from scipy import ndimage
import plotly.graph_objects as go
from pathlib import Path
import glob
import plotly.io as pio
import napari
from napari.utils import nbscreenshot
from napari_animation import Animation
import os

## <a id='toc1_2_'></a>[Set up working directory](#toc0_)

In [None]:
"""
Define the path to the working directory and the output folder where the 3D volumes will be saved.
Define the name of the files to be processed.
"""
data_path = 'path_to_working_directory'
output_path = 'path_to_output_directory'
file_name = glob.glob(f'{data_path}*.*', recursive = True) # List of file names to process

In [None]:
"""
Ensure the output directory exists and that files are read correctly.
"""
total_files = len(file_name)
print(f"Wolking forlder contains {total_files} files")

## <a id='toc1_3_'></a>[Define functions for image processing](#toc0_)

In [None]:
class ZebrafishVolumeAnalyzer:
    """
    Load confocal images stacks (.czi files) and pre-processes the images and creates 3D volumes. 
    Use Napari to interactively visualise the volumes with adjustable rendering settings.
    Saves rendered views as images (.png format).
    """
    def __init__(self, file_paths):
        self.file_paths = [Path(p) for p in file_paths] 
        self.volumes = {}

    # -------------------------
    # Load CZI z-stack
    # -------------------------
    def load_czi_stack_cranial(self, file_path):
        """
        Load a confocal image (.czi file), read image metada and reduces metadata dimensions.
        """
        czi = CziFile(file_path) # Defines the path to the .czi file
        img_data = czi.read_image() # Reads the image data from the .czi file
        img_array = img_data[0] if isinstance(img_data, tuple) else img_data # Extracts the image array from the data tuple
        img_array = np.squeeze(img_array) # Squeeze any a single dimension (removes dimensions of metadata)
        if img_array.ndim == 4: # If the image has 4 dimensions (e.g. channels × Z × Y × X)
            img_array = img_array[0] if img_array.shape[0] < img_array.shape[1] else img_array[:, 0, :, :] # Select first channel if multiple channels present
        if img_array.ndim != 3: # If the image has 3 dimensions (e.g.  Z × Y × X)
            raise ValueError(f"Unexpected shape: {img_array.shape}") # Retrieves an error as image is 2D
        return img_array # Returns the image array (Z x Y x X)
    
    def load_czi_stack_trunk(self, file_path, tile_number = 3):
        """
        Load a confocal image (.czi file) composed of 3 horizontal tiles, read image metada and reduces metadata dimensions.
        The tiles should be arranged horizontally in the image.
        """
        with czifile.CziFile(file_path) as czi:
            data = czi.asarray()   # Reads the image data from the .czi file
        img_array = data[0] if isinstance(data, tuple) else data # Extracts the image array from the data tuple
        img_array = np.squeeze(img_array) # Squeeze any a single dimension (removes dimensions of metadata)
        if img_array.ndim == 4: # If the image has 4 dimensions (e.g. channels × Z × Y × X)
            Z, Y, X_total = img_array.shape # Get the shape of the image array
            tile_width = X_total // tile_number # Calculate the width of each tile
            tiles = [img_array[:, :, i*tile_width:(i+1)*tile_width] for i in range(3)] # Split the image into tiles
            img_array = np.concatenate(tiles, axis=2) # Concatenate tiles along X axis
        if img_array.ndim != 3: # If the image has 3 dimensions (e.g.  Z × Y × X)
            raise ValueError(f"Unexpected shape: {img_array.shape}") # Retrieves an error as image is 2D
        return img_array # Returns the image array (Z x Y x X)

    # -------------------------
    # Pre-process volume
    # -------------------------
    def preprocess_volume(self, volume, threshold_percentile=50, smooth_sigma=1):
        """
        Pre-process 3D volume by applying Gaussian smoothing and intensity thresholding.
        """
        if smooth_sigma > 0:
            volume = ndimage.gaussian_filter(volume, sigma=smooth_sigma) # Apply Gaussian filter for smoothing
        threshold = np.percentile(volume, threshold_percentile) # Calculate intensity threshold based on percentile
        volume = np.where(volume > threshold, volume, 0) # Apply thresholding to remove low-intensity pixels (background noise)
        if volume.max() > 0: # Normalise volume to [0, 1] range if max intensity > 0
            volume = (volume - volume.min()) / (volume.max() - volume.min()) # Normalise volume to [0, 1] range
        return volume # Return pre-processed volume

    # -------------------------
    # Load all volumes
    # -------------------------
    def load_all_volumes(self, preprocess=True, **kwargs):
        """
        Load all pre-process all volumes from the provided file paths.
        """
        for file_path in self.file_paths: # Iterate through each file within the working directory
            vol = self.load_czi_stack(file_path) # Load the image stacks (.czi)
            if preprocess:
                vol = self.preprocess_volume(vol, **kwargs) # Load pre-processed the volume
            self.volumes[file_path.stem] = vol # Store the volume in a dictionary with the file name as key
        print(f"Loaded {len(self.volumes)} volumes.") 

    # -------------------------
    # Interactive Napari viewer for a single volume
    # -------------------------
    def napari_view_volume(self, volume_name, colormap="green", rendering="attenuated_mip", contrast_limits=(0,1), projection_mode='max'):
        """
        Launch Napari viewer to plot single volumes.
        Colourmaps, rendering mode and contrast limit can be modified through the interactive interphase.
        """
        if volume_name not in self.volumes: # Check if the specified volume is loaded
            raise ValueError(f"Volume {volume_name} not loaded.") # Raise error if volume not found
        viewer = napari.Viewer(ndisplay=3) # Create Napari viewer with the 3D volume display
        viewer.theme = 'light'  # Optional: Switch to light theme. This will create a white canvas background.
        vol = self.volumes[volume_name] # Retrieve the specified volume
        layer = viewer.add_image(
            vol,                                # Add the volume to the Napari viewer
            name=volume_name,                   # Name the layer with the volume name
            contrast_limits=contrast_limits,    # Set contrast limits for intensity adjustment
            projection_mode = projection_mode,  # Set projection mode for volume rendering
            # Options for projection_mode: 'max', 'min', 'mean' or 'sum'
            rendering=rendering,                # Set rendering mode for volume rendering
            # Options for rendering: 'attenuated_mip', 'mip', 'additive' or 'translucent'
            colormap=colormap,                  # Set colourmap for volume display
            attenuation=0.5                     # Adjust light attenuation
        )
        return viewer, layer # Return the Napari viewer and the added layer

    # -------------------------
    # Save rendered volume as an image (.png)
    # -------------------------
    def save_volume_png(self, viewer, volume_name, folder_path="volumes_png"):
        """
        Save the currently visualised volume as an image (.png format)
        """
        os.makedirs(folder_path, exist_ok=True) # Create output folder if it doesn't exist
        png_path = os.path.join(folder_path, f"{volume_name}.png") # Define the path to save the .png file
        viewer.screenshot(path=png_path) # Save screenshot of the current Napari viewer state as a .png file
        print(f"Saved PNG for {volume_name}: {png_path}")
        return png_path # Retruns the path to the saved .png files

    # -------------------------
    # View and save all volumes
    # -------------------------
    def napari_view_and_save_all(self, folder_path="volumes_png"):
        """
        Open each volume interactively in Napari for adjustment and save as an image (.png format)
        """
        for name in self.volumes: # Iterate through each loaded volume
            print(f"Visualizing {name}")
            viewer, layer = self.napari_view_volume(name) # Launch Napari viewer for the volume
            napari.run()  # Start the Napari event loop for interactive viewing
            self.save_volume_png(viewer, name, folder_path) # Save the visualised volume as a .png file

## <a id='toc1_4_'></a>[Process files](#toc0_)

In [None]:
# Initialise the analyser with the list of file paths
analyzer = ZebrafishVolumeAnalyzer(file_name)
analyzer.load_all_volumes( # Load all images and process the volumes
        preprocess=True, # Perform image pre-processing
        threshold_percentile=50,  # Adjust to filter background
        smooth_sigma=1.0  # Adjust smoothing
    )

In [None]:
for name in analyzer.volumes: # Loop through each loaded volume
    viewer, layer = analyzer.napari_view_volume( # Launch Napari viewer for the volume
        name,                                    # Load defined volume
        colormap='black',                        # Defines colourmap to plot the rendered 3D volume. Options: Matplotlib colourmaps
        rendering='attenuated_mip'               # Set rendering mode for volume rendering 
    )
    print("Adjust the projection, colormap, attenuation, etc. in Napari.")
    napari.run()
    analyzer.save_volume_png(viewer, name, folder_path=output_path) # Save rendered volumes as .png image