<a href="https://colab.research.google.com/github/emmarant/biscotto/blob/main/image_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Testing Neural Network-based image segmentation models on your data
---
This notebook gives you access to several pre-installed segmentation models that can be easily run and compared on your image data-set. Please follow the instructions below to upload your data and test the models.

The notebook is licensed under CC BY-NC 4.0
Copyright (C) 2024 Franziska Oschmann, Scientific IT Services of ETH Zurich.

Contributing Authors:  Franziska Oschmann  together with  Andrzej Rzepiela (ScopeM ETH) and Szymon Stoma (ScopeM ETH).

# 1. Install dependencies
- The notebook runs on the **L4 accelator**. Select it in 'Runtime' by clicking 'Change runtime type'
- To run a code `cell` (separate piece of code), click on it and press the `play` button on the top left of it.
- It will take about 1 min to install dependencies in the cell below
- Ignore the restart warning message



In [None]:
%%capture
!pip install --no-deps git+https://github.com/ajrzepiela/midap.git@dev
!pip install --no-deps git+https://www.github.com/mouseland/cellpose.git
!pip install -q --no-deps numpy==1.26.4 "scipy>=1.11.4,<1.12" scikit-image>=0.22 \
  opencv-python>=4.8.1 pandas>=2.0.2 stardist>=0.9.1 omnipose>=1.0.6 tqdm gitpython coverage mpl_interactions \
  ipympl csbdeep fastremap edt igraph texttable mgen pbr ncolor mahotas torchvf peakdetect fill_voids roifile segment_anything
try:
  from gem.utils import graph_util, plot_util
except (ImportError, KeyError, ModuleNotFoundError):
  exit()

