Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tiling #161

Merged
merged 10 commits into from
Nov 25, 2021
49 changes: 37 additions & 12 deletions bioimageio/core/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def load_tile(tile):
inp = input_[tile]
# whether to pad on the right or left of the dim for the spatial dims
# + placeholders for batch and axis dimension, where we don't pad
pad_right = [None, None] + [tile[ax].start == 0 for ax in input_axes if ax in "xyz"]
pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_axes]
return inp, pad_right

# we need to use padded prediction for the individual tiles in case the
Expand Down Expand Up @@ -318,6 +318,8 @@ def predict_with_tiling(prediction_pipeline: PredictionPipeline, inputs, tiling)


def parse_padding(padding, model):
if padding is None: # no padding
return padding
if len(model.inputs) > 1:
raise NotImplementedError("Padding for multiple inputs not yet implemented")

Expand All @@ -327,9 +329,7 @@ def parse_padding(padding, model):
def check_padding(padding):
assert all(k in pad_keys for k in padding.keys())

if padding is None: # no padding
return padding
elif isinstance(padding, dict): # pre-defined padding
if isinstance(padding, dict): # pre-defined padding
check_padding(padding)
elif isinstance(padding, bool): # determine padding from spec
if padding:
Expand All @@ -350,7 +350,25 @@ def check_padding(padding):
return padding


# simple heuristic to determine suitable shape from min and step
def _determine_shape(min_shape, step, axes):
is3d = "z" in axes
min_len = 64 if is3d else 256
shape = []
for ax, min_ax, step_ax in zip(axes, min_shape, step):
if ax in "zyx" and step_ax > 0:
len_ax = min_ax
while len_ax < min_len:
len_ax += step_ax
shape.append(len_ax)
else:
shape.append(min_ax)
return shape


def parse_tiling(tiling, model):
if tiling is None: # no tiling
return tiling
if len(model.inputs) > 1:
raise NotImplementedError("Tiling for multiple inputs not yet implemented")

Expand All @@ -359,13 +377,17 @@ def parse_tiling(tiling, model):

input_spec = model.inputs[0]
output_spec = model.outputs[0]
axes = input_spec.axes

def check_tiling(tiling):
assert "halo" in tiling and "tile" in tiling
spatial_axes = [ax for ax in axes if ax in "xyz"]
halo = tiling["halo"]
tile = tiling["tile"]
assert all(halo.get(ax, 0) > 0 for ax in spatial_axes)
assert all(tile.get(ax, 0) > 0 for ax in spatial_axes)

if tiling is None: # no tiling
return tiling
elif isinstance(tiling, dict):
if isinstance(tiling, dict):
check_tiling(tiling)
elif isinstance(tiling, bool):
if tiling:
Expand All @@ -374,18 +396,21 @@ def check_tiling(tiling):
# output space and then request the corresponding input tiles
# so we would need to apply the output scale and offset to the
# input shape to compute the tile size and halo here
axes = input_spec.axes
shape = input_spec.shape
if not isinstance(shape, list):
# NOTE this might result in very small tiles.
# it would be good to have some heuristic to determine a suitable tilesize
# from shape.min and shape.step
shape = shape.min
shape = _determine_shape(shape.min, shape.step, axes)
assert isinstance(shape, list)
assert len(shape) == len(axes)

halo = output_spec.halo
if halo is None:
raise ValueError("Model does not provide a valid halo to use for tiling with default parameters")

tiling = {
"halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"},
"tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"},
}
check_tiling(tiling)
else:
tiling = None
else:
Expand Down
2 changes: 1 addition & 1 deletion dev/environment-torch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ dependencies:
- pytest
- python >=3.7
- xarray
- pytorch
- pytorch <1.10
- onnxruntime
35 changes: 18 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
torch_models_pre_3_10 = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
tensorflow1_models = ["FruNet_model", "stardist"]
tensorflow1_models = ["stardist"]
tensorflow2_models = []
keras_models = ["FruNet_model"]
tensorflow_js_models = ["FruNet_model"]
keras_models = []
tensorflow_js_models = []

model_sources = {
"FruNet_model": "https://sandbox.zenodo.org/record/894498/files/rdf.yaml",
# "FruNet_model": "https://raw.githubusercontent.com/deepimagej/models/master/fru-net_sev_segmentation/model.yaml",
# TODO add unet2d_keras_tf from https://github.com/bioimage-io/spec-bioimage-io/pull/267
# "unet2d_keras_tf": (""),
"unet2d_nuclei_broad_model": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"unet2d_nuclei_broad/rdf.yaml"
Expand All @@ -35,10 +35,12 @@
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml"
),
"stardist": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf.yaml"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models"
"/stardist_example_model/rdf.yaml"
),
"stardist_wrong_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf_wrong_shape.yaml"
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"stardist_example_model/rdf_wrong_shape.yaml"
),
}

Expand All @@ -50,7 +52,6 @@
except ImportError:
torch = None
torch_version = None

skip_torch = torch is None

try:
Expand All @@ -66,17 +67,15 @@
except ImportError:
tensorflow = None
tf_major_version = None

skip_tensorflow = tensorflow is None
skip_tensorflow = True # todo: update FruNet and remove this
skip_tensorflow_js = True # todo: update FruNet and figure out how to test tensorflow_js weights in python
skip_tensorflow_js = True # TODO: add a tensorflow_js example model

