In [None]:
from bioimageio.spec.pretty_validation_errors import (
    enable_pretty_validation_errors_in_ipynb,
)

enable_pretty_validation_errors_in_ipynb()

## Load and inspect a model description

In [None]:
from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.common import HttpUrl
from bioimageio.spec.model import ModelDescr

# examples tested with this notebook
IMPARTIAL_SHRIMP = "impartial-shrimp"
PIONEERING_RHINO = "pioneering-rhino"

# example_model_id = IMPARTIAL_SHRIMP
example_model_id = PIONEERING_RHINO

# TODO: load bioimageio id from new S3 collection
if example_model_id == IMPARTIAL_SHRIMP:  # pyright: ignore[reportUnnecessaryComparison]
    url = "https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.5874741/5874742/rdf.yaml"
elif example_model_id == PIONEERING_RHINO:
    url = "https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.6334383/7805067/rdf.yaml"
else:
    raise NotImplementedError(example_model_id)

loaded_descr = load_description(url, format_version="latest")
if isinstance(loaded_descr, InvalidDescr):
    raise ValueError(f"Failed to load {example_model_id}:\n{loaded_descr.validation_summary.format()}")
elif not isinstance(loaded_descr, ModelDescr):
    raise ValueError("This notebook expects a model description")
else:
    model = loaded_descr

print(f"loaded '{model.name}'")


In [None]:
model.validation_summary.display()

In [None]:
import matplotlib.pyplot as plt
from imageio.v3 import imread

from bioimageio.spec.utils import download

for cover in model.covers:
    cover_data = imread(download(cover).path)
    plt.imshow(cover_data)
    plt.xticks([])
    plt.yticks([])
    plt.show()


the following parts assume we only have a single input and a single output tensor

In [None]:
assert len(model.inputs) == 1
assert len(model.outputs) == 1


In [None]:
from pprint import pprint

import numpy as np

pprint(model.inputs[0].axes)
test_input_path = model.inputs[0].test_tensor.download().path
test_input_array = np.load(test_input_path)
print(test_input_array.shape)


In [None]:
pprint(model.outputs[0].axes)
test_output_path = model.outputs[0].test_tensor.download().path
test_output_array = np.load(test_output_path)
print(test_output_array.shape)


In [None]:
assert model.weights.pytorch_state_dict is not None
pytorch_state_dict_weights_src = model.weights.pytorch_state_dict.download().path
print(pytorch_state_dict_weights_src)
assert model.weights.torchscript is not None
torchscript_weights_src = model.weights.torchscript.download().path
print(torchscript_weights_src)


In [None]:
from bioimageio.spec.model.v0_5 import ArchitectureFromFileDescr
from bioimageio.spec.utils import download

assert model.weights.pytorch_state_dict is not None
arch = model.weights.pytorch_state_dict.architecture
assert isinstance(arch, ArchitectureFromFileDescr)
print(f"Model architecture given by '{arch.callable}' in {arch.source}")
print("architecture key word arguments:")
pprint(arch.kwargs)
arch_file_path = download(arch.source, sha256=arch.sha256).path
arch_file_sha256 = arch.sha256
arch_name = arch.callable
arch_kwargs = arch.kwargs


## Create a model description

Let's recreate a model based on parts of the loaded model description from above!

Creating a model description in Python means creating a `ModelDescr` object.
Without any input data this will raise a `ValidationError` listing missing fields that are required:

In [None]:
from bioimageio.spec.common import ValidationError
from bioimageio.spec.model.v0_5 import ModelDescr

try:
    my_model_descr = ModelDescr()  # type: ignore
except ValidationError as e:
    print(e)


to populate a `ModelDescr` appropriately we create the required subparts.
Let's start with the inputs:

In [None]:
from bioimageio.spec.model.v0_5 import (
    Author,
    AxisId,
    BatchAxis,
    ChannelAxis,
    CiteEntry,
    Doi,
    FileDescr,
    Identifier,
    InputTensorDescr,
    IntervalOrRatioDataDescr,
    ModelDescr,
    OutputTensorDescr,
    ParameterizedSize,
    PytorchStateDictWeightsDescr,
    SizeReference,
    SpaceInputAxis,
    SpaceOutputAxis,
    TensorId,
    TorchscriptWeightsDescr,
    WeightsDescr,
)

input_axes = [
    BatchAxis(),
    ChannelAxis(channel_names=[Identifier("raw")])]