If you want to use the notebook more often, follow the instructions [here](https://medium.com/@ismailelalaoui/how-to-install-external-libraries-permanently-on-google-colab-eaa4509fb43f) to install dependencies permanently on your google drive

Now, **download the custom segmentation models**. Standard models are loaded from libraries.

In [None]:
!midap_download --force

## 2. Upload of images



- Run the `cell` below and later select the `data/raw_im/` data folder (don't miss the 'select' button that appears below the cell), and then run the second cell to select the files you want to analyse.

- This will take about 30 sec

- In `/content/data/raw_im/` there are several example images that can be used to test the notebook. If you want other images, upload them to the folder `/content/data/raw_im/`. To do this, in the panel on the left click on the folder icon and then drag and drop your image files into directory `/content/data/raw_im/`.

- Uploaded images should have the same size (the example set is 256x256, you can also remove it)


In [None]:
import os; os.environ["MATPLOTLIB_BACKEND"] = "module://ipympl.backend_nbagg"
import matplotlib
matplotlib.use('module://ipympl.backend_nbagg')
import matplotlib.pyplot as plt

from google.colab import output
output.enable_custom_widget_manager()

%matplotlib ipympl
from midap.midap_jupyter.segmentation_jupyter import SegmentationJupyter

path = '/content/data/'
sj = SegmentationJupyter(path = path)

sj.get_input_dir()
display(sj.fc_file)

Please make sure that the folder has been selected. Then run the next cell to select the files (mark them with the mouse).

In [None]:
sj.get_input_files(sj.fc_file.selected)

## 3. Choose image axes

This is where we define the labels for the image axes. We need to specify which axes contain the number of images and the number of channels in the uploaded image stack. Based on this information, the image stack will be transformed into the following shape (num_images, width, height, num_channels). Please run the two cells below and select the correct options.

In [None]:
sj.load_input_image()

In [None]:
sj.spec_img_dims()
sj.align_img_dims()

## 2. Select channel

Select the channel which will be used for the further analysis. If images in your set contain only one channel, keep the channel '0'. Please run both cells below.

In [None]:
%matplotlib ipympl
sj.select_channel()
display(sj.output_sel_ch)

In [None]:
sj.set_channel()

## 3. Define ROI

Define the region of interest (the same for all the images) by zooming into the part of the image you want to segment (use the 'zoom to rectangle' tool from the tool icons at the left-hand side). Run all the cells below.

In [None]:
%matplotlib ipympl
sj.show_example_image(sj.imgs_sel_ch[0,:,:,0])

In [None]:
sj.get_corners_cutout()
sj.make_cutouts()

%matplotlib ipympl
sj.show_all_cutouts()
sj.output_all_cuts

In [None]:
sj.save_cutouts()

## 4. Model selection

You can choose between different models trained on different species, markers and neural network types. Select the models by running the cell below and clicking through the options (the selected models will appear in the list below the cell).

In [None]:
sj.get_segmentation_models()
sj.display_segmentation_models()
#sj.outp_interact_table

- By running the following cell, segmentations with all selected models are generated and displayed for comparison.

- This will take some time depending on how many models and images you have

In [None]:
# run all models
sj.select_segmentation_models()
sj.run_all_chosen_models()

In [None]:
sj.compare_segmentations()
sj.output_seg_comp

Choose the name of the model weights giving the best segmentation result:

In [None]:
sj.display_buttons_weights()
display(sj.out_weights)

## 5. Save segmentations

Based on the chosen model and model weights, the whole image stack will be segmented. In case you would like to upload an additional file for the segmentation, please do that below.

In [None]:
sj.load_add_files()
sj.out_add_file

In [None]:
sj.process_images()

You find the segmented images under `/content/data/seg_im/`.

## 6. Manou's Sandbox

### I. interactive table for model selection

In [None]:
## import stuff
## NOTE: check that those are not already imported; eventually place at top of notebook

import pandas as pd
from google.colab import data_table
from IPython.display import display, clear_output
from google.colab import data_table
import ipywidgets as widgets
from ipywidgets import Layout


In [None]:
## Pull the registry from SegmentationJupyter (build sj.df_models)
#_ = sj.get_segmentation_models()

In [None]:
## let's get the list of models
sj.get_segmentation_models()
df_model_interact = sj.df_models.copy()
df_model_interact.index.name = "model_name"
df_model_interact = df_model_interact.reset_index()

## check what columns or descriptive fields already exist
df_model_interact.columns.tolist()

In [None]:
## take a look at the dataframe listing the available models with the already existing description fields
df_model_interact

In [None]:
## let's get the list of models
df_model_interact = sj.df_models.copy()
df_model_interact.index.name = "model_name"
df_model_interact = df_model_interact.reset_index()



#df_model_interact.insert(0, "model_name", df_model_interact.index.astype(str))  # does NOT rename any existing fields

## Add new (fixed) columns the dumm way: check existing model named and infer functionality or features. If # of models remain finite, that is still OK.
## If not, find a clever way to do it

## FUNCTION for INFERRED fields/characteristics of the models based on their name.
## Pure brute force and ignorance. But can serve as the code skeleton.

def infer_from_name(name: str):
    nm = str(name).lower()
    return pd.Series({
        "Family": ("StarDist" if "stardist" in nm else
                            "Omnipose" if "omni" in nm else
                            "Cellpose+SAM" if "cpsam" in nm else None),
        "Target": ("nuclei" if any(k in nm for k in ["nuclei","dsb2018","fluo","stardist"]) else
                            "bacteria" if "bact" in nm else
                            "worm" if "worm" in nm else None),
        "Modality": ("fluorescence" if ("fluo" in nm or "stardist" in nm) else
                              "phase/brightfield" if "phase" in nm else None),
        "Dims": ("2D" if "2d" in nm else None),
        "Functionality":"clas/den/seg",
        "Trained on": ("StarDist 2D versatile (fluo)" if "versatile_fluo" in nm else
                                "DSB2018 nuclei" if "dsb2018" in nm else
                                "omnipose corpus" if "omni" in nm else
                                "cellpose+sam generalist" if "cpsam" in nm else None),
        "Good for": ("star-convex nuclei" if "stardist" in nm else
                              "bacteria/elongated cells" if ("bact" in nm or "omni" in nm) else
                              "general cell bodies" if any(t in nm for t in ["cp","cellpose","cpsam"]) else None),
        " NOT good for": ("blah" if "stardist" in nm else
                              "blah" if ("bact" in nm or "omni" in nm) else
                              "blahblah" if any(t in nm for t in ["cp","cellpose","cpsam"]) else None),
        "Channels expected": 1,
    })

df_inferred = df_model_interact["model_name"].apply(infer_from_name)
df_model_interact = pd.concat([df_model_interact, df_inferred], axis=1)


## return interactive table with model names and pre-selected known info or specs for each model. This is a searchable table, i.e. user can filter based on existing fields.

#data_table.DataTable(df_model_interact, include_index=False, num_rows_per_page=10)


In [None]:
## Show table with model list and their attributes

display(data_table.DataTable(df_model_interact, include_index=False, num_rows_per_page=10))

# Add a selection UI that drives sj.select_segmentation_models() ---

# 1) Build a simple, fast selector with search + multi-select
all_names = df_model_interact["model_name"].astype(str).tolist()

search = widgets.Text(placeholder="filter models… (substring match)", layout=Layout(width="40%"))
sel    = widgets.SelectMultiple(options=sorted(all_names), rows=12, description="Select")
btn_all   = widgets.Button(description="Select all (filtered)")
btn_none  = widgets.Button(description="Clear")
btn_apply = widgets.Button(description="Apply to sj")
btn_run   = widgets.Button(description="Apply & run", button_style="primary")
out       = widgets.Output()

def refresh_options(_=None):
    q = search.value.lower().strip()
    opts = [n for n in all_names if q in n.lower()] if q else sorted(all_names)
    # preserve already-selected items that still match the filter
    current = set(sel.value)
    sel.options = opts
    sel.value = tuple([o for o in opts if o in current])

search.observe(refresh_options, names="value")
refresh_options()

def on_all_clicked(_):
    sel.value = tuple(sel.options)

def on_none_clicked(_):
    sel.value = ()

btn_all.on_click(on_all_clicked)
btn_none.on_click(on_none_clicked)

# 2) Apply selection to sj.model_checkboxes (what sj.select_segmentation_models uses)
def _apply_selection(run_after=False):
    chosen = set(sel.value)

    # Create the exact structure sj expects: dict[model_id] -> Checkbox widget (True if selected)
    sj.model_checkboxes = {
        name: widgets.Checkbox(value=(name in chosen), indent=False, layout=Layout(width="1px", height="1px"))
        for name in all_names
    }

    with out:
        clear_output()
        print(f"Selected {len(chosen)} model(s):")
        for n in sorted(chosen):
            print("  •", n)

    if run_after:
        # these will use sj.model_checkboxes to decide what to run
        sj.select_segmentation_models()
        sj.run_all_chosen_models()

def on_apply_clicked(_):
    _apply_selection(run_after=False)

def on_run_clicked(_):
    _apply_selection(run_after=True)

btn_apply.on_click(on_apply_clicked)
btn_run.on_click(on_run_clicked)

# 3) Render the selector UI
widgets.VBox([
    widgets.HBox([search, btn_all, btn_none, btn_apply, btn_run]),
    sel,
    out
])


In [None]:
sj.compare_segmentations()
sj.output_seg_comp

In [None]:
sj.display_buttons_weights()
display(sj.out_weights)

In [None]:
sj.load_add_files()
sj.out_add_file

In [None]:
sj.process_images()

### **General comments**

+ Add 'functionality' field (segmentation, classification,denoising)

+ Get inspiration for fields, terms, tags, etc from biii.eu
