From b1263f9ceae4dffec299529cfd24da7d6bc51cc4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 19 Nov 2025 10:40:20 +0100 Subject: [PATCH 1/2] Update SGN model --- development/check_uploaded_models.py | 31 +++++++++++++++++++--------- development/export_models.py | 4 ++-- flamingo_tools/model_utils.py | 4 ++-- flamingo_tools/segmentation/cli.py | 14 +++++++++++-- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/development/check_uploaded_models.py b/development/check_uploaded_models.py index eb18262..a7e79a1 100644 --- a/development/check_uploaded_models.py +++ b/development/check_uploaded_models.py @@ -18,17 +18,23 @@ } -def check_segmentation_model(model_name, checkpoint_path=None): +def check_segmentation_model(model_name, checkpoint_path=None, input_path=None, check_prediction=False, **kwargs): output_folder = f"result_{model_name}" os.makedirs(output_folder, exist_ok=True) - input_path = os.path.join(output_folder, f"{model_name}.tif") - if not os.path.exists(input_path): - data_path = _sample_registry().fetch(data_dict[model_name]) - copyfile(data_path, input_path) + if input_path is None: + input_path = os.path.join(output_folder, f"{model_name}.tif") + if not os.path.exists(input_path): + data_path = _sample_registry().fetch(data_dict[model_name]) + copyfile(data_path, input_path) output_path = os.path.join(output_folder, "segmentation.zarr") if not os.path.exists(output_path): - cmd = ["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name] + cmd = [ + "flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name, + "--disable_masking", "--min_size", "5", + ] + for name, val in kwargs.items(): + cmd.extend([f"--{name}", str(val)]) if checkpoint_path is not None: cmd.extend(["-c", checkpoint_path]) subprocess.run(cmd) @@ -38,6 +44,9 @@ def check_segmentation_model(model_name, checkpoint_path=None): image = imageio.imread(input_path) v = napari.Viewer() v.add_image(image) + if check_prediction: + prediction = zarr.open(os.path.join(output_folder, "predictions.zarr"))["prediction"][:] + v.add_image(prediction) v.add_labels(segmentation, name=f"{model_name}-segmentation") napari.run() @@ -77,11 +86,13 @@ def main(): # - Prediction works well on the GPU. # check_segmentation_model("IHC") - # TODO: Update model. # SGN segmentation (lowres): - # - Prediction does not work well on the CPU. - # - Prediction does not work well on the GPU. - check_segmentation_model("SGN-lowres", checkpoint_path="SGN-lowres.pt") + # - Prediction works well on the CPU. + # - Prediction works well on the GPU. + check_segmentation_model( + "SGN-lowres", + # boundary_distance_threshold=0.5, center_distance_threshold=None, + ) # IHC segmentation (lowres): # - Prediction works well on the CPU. diff --git a/development/export_models.py b/development/export_models.py index 69a62ec..f4b833d 100644 --- a/development/export_models.py +++ b/development/export_models.py @@ -21,8 +21,8 @@ def export_synapses(): def export_sgn_lowres(): - path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_sgn-low-res-v4" # noqa - model = load_model(path, device="cpu") + path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/cochlea_distance_unet_sgn-low-res-v5" # noqa + model = load_model(path, device="cpu", name="latest") torch.save(model, "SGN-lowres.pt") diff --git a/flamingo_tools/model_utils.py b/flamingo_tools/model_utils.py index c8676fe..64b0f67 100644 --- a/flamingo_tools/model_utils.py +++ b/flamingo_tools/model_utils.py @@ -64,14 +64,14 @@ def get_model_registry() -> None: "SGN": "3058690b49015d6210a8e8414eb341c34189fee660b8fac438f1fdc41bdfff98", "IHC": "752dab7995b076ec4b8526b0539d1b33ade5de9251aaf6863d9bd8cc9cd036b6", "Synapses": "2a42712b056f082b4794f15cf41b15678aab0bec1acc922ff9f0dc76abe6747e", - "SGN-lowres": "6accba4b4c65158fccf25623dcd0fb3b14203305d033a0d443a307114ec5dd8c", + "SGN-lowres": "2c773792f0ef6022c7d052c452071cf7bf45dfce6b498b408ad6cd1cc3a30d35", "IHC-lowres": "537f1d4afc5a582771b87adeccadfa5635e1defd13636702363992188ef5bdbd", } urls = { "SGN": "https://owncloud.gwdg.de/index.php/s/NZ2vv7hxX1imITG/download", "IHC": "https://owncloud.gwdg.de/index.php/s/wB7d2MjV5LRTP06/download", "Synapses": "https://owncloud.gwdg.de/index.php/s/A9W5NmOeBxiyZgY/download", - "SGN-lowres": "https://owncloud.gwdg.de/index.php/s/8hwZjBVzkuYhHLm/download", + "SGN-lowres": "https://owncloud.gwdg.de/index.php/s/OS7985CKaTTBT5g/download", "IHC-lowres": "https://owncloud.gwdg.de/index.php/s/EhnV4brhpvFbSsy/download", } cache_dir = get_cache_dir() diff --git a/flamingo_tools/segmentation/cli.py b/flamingo_tools/segmentation/cli.py index 3dd445b..b971ebc 100644 --- a/flamingo_tools/segmentation/cli.py +++ b/flamingo_tools/segmentation/cli.py @@ -19,6 +19,8 @@ def _parse_kwargs(extra_kwargs, **default_kwargs): def _convert_argval(value): # The values for the parsed arguments need to be in the expected input structure as provided. # i.e. integers and floats should be in their original types. + if value is None or value == "None": + return None try: return int(value) except ValueError: @@ -47,13 +49,21 @@ def _convert_argval(value): def _parse_segmentation_kwargs(extra_kwargs, model_type): - if model_type.startswith("SGN"): + if model_type == "SGN": default_kwargs = { "center_distance_threshold": 0.4, "boundary_distance_threshold": 0.5, "fg_threshold": 0.5, "distance_smoothing": 0.0, - "seg_class": "sgn_low" if model_type == "SGN-lowres" else "sgn", + "seg_class": "sgn", + } + elif model_type == "SGN-lowres": + default_kwargs = { + "center_distance_threshold": None, + "boundary_distance_threshold": 0.5, + "fg_threshold": 0.5, + "distance_smoothing": 0.0, + "seg_class": "sgn_low", } else: assert model_type.startswith("IHC") From c97f99ef909486c0ed530999e3cbd3c7a21bdd98 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 20 Nov 2025 10:18:57 +0100 Subject: [PATCH 2/2] Update segmentation setting logic --- flamingo_tools/model_utils.py | 45 ++++++++++++++++++++ flamingo_tools/plugin/segmentation_widget.py | 27 ++++++++---- flamingo_tools/segmentation/cli.py | 29 +------------ 3 files changed, 65 insertions(+), 36 deletions(-) diff --git a/flamingo_tools/model_utils.py b/flamingo_tools/model_utils.py index 64b0f67..714f6ff 100644 --- a/flamingo_tools/model_utils.py +++ b/flamingo_tools/model_utils.py @@ -117,6 +117,51 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None return model +def get_default_segmentation_settings(model_type: str) -> Dict[str, Union[str, float]]: + """Get the default settings for instance segmentation post-processing for a given model. + + Args: + model_type: The model. One of 'SGN', 'SGN-lowres', 'IHC', 'IHC-lowres'. + + Returns: + Dictionary with the default segmentation settings. + """ + all_default_kwargs = { + "SGN": { + "center_distance_threshold": 0.4, + "boundary_distance_threshold": 0.5, + "fg_threshold": 0.5, + "distance_smoothing": 0.0, + "seg_class": "sgn", + }, + "SGN-lowres": { + "center_distance_threshold": None, + "boundary_distance_threshold": 0.5, + "fg_threshold": 0.5, + "distance_smoothing": 0.0, + "seg_class": "sgn_low", + }, + "IHC": { + "center_distance_threshold": 0.5, + "boundary_distance_threshold": 0.6, + "fg_threshold": 0.5, + "distance_smoothing": 0.6, + "seg_class": "ihc", + }, + "IHC-lowres": { + "center_distance_threshold": 0.5, + "boundary_distance_threshold": 0.6, + "fg_threshold": 0.5, + "distance_smoothing": 0.6, + "seg_class": "ihc", + }, + } + if model_type not in all_default_kwargs: + raise ValueError(f"Invalid model: {model_type}. Choose one of {list(all_default_kwargs.keys())}.") + default_kwargs = all_default_kwargs[model_type] + return default_kwargs + + def get_default_tiling() -> Dict[str, Dict[str, int]]: """Determine the tile shape and halo depending on the available VRAM. diff --git a/flamingo_tools/plugin/segmentation_widget.py b/flamingo_tools/plugin/segmentation_widget.py index 5d22fed..d629bd9 100644 --- a/flamingo_tools/plugin/segmentation_widget.py +++ b/flamingo_tools/plugin/segmentation_widget.py @@ -9,25 +9,27 @@ from .base_widget import BaseWidget from .util import _load_custom_model, _available_devices, _get_current_tiling -from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling +from ..model_utils import ( + get_model, get_model_registry, get_device, get_default_tiling, get_default_segmentation_settings +) -# TODO Expose segmentation kwargs. -def _run_segmentation(image, model, model_type, tiling, device): +def _run_segmentation(image, model, model_type, tiling, device, min_size): block_shape = [tiling["tile"][ax] for ax in "zyx"] halo = [tiling["halo"][ax] for ax in "zyx"] prediction = predict_with_halo( image, model, gpu_ids=[device], block_shape=block_shape, halo=halo, tqdm_desc="Run prediction" ) + settings = get_default_segmentation_settings(model_type) + foreground_threshold = settings.pop("fg_threshold", 0.5) + settings.pop("seg_class", None) + settings = {name: 1.0 if val is None else val for name, val in settings.items()} foreground_map, center_distances, boundary_distances = prediction segmentation = watershed_from_center_and_boundary_distances( center_distances, boundary_distances, foreground_map, - center_distance_threshold=0.5, - boundary_distance_threshold=0.5, - foreground_threshold=0.5, - distance_smoothing=1.6, - min_size=100, + min_size=min_size, foreground_threshold=foreground_threshold, + **settings, ) return segmentation @@ -110,7 +112,9 @@ def on_predict(self): # Get the current tiling. self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) - segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling, device=device) + segmentation = _run_segmentation( + image, model=model, model_type=model_type, tiling=self.tiling, device=device, min_size=self.min_size + ) self.viewer.add_labels(segmentation, name=model_type) show_info(f"INFO: Segmentation of {model_type} added to layers.") @@ -120,6 +124,11 @@ def _create_settings_widget(self): # setting_values.setToolTip(get_tooltip("embedding", "settings")) setting_values.setLayout(QVBoxLayout()) + # Create UI for the min-size parameter. + self.min_size = 100 + self.min_size_menu, layout = self._add_int_param("min_size", self.min_size, 0, 10000) + setting_values.layout().addLayout(layout) + # Create UI for the device. device = "auto" device_options = ["auto"] + _available_devices() diff --git a/flamingo_tools/segmentation/cli.py b/flamingo_tools/segmentation/cli.py index b971ebc..953af2e 100644 --- a/flamingo_tools/segmentation/cli.py +++ b/flamingo_tools/segmentation/cli.py @@ -4,7 +4,7 @@ from .unet_prediction import run_unet_prediction from .synapse_detection import marker_detection -from ..model_utils import get_model_path +from ..model_utils import get_model_path, get_default_segmentation_settings def _get_model_path(model_type, checkpoint_path=None): @@ -49,32 +49,7 @@ def _convert_argval(value): def _parse_segmentation_kwargs(extra_kwargs, model_type): - if model_type == "SGN": - default_kwargs = { - "center_distance_threshold": 0.4, - "boundary_distance_threshold": 0.5, - "fg_threshold": 0.5, - "distance_smoothing": 0.0, - "seg_class": "sgn", - } - elif model_type == "SGN-lowres": - default_kwargs = { - "center_distance_threshold": None, - "boundary_distance_threshold": 0.5, - "fg_threshold": 0.5, - "distance_smoothing": 0.0, - "seg_class": "sgn_low", - } - else: - assert model_type.startswith("IHC") - default_kwargs = { - "center_distance_threshold": 0.5, - "boundary_distance_threshold": 0.6, - "fg_threshold": 0.5, - "distance_smoothing": 0.6, - "seg_class": "ihc", - } - + default_kwargs = get_default_segmentation_settings(model_type) kwargs = _parse_kwargs(extra_kwargs, **default_kwargs) return kwargs