if example_model_id == "impartial-shrimp":
    input_axes += [
        SpaceInputAxis(id=AxisId("z"), size=ParameterizedSize(min=16, step=8)),
        SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=144, step=72)),
        SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=144, step=72)),
    ]
    data_descr = IntervalOrRatioDataDescr(type="uint8")
elif example_model_id == "pioneering-rhino":
    input_axes += [
        SpaceInputAxis(id=AxisId('y'), size=ParameterizedSize(min=256, step=8)),
        SpaceInputAxis(id=AxisId('x'), size=ParameterizedSize(min=256, step=8)),
    ]
    data_descr = IntervalOrRatioDataDescr()
else:
    raise NotImplementedError(f"Recreating inputs for {example_model_id} is not implemented")

input_descr = InputTensorDescr(id=TensorId("raw"), axes=input_axes, test_tensor=FileDescr(source=test_input_path), data=data_descr)


... and describe the outputs very similarly:

In [None]:
output_axes = [
    BatchAxis(),
    ChannelAxis(channel_names=[Identifier("membrane")])]
if example_model_id == "impartial-shrimp":
    output_axes += [
        SpaceOutputAxis(id=AxisId("z"), size=ParameterizedSize(min=16, step=8)),  # implicitly same size as raw.z as it is parametrized the same.
        SpaceOutputAxis(id=AxisId('y'), size=ParameterizedSize(min=144, step=72)),
        SpaceOutputAxis(id=AxisId('x'), size=ParameterizedSize(min=144, step=72))
    ]
elif example_model_id == "pioneering-rhino":
    output_axes += [
        SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('y'))),  # explicitly same size as raw.y
        SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId('raw'), axis_id=AxisId('x'))),
    ]
else:
    raise NotImplementedError(f"Recreating outputs for {example_model_id} is not implemented")

output_descr = OutputTensorDescr(id=TensorId("prob"), axes=output_axes, test_tensor=FileDescr(source=test_output_path))


... and finish with describing the architecutre needed for the pytorch state dict weights:

In [None]:
from bioimageio.spec.model.v0_5 import (
    ArchitectureFromFileDescr,
    ArchitectureFromLibraryDescr,
    Version,
)

try:
    import torch
except ImportError:
    pytorch_version = Version("1.15")
else:
    pytorch_version = Version(torch.__version__)

pytorch_architecture = ArchitectureFromFileDescr(
    source=arch_file_path,
    sha256=arch_file_sha256,
    callable=arch_name,
    kwargs=arch_kwargs
)
# A model architecture published as a package may also be referenced
# Make sure to include the library referenced in `import_from` in the `depdendencies`
my_unused_arch = ArchitectureFromLibraryDescr(callable=Identifier("MyModel"), import_from="my_library.subpackage")


now, we are ready to create a new model:

In [None]:
from bioimageio.spec.model.v0_5 import LicenseId

my_model_descr = ModelDescr(
    name="My cool model",
    description="A test model for demonstration purposes only",
    authors=[Author(name="me", affiliation="my institute", github_user="bioimageiobot")],  # change github_user to your GitHub account name
    cite=[CiteEntry(text="for model training see my paper", doi=Doi("10.1234something"))],
    license=LicenseId("MIT"),
    documentation=HttpUrl("https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/README.md"),
    git_repo=HttpUrl("https://github.com/bioimage-io/spec-bioimage-io"),  # change to repo where your model is developed
    inputs=[input_descr],
    outputs=[output_descr],
    weights=WeightsDescr(
        pytorch_state_dict=PytorchStateDictWeightsDescr(
            source=pytorch_state_dict_weights_src,
            architecture=pytorch_architecture,
            pytorch_version=pytorch_version
        ),
        torchscript=TorchscriptWeightsDescr(
            source=torchscript_weights_src,
            pytorch_version=pytorch_version,
            parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights ones.
        ),
    ),
    )

print("created '{my_model_descr.name}'")


some optional fields were filed with default values, e.g. as we did not specify `covers`, a default visualization of the test inputs and test outputs was used:

In [None]:
for cover in my_model_descr.covers:
    plt.imshow(imread(cover))
    plt.xticks([])
    plt.yticks([])
    plt.show()


## test your model

In [None]:
from bioimageio.core import test_model

summary = test_model(my_model_descr)
summary.display()

side note: the validation summary is also available as a property

In [None]:
assert summary == my_model_descr.validation_summary

## Package your model

A model is more than it's YAML description file! We refer to a zip-file containing all files relevant to a model as a model package. 

In [None]:
from pathlib import Path

from bioimageio.spec import save_bioimageio_package

print("package path:", save_bioimageio_package(my_model_descr, output_path=Path('my_model.zip')))