# BioImage Model Zoo Example notebook

This notebook shows how to interact with the `bioimgaeio.spec` programmatically to explore, load and export content from the [BioImage Model Zoo](https://bioimage.io).

## 0. Setup

### 0.1 Install dependencies
(if in Google Colab)

In [None]:
import os

if os.getenv("COLAB_RELEASE_TAG"):
    %pip install bioimageio.spec python-devtools

### 0.2 Enable pretty validation errors

Improves readiblity of format validation errors in Jupyter notebooks by removing redundant error details and hiding calls witin the pydantic library from the stacktrace.

In [None]:
from bioimageio.spec.pretty_validation_errors import (
    enable_pretty_validation_errors_in_ipynb,
)

enable_pretty_validation_errors_in_ipynb()

## 1. Inspect the available models in the BioImage Model Zoo

Go to https://bioimage.io to browser available models

## 2. Load and inspect a model description

bioimage.io resources may be identified via their bioimage.io ID, e.g. "affable-shark" or the [DOI](https://doi.org/) of their [Zenodo](https://zenodo.org/) backup.

Both of these options may be version specific ("affable-shark/1" or a version specific [Zenodo](https://zenodo.org/) backup [DOI](https://doi.org/)).

Alternativly any RDF source may be loaded by providing a local path or URL.

In [None]:
# Load the model description with one of these options
# 1. version unspecific (implicitly refering to the latest version):
MODEL_ID = "affable-shark"
MODEL_DOI = "10.5281/zenodo.11092561"

# 2. version specific
MODEL_VERSION_ID = "affable-shark/1"  # not yet available for this legacy model
MODEL_VERSION_DOI = "10.5281/zenodo.11092562"

# 3. an uploaded draft
MODEL_DRAFT = "affable-shark/draft"

# 4. from source
MODEL_URL = "https://zenodo.org/records/11092562/files/rdf.yaml"
MODEL_PATH = "some/local/rdf.yaml"
MODEL_PACKAGE_PATH = "some/local/package.zip"  # with an rdf.yaml inside

In [None]:
# Another set of examples to source a bioimage.io model
# 1. version unspecific (implicitly refering to the latest version):
MODEL_ID = "emotional-cricket"
MODEL_DOI = "10.5281/zenodo.6346511"

# 2. version specific
MODEL_VERSION_ID = "emotional-cricket/1"  # not yet available for this legacy model
MODEL_VERSION_DOI = "10.5281/zenodo.7768142"

# 3. an uploaded draft
MODEL_DRAFT = "emotional-cricket/draft"

# 4. from source
MODEL_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/emotional-cricket/1/files/rdf.yaml"
MODEL_PATH = "some/local/rdf.yaml"
MODEL_PACKAGE_PATH = "some/local/package.zip"  # with an rdf.yaml inside

In [None]:
from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.model.v0_5 import ModelDescr

source = MODEL_DRAFT  # let's use the latest draft version

loaded_description = load_description(source)

## 3. Validation summary of the model
A model description is validated with our format specification. 
To inspect the corresponding validation summary access the `validation_summary` attribute.

The validation summary will indicate:
- the version of the `bioimageio.spec` library used to run the validation
- the status of several validation steps
    - ✔️: Success
    - 🔍: information about the validation context
    - ⚠: Warning
    - ❌: Error

To display the validaiton summary in a terminal or notebook we recommend to run:

In [None]:
loaded_description.validation_summary.display()

In [None]:
# let's make sure we have a valid model...
if isinstance(loaded_description, InvalidDescr):
    raise ValueError(f"Failed to load {source}")
elif not isinstance(loaded_description, ModelDescr):
    raise ValueError("This notebook expects a model 0.5 description")

model = loaded_description
example_model_id = model.id
assert example_model_id is not None

## 4. Inspect the model description

In [None]:
from typing import Any

import imageio.v3
import matplotlib.pyplot as plt
from numpy.typing import NDArray

from bioimageio.spec._internal.io import FileSource
from bioimageio.spec.utils import download


def imread(src: FileSource) -> NDArray[Any]:
    """typed `imageio.v3.imread`"""
    img: NDArray[Any] = imageio.v3.imread(download(src).path)
    return img

print(f"The model is named '{model.name}'")
print(f"Description:\n{model.description}")
print(f"License: {model.license}")

In [None]:
try:
    from devtools import pprint
except ImportError:
    from pprint import pprint

print("\nThe authors of the model are:")
pprint(model.authors)
print(f"\nIn addition to the authors it is maintained by:")
pprint(model.maintainers)

In [None]:
print("\nIf you use this model, you are expected to cite:")
pprint(model.cite)

print(f"\nFurther documentation can be found here: {model.documentation}")

In [None]:
if model.git_repo is None:
    print("\nThere is no associated GitHub repository.")
else:
    print(f"\nThere is an associated GitHub repository: {model.git_repo}.")

for i, cover in enumerate(model.covers):
    downloaded_cover = download(cover)
    cover_data: NDArray[Any] = imread(downloaded_cover.path)
    _ = plt.figure(figsize=(10, 10))
    plt.imshow(cover_data)  # type: ignore
    plt.xticks([])  # type: ignore
    plt.yticks([])  # type: ignore
    plt.title(f"cover image {downloaded_cover.original_file_name}")  # type: ignore
    plt.show()

In [None]:
from bioimageio.spec.utils import download

cover_path = download(model.covers[0]).path
plt.imshow(imread(cover_path))
plt.xticks([])
plt.yticks([])
plt.show()

### 4.1 Inspect Available weight formats of the model

In [None]:
for w in [(weights := model.weights).onnx, weights.keras_hdf5, weights.tensorflow_js, weights.tensorflow_saved_model_bundle, weights.torchscript,weights.pytorch_state_dict]:
    if w is  None:
        continue

    print(w.weights_format_name)
    print(f"weights are available at {w.source.absolute()}")
    print(f"and have a SHA-256 value of {w.sha256}")
    details = {k: v for k, v in w.model_dump(mode="json", exclude_none=True).items() if k not in ("source", "sha256")}
    if details:
        print(f"additonal metadata for {w.weights_format_name}:")
        pprint(details)

    print()

### 4.2 Inspect expected inputs and outputs of the model

In [None]:
print(f"Model '{model.name}' requires {len(model.inputs)} input(s) with the following features:")
for ipt in model.inputs:
    print(f"\ninput '{ipt.id}' with axes:")
    pprint(ipt.axes)
    print(f"Data description: {ipt.data}")
    print(f"Test tensor available at:  {ipt.test_tensor.source.absolute()}")
    if len(ipt.preprocessing) > 1:
        print("This input is preprocessed with: ")
        for p in ipt.preprocessing:
            print(p)

print("\n-------------------------------------------------------------------------------")
# # and what the model outputs are
print(f"Model '{model.name}' requires {len(model.outputs)} output(s) with the following features:")
for out in model.outputs:
    print(f"\noutput '{out.id}' with axes:")
    pprint(out.axes)
    print(f"Data description: {out.data}")
    print(f"Test tensor available at:  {out.test_tensor.source.absolute()}")
    if len(out.postprocessing) > 1:
        print("This output is postprocessed with: ")
        for p in out.postprocessing:
            print(p)

### 4.3 Inspect model architecture

(inspection in this notebook only implemented for pytorch state dict weights)

In [None]:
from typing_extensions import assert_never

from bioimageio.spec.model.v0_5 import (
    ArchitectureFromFileDescr,
    ArchitectureFromLibraryDescr,
)

assert isinstance(model, ModelDescr)
if (w:=model.weights.pytorch_state_dict) is not None:
    arch = w.architecture
    print(f"callable: {arch.callable}")
    if isinstance(arch, ArchitectureFromFileDescr):
        print(f"import from file: {arch.source.absolute()}")
        if arch.sha256 is not None:
            print(f"SHA-256: {arch.sha256}")
    elif isinstance(arch, ArchitectureFromLibraryDescr):
        print(f"import from module: {arch.import_from}")
    else:
        assert_never(arch)

### 4.4 Inspect it all!

Of course we can also inspect the model description in full detail...
(which is a lot of text and the reason we have a `ModelDescr` object in the first place that keeps this metadata more organized)

In [None]:
pprint(model)

## 5. Create a model description

Let's recreate a model based on parts of the loaded model description from above!

Creating a model description with bioimageio.spec means creating a `bioimageio.spec.model.ModelDescr` object. This description object can be exportet and uploaded to the BioImage Model Zoo or deployed directly with community partner software.


Without any input data, initializing a `ModelDescr` will raise a `ValidationError` listing missing required fields:

In [None]:
from bioimageio.spec.model.v0_5 import ModelDescr

_ = ModelDescr()  # pyright: ignore[reportCallIssue]

To populate a `ModelDescr` appropriately we need to create the required subparts. This is part of the model metadata needed to document the model and ensure its correct deployment.

### 5.1 Inputs:

In [None]:
from bioimageio.spec.model.v0_5 import (
    AxisId,
    BatchAxis,
    ChannelAxis,
    FileDescr,
    Identifier,
    InputTensorDescr,
    IntervalOrRatioDataDescr,
    ParameterizedSize,
    SpaceInputAxis,
    SpaceOutputAxis,
    TensorId,
    WeightsDescr,
)

input_axes = [
    BatchAxis(),
    ChannelAxis(channel_names=[Identifier("raw")])]
if len(model.inputs[0].axes)==5: # e.g. impartial-shrimp
    input_axes += [
        SpaceInputAxis(id=AxisId("z"), size=ParameterizedSize(min=16, step=8)),
        SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=144, step=72)),
        SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=144, step=72)),
    ]
    data_descr = IntervalOrRatioDataDescr(type="float32")
