# Testing Neural Network-based image segmentation models on your data
---
This notebook gives you access to several pre-installed segmentation models that can be easily run and compared on your image data-set. Please follow the instructions below to upload your data and test the models.

The notebook is licensed under CC BY-NC 4.0
Copyright (C) 2024 Franziska Oschmann, Scientific IT Services of ETH Zurich.

Contributing Authors:  Franziska Oschmann  together with  Andrzej Rzepiela (ScopeM ETH) and Szymon Stoma (ScopeM ETH).

# 1. Install dependencies
- The notebook runs on the **L4 accelator**. Select it in 'Runtime' by clicking 'Change runtime type'
- To run a code `cell` (separate piece of code), click on it and press the `play` button on the top left of it.
- It will take about 1 min to install dependencies in the cell below
- Ignore the restart warning message



In [None]:
%%capture
!pip install --no-deps git+https://github.com/ajrzepiela/midap.git@dev
!pip install --no-deps git+https://www.github.com/mouseland/cellpose.git
!pip install -q --no-deps numpy==1.26.4 "scipy>=1.11.4,<1.12" scikit-image>=0.22 \
  opencv-python>=4.8.1 pandas>=2.0.2 stardist>=0.9.1 omnipose>=1.0.6 tqdm gitpython coverage mpl_interactions \
  ipympl csbdeep fastremap edt igraph texttable mgen pbr ncolor mahotas torchvf peakdetect fill_voids roifile segment_anything
try:
  from gem.utils import graph_util, plot_util
except (ImportError, KeyError, ModuleNotFoundError):
  exit()


