In [33]:
# painting pixels

import os
import zarr
import napari
import numpy as np
import copick
from pathlib import Path
from cellcanvas_spp.segmentation import superpixels
from skimage.feature import multiscale_basic_features
from sklearn.ensemble import RandomForestClassifier
from skimage import future
from functools import partial
import threading
import toolz as tz
from psygnal import debounced
from superqt import ensure_main_thread
from qtpy.QtWidgets import (
    QVBoxLayout,
    QHBoxLayout,
    QComboBox,
    QLabel,
    QCheckBox,
    QDoubleSpinBox,
    QGroupBox,
    QWidget,
)
from appdirs import user_data_dir
import logging

# Set up logging
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

# Set up the data directory
DATA_DIR = Path("/Users/kharrington/git/cellcanvas/superpixels/notebooks/my_synthetic_data_10439_dataportal.json")

# Load the tomogram
def load_tomogram():
    config_file = DATA_DIR
    root = copick.from_file(config_file)
    run_name = "16193"
    example_run = root.get_run(run_name)
    tomogram = example_run.voxel_spacings[0].tomograms[0]
    z = zarr.open(tomogram.zarr())
    img = z["0"]  # Get the highest resolution scale
    return np.asarray(img)

# Load and crop the tomogram
full_tomogram = load_tomogram()
crop_3D = full_tomogram[50:100, 180:360, 210:430]  # Adjust crop as needed

# Compute superpixels
superpixel_seg = superpixels(crop_3D, sigma=4, h_minima=0.0025)

# Set up Napari viewer
viewer = napari.Viewer()
scale = (1, 1, 1)  # Adjust scale if needed
contrast_limits = (crop_3D.min(), crop_3D.max())

# Add layers
data_layer = viewer.add_image(crop_3D, scale=scale, contrast_limits=contrast_limits, name="Tomogram")
superpixel_layer = viewer.add_labels(superpixel_seg, scale=scale, name="Superpixels", opacity=0.5)

# Set up zarr for prediction and painting layers
zarr_path = os.path.join(user_data_dir("napari_dl_at_mbl_2024", "napari"), "diy_segmentation.zarr")
prediction_data = zarr.open(f"{zarr_path}/prediction", mode='a', shape=crop_3D.shape, dtype='i4', dimension_separator="/")
painting_data = zarr.open(f"{zarr_path}/painting", mode='a', shape=crop_3D.shape, dtype='i4', dimension_separator="/")

prediction_layer = viewer.add_labels(prediction_data, name="Prediction", scale=scale)
painting_layer = viewer.add_labels(painting_data, name="Painting", scale=scale)

# Feature extraction function
def extract_features(image, feature_params):
    features_func = partial(
        multiscale_basic_features,
        intensity=feature_params["intensity"],
        edges=feature_params["edges"],
        texture=feature_params["texture"],
        sigma_min=feature_params["sigma_min"],
        sigma_max=feature_params["sigma_max"],
        channel_axis=None,
    )
    features = features_func(image)
    return features

# Model update and prediction functions
def update_model(labels, features, model_type):
    logger.debug(f"Labels shape: {labels.shape}, Features shape: {features.shape}")
    logger.debug(f"Unique labels: {np.unique(labels)}")
    
    # Flatten the labels and features
    labels_flat = labels.ravel()
    features_flat = features.reshape(-1, features.shape[-1])
    
    # Filter out background (label 0)
    mask = labels_flat > 0
    filtered_features = features_flat[mask]
    filtered_labels = labels_flat[mask] - 1
    
    logger.debug(f"Filtered labels shape: {filtered_labels.shape}, Filtered features shape: {filtered_features.shape}")
    logger.debug(f"Unique filtered labels: {np.unique(filtered_labels)}")
    
    if filtered_labels.size == 0:
        logger.warning("No non-background labels found. Skipping model update.")
        return None
    
    if model_type == "Random Forest":
        clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, max_depth=10, max_samples=0.05)
    
    try:
        clf.fit(filtered_features, filtered_labels)
        logger.info("Model successfully updated")
        return clf
    except Exception as e:
        logger.error(f"Error updating model: {str(e)}")
        return None

