# Tutorial: using foundation models for segmentation and prompt automatization

In this tutorial we will introduce you to the usage of foundation models for microscopy for instance segmentation in a variety of images. Specifically, you will learn about and use _[μSAM](https://doi.org/10.1038/s41592-024-02580-4)_, a foundation model for segmentation specialized in microscopy images.

- TODO: Add some images


_μSAM_ is based on the [Segment Anything Model (SAM)](https://openaccess.thecvf.com/content/ICCV2023/papers/Kirillov_Segment_Anything_ICCV_2023_paper.pdf) by Meta, which introduced the first widely-used foundational model for segmentation in natural images. SAM's backbone network is a "simple" Vision Transformer (ViT), but the key to its success comes from 1) its training scheme and 2) the large amount of training data (1 __billion__ masks!). Specifically, during the training phase, SAM uses a _promptable_ segmentation task which enables it to be generalistic and, at the same time, achieve excellent zero-shot capabilities. By _prompting_, we mean that the model allows user input on _what_ the user wants to segment. This prompt can come in two different flavours: _bounding boxes_ (a.k.a. draw rectangles around the object) and _point prompting_ (a.k.a. click inside the object).

_μSAM_ is a specialized version of SAM that was fine-tuned on a large and diverse dataset of microscopy images. This was necessary due to the large domain gap between natural images (naturally RGB, which SAM was trained on) and microscopy images (potentially multichannel, controlled acquisitions).

The fact that (μ)SAM is promptable (and the current general DL landscape :)) arises several interesting questions. Does prompting always improve performance? What is the best way to prompt it? Can we automatize prompting? We will particularly focus on the last question in this tutorial. While foundation models, like _μSAM_, are trained to be generalistic, they may still struggle in certain domains, which generally requires them to be fine-tuned to assess the _domain gap_. Nevertheless, fine-tuning such models tends to be computationally very demanding and require large amounts of resources which are not always available.

In this tutorial, we will show you that another smaller, but specialized, neural network can be used to effectively prompt _μSAM_ in order to overcome the _domain gap_. In particular, we will use [_Spotiflow_](https://doi.org/10.1101/2024.02.01.578426), a neural network-based spot detection method for microscopy, which will generate _point prompts_ for _μSAM_. The usage of another smaller neural network for prompting requires it to be highly specialized and, thus, trained for each task, so it is likely that this automatic prompting strategy doesn't work right off the bat for many data modalities. Nevertheless, we encourage you to try the methods on your data and discuss the results with the TAs!


_Authors_: Anwai Archit, Albert Dominguez Mantes 

### Import Libraries

In [None]:
import os
from glob import glob
from tqdm.auto import tqdm
from typing import Optional, Literal, Tuple, List, Union

import numpy as np
import imageio.v3 as imageio
from skimage.measure import regionprops
from skimage.measure import label as connected_components

from torch_em.data.datasets.light_microscopy.ifnuclei import get_ifnuclei_paths

from spotiflow.model import Spotiflow

from segment_anything import SamPredictor

from micro_sam.prompt_based_segmentation import segment_from_points
from micro_sam.util import get_sam_model, precompute_image_embeddings
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Path to the folder where all required files are stored.
ROOT = "/scratch/denbi/k8s/ADL4IA_flexprojects/flexproject1"

### Bacteria colony segmentation

We will begin with a task. 

In [4]:
# Get the filepaths to the input images from AGAR data.

def get_agar_paths(
    path: Union[os.PathLike, str], resolution: Optional[Literal["higher", "lower"]] = None
) -> Tuple[List[str], List[str]]:
    """Get the filepaths to the input image and corresponding metadata file.

    Args:
        path: The folder where the input data is stored.
        resolution: The choice of resolution.

    Returns:
        List of filepaths for the input data.
        List of filepaths for the corresponding metadata.
    """

    data_dir = os.path.join(path, "AGAR_representative")

    # Get path to one low-res image and corresponding metadata file.
    resolution = ("*" if resolution is None else resolution) + "-resolution"
    image_paths = sorted(glob(os.path.join(data_dir, resolution, "*.jpg")))
    metadata_paths = [p.replace(".jpg", ".json") for p in image_paths]
    metadata_paths = [p for p in metadata_paths if os.path.exists(p)]

    assert image_paths and len(image_paths) == len(metadata_paths)

    return image_paths, metadata_paths

In [None]:
# Get the folder to the Spotiflow model trained on AGAR data.
spotiflow_model_dir = os.path.join(ROOT, "models", "spotiflow")

# Get the filepaths to the input images (low-resolution images from AGAR)
image_paths, _ = get_agar_paths(path=os.path.join(ROOT, "data", "agar"), resolution="lower")

In [10]:
# Run the AGAR model
model = Spotiflow.from_folder(
    pretrained_path=os.path.join(spotiflow_model_dir, "spotiflow_agar", "agar_model"),
    verbose=True,
)

# Get the detected spots
spots_per_image = []
for image_path in tqdm(image_paths, desc="Running Spotiflow on each image"):
    image = imageio.imread(image_path)
    detected_spots = model.predict(image, verbose=False, min_distance=9)[0]
    spots_per_image.append(detected_spots)

    # TODO: visualize spots for 1 image only (?)

INFO:spotiflow.model.spotiflow:Loading model from folder: /scratch/denbi/k8s/ADL4IA_flexprojects/flexproject1/models/spotiflow/spotiflow_agar/agar_model


INFO:spotiflow.model.spotiflow:Loading model from folder: /scratch/denbi/k8s/ADL4IA_flexprojects/flexproject1/models/spotiflow/spotiflow_agar/agar_model
Running Spotiflow on each image: 100%|██████████| 10/10 [00:03<00:00,  2.56it/s]


Run `micro-sam` with the detected spots.

In [None]:
def run_promptable_segmentation(
    predictor: SamPredictor, image: np.ndarray, point_prompts: List[List[Tuple[int, int]]],
) -> np.ndarray:
    """
    """

    # Compute the image embeddings.
    image_embeddings = precompute_image_embeddings(
        predictor=predictor,
        input_=image,
        ndim=2,  # With RGB images, we should have channels last and must set ndim to 2.
        verbose=False,
        # tile_shape=(384, 384),  # Tile shape for larger images.
        # halo=(64, 64),  # Overlap shape for larger images.
        # save_path=f"embeddings_{i}.zarr",  # Caches the image embeddings.
    )

    # Run promptable segmentation.
    masks = [
        segment_from_points(
            predictor=predictor,
            points=np.array([each_point_prompt]),  # Each point coordinate (Y, X) is expected as array.
            labels=np.array([1]),  # Each corresponding label, eg. 1 corresponds positive, is expected as array.
            image_embeddings=image_embeddings,
        ).squeeze() for each_point_prompt in point_prompts
    ]

    # Merge all masks into one segmentation.
    # 1. First, we get the area per object and try to map as: big objects first and small ones then
    #    (to avoid losing tiny objects near-by or to overlaps)
    mask_props = [{"mask": mask, "area": regionprops(connected_components(mask))[0].area} for mask in masks]

    # 2. Next, we assort based on area from greatest to smallest.
    assorted_masks = sorted(mask_props, key=(lambda x: x["area"]), reverse=True)
    masks = [per_mask["mask"] for per_mask in assorted_masks]

    # 3. Finally, we merge all individual segmentations into one.
    segmentation = np.zeros(image.shape[:2], dtype=int)
    for j, mask in enumerate(masks, start=1):
        segmentation[mask > 0] = j

    return segmentation

In [None]:
view = False

# Get the Segment Anything Model to simulate interactive segmentation with detected spots.
predictor = get_sam_model(
    model_type="vit_b_lm",
    checkpoint_path=os.path.join(ROOT, "models", "micro-sam", "vit_b_lm_v3.pt")
)

# Run simulated interactive segmentation per image.
for i, (image_path, point_prompts) in tqdm(
    enumerate(zip(image_paths, spots_per_image)),
    desc="Running micro-sam on each image",
    total=len(image_paths),
):
    image = imageio.imread(image_path)

    segmentation = run_promptable_segmentation(predictor=predictor, image=image, point_prompts=point_prompts)

    if view:
        # Visualize the image and corresponding segmentation (and detected spots).
        import napari
        v = napari.Viewer()
        v.add_image(image)
        v.add_labels(segmentation)
        v.add_points(point_prompts)
        napari.run()

Running micro-sam on each image: 100%|██████████| 10/10 [00:10<00:00,  1.10s/it]


Compare this with automatic segmentation of `micro-sam`.

In [11]:
# Run automatic instance segmentation with default parameters.
view = False

# Get the Segment Anything model and the corresponding segmentation class.
predictor, segmenter = get_predictor_and_segmenter(
    model_type="vit_b",
    checkpoint=os.path.join(ROOT, "models", "micro-sam", "vit_b_lm_v3.pt"),
    amg=False,  # i.e. runs our new automatic instance segmentation.
    is_tiled=False,  # overwrite if automatic segmentation is run based on tiling window
)

for i, (image_path, point_prompts) in tqdm(
    enumerate(zip(image_paths, spots_per_image)),
    desc="Running automatic segmentation with micro-sam",
    total=len(image_paths)
):
    image = imageio.imread(image_path)

    # Get automatic segmentation
    segmentation = automatic_instance_segmentation(
        predictor=predictor,
        segmenter=segmenter,
        input_path=image,
        ndim=2,
        verbose=False,
        # tile_shape=(384, 384),
        # halo=(64, 64),
    )

    if view:
        # Visualize the image and corresponding segmentation.
        import napari
        v = napari.Viewer()
        v.add_image(image)
        v.add_labels(segmentation)
        napari.run()

Running automatic segmentation with micro-sam: 100%|██████████| 10/10 [00:16<00:00,  1.62s/it]


Next Task: Use `Spotiflow`, trained on DAPI-stained images for detecting nuclei, on new fluorescence images.

TODO: explain IFNuclei data

TODO: explain the idea of the task - to check zero-shot performance on both tasks.

TODO: elaborate on the idea that the participants are expected to establish this pipeline and try everything here onwards themselves. We can provide them some hints (eg. choice of vit model, hyperparameters in spotiflow to try, etc)

In [27]:
# Run the DSB model
model = Spotiflow.from_folder(
    pretrained_path=os.path.join(spotiflow_model_dir, "spotiflow_dsb18", "dsb18_model"),
    verbose=True,
)

# Get DAPI-stained images.
image_paths, gt_paths = get_ifnuclei_paths(path=os.path.join(ROOT, "data", "if_nuclei"))
image_paths = [p for p in image_paths if os.path.basename(p).startswith("normal")]  # Consider DAPI-stained images.
gt_paths = [p for p in gt_paths if os.path.basename(p).startswith("normal")]  # Consider DAPI-stained images.

# Get the detected spots
spots_per_image = []
for image_path in tqdm(image_paths, desc="Running Spotiflow on each image"):
    image = imageio.imread(image_path)
    detected_spots = model.predict(image, verbose=False, min_distance=9, prob_thresh=0.4)[0]
    spots_per_image.append(detected_spots)

    # TODO: visualize spots for 1 image only (?)

INFO:spotiflow.model.spotiflow:Loading model from folder: /scratch/denbi/k8s/ADL4IA_flexprojects/flexproject1/models/spotiflow/spotiflow_dsb18/dsb18_model


INFO:spotiflow.model.spotiflow:Loading model from folder: /scratch/denbi/k8s/ADL4IA_flexprojects/flexproject1/models/spotiflow/spotiflow_dsb18/dsb18_model
Running Spotiflow on each image: 100%|██████████| 41/41 [00:04<00:00, 10.12it/s]


In [None]:
view = False

# Get the Segment Anything Model to simulate interactive segmentation with detected spots.
predictor = get_sam_model(model_type="vit_b_lm", checkpoint_path=os.path.join(ROOT, "models", "micro-sam", "vit_b_lm_v3.pt"))

# Run simulated interactive segmentation per image.
for i, (image_path, point_prompts) in tqdm(
    enumerate(zip(image_paths, spots_per_image)),
    desc="Running micro-sam on each image",
    total=len(image_paths),
):
    image = imageio.imread(image_path)
    
    segmentation = run_promptable_segmentation(predictor=predictor, image=image, point_prompts=point_prompts)

    if view:
        # Visualize the image and corresponding segmentation (and detected spots).
        import napari
        v = napari.Viewer()
        v.add_image(image)
        v.add_labels(segmentation)
        v.add_points(point_prompts)
        napari.run()

Running micro-sam on each image: 100%|██████████| 41/41 [01:54<00:00,  2.79s/it]


In [None]:
# Run automatic instance segmentation with default parameters.
view = False

# Get the Segment Anything model and the corresponding segmentation class.
predictor, segmenter = get_predictor_and_segmenter(
    model_type="vit_b",
    checkpoint=os.path.join(ROOT, "models", "micro-sam", "vit_b_lm_v3.pt"),
    amg=False,  # i.e. runs our new automatic instance segmentation.
    is_tiled=False,  # overwrite if automatic segmentation is run based on tiling window
)

for i, (image_path, point_prompts) in tqdm(
    enumerate(zip(image_paths, spots_per_image)),
    desc="Running automatic segmentation with micro-sam",
    total=len(image_paths)
):
    image = imageio.imread(image_path)

    # Get automatic segmentation
    segmentation = automatic_instance_segmentation(
        predictor=predictor,
        segmenter=segmenter,
        input_path=image,
        ndim=2,
        verbose=False,
        # tile_shape=(384, 384),
        # halo=(64, 64),
    )

    if view:
        # Visualize the image and corresponding segmentation.
        import napari
        v = napari.Viewer()
        v.add_image(image)
        v.add_labels(segmentation)
        napari.run()

Running automatic segmentation with micro-sam: 100%|██████████| 41/41 [01:17<00:00,  1.90s/it]