elif len(model.inputs[0].axes)==4: # e.g. pioneering-rhino
    input_axes += [
        SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=256, step=8)),
        SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=256, step=8)),
    ]
    data_descr = IntervalOrRatioDataDescr(type="float32")
else:
    raise NotImplementedError(f"Recreating inputs for {example_model_id} is not implemented")

test_input_path = model.inputs[0].test_tensor.download().path
input_descr = InputTensorDescr(id=TensorId("raw"), axes=input_axes, test_tensor=FileDescr(source=test_input_path), data=data_descr)

### 5.2 Outputs

In [None]:
from bioimageio.spec.model.v0_5 import OutputTensorDescr, SizeReference

assert isinstance(model.outputs[0].axes[1], ChannelAxis)
output_axes = [
    BatchAxis(),
    ChannelAxis(channel_names=[Identifier(n) for n in model.outputs[0].axes[1].channel_names])]
if len(model.outputs[0].axes) == 5: # e.g. impartial-shrimp
    output_axes += [
        SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), # same size as input (tensor `raw`) axis `z`
        SpaceOutputAxis(id=AxisId('y'), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
        SpaceOutputAxis(id=AxisId('x'), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")))
    ]
elif len(model.outputs[0].axes) == 4: # e.g. pioneering-rhino
    output_axes += [
        SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('y'))), # same size as input (tensor `raw`) axis `y`
        SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('x'))),
    ]