# Update the predict function to handle flattened input
def predict(model, features, model_type):
    features_flat = features.reshape(-1, features.shape[-1])
    prediction_flat = model.predict(features_flat)
    return prediction_flat.reshape(features.shape[:-1]) + 1

# Napari ML Widget
class NapariMLWidget(QWidget):
    def __init__(self, parent=None):
        super(NapariMLWidget, self).__init__(parent)
        self.initUI()

    def initUI(self):
        layout = QVBoxLayout()

        model_label = QLabel("Select Model")
        self.model_dropdown = QComboBox()
        self.model_dropdown.addItems(["Random Forest"])
        model_layout = QHBoxLayout()
        model_layout.addWidget(model_label)
        model_layout.addWidget(self.model_dropdown)
        layout.addLayout(model_layout)

        self.sigma_start_spinbox = QDoubleSpinBox()
        self.sigma_start_spinbox.setRange(0, 10)
        self.sigma_start_spinbox.setValue(1)

        self.sigma_end_spinbox = QDoubleSpinBox()
        self.sigma_end_spinbox.setRange(0, 10)
        self.sigma_end_spinbox.setValue(5)

        sigma_layout = QHBoxLayout()
        sigma_layout.addWidget(QLabel("Sigma Range: From"))
        sigma_layout.addWidget(self.sigma_start_spinbox)
        sigma_layout.addWidget(QLabel("To"))
        sigma_layout.addWidget(self.sigma_end_spinbox)
        layout.addLayout(sigma_layout)

        self.intensity_checkbox = QCheckBox("Intensity")
        self.intensity_checkbox.setChecked(True)
        self.edges_checkbox = QCheckBox("Edges")
        self.texture_checkbox = QCheckBox("Texture")
        self.texture_checkbox.setChecked(True)

        features_group = QGroupBox("Features")
        features_layout = QVBoxLayout()
        features_layout.addWidget(self.intensity_checkbox)
        features_layout.addWidget(self.edges_checkbox)
        features_layout.addWidget(self.texture_checkbox)
        features_group.setLayout(features_layout)
        layout.addWidget(features_group)

        self.live_fit_checkbox = QCheckBox("Live Model Fitting")
        self.live_fit_checkbox.setChecked(True)
        layout.addWidget(self.live_fit_checkbox)

        self.live_pred_checkbox = QCheckBox("Live Prediction")
        self.live_pred_checkbox.setChecked(True)
        layout.addWidget(self.live_pred_checkbox)

        self.setLayout(layout)

# Add widget to Napari
widget = NapariMLWidget()
viewer.window.add_dock_widget(widget, name="Interactive Segmentation")

# Event listener
model = None

@tz.curry
def on_data_change(event, viewer=None, widget=None):
    painting_layer.refresh()

    thread = threading.Thread(
        target=threaded_on_data_change,
        args=(
            event,
            viewer.dims,
            widget.model_dropdown.currentText(),
            {
                "sigma_min": widget.sigma_start_spinbox.value(),
                "sigma_max": widget.sigma_end_spinbox.value(),
                "intensity": widget.intensity_checkbox.isChecked(),
                "edges": widget.edges_checkbox.isChecked(),
                "texture": widget.texture_checkbox.isChecked(),
            },
            widget.live_fit_checkbox.isChecked(),
            widget.live_pred_checkbox.isChecked(),
        ),
    )
    thread.start()
    thread.join()

    prediction_layer.refresh()

def threaded_on_data_change(
    event,
    dims,
    model_type,
    feature_params,
    live_fit,
    live_prediction,
):
    global model, crop_3D, painting_data
    
    # Ensure consistent shapes
    min_shape = [min(s1, s2) for s1, s2 in zip(crop_3D.shape, painting_data.shape)]
    active_image = crop_3D[:min_shape[0], :min_shape[1], :min_shape[2]]
    active_labels = painting_data[:min_shape[0], :min_shape[1], :min_shape[2]]

    logger.debug(f"Active image shape: {active_image.shape}, Active labels shape: {active_labels.shape}")

    training_features = extract_features(active_image, feature_params)
    training_labels = active_labels

    logger.debug(f"Training features shape: {training_features.shape}, Training labels shape: {training_labels.shape}")

    if np.any(training_labels > 0) and live_fit:
        model = update_model(training_labels, training_features, model_type)

    if live_prediction and model is not None:
        try:
            prediction = predict(model, training_features, model_type)
            prediction_layer.data[:min_shape[0], :min_shape[1], :min_shape[2]] = prediction
            logger.info("Prediction updated successfully")
        except Exception as e:
            logger.error(f"Error during prediction: {str(e)}")

