In [1]:
import hashlib
import torch

In [2]:
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
from bioimageio.spec.model.v0_5 import (
    ArchitectureFromLibraryDescr,
    ArchitectureFromFileDescr,
    Author,
    CiteEntry,
    AxisBase,
    AxisId,
    BatchAxis,
    ChannelAxis,
    EnvironmentFileDescr,
    FileDescr,
    FixedZeroMeanUnitVarianceDescr,
    FixedZeroMeanUnitVarianceKwargs,
    Identifier,
    InputTensorDescr,
    ModelDescr,
    OutputTensorDescr,
    PytorchStateDictWeightsDescr,
    SpaceInputAxis,
    SpaceOutputAxis,
    IndexOutputAxis,
    TensorId,
    Version,
    WeightsDescr,
)

In [3]:
np.random.seed(13)
torch.manual_seed(13)
torch.cuda.manual_seed(13)

In [3]:
# README.md file
doc_md = "doc.md"

### Authors and Citations

In [None]:
# authors
au1 = Author(
    name="John Doe",
    email=None,
    affiliation=None,
    orcid=None,
    github_user=None
)

authors = [au1]

In [None]:
citation1 = CiteEntry(
    text="""Stringer, C., Wang, T., Michaelos, M. et al. Cellpose: a generalist algorithm for cellular segmentation. Nat Methods 18, 100–106 (2021).""",
    doi="10.1038/s41592-020-01018-x"
)

citation2 = CiteEntry(
    text="""Pachitariu, M., Stringer, C. Cellpose 2.0: how to train your own model. Nat Methods 19, 1634–1641 (2022).""",
    doi="10.1038/s41592-022-01663-4"
)

citation3 = CiteEntry(
    text="""Stringer, Carsen, and Marius Pachitariu. "Cellpose3: one-click image restoration for improved cellular segmentation." bioRxiv (2024).""",
    doi="10.1101/2024.02.10.579780"
)

### Model Input
**Must have the shape of (B, C=1, Y, X)**

In [None]:
# model input
input_path = "./data/input_sample.npy"
input_sample = np.load(input_path)
# print(input_sample.shape)

# building axes
in_axes = [BatchAxis()]  # batch is always there!
# channel
in_axes.append(
    ChannelAxis(channel_names=[Identifier("channel")])
)
# spatial dims
in_axes.append(
    SpaceInputAxis(id=AxisId("y"), size=input_sample.shape[2])
)
in_axes.append(
    SpaceInputAxis(id=AxisId("x"), size=input_sample.shape[3])
)
# input descriptor
input_descr = InputTensorDescr(
    id=TensorId("input"),
    axes=in_axes,
    test_tensor=FileDescr(source=input_path),
)

### Model Outputs:
**masks, flows, styles, diams**

In [None]:
# masks
output_1_path = "./data/output_sample1_masks.npy"
output_1_sample = np.load(output_1_path)
# print(output_1_sample.shape)
# building axes
out_1_axes = [BatchAxis()]  # batch is always there!
# spatial dims
out_1_axes.append(
    SpaceOutputAxis(id=AxisId("y"), size=output_1_sample.shape[1])
)
out_1_axes.append(
    SpaceOutputAxis(id=AxisId("x"), size=output_1_sample.shape[2])
)
# output descriptor
output_1_descr = OutputTensorDescr(
    id=TensorId("masks"),
    axes=out_1_axes,
    test_tensor=FileDescr(source=output_1_path),
)

In [None]:
# flows
output_2_path = "./data/output_sample2_flows.npy"
output_2_sample = np.load(output_2_path)
# building axes
out_2_axes = [BatchAxis()]  # batch is always there!
# channel
out_2_axes.append(
    ChannelAxis(channel_names=[
        Identifier(f"ch_{i}")
        for i in range(output_2_sample.shape[1])
    ])
)
# spatial dims
out_2_axes.append(
    SpaceOutputAxis(id=AxisId("y"), size=output_2_sample.shape[2])
)
out_2_axes.append(
    SpaceOutputAxis(id=AxisId("x"), size=output_2_sample.shape[3])
)
# output descriptor
output_2_descr = OutputTensorDescr(
    id=TensorId("flows"),
    axes=out_2_axes,
    test_tensor=FileDescr(source=output_2_path),
)

In [None]:
# styles
output_3_path = "./data/output_sample3_styles.npy"
output_3_sample = np.load(output_3_path)
# building axes
out_3_axes = [BatchAxis()]  # batch is always there!
# spatial dims
out_3_axes.append(
    SpaceOutputAxis(id=AxisId("y"), size=output_3_sample.shape[1])
)
# output descriptor
output_3_descr = OutputTensorDescr(
    id=TensorId("styles"),
    axes=out_3_axes,
    test_tensor=FileDescr(source=output_3_path),
)

In [None]:
# diams
output_4_path = "./data/output_sample4_diams.npy"
output_4_sample = np.load(output_4_path)
print(output_4_sample.shape)
# building axes
out_4_axes = [BatchAxis(), SpaceOutputAxis(id=AxisId("y"), size=1)]
# output descriptor
output_4_descr = OutputTensorDescr(
    id=TensorId("diams"),
    axes=out_4_axes,
    test_tensor=FileDescr(source=output_4_path),
)

### Model's Architecture & Weights description

In [None]:
# model arch & weights
model_weights_file = Path("./cellpose_models/cyto3.pth")

model_src_file = "./model.py"
with open(model_src_file, "rb") as f:
    model_sha256 = hashlib.sha256(f.read()).hexdigest()
# print(model_sha256)

# params to instantiate the model (pass to __init__)
model_kwargs = {
    "diam_mean": 30.,
    "cp_batch_size": 8,
    "channels": [0, 0],
    "flow_threshold": 0.4,
    "cellprob_threshold": 0.0,
    "stitch_threshold": 0.0,
    "estimate_diam": True,
    "normalize": True,
    "do_3D": False,
    "gpu": False
}

arch_descr = ArchitectureFromFileDescr(
    source=model_src_file,
    sha256=model_sha256,
    callable="CellPoseWrapper",
    kwargs=model_kwargs
)

pytorch_version = str(torch.__version__)
env_path = "./environment.yml"

weights_descr = WeightsDescr(
    pytorch_state_dict=PytorchStateDictWeightsDescr(
        source=model_weights_file,
        architecture=arch_descr,
        pytorch_version=Version(pytorch_version),
        dependencies=EnvironmentFileDescr(source=env_path),
    ),
)

### Final Model Description

In [None]:
# model
model_descr = ModelDescr(
    name="CellPose(cyto3)",
    authors=authors,
    description="CellPose 'cyto3' model",
    documentation=doc_md,
    inputs=[input_descr],
    outputs=[output_1_descr, output_2_descr, output_3_descr, output_4_descr],
    tags=["Cellpose", "Cell Segmentation", "Segmentation"],
    links=[
        "https://github.com/mouseland/cellpose",
    ],
    license="BSD-3-Clause",
    git_repo="https://github.com/juglab/cellpose-bmz-wrapper",
    version="0.1.0",
    weights=weights_descr,
    cite=[citation1, citation2, citation3],
    covers=["cover.png"]
)

In [None]:
model_descr.validation_summary.display()

In [None]:
from bioimageio.core import test_model

summary = test_model(model_descr)
summary.display()

In [None]:
from bioimageio.spec import save_bioimageio_package

save_bioimageio_package(model_descr, output_path=Path("cellpose_cyto3.zip"))