else:
    raise NotImplementedError(f"Recreating outputs for {example_model_id} is not implemented")

test_output_path = model.outputs[0].test_tensor.download().path
output_descr = OutputTensorDescr(id=TensorId("prob"), axes=output_axes, test_tensor=FileDescr(source=test_output_path))

### 5.3 Model architecture
PyTorch state dict type of weights need to come with the corresponding architecture (e.g., 2D-U-Net):

In [None]:
from bioimageio.spec.model.v0_5 import (
    ArchitectureFromFileDescr,
    ArchitectureFromLibraryDescr,
    Version,
)

try:
    import torch
except ImportError:
    pytorch_version = Version("1.15")
else:
    pytorch_version = Version(torch.__version__)

## Recover the architecture information from the original model
assert model.weights.pytorch_state_dict is not None

arch = model.weights.pytorch_state_dict.architecture
if isinstance(arch, ArchitectureFromFileDescr):
    arch_file_path = download(arch.source, sha256=arch.sha256).path
    arch_file_sha256 = arch.sha256
    arch_name = arch.callable
    arch_kwargs = arch.kwargs

    pytorch_architecture = ArchitectureFromFileDescr(
        source=arch_file_path,
        sha256=arch_file_sha256,
        callable=arch_name,
        kwargs=arch_kwargs
    )