try:
import keras
except ImportError:
keras = None
skip_keras = keras is None
skip_keras = True # FruNet requires update
skip_keras = True # TODO add unet2d_keras_tf to have a model for keras tests

# load all model packages we need for testing
load_model_packages = set()
Expand Down Expand Up @@ -120,14 +119,14 @@ def unet2d_nuclei_broad_model(request):


# written as model group to automatically skip on missing tensorflow 1
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["FruNet_model"])
def FruNet_model(request):
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
def stardist_wrong_shape(request):
return pytest.model_packages[request.param]


# written as model group to automatically skip on missing tensorflow 1
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
def stardist_wrong_shape(request):
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
def stardist(request):
return pytest.model_packages[request.param]


Expand Down Expand Up @@ -164,7 +163,9 @@ def any_tensorflow_js_model(request):


# fixture to test with all models that should run in the current environment
@pytest.fixture(params=load_model_packages)
# we exclude stardist_wrong_shape here because it is not a valid model
# and included only to test that validation for this model fails
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"})
def any_model(request):
return pytest.model_packages[request.param]

Expand Down
84 changes: 64 additions & 20 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@
from bioimageio.core.resource_io.nodes import Model


def test_test_model(unet2d_nuclei_broad_model):
def test_test_model(any_model):
from bioimageio.core.resource_tests import test_model

assert test_model(unet2d_nuclei_broad_model)
assert test_model(any_model)


def test_test_resource(unet2d_nuclei_broad_model):
def test_test_resource(any_model):
from bioimageio.core.resource_tests import test_resource

assert test_resource(unet2d_nuclei_broad_model)
assert test_resource(any_model)


def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
def test_predict_image(any_model, tmpdir):
from bioimageio.core.prediction import predict_image

spec = load_resource_description(any_model)
Expand Down Expand Up @@ -57,46 +56,81 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir):
assert_array_almost_equal(res, exp, decimal=4)


def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
def _test_predict_with_padding(model, tmp_path):
from bioimageio.core.prediction import predict_image

spec = load_resource_description(any_model)
spec = load_resource_description(model)
assert isinstance(spec, Model)
image = np.load(str(spec.test_inputs[0]))[0, 0]

input_spec, output_spec = spec.inputs[0], spec.outputs[0]
channel_axis = input_spec.axes.index("c")
channel_first = channel_axis == 1

image = np.load(str(spec.test_inputs[0]))
assert image.shape[channel_axis] == 1
if channel_first:
image = image[0, 0]
else:
image = image[0, ..., 0]
original_shape = image.shape
assert image.ndim == 2

if isinstance(output_spec.shape, list):
n_channels = output_spec.shape[channel_axis]
else:
scale = output_spec.shape.scale[channel_axis]
offset = output_spec.shape.offset[channel_axis]
in_channels = 1
n_channels = int(2 * offset + scale * in_channels)

# write the padded image
image = image[3:-2, 1:-12]
in_path = tmp_path / "in.tif"
out_path = tmp_path / "out.tif"
imageio.imwrite(in_path, image)

def check_result():
assert out_path.exists()
res = imageio.imread(out_path)
assert res.shape == image.shape
if n_channels == 1:
assert out_path.exists()
res = imageio.imread(out_path)
assert res.shape == image.shape
else:
path = str(out_path)
for c in range(n_channels):
channel_out_path = Path(path.replace(".tif", f"-c{c}.tif"))
assert channel_out_path.exists()
res = imageio.imread(channel_out_path)
assert res.shape == image.shape

# test with dynamic padding
predict_image(any_model, in_path, out_path, padding={"x": 8, "y": 8, "mode": "dynamic"})
predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"})
check_result()

# test with fixed padding
predict_image(
any_model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}
model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}
)
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
check_result()

# test with automated padding
predict_image(any_model, in_path, out_path, padding=True)
predict_image(model, in_path, out_path, padding=True)
check_result()


def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
# prediction with padding with the parameters above may not be suited for any model
# so we only run it for the pytorch unet2d here
def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path):
_test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path)


def test_predict_image_with_padding_channel_last(stardist, tmp_path):
_test_predict_with_padding(stardist, tmp_path)


def _test_predict_image_with_tiling(model, tmp_path):
from bioimageio.core.prediction import predict_image

spec = load_resource_description(unet2d_nuclei_broad_model)
spec = load_resource_description(model)
assert isinstance(spec, Model)
inputs = spec.test_inputs
assert len(inputs) == 1
Expand All @@ -114,14 +148,24 @@ def check_result():

# with tiling config
tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}}
predict_image(unet2d_nuclei_broad_model, inputs, [out_path], tiling=tiling)
predict_image(model, inputs, [out_path], tiling=tiling)
check_result()

# with tiling determined from spec
predict_image(unet2d_nuclei_broad_model, inputs, [out_path], tiling=True)
predict_image(model, inputs, [out_path], tiling=True)
check_result()


# prediction with tiling with the parameters above may not be suited for any model
# so we only run it for the pytorch unet2d here
def test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path):
_test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path)


def test_predict_image_with_tiling_channel_last(stardist, tmp_path):
_test_predict_image_with_tiling(stardist, tmp_path)


def test_predict_images(unet2d_nuclei_broad_model, tmp_path):
from bioimageio.core.prediction import predict_images

Expand Down