If you want to use the notebook more often, follow the instructions [here](https://medium.com/@ismailelalaoui/how-to-install-external-libraries-permanently-on-google-colab-eaa4509fb43f) to install dependencies permanently on your google drive

Now, **download the custom segmentation models**. Standard models are loaded from libraries.

In [None]:
!midap_download --force

## 2. Upload of images



- Run the `cell` below and later select the `data/raw_im/` data folder (don't miss the 'select' button that appears below the cell), and then run the second cell to select the files you want to analyse.

- This will take about 30 sec

- In `/content/data/raw_im/` there are several example images that can be used to test the notebook. If you want other images, upload them to the folder `/content/data/raw_im/`. To do this, in the panel on the left click on the folder icon and then drag and drop your image files into directory `/content/data/raw_im/`.

- Uploaded images should have the same size (the example set is 256x256, you can also remove it)


In [None]:
import os; os.environ["MATPLOTLIB_BACKEND"] = "module://ipympl.backend_nbagg"
import matplotlib
matplotlib.use('module://ipympl.backend_nbagg')
import matplotlib.pyplot as plt
import pandas as pd
import inspect
from google.colab import data_table
from IPython.display import display, clear_output
import ipywidgets as widgets
from ipywidgets import Layout

from google.colab import output
output.enable_custom_widget_manager()

%matplotlib ipympl
from midap.midap_jupyter.segmentation_jupyter import SegmentationJupyter

path = '/content/data/'
sj = SegmentationJupyter(path = path)

sj.get_input_dir()
display(sj.fc_file)

Please make sure that the folder has been selected. Then run the next cell to select the files (mark them with the mouse).

In [None]:
sj.get_input_files(sj.fc_file.selected)

## 3. Choose image axes

This is where we define the labels for the image axes. We need to specify which axes contain the number of images and the number of channels in the uploaded image stack. Based on this information, the image stack will be transformed into the following shape (num_images, width, height, num_channels). Please run the two cells below and select the correct options.

In [None]:
sj.load_input_image()

In [None]:
sj.spec_img_dims()
sj.align_img_dims()

## 2. Select channel

Select the channel which will be used for the further analysis. If images in your set contain only one channel, keep the channel '0'. Please run both cells below.

In [None]:
%matplotlib ipympl
sj.select_channel()
display(sj.output_sel_ch)

In [None]:
sj.set_channel()

## 3. Define ROI

Define the region of interest (the same for all the images) by zooming into the part of the image you want to segment (use the 'zoom to rectangle' tool from the tool icons at the left-hand side). Run all the cells below.

In [None]:
%matplotlib ipympl
sj.show_example_image(sj.imgs_sel_ch[0,:,:,0])

In [None]:
sj.get_corners_cutout()
sj.make_cutouts()

%matplotlib ipympl
sj.show_all_cutouts()
sj.output_all_cuts

In [None]:
sj.save_cutouts()

## 4. Model selection

You can choose between different models trained on different species, markers and neural network types. Select the models by running the cell below and clicking through the options (the selected models will appear in the list below the cell).

In [None]:
csv_url = "https://raw.githubusercontent.com/emmarant/biscotto/main/model_table.csv"
df_model_interact = pd.read_csv(csv_url)

sj.get_segmentation_models()
display(data_table.DataTable(df_model_interact, include_index=False, num_rows_per_page=10))

all_names = df_model_interact["Model Name"].astype(str).tolist()

search = widgets.Text(placeholder="filter models with ... (substring match)", layout=Layout(width="40%"))
sel    = widgets.SelectMultiple(options=sorted(all_names), rows=12, description="Select")
btn_all   = widgets.Button(description="Select all (filtered)", tootip='Select all models matching filter keywords')
btn_clear  = widgets.Button(description="Clear", tooltip='Clear selection')
btn_apply = widgets.Button(description="Apply selection",tooltip='Apply selected models')
btn_applyrun   = widgets.Button(description="Apply & run", tooltip='Apply selected models and run segmentation',button_style="primary")
out       = widgets.Output()

def refresh_options(_=None):
    q = search.value.lower().strip()
    opts = [n for n in all_names if q in n.lower()] if q else sorted(all_names)
    current = set(sel.value)
    sel.options = opts
    sel.value = tuple([o for o in opts if o in current])

search.observe(refresh_options, names="value")
refresh_options()

def on_all_clicked(_):
    sel.value = tuple(sel.options)

def on_none_clicked(_):
    sel.value = ()

btn_all.on_click(on_all_clicked)
btn_clear.on_click(on_none_clicked)


def _apply_selection(run_now=False):
    selected = set(sel.value)

    sj.model_checkboxes = {
        name: widgets.Checkbox(value=(name in selected), indent=False, layout=Layout(width="1px", height="1px"))
        for name in all_names
    }

    with out:
        clear_output()
        print(f"Selected {len(selected)} model(s):")
        for n in sorted(selected):
            print("  •", n)

    if run_now:
        sj.select_segmentation_models()
        sj.run_all_chosen_models()

def on_apply_clicked(_):
    _apply_selection(run_now=False)

def on_applyrun_clicked(_):
    _apply_selection(run_now=True)

btn_apply.on_click(on_apply_clicked)
btn_applyrun.on_click(on_applyrun_clicked)

widgets.VBox([
    widgets.HBox([search, btn_all, btn_clear, btn_apply, btn_applyrun]),
    sel,
    out
])


- By running the following cell, segmentations with all selected models are generated and displayed for comparison.

- This will take some time depending on how many models and images you have

In [None]:
#sj.compare_segmentations()
#sj.output_seg_comp

compare_and_plot_segmentations(sj)


Choose the name of the model weights giving the best segmentation result:

In [None]:
sj.display_buttons_weights()
display(sj.out_weights)

## 5. Save segmentations

Based on the chosen model and model weights, the whole image stack will be segmented. In case you would like to upload an additional file for the segmentation, please do that below.

In [None]:
sj.load_add_files()
sj.out_add_file

In [None]:
sj.process_images()

You find the segmented images under `/content/data/seg_im/`.

### **General comments and TODOs**

+ [X] Add 'functionality' field (segmentation, classification,denoising)
+ [ ] Get inspiration for fields, terms, tags, etc from biii.eu
+ [ ] There might be no need for both buttons "Apply to sj" and "Apply and Run". Re-think it: is there any forseen user action that after selecting the desired models, does not proceed with running these models on the images determined earlier?
+ [ ] **Color scheme** for ROI etc not the best. How about **grayscale** instead??
+ [ ] DataTable does not render hyperlinks (just keeps the address string). Can use HTML instead, but then the table doesn't have the nice filter functionalities of a DataTable
+ [X] **Useful Docs column**: can add relevant papers too or other useful links. Or decide to keep it simple with just the github links
+ [X] Add 'version' and 'comments' columns. To be updated by model list editors.
+ [X] Make model list a CSV file, imported at run time (from same git where notebook lives)


### **Resources**

- *Omnipose*

https://github.com/kevinjohncutler/omnipose

https://omnipose.readthedocs.io/index.html

- *Cellpose*

https://github.com/mouseland/cellpose

https://cellpose.readthedocs.io/en/latest/index.html


- *Stardist*

https://github.com/stardist/stardist#pretrained-models

##**SCRATCH**

### New comparison plotting function

This is entirely based on the existing sj.compare_segmentations() function

Edited for visual appeal and added new plots (overlays)

In [None]:
def compare_and_plot_segmentations(sj):
        """
        Visualises:
          1. raw image
          2. instance segmentation of model-1
          3. instance segmentation of model-2
          4. overlay raw image + outlines of instance segmentation of model-1
          5. overlay raw image + outlines of instance segmentation of model-2
          6. bar-plot of the per-model mean semantic disagreement scores
             with standard deviation as error bars.
        """
        from ipywidgets import interactive
        import numpy as np
        # ----------------------------------------------------------------
        # prepare bar-plot data (only once)
        # ----------------------------------------------------------------
        if not hasattr(sj, "model_diff_scores"):
            sj.model_diff_scores = sj.compute_model_diff_scores()

        def draw_instance_outlines(ax, inst_labels, color="yellow", lw=1.5):
            #import numpy as np, matplotlib.pyplot as plt
            inst = np.asarray(inst_labels)
            if inst.ndim == 3 and inst.shape[-1] == 2:
                inst = inst[..., 0]
            labels = np.unique(inst)
            labels = labels[labels != 0]

            for lab in labels:
                ax.contour(inst == lab, levels=[0.5], colors=[color], linewidths=lw)

        def f(a, b, c):
            fig = plt.figure(figsize=(20, 22))
            gs = fig.add_gridspec(
                nrows=4, ncols=2,
                height_ratios=[1.0, 1.0, 1.0, 0.8],
                hspace=0.13, wspace=0.08
            )

        # ---- Row 1 (raw) ----
            ax0 = fig.add_subplot(gs[0, :])

        # ---- Row 2 (two segmentations) — side by side ----
            ax1 = fig.add_subplot(gs[1, 0], sharex=ax0, sharey=ax0)
            ax2 = fig.add_subplot(gs[1, 1], sharex=ax0, sharey=ax0)

        # Row 3 (two segmentations) — side by side
            ax3 = fig.add_subplot(gs[2, 0], sharex=ax0, sharey=ax0)
            ax4 = fig.add_subplot(gs[2, 1], sharex=ax0, sharey=ax0)

        # Row 4 (bar plot) — span both columns
            ax5 = fig.add_subplot(gs[3, :])

        # ---- raw image ----
            raw = sj.imgs_cut[int(c)]
            ax0.imshow(raw, cmap="gray")
            ax0.set_xticks([]); ax0.set_yticks([])
            ax0.set_title("Raw image")

        # ---- instance seg – model 1 ----
            inst_a = sj.dict_all_models_label[a][int(c)]
            inst_a = np.asarray(inst_a)
            if inst_a.ndim == 3 and inst_a.shape[-1] == 2:
               inst_a = inst_a[..., 0]
            inst_a = np.ma.masked_where(inst_a == 0, inst_a)
            ax1.imshow(inst_a, cmap="tab20")
            ax1.set_xticks([]); ax1.set_yticks([])
            ax1.set_title("Model 1 (instance)")


        # ---- instance seg – model 2 ----
            inst_b = sj.dict_all_models_label[b][int(c)]
            inst_b = np.asarray(inst_b)
            if inst_b.ndim == 3 and inst_b.shape[-1] == 2:
                inst_b = inst_b[..., 0]
            inst_b = np.ma.masked_where(inst_b == 0, inst_b)
            ax2.imshow(inst_b, cmap="tab20")
            ax2.set_xticks([]); ax2.set_yticks([])
            ax2.set_title("Model 2 (instance)")


        # ---- raw + seg overlay – model 1 ----
            inst_a = sj.dict_all_models_label[a][int(c)]
            ax3.imshow(raw, cmap="gray")
            draw_instance_outlines(ax3, inst_a)
            ax3.set_xticks([]); ax3.set_yticks([])
            ax3.set_title("Raw + Model 1 (outlines)")


        # ---- raw + seg overlay – model 2 ----
            inst_b = sj.dict_all_models_label[b][int(c)]
            ax4.imshow(raw, cmap="gray")
            draw_instance_outlines(ax4, inst_b, color="cyan", lw=1.5)
            ax4.set_xticks([]); ax4.set_yticks([])
            ax4.set_title("Raw + Model 2 (outlines)")

        # ---- bar plot: mean disagreements + std dev ----
            mdl_ids = list(sj.model_diff_scores.keys())
            scores, std_devs = zip(*[sj.model_diff_scores[m] for m in mdl_ids])
            short_mdl_ids = [f"{m[:5]}...{m.split('_')[-1]}" for m in mdl_ids]

            ax5.bar(range(len(mdl_ids)), scores, yerr=std_devs, capsize=5)
            ax5.set_xticks(range(len(mdl_ids)))
            ax5.set_xticklabels(short_mdl_ids, rotation=90)
            ax5.set_ylabel("Mean semantic difference")
            ax5.set_title("Per-model disagreement")

            plt.show()
            plt.close(fig) # stop figure count increasing with every run


        sj.output_seg_comp = interactive(
            f,
            a=widgets.Dropdown(
                options=sj.dict_all_models.keys(),
                description="Model 1", layout=widgets.Layout(width="50%")
            ),
            b=widgets.Dropdown(
                options=sj.dict_all_models.keys(),
                description="Model 2", layout=widgets.Layout(width="50%")
            ),
            c=widgets.IntSlider(
                min=0,
                max=len(next(iter(sj.dict_all_models.values()))) - 1,
                description="Image ID"
            ),
        )
        display(sj.output_seg_comp)



In [None]:

compare_and_plot_segmentations(sj)