# Connect event listeners
for listener in [viewer.dims.events, painting_layer.events.paint]:
    listener.connect(
        debounced(
            ensure_main_thread(
                on_data_change(viewer=viewer, widget=widget)
            ),
            timeout=1000,
        )
    )

napari.run()

In [3]:
# painting superpixels

import os
import zarr
import napari
import numpy as np
import pandas as pd
import copick
from pathlib import Path
from cellcanvas_spp.segmentation import superpixels
from skimage.measure import regionprops_table
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from skimage import future
from functools import partial
import threading
import toolz as tz
from psygnal import debounced
from superqt import ensure_main_thread
from qtpy.QtWidgets import (
    QVBoxLayout,
    QHBoxLayout,
    QComboBox,
    QLabel,
    QCheckBox,
    QDoubleSpinBox,
    QGroupBox,
    QWidget,
)
from appdirs import user_data_dir
import logging

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Set up the data directory
DATA_DIR = Path("/Users/kharrington/git/cellcanvas/superpixels/notebooks/my_synthetic_data_10439_dataportal.json")

# Load the tomogram
def load_tomogram():
    config_file = DATA_DIR
    root = copick.from_file(config_file)
    run_name = "16193"
    example_run = root.get_run(run_name)
    tomogram = example_run.voxel_spacings[0].tomograms[0]
    z = zarr.open(tomogram.zarr())
    img = z["0"]  # Get the highest resolution scale
    return np.asarray(img)

# Load and crop the tomogram
full_tomogram = load_tomogram()
crop_3D = full_tomogram[50:100, 180:360, 210:430]  # Adjust crop as needed

# Compute superpixels
superpixel_seg = superpixels(crop_3D, sigma=4, h_minima=0.0025)

# Set up Napari viewer
viewer = napari.Viewer()
scale = (1, 1, 1)  # Adjust scale if needed
contrast_limits = (crop_3D.min(), crop_3D.max())

# Add layers
data_layer = viewer.add_image(crop_3D, scale=scale, contrast_limits=contrast_limits, name="Tomogram")
superpixel_layer = viewer.add_labels(superpixel_seg, scale=scale, name="Superpixels", opacity=0.5)

# Set up zarr for prediction and painting layers
zarr_path = os.path.join(user_data_dir("napari_dl_at_mbl_2024", "napari"), "diy_segmentation.zarr")
print(f"zarr path: {zarr_path}")
prediction_data = zarr.open(f"{zarr_path}/prediction", mode='a', shape=crop_3D.shape, dtype='i4', dimension_separator="/")
painting_data = zarr.open(f"{zarr_path}/painting", mode='a', shape=crop_3D.shape, dtype='i4', dimension_separator="/")

prediction_layer = viewer.add_labels(prediction_data, name="Prediction", scale=scale)
painting_layer = viewer.add_labels(painting_data, name="Painting", scale=scale)

# Precompute regionprops features for each superpixel
def compute_superpixel_features(image, superpixels):
    props = regionprops_table(superpixels, intensity_image=image,
                              properties=('label', 
                                          'area', 
                                          'bbox',
                                          'bbox_area',
                                          'centroid',
                                          'equivalent_diameter',
                                          'euler_number',
                                          'extent',
                                          'filled_area',
                                          'major_axis_length',
                                          'max_intensity',
                                          'mean_intensity',
                                          'min_intensity',
                                          'std_intensity',))
    return props

superpixel_features = compute_superpixel_features(crop_3D, superpixel_seg)

def update_model(y, X, model_type):
    logger.debug(f"X shape: {X.shape}")
    logger.debug(f"y shape: {y.shape}")
    logger.debug(f"Unique labels: {np.unique(y)}")
    
    if y.size == 0:
        logger.warning("No labeled data found. Skipping model update.")
        return None
    
    if model_type == "Random Forest":
        clf = RandomForestClassifier(n_estimators=50, n_jobs=-1, max_depth=10, max_samples=0.05)
    
    try:
        clf.fit(X, y)
        logger.info("Model successfully updated")
        return clf
    except Exception as e:
        logger.error(f"Error updating model: {str(e)}")
        return None