else:
    # For a model architecture that is published in a Python package
    # Make sure to include the Python library referenced in `import_from` in the weights entry's `depdendencies`
    pytorch_architecture = ArchitectureFromLibraryDescr(
        callable=arch.callable,
        kwargs=arch.kwargs,
        import_from=arch.import_from,
    )


### 5.4 Create the model

In [None]:
from bioimageio.spec.model.v0_5 import (
    Author,
    CiteEntry,
    Doi,
    HttpUrl,
    LicenseId,
    PytorchStateDictWeightsDescr,
    TorchscriptWeightsDescr,
)

assert model.weights.pytorch_state_dict is not None
assert model.weights.torchscript is not None
my_model_descr = ModelDescr(
    name="My cool model",
    description="A test model for demonstration purposes only",
    authors=[Author(name="me", affiliation="my institute", github_user="bioimageiobot")],  # change github_user to your GitHub account name
    cite=[CiteEntry(text="for model training see my paper", doi=Doi("10.1234something"))],
    license=LicenseId("MIT"),
    documentation=HttpUrl("https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md"),
    git_repo=HttpUrl("https://github.com/bioimage-io/spec-bioimage-io"),  # change to repo where your model is developed
    inputs=model.inputs,
    # inputs=[input_descr],  # try out our recreated input description
    outputs=model.outputs,
    # outputs=[output_descr],  # try out our recreated input description
    weights=WeightsDescr(
        pytorch_state_dict=PytorchStateDictWeightsDescr(
            source=model.weights.pytorch_state_dict.source,
            sha256=model.weights.pytorch_state_dict.sha256,
            architecture=pytorch_architecture,
            pytorch_version=pytorch_version
        ),
        torchscript=TorchscriptWeightsDescr(
            source=model.weights.torchscript.source,
            sha256=model.weights.torchscript.sha256,
            pytorch_version=pytorch_version,
            parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights ones.
        ),
    ),
    )
print(f"created '{my_model_descr.name}'")


### 5.5. Covers
Some optional fields were filed with default values, e.g., we did not specify `covers`. 
When possible, a default visualization of the test inputs and test outputs is generated.
When the input or the output have more than one channel, the current implementation cannot generate a cover image automatically.

Automatically generated cover images:

In [None]:
for cover in my_model_descr.covers:
    img: NDArray[Any] = imread(download(cover).path)
    _ = plt.imshow(img)
    plt.xticks([])  # type: ignore
    plt.yticks([])  # type: ignore
    plt.show()

## 6. Test the recently exported model
### 6.1 Static validation
(Same validation as at the very beginning)

In [None]:
model.validation_summary.display()

### 6.2 Dynamic validation

If you have the `bioimageio.core` library installed, you can run the dynamic validation and test if the model is correct and properly producing the test output image from the test input image. 
This extends the validation summary from above:

In [None]:
from bioimageio.core import test_model

summary = test_model(my_model_descr)
summary.display()

## 7. Package your model

A model is more than it's YAML description file! We refer to a zip-file containing all files relevant to a model as a model package. 

In [None]:
from pathlib import Path

from bioimageio.spec import save_bioimageio_package

print("package path:", save_bioimageio_package(my_model_descr, output_path=Path('my_model.zip')))