diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 2d51c8f4..7c171456 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -2,9 +2,9 @@ name: Test and Deploy bioimageio.core on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ "**" ] + branches: ['**'] defaults: run: @@ -58,81 +58,69 @@ jobs: matrix: include: - python-version: '3.9' - conda-env: dev - spec: conda numpy-version: 1 - python-version: '3.9' - conda-env: dev - spec: main numpy-version: 2 - python-version: '3.10' - conda-env: full run-expensive-tests: true report-coverage: true save-cache: true - spec: conda numpy-version: 1 - python-version: '3.11' - conda-env: dev - spec: main numpy-version: 2 - python-version: '3.12' - conda-env: dev - spec: conda numpy-version: 1 # - python-version: '3.13' - # conda-env: '313' - # spec: main # numpy-version: 2 steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v6 - with: - python-version: ${{matrix.python-version}} - cache: 'pip' - - name: Install dependencies - run: | - pip install --upgrade pip - pip install -e .[dev] numpy==${{matrix.numpy-version}}.* - - name: Pyright - if: matrix.run-expensive-tests - run: | - pyright --version - pyright -p pyproject.toml --pythonversion ${{ matrix.python-version }} - - name: Restore bioimageio cache ${{needs.populate-cache.outputs.cache-key}} - uses: actions/cache/restore@v4 - with: - path: bioimageio_cache - key: ${{needs.populate-cache.outputs.cache-key}} - - name: pytest - run: pytest --cov bioimageio --cov-report xml --cov-append --capture no --disable-pytest-warnings - env: - BIOIMAGEIO_CACHE_PATH: bioimageio_cache - RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }} - - name: Save bioimageio cache ${{needs.populate-cache.outputs.cache-key}} - if: matrix.save-cache - uses: actions/cache/save@v4 - with: - path: bioimageio_cache - key: ${{needs.populate-cache.outputs.cache-key}} + - uses: actions/checkout@v4 + - uses: actions/setup-python@v6 + with: + python-version: ${{matrix.python-version}} + cache: 'pip' + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e .[dev] numpy==${{matrix.numpy-version}}.* + - name: Pyright + if: matrix.run-expensive-tests # pyright is not expensive, but we only want to run it once due to otherwise inconsistent typing + run: | + pyright --version + pyright -p pyproject.toml --pythonversion ${{ matrix.python-version }} + - name: Restore bioimageio cache ${{needs.populate-cache.outputs.cache-key}} + uses: actions/cache/restore@v4 + with: + path: bioimageio_cache + key: ${{needs.populate-cache.outputs.cache-key}} + - name: pytest + run: pytest --cov bioimageio --cov-report xml --cov-append --capture no --disable-pytest-warnings + env: + BIOIMAGEIO_CACHE_PATH: bioimageio_cache + RUN_EXPENSIVE_TESTS: ${{ matrix.run-expensive-tests && 'true' || 'false' }} + - name: Save bioimageio cache ${{needs.populate-cache.outputs.cache-key}} + if: matrix.save-cache + uses: actions/cache/save@v4 + with: + path: bioimageio_cache + key: ${{needs.populate-cache.outputs.cache-key}} - - if: matrix.report-coverage && github.event_name == 'pull_request' - uses: orgoro/coverage@v3.2 - with: - coverageFile: coverage.xml - token: ${{secrets.GITHUB_TOKEN}} - - if: matrix.report-coverage && github.ref == 'refs/heads/main' - run: | - pip install genbadge[coverage] - genbadge coverage --input-file coverage.xml --output-file ./dist/coverage/coverage-badge.svg - coverage html -d dist/coverage - - if: matrix.report-coverage && github.ref == 'refs/heads/main' - uses: actions/upload-artifact@v4 - with: - name: coverage - retention-days: 1 - path: dist + - if: matrix.report-coverage && github.event_name == 'pull_request' + uses: orgoro/coverage@v3.2 + with: + coverageFile: coverage.xml + token: ${{secrets.GITHUB_TOKEN}} + - if: matrix.report-coverage && github.ref == 'refs/heads/main' + run: | + pip install genbadge[coverage] + genbadge coverage --input-file coverage.xml --output-file ./dist/coverage/coverage-badge.svg + coverage html -d dist/coverage + - if: matrix.report-coverage && github.ref == 'refs/heads/main' + uses: actions/upload-artifact@v4 + with: + name: coverage + retention-days: 1 + path: dist conda-build: needs: test @@ -145,7 +133,7 @@ jobs: with: auto-update-conda: true auto-activate-base: true - activate-environment: "" + activate-environment: '' channel-priority: strict miniforge-version: latest conda-solver: libmamba @@ -242,7 +230,7 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1.12 with: user: __token__ - password: "${{ secrets.PYPI_TOKEN }}" + password: '${{ secrets.PYPI_TOKEN }}' packages-dir: dist/ verbose: true - name: Publish the release notes @@ -250,6 +238,6 @@ jobs: uses: release-drafter/release-drafter@v6.0.0 with: publish: "${{ steps.tag-version.outputs.new_tag != '' }}" - tag: "${{ steps.tag-version.outputs.new_tag }}" + tag: '${{ steps.tag-version.outputs.new_tag }}' env: - GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + GITHUB_TOKEN: '${{ secrets.GITHUB_TOKEN }}' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9544bc10..196c94bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,4 @@ repos: - - repo: https://github.com/ambv/black - rev: 25.1.0 - hooks: - - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.12.8 hooks: diff --git a/README.md b/README.md index 7adb7800..c9a929fc 100644 --- a/README.md +++ b/README.md @@ -364,9 +364,16 @@ may be controlled with the `LOGURU_LEVEL` environment variable. ## Changelog +### 0.9.3 + +- bump bioimageio.spec library version to 0.5.5.5 +- more robust test model reporting +- improved user input axis intepretation +- fixed conda subprocess calls + ### 0.9.2 -fix model inference tolerance reporting +- fix model inference tolerance reporting ### 0.9.1 diff --git a/example/dataset_statistics_demo.ipynb b/example/dataset_statistics_demo.ipynb index 25e0d676..705fe295 100644 --- a/example/dataset_statistics_demo.ipynb +++ b/example/dataset_statistics_demo.ipynb @@ -329,12 +329,14 @@ "source": [ "# compute dataset statistics on all samples\n", "# (in this case we should really use the non-overlapping tiles as samples in dataset_for_initial_statistics)\n", - "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " dataset_for_initial_statistics=dataset,\n", - " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset)\n", - " # times you might want to set this to zero to avoid further updates to the dataset statistics\n", - ") as pp:\n", + "with (\n", + " create_prediction_pipeline(\n", + " bioimageio_model=model_resource,\n", + " dataset_for_initial_statistics=dataset,\n", + " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset)\n", + " # times you might want to set this to zero to avoid further updates to the dataset statistics\n", + " ) as pp\n", + "):\n", " only_init_dataset_stats = process_dataset(pp, dataset)" ] }, diff --git a/example/model_usage.ipynb b/example/model_usage.ipynb index a253cc60..74fbe830 100644 --- a/example/model_usage.ipynb +++ b/example/model_usage.ipynb @@ -111,7 +111,6 @@ " Mapping[str, NDArray[Any]], Mapping[TensorId, Union[Tensor, NDArray[Any]]]\n", " ],\n", ") -> None:\n", - "\n", " for title, image in images.items():\n", " if isinstance(image, Tensor):\n", " input_array = image.data.data\n", diff --git a/pyproject.toml b/pyproject.toml index 059d8e9b..c2393c11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires-python = ">=3.9" readme = "README.md" dynamic = ["version"] dependencies = [ - "bioimageio.spec ==0.5.5.4", + "bioimageio.spec ==0.5.5.5", "h5py", "imagecodecs", "imageio>=2.10", @@ -44,15 +44,14 @@ onnx = ["onnxruntime"] pytorch = ["torch>=1.6,<3", "torchvision>=0.21", "keras>=3.0,<4"] tensorflow = ["tensorflow", "keras>=2.15,<4"] dev = [ - "black", "cellpose", # for model testing "crick", "httpx", - "jupyter-black", "jupyter", "keras>=3.0,<4", "matplotlib", "monai", # for model testing + "numpy", "onnx", "onnxruntime", "packaging>=17.0", @@ -79,12 +78,6 @@ where = ["src/"] [tool.setuptools.dynamic] version = { attr = "bioimageio.core.__version__" } -[tool.black] -line-length = 88 -extend-exclude = "/presentations/" -target-version = ["py39", "py310", "py311", "py312"] -preview = true - [tool.pyright] exclude = [ "**/__pycache__", diff --git a/src/bioimageio/core/__init__.py b/src/bioimageio/core/__init__.py index e836908f..f35cb69a 100644 --- a/src/bioimageio/core/__init__.py +++ b/src/bioimageio/core/__init__.py @@ -3,7 +3,7 @@ """ # ruff: noqa: E402 -__version__ = "0.9.2" +__version__ = "0.9.3" from loguru import logger logger.disable("bioimageio.core") diff --git a/src/bioimageio/core/_magic_tensor_ops.py b/src/bioimageio/core/_magic_tensor_ops.py index c1526fef..9b73efa6 100644 --- a/src/bioimageio/core/_magic_tensor_ops.py +++ b/src/bioimageio/core/_magic_tensor_ops.py @@ -71,12 +71,14 @@ def __ge__(self, other: _Compatible) -> Self: def __eq__(self, other: _Compatible) -> Self: # type: ignore[override] return self._binary_op( - other, nputils.array_eq # pyright: ignore[reportUnknownArgumentType] + other, + nputils.array_eq, # pyright: ignore[reportUnknownArgumentType] ) def __ne__(self, other: _Compatible) -> Self: # type: ignore[override] return self._binary_op( - other, nputils.array_ne # pyright: ignore[reportUnknownArgumentType] + other, + nputils.array_ne, # pyright: ignore[reportUnknownArgumentType] ) # When __eq__ is defined but __hash__ is not, then an object is unhashable, @@ -171,22 +173,30 @@ def __invert__(self) -> Self: def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op( - ops.round_, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ops.round_, # pyright: ignore[reportUnknownArgumentType] + *args, + **kwargs, ) def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op( - ops.argsort, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ops.argsort, # pyright: ignore[reportUnknownArgumentType] + *args, + **kwargs, ) def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op( - ops.conj, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ops.conj, # pyright: ignore[reportUnknownArgumentType] + *args, + **kwargs, ) def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op( - ops.conjugate, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ops.conjugate, # pyright: ignore[reportUnknownArgumentType] + *args, + **kwargs, ) __add__.__doc__ = operator.add.__doc__ diff --git a/src/bioimageio/core/_resource_tests.py b/src/bioimageio/core/_resource_tests.py index 024ec16e..6b5764e9 100644 --- a/src/bioimageio/core/_resource_tests.py +++ b/src/bioimageio/core/_resource_tests.py @@ -66,6 +66,7 @@ from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args from bioimageio.core import __version__ +from bioimageio.core.io import save_tensor from ._prediction_pipeline import create_prediction_pipeline from .axis import AxisId, BatchSize @@ -73,6 +74,8 @@ from .digest_spec import get_test_input_sample, get_test_output_sample from .sample import Sample +CONDA_CMD = "conda.bat" if platform.system() == "Windows" else "conda" + class DeprecatedKwargs(TypedDict): absolute_tolerance: NotRequired[AbsoluteTolerance] @@ -190,7 +193,7 @@ def test_model( def default_run_command(args: Sequence[str]): logger.info("running '{}'...", " ".join(args)) - _ = subprocess.run(args, shell=True, text=True, check=True) + _ = subprocess.check_call(args) def test_description( @@ -378,21 +381,22 @@ def _test_in_env( env_name = hashlib.sha256(encoded_env).hexdigest() try: - run_command(["where" if platform.system() == "Windows" else "which", "conda"]) + run_command(["where" if platform.system() == "Windows" else "which", CONDA_CMD]) except Exception as e: raise RuntimeError("Conda not available") from e try: - run_command(["conda", "activate", env_name]) + run_command([CONDA_CMD, "activate", env_name]) except Exception: + working_dir.mkdir(parents=True, exist_ok=True) path = working_dir / "env.yaml" try: _ = path.write_bytes(encoded_env) logger.debug("written conda env to {}", path) run_command( - ["conda", "env", "create", f"--file={path}", f"--name={env_name}"] + [CONDA_CMD, "env", "create", f"--file={path}", f"--name={env_name}"] ) - run_command(["conda", "activate", env_name]) + run_command([CONDA_CMD, "activate", env_name]) except Exception as e: summary = descr.validation_summary summary.add_detail( @@ -423,7 +427,7 @@ def _test_in_env( run_command( cmd := ( [ - "conda", + CONDA_CMD, "run", "-n", env_name, @@ -789,58 +793,74 @@ def add_warning_entry(msg: str): else: continue - expected_np = expected.data.to_numpy().astype(np.float32) - del expected - actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) - del actual + try: + expected_np = expected.data.to_numpy().astype(np.float32) + del expected + actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) - rtol, atol, mismatched_tol = _get_tolerance( - model, wf=weight_format, m=m, **deprecated - ) - rtol_value = rtol * abs(expected_np) - abs_diff = abs(actual_np - expected_np) - mismatched = abs_diff > atol + rtol_value - mismatched_elements = mismatched.sum().item() - if not mismatched_elements: - continue - - mismatched_ppm = mismatched_elements / expected_np.size * 1e6 - abs_diff[~mismatched] = 0 # ignore non-mismatched elements - - r_max_idx_flat = ( - r_diff := (abs_diff / (abs(expected_np) + 1e-6)) - ).argmax() - r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) - r_max = r_diff[r_max_idx].item() - r_actual = actual_np[r_max_idx].item() - r_expected = expected_np[r_max_idx].item() - - # Calculate the max absolute difference with the relative tolerance subtracted - abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value - a_max_idx = np.unravel_index( - abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape - ) + rtol, atol, mismatched_tol = _get_tolerance( + model, wf=weight_format, m=m, **deprecated + ) + rtol_value = rtol * abs(expected_np) + abs_diff = abs(actual_np - expected_np) + mismatched = abs_diff > atol + rtol_value + mismatched_elements = mismatched.sum().item() + if not mismatched_elements: + continue - a_max = abs_diff[a_max_idx].item() - a_actual = actual_np[a_max_idx].item() - a_expected = expected_np[a_max_idx].item() - - msg = ( - f"Output '{m}' disagrees with {mismatched_elements} of" - + f" {expected_np.size} expected values" - + f" ({mismatched_ppm:.1f} ppm)." - + f"\n Max relative difference: {r_max:.2e}" - + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" - + f" at {dict(zip(dims, r_max_idx))}" - + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" - + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" - ) - if mismatched_ppm > mismatched_tol: + actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy") + try: + save_tensor(actual_output_path, actual) + except Exception as e: + logger.error( + "Failed to save actual output tensor to {}: {}", + actual_output_path, + e, + ) + + mismatched_ppm = mismatched_elements / expected_np.size * 1e6 + abs_diff[~mismatched] = 0 # ignore non-mismatched elements + + r_max_idx_flat = ( + r_diff := (abs_diff / (abs(expected_np) + 1e-6)) + ).argmax() + r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) + r_max = r_diff[r_max_idx].item() + r_actual = actual_np[r_max_idx].item() + r_expected = expected_np[r_max_idx].item() + + # Calculate the max absolute difference with the relative tolerance subtracted + abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value + a_max_idx = np.unravel_index( + abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape + ) + + a_max = abs_diff[a_max_idx].item() + a_actual = actual_np[a_max_idx].item() + a_expected = expected_np[a_max_idx].item() + except Exception as e: + msg = f"Output '{m}' disagrees with expected values." add_error_entry(msg) if stop_early: break else: - add_warning_entry(msg) + msg = ( + f"Output '{m}' disagrees with {mismatched_elements} of" + + f" {expected_np.size} expected values" + + f" ({mismatched_ppm:.1f} ppm)." + + f"\n Max relative difference: {r_max:.2e}" + + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" + + f" at {dict(zip(dims, r_max_idx))}" + + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" + + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" + + f"\n Saved actual output to {actual_output_path}." + ) + if mismatched_ppm > mismatched_tol: + add_error_entry(msg) + if stop_early: + break + else: + add_warning_entry(msg) except Exception as e: if get_validation_context().raise_errors: diff --git a/src/bioimageio/core/axis.py b/src/bioimageio/core/axis.py index 0b39045e..b57be688 100644 --- a/src/bioimageio/core/axis.py +++ b/src/bioimageio/core/axis.py @@ -3,9 +3,8 @@ from dataclasses import dataclass from typing import Literal, Mapping, Optional, TypeVar, Union -from typing_extensions import assert_never - from bioimageio.spec.model import v0_5 +from typing_extensions import Protocol, assert_never, runtime_checkable def _guess_axis_type(a: str): @@ -42,7 +41,16 @@ def _guess_axis_type(a: str): BatchSize = int AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] -AxisLike = Union[AxisId, AxisLetter, v0_5.AnyAxis, "Axis"] +_AxisLikePlain = Union[str, AxisId, AxisLetter] + + +@runtime_checkable +class AxisDescrLike(Protocol): + id: _AxisLikePlain + type: Literal["batch", "channel", "index", "space", "time"] + + +AxisLike = Union[_AxisLikePlain, AxisDescrLike, v0_5.AnyAxis, "Axis"] @dataclass @@ -60,14 +68,22 @@ def __post_init__(self): def create(cls, axis: AxisLike) -> Axis: if isinstance(axis, cls): return axis - elif isinstance(axis, Axis): - return Axis(id=axis.id, type=axis.type) - elif isinstance(axis, v0_5.AxisBase): - return Axis(id=AxisId(axis.id), type=axis.type) - elif isinstance(axis, str): - return Axis(id=AxisId(axis), type=_guess_axis_type(axis)) + + if isinstance(axis, (AxisId, str)): + axis_id = axis + axis_type = _guess_axis_type(str(axis)) else: - assert_never(axis) + if hasattr(axis, "type"): + axis_type = axis.type + else: + axis_type = _guess_axis_type(str(axis)) + + if hasattr(axis, "id"): + axis_id = axis.id + else: + axis_id = axis + + return Axis(id=AxisId(axis_id), type=axis_type) @dataclass @@ -81,7 +97,7 @@ def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisI axis_base = super().create(axis) if maybe_singleton is None: - if isinstance(axis, (Axis, str)): + if not isinstance(axis, v0_5.AxisBase): maybe_singleton = True else: if axis.size is None: diff --git a/src/bioimageio/core/backends/torchscript_backend.py b/src/bioimageio/core/backends/torchscript_backend.py index 8c2de21b..fa86e964 100644 --- a/src/bioimageio/core/backends/torchscript_backend.py +++ b/src/bioimageio/core/backends/torchscript_backend.py @@ -44,7 +44,6 @@ def __init__( def _forward_impl( self, input_arrays: Sequence[Optional[NDArray[Any]]] ) -> List[Optional[NDArray[Any]]]: - with torch.no_grad(): torch_tensor = [ None if a is None else torch.from_numpy(a).to(self.devices[0]) @@ -60,7 +59,9 @@ def _forward_impl( ( None if r is None - else r.cpu().numpy() if isinstance(r, torch.Tensor) else r + else r.cpu().numpy() + if isinstance(r, torch.Tensor) + else r ) for r in output_seq ] diff --git a/src/bioimageio/core/block_meta.py b/src/bioimageio/core/block_meta.py index 4e40c1cf..b8246953 100644 --- a/src/bioimageio/core/block_meta.py +++ b/src/bioimageio/core/block_meta.py @@ -198,13 +198,13 @@ def __post_init__(self): if not isinstance(self.halo, Frozen): object.__setattr__(self, "halo", Frozen(self.halo)) - assert all( - a in self.sample_shape for a in self.inner_slice - ), "block has axes not present in sample" + assert all(a in self.sample_shape for a in self.inner_slice), ( + "block has axes not present in sample" + ) - assert all( - a in self.inner_slice for a in self.halo - ), "halo has axes not present in block" + assert all(a in self.inner_slice for a in self.halo), ( + "halo has axes not present in block" + ) if any(s > self.sample_shape[a] for a, s in self.shape.items()): logger.warning( @@ -343,9 +343,9 @@ def split_multiple_shapes_into_blocks( if strides is None: strides = {} - assert not ( - unknown_block := [t for t in strides if t not in block_shapes] - ), f"`stride` specified for tensors without block shape: {unknown_block}" + assert not (unknown_block := [t for t in strides if t not in block_shapes]), ( + f"`stride` specified for tensors without block shape: {unknown_block}" + ) blocks: Dict[MemberId, Iterable[BlockMeta]] = {} n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} diff --git a/src/bioimageio/core/proc_ops.py b/src/bioimageio/core/proc_ops.py index 95f7466a..65c4975a 100644 --- a/src/bioimageio/core/proc_ops.py +++ b/src/bioimageio/core/proc_ops.py @@ -237,9 +237,9 @@ class Clip(_SimpleOperator): def __post_init__(self): assert self.min is not None or self.max is not None, "missing min or max value" - assert ( - self.min is None or self.max is None or self.min < self.max - ), f"expected min < max, but {self.min} !< {self.max}" + assert self.min is None or self.max is None or self.min < self.max, ( + f"expected min < max, but {self.min} !< {self.max}" + ) def _apply(self, x: Tensor, stat: Stat) -> Tensor: return x.clip(self.min, self.max) @@ -321,9 +321,9 @@ def from_proc_descr( gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) else: - assert ( - isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 - ), kwargs.gain + assert isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1, ( + kwargs.gain + ) gain = ( kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] ) diff --git a/src/bioimageio/core/sample.py b/src/bioimageio/core/sample.py index 0d4c3724..00308d7f 100644 --- a/src/bioimageio/core/sample.py +++ b/src/bioimageio/core/sample.py @@ -76,12 +76,12 @@ def split_into_blocks( pad_mode: PadMode, broadcast: bool = False, ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: - assert not ( - missing := [m for m in block_shapes if m not in self.members] - ), f"`block_shapes` specified for unknown members: {missing}" - assert not ( - missing := [m for m in halo if m not in block_shapes] - ), f"`halo` specified for members without `block_shape`: {missing}" + assert not (missing := [m for m in block_shapes if m not in self.members]), ( + f"`block_shapes` specified for unknown members: {missing}" + ) + assert not (missing := [m for m in halo if m not in block_shapes]), ( + f"`halo` specified for members without `block_shape`: {missing}" + ) n_blocks, blocks = split_multiple_shapes_into_blocks( shapes=self.shape, diff --git a/src/bioimageio/core/tensor.py b/src/bioimageio/core/tensor.py index 408865de..17358b00 100644 --- a/src/bioimageio/core/tensor.py +++ b/src/bioimageio/core/tensor.py @@ -19,14 +19,13 @@ import numpy as np import xarray as xr +from bioimageio.spec.model import v0_5 from loguru import logger from numpy.typing import DTypeLike, NDArray from typing_extensions import Self, assert_never -from bioimageio.spec.model import v0_5 - from ._magic_tensor_ops import MagicTensorOpsMixin -from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis +from .axis import AxisId, AxisInfo, AxisLike, PerAxis from .common import ( CropWhere, DTypeStr, @@ -187,10 +186,12 @@ def from_numpy( if dims is None: return cls._interprete_array_wo_known_axes(array) - elif isinstance(dims, (str, Axis, v0_5.AxisBase)): - dims = [dims] + elif isinstance(dims, collections.abc.Sequence): + dim_seq = list(dims) + else: + dim_seq = [dims] - axis_infos = [AxisInfo.create(a) for a in dims] + axis_infos = [AxisInfo.create(a) for a in dim_seq] original_shape = tuple(array.shape) successful_view = _get_array_view(array, axis_infos) diff --git a/tests/conftest.py b/tests/conftest.py index dd3bbec0..06c89641 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,7 +129,9 @@ KERAS_MODELS = ( [] if keras is None - else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] + else ["unet2d_keras"] + if tf_major_version == 1 + else ["unet2d_keras_tf2"] ) TENSORFLOW_JS_MODELS: List[str] = [] # TODO: add a tensorflow_js example model @@ -230,7 +232,9 @@ def convert_to_onnx(request: FixtureRequest): params=( [] if tf_major_version is None - else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] + else ["unet2d_keras"] + if tf_major_version == 1 + else ["unet2d_keras_tf2"] ) ) def unet2d_keras(request: FixtureRequest): diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py index 615d97ed..394f7e0e 100644 --- a/tests/test_prediction_pipeline.py +++ b/tests/test_prediction_pipeline.py @@ -18,9 +18,9 @@ def _test_prediction_pipeline( ) bio_model = load_description(model_package) - assert isinstance( - bio_model, (ModelDescr, ModelDescr04) - ), bio_model.validation_summary.format() + assert isinstance(bio_model, (ModelDescr, ModelDescr04)), ( + bio_model.validation_summary.format() + ) pp = create_prediction_pipeline( bioimageio_model=bio_model, weight_format=weights_format ) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index c57980bd..1d537ce6 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -23,7 +23,6 @@ ], ) def test_transpose_tensor_2d(axes: Sequence[str]): - tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) assert transposed.ndim == len(axes)