def predict(model, superpixel_features):
    features = np.array([[superpixel_features[prop][i] for prop in superpixel_features.keys() if prop != 'label'] 
                         for i in range(len(superpixel_features['label']))])
    prediction = model.predict(features)
    return prediction

# Napari ML Widget
class NapariMLWidget(QWidget):
    def __init__(self, parent=None):
        super(NapariMLWidget, self).__init__(parent)
        self.initUI()

    def initUI(self):
        layout = QVBoxLayout()

        model_label = QLabel("Select Model")
        self.model_dropdown = QComboBox()
        self.model_dropdown.addItems(["Random Forest"])
        model_layout = QHBoxLayout()
        model_layout.addWidget(model_label)
        model_layout.addWidget(self.model_dropdown)
        layout.addLayout(model_layout)

        self.live_fit_checkbox = QCheckBox("Live Model Fitting")
        self.live_fit_checkbox.setChecked(True)
        layout.addWidget(self.live_fit_checkbox)

        self.live_pred_checkbox = QCheckBox("Live Prediction")
        self.live_pred_checkbox.setChecked(True)
        layout.addWidget(self.live_pred_checkbox)

        self.setLayout(layout)

# Add widget to Napari
widget = NapariMLWidget()
viewer.window.add_dock_widget(widget, name="Interactive Segmentation")

# Event listener
model = None

@tz.curry
def on_data_change(event, viewer=None, widget=None):
    painting_layer.refresh()

    thread = threading.Thread(
        target=threaded_on_data_change,
        args=(
            event,
            viewer.dims,
            widget.model_dropdown.currentText(),
            widget.live_fit_checkbox.isChecked(),
            widget.live_pred_checkbox.isChecked(),
        ),
    )
    thread.start()
    thread.join()

    prediction_layer.refresh()

def threaded_on_data_change(
    event,
    dims,
    model_type,
    live_fit,
    live_prediction,
):
    global model, crop_3D, painting_data, superpixel_seg, superpixel_features
    
    # Ensure consistent shapes
    min_shape = [min(s1, s2, s3) for s1, s2, s3 in zip(crop_3D.shape, painting_data.shape, superpixel_seg.shape)]
    logger.debug(f"min_shape: {min_shape}")
    
    active_labels = painting_data[:min_shape[0], :min_shape[1], :min_shape[2]]
    crop_3D_subset = crop_3D[:min_shape[0], :min_shape[1], :min_shape[2]]
    superpixel_seg_subset = superpixel_seg[:min_shape[0], :min_shape[1], :min_shape[2]]

    # Recompute superpixel features
    superpixel_features = compute_superpixel_features(crop_3D_subset, superpixel_seg_subset)
    
    # Create a mask of painted pixels
    painted_mask = active_labels > 0
    
    if live_fit:
        # Prepare features and labels for training
        X = []
        y = []
        for label in superpixel_features['label']:
            mask = superpixel_seg_subset == label
            if np.any(painted_mask[mask]):  # Check if any pixel in this superpixel is painted
                feature_vector = [superpixel_features[prop][superpixel_features['label'] == label][0] 
                                  for prop in superpixel_features.keys() if prop != 'label']
                X.append(feature_vector)
                y.append(stats.mode(active_labels[mask][painted_mask[mask]], axis=None)[0])
        
        X = np.array(X)
        y = np.array(y)
        
        logger.debug(f"Number of painted superpixels: {len(X)}")
        logger.debug(f"X shape: {X.shape}, y shape: {y.shape}")
        
        if len(X) > 0:
            model = update_model(y, X, model_type)
        else:
            logger.warning("No painted superpixels found. Skipping model update.")

    if live_prediction and model is not None:
        try:
            superpixel_predictions = predict(model, superpixel_features)
            prediction = np.zeros_like(superpixel_seg_subset)
            for i, label in enumerate(superpixel_features['label']):
                prediction[superpixel_seg_subset == label] = superpixel_predictions[i]
            prediction_layer.data[:min_shape[0], :min_shape[1], :min_shape[2]] = prediction
            logger.debug("Prediction updated successfully")
        except Exception as e:
            logger.error(f"Error during prediction: {str(e)}")

