In [1]:
import numpy as np
import cv2
from skimage import measure
from skimage.draw import line
import napari

class CellTracker:
    def __init__(self):
        self.cell_data = []  # List to store cell data (centroid, area, ID) for each frame
        self.current_id = 1  # Counter for assigning new IDs

    def add_frame(self, labeled_frame):
        """
        Iterate through frames in the vstacked tiffile of filtered labels and process them (extract region props, etc.)
        """
        # List frames in directory
        frames_from_tiff = tifffile.imread(tiffile_from_directory_vstack)

        for frame_timepoint in frames_from_tiff:
            self._process_frame(frame_timepoint)

    def _process_frame(self, labeled_frame):
        """
        Process frame and extract region props (area, centroid, etc.)
        """
        # Extract region properties from the labeled frame
        region_props_list = []

        for region in measure.regionprops(labeled_frame):
            area = region.area
            centroid = tuple(map(int, region.centroid))

            # Add cell data to the region_props_list
            region_props_list.append({"centroid": centroid, "area": area, "ID": None})

        # Link cells based on proximity, area change, and division
        area_threshold_linking = 100
        area_threshold_division = 150
        distance_threshold_linking = 50
        distance_threshold_division = 60
        distance_threshold_skipframe = 100  # Distance threshold to skip a frame
        self._link_cells(region_props_list, area_threshold_linking, area_threshold_division,
                         distance_threshold_linking, distance_threshold_division, distance_threshold_skipframe)

        # Add processed frame data to the cell_data list
        self.cell_data.append(region_props_list)

    def _link_cells(self, region_props_list, area_threshold_linking, area_threshold_division,
                    distance_threshold_linking, distance_threshold_division, distance_threshold_skipframe):
        """
        Link cells based on different scenarios (ID assignment, proximity, division, new cells, skip frame).

        Parameters:
            - region_props_list: List of dictionaries containing region properties for the current frame.
            - area_threshold_linking: Threshold for considering a change in area during cell linking.
            - area_threshold_division: Threshold for considering a change in area during cell division.
            - distance_threshold_linking: Threshold for considering proximity between centroids during cell linking.
            - distance_threshold_division: Threshold for considering proximity between centroids during cell division.
            - distance_threshold_skipframe: Threshold for considering proximity to skip a frame.
        """
        if len(self.cell_data) == 0:  # Assign cell IDs in the first frame
            for cell in region_props_list:
                cell["ID"] = f"C{self.current_id}"
                self.current_id += 1

        elif len(self.cell_data) > 0:
            prev_frame_data = self.cell_data[-1]

            # Check if the previous frame was skipped
            if hasattr(self, 'skip_frame') and self.skip_frame:
                self.skip_frame = False  # Reset skip_frame flag
                return  # Skip processing the current frame

            # Check if there are any cells within distance_threshold_skipframe
            if not any(np.linalg.norm(np.array(cell["centroid"]) for cell in prev_frame_data) < distance_threshold_skipframe):
                self.skip_frame = True  # Set skip_frame flag
                return  # Skip processing the current frame

            for cell in region_props_list:
                # Check for ID assignment based on proximity and area change
                matching_cells = [prev_cell for prev_cell in prev_frame_data if
                                  np.linalg.norm(np.array(cell["centroid"]) - np.array(prev_cell["centroid"])) < distance_threshold_linking
                                  and abs(cell["area"] - prev_cell["area"]) < area_threshold_linking]

                if matching_cells:
                    cell["ID"] = matching_cells[0]["ID"]

                # Check for division based on multiple centroids and a drop in area in the current frame
                elif len(cell["centroid"]) > 1:
                    for i, centroid in enumerate(cell["centroid"]):
                        cell_area = cell["area"] if i == 0 else cell["area"] / 2  # Adjust area for daughter cells
                        area_drop_threshold = area_threshold_division

                        potential_division = [prev_cell for prev_cell in prev_frame_data if
                                              np.linalg.norm(np.array(centroid) - np.array(prev_cell["centroid"])) < distance_threshold_division
                                              and abs(cell_area - prev_cell["area"]) > area_drop_threshold]

                        if potential_division:
                            cell[f"ID_D{i + 1}"] = f"D{self.current_id}_{i + 1}"
                            self.current_id += 1

                # Check for new cells entering the frame
                else:
                    cell["ID"] = f"C{self.current_id}"
                    self.current_id += 1

    def draw_tracks_napari(self, original_images, output_path):
        """
        Draw tracks between cells in Napari and save the resulting image.

        Parameters:
            - original_images: Vstacked original images (TCYX)
            - output_path: Path to save the Napari viewer screenshot
        """
        viewer = napari.Viewer()

        for frame_index, frame_data in enumerate(self.cell_data):
            labels_frame = self.frame_labels[frame_index]
            original_image = original_images[frame_index]

            # Create a new layer for the original image
            viewer.add_image(original_image, name=f'Original Frame {frame_index}')

            # Create a new layer for the labeled cells
            viewer.add_labels(labels_frame, name=f'Labels Frame {frame_index}')

            # Draw tracks
            for cell in frame_data:
                if cell["ID"] is not None:
                    prev_frame_data = self.cell_data[frame_index - 1]
                    prev_cell = next((c for c in prev_frame_data if c["ID"] == cell["ID"]), None)

                    if prev_cell is not None:
                        track_color = (255, 0, 0)
                        track_thickness = 2

                        # Draw track between centroids
                        rr, cc = line(prev_cell["centroid"][0], prev_cell["centroid"][1],
                                      cell["centroid"][0], cell["centroid"][1])
                        viewer.add_shapes(data=[(rr, cc)], edge_color=track_color, edge_width=track_thickness,
                                          name=f'Track {cell["ID"]}')

        # Save the Napari viewer screenshot
        napari.write_screenshot(output_path)


#USAGE: 
import tifffile

# Create an instance of the CellTracker class
cell_tracker = CellTracker()

# Specify the path to the vstacked tiffile of filtered labels
tiffile_from_directory_vstack = r"C:\Users\micha\Desktop\MM_data_Pipeline_results_all\MM_pipeline_results_converted_Z1\Filtered_labels\filtered_labels.tiff"

# Call the add_frame method to process frames and extract region properties
cell_tracker.add_frame(tiffile_from_directory_vstack)

# Optional: Call the draw_tracks_napari method to visualize tracks in Napari
# Specify the path to the original images (TCYX)
original_images = tifffile.imread("path/to/your/original_images.tiff")
output_path = r"C:\Users\micha\Desktop\MM_data_Pipeline_results_all\MM_pipeline_results_converted_Z1\Tracks\napari_screenshot.png"
cell_tracker.draw_tracks_napari(original_images, output_path)
