# Load PlantSeg Models from BioImage.IO Model Zoo by Nicknames

To get started, create a conda/mamba environment with both `plant-seg` and `bioimageio.spec` from `conda-forge` channel. You may choose to only install `bioimageio.spec` and avoid to run the very last cell (PlantSeg model loading).

In [1]:
from pathlib import Path
import json
import pooch

from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4
from bioimageio.spec.model.v0_5 import ModelDescr as ModelDescr_v0_5
from bioimageio.spec.utils import download

from plantseg.training.model import UNet2D, UNet3D

In [2]:
BIOIMAGE_IO_COLLECTION_URL = "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json"

Let's filter the models by tags, which doesn't gaurantee to find all PlantSeg-compatible models, but it's quick:

In [3]:
def _is_plantseg_model(collection_entry: dict) -> bool:
    """Determines if the 'tags' field in a collection entry contains the keyword 'plantseg'."""
    tags = collection_entry.get("tags")
    if tags is None:
        return False
    if not isinstance(tags, list):
        raise ValueError(f"Tags in a collection entry must be a list of strings, got {type(tags).__name__}")

    # Normalize tags to lower case and remove non-alphanumeric characters
    normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags]
    return 'plantseg' in normalized_tags

In [4]:
collection_path = Path(pooch.retrieve(BIOIMAGE_IO_COLLECTION_URL, known_hash=None))
with collection_path.open(encoding='utf-8') as f:
    collection = json.load(f)
bioimageio_zoo_collection = collection
bioimageio_zoo_plantseg_model_url_dict = {
    entry["nickname"]: entry["rdf_source"]
    for entry in collection["collection"]
    if entry["type"] == "model" and _is_plantseg_model(entry)
}
set(bioimageio_zoo_plantseg_model_url_dict.keys())

{'efficient-chipmunk',
 'emotional-cricket',
 'laid-back-lobster',
 'loyal-squid',
 'noisy-fish',
 'passionate-t-rex',
 'pioneering-rhino',
 'powerful-fish',
 'thoughtful-turtle'}

Pick a model from the list and load it with `bioimageio.spec`:

In [5]:
model_id = 'efficient-chipmunk'

if model_id not in bioimageio_zoo_plantseg_model_url_dict:
    raise ValueError(f"Model ID {model_id} may not be a PlantSeg model in BioImage.IO Model Zoo")

In [6]:
rdf_url = bioimageio_zoo_plantseg_model_url_dict[model_id]
model_description = load_description(rdf_url)

# Check if description is `ResourceDescr`
if isinstance(model_description, InvalidDescr):
    model_description.validation_summary.display()
    raise ValueError(f"Failed to load {model_id}")

# Check `model_description` has `weights`
if not isinstance(model_description, ModelDescr_v0_4) and not isinstance(model_description, ModelDescr_v0_5):
    raise ValueError(
        f"Model description {model_id} is not in v0.4 or v0.5 BioImage.IO model description format. "
        "Only v0.4 and v0.5 formats are supported by BioImage.IO Spec and PlantSeg."
    )

# Check `model_description.weights` has `pytorch_state_dict`
if model_description.weights.pytorch_state_dict is None:
    raise ValueError(f"Model {model_id} does not have PyTorch weights")

# Spec format version v0.4 and v0.5 have different designs to store model architecture
if isinstance(model_description, ModelDescr_v0_4):  # then `pytorch_state_dict.architecture` is nn.Module
    architecture_callable = model_description.weights.pytorch_state_dict.architecture
    architecture_kwargs = model_description.weights.pytorch_state_dict.kwargs
elif isinstance(model_description, ModelDescr_v0_5):  # then it is `ArchitectureDescr` with `callable`
    architecture_callable = model_description.weights.pytorch_state_dict.architecture.callable
    architecture_kwargs = model_description.weights.pytorch_state_dict.architecture.kwargs
print(f"Got {architecture_callable} model with kwargs {architecture_kwargs}.")

Got plantseg.models.model.UNet3D model with kwargs {'f_maps': 16, 'in_channels': 1, 'out_channels': 2}.


In [7]:
# Create model from architecture and kwargs
architecture = str(architecture_callable)  # e.g. 'plantseg.models.model.UNet3D'
architecture = UNet3D if 'UNet3D' in architecture else UNet2D
model_config = {
    'in_channels': 1,
    'out_channels': 1,
    'layer_order': 'gcr',
    'f_maps': 32,
    'num_groups': 8,
    'final_sigmoid': True,
}
model_config.update(architecture_kwargs)
model = architecture(**model_config)
model_weights_path = download(model_description.weights.pytorch_state_dict.source).path

print(f"Created {architecture} model with kwargs {model_config}.")
print(f"Loaded model from BioImage.IO Model Zoo: {model_id}")

Created <class 'plantseg.training.model.UNet3D'> model with kwargs {'in_channels': 1, 'out_channels': 2, 'layer_order': 'gcr', 'f_maps': 16, 'num_groups': 8, 'final_sigmoid': True}.
Loaded model from BioImage.IO Model Zoo: efficient-chipmunk