# Connect event listeners
for listener in [viewer.dims.events, painting_layer.events.paint]:
    listener.connect(
        debounced(
            ensure_main_thread(
                on_data_change(viewer=viewer, widget=widget)
            ),
            timeout=1000,
        )
    )

napari.run()

DEBUG:gql.dsl:Creating <DSLType <GraphQLObjectType 'datasets'>>)
DEBUG:gql.dsl:Creating <DSLField datasets::id>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_component_id>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_component_name>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_name>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_strain_id>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_strain_name>
DEBUG:gql.dsl:Creating <DSLField datasets::cell_type_id>
DEBUG:gql.dsl:Creating <DSLField datasets::dataset_citations>
DEBUG:gql.dsl:Creating <DSLField datasets::dataset_publications>
DEBUG:gql.dsl:Creating <DSLField datasets::deposition_date>
DEBUG:gql.dsl:Creating <DSLField datasets::description>
DEBUG:gql.dsl:Creating <DSLField datasets::grid_preparation>
DEBUG:gql.dsl:Creating <DSLField datasets::https_prefix>
DEBUG:gql.dsl:Creating <DSLField datasets::key_photo_thumbnail_url>
DEBUG:gql.dsl:Creating <DSLField datasets::key_photo_url>
DEBUG:gql.dsl:Creating <DSLField datasets::l

zarr path: /Users/kharrington/Library/Application Support/napari_dl_at_mbl_2024/diy_segmentation.zarr


DEBUG:napari.components._layer_slicer:_LayerSlicer.submit: layers=[<Labels layer 'Painting' at 0x1d0ee0190>], dims=ndim=3 ndisplay=2 order=(0, 1, 2) axis_labels=('0', '1', '2') range=(RangeTuple(start=0.0, stop=49.0, step=1.0), RangeTuple(start=0.0, stop=179.0, step=1.0), RangeTuple(start=0.0, stop=219.0, step=1.0)) margin_left=(0.0, 0.0, 0.0) margin_right=(0.0, 0.0, 0.0) point=(24.0, 89.0, 109.0) last_used=0, force=False
DEBUG:napari.components._layer_slicer:Sync slicing for Painting
DEBUG:napari.layers.base.base:Layer._slice_dims: Painting, dims=ndim=3 ndisplay=2 order=(0, 1, 2) axis_labels=('0', '1', '2') range=(RangeTuple(start=0.0, stop=49.0, step=1.0), RangeTuple(start=0.0, stop=179.0, step=1.0), RangeTuple(start=0.0, stop=219.0, step=1.0)) margin_left=(0.0, 0.0, 0.0) margin_right=(0.0, 0.0, 0.0) point=(24.0, 89.0, 109.0) last_used=0, force=False
DEBUG:napari.layers.base.base:Layer._refresh_sync: Painting


DEBUG:napari.layers.base.base:Layer.refresh: Superpixels
DEBUG:napari.layers.base.base:Layer._refresh_sync: Superpixels
DEBUG:in_n_out:Executing @injected activate_labels_paint_mode(layer: napari.layers.labels.labels.Labels) with args: (<Labels layer 'Painting' at 0x1d0ee0190>,), kwargs: {}
DEBUG:in_n_out:  Calling activate_labels_paint_mode with {'layer': <Labels layer 'Painting' at 0x1d0ee0190>} (injected set())
DEBUG:napari.layers.base.base:Layer.refresh: Superpixels
DEBUG:napari.layers.base.base:Layer._refresh_sync: Superpixels
DEBUG:napari.layers.base.base:Layer.refresh: Superpixels
DEBUG:napari.layers.base.base:Layer._refresh_sync: Superpixels
DEBUG:napari.layers.base.base:Layer.refresh: Painting
DEBUG:napari.layers.base.base:Layer._refresh_sync: Painting
DEBUG:__main__:min_shape: [50, 180, 220]
DEBUG:__main__:Number of painted superpixels: 47
DEBUG:__main__:X shape: (47, 20), y shape: (47,)
DEBUG:__main__:X shape: (47, 20)
DEBUG:__main__:y shape: (47,)
DEBUG:__main__:Unique labe