From d35832418431b089ebd19d2146dc0937d5771e57 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:23:35 +0200 Subject: [PATCH 01/11] Refactor inference into runtime workflow modules --- docs/documentation.md | 3 + slide2vec/inference.py | 757 ++--------------------- slide2vec/runtime/__init__.py | 2 + slide2vec/runtime/batching.py | 441 +++++++++++++ slide2vec/runtime/hierarchical.py | 105 ++++ slide2vec/runtime/progress_bridge.py | 52 ++ slide2vec/runtime/serialization.py | 119 ++++ slide2vec/runtime/types.py | 37 ++ tasks/lessons.md | 2 + tests/test_architecture_runtime_split.py | 23 + 10 files changed, 831 insertions(+), 710 deletions(-) create mode 100644 slide2vec/runtime/__init__.py create mode 100644 slide2vec/runtime/batching.py create mode 100644 slide2vec/runtime/hierarchical.py create mode 100644 slide2vec/runtime/progress_bridge.py create mode 100644 slide2vec/runtime/serialization.py create mode 100644 slide2vec/runtime/types.py create mode 100644 tests/test_architecture_runtime_split.py diff --git a/docs/documentation.md b/docs/documentation.md index c4e1f62..7cab558 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -2,6 +2,9 @@ ## 2026-04-18 +- Split `slide2vec.inference` into workflow-scoped internal runtime helpers under `slide2vec.runtime` (`batching`, `hierarchical`, `progress_bridge`, `serialization`, `types`) while keeping `slide2vec.inference` as the stable orchestration entrypoint. +- Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. + - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. - Split the live tiling UI into a coordinates-extraction bar plus a separate preview-generation bar, and moved the final tiling summary into a dedicated `tiling.summary` event so it prints once at the very end. diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 2d4e52f..1b84491 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -9,7 +9,7 @@ import threading import time from contextlib import contextmanager, nullcontext -from dataclasses import dataclass, replace +from dataclasses import replace from pathlib import Path from types import SimpleNamespace from typing import Any, Callable, Sequence @@ -18,11 +18,43 @@ import pandas as pd import torch from hs2p import SlideSpec, FilterConfig, PreviewConfig, SegmentationConfig, TilingConfig, load_tiling_result, tile_slides -from hs2p import progress as hs2p_progress from hs2p.utils.stderr import run_with_filtered_stderr import numpy as np -from transformers.image_processing_utils import BaseImageProcessor +from slide2vec.runtime.batching import ( + autocast_dtype as _autocast_dtype, + build_batch_preprocessor as _build_batch_preprocessor, + build_batch_preprocessor_for_tile_images as _build_batch_preprocessor_for_tile_images, + build_batch_transform_spec as _build_batch_transform_spec, + center_crop_batch as _center_crop_batch, + embedding_dataloader_kwargs as _embedding_dataloader_kwargs, + interp_mode_to_str as _interp_mode_to_str, + iter_transform_steps as _iter_transform_steps, + normalize_hw as _normalize_hw, + prepare_batch_tensor as _prepare_batch_tensor, + resize_image_batch as _resize_image_batch, + resolve_device as _resolve_device, + run_forward_pass as _run_forward_pass, + should_suppress_cucim_dataloader_stderr as _should_suppress_cucim_dataloader_stderr, + uses_cuda_runtime as _uses_cuda_runtime, +) +from slide2vec.runtime.hierarchical import ( + build_hierarchical_index as _build_hierarchical_index, + is_hierarchical_preprocessing as _is_hierarchical_preprocessing, + num_embedding_items as _num_embedding_items, + num_tiles as _num_tiles, + resolve_hierarchical_geometry as _resolve_hierarchical_geometry, +) +from slide2vec.runtime.progress_bridge import ( + bridge_hs2p_progress_to_slide2vec as _bridge_hs2p_progress_to_slide2vec, +) +from slide2vec.runtime.serialization import ( + deserialize_execution, + deserialize_preprocessing, + serialize_execution as _serialize_execution_base, + serialize_model as _serialize_model, + serialize_preprocessing as _serialize_preprocessing, +) from slide2vec.api import ( EmbeddedPatient, EmbeddedSlide, @@ -52,18 +84,15 @@ from slide2vec.model_settings import canonicalize_model_name from slide2vec.runtime_types import LoadedModel from slide2vec.progress import ( - NullProgressReporter, - ProgressEvent as Slide2VecProgressEvent, emit_progress, emit_progress_event, - get_progress_reporter, read_progress_events, read_tiling_progress_snapshot, ) +from slide2vec.utils.coordinates import coordinate_arrays from slide2vec.utils.log_utils import suppress_c_stderr from slide2vec.data.dataset import BatchTileCollator, TileIndexDataset from slide2vec.data.tile_reader import OnTheFlyBatchTileCollator, OnTheFlyHierarchicalBatchCollator -from slide2vec.utils.coordinates import coordinate_arrays from slide2vec.utils.tiling_io import ( load_embedding_process_df, load_patient_id_mapping, @@ -75,173 +104,18 @@ from slide2vec.utils.utils import cpu_worker_limit, slurm_cpu_limit -@dataclass(frozen=True, kw_only=True) -class BatchTransformSpec: - resize_size: tuple[int, int] | None - center_crop_size: tuple[int, int] | None - mean: tuple[float, ...] | None - std: tuple[float, ...] | None - resize_interpolation: str = "bilinear" - - -_BRIDGED_HS2P_PROGRESS_KINDS = { - "backend.selected", - "tissue.started", - "tissue.progress", - "tissue.finished", - "tiling.progress", - "tiling.finished", - "preview.started", - "preview.progress", - "preview.finished", -} - - -class _Hs2pProgressBridge: - def __init__(self, downstream) -> None: - self._downstream = downstream - - def emit(self, event) -> None: - if event.kind not in _BRIDGED_HS2P_PROGRESS_KINDS: - return - self._downstream.emit( - Slide2VecProgressEvent(kind=event.kind, payload=dict(event.payload)) - ) - - def close(self) -> None: - return None - - def write_log(self, message: str, *, stream=None) -> None: - if hasattr(self._downstream, "write_log"): - self._downstream.write_log(message, stream=stream) - - -@contextmanager -def _bridge_hs2p_progress_to_slide2vec(): - downstream = get_progress_reporter() - if isinstance(downstream, NullProgressReporter): - yield - return - bridge = _Hs2pProgressBridge(downstream) - with hs2p_progress.activate_progress_reporter(bridge): - yield - - -@dataclass(kw_only=True) -class PreparedBatch: - indices: Any - image: Any - loader_wait_ms: float - preprocess_ms: float - ready_wait_ms: float = 0.0 - worker_batch_ms: float = 0.0 - reader_open_ms: float = 0.0 - reader_read_ms: float = 0.0 - - -@dataclass(frozen=True, kw_only=True) -class HierarchicalIndex: - flat_index: np.ndarray - region_index: np.ndarray - subtile_index_within_region: np.ndarray - subtile_x: np.ndarray - subtile_y: np.ndarray - num_regions: int - tiles_per_region: int - - -def _is_hierarchical_preprocessing(preprocessing: PreprocessingConfig | None) -> bool: - if preprocessing is None: - return False - return preprocessing.region_tile_multiple is not None or preprocessing.requested_region_size_px is not None - - -def _resolve_hierarchical_geometry(preprocessing: PreprocessingConfig, tiling_result) -> dict[str, int]: - if preprocessing.region_tile_multiple is None: - raise ValueError("Hierarchical preprocessing requires region_tile_multiple") - if preprocessing.requested_region_size_px is None: - raise ValueError("Hierarchical preprocessing requires requested_region_size_px") - requested_tile_size_px = int(preprocessing.requested_tile_size_px) - requested_region_size_px = int(preprocessing.requested_region_size_px) - requested_spacing_um = float(preprocessing.requested_spacing_um) - multiple = int(preprocessing.region_tile_multiple) - if requested_region_size_px % multiple != 0: - raise ValueError("requested_region_size_px must be divisible by region_tile_multiple") - read_spacing_um = float(getattr(tiling_result, "read_spacing_um")) - base_spacing_um = float(getattr(tiling_result, "base_spacing_um")) - if abs(read_spacing_um - requested_spacing_um) / requested_spacing_um <= float(preprocessing.tolerance): - read_tile_size_px = requested_tile_size_px - else: - read_tile_size_px = int( - round(requested_tile_size_px * requested_spacing_um / read_spacing_um) - ) - read_region_size_px = read_tile_size_px * multiple - # Use the actual read geometry that produced the tile crop. When the - # resolved spacing is considered equivalent to the requested spacing, - # this keeps the level-0 footprint aligned with the real crop size. - tile_size_lv0 = int(round(read_tile_size_px * read_spacing_um / base_spacing_um)) - return { - "region_tile_multiple": multiple, - "tiles_per_region": multiple * multiple, - "requested_tile_size_px": requested_tile_size_px, - "read_tile_size_px": read_tile_size_px, - "requested_region_size_px": requested_region_size_px, - "read_region_size_px": read_region_size_px, - "tile_size_lv0": tile_size_lv0, - } - - -def _build_hierarchical_index( - tiling_result, +def _serialize_execution( + execution: ExecutionOptions, *, - region_tile_multiple: int, - tile_size_lv0: int | None = None, -) -> HierarchicalIndex: - x_values, y_values = coordinate_arrays(tiling_result) - num_regions = int(len(x_values)) - multiple = int(region_tile_multiple) - if multiple < 2: - raise ValueError("region_tile_multiple must be at least 2") - subtile_size_lv0 = ( - int(tile_size_lv0) - if tile_size_lv0 is not None - else int(getattr(tiling_result, "tile_size_lv0")) // multiple - ) - tiles_per_region = multiple * multiple - if num_regions == 0: - empty = np.empty(0, dtype=np.int64) - return HierarchicalIndex( - flat_index=empty, - region_index=np.empty(0, dtype=np.int32), - subtile_index_within_region=np.empty(0, dtype=np.int32), - subtile_x=empty, - subtile_y=empty, - num_regions=0, - tiles_per_region=tiles_per_region, - ) - rows, cols = np.divmod(np.arange(tiles_per_region, dtype=np.int32), multiple) - offsets_x = cols.astype(np.int64) * subtile_size_lv0 - offsets_y = rows.astype(np.int64) * subtile_size_lv0 - region_x = np.asarray(x_values, dtype=np.int64)[:, np.newaxis] - region_y = np.asarray(y_values, dtype=np.int64)[:, np.newaxis] - subtile_x = (region_x + offsets_x[np.newaxis, :]).reshape(-1) - subtile_y = (region_y + offsets_y[np.newaxis, :]).reshape(-1) - return HierarchicalIndex( - flat_index=np.arange(num_regions * tiles_per_region, dtype=np.int64), - region_index=np.repeat(np.arange(num_regions, dtype=np.int32), tiles_per_region), - subtile_index_within_region=np.tile(np.arange(tiles_per_region, dtype=np.int32), num_regions), - subtile_x=subtile_x, - subtile_y=subtile_y, - num_regions=num_regions, - tiles_per_region=tiles_per_region, - ) - - -def _num_embedding_items(tiling_result, preprocessing: PreprocessingConfig | None) -> int: - if not _is_hierarchical_preprocessing(preprocessing): - return _num_tiles(tiling_result) - geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) - return _num_tiles(tiling_result) * int(geometry["tiles_per_region"]) + preprocessing: PreprocessingConfig | None = None, +) -> dict[str, Any]: + effective_num_workers = None + if preprocessing is not None and preprocessing.on_the_fly and preprocessing.read_tiles_from is None: + effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + return _serialize_execution_base( + execution, + effective_num_workers=effective_num_workers, + ) @@ -2110,413 +1984,6 @@ def _write_hierarchical_embedding_artifact( -def _embedding_dataloader_kwargs(loaded: LoadedModel, execution: ExecutionOptions) -> dict[str, Any]: - resolved_num_workers = execution.resolved_num_workers() - kwargs: dict[str, Any] = { - "num_workers": resolved_num_workers, - "pin_memory": _uses_cuda_runtime(loaded.device), - } - if resolved_num_workers > 0: - kwargs["persistent_workers"] = bool(execution.persistent_workers) - kwargs["prefetch_factor"] = int(execution.prefetch_factor) - return kwargs - - - -def _build_batch_preprocessor( - loaded: LoadedModel, - tiling_result, -): - return _build_batch_preprocessor_for_tile_images( - loaded, - requested_tile_size_px=int(getattr(tiling_result, "requested_tile_size_px")), - ) - - -def _build_batch_preprocessor_for_tile_images( - loaded: LoadedModel, - *, - requested_tile_size_px: int, -): - spec = _build_batch_transform_spec(loaded.transforms) - if spec is None: - logging.getLogger(__name__).warning( - "Batched preprocessing is disabled for %s because the transform stack is not supported; " - "falling back to per-item preprocessing", - loaded.name, - ) - return None - - def preprocess(batch): - image = _prepare_batch_tensor(batch) - if spec.resize_size is None: - image = _resize_image_batch( - image, - (int(requested_tile_size_px), int(requested_tile_size_px)), - ) - image = _apply_batch_transform_spec(image, spec) - if image.device != loaded.device: - image = image.to(loaded.device, non_blocking=str(loaded.device).startswith("cuda")) - return image.contiguous() - - return preprocess - - -def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: - if isinstance(transforms, BaseImageProcessor): - crop_size = transforms.crop_size if hasattr(transforms, "crop_size") else None - size = transforms.size if hasattr(transforms, "size") else None - resize_size = _normalize_hw(crop_size or size) - if resize_size is None: - return None - mean = transforms.image_mean if hasattr(transforms, "image_mean") else None - std = transforms.image_std if hasattr(transforms, "image_std") else None - return BatchTransformSpec( - resize_size=resize_size, - center_crop_size=None, - mean=tuple(float(value) for value in mean) if mean is not None else None, - std=tuple(float(value) for value in std) if std is not None else None, - ) - - transform_steps = _iter_transform_steps(transforms) - if transform_steps is None: - return None - - resize_size = None - resize_interpolation = "bilinear" - center_crop_size = None - mean = None - std = None - supported_step_names = { - "Resize", - "CenterCrop", - "Normalize", - "ToTensor", - "MaybeToTensor", - "ToImage", - "ConvertImageDtype", - } - for step in transform_steps: - step_name = type(step).__name__ - if step_name not in supported_step_names: - return None - if step_name == "Resize": - resize_size = _normalize_hw(step.size if hasattr(step, "size") else None) - resize_interpolation = _interp_mode_to_str(step.interpolation if hasattr(step, "interpolation") else None) - elif step_name == "CenterCrop": - center_crop_size = _normalize_hw(step.size if hasattr(step, "size") else None) - elif step_name == "Normalize": - mean = tuple(float(value) for value in step.mean) - std = tuple(float(value) for value in step.std) - return BatchTransformSpec( - resize_size=resize_size, - center_crop_size=center_crop_size, - mean=mean, - std=std, - resize_interpolation=resize_interpolation, - ) - - -def _iter_transform_steps(transforms): - transform_steps = transforms.transforms if hasattr(transforms, "transforms") else None - if transform_steps is None: - return None - flattened = [] - for step in transform_steps: - nested = _iter_transform_steps(step) - if nested is not None: - flattened.extend(nested) - else: - flattened.append(step) - return flattened -def _prepare_batch_tensor(image): - if image.dtype == torch.uint8: - return image.float().div(255.0) - return image.float() - - -def _apply_transforms_itemwise(image, transforms): - if not torch.is_tensor(image) or image.ndim <= 3: - return transforms(image) - - transformed_items = [transforms(sample) for sample in image.cpu()] - if not transformed_items: - return image.new_empty((0,), dtype=torch.float32) - if not all(torch.is_tensor(item) for item in transformed_items): - transformed_items = [torch.as_tensor(item) for item in transformed_items] - return torch.stack(transformed_items, dim=0) - - -def _interp_mode_to_str(interp_mode) -> str: - """Map a torchvision InterpolationMode to the string accepted by F.interpolate.""" - if interp_mode is None: - return "bilinear" - name = str(interp_mode).upper() - if "BICUBIC" in name: - return "bicubic" - if "NEAREST" in name: - return "nearest" - return "bilinear" - - -def _resize_image_batch(image, size: tuple[int, int], *, mode: str = "bilinear"): - if tuple(int(dim) for dim in image.shape[-2:]) == size: - return image - - align_corners = False if mode in ("bilinear", "bicubic") else None - kwargs = {"antialias": True} if mode in ("bilinear", "bicubic") else {} - return torch.nn.functional.interpolate( - image, - size=size, - mode=mode, - **({"align_corners": align_corners} if align_corners is not None else {}), - **kwargs, - ) - - -def _apply_batch_transform_spec(image, spec: BatchTransformSpec): - - if spec.resize_size is not None: - image = _resize_image_batch(image, spec.resize_size, mode=spec.resize_interpolation) - if spec.center_crop_size is not None: - image = _center_crop_batch(image, spec.center_crop_size) - if spec.mean is not None and spec.std is not None: - mean = torch.tensor(spec.mean, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) - std = torch.tensor(spec.std, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) - image = (image - mean) / std - return image -def _normalize_hw(value) -> tuple[int, int] | None: - if value is None: - return None - if isinstance(value, int): - return (int(value), int(value)) - if isinstance(value, (tuple, list)): - if len(value) == 1: - return (int(value[0]), int(value[0])) - if len(value) >= 2: - return (int(value[0]), int(value[1])) - return None - if isinstance(value, dict): - if "height" in value and "width" in value: - return (int(value["height"]), int(value["width"])) - if "shortest_edge" in value: - edge = int(value["shortest_edge"]) - return (edge, edge) - return None - - -def _center_crop_batch(image, size: tuple[int, int]): - target_h, target_w = size - height, width = int(image.shape[-2]), int(image.shape[-1]) - crop_h = min(target_h, height) - crop_w = min(target_w, width) - top = max((height - crop_h) // 2, 0) - left = max((width - crop_w) // 2, 0) - return image[..., top : top + crop_h, left : left + crop_w] - - -class _BatchPrefetcher: - def __init__(self, dataloader, loaded: LoadedModel, batch_preprocessor): - self.iterator = iter(dataloader) - self.loaded = loaded - self.batch_preprocessor = batch_preprocessor - self.copy_stream = self._make_copy_stream() - self._pinned_host_buffer = None - self._next_batch: PreparedBatch | None = None - self._preload() - - def _unpack_loader_batch(self, batch): - if isinstance(batch, (tuple, list)): - if len(batch) == 3 and isinstance(batch[2], dict): - return batch[0], batch[1], batch[2] - if len(batch) == 2: - return batch[0], batch[1], {} - raise ValueError("Expected the embedding dataloader to yield (indices, image) or (indices, image, timing)") - - def _make_copy_stream(self): - if not _uses_cuda_runtime(self.loaded.device): - return None - return torch.cuda.Stream(device=self.loaded.device) - - def _stage_host_batch(self, image): - if self.copy_stream is None or not torch.is_tensor(image): - return image - if image.device.type != "cpu" or image.is_pinned(): - return image - if ( - self._pinned_host_buffer is None - or tuple(self._pinned_host_buffer.shape) != tuple(image.shape) - or self._pinned_host_buffer.dtype != image.dtype - ): - self._pinned_host_buffer = torch.empty( - image.shape, - dtype=image.dtype, - pin_memory=True, - ) - self._pinned_host_buffer.copy_(image) - return self._pinned_host_buffer - - def _prepare_batch(self, image): - preprocess_start = time.perf_counter() - if self.batch_preprocessor is not None: - prepared = self.batch_preprocessor(image) - else: - prepared = _apply_transforms_itemwise(image, self.loaded.transforms) - if torch.is_tensor(prepared) and prepared.device != self.loaded.device: - prepared = prepared.to( - self.loaded.device, - non_blocking=_uses_cuda_runtime(self.loaded.device), - ) - preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 - return prepared, preprocess_ms - - def _preload(self) -> None: - wait_start = time.perf_counter() - try: - batch = next(self.iterator) - except StopIteration: - self._next_batch = None - return - loader_wait_ms = (time.perf_counter() - wait_start) * 1000.0 - indices, image, timing = self._unpack_loader_batch(batch) - worker_batch_ms = float(timing["worker_batch_ms"]) if "worker_batch_ms" in timing else 0.0 - reader_open_ms = float(timing["reader_open_ms"]) if "reader_open_ms" in timing else 0.0 - reader_read_ms = float(timing["reader_read_ms"]) if "reader_read_ms" in timing else 0.0 - if self.copy_stream is None or self.batch_preprocessor is None: - prepared, preprocess_ms = self._prepare_batch(image) - self._next_batch = PreparedBatch( - indices=indices, - image=prepared, - loader_wait_ms=loader_wait_ms, - preprocess_ms=preprocess_ms, - worker_batch_ms=worker_batch_ms, - reader_open_ms=reader_open_ms, - reader_read_ms=reader_read_ms, - ) - return - - staged = self._stage_host_batch(image) - preprocess_start = time.perf_counter() - with torch.cuda.stream(self.copy_stream): - prepared = self.batch_preprocessor(staged) if self.batch_preprocessor is not None else staged.to( - self.loaded.device, - non_blocking=True, - ) - preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 - self._next_batch = PreparedBatch( - indices=indices, - image=prepared, - loader_wait_ms=loader_wait_ms, - preprocess_ms=preprocess_ms, - worker_batch_ms=worker_batch_ms, - reader_open_ms=reader_open_ms, - reader_read_ms=reader_read_ms, - ) - - def __iter__(self): - return self - - def __next__(self) -> PreparedBatch: - if self._next_batch is None: - raise StopIteration - current = self._next_batch - if self.copy_stream is not None: - ready_start = time.perf_counter() - current_stream = torch.cuda.current_stream(device=self.loaded.device) - current_stream.wait_stream(self.copy_stream) - current.ready_wait_ms = (time.perf_counter() - ready_start) * 1000.0 - self._preload() - return current - - -def _run_forward_pass( - dataloader, - loaded: LoadedModel, - autocast_context, - *, - batch_preprocessor=None, - sample_id: str | None = None, - total_items: int | None = None, - unit_label: str = "tile", - return_indices: bool = False, -): - - outputs = [] - batch_indices = [] if return_indices else None - processed = 0 - batch_index = 0 - prefetcher_context = ( - suppress_c_stderr() - if _should_suppress_cucim_dataloader_stderr(dataloader) - else nullcontext() - ) - with prefetcher_context: - prefetcher = _BatchPrefetcher(dataloader, loaded, batch_preprocessor) - with torch.inference_mode(), autocast_context: - for prepared_batch in prefetcher: - image = prepared_batch.image - forward_start = time.perf_counter() - embedding = loaded.model.encode_tiles(image).detach().cpu() - forward_ms = (time.perf_counter() - forward_start) * 1000.0 - outputs.append(embedding) - if batch_indices is not None: - batch_indices.append(torch.as_tensor(prepared_batch.indices, dtype=torch.long).detach().cpu()) - processed += int(embedding.shape[0]) - batch_index += 1 - batch_total_ms = ( - prepared_batch.loader_wait_ms - + prepared_batch.ready_wait_ms - + prepared_batch.preprocess_ms - + forward_ms - ) - gpu_busy_fraction = ( - (prepared_batch.ready_wait_ms + prepared_batch.preprocess_ms + forward_ms) / batch_total_ms - if batch_total_ms > 0 - else 0.0 - ) - emit_progress( - "embedding.batch.timing", - sample_id=sample_id, - batch_index=batch_index, - batch_size=int(embedding.shape[0]), - loader_wait_ms=round(prepared_batch.loader_wait_ms, 4), - ready_wait_ms=round(prepared_batch.ready_wait_ms, 4), - preprocess_ms=round(prepared_batch.preprocess_ms, 4), - worker_batch_ms=round(prepared_batch.worker_batch_ms, 4), - reader_open_ms=round(prepared_batch.reader_open_ms, 4), - reader_read_ms=round(prepared_batch.reader_read_ms, 4), - forward_ms=round(forward_ms, 4), - gpu_busy_fraction=round(gpu_busy_fraction, 4), - unit=unit_label, - ) - if sample_id is not None: - emit_progress( - "embedding.tile.progress", - sample_id=sample_id, - processed=processed, - total=int(total_items or processed), - unit=unit_label, - ) - if not outputs: - feature_dim = loaded.tile_feature_dim if loaded.tile_feature_dim is not None else loaded.feature_dim - empty = torch.empty((0, int(feature_dim)), dtype=torch.float32) - if batch_indices is not None: - return torch.empty((0,), dtype=torch.long), empty - return empty - embeddings = torch.cat(outputs, dim=0) - if batch_indices is not None: - return torch.cat(batch_indices, dim=0), embeddings - return embeddings - - - -def _resolve_device(device: str, default_device): - - if device == "auto": - return default_device - return torch.device(device) - - def _describe_device_mode(model, execution: ExecutionOptions) -> str: requested_device = getattr(model, "_requested_device", None) if requested_device == "cpu": @@ -3445,136 +2912,6 @@ def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): return torch.load(payload_path, map_location="cpu", weights_only=True) -def _num_tiles(tiling_result) -> int: - x_values, _y_values = coordinate_arrays(tiling_result) - return int(len(x_values)) - - -def _serialize_model(model) -> dict[str, Any]: - return { - "name": model.name, - "output_variant": model._output_variant if hasattr(model, "_output_variant") else None, - "allow_non_recommended_settings": bool( - getattr(model, "allow_non_recommended_settings", False) - ), - } - - -def _serialize_preprocessing(preprocessing: PreprocessingConfig) -> dict[str, Any]: - return { - "backend": preprocessing.backend, - "requested_spacing_um": preprocessing.requested_spacing_um, - "requested_tile_size_px": preprocessing.requested_tile_size_px, - "requested_region_size_px": preprocessing.requested_region_size_px, - "region_tile_multiple": preprocessing.region_tile_multiple, - "tolerance": preprocessing.tolerance, - "overlap": preprocessing.overlap, - "tissue_threshold": preprocessing.tissue_threshold, - "read_coordinates_from": str(preprocessing.read_coordinates_from) if preprocessing.read_coordinates_from is not None else None, - "read_tiles_from": str(preprocessing.read_tiles_from) if preprocessing.read_tiles_from is not None else None, - "resume": preprocessing.resume, - "segmentation": dict(preprocessing.segmentation), - "filtering": dict(preprocessing.filtering), - "preview": dict(preprocessing.preview), - } - - -def _serialize_execution( - execution: ExecutionOptions, - *, - preprocessing: PreprocessingConfig | None = None, -) -> dict[str, Any]: - effective_num_workers = execution.num_workers - if preprocessing is not None and preprocessing.on_the_fly and preprocessing.read_tiles_from is None: - effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) - return { - "output_dir": str(execution.output_dir) if execution.output_dir is not None else None, - "output_format": execution.output_format, - "batch_size": execution.batch_size, - "num_workers": effective_num_workers, - "num_preprocessing_workers": execution.num_preprocessing_workers, - "num_gpus": execution.num_gpus, - "precision": execution.precision, - "prefetch_factor": execution.prefetch_factor, - "persistent_workers": execution.persistent_workers, - "save_tile_embeddings": execution.save_tile_embeddings, - "save_latents": execution.save_latents, - } - - -def deserialize_preprocessing(payload: dict[str, Any]) -> PreprocessingConfig: - read_coordinates_from = ( - Path(payload["read_coordinates_from"]) - if "read_coordinates_from" in payload and payload["read_coordinates_from"] - else None - ) - read_tiles_from = ( - Path(payload["read_tiles_from"]) - if "read_tiles_from" in payload and payload["read_tiles_from"] - else None - ) - return PreprocessingConfig( - backend=payload["backend"], - requested_spacing_um=float(payload["requested_spacing_um"]), - requested_tile_size_px=int(payload["requested_tile_size_px"]), - requested_region_size_px=( - int(payload["requested_region_size_px"]) - if "requested_region_size_px" in payload and payload["requested_region_size_px"] is not None - else None - ), - region_tile_multiple=( - int(payload["region_tile_multiple"]) - if "region_tile_multiple" in payload and payload["region_tile_multiple"] is not None - else None - ), - tolerance=float(payload["tolerance"]), - overlap=float(payload["overlap"]), - tissue_threshold=float(payload["tissue_threshold"]), - read_coordinates_from=read_coordinates_from, - read_tiles_from=read_tiles_from, - resume=bool(payload["resume"]) if "resume" in payload else False, - segmentation=dict(payload["segmentation"]) if "segmentation" in payload else {}, - filtering=dict(payload["filtering"]) if "filtering" in payload else {}, - preview=dict(payload["preview"]) if "preview" in payload else {}, - ) - - -def deserialize_execution(payload: dict[str, Any]) -> ExecutionOptions: - output_dir = payload["output_dir"] if "output_dir" in payload else None - batch_size = payload["batch_size"] if "batch_size" in payload else None - num_workers = payload["num_workers"] if "num_workers" in payload else None - num_gpus = payload["num_gpus"] if "num_gpus" in payload else 1 - precision = payload["precision"] if "precision" in payload else "fp32" - prefetch_factor = payload["prefetch_factor"] if "prefetch_factor" in payload else 4 - persistent_workers = ( - bool(payload["persistent_workers"]) if "persistent_workers" in payload else True - ) - save_tile_embeddings = ( - bool(payload["save_tile_embeddings"]) if "save_tile_embeddings" in payload else False - ) - save_latents = bool(payload["save_latents"]) if "save_latents" in payload else False - return ExecutionOptions( - output_dir=Path(output_dir) if output_dir is not None else None, - output_format=payload["output_format"] if "output_format" in payload else "pt", - batch_size=batch_size, - num_workers=int(num_workers) if num_workers is not None else None, - num_gpus=int(num_gpus), - precision=precision, - prefetch_factor=int(prefetch_factor), - persistent_workers=persistent_workers, - save_tile_embeddings=save_tile_embeddings, - save_latents=save_latents, - ) - - -def _autocast_dtype(torch, precision: str): - if precision == "fp16": - return torch.float16 - if precision == "bf16": - return torch.bfloat16 - return None - - def _collect_pipeline_artifacts( slide_records: Sequence[SlideSpec], *, diff --git a/slide2vec/runtime/__init__.py b/slide2vec/runtime/__init__.py new file mode 100644 index 0000000..af4011c --- /dev/null +++ b/slide2vec/runtime/__init__.py @@ -0,0 +1,2 @@ +"""Internal runtime helpers for slide2vec inference orchestration.""" + diff --git a/slide2vec/runtime/batching.py b/slide2vec/runtime/batching.py new file mode 100644 index 0000000..4cae17c --- /dev/null +++ b/slide2vec/runtime/batching.py @@ -0,0 +1,441 @@ +from __future__ import annotations + +import logging +import time +from contextlib import nullcontext +from typing import Any + +import torch +from transformers.image_processing_utils import BaseImageProcessor + +from slide2vec.progress import emit_progress +from slide2vec.runtime_types import LoadedModel +from slide2vec.utils.log_utils import suppress_c_stderr + +from .types import BatchTransformSpec, PreparedBatch + + +def uses_cuda_runtime(device) -> bool: + return str(device).startswith("cuda") and torch.cuda.is_available() + + +def embedding_dataloader_kwargs(loaded: LoadedModel, execution) -> dict[str, Any]: + resolved_num_workers = execution.resolved_num_workers() + kwargs: dict[str, Any] = { + "num_workers": resolved_num_workers, + "pin_memory": uses_cuda_runtime(loaded.device), + } + if resolved_num_workers > 0: + kwargs["persistent_workers"] = bool(execution.persistent_workers) + kwargs["prefetch_factor"] = int(execution.prefetch_factor) + return kwargs + + +def should_suppress_cucim_dataloader_stderr(dataloader) -> bool: + if dataloader is None: + return False + collate_fn = getattr(dataloader, "collate_fn", None) + if collate_fn is None: + return False + return bool(getattr(collate_fn, "_suppress_cucim_stderr", False)) + + +def build_batch_preprocessor( + loaded: LoadedModel, + tiling_result, +): + return build_batch_preprocessor_for_tile_images( + loaded, + requested_tile_size_px=int(getattr(tiling_result, "requested_tile_size_px")), + ) + + +def build_batch_preprocessor_for_tile_images( + loaded: LoadedModel, + *, + requested_tile_size_px: int, +): + spec = build_batch_transform_spec(loaded.transforms) + if spec is None: + logging.getLogger(__name__).warning( + "Batched preprocessing is disabled for %s because the transform stack is not supported; " + "falling back to per-item preprocessing", + loaded.name, + ) + return None + + def preprocess(batch): + image = prepare_batch_tensor(batch) + if spec.resize_size is None: + image = resize_image_batch( + image, + (int(requested_tile_size_px), int(requested_tile_size_px)), + ) + image = apply_batch_transform_spec(image, spec) + if image.device != loaded.device: + image = image.to(loaded.device, non_blocking=str(loaded.device).startswith("cuda")) + return image.contiguous() + + return preprocess + + +def build_batch_transform_spec(transforms) -> BatchTransformSpec | None: + if isinstance(transforms, BaseImageProcessor): + crop_size = transforms.crop_size if hasattr(transforms, "crop_size") else None + size = transforms.size if hasattr(transforms, "size") else None + resize_size = normalize_hw(crop_size or size) + if resize_size is None: + return None + mean = transforms.image_mean if hasattr(transforms, "image_mean") else None + std = transforms.image_std if hasattr(transforms, "image_std") else None + return BatchTransformSpec( + resize_size=resize_size, + center_crop_size=None, + mean=tuple(float(value) for value in mean) if mean is not None else None, + std=tuple(float(value) for value in std) if std is not None else None, + ) + + transform_steps = iter_transform_steps(transforms) + if transform_steps is None: + return None + + resize_size = None + resize_interpolation = "bilinear" + center_crop_size = None + mean = None + std = None + supported_step_names = { + "Resize", + "CenterCrop", + "Normalize", + "ToTensor", + "MaybeToTensor", + "ToImage", + "ConvertImageDtype", + } + for step in transform_steps: + step_name = type(step).__name__ + if step_name not in supported_step_names: + return None + if step_name == "Resize": + resize_size = normalize_hw(step.size if hasattr(step, "size") else None) + resize_interpolation = interp_mode_to_str(step.interpolation if hasattr(step, "interpolation") else None) + elif step_name == "CenterCrop": + center_crop_size = normalize_hw(step.size if hasattr(step, "size") else None) + elif step_name == "Normalize": + mean = tuple(float(value) for value in step.mean) + std = tuple(float(value) for value in step.std) + return BatchTransformSpec( + resize_size=resize_size, + center_crop_size=center_crop_size, + mean=mean, + std=std, + resize_interpolation=resize_interpolation, + ) + + +def iter_transform_steps(transforms): + transform_steps = transforms.transforms if hasattr(transforms, "transforms") else None + if transform_steps is None: + return None + flattened = [] + for step in transform_steps: + nested = iter_transform_steps(step) + if nested is not None: + flattened.extend(nested) + else: + flattened.append(step) + return flattened + + +def prepare_batch_tensor(image): + if image.dtype == torch.uint8: + return image.float().div(255.0) + return image.float() + + +def apply_transforms_itemwise(image, transforms): + if not torch.is_tensor(image) or image.ndim <= 3: + return transforms(image) + + transformed_items = [transforms(sample) for sample in image.cpu()] + if not transformed_items: + return image.new_empty((0,), dtype=torch.float32) + if not all(torch.is_tensor(item) for item in transformed_items): + transformed_items = [torch.as_tensor(item) for item in transformed_items] + return torch.stack(transformed_items, dim=0) + + +def interp_mode_to_str(interp_mode) -> str: + if interp_mode is None: + return "bilinear" + name = str(interp_mode).upper() + if "BICUBIC" in name: + return "bicubic" + if "NEAREST" in name: + return "nearest" + return "bilinear" + + +def resize_image_batch(image, size: tuple[int, int], *, mode: str = "bilinear"): + if tuple(int(dim) for dim in image.shape[-2:]) == size: + return image + + align_corners = False if mode in ("bilinear", "bicubic") else None + kwargs = {"antialias": True} if mode in ("bilinear", "bicubic") else {} + return torch.nn.functional.interpolate( + image, + size=size, + mode=mode, + **({"align_corners": align_corners} if align_corners is not None else {}), + **kwargs, + ) + + +def apply_batch_transform_spec(image, spec: BatchTransformSpec): + if spec.resize_size is not None: + image = resize_image_batch(image, spec.resize_size, mode=spec.resize_interpolation) + if spec.center_crop_size is not None: + image = center_crop_batch(image, spec.center_crop_size) + if spec.mean is not None and spec.std is not None: + mean = torch.tensor(spec.mean, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) + std = torch.tensor(spec.std, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) + image = (image - mean) / std + return image + + +def normalize_hw(value) -> tuple[int, int] | None: + if value is None: + return None + if isinstance(value, int): + return (int(value), int(value)) + if isinstance(value, (tuple, list)): + if len(value) == 1: + return (int(value[0]), int(value[0])) + if len(value) >= 2: + return (int(value[0]), int(value[1])) + return None + if isinstance(value, dict): + if "height" in value and "width" in value: + return (int(value["height"]), int(value["width"])) + if "shortest_edge" in value: + edge = int(value["shortest_edge"]) + return (edge, edge) + return None + + +def center_crop_batch(image, size: tuple[int, int]): + target_h, target_w = size + height, width = int(image.shape[-2]), int(image.shape[-1]) + crop_h = min(target_h, height) + crop_w = min(target_w, width) + top = max((height - crop_h) // 2, 0) + left = max((width - crop_w) // 2, 0) + return image[..., top : top + crop_h, left : left + crop_w] + + +class BatchPrefetcher: + def __init__(self, dataloader, loaded: LoadedModel, batch_preprocessor): + self.iterator = iter(dataloader) + self.loaded = loaded + self.batch_preprocessor = batch_preprocessor + self.copy_stream = self._make_copy_stream() + self._pinned_host_buffer = None + self._next_batch: PreparedBatch | None = None + self._preload() + + def _unpack_loader_batch(self, batch): + if isinstance(batch, (tuple, list)): + if len(batch) == 3 and isinstance(batch[2], dict): + return batch[0], batch[1], batch[2] + if len(batch) == 2: + return batch[0], batch[1], {} + raise ValueError("Expected the embedding dataloader to yield (indices, image) or (indices, image, timing)") + + def _make_copy_stream(self): + if not uses_cuda_runtime(self.loaded.device): + return None + return torch.cuda.Stream(device=self.loaded.device) + + def _stage_host_batch(self, image): + if self.copy_stream is None or not torch.is_tensor(image): + return image + if image.device.type != "cpu" or image.is_pinned(): + return image + if ( + self._pinned_host_buffer is None + or tuple(self._pinned_host_buffer.shape) != tuple(image.shape) + or self._pinned_host_buffer.dtype != image.dtype + ): + self._pinned_host_buffer = torch.empty( + image.shape, + dtype=image.dtype, + pin_memory=True, + ) + self._pinned_host_buffer.copy_(image) + return self._pinned_host_buffer + + def _prepare_batch(self, image): + preprocess_start = time.perf_counter() + if self.batch_preprocessor is not None: + prepared = self.batch_preprocessor(image) + else: + prepared = apply_transforms_itemwise(image, self.loaded.transforms) + if torch.is_tensor(prepared) and prepared.device != self.loaded.device: + prepared = prepared.to( + self.loaded.device, + non_blocking=uses_cuda_runtime(self.loaded.device), + ) + preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 + return prepared, preprocess_ms + + def _preload(self) -> None: + wait_start = time.perf_counter() + try: + batch = next(self.iterator) + except StopIteration: + self._next_batch = None + return + loader_wait_ms = (time.perf_counter() - wait_start) * 1000.0 + indices, image, timing = self._unpack_loader_batch(batch) + worker_batch_ms = float(timing["worker_batch_ms"]) if "worker_batch_ms" in timing else 0.0 + reader_open_ms = float(timing["reader_open_ms"]) if "reader_open_ms" in timing else 0.0 + reader_read_ms = float(timing["reader_read_ms"]) if "reader_read_ms" in timing else 0.0 + if self.copy_stream is None or self.batch_preprocessor is None: + prepared, preprocess_ms = self._prepare_batch(image) + self._next_batch = PreparedBatch( + indices=indices, + image=prepared, + loader_wait_ms=loader_wait_ms, + preprocess_ms=preprocess_ms, + worker_batch_ms=worker_batch_ms, + reader_open_ms=reader_open_ms, + reader_read_ms=reader_read_ms, + ) + return + + staged = self._stage_host_batch(image) + preprocess_start = time.perf_counter() + with torch.cuda.stream(self.copy_stream): + prepared = self.batch_preprocessor(staged) if self.batch_preprocessor is not None else staged.to( + self.loaded.device, + non_blocking=True, + ) + preprocess_ms = (time.perf_counter() - preprocess_start) * 1000.0 + self._next_batch = PreparedBatch( + indices=indices, + image=prepared, + loader_wait_ms=loader_wait_ms, + preprocess_ms=preprocess_ms, + worker_batch_ms=worker_batch_ms, + reader_open_ms=reader_open_ms, + reader_read_ms=reader_read_ms, + ) + + def __iter__(self): + return self + + def __next__(self) -> PreparedBatch: + if self._next_batch is None: + raise StopIteration + current = self._next_batch + if self.copy_stream is not None: + ready_start = time.perf_counter() + current_stream = torch.cuda.current_stream(device=self.loaded.device) + current_stream.wait_stream(self.copy_stream) + current.ready_wait_ms = (time.perf_counter() - ready_start) * 1000.0 + self._preload() + return current + + +def run_forward_pass( + dataloader, + loaded: LoadedModel, + autocast_context, + *, + batch_preprocessor=None, + sample_id: str | None = None, + total_items: int | None = None, + unit_label: str = "tile", + return_indices: bool = False, +): + outputs = [] + batch_indices = [] if return_indices else None + processed = 0 + batch_index = 0 + prefetcher_context = ( + suppress_c_stderr() + if should_suppress_cucim_dataloader_stderr(dataloader) + else nullcontext() + ) + with prefetcher_context: + prefetcher = BatchPrefetcher(dataloader, loaded, batch_preprocessor) + with torch.inference_mode(), autocast_context: + for prepared_batch in prefetcher: + image = prepared_batch.image + forward_start = time.perf_counter() + embedding = loaded.model.encode_tiles(image).detach().cpu() + forward_ms = (time.perf_counter() - forward_start) * 1000.0 + outputs.append(embedding) + if batch_indices is not None: + batch_indices.append(torch.as_tensor(prepared_batch.indices, dtype=torch.long).detach().cpu()) + processed += int(embedding.shape[0]) + batch_index += 1 + batch_total_ms = ( + prepared_batch.loader_wait_ms + + prepared_batch.ready_wait_ms + + prepared_batch.preprocess_ms + + forward_ms + ) + gpu_busy_fraction = ( + (prepared_batch.ready_wait_ms + prepared_batch.preprocess_ms + forward_ms) / batch_total_ms + if batch_total_ms > 0 + else 0.0 + ) + emit_progress( + "embedding.batch.timing", + sample_id=sample_id, + batch_index=batch_index, + batch_size=int(embedding.shape[0]), + loader_wait_ms=round(prepared_batch.loader_wait_ms, 4), + ready_wait_ms=round(prepared_batch.ready_wait_ms, 4), + preprocess_ms=round(prepared_batch.preprocess_ms, 4), + worker_batch_ms=round(prepared_batch.worker_batch_ms, 4), + reader_open_ms=round(prepared_batch.reader_open_ms, 4), + reader_read_ms=round(prepared_batch.reader_read_ms, 4), + forward_ms=round(forward_ms, 4), + gpu_busy_fraction=round(gpu_busy_fraction, 4), + unit=unit_label, + ) + if sample_id is not None: + emit_progress( + "embedding.tile.progress", + sample_id=sample_id, + processed=processed, + total=int(total_items or processed), + unit=unit_label, + ) + if not outputs: + feature_dim = loaded.tile_feature_dim if loaded.tile_feature_dim is not None else loaded.feature_dim + empty = torch.empty((0, int(feature_dim)), dtype=torch.float32) + if batch_indices is not None: + return torch.empty((0,), dtype=torch.long), empty + return empty + embeddings = torch.cat(outputs, dim=0) + if batch_indices is not None: + return torch.cat(batch_indices, dim=0), embeddings + return embeddings + + +def resolve_device(device: str, default_device): + if device == "auto": + return default_device + return torch.device(device) + + +def autocast_dtype(torch_module, precision: str): + if precision == "fp16": + return torch_module.float16 + if precision == "bf16": + return torch_module.bfloat16 + return None diff --git a/slide2vec/runtime/hierarchical.py b/slide2vec/runtime/hierarchical.py new file mode 100644 index 0000000..4a42a2a --- /dev/null +++ b/slide2vec/runtime/hierarchical.py @@ -0,0 +1,105 @@ +import numpy as np + +from slide2vec.api import PreprocessingConfig +from slide2vec.utils.coordinates import coordinate_arrays + +from .types import HierarchicalIndex + + +def num_tiles(tiling_result) -> int: + x_values, _y_values = coordinate_arrays(tiling_result) + return int(len(x_values)) + + +def is_hierarchical_preprocessing(preprocessing: PreprocessingConfig | None) -> bool: + if preprocessing is None: + return False + return preprocessing.region_tile_multiple is not None or preprocessing.requested_region_size_px is not None + + +def resolve_hierarchical_geometry(preprocessing: PreprocessingConfig, tiling_result) -> dict[str, int]: + if preprocessing.region_tile_multiple is None: + raise ValueError("Hierarchical preprocessing requires region_tile_multiple") + if preprocessing.requested_region_size_px is None: + raise ValueError("Hierarchical preprocessing requires requested_region_size_px") + requested_tile_size_px = int(preprocessing.requested_tile_size_px) + requested_region_size_px = int(preprocessing.requested_region_size_px) + requested_spacing_um = float(preprocessing.requested_spacing_um) + multiple = int(preprocessing.region_tile_multiple) + if requested_region_size_px % multiple != 0: + raise ValueError("requested_region_size_px must be divisible by region_tile_multiple") + read_spacing_um = float(getattr(tiling_result, "read_spacing_um")) + base_spacing_um = float(getattr(tiling_result, "base_spacing_um")) + if abs(read_spacing_um - requested_spacing_um) / requested_spacing_um <= float(preprocessing.tolerance): + read_tile_size_px = requested_tile_size_px + else: + read_tile_size_px = int( + round(requested_tile_size_px * requested_spacing_um / read_spacing_um) + ) + read_region_size_px = read_tile_size_px * multiple + # Use the actual read geometry that produced the tile crop. When the + # resolved spacing is considered equivalent to the requested spacing, + # this keeps the level-0 footprint aligned with the real crop size. + tile_size_lv0 = int(round(read_tile_size_px * read_spacing_um / base_spacing_um)) + return { + "region_tile_multiple": multiple, + "tiles_per_region": multiple * multiple, + "requested_tile_size_px": requested_tile_size_px, + "read_tile_size_px": read_tile_size_px, + "requested_region_size_px": requested_region_size_px, + "read_region_size_px": read_region_size_px, + "tile_size_lv0": tile_size_lv0, + } + + +def build_hierarchical_index( + tiling_result, + *, + region_tile_multiple: int, + tile_size_lv0: int | None = None, +) -> HierarchicalIndex: + x_values, y_values = coordinate_arrays(tiling_result) + num_regions = int(len(x_values)) + multiple = int(region_tile_multiple) + if multiple < 2: + raise ValueError("region_tile_multiple must be at least 2") + subtile_size_lv0 = ( + int(tile_size_lv0) + if tile_size_lv0 is not None + else int(getattr(tiling_result, "tile_size_lv0")) // multiple + ) + tiles_per_region = multiple * multiple + if num_regions == 0: + empty = np.empty(0, dtype=np.int64) + return HierarchicalIndex( + flat_index=empty, + region_index=np.empty(0, dtype=np.int32), + subtile_index_within_region=np.empty(0, dtype=np.int32), + subtile_x=empty, + subtile_y=empty, + num_regions=0, + tiles_per_region=tiles_per_region, + ) + rows, cols = np.divmod(np.arange(tiles_per_region, dtype=np.int32), multiple) + offsets_x = cols.astype(np.int64) * subtile_size_lv0 + offsets_y = rows.astype(np.int64) * subtile_size_lv0 + region_x = np.asarray(x_values, dtype=np.int64)[:, np.newaxis] + region_y = np.asarray(y_values, dtype=np.int64)[:, np.newaxis] + subtile_x = (region_x + offsets_x[np.newaxis, :]).reshape(-1) + subtile_y = (region_y + offsets_y[np.newaxis, :]).reshape(-1) + return HierarchicalIndex( + flat_index=np.arange(num_regions * tiles_per_region, dtype=np.int64), + region_index=np.repeat(np.arange(num_regions, dtype=np.int32), tiles_per_region), + subtile_index_within_region=np.tile(np.arange(tiles_per_region, dtype=np.int32), num_regions), + subtile_x=subtile_x, + subtile_y=subtile_y, + num_regions=num_regions, + tiles_per_region=tiles_per_region, + ) + + +def num_embedding_items(tiling_result, preprocessing: PreprocessingConfig | None) -> int: + if not is_hierarchical_preprocessing(preprocessing): + return num_tiles(tiling_result) + geometry = resolve_hierarchical_geometry(preprocessing, tiling_result) + return num_tiles(tiling_result) * int(geometry["tiles_per_region"]) diff --git a/slide2vec/runtime/progress_bridge.py b/slide2vec/runtime/progress_bridge.py new file mode 100644 index 0000000..1be85cd --- /dev/null +++ b/slide2vec/runtime/progress_bridge.py @@ -0,0 +1,52 @@ +from contextlib import contextmanager + +from hs2p import progress as hs2p_progress + +from slide2vec.progress import ( + NullProgressReporter, + ProgressEvent as Slide2VecProgressEvent, + get_progress_reporter, +) + +_BRIDGED_HS2P_PROGRESS_KINDS = { + "backend.selected", + "tissue.started", + "tissue.progress", + "tissue.finished", + "tiling.progress", + "tiling.finished", + "preview.started", + "preview.progress", + "preview.finished", +} + + +class _Hs2pProgressBridge: + def __init__(self, downstream) -> None: + self._downstream = downstream + + def emit(self, event) -> None: + if event.kind not in _BRIDGED_HS2P_PROGRESS_KINDS: + return + self._downstream.emit( + Slide2VecProgressEvent(kind=event.kind, payload=dict(event.payload)) + ) + + def close(self) -> None: + return None + + def write_log(self, message: str, *, stream=None) -> None: + if hasattr(self._downstream, "write_log"): + self._downstream.write_log(message, stream=stream) + + +@contextmanager +def bridge_hs2p_progress_to_slide2vec(): + downstream = get_progress_reporter() + if isinstance(downstream, NullProgressReporter): + yield + return + bridge = _Hs2pProgressBridge(downstream) + with hs2p_progress.activate_progress_reporter(bridge): + yield + diff --git a/slide2vec/runtime/serialization.py b/slide2vec/runtime/serialization.py new file mode 100644 index 0000000..47199e6 --- /dev/null +++ b/slide2vec/runtime/serialization.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Any + +from slide2vec.api import ExecutionOptions, PreprocessingConfig + + +def serialize_model(model) -> dict[str, Any]: + return { + "name": model.name, + "output_variant": model._output_variant if hasattr(model, "_output_variant") else None, + "allow_non_recommended_settings": bool( + getattr(model, "allow_non_recommended_settings", False) + ), + } + + +def serialize_preprocessing(preprocessing: PreprocessingConfig) -> dict[str, Any]: + return { + "backend": preprocessing.backend, + "requested_spacing_um": preprocessing.requested_spacing_um, + "requested_tile_size_px": preprocessing.requested_tile_size_px, + "requested_region_size_px": preprocessing.requested_region_size_px, + "region_tile_multiple": preprocessing.region_tile_multiple, + "tolerance": preprocessing.tolerance, + "overlap": preprocessing.overlap, + "tissue_threshold": preprocessing.tissue_threshold, + "read_coordinates_from": str(preprocessing.read_coordinates_from) if preprocessing.read_coordinates_from is not None else None, + "read_tiles_from": str(preprocessing.read_tiles_from) if preprocessing.read_tiles_from is not None else None, + "resume": preprocessing.resume, + "segmentation": dict(preprocessing.segmentation), + "filtering": dict(preprocessing.filtering), + "preview": dict(preprocessing.preview), + } + + +def serialize_execution( + execution: ExecutionOptions, + *, + effective_num_workers: int | None = None, +) -> dict[str, Any]: + return { + "output_dir": str(execution.output_dir) if execution.output_dir is not None else None, + "output_format": execution.output_format, + "batch_size": execution.batch_size, + "num_workers": effective_num_workers if effective_num_workers is not None else execution.num_workers, + "num_preprocessing_workers": execution.num_preprocessing_workers, + "num_gpus": execution.num_gpus, + "precision": execution.precision, + "prefetch_factor": execution.prefetch_factor, + "persistent_workers": execution.persistent_workers, + "save_tile_embeddings": execution.save_tile_embeddings, + "save_latents": execution.save_latents, + } + + +def deserialize_preprocessing(payload: dict[str, Any]) -> PreprocessingConfig: + read_coordinates_from = ( + Path(payload["read_coordinates_from"]) + if "read_coordinates_from" in payload and payload["read_coordinates_from"] + else None + ) + read_tiles_from = ( + Path(payload["read_tiles_from"]) + if "read_tiles_from" in payload and payload["read_tiles_from"] + else None + ) + return PreprocessingConfig( + backend=payload["backend"], + requested_spacing_um=float(payload["requested_spacing_um"]), + requested_tile_size_px=int(payload["requested_tile_size_px"]), + requested_region_size_px=( + int(payload["requested_region_size_px"]) + if "requested_region_size_px" in payload and payload["requested_region_size_px"] is not None + else None + ), + region_tile_multiple=( + int(payload["region_tile_multiple"]) + if "region_tile_multiple" in payload and payload["region_tile_multiple"] is not None + else None + ), + tolerance=float(payload["tolerance"]), + overlap=float(payload["overlap"]), + tissue_threshold=float(payload["tissue_threshold"]), + read_coordinates_from=read_coordinates_from, + read_tiles_from=read_tiles_from, + resume=bool(payload["resume"]) if "resume" in payload else False, + segmentation=dict(payload["segmentation"]) if "segmentation" in payload else {}, + filtering=dict(payload["filtering"]) if "filtering" in payload else {}, + preview=dict(payload["preview"]) if "preview" in payload else {}, + ) + + +def deserialize_execution(payload: dict[str, Any]) -> ExecutionOptions: + output_dir = payload["output_dir"] if "output_dir" in payload else None + batch_size = payload["batch_size"] if "batch_size" in payload else None + num_workers = payload["num_workers"] if "num_workers" in payload else None + num_gpus = payload["num_gpus"] if "num_gpus" in payload else 1 + precision = payload["precision"] if "precision" in payload else "fp32" + prefetch_factor = payload["prefetch_factor"] if "prefetch_factor" in payload else 4 + persistent_workers = ( + bool(payload["persistent_workers"]) if "persistent_workers" in payload else True + ) + save_tile_embeddings = ( + bool(payload["save_tile_embeddings"]) if "save_tile_embeddings" in payload else False + ) + save_latents = bool(payload["save_latents"]) if "save_latents" in payload else False + return ExecutionOptions( + output_dir=Path(output_dir) if output_dir is not None else None, + output_format=payload["output_format"] if "output_format" in payload else "pt", + batch_size=batch_size, + num_workers=int(num_workers) if num_workers is not None else None, + num_gpus=int(num_gpus), + precision=precision, + prefetch_factor=int(prefetch_factor), + persistent_workers=persistent_workers, + save_tile_embeddings=save_tile_embeddings, + save_latents=save_latents, + ) + diff --git a/slide2vec/runtime/types.py b/slide2vec/runtime/types.py new file mode 100644 index 0000000..2e67823 --- /dev/null +++ b/slide2vec/runtime/types.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Any + +import numpy as np + + +@dataclass(frozen=True, kw_only=True) +class BatchTransformSpec: + resize_size: tuple[int, int] | None + center_crop_size: tuple[int, int] | None + mean: tuple[float, ...] | None + std: tuple[float, ...] | None + resize_interpolation: str = "bilinear" + + +@dataclass(kw_only=True) +class PreparedBatch: + indices: Any + image: Any + loader_wait_ms: float + preprocess_ms: float + ready_wait_ms: float = 0.0 + worker_batch_ms: float = 0.0 + reader_open_ms: float = 0.0 + reader_read_ms: float = 0.0 + + +@dataclass(frozen=True, kw_only=True) +class HierarchicalIndex: + flat_index: np.ndarray + region_index: np.ndarray + subtile_index_within_region: np.ndarray + subtile_x: np.ndarray + subtile_y: np.ndarray + num_regions: int + tiles_per_region: int + diff --git a/tasks/lessons.md b/tasks/lessons.md index 25a9736..ec3d646 100644 --- a/tasks/lessons.md +++ b/tasks/lessons.md @@ -2,6 +2,8 @@ ## 2026-04-18 +- Prefer neutral package names like `runtime/` for internal implementation modules unless the user explicitly wants a private-style namespace; leading underscores in directory names read as accidental or overly internal. + - When slide2vec depends on bridged HS2P progress events, keep the bridge whitelist in sync with every reporter stage the UI renders; otherwise the code can define a preview bar and still never receive preview events. ## 2026-04-18 diff --git a/tests/test_architecture_runtime_split.py b/tests/test_architecture_runtime_split.py new file mode 100644 index 0000000..fdd084b --- /dev/null +++ b/tests/test_architecture_runtime_split.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from pathlib import Path + + +def test_internal_runtime_modules_stay_small(): + package_root = Path(__file__).resolve().parents[1] / "slide2vec" / "runtime" + module_paths = [ + package_root / "batching.py", + package_root / "hierarchical.py", + package_root / "progress_bridge.py", + package_root / "serialization.py", + package_root / "types.py", + ] + for module_path in module_paths: + line_count = len(module_path.read_text(encoding="utf-8").splitlines()) + assert line_count <= 500, f"{module_path.name} grew beyond the workflow-module size target ({line_count} > 500)" + + +def test_inference_module_is_orchestration_sized(): + inference_path = Path(__file__).resolve().parents[1] / "slide2vec" / "inference.py" + line_count = len(inference_path.read_text(encoding="utf-8").splitlines()) + assert line_count <= 3200, f"inference.py should stay below the legacy monolith size (current lines: {line_count})" From ea8cab8c07a485b55e11ab6d98ba3bfdd6aba775 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:27:22 +0200 Subject: [PATCH 02/11] Extract distributed inference helpers into runtime module --- docs/documentation.md | 1 + slide2vec/inference.py | 172 +++++--------------- slide2vec/runtime/distributed.py | 194 +++++++++++++++++++++++ tests/test_architecture_runtime_split.py | 1 + 4 files changed, 240 insertions(+), 128 deletions(-) create mode 100644 slide2vec/runtime/distributed.py diff --git a/docs/documentation.md b/docs/documentation.md index 7cab558..bcb032b 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -4,6 +4,7 @@ - Split `slide2vec.inference` into workflow-scoped internal runtime helpers under `slide2vec.runtime` (`batching`, `hierarchical`, `progress_bridge`, `serialization`, `types`) while keeping `slide2vec.inference` as the stable orchestration entrypoint. - Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. +- Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 1b84491..1c43dea 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -1,10 +1,7 @@ import json import importlib -import heapq import os -import shutil import subprocess -import sys import tempfile import threading import time @@ -38,6 +35,17 @@ should_suppress_cucim_dataloader_stderr as _should_suppress_cucim_dataloader_stderr, uses_cuda_runtime as _uses_cuda_runtime, ) +from slide2vec.runtime.distributed import ( + assign_slides_to_ranks as _assign_slides_to_ranks_runtime, + distributed_coordination_dir as _distributed_coordination_dir_runtime, + load_embedded_slide_payload as _load_embedded_slide_payload_runtime, + load_hierarchical_embedding_shards as _load_hierarchical_embedding_shards_runtime, + load_tile_embedding_shards as _load_tile_embedding_shards_runtime, + merge_hierarchical_embedding_shards as _merge_hierarchical_embedding_shards_runtime, + merge_tile_embedding_shards as _merge_tile_embedding_shards_runtime, + reset_progress_event_logs as _reset_progress_event_logs_runtime, + run_torchrun_worker as _run_torchrun_worker_runtime, +) from slide2vec.runtime.hierarchical import ( build_hierarchical_index as _build_hierarchical_index, is_hierarchical_preprocessing as _is_hierarchical_preprocessing, @@ -85,8 +93,6 @@ from slide2vec.runtime_types import LoadedModel from slide2vec.progress import ( emit_progress, - emit_progress_event, - read_progress_events, read_tiling_progress_snapshot, ) from slide2vec.utils.coordinates import coordinate_arrays @@ -2650,11 +2656,8 @@ def _embed_multi_slides_distributed( @contextmanager def _distributed_coordination_dir(work_dir: Path): - coordination_dir = Path(tempfile.mkdtemp(prefix="slide2vec-dist-", dir=work_dir)) - try: + with _distributed_coordination_dir_runtime(work_dir) as coordination_dir: yield coordination_dir - finally: - shutil.rmtree(coordination_dir, ignore_errors=True) def _run_distributed_direct_embedding_stage( @@ -2701,58 +2704,15 @@ def _run_torchrun_worker( failure_title: str, progress_events_path: Path | None = None, ) -> None: - command = [ - sys.executable, - "-m", - "torch.distributed.run", - f"--nproc_per_node={execution.num_gpus}", - "-m", - module, - "--output-dir", - str(output_dir), - "--request-path", - str(request_path), - ] - process = subprocess.Popen( - command, - cwd=str(Path(__file__).resolve().parents[1]), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, - ) - stdout_chunks: list[str] = [] - stderr_chunks: list[str] = [] - stdout_thread = threading.Thread(target=_drain_stream_to_buffer, args=(process.stdout, stdout_chunks), daemon=True) - stderr_thread = threading.Thread(target=_drain_stream_to_buffer, args=(process.stderr, stderr_chunks), daemon=True) - stdout_thread.start() - stderr_thread.start() - offsets: dict[Path, int] = {} - while process.poll() is None: - if progress_events_path is not None: - events, offsets = read_progress_events(progress_events_path, offsets=offsets) - for event in events: - emit_progress_event(event) - time.sleep(0.1) - if progress_events_path is not None: - events, offsets = read_progress_events(progress_events_path, offsets=offsets) - for event in events: - emit_progress_event(event) - returncode = process.wait() - stdout_thread.join(timeout=1.0) - stderr_thread.join(timeout=1.0) - stdout_text = "".join(stdout_chunks) - stderr_text = "".join(stderr_chunks) - stdout_log_path, stderr_log_path = _write_worker_logs(module, output_dir, stdout_text, stderr_text) - if returncode != 0: - raise RuntimeError( - f"{failure_title}.\n" - f"See logs:\n" - f"stdout: {stdout_log_path}\n" - f"stderr: {stderr_log_path}\n" - f"stdout:\n{stdout_text}\n" - f"stderr:\n{stderr_text}" - ) + _run_torchrun_worker_runtime( + module=module, + num_gpus=execution.num_gpus, + output_dir=output_dir, + request_path=request_path, + failure_title=failure_title, + progress_events_path=progress_events_path, + popen_factory=subprocess.Popen, + ) def _build_pipeline_worker_request_payload( @@ -2809,33 +2769,20 @@ def _build_direct_embed_worker_request_payload( def _reset_progress_event_logs(progress_events_path: Path) -> None: - progress_events_path.parent.mkdir(parents=True, exist_ok=True) - for path in [progress_events_path, *progress_events_path.parent.glob(f"{progress_events_path.stem}.rank*{progress_events_path.suffix}")]: - if path.exists(): - path.unlink() + _reset_progress_event_logs_runtime(progress_events_path) def _drain_stream_to_buffer(stream, chunks: list[str]) -> None: - if stream is None: - return - try: - for line in iter(stream.readline, ""): - if line == "": - break - chunks.append(line) - finally: - stream.close() + # Compatibility shim for tests monkeypatching this helper in slide2vec.inference. + from slide2vec.runtime.distributed import drain_stream_to_buffer as _drain_stream_to_buffer_runtime + + _drain_stream_to_buffer_runtime(stream, chunks) def _write_worker_logs(module: str, output_dir: Path, stdout_text: str, stderr_text: str) -> tuple[Path, Path]: - logs_dir = output_dir / "logs" - logs_dir.mkdir(parents=True, exist_ok=True) - module_name = module.rsplit(".", 1)[-1] - stdout_log_path = logs_dir / f"{module_name}.stdout.log" - stderr_log_path = logs_dir / f"{module_name}.stderr.log" - stdout_log_path.write_text(stdout_text, encoding="utf-8") - stderr_log_path.write_text(stderr_text, encoding="utf-8") - return stdout_log_path, stderr_log_path + from slide2vec.runtime.distributed import write_worker_logs as _write_worker_logs_runtime + + return _write_worker_logs_runtime(module, output_dir, stdout_text, stderr_text) def _assign_slides_to_ranks( @@ -2844,31 +2791,16 @@ def _assign_slides_to_ranks( *, num_gpus: int, ) -> dict[int, list[str]]: - assignments: dict[int, list[str]] = {rank: [] for rank in range(num_gpus)} - assigned_ranks = [(0, rank) for rank in range(num_gpus)] - heapq.heapify(assigned_ranks) - sortable = [] - for slide, tiling_result in zip(slide_records, tiling_results): - sortable.append((slide.sample_id, _num_tiles(tiling_result))) - for sample_id, num_tiles in sorted(sortable, key=lambda item: (-item[1], item[0])): - assigned_tiles, rank = heapq.heappop(assigned_ranks) - assignments[rank].append(sample_id) - heapq.heappush(assigned_ranks, (assigned_tiles + int(num_tiles), rank)) - return assignments + return _assign_slides_to_ranks_runtime( + slide_records, + tiling_results, + num_gpus=num_gpus, + num_tiles_fn=_num_tiles, + ) def _merge_tile_embedding_shards(shard_payloads): - if not shard_payloads: - raise ValueError("No tile embedding shards were produced") - indices = np.concatenate([np.asarray(payload["tile_index"], dtype=np.int64) for payload in shard_payloads], axis=0) - order = np.argsort(indices, kind="stable") - embeddings = [payload["tile_embeddings"] for payload in shard_payloads] - first = embeddings[0] - if torch.is_tensor(first): - merged = torch.cat(embeddings, dim=0) - return merged[torch.as_tensor(order, dtype=torch.long)] - merged = np.concatenate([np.asarray(embedding) for embedding in embeddings], axis=0) - return merged[order] + return _merge_tile_embedding_shards_runtime(shard_payloads) def _merge_hierarchical_embedding_shards( @@ -2877,39 +2809,23 @@ def _merge_hierarchical_embedding_shards( num_regions: int, tiles_per_region: int, ): - if not shard_payloads: - raise ValueError("No hierarchical embedding shards were produced") - indices = np.concatenate( - [np.asarray(payload["flat_index"], dtype=np.int64) for payload in shard_payloads], - axis=0, - ) - order = np.argsort(indices, kind="stable") - embeddings = [payload["tile_embeddings"] for payload in shard_payloads] - first = embeddings[0] - if torch.is_tensor(first): - merged = torch.cat(embeddings, dim=0) - merged = merged[torch.as_tensor(order, dtype=torch.long)] - return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) - merged = np.concatenate([np.asarray(embedding) for embedding in embeddings], axis=0) - merged = merged[order] - return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) + return _merge_hierarchical_embedding_shards_runtime( + shard_payloads, + num_regions=num_regions, + tiles_per_region=tiles_per_region, + ) def _load_tile_embedding_shards(coordination_dir: Path, sample_id: str): - - shard_paths = sorted(coordination_dir.glob(f"{sample_id}.tiles.rank*.pt")) - return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] + return _load_tile_embedding_shards_runtime(coordination_dir, sample_id) def _load_hierarchical_embedding_shards(coordination_dir: Path, sample_id: str): - shard_paths = sorted(coordination_dir.glob(f"{sample_id}.hier.rank*.pt")) - return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] + return _load_hierarchical_embedding_shards_runtime(coordination_dir, sample_id) def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): - - payload_path = coordination_dir / f"{sample_id}.embedded.pt" - return torch.load(payload_path, map_location="cpu", weights_only=True) + return _load_embedded_slide_payload_runtime(coordination_dir, sample_id) def _collect_pipeline_artifacts( diff --git a/slide2vec/runtime/distributed.py b/slide2vec/runtime/distributed.py new file mode 100644 index 0000000..cc83064 --- /dev/null +++ b/slide2vec/runtime/distributed.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +import heapq +import shutil +import subprocess +import sys +import tempfile +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Sequence + +import numpy as np +import torch +from hs2p import SlideSpec + +from slide2vec.progress import emit_progress_event, read_progress_events + + +@contextmanager +def distributed_coordination_dir(work_dir: Path): + coordination_dir = Path(tempfile.mkdtemp(prefix="slide2vec-dist-", dir=work_dir)) + try: + yield coordination_dir + finally: + shutil.rmtree(coordination_dir, ignore_errors=True) + + +def reset_progress_event_logs(progress_events_path: Path) -> None: + progress_events_path.parent.mkdir(parents=True, exist_ok=True) + for path in [progress_events_path, *progress_events_path.parent.glob(f"{progress_events_path.stem}.rank*{progress_events_path.suffix}")]: + if path.exists(): + path.unlink() + + +def drain_stream_to_buffer(stream, chunks: list[str]) -> None: + if stream is None: + return + try: + for line in iter(stream.readline, ""): + if line == "": + break + chunks.append(line) + finally: + stream.close() + + +def write_worker_logs(module: str, output_dir: Path, stdout_text: str, stderr_text: str) -> tuple[Path, Path]: + logs_dir = output_dir / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + module_name = module.rsplit(".", 1)[-1] + stdout_log_path = logs_dir / f"{module_name}.stdout.log" + stderr_log_path = logs_dir / f"{module_name}.stderr.log" + stdout_log_path.write_text(stdout_text, encoding="utf-8") + stderr_log_path.write_text(stderr_text, encoding="utf-8") + return stdout_log_path, stderr_log_path + + +def run_torchrun_worker( + *, + module: str, + num_gpus: int, + output_dir: Path, + request_path: Path, + failure_title: str, + progress_events_path: Path | None = None, + popen_factory=subprocess.Popen, +) -> None: + command = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={num_gpus}", + "-m", + module, + "--output-dir", + str(output_dir), + "--request-path", + str(request_path), + ] + process = popen_factory( + command, + cwd=str(Path(__file__).resolve().parents[2]), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + stdout_chunks: list[str] = [] + stderr_chunks: list[str] = [] + stdout_thread = threading.Thread(target=drain_stream_to_buffer, args=(process.stdout, stdout_chunks), daemon=True) + stderr_thread = threading.Thread(target=drain_stream_to_buffer, args=(process.stderr, stderr_chunks), daemon=True) + stdout_thread.start() + stderr_thread.start() + offsets: dict[Path, int] = {} + while process.poll() is None: + if progress_events_path is not None: + events, offsets = read_progress_events(progress_events_path, offsets=offsets) + for event in events: + emit_progress_event(event) + time.sleep(0.1) + if progress_events_path is not None: + events, offsets = read_progress_events(progress_events_path, offsets=offsets) + for event in events: + emit_progress_event(event) + returncode = process.wait() + stdout_thread.join(timeout=1.0) + stderr_thread.join(timeout=1.0) + stdout_text = "".join(stdout_chunks) + stderr_text = "".join(stderr_chunks) + stdout_log_path, stderr_log_path = write_worker_logs(module, output_dir, stdout_text, stderr_text) + if returncode != 0: + raise RuntimeError( + f"{failure_title}.\n" + f"See logs:\n" + f"stdout: {stdout_log_path}\n" + f"stderr: {stderr_log_path}\n" + f"stdout:\n{stdout_text}\n" + f"stderr:\n{stderr_text}" + ) + + +def assign_slides_to_ranks( + slide_records: Sequence[SlideSpec], + tiling_results, + *, + num_gpus: int, + num_tiles_fn, +) -> dict[int, list[str]]: + assignments: dict[int, list[str]] = {rank: [] for rank in range(num_gpus)} + assigned_ranks = [(0, rank) for rank in range(num_gpus)] + heapq.heapify(assigned_ranks) + sortable = [] + for slide, tiling_result in zip(slide_records, tiling_results): + sortable.append((slide.sample_id, num_tiles_fn(tiling_result))) + for sample_id, num_tiles in sorted(sortable, key=lambda item: (-item[1], item[0])): + assigned_tiles, rank = heapq.heappop(assigned_ranks) + assignments[rank].append(sample_id) + heapq.heappush(assigned_ranks, (assigned_tiles + int(num_tiles), rank)) + return assignments + + +def merge_tile_embedding_shards(shard_payloads): + if not shard_payloads: + raise ValueError("No tile embedding shards were produced") + indices = np.concatenate([np.asarray(payload["tile_index"], dtype=np.int64) for payload in shard_payloads], axis=0) + order = np.argsort(indices, kind="stable") + embeddings = [payload["tile_embeddings"] for payload in shard_payloads] + first = embeddings[0] + if torch.is_tensor(first): + merged = torch.cat(embeddings, dim=0) + return merged[torch.as_tensor(order, dtype=torch.long)] + merged = np.concatenate([np.asarray(embedding) for embedding in embeddings], axis=0) + return merged[order] + + +def merge_hierarchical_embedding_shards( + shard_payloads, + *, + num_regions: int, + tiles_per_region: int, +): + if not shard_payloads: + raise ValueError("No hierarchical embedding shards were produced") + indices = np.concatenate( + [np.asarray(payload["flat_index"], dtype=np.int64) for payload in shard_payloads], + axis=0, + ) + order = np.argsort(indices, kind="stable") + embeddings = [payload["tile_embeddings"] for payload in shard_payloads] + first = embeddings[0] + if torch.is_tensor(first): + merged = torch.cat(embeddings, dim=0) + merged = merged[torch.as_tensor(order, dtype=torch.long)] + return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) + merged = np.concatenate([np.asarray(embedding) for embedding in embeddings], axis=0) + merged = merged[order] + return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) + + +def load_tile_embedding_shards(coordination_dir: Path, sample_id: str): + shard_paths = sorted(coordination_dir.glob(f"{sample_id}.tiles.rank*.pt")) + return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] + + +def load_hierarchical_embedding_shards(coordination_dir: Path, sample_id: str): + shard_paths = sorted(coordination_dir.glob(f"{sample_id}.hier.rank*.pt")) + return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] + + +def load_embedded_slide_payload(coordination_dir: Path, sample_id: str): + payload_path = coordination_dir / f"{sample_id}.embedded.pt" + return torch.load(payload_path, map_location="cpu", weights_only=True) diff --git a/tests/test_architecture_runtime_split.py b/tests/test_architecture_runtime_split.py index fdd084b..a802b0c 100644 --- a/tests/test_architecture_runtime_split.py +++ b/tests/test_architecture_runtime_split.py @@ -7,6 +7,7 @@ def test_internal_runtime_modules_stay_small(): package_root = Path(__file__).resolve().parents[1] / "slide2vec" / "runtime" module_paths = [ package_root / "batching.py", + package_root / "distributed.py", package_root / "hierarchical.py", package_root / "progress_bridge.py", package_root / "serialization.py", From 363bd563c514c91f41a6d7b8f7fe99694ee7383a Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:29:38 +0200 Subject: [PATCH 03/11] Extract pipeline artifact persistence from inference --- docs/documentation.md | 1 + slide2vec/inference.py | 156 +-------------------- slide2vec/runtime/persistence.py | 165 +++++++++++++++++++++++ tests/test_architecture_runtime_split.py | 1 + 4 files changed, 171 insertions(+), 152 deletions(-) create mode 100644 slide2vec/runtime/persistence.py diff --git a/docs/documentation.md b/docs/documentation.md index bcb032b..a233250 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -5,6 +5,7 @@ - Split `slide2vec.inference` into workflow-scoped internal runtime helpers under `slide2vec.runtime` (`batching`, `hierarchical`, `progress_bridge`, `serialization`, `types`) while keeping `slide2vec.inference` as the stable orchestration entrypoint. - Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. - Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. +- Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 1c43dea..7eb8a69 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -53,6 +53,10 @@ num_tiles as _num_tiles, resolve_hierarchical_geometry as _resolve_hierarchical_geometry, ) +from slide2vec.runtime.persistence import ( + collect_pipeline_artifacts as _collect_pipeline_artifacts, + update_process_list_after_embedding as _update_process_list_after_embedding, +) from slide2vec.runtime.progress_bridge import ( bridge_hs2p_progress_to_slide2vec as _bridge_hs2p_progress_to_slide2vec, ) @@ -78,7 +82,6 @@ TileEmbeddingArtifact, write_hierarchical_embeddings, load_array, - load_metadata, write_patient_embeddings, write_slide_embeddings, write_tile_embedding_metadata, @@ -2828,157 +2831,6 @@ def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): return _load_embedded_slide_payload_runtime(coordination_dir, sample_id) -def _collect_pipeline_artifacts( - slide_records: Sequence[SlideSpec], - *, - output_dir: Path, - output_format: str, - include_tile_embeddings: bool, - include_hierarchical_embeddings: bool, - include_slide_embeddings: bool, -) -> tuple[ - list[TileEmbeddingArtifact], - list[HierarchicalEmbeddingArtifact], - list[SlideEmbeddingArtifact], -]: - tile_artifacts: list[TileEmbeddingArtifact] = [] - hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = [] - slide_artifacts: list[SlideEmbeddingArtifact] = [] - for slide in slide_records: - if include_tile_embeddings: - tile_artifacts.append(_load_tile_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format)) - if include_hierarchical_embeddings: - hierarchical_artifacts.append( - _load_hierarchical_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) - ) - if include_slide_embeddings: - slide_artifacts.append( - _load_slide_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) - ) - return tile_artifacts, hierarchical_artifacts, slide_artifacts - - -def _load_tile_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> TileEmbeddingArtifact: - artifact_path = output_dir / "tile_embeddings" / f"{sample_id}.{output_format}" - metadata_path = output_dir / "tile_embeddings" / f"{sample_id}.meta.json" - metadata = load_metadata(metadata_path) - return TileEmbeddingArtifact( - sample_id=sample_id, - path=artifact_path, - metadata_path=metadata_path, - format=output_format, - feature_dim=int(metadata["feature_dim"]), - num_tiles=int(metadata["num_tiles"]), - ) - - -def _load_hierarchical_artifact( - sample_id: str, - *, - output_dir: Path, - output_format: str, -) -> HierarchicalEmbeddingArtifact: - artifact_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.{output_format}" - metadata_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.meta.json" - metadata = load_metadata(metadata_path) - return HierarchicalEmbeddingArtifact( - sample_id=sample_id, - path=artifact_path, - metadata_path=metadata_path, - format=output_format, - feature_dim=int(metadata["feature_dim"]), - num_regions=int(metadata["num_regions"]), - tiles_per_region=int(metadata["tiles_per_region"]), - ) - - -def _load_slide_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> SlideEmbeddingArtifact: - artifact_path = output_dir / "slide_embeddings" / f"{sample_id}.{output_format}" - metadata_path = output_dir / "slide_embeddings" / f"{sample_id}.meta.json" - metadata = load_metadata(metadata_path) - latent_suffix = "pt" if output_format == "pt" else "npz" - latent_path = output_dir / "slide_latents" / f"{sample_id}.{latent_suffix}" - return SlideEmbeddingArtifact( - sample_id=sample_id, - path=artifact_path, - metadata_path=metadata_path, - format=output_format, - feature_dim=int(metadata["feature_dim"]), - latent_path=latent_path if latent_path.is_file() else None, - ) - - -def _update_process_list_after_embedding( - process_list_path: Path, - *, - successful_slides: Sequence[SlideSpec], - persist_tile_embeddings: bool, - persist_hierarchical_embeddings: bool, - include_slide_embeddings: bool, - encoder_name: str, - output_variant: str | None, - tile_artifacts: Sequence[TileEmbeddingArtifact], - hierarchical_artifacts: Sequence[HierarchicalEmbeddingArtifact], - slide_artifacts: Sequence[SlideEmbeddingArtifact], -) -> None: - def _resolve_path_str(value: Any) -> str | None: - if value is None or pd.isna(value): - return None - return str(Path(value).resolve()) - - df = pd.read_csv(process_list_path) - if "feature_status" not in df.columns: - df["feature_status"] = ["tbp"] * len(df) - if "feature_path" not in df.columns: - df["feature_path"] = [None] * len(df) - if "encoder_name" not in df.columns: - df["encoder_name"] = [None] * len(df) - if "output_variant" not in df.columns: - df["output_variant"] = [None] * len(df) - if "feature_kind" not in df.columns: - df["feature_kind"] = [None] * len(df) - if include_slide_embeddings and "aggregation_status" not in df.columns: - df["aggregation_status"] = ["tbp"] * len(df) - tile_success_ids = {artifact.sample_id for artifact in tile_artifacts} - hierarchical_success_ids = {artifact.sample_id for artifact in hierarchical_artifacts} - slide_success_ids = {artifact.sample_id for artifact in slide_artifacts} - if slide_artifacts: - feature_path_by_sample_id = {artifact.sample_id: _resolve_path_str(artifact.path) for artifact in slide_artifacts} - feature_kind = "slide" - feature_success_ids = slide_success_ids - elif persist_hierarchical_embeddings: - feature_path_by_sample_id = { - artifact.sample_id: _resolve_path_str(artifact.path) for artifact in hierarchical_artifacts - } - feature_kind = "hierarchical" - feature_success_ids = hierarchical_success_ids - elif persist_tile_embeddings: - feature_path_by_sample_id = { - artifact.sample_id: _resolve_path_str(artifact.path) for artifact in tile_artifacts - } - feature_kind = "tile" - feature_success_ids = tile_success_ids - else: - feature_path_by_sample_id = {} - feature_kind = None - feature_success_ids = {slide.sample_id for slide in successful_slides} - for slide in successful_slides: - mask = df["sample_id"].astype(str) == slide.sample_id - feature_status = "success" if slide.sample_id in feature_success_ids else "error" - df.loc[mask, "feature_status"] = feature_status - mapped_feature_path = feature_path_by_sample_id.get(slide.sample_id) - if mapped_feature_path is not None: - df.loc[mask, "feature_path"] = mapped_feature_path - df.loc[mask, "encoder_name"] = encoder_name - df.loc[mask, "output_variant"] = output_variant - df.loc[mask, "feature_kind"] = feature_kind - if include_slide_embeddings: - df.loc[mask, "aggregation_status"] = ( - "success" if slide.sample_id in slide_success_ids else "error" - ) - df.to_csv(process_list_path, index=False) - - def load_successful_tiled_slides(output_dir: str | Path) -> tuple[list[SlideSpec], list[Any]]: base_dir = Path(output_dir) process_df = load_tiling_process_df(base_dir / "process_list.csv") diff --git a/slide2vec/runtime/persistence.py b/slide2vec/runtime/persistence.py new file mode 100644 index 0000000..c640093 --- /dev/null +++ b/slide2vec/runtime/persistence.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Sequence + +import pandas as pd +from hs2p import SlideSpec + +from slide2vec.artifacts import ( + HierarchicalEmbeddingArtifact, + SlideEmbeddingArtifact, + TileEmbeddingArtifact, + load_metadata, +) + + +def collect_pipeline_artifacts( + slide_records: Sequence[SlideSpec], + *, + output_dir: Path, + output_format: str, + include_tile_embeddings: bool, + include_hierarchical_embeddings: bool, + include_slide_embeddings: bool, +) -> tuple[ + list[TileEmbeddingArtifact], + list[HierarchicalEmbeddingArtifact], + list[SlideEmbeddingArtifact], +]: + tile_artifacts: list[TileEmbeddingArtifact] = [] + hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = [] + slide_artifacts: list[SlideEmbeddingArtifact] = [] + for slide in slide_records: + if include_tile_embeddings: + tile_artifacts.append(load_tile_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format)) + if include_hierarchical_embeddings: + hierarchical_artifacts.append( + load_hierarchical_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) + ) + if include_slide_embeddings: + slide_artifacts.append( + load_slide_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) + ) + return tile_artifacts, hierarchical_artifacts, slide_artifacts + + +def load_tile_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> TileEmbeddingArtifact: + artifact_path = output_dir / "tile_embeddings" / f"{sample_id}.{output_format}" + metadata_path = output_dir / "tile_embeddings" / f"{sample_id}.meta.json" + metadata = load_metadata(metadata_path) + return TileEmbeddingArtifact( + sample_id=sample_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=int(metadata["feature_dim"]), + num_tiles=int(metadata["num_tiles"]), + ) + + +def load_hierarchical_artifact( + sample_id: str, + *, + output_dir: Path, + output_format: str, +) -> HierarchicalEmbeddingArtifact: + artifact_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.{output_format}" + metadata_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.meta.json" + metadata = load_metadata(metadata_path) + return HierarchicalEmbeddingArtifact( + sample_id=sample_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=int(metadata["feature_dim"]), + num_regions=int(metadata["num_regions"]), + tiles_per_region=int(metadata["tiles_per_region"]), + ) + + +def load_slide_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> SlideEmbeddingArtifact: + artifact_path = output_dir / "slide_embeddings" / f"{sample_id}.{output_format}" + metadata_path = output_dir / "slide_embeddings" / f"{sample_id}.meta.json" + metadata = load_metadata(metadata_path) + latent_suffix = "pt" if output_format == "pt" else "npz" + latent_path = output_dir / "slide_latents" / f"{sample_id}.{latent_suffix}" + return SlideEmbeddingArtifact( + sample_id=sample_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=int(metadata["feature_dim"]), + latent_path=latent_path if latent_path.is_file() else None, + ) + + +def update_process_list_after_embedding( + process_list_path: Path, + *, + successful_slides: Sequence[SlideSpec], + persist_tile_embeddings: bool, + persist_hierarchical_embeddings: bool, + include_slide_embeddings: bool, + encoder_name: str, + output_variant: str | None, + tile_artifacts: Sequence[TileEmbeddingArtifact], + hierarchical_artifacts: Sequence[HierarchicalEmbeddingArtifact], + slide_artifacts: Sequence[SlideEmbeddingArtifact], +) -> None: + def _resolve_path_str(value: Any) -> str | None: + if value is None or pd.isna(value): + return None + return str(Path(value).resolve()) + + df = pd.read_csv(process_list_path) + if "feature_status" not in df.columns: + df["feature_status"] = ["tbp"] * len(df) + if "feature_path" not in df.columns: + df["feature_path"] = [None] * len(df) + if "encoder_name" not in df.columns: + df["encoder_name"] = [None] * len(df) + if "output_variant" not in df.columns: + df["output_variant"] = [None] * len(df) + if "feature_kind" not in df.columns: + df["feature_kind"] = [None] * len(df) + if include_slide_embeddings and "aggregation_status" not in df.columns: + df["aggregation_status"] = ["tbp"] * len(df) + tile_success_ids = {artifact.sample_id for artifact in tile_artifacts} + hierarchical_success_ids = {artifact.sample_id for artifact in hierarchical_artifacts} + slide_success_ids = {artifact.sample_id for artifact in slide_artifacts} + if slide_artifacts: + feature_path_by_sample_id = {artifact.sample_id: _resolve_path_str(artifact.path) for artifact in slide_artifacts} + feature_kind = "slide" + feature_success_ids = slide_success_ids + elif persist_hierarchical_embeddings: + feature_path_by_sample_id = { + artifact.sample_id: _resolve_path_str(artifact.path) for artifact in hierarchical_artifacts + } + feature_kind = "hierarchical" + feature_success_ids = hierarchical_success_ids + elif persist_tile_embeddings: + feature_path_by_sample_id = { + artifact.sample_id: _resolve_path_str(artifact.path) for artifact in tile_artifacts + } + feature_kind = "tile" + feature_success_ids = tile_success_ids + else: + feature_path_by_sample_id = {} + feature_kind = None + feature_success_ids = {slide.sample_id for slide in successful_slides} + for slide in successful_slides: + mask = df["sample_id"].astype(str) == slide.sample_id + feature_status = "success" if slide.sample_id in feature_success_ids else "error" + df.loc[mask, "feature_status"] = feature_status + mapped_feature_path = feature_path_by_sample_id.get(slide.sample_id) + if mapped_feature_path is not None: + df.loc[mask, "feature_path"] = mapped_feature_path + df.loc[mask, "encoder_name"] = encoder_name + df.loc[mask, "output_variant"] = output_variant + df.loc[mask, "feature_kind"] = feature_kind + if include_slide_embeddings: + df.loc[mask, "aggregation_status"] = ( + "success" if slide.sample_id in slide_success_ids else "error" + ) + df.to_csv(process_list_path, index=False) diff --git a/tests/test_architecture_runtime_split.py b/tests/test_architecture_runtime_split.py index a802b0c..06c7346 100644 --- a/tests/test_architecture_runtime_split.py +++ b/tests/test_architecture_runtime_split.py @@ -9,6 +9,7 @@ def test_internal_runtime_modules_stay_small(): package_root / "batching.py", package_root / "distributed.py", package_root / "hierarchical.py", + package_root / "persistence.py", package_root / "progress_bridge.py", package_root / "serialization.py", package_root / "types.py", From 592fd9fad9b99e4b3b216791f86a41f286860042 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:38:43 +0200 Subject: [PATCH 04/11] Extract tiling and embedding helper domains into runtime --- docs/documentation.md | 1 + slide2vec/inference.py | 183 ++++++++--------------- slide2vec/runtime/embedding.py | 153 +++++++++++++++++++ slide2vec/runtime/tiling.py | 100 +++++++++++++ tests/test_architecture_runtime_split.py | 18 +++ 5 files changed, 338 insertions(+), 117 deletions(-) create mode 100644 slide2vec/runtime/embedding.py create mode 100644 slide2vec/runtime/tiling.py diff --git a/docs/documentation.md b/docs/documentation.md index a233250..0bbb24c 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -6,6 +6,7 @@ - Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. - Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. - Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. +- Extracted pure tiling and embedding metadata/writer helpers into `slide2vec.runtime.tiling` and `slide2vec.runtime.embedding`, keeping inference-level wrappers so existing monkeypatch-based regression tests remain stable. - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 7eb8a69..6e7b9c4 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -35,6 +35,15 @@ should_suppress_cucim_dataloader_stderr as _should_suppress_cucim_dataloader_stderr, uses_cuda_runtime as _uses_cuda_runtime, ) +from slide2vec.runtime.embedding import ( + build_hierarchical_embedding_metadata as _build_hierarchical_embedding_metadata_runtime, + build_slide_embedding_metadata as _build_slide_embedding_metadata_runtime, + build_tile_embedding_metadata as _build_tile_embedding_metadata_runtime, + should_persist_tile_embeddings as _should_persist_tile_embeddings_runtime, + write_hierarchical_embedding_artifact as _write_hierarchical_embedding_artifact_runtime, + write_slide_embedding_artifact as _write_slide_embedding_artifact_runtime, + write_tile_embedding_artifact as _write_tile_embedding_artifact_runtime, +) from slide2vec.runtime.distributed import ( assign_slides_to_ranks as _assign_slides_to_ranks_runtime, distributed_coordination_dir as _distributed_coordination_dir_runtime, @@ -67,6 +76,16 @@ serialize_model as _serialize_model, serialize_preprocessing as _serialize_preprocessing, ) +from slide2vec.runtime.tiling import ( + build_hs2p_configs as _build_hs2p_configs_runtime, + build_preview_config as _build_preview_config_runtime, + load_tiling_result as _load_tiling_result_runtime, + resolve_slide_backend as _resolve_slide_backend_runtime, + resolve_tile_store_archive_for_slide as _resolve_tile_store_archive_for_slide_runtime, + resolve_tiling_backend as _resolve_tiling_backend_runtime, + scale_coordinates as _scale_coordinates_runtime, + tile_store_archive_path as _tile_store_archive_path_runtime, +) from slide2vec.api import ( EmbeddedPatient, EmbeddedSlide, @@ -83,9 +102,7 @@ write_hierarchical_embeddings, load_array, write_patient_embeddings, - write_slide_embeddings, write_tile_embedding_metadata, - write_tile_embeddings, ) from slide2vec.encoders.registry import ( encoder_registry, @@ -1873,32 +1890,18 @@ def _build_tile_embedding_metadata( tile_size_lv0: int, backend: str, ) -> dict[str, Any]: - coordinates_npz_path = ( - tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None - ) - coordinates_meta_path = ( - tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None + return _build_tile_embedding_metadata_runtime( + model, + tiling_result=tiling_result, + image_path=image_path, + mask_path=mask_path, + tile_size_lv0=tile_size_lv0, + backend=backend, ) - tiles_tar_path = tiling_result.tiles_tar_path if hasattr(tiling_result, "tiles_tar_path") else None - return { - "encoder_name": model.name, - "encoder_level": model.level, - "coordinates_npz_path": str(coordinates_npz_path or ""), - "coordinates_meta_path": str(coordinates_meta_path or ""), - "tiles_tar_path": str(tiles_tar_path or ""), - "image_path": str(image_path), - "mask_path": str(mask_path) if mask_path is not None else None, - "tile_size_lv0": int(tile_size_lv0), - "backend": backend, - } def _build_slide_embedding_metadata(model, *, image_path: Path | str) -> dict[str, Any]: - return { - "encoder_name": model.name, - "encoder_level": model.level, - "image_path": str(image_path), - } + return _build_slide_embedding_metadata_runtime(model, image_path=image_path) def _build_hierarchical_embedding_metadata( @@ -1910,29 +1913,15 @@ def _build_hierarchical_embedding_metadata( backend: str, preprocessing: PreprocessingConfig, ) -> dict[str, Any]: - coordinates_npz_path = ( - tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None - ) - coordinates_meta_path = ( - tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None + return _build_hierarchical_embedding_metadata_runtime( + model, + tiling_result=tiling_result, + image_path=image_path, + mask_path=mask_path, + backend=backend, + preprocessing=preprocessing, + resolve_hierarchical_geometry_fn=_resolve_hierarchical_geometry, ) - geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) - return { - "encoder_name": model.name, - "encoder_level": model.level, - "coordinates_npz_path": str(coordinates_npz_path or ""), - "coordinates_meta_path": str(coordinates_meta_path or ""), - "image_path": str(image_path), - "mask_path": str(mask_path) if mask_path is not None else None, - "backend": backend, - "region_tile_multiple": int(geometry["region_tile_multiple"]), - "requested_tile_size_px": int(geometry["requested_tile_size_px"]), - "read_tile_size_px": int(geometry["read_tile_size_px"]), - "requested_region_size_px": int(geometry["requested_region_size_px"]), - "read_region_size_px": int(geometry["read_region_size_px"]), - "requested_spacing_um": float(preprocessing.requested_spacing_um), - "subtile_order": "row_major", - } def _write_tile_embedding_artifact( @@ -1942,15 +1931,12 @@ def _write_tile_embedding_artifact( execution: ExecutionOptions, metadata: dict[str, Any], ) -> TileEmbeddingArtifact: - if execution.output_dir is None: - raise ValueError("ExecutionOptions.output_dir is required to persist tile embeddings") - return write_tile_embeddings( + return _write_tile_embedding_artifact_runtime( sample_id, features, - output_dir=execution.output_dir, - output_format=execution.output_format, + execution=execution, metadata=metadata, - tile_index=np.arange(_num_rows(features), dtype=np.int64), + num_rows_fn=_num_rows, ) @@ -1962,13 +1948,10 @@ def _write_slide_embedding_artifact( metadata: dict[str, Any], latents=None, ) -> SlideEmbeddingArtifact: - if execution.output_dir is None: - raise ValueError("ExecutionOptions.output_dir is required to persist slide embeddings") - return write_slide_embeddings( + return _write_slide_embedding_artifact_runtime( sample_id, embedding, - output_dir=execution.output_dir, - output_format=execution.output_format, + execution=execution, metadata=metadata, latents=latents, ) @@ -1981,13 +1964,10 @@ def _write_hierarchical_embedding_artifact( execution: ExecutionOptions, metadata: dict[str, Any], ) -> HierarchicalEmbeddingArtifact: - if execution.output_dir is None: - raise ValueError("ExecutionOptions.output_dir is required to persist hierarchical embeddings") - return write_hierarchical_embeddings( + return _write_hierarchical_embedding_artifact_runtime( sample_id, features, - output_dir=execution.output_dir, - output_format=execution.output_format, + execution=execution, metadata=metadata, ) @@ -2174,9 +2154,7 @@ def _emit_tiling_summary( def _should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: - if model.level in {"slide", "patient"}: - return bool(execution.save_tile_embeddings) - return True + return _should_persist_tile_embeddings_runtime(model, execution) def _resolved_process_list_output_variant(model) -> str | None: @@ -2397,39 +2375,18 @@ def _resolve_path_str(value: Any) -> str | None: def _build_preview_config(preview: dict[str, Any]) -> PreviewConfig: - return PreviewConfig( - save_mask_preview=bool(preview["save_mask_preview"]), - save_tiling_preview=bool(preview["save_tiling_preview"]), - downsample=int(preview["downsample"]), - tissue_contour_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), - mask_overlay_alpha=float(preview["mask_overlay_alpha"]), - ) + return _build_preview_config_runtime(preview, preview_config_cls=PreviewConfig) def _build_hs2p_configs(preprocessing: PreprocessingConfig): - requested_tile_size_px = ( - preprocessing.requested_region_size_px - if _is_hierarchical_preprocessing(preprocessing) - else preprocessing.requested_tile_size_px - ) - tiling_cfg = TilingConfig( - backend=_resolve_tiling_backend(preprocessing), - requested_spacing_um=preprocessing.requested_spacing_um, - requested_tile_size_px=requested_tile_size_px, - tolerance=preprocessing.tolerance, - overlap=preprocessing.overlap, - tissue_threshold=preprocessing.tissue_threshold, - ) - segmentation_cfg = SegmentationConfig(**dict(preprocessing.segmentation)) - filtering_cfg = FilterConfig(**dict(preprocessing.filtering)) - preview_cfg = _build_preview_config(dict(preprocessing.preview)) - return ( - tiling_cfg, - segmentation_cfg, - filtering_cfg, - preview_cfg, - preprocessing.read_coordinates_from, - preprocessing.resume, + return _build_hs2p_configs_runtime( + preprocessing, + is_hierarchical_preprocessing_fn=_is_hierarchical_preprocessing, + resolve_tiling_backend_fn=_resolve_tiling_backend, + tiling_config_cls=TilingConfig, + segmentation_config_cls=SegmentationConfig, + filter_config_cls=FilterConfig, + preview_config_cls=PreviewConfig, ) @@ -2439,45 +2396,37 @@ def _resolve_tile_store_archive_for_slide( tiling_result, preprocessing: PreprocessingConfig, ) -> Path | None: - if preprocessing.read_tiles_from is not None: - return _tile_store_archive_path(preprocessing.read_tiles_from, slide.sample_id) - return tiling_result.tiles_tar_path if hasattr(tiling_result, "tiles_tar_path") else None + return _resolve_tile_store_archive_for_slide_runtime( + slide_sample_id=slide.sample_id, + tiling_result=tiling_result, + preprocessing=preprocessing, + ) def _tile_store_archive_path(tile_store_root: Path, sample_id: str) -> Path: - root = Path(tile_store_root) - if root.is_file(): - return root - if root.suffix == ".tar" and root.exists(): - return root - return root / f"{sample_id}.tiles.tar" + return _tile_store_archive_path_runtime(tile_store_root, sample_id) def _load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path): - return load_tiling_result(coordinates_npz_path=coordinates_npz_path, coordinates_meta_path=coordinates_meta_path) + return _load_tiling_result_runtime( + coordinates_npz_path, + coordinates_meta_path, + load_tiling_result_fn=load_tiling_result, + ) def _scale_coordinates(coordinates: np.ndarray, base_spacing_um: float, spacing: float) -> np.ndarray: - scale = base_spacing_um / spacing - return (coordinates * scale).astype(int) + return _scale_coordinates_runtime(coordinates, base_spacing_um, spacing) def _resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: - if preprocessing is None: - return "asap" - return preprocessing.backend + return _resolve_tiling_backend_runtime(preprocessing) def _resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_result) -> str: - backend = _resolve_tiling_backend(preprocessing) - if backend != "auto": - return backend - resolved_backend = tiling_result.backend if hasattr(tiling_result, "backend") else None - if isinstance(resolved_backend, str) and resolved_backend and resolved_backend != "auto": - return resolved_backend - return "asap" + return _resolve_slide_backend_runtime(preprocessing, tiling_result) def _resolve_model_preprocessing(model, preprocessing: PreprocessingConfig | None) -> PreprocessingConfig: diff --git a/slide2vec/runtime/embedding.py b/slide2vec/runtime/embedding.py new file mode 100644 index 0000000..01e16d9 --- /dev/null +++ b/slide2vec/runtime/embedding.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable + +import numpy as np + +from slide2vec.api import ExecutionOptions, PreprocessingConfig +from slide2vec.artifacts import ( + HierarchicalEmbeddingArtifact, + SlideEmbeddingArtifact, + TileEmbeddingArtifact, + write_hierarchical_embeddings, + write_slide_embeddings, + write_tile_embeddings, +) + + +def should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: + if model.level in {"slide", "patient"}: + return bool(execution.save_tile_embeddings) + return True + + +def build_tile_embedding_metadata( + model, + *, + tiling_result, + image_path: Path | str, + mask_path: Path | str | None, + tile_size_lv0: int, + backend: str, +) -> dict[str, Any]: + coordinates_npz_path = ( + tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None + ) + coordinates_meta_path = ( + tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None + ) + tiles_tar_path = tiling_result.tiles_tar_path if hasattr(tiling_result, "tiles_tar_path") else None + return { + "encoder_name": model.name, + "encoder_level": model.level, + "coordinates_npz_path": str(coordinates_npz_path or ""), + "coordinates_meta_path": str(coordinates_meta_path or ""), + "tiles_tar_path": str(tiles_tar_path or ""), + "image_path": str(image_path), + "mask_path": str(mask_path) if mask_path is not None else None, + "tile_size_lv0": int(tile_size_lv0), + "backend": backend, + } + + +def build_slide_embedding_metadata(model, *, image_path: Path | str) -> dict[str, Any]: + return { + "encoder_name": model.name, + "encoder_level": model.level, + "image_path": str(image_path), + } + + +def build_hierarchical_embedding_metadata( + model, + *, + tiling_result, + image_path: Path | str, + mask_path: Path | str | None, + backend: str, + preprocessing: PreprocessingConfig, + resolve_hierarchical_geometry_fn: Callable[[PreprocessingConfig, Any], dict[str, int]], +) -> dict[str, Any]: + coordinates_npz_path = ( + tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None + ) + coordinates_meta_path = ( + tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None + ) + geometry = resolve_hierarchical_geometry_fn(preprocessing, tiling_result) + return { + "encoder_name": model.name, + "encoder_level": model.level, + "coordinates_npz_path": str(coordinates_npz_path or ""), + "coordinates_meta_path": str(coordinates_meta_path or ""), + "image_path": str(image_path), + "mask_path": str(mask_path) if mask_path is not None else None, + "backend": backend, + "region_tile_multiple": int(geometry["region_tile_multiple"]), + "requested_tile_size_px": int(geometry["requested_tile_size_px"]), + "read_tile_size_px": int(geometry["read_tile_size_px"]), + "requested_region_size_px": int(geometry["requested_region_size_px"]), + "read_region_size_px": int(geometry["read_region_size_px"]), + "requested_spacing_um": float(preprocessing.requested_spacing_um), + "subtile_order": "row_major", + } + + +def write_tile_embedding_artifact( + sample_id: str, + features, + *, + execution: ExecutionOptions, + metadata: dict[str, Any], + num_rows_fn: Callable[[Any], int], +) -> TileEmbeddingArtifact: + if execution.output_dir is None: + raise ValueError("ExecutionOptions.output_dir is required to persist tile embeddings") + return write_tile_embeddings( + sample_id, + features, + output_dir=execution.output_dir, + output_format=execution.output_format, + metadata=metadata, + tile_index=np.arange(num_rows_fn(features), dtype=np.int64), + ) + + +def write_slide_embedding_artifact( + sample_id: str, + embedding, + *, + execution: ExecutionOptions, + metadata: dict[str, Any], + latents=None, +) -> SlideEmbeddingArtifact: + if execution.output_dir is None: + raise ValueError("ExecutionOptions.output_dir is required to persist slide embeddings") + return write_slide_embeddings( + sample_id, + embedding, + output_dir=execution.output_dir, + output_format=execution.output_format, + metadata=metadata, + latents=latents, + ) + + +def write_hierarchical_embedding_artifact( + sample_id: str, + features, + *, + execution: ExecutionOptions, + metadata: dict[str, Any], +) -> HierarchicalEmbeddingArtifact: + if execution.output_dir is None: + raise ValueError("ExecutionOptions.output_dir is required to persist hierarchical embeddings") + return write_hierarchical_embeddings( + sample_id, + features, + output_dir=execution.output_dir, + output_format=execution.output_format, + metadata=metadata, + ) + diff --git a/slide2vec/runtime/tiling.py b/slide2vec/runtime/tiling.py new file mode 100644 index 0000000..ecd5beb --- /dev/null +++ b/slide2vec/runtime/tiling.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable + +import numpy as np + +from slide2vec.api import PreprocessingConfig + + +def resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: + if preprocessing is None: + return "asap" + return preprocessing.backend + + +def resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_result) -> str: + backend = resolve_tiling_backend(preprocessing) + if backend != "auto": + return backend + resolved_backend = tiling_result.backend if hasattr(tiling_result, "backend") else None + if isinstance(resolved_backend, str) and resolved_backend and resolved_backend != "auto": + return resolved_backend + return "asap" + + +def build_preview_config(preview: dict[str, Any], *, preview_config_cls): + return preview_config_cls( + save_mask_preview=bool(preview["save_mask_preview"]), + save_tiling_preview=bool(preview["save_tiling_preview"]), + downsample=int(preview["downsample"]), + tissue_contour_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), + mask_overlay_alpha=float(preview["mask_overlay_alpha"]), + ) + + +def build_hs2p_configs( + preprocessing: PreprocessingConfig, + *, + is_hierarchical_preprocessing_fn: Callable[[PreprocessingConfig | None], bool], + resolve_tiling_backend_fn: Callable[[PreprocessingConfig | None], str], + tiling_config_cls, + segmentation_config_cls, + filter_config_cls, + preview_config_cls, +): + requested_tile_size_px = ( + preprocessing.requested_region_size_px + if is_hierarchical_preprocessing_fn(preprocessing) + else preprocessing.requested_tile_size_px + ) + tiling_cfg = tiling_config_cls( + backend=resolve_tiling_backend_fn(preprocessing), + requested_spacing_um=preprocessing.requested_spacing_um, + requested_tile_size_px=requested_tile_size_px, + tolerance=preprocessing.tolerance, + overlap=preprocessing.overlap, + tissue_threshold=preprocessing.tissue_threshold, + ) + segmentation_cfg = segmentation_config_cls(**dict(preprocessing.segmentation)) + filtering_cfg = filter_config_cls(**dict(preprocessing.filtering)) + preview_cfg = build_preview_config(dict(preprocessing.preview), preview_config_cls=preview_config_cls) + return ( + tiling_cfg, + segmentation_cfg, + filtering_cfg, + preview_cfg, + preprocessing.read_coordinates_from, + preprocessing.resume, + ) + + +def tile_store_archive_path(tile_store_root: Path, sample_id: str) -> Path: + root = Path(tile_store_root) + if root.is_file(): + return root + if root.suffix == ".tar" and root.exists(): + return root + return root / f"{sample_id}.tiles.tar" + + +def resolve_tile_store_archive_for_slide( + *, + slide_sample_id: str, + tiling_result, + preprocessing: PreprocessingConfig, +) -> Path | None: + if preprocessing.read_tiles_from is not None: + return tile_store_archive_path(preprocessing.read_tiles_from, slide_sample_id) + return tiling_result.tiles_tar_path if hasattr(tiling_result, "tiles_tar_path") else None + + +def load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path, *, load_tiling_result_fn): + return load_tiling_result_fn(coordinates_npz_path=coordinates_npz_path, coordinates_meta_path=coordinates_meta_path) + + +def scale_coordinates(coordinates: np.ndarray, base_spacing_um: float, spacing: float) -> np.ndarray: + scale = base_spacing_um / spacing + return (coordinates * scale).astype(int) + diff --git a/tests/test_architecture_runtime_split.py b/tests/test_architecture_runtime_split.py index 06c7346..b937912 100644 --- a/tests/test_architecture_runtime_split.py +++ b/tests/test_architecture_runtime_split.py @@ -3,15 +3,33 @@ from pathlib import Path +def test_runtime_modules_do_not_depend_on_cli_or_package_facade(): + package_root = Path(__file__).resolve().parents[1] / "slide2vec" / "runtime" + runtime_modules = sorted(package_root.glob("*.py")) + forbidden_fragments = [ + "from slide2vec import", + "import slide2vec.cli", + "from slide2vec.cli import", + "import slide2vec.__init__", + "from slide2vec.__init__ import", + ] + for module_path in runtime_modules: + source = module_path.read_text(encoding="utf-8") + for fragment in forbidden_fragments: + assert fragment not in source, f"{module_path.name} should not import public CLI/facade modules" + + def test_internal_runtime_modules_stay_small(): package_root = Path(__file__).resolve().parents[1] / "slide2vec" / "runtime" module_paths = [ package_root / "batching.py", package_root / "distributed.py", + package_root / "embedding.py", package_root / "hierarchical.py", package_root / "persistence.py", package_root / "progress_bridge.py", package_root / "serialization.py", + package_root / "tiling.py", package_root / "types.py", ] for module_path in module_paths: From a50259eb9fe403da05b1379a1837494d463b6f13 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:45:21 +0200 Subject: [PATCH 05/11] Declutter root package by relocating internal modules --- docs/documentation.md | 1 + slide2vec/api.py | 4 ++-- slide2vec/configs/__init__.py | 2 +- slide2vec/{ => configs}/resources.py | 3 ++- slide2vec/encoders/registry.py | 2 +- slide2vec/encoders/validation.py | 2 +- slide2vec/inference.py | 4 ++-- slide2vec/runtime/batching.py | 2 +- slide2vec/{ => runtime}/model_settings.py | 1 + slide2vec/{ => runtime}/registry.py | 1 + slide2vec/runtime/types.py | 11 +++++++++++ slide2vec/runtime_types.py | 14 -------------- slide2vec/utils/config.py | 2 +- tests/test_architecture_runtime_split.py | 17 +++++++++++++++++ tests/test_regression_core.py | 2 +- tests/test_regression_inference.py | 2 +- tests/test_regression_models.py | 2 +- 17 files changed, 45 insertions(+), 27 deletions(-) rename slide2vec/{ => configs}/resources.py (99%) rename slide2vec/{ => runtime}/model_settings.py (99%) rename slide2vec/{ => runtime}/registry.py (99%) delete mode 100644 slide2vec/runtime_types.py diff --git a/docs/documentation.md b/docs/documentation.md index 0bbb24c..2474fb7 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -7,6 +7,7 @@ - Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. - Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. - Extracted pure tiling and embedding metadata/writer helpers into `slide2vec.runtime.tiling` and `slide2vec.runtime.embedding`, keeping inference-level wrappers so existing monkeypatch-based regression tests remain stable. +- Moved root-level internal helpers (`runtime_types.py`, `model_settings.py`, `registry.py`, `resources.py`) into clearer homes: `slide2vec.runtime.*` and `slide2vec.configs.resources`, and added a guardrail test that keeps the root package module list intentionally minimal. - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. diff --git a/slide2vec/api.py b/slide2vec/api.py index 3465634..5848d50 100644 --- a/slide2vec/api.py +++ b/slide2vec/api.py @@ -20,9 +20,9 @@ resolve_preprocessing_defaults, ) from slide2vec.encoders.validation import validate_encoder_config -from slide2vec.model_settings import canonicalize_model_name, normalize_precision_name +from slide2vec.runtime.model_settings import canonicalize_model_name, normalize_precision_name from slide2vec.progress import emit_progress -from slide2vec.runtime_types import LoadedModel +from slide2vec.runtime.types import LoadedModel from slide2vec.utils.utils import cpu_worker_limit, slurm_cpu_limit PathLike = str | Path diff --git a/slide2vec/configs/__init__.py b/slide2vec/configs/__init__.py index 3532d60..2085613 100644 --- a/slide2vec/configs/__init__.py +++ b/slide2vec/configs/__init__.py @@ -1,4 +1,4 @@ -from slide2vec.resources import load_config +from slide2vec.configs.resources import load_config default_config = load_config("default") diff --git a/slide2vec/resources.py b/slide2vec/configs/resources.py similarity index 99% rename from slide2vec/resources.py rename to slide2vec/configs/resources.py index 034b43e..13f96a0 100644 --- a/slide2vec/resources.py +++ b/slide2vec/configs/resources.py @@ -1,7 +1,7 @@ +from contextlib import contextmanager from importlib.resources import as_file, files from pathlib import Path from typing import Iterator -from contextlib import contextmanager def config_resource(*parts: str): @@ -24,3 +24,4 @@ def config_path(*parts: str) -> Iterator[Path]: resource = config_resource(*parts) with as_file(resource) as resolved: yield resolved + diff --git a/slide2vec/encoders/registry.py b/slide2vec/encoders/registry.py index b2772b5..fadff7c 100644 --- a/slide2vec/encoders/registry.py +++ b/slide2vec/encoders/registry.py @@ -2,7 +2,7 @@ from typing import Any -from slide2vec.registry import Registry +from slide2vec.runtime.registry import Registry encoder_registry = Registry("encoders") diff --git a/slide2vec/encoders/validation.py b/slide2vec/encoders/validation.py index bb5080e..bbc94e1 100644 --- a/slide2vec/encoders/validation.py +++ b/slide2vec/encoders/validation.py @@ -8,7 +8,7 @@ resolve_encoder_output, resolve_preprocessing_requirements, ) -from slide2vec.model_settings import normalize_precision_name +from slide2vec.runtime.model_settings import normalize_precision_name logger = logging.getLogger("slide2vec") diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 6e7b9c4..3572256 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -109,8 +109,8 @@ resolve_encoder_output, resolve_preprocessing_defaults, ) -from slide2vec.model_settings import canonicalize_model_name -from slide2vec.runtime_types import LoadedModel +from slide2vec.runtime.model_settings import canonicalize_model_name +from slide2vec.runtime.types import LoadedModel from slide2vec.progress import ( emit_progress, read_tiling_progress_snapshot, diff --git a/slide2vec/runtime/batching.py b/slide2vec/runtime/batching.py index 4cae17c..c02cdd0 100644 --- a/slide2vec/runtime/batching.py +++ b/slide2vec/runtime/batching.py @@ -9,7 +9,7 @@ from transformers.image_processing_utils import BaseImageProcessor from slide2vec.progress import emit_progress -from slide2vec.runtime_types import LoadedModel +from slide2vec.runtime.types import LoadedModel from slide2vec.utils.log_utils import suppress_c_stderr from .types import BatchTransformSpec, PreparedBatch diff --git a/slide2vec/model_settings.py b/slide2vec/runtime/model_settings.py similarity index 99% rename from slide2vec/model_settings.py rename to slide2vec/runtime/model_settings.py index 095ccfd..3078cb7 100644 --- a/slide2vec/model_settings.py +++ b/slide2vec/runtime/model_settings.py @@ -45,3 +45,4 @@ def normalize_precision_name(value: Any) -> str | None: def canonicalize_model_name(name: str) -> str: normalized = name.strip().lower() return MODEL_NAME_ALIASES.get(normalized, normalized) + diff --git a/slide2vec/registry.py b/slide2vec/runtime/registry.py similarity index 99% rename from slide2vec/registry.py rename to slide2vec/runtime/registry.py index 62fa50c..8db8972 100644 --- a/slide2vec/registry.py +++ b/slide2vec/runtime/registry.py @@ -68,3 +68,4 @@ class _Entry: def __init__(self, cls: type, metadata: dict[str, Any]) -> None: self.cls = cls self.metadata = metadata + diff --git a/slide2vec/runtime/types.py b/slide2vec/runtime/types.py index 2e67823..ddb9737 100644 --- a/slide2vec/runtime/types.py +++ b/slide2vec/runtime/types.py @@ -2,6 +2,7 @@ from typing import Any import numpy as np +import torch @dataclass(frozen=True, kw_only=True) @@ -35,3 +36,13 @@ class HierarchicalIndex: num_regions: int tiles_per_region: int + +@dataclass(kw_only=True) +class LoadedModel: + name: str + level: str + model: object + transforms: object + feature_dim: int + device: torch.device + tile_feature_dim: int | None = None diff --git a/slide2vec/runtime_types.py b/slide2vec/runtime_types.py deleted file mode 100644 index 2260487..0000000 --- a/slide2vec/runtime_types.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -import torch - - -@dataclass(kw_only=True) -class LoadedModel: - name: str - level: str - model: object - transforms: object - feature_dim: int - device: torch.device - tile_feature_dim: int | None = None diff --git a/slide2vec/utils/config.py b/slide2vec/utils/config.py index 25ebf42..64e3f67 100644 --- a/slide2vec/utils/config.py +++ b/slide2vec/utils/config.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf import slide2vec.distributed as distributed -from slide2vec.model_settings import canonicalize_model_name +from slide2vec.runtime.model_settings import canonicalize_model_name from slide2vec.utils import initialize_wandb, fix_random_seeds, get_sha, setup_logging from slide2vec.configs import default_config diff --git a/tests/test_architecture_runtime_split.py b/tests/test_architecture_runtime_split.py index b937912..f466491 100644 --- a/tests/test_architecture_runtime_split.py +++ b/tests/test_architecture_runtime_split.py @@ -3,6 +3,21 @@ from pathlib import Path +def test_root_package_python_modules_are_curated(): + package_root = Path(__file__).resolve().parents[1] / "slide2vec" + root_modules = {path.name for path in package_root.glob("*.py")} + assert root_modules == { + "__init__.py", + "__main__.py", + "api.py", + "artifacts.py", + "cli.py", + "inference.py", + "main.py", + "progress.py", + } + + def test_runtime_modules_do_not_depend_on_cli_or_package_facade(): package_root = Path(__file__).resolve().parents[1] / "slide2vec" / "runtime" runtime_modules = sorted(package_root.glob("*.py")) @@ -26,8 +41,10 @@ def test_internal_runtime_modules_stay_small(): package_root / "distributed.py", package_root / "embedding.py", package_root / "hierarchical.py", + package_root / "model_settings.py", package_root / "persistence.py", package_root / "progress_bridge.py", + package_root / "registry.py", package_root / "serialization.py", package_root / "tiling.py", package_root / "types.py", diff --git a/tests/test_regression_core.py b/tests/test_regression_core.py index 58d8799..98d9349 100644 --- a/tests/test_regression_core.py +++ b/tests/test_regression_core.py @@ -20,7 +20,7 @@ write_slide_embeddings, write_tile_embeddings, ) -from slide2vec.resources import config_resource, load_config +from slide2vec.configs.resources import config_resource, load_config ROOT = Path(__file__).resolve().parents[1] DEFAULT_PREPROCESSING = PreprocessingConfig(requested_spacing_um=0.5, requested_tile_size_px=224) diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index 011825e..dfcd28d 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -24,7 +24,7 @@ write_tile_embedding_metadata, write_tile_embeddings, ) -from slide2vec.resources import config_resource, load_config +from slide2vec.configs.resources import config_resource, load_config ROOT = Path(__file__).resolve().parents[1] diff --git a/tests/test_regression_models.py b/tests/test_regression_models.py index 0290e62..1f2ec13 100644 --- a/tests/test_regression_models.py +++ b/tests/test_regression_models.py @@ -17,7 +17,7 @@ write_slide_embeddings, write_tile_embeddings, ) -from slide2vec.resources import config_resource, load_config +from slide2vec.configs.resources import config_resource, load_config ROOT = Path(__file__).resolve().parents[1] DEFAULT_PREPROCESSING = PreprocessingConfig(requested_spacing_um=0.5, requested_tile_size_px=224) From 650d62890f9c722d4b1d19de766b886018e4fc13 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:48:46 +0200 Subject: [PATCH 06/11] Trim secondary tests to keep suite focused --- docs/documentation.md | 1 + tests/test_batch_collator_timing.py | 161 ------------------------ tests/test_docs.py | 45 ------- tests/test_output_consistency.py | 186 ---------------------------- tests/test_packaging_metadata.py | 23 ---- 5 files changed, 1 insertion(+), 415 deletions(-) delete mode 100644 tests/test_batch_collator_timing.py delete mode 100644 tests/test_docs.py delete mode 100644 tests/test_output_consistency.py delete mode 100644 tests/test_packaging_metadata.py diff --git a/docs/documentation.md b/docs/documentation.md index 2474fb7..9a9c44e 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -8,6 +8,7 @@ - Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. - Extracted pure tiling and embedding metadata/writer helpers into `slide2vec.runtime.tiling` and `slide2vec.runtime.embedding`, keeping inference-level wrappers so existing monkeypatch-based regression tests remain stable. - Moved root-level internal helpers (`runtime_types.py`, `model_settings.py`, `registry.py`, `resources.py`) into clearer homes: `slide2vec.runtime.*` and `slide2vec.configs.resources`, and added a guardrail test that keeps the root package module list intentionally minimal. +- Trimmed secondary/low-signal tests (docs build, packaging metadata smoke, collator timing micro-tests, and heavyweight output-consistency smoke) to keep the test suite focused on high-signal core regression coverage. - Aligned slide2vec with hs2p 4.0.0's unified tiling/sampling contract by preserving the new `annotation` column in process lists and translating preview configs to hs2p's `save_mask_preview` / `save_tiling_preview` / `tissue_contour_color` fields. diff --git a/tests/test_batch_collator_timing.py b/tests/test_batch_collator_timing.py deleted file mode 100644 index 0bad9a7..0000000 --- a/tests/test_batch_collator_timing.py +++ /dev/null @@ -1,161 +0,0 @@ -from pathlib import Path -from types import SimpleNamespace - -import numpy as np -import pytest - - -def test_batch_tile_collator_emits_worker_and_reader_timing(monkeypatch: pytest.MonkeyPatch): - torch = pytest.importorskip("torch") - from slide2vec.data import dataset - - timings = { - "reader_open_ms": 1.25, - "reader_read_ms": 8.5, - } - - class FakeReader: - def __init__(self, tar_path: Path, tile_size_px: int): - self.tar_path = tar_path - self.tile_size_px = tile_size_px - - def read_batch_with_timing(self, tile_indices): - tensor = torch.zeros((len(tile_indices), 3, self.tile_size_px, self.tile_size_px), dtype=torch.uint8) - return tensor, dict(timings) - - monkeypatch.setattr(dataset, "TarTileReader", FakeReader) - - collator = dataset.BatchTileCollator( - tar_path=Path("/tmp/fake.tiles.tar"), - tiling_result=SimpleNamespace(requested_tile_size_px=4), - ) - - indices, tensor, timing = collator([2, 5]) - - assert indices.tolist() == [2, 5] - assert tuple(tensor.shape) == (2, 3, 4, 4) - assert timing["reader_open_ms"] == pytest.approx(1.25) - assert timing["reader_read_ms"] == pytest.approx(8.5) - assert timing["worker_batch_ms"] >= 0.0 - - -def test_on_the_fly_collator_emits_worker_and_reader_timing(monkeypatch: pytest.MonkeyPatch): - torch = pytest.importorskip("torch") - import slide2vec.data.tile_reader as tile_reader - - class FakeReader: - ordered_indices = None - - def __init__(self, image_path, tiling_result, *, backend: str, num_cucim_workers: int, gpu_decode: bool, use_supertiles: bool): - self.tile_size = int(tiling_result.read_tile_size_px) - - def read_batch_with_timing(self, tile_indices): - tensor = torch.zeros((len(tile_indices), 3, self.tile_size, self.tile_size), dtype=torch.uint8) - return tensor, {"reader_open_ms": 2.0, "reader_read_ms": 7.25} - - monkeypatch.setattr(tile_reader, "WSITileReader", FakeReader) - - collator = tile_reader.OnTheFlyBatchTileCollator( - image_path=Path("/tmp/fake.svs"), - tiling_result=SimpleNamespace(read_tile_size_px=4), - backend="cucim", - num_cucim_workers=4, - gpu_decode=False, - use_supertiles=False, - ) - - indices, tensor, timing = collator([0, 4]) - - assert indices.tolist() == [0, 4] - assert tuple(tensor.shape) == (2, 3, 4, 4) - assert timing["reader_open_ms"] == pytest.approx(2.0) - assert timing["reader_read_ms"] == pytest.approx(7.25) - assert timing["worker_batch_ms"] >= 0.0 - - -def test_wsi_tile_reader_suppresses_native_stderr_for_cucim(monkeypatch: pytest.MonkeyPatch): - torch = pytest.importorskip("torch") - import slide2vec.data.tile_reader as tile_reader - - calls: list[str] = [] - - class _FakeSuppress: - def __enter__(self): - calls.append("enter") - return self - - def __exit__(self, *args): - calls.append("exit") - return False - - class FakeBackendReader: - def read_regions(self, locations, level, size, *, num_workers=None): - del locations, level, num_workers - width, height = size - return [np.zeros((height, width, 3), dtype=np.uint8), np.zeros((height, width, 3), dtype=np.uint8)] - - monkeypatch.setattr(tile_reader, "_open_wsi_backend", lambda *args, **kwargs: FakeBackendReader()) - monkeypatch.setattr(tile_reader, "suppress_c_stderr", lambda: _FakeSuppress()) - reader = tile_reader.WSITileReader( - Path("/tmp/fake.svs"), - SimpleNamespace( - read_tile_size_px=4, - read_level=0, - x=np.array([0, 4]), - y=np.array([0, 0]), - ), - backend="cucim", - num_cucim_workers=4, - gpu_decode=False, - use_supertiles=False, - ) - - tensor, timing = reader.read_batch_with_timing(np.array([0, 1], dtype=np.int64)) - - assert tuple(tensor.shape) == (2, 3, 4, 4) - assert timing["reader_open_ms"] >= 0.0 - assert timing["reader_read_ms"] >= 0.0 - assert calls == ["enter", "exit"] - - -def test_on_the_fly_collator_filters_native_stderr_for_cucim(monkeypatch: pytest.MonkeyPatch): - torch = pytest.importorskip("torch") - import slide2vec.data.tile_reader as tile_reader - - calls: list[str] = [] - - class FakeReader: - ordered_indices = None - _backend = "cucim" - - def __init__(self, image_path, tiling_result, *, backend: str, num_cucim_workers: int, gpu_decode: bool, use_supertiles: bool): - self.tile_size = int(tiling_result.read_tile_size_px) - - def read_batch_with_timing(self, tile_indices): - tensor = torch.zeros((len(tile_indices), 3, self.tile_size, self.tile_size), dtype=torch.uint8) - return tensor, {"reader_open_ms": 2.0, "reader_read_ms": 7.25} - - def _fake_run_with_filtered_stderr(func, *, suppress_patterns=()): - del suppress_patterns - calls.append("filtered") - return func() - - monkeypatch.setattr(tile_reader, "WSITileReader", FakeReader) - monkeypatch.setattr(tile_reader, "run_with_filtered_stderr", _fake_run_with_filtered_stderr) - - collator = tile_reader.OnTheFlyBatchTileCollator( - image_path=Path("/tmp/fake.svs"), - tiling_result=SimpleNamespace(read_tile_size_px=4), - backend="cucim", - num_cucim_workers=4, - gpu_decode=False, - use_supertiles=False, - ) - - indices, tensor, timing = collator([0, 4]) - - assert indices.tolist() == [0, 4] - assert tuple(tensor.shape) == (2, 3, 4, 4) - assert timing["reader_open_ms"] == pytest.approx(2.0) - assert timing["reader_read_ms"] == pytest.approx(7.25) - assert calls == ["filtered"] diff --git a/tests/test_docs.py b/tests/test_docs.py deleted file mode 100644 index 0f6733e..0000000 --- a/tests/test_docs.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import importlib.util -from pathlib import Path - -import pytest - - -def _load_reference_generator(): - docs_dir = Path(__file__).resolve().parents[1] / "docs" - module_path = docs_dir / "_generate_reference.py" - spec = importlib.util.spec_from_file_location("_generate_reference", module_path) - if spec is None or spec.loader is None: - raise RuntimeError(f"Unable to load {module_path}") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module, docs_dir - - -def test_reference_generator_matches_checked_in_file() -> None: - generator, docs_dir = _load_reference_generator() - generated = generator.build_reference_rst().strip() - checked_in = (docs_dir / "reference.rst").read_text(encoding="utf-8").strip() - - assert generated == checked_in - assert "Compact Reference" in generated - assert "Main entry points" in generated - assert "Registered presets" in generated - - -def test_sphinx_docs_build(tmp_path: Path) -> None: - pytest.importorskip("sphinx") - from sphinx.cmd.build import build_main - - docs_dir = Path(__file__).resolve().parents[1] / "docs" - out_dir = tmp_path / "html" - status = build_main(["-W", "-b", "html", str(docs_dir), str(out_dir)]) - - assert status == 0 - index_html = (out_dir / "index.html").read_text(encoding="utf-8") - assert "Made with" not in index_html - assert "@pradyunsg" not in index_html - assert (out_dir / "index.html").exists() - assert (out_dir / "reference.html").exists() - diff --git a/tests/test_output_consistency.py b/tests/test_output_consistency.py deleted file mode 100644 index b00e2dd..0000000 --- a/tests/test_output_consistency.py +++ /dev/null @@ -1,186 +0,0 @@ -import os -import json -import subprocess -import sys -from pathlib import Path - -import numpy as np -import pytest - -torch = pytest.importorskip("torch") -OmegaConf = pytest.importorskip("omegaconf").OmegaConf - -# --------------------------------------------------------------------------- -# Hardcoded pipeline parameters -# --------------------------------------------------------------------------- - -# -- tiling.params -- -TILING_PARAMS = dict( - requested_spacing_um=0.5, - tolerance=0.07, # override (default: 0.05) - requested_tile_size_px=224, # override (default: 256) - overlap=0.0, - tissue_threshold=0.1, # override (default: 0.01) -) - -# -- tiling.seg_params -- -TILING_SEG_PARAMS = dict( - downsample=64, # override (default: 16) - sthresh=8, - sthresh_up=255, - mthresh=7, - close=4, - method="hsv", -) - -# -- tiling.filter_params -- -TILING_FILTER_PARAMS = dict( - ref_tile_size=224, # override (default: 16) - a_t=4, - a_h=2, - filter_white=False, - filter_black=False, - white_threshold=220, - black_threshold=25, - fraction_threshold=0.9, -) - -# -- tiling.preview -- -TILING_PREVIEW = dict(save=False, downsample=32) - -# -- model -- -MODEL_PARAMS = dict( - name="prism", # override (default: null) - batch_size=8, # override (default: 256) - save_tile_embeddings=True, - save_slide_embeddings=False, - save_latents=False, -) - -# -- speed -- -SPEED_PARAMS = dict( - precision="fp16", # override (default: fp32) - num_dataloader_workers=0, # keep the Prism subprocess path single-process to avoid worker SHM pressure -) - -# --------------------------------------------------------------------------- -# Paths relative to this test file -# --------------------------------------------------------------------------- -TEST_DIR = Path(__file__).parent -INPUT_DIR = TEST_DIR / "fixtures" / "input" -GT_DIR = TEST_DIR / "fixtures" / "gt" -REPO_ROOT = TEST_DIR.parent - - -@pytest.fixture(scope="module") -def wsi_path() -> Path: - p = INPUT_DIR / "test-wsi.tif" - if not p.is_file(): - pytest.skip(f"Test fixture missing: {p}") - return p - - -@pytest.fixture(scope="module") -def mask_path() -> Path: - p = INPUT_DIR / "test-mask.tif" - if not p.is_file(): - pytest.skip(f"Test fixture missing: {p}") - return p - - -@pytest.mark.skipif( - not os.environ.get("HF_TOKEN"), - reason="HF_TOKEN required for model weight download", -) -def test_output_consistency(wsi_path, mask_path, tmp_path): - """Running the full pipeline with hardcoded params produces x/y coordinates and - embeddings that match the ground truth fixtures in test/gt/.""" - - pytest.importorskip("transformers") - pytest.importorskip("wholeslidedata") - - # 1. Build a temporary CSV with resolved absolute paths - tmp_csv = tmp_path / "test.csv" - tmp_csv.write_text( - f"sample_id,image_path,mask_path\ntest-wsi,{wsi_path},{mask_path}\n" - ) - - # 2. Build config from hardcoded constants (no dependency on test/input/config.yaml) - cfg = OmegaConf.create({ - "csv": str(tmp_csv), - "output_dir": str(tmp_path), - "resume": False, - "resume_dirname": None, - "seed": 0, - "tiling": { - "read_coordinates_from": None, - "read_tiles_from": None, - "on_the_fly": True, - "backend": "asap", - "params": TILING_PARAMS, - "seg_params": TILING_SEG_PARAMS, - "filter_params": TILING_FILTER_PARAMS, - "preview": TILING_PREVIEW, - }, - "model": MODEL_PARAMS, - "speed": SPEED_PARAMS, - "wandb": {"enable": False}, - }) - cfg_path = tmp_path / "config.yaml" - OmegaConf.save(cfg, cfg_path) - - # 3. Run the pipeline - subprocess.run( - [ - "slide2vec", - str(cfg_path), - "--skip-datetime", - "--run-on-cpu", - ], - cwd=REPO_ROOT, - check=True, - ) - - # 4. Assert coordinates match exactly (tiling is deterministic) - gt_coords = np.load(GT_DIR / "test-wsi.coordinates.npz", allow_pickle=False) - coords = np.load(tmp_path / "tiles" / "test-wsi.coordinates.npz", allow_pickle=False) - np.testing.assert_array_equal(coords, gt_coords) - - meta = json.loads((tmp_path / "tiles" / "test-wsi.coordinates.meta.json").read_text()) - assert meta["provenance"]["sample_id"] == "test-wsi" - assert meta["provenance"]["backend"] == "asap" - assert meta["tiling"]["requested_spacing_um"] == pytest.approx(0.5) - assert meta["tiling"]["requested_tile_size_px"] == 224 - - # 5. Assert slide embeddings are within tolerance - gt_emb = torch.load(GT_DIR / "test-wsi.pt", map_location="cpu", weights_only=True) - emb = torch.load(tmp_path / "slide_embeddings" / "test-wsi.pt", map_location="cpu", weights_only=True) - assert emb.shape == gt_emb.shape, f"Shape mismatch: {emb.shape} vs {gt_emb.shape}" - - cos = torch.nn.functional.cosine_similarity(emb, gt_emb, dim=-1) - mean_cos = float(cos.mean()) - atol, rtol = 1e-2, 1e-3 - if not torch.allclose(emb, gt_emb, atol=atol, rtol=rtol): - assert mean_cos >= 0.99, ( - f"Embedding mismatch: mean cosine similarity={mean_cos:.4f} " - f"(atol={atol}, rtol={rtol})" - ) - else: - print(f"OK: slide embeddings within tolerance; mean cosine similarity={mean_cos:.4f}") - - # 6. Assert tile-level embeddings match ground truth (verifies tile ordering) - gt_tile_emb = torch.load(GT_DIR / "test-wsi.tiles.pt", map_location="cpu", weights_only=True) - tile_emb = torch.load(tmp_path / "tile_embeddings" / "test-wsi.pt", map_location="cpu", weights_only=True) - assert tile_emb.shape == gt_tile_emb.shape, ( - f"Tile embedding shape mismatch: {tile_emb.shape} vs {gt_tile_emb.shape}" - ) - tile_cos = torch.nn.functional.cosine_similarity(tile_emb, gt_tile_emb, dim=-1) - mean_tile_cos = float(tile_cos.mean()) - atol, rtol = 1e-2, 1e-3 - if not torch.allclose(tile_emb, gt_tile_emb, atol=atol, rtol=rtol): - assert mean_tile_cos >= 0.99, ( - f"Tile embedding mismatch: mean cosine similarity={mean_tile_cos:.4f} " - f"(atol={atol}, rtol={rtol})" - ) - else: - print(f"OK: tile embeddings within tolerance; mean cosine similarity={mean_tile_cos:.4f}") diff --git a/tests/test_packaging_metadata.py b/tests/test_packaging_metadata.py deleted file mode 100644 index 85957f6..0000000 --- a/tests/test_packaging_metadata.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -try: - import tomllib -except ModuleNotFoundError: # pragma: no cover - Python < 3.11 test environments - import tomli as tomllib - - -def test_optional_dependencies_do_not_publish_direct_vcs_urls(): - pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml" - pyproject = tomllib.loads(pyproject_path.read_text(encoding="utf-8")) - - optional_dependencies = pyproject["project"]["optional-dependencies"] - published_dependency_strings = [ - requirement - for requirements in optional_dependencies.values() - for requirement in requirements - ] - - assert published_dependency_strings - assert all(" @ git+" not in requirement for requirement in published_dependency_strings) From 35d53b972aab996c51f73071fb435eb5adb97312 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:55:57 +0200 Subject: [PATCH 07/11] Fix docs workflow after docs test suite trim --- .github/workflows/docs.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 0924afa..24d0661 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -38,9 +38,6 @@ jobs: python -m pip install --upgrade pip pip install -e ".[testing,docs]" - - name: Run docs smoke test - run: python -m pytest -q -o addopts= -p no:cov tests/test_docs.py - - name: Build Sphinx site run: python -m sphinx -W -b html docs docs/_build/html From 57a30dceb1feeb233ce8c27aa94a916b6cee51f7 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 21:58:55 +0200 Subject: [PATCH 08/11] Reduce unnecessary import aliasing in inference --- slide2vec/inference.py | 106 ++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 70 deletions(-) diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 3572256..fba19dc 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -32,29 +32,9 @@ resize_image_batch as _resize_image_batch, resolve_device as _resolve_device, run_forward_pass as _run_forward_pass, - should_suppress_cucim_dataloader_stderr as _should_suppress_cucim_dataloader_stderr, - uses_cuda_runtime as _uses_cuda_runtime, -) -from slide2vec.runtime.embedding import ( - build_hierarchical_embedding_metadata as _build_hierarchical_embedding_metadata_runtime, - build_slide_embedding_metadata as _build_slide_embedding_metadata_runtime, - build_tile_embedding_metadata as _build_tile_embedding_metadata_runtime, - should_persist_tile_embeddings as _should_persist_tile_embeddings_runtime, - write_hierarchical_embedding_artifact as _write_hierarchical_embedding_artifact_runtime, - write_slide_embedding_artifact as _write_slide_embedding_artifact_runtime, - write_tile_embedding_artifact as _write_tile_embedding_artifact_runtime, -) -from slide2vec.runtime.distributed import ( - assign_slides_to_ranks as _assign_slides_to_ranks_runtime, - distributed_coordination_dir as _distributed_coordination_dir_runtime, - load_embedded_slide_payload as _load_embedded_slide_payload_runtime, - load_hierarchical_embedding_shards as _load_hierarchical_embedding_shards_runtime, - load_tile_embedding_shards as _load_tile_embedding_shards_runtime, - merge_hierarchical_embedding_shards as _merge_hierarchical_embedding_shards_runtime, - merge_tile_embedding_shards as _merge_tile_embedding_shards_runtime, - reset_progress_event_logs as _reset_progress_event_logs_runtime, - run_torchrun_worker as _run_torchrun_worker_runtime, ) +import slide2vec.runtime.embedding as runtime_embedding +import slide2vec.runtime.distributed as runtime_distributed from slide2vec.runtime.hierarchical import ( build_hierarchical_index as _build_hierarchical_index, is_hierarchical_preprocessing as _is_hierarchical_preprocessing, @@ -69,23 +49,8 @@ from slide2vec.runtime.progress_bridge import ( bridge_hs2p_progress_to_slide2vec as _bridge_hs2p_progress_to_slide2vec, ) -from slide2vec.runtime.serialization import ( - deserialize_execution, - deserialize_preprocessing, - serialize_execution as _serialize_execution_base, - serialize_model as _serialize_model, - serialize_preprocessing as _serialize_preprocessing, -) -from slide2vec.runtime.tiling import ( - build_hs2p_configs as _build_hs2p_configs_runtime, - build_preview_config as _build_preview_config_runtime, - load_tiling_result as _load_tiling_result_runtime, - resolve_slide_backend as _resolve_slide_backend_runtime, - resolve_tile_store_archive_for_slide as _resolve_tile_store_archive_for_slide_runtime, - resolve_tiling_backend as _resolve_tiling_backend_runtime, - scale_coordinates as _scale_coordinates_runtime, - tile_store_archive_path as _tile_store_archive_path_runtime, -) +import slide2vec.runtime.serialization as runtime_serialization +import slide2vec.runtime.tiling as runtime_tiling from slide2vec.api import ( EmbeddedPatient, EmbeddedSlide, @@ -129,6 +94,11 @@ ) from slide2vec.utils.utils import cpu_worker_limit, slurm_cpu_limit +deserialize_execution = runtime_serialization.deserialize_execution +deserialize_preprocessing = runtime_serialization.deserialize_preprocessing +_serialize_model = runtime_serialization.serialize_model +_serialize_preprocessing = runtime_serialization.serialize_preprocessing + def _serialize_execution( execution: ExecutionOptions, @@ -138,7 +108,7 @@ def _serialize_execution( effective_num_workers = None if preprocessing is not None and preprocessing.on_the_fly and preprocessing.read_tiles_from is None: effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) - return _serialize_execution_base( + return runtime_serialization.serialize_execution( execution, effective_num_workers=effective_num_workers, ) @@ -1890,7 +1860,7 @@ def _build_tile_embedding_metadata( tile_size_lv0: int, backend: str, ) -> dict[str, Any]: - return _build_tile_embedding_metadata_runtime( + return runtime_embedding.build_tile_embedding_metadata( model, tiling_result=tiling_result, image_path=image_path, @@ -1901,7 +1871,7 @@ def _build_tile_embedding_metadata( def _build_slide_embedding_metadata(model, *, image_path: Path | str) -> dict[str, Any]: - return _build_slide_embedding_metadata_runtime(model, image_path=image_path) + return runtime_embedding.build_slide_embedding_metadata(model, image_path=image_path) def _build_hierarchical_embedding_metadata( @@ -1913,7 +1883,7 @@ def _build_hierarchical_embedding_metadata( backend: str, preprocessing: PreprocessingConfig, ) -> dict[str, Any]: - return _build_hierarchical_embedding_metadata_runtime( + return runtime_embedding.build_hierarchical_embedding_metadata( model, tiling_result=tiling_result, image_path=image_path, @@ -1931,7 +1901,7 @@ def _write_tile_embedding_artifact( execution: ExecutionOptions, metadata: dict[str, Any], ) -> TileEmbeddingArtifact: - return _write_tile_embedding_artifact_runtime( + return runtime_embedding.write_tile_embedding_artifact( sample_id, features, execution=execution, @@ -1948,7 +1918,7 @@ def _write_slide_embedding_artifact( metadata: dict[str, Any], latents=None, ) -> SlideEmbeddingArtifact: - return _write_slide_embedding_artifact_runtime( + return runtime_embedding.write_slide_embedding_artifact( sample_id, embedding, execution=execution, @@ -1964,7 +1934,7 @@ def _write_hierarchical_embedding_artifact( execution: ExecutionOptions, metadata: dict[str, Any], ) -> HierarchicalEmbeddingArtifact: - return _write_hierarchical_embedding_artifact_runtime( + return runtime_embedding.write_hierarchical_embedding_artifact( sample_id, features, execution=execution, @@ -2154,7 +2124,7 @@ def _emit_tiling_summary( def _should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: - return _should_persist_tile_embeddings_runtime(model, execution) + return runtime_embedding.should_persist_tile_embeddings(model, execution) def _resolved_process_list_output_variant(model) -> str | None: @@ -2375,11 +2345,11 @@ def _resolve_path_str(value: Any) -> str | None: def _build_preview_config(preview: dict[str, Any]) -> PreviewConfig: - return _build_preview_config_runtime(preview, preview_config_cls=PreviewConfig) + return runtime_tiling.build_preview_config(preview, preview_config_cls=PreviewConfig) def _build_hs2p_configs(preprocessing: PreprocessingConfig): - return _build_hs2p_configs_runtime( + return runtime_tiling.build_hs2p_configs( preprocessing, is_hierarchical_preprocessing_fn=_is_hierarchical_preprocessing, resolve_tiling_backend_fn=_resolve_tiling_backend, @@ -2396,7 +2366,7 @@ def _resolve_tile_store_archive_for_slide( tiling_result, preprocessing: PreprocessingConfig, ) -> Path | None: - return _resolve_tile_store_archive_for_slide_runtime( + return runtime_tiling.resolve_tile_store_archive_for_slide( slide_sample_id=slide.sample_id, tiling_result=tiling_result, preprocessing=preprocessing, @@ -2404,12 +2374,12 @@ def _resolve_tile_store_archive_for_slide( def _tile_store_archive_path(tile_store_root: Path, sample_id: str) -> Path: - return _tile_store_archive_path_runtime(tile_store_root, sample_id) + return runtime_tiling.tile_store_archive_path(tile_store_root, sample_id) def _load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path): - return _load_tiling_result_runtime( + return runtime_tiling.load_tiling_result( coordinates_npz_path, coordinates_meta_path, load_tiling_result_fn=load_tiling_result, @@ -2417,16 +2387,16 @@ def _load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path) def _scale_coordinates(coordinates: np.ndarray, base_spacing_um: float, spacing: float) -> np.ndarray: - return _scale_coordinates_runtime(coordinates, base_spacing_um, spacing) + return runtime_tiling.scale_coordinates(coordinates, base_spacing_um, spacing) def _resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: - return _resolve_tiling_backend_runtime(preprocessing) + return runtime_tiling.resolve_tiling_backend(preprocessing) def _resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_result) -> str: - return _resolve_slide_backend_runtime(preprocessing, tiling_result) + return runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) def _resolve_model_preprocessing(model, preprocessing: PreprocessingConfig | None) -> PreprocessingConfig: @@ -2608,7 +2578,7 @@ def _embed_multi_slides_distributed( @contextmanager def _distributed_coordination_dir(work_dir: Path): - with _distributed_coordination_dir_runtime(work_dir) as coordination_dir: + with runtime_distributed.distributed_coordination_dir(work_dir) as coordination_dir: yield coordination_dir @@ -2656,7 +2626,7 @@ def _run_torchrun_worker( failure_title: str, progress_events_path: Path | None = None, ) -> None: - _run_torchrun_worker_runtime( + runtime_distributed.run_torchrun_worker( module=module, num_gpus=execution.num_gpus, output_dir=output_dir, @@ -2721,20 +2691,16 @@ def _build_direct_embed_worker_request_payload( def _reset_progress_event_logs(progress_events_path: Path) -> None: - _reset_progress_event_logs_runtime(progress_events_path) + runtime_distributed.reset_progress_event_logs(progress_events_path) def _drain_stream_to_buffer(stream, chunks: list[str]) -> None: # Compatibility shim for tests monkeypatching this helper in slide2vec.inference. - from slide2vec.runtime.distributed import drain_stream_to_buffer as _drain_stream_to_buffer_runtime - - _drain_stream_to_buffer_runtime(stream, chunks) + runtime_distributed.drain_stream_to_buffer(stream, chunks) def _write_worker_logs(module: str, output_dir: Path, stdout_text: str, stderr_text: str) -> tuple[Path, Path]: - from slide2vec.runtime.distributed import write_worker_logs as _write_worker_logs_runtime - - return _write_worker_logs_runtime(module, output_dir, stdout_text, stderr_text) + return runtime_distributed.write_worker_logs(module, output_dir, stdout_text, stderr_text) def _assign_slides_to_ranks( @@ -2743,7 +2709,7 @@ def _assign_slides_to_ranks( *, num_gpus: int, ) -> dict[int, list[str]]: - return _assign_slides_to_ranks_runtime( + return runtime_distributed.assign_slides_to_ranks( slide_records, tiling_results, num_gpus=num_gpus, @@ -2752,7 +2718,7 @@ def _assign_slides_to_ranks( def _merge_tile_embedding_shards(shard_payloads): - return _merge_tile_embedding_shards_runtime(shard_payloads) + return runtime_distributed.merge_tile_embedding_shards(shard_payloads) def _merge_hierarchical_embedding_shards( @@ -2761,7 +2727,7 @@ def _merge_hierarchical_embedding_shards( num_regions: int, tiles_per_region: int, ): - return _merge_hierarchical_embedding_shards_runtime( + return runtime_distributed.merge_hierarchical_embedding_shards( shard_payloads, num_regions=num_regions, tiles_per_region=tiles_per_region, @@ -2769,15 +2735,15 @@ def _merge_hierarchical_embedding_shards( def _load_tile_embedding_shards(coordination_dir: Path, sample_id: str): - return _load_tile_embedding_shards_runtime(coordination_dir, sample_id) + return runtime_distributed.load_tile_embedding_shards(coordination_dir, sample_id) def _load_hierarchical_embedding_shards(coordination_dir: Path, sample_id: str): - return _load_hierarchical_embedding_shards_runtime(coordination_dir, sample_id) + return runtime_distributed.load_hierarchical_embedding_shards(coordination_dir, sample_id) def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): - return _load_embedded_slide_payload_runtime(coordination_dir, sample_id) + return runtime_distributed.load_embedded_slide_payload(coordination_dir, sample_id) def load_successful_tiled_slides(output_dir: str | Path) -> tuple[list[SlideSpec], list[Any]]: From 2daf6212f787bd7cbb24abb54d65243ae7a885f7 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 22:10:44 +0200 Subject: [PATCH 09/11] Remove inference passthrough wrappers and call runtime modules directly --- docs/documentation.md | 1 + slide2vec/distributed/direct_embed_worker.py | 3 +- slide2vec/distributed/pipeline_worker.py | 7 +- slide2vec/inference.py | 331 ++++--------------- slide2vec/runtime/distributed.py | 8 +- slide2vec/runtime/embedding.py | 16 +- slide2vec/runtime/tiling.py | 40 +-- tasks/lessons.md | 1 + tests/test_regression_inference.py | 124 +++---- 9 files changed, 154 insertions(+), 377 deletions(-) diff --git a/docs/documentation.md b/docs/documentation.md index 9a9c44e..98b61cc 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -3,6 +3,7 @@ ## 2026-04-18 - Split `slide2vec.inference` into workflow-scoped internal runtime helpers under `slide2vec.runtime` (`batching`, `hierarchical`, `progress_bridge`, `serialization`, `types`) while keeping `slide2vec.inference` as the stable orchestration entrypoint. +- Removed remaining pass-through wrappers from `slide2vec.inference` (tiling/embedding/distributed helpers) and switched workers/tests to import `slide2vec.runtime.*` directly, with simplified runtime helper signatures replacing callback/class injection patterns. - Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. - Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. - Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. diff --git a/slide2vec/distributed/direct_embed_worker.py b/slide2vec/distributed/direct_embed_worker.py index 7328608..e28f4de 100644 --- a/slide2vec/distributed/direct_embed_worker.py +++ b/slide2vec/distributed/direct_embed_worker.py @@ -26,11 +26,10 @@ def main(argv=None) -> int: _compute_tile_embeddings_for_slide, _is_hierarchical_preprocessing, _resolve_hierarchical_geometry, - deserialize_execution, - deserialize_preprocessing, load_successful_tiled_slides, ) from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter + from slide2vec.runtime.serialization import deserialize_execution, deserialize_preprocessing parser = get_args_parser(add_help=True) args = parser.parse_args(argv) diff --git a/slide2vec/distributed/pipeline_worker.py b/slide2vec/distributed/pipeline_worker.py index 9b752ed..bdaa950 100644 --- a/slide2vec/distributed/pipeline_worker.py +++ b/slide2vec/distributed/pipeline_worker.py @@ -3,7 +3,7 @@ import json from pathlib import Path -from slide2vec.inference import _assign_slides_to_ranks +from slide2vec.runtime.distributed import assign_slides_to_ranks def get_args_parser(add_help: bool = True) -> argparse.ArgumentParser: @@ -21,11 +21,10 @@ def main(argv=None) -> int: from slide2vec.inference import ( _compute_embedded_slides, _persist_embedded_slide, - deserialize_execution, - deserialize_preprocessing, load_successful_tiled_slides, ) from slide2vec.progress import JsonlProgressReporter, activate_progress_reporter + from slide2vec.runtime.serialization import deserialize_execution, deserialize_preprocessing parser = get_args_parser(add_help=True) args = parser.parse_args(argv) @@ -48,7 +47,7 @@ def main(argv=None) -> int: preprocessing = deserialize_preprocessing(request["preprocessing"]) execution = deserialize_execution(request["execution"]) slide_records, tiling_results = load_successful_tiled_slides(output_dir) - assignments = _assign_slides_to_ranks(slide_records, tiling_results, num_gpus=world_size) + assignments = assign_slides_to_ranks(slide_records, tiling_results, num_gpus=world_size) assigned_ids = assignments.get(global_rank, []) if not assigned_ids: return 0 diff --git a/slide2vec/inference.py b/slide2vec/inference.py index fba19dc..4dac491 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -14,7 +14,7 @@ import logging import pandas as pd import torch -from hs2p import SlideSpec, FilterConfig, PreviewConfig, SegmentationConfig, TilingConfig, load_tiling_result, tile_slides +from hs2p import SlideSpec, tile_slides from hs2p.utils.stderr import run_with_filtered_stderr import numpy as np @@ -94,12 +94,6 @@ ) from slide2vec.utils.utils import cpu_worker_limit, slurm_cpu_limit -deserialize_execution = runtime_serialization.deserialize_execution -deserialize_preprocessing = runtime_serialization.deserialize_preprocessing -_serialize_model = runtime_serialization.serialize_model -_serialize_preprocessing = runtime_serialization.serialize_preprocessing - - def _serialize_execution( execution: ExecutionOptions, *, @@ -134,7 +128,7 @@ def _log_on_the_fly_worker_override_once( ) -> None: if not preprocessing.on_the_fly or preprocessing.read_tiles_from is not None: return - if not any(_resolve_slide_backend(preprocessing, tiling_result) == "cucim" for tiling_result in tiling_results): + if not any(runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) == "cucim" for tiling_result in tiling_results): return effective_num_workers, worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) if effective_num_workers == execution.num_workers: @@ -357,7 +351,7 @@ def embed_slides( if slide_artifact is not None: slide_artifacts.append(slide_artifact) if process_list_path.is_file(): - persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_tile_embeddings = runtime_embedding.should_persist_tile_embeddings(model, execution) persist_hierarchical_embeddings = _is_hierarchical_preprocessing(preprocessing) include_slide_embeddings = model.level == "slide" _update_process_list_after_embedding( @@ -626,16 +620,16 @@ def embed_tiles( preprocessing=resolved_preprocessing, execution=execution, ) - artifact = _write_hierarchical_embedding_artifact( + artifact = runtime_embedding.write_hierarchical_embedding_artifact( slide.sample_id, features, execution=execution, - metadata=_build_hierarchical_embedding_metadata( + metadata=runtime_embedding.build_hierarchical_embedding_metadata( model, tiling_result=tiling_result, image_path=slide.image_path, mask_path=slide.mask_path, - backend=_resolve_slide_backend(resolved_preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(resolved_preprocessing, tiling_result), preprocessing=resolved_preprocessing, ), ) @@ -648,15 +642,15 @@ def embed_tiles( preprocessing=resolved_preprocessing, execution=execution, ) - metadata = _build_tile_embedding_metadata( + metadata = runtime_embedding.build_tile_embedding_metadata( model, tiling_result=tiling_result, image_path=slide.image_path, mask_path=slide.mask_path, tile_size_lv0=int(tiling_result.tile_size_lv0), - backend=_resolve_slide_backend(resolved_preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(resolved_preprocessing, tiling_result), ) - artifact = _write_tile_embedding_artifact( + artifact = runtime_embedding.write_tile_embedding_artifact( slide.sample_id, features, execution=execution, @@ -689,7 +683,7 @@ def aggregate_tiles( raise ValueError( f"Tile artifact for {artifact.sample_id} is missing tiling metadata paths required for slide aggregation" ) - tiling_result = _load_tiling_result( + tiling_result = runtime_tiling.load_tiling_result_from_paths( Path(metadata["coordinates_npz_path"]), Path(metadata["coordinates_meta_path"]), ) @@ -697,7 +691,7 @@ def aggregate_tiles( coordinates = np.column_stack((x_values, y_values)) image_path = Path(metadata["image_path"]) if model.name == "prov-gigapath": - coordinates = _scale_coordinates( + coordinates = runtime_tiling.scale_coordinates( coordinates, float(tiling_result.base_spacing_um), float(tiling_result.requested_spacing_um), @@ -714,11 +708,11 @@ def aggregate_tiles( tile_size_lv0=int(tiling_result.tile_size_lv0), ) latents = None - slide_artifact = _write_slide_embedding_artifact( + slide_artifact = runtime_embedding.write_slide_embedding_artifact( artifact.sample_id, embedding, execution=execution, - metadata=_build_slide_embedding_metadata(model, image_path=metadata["image_path"]), + metadata=runtime_embedding.build_slide_embedding_metadata(model, image_path=metadata["image_path"]), latents=latents, ) outputs.append(slide_artifact) @@ -862,7 +856,7 @@ def run_pipeline( process_list_path=process_list_path, ) - persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_tile_embeddings = runtime_embedding.should_persist_tile_embeddings(model, execution) persist_hierarchical_embeddings = _is_hierarchical_preprocessing(resolved_preprocessing) include_slide_embeddings = model.level == "slide" include_tile_embeddings = persist_tile_embeddings and not persist_hierarchical_embeddings @@ -1074,17 +1068,17 @@ def _run_patient_pipeline( ) if execution.save_tile_embeddings: - tile_artifact = _write_tile_embedding_artifact( + tile_artifact = runtime_embedding.write_tile_embedding_artifact( slide.sample_id, tile_embeddings, execution=execution, - metadata=_build_tile_embedding_metadata( + metadata=runtime_embedding.build_tile_embedding_metadata( model, tiling_result=tiling_result, image_path=slide.image_path, mask_path=slide.mask_path, tile_size_lv0=int(tiling_result.tile_size_lv0), - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), ), ) tile_artifacts.append(tile_artifact) @@ -1098,11 +1092,11 @@ def _run_patient_pipeline( emit_progress("aggregation.finished", sample_id=slide.sample_id, has_latents=False) if execution.save_slide_embeddings: - slide_artifact = _write_slide_embedding_artifact( + slide_artifact = runtime_embedding.write_slide_embedding_artifact( slide.sample_id, slide_emb, execution=execution, - metadata=_build_slide_embedding_metadata(model, image_path=slide.image_path), + metadata=runtime_embedding.build_slide_embedding_metadata(model, image_path=slide.image_path), ) slide_artifacts.append(slide_artifact) @@ -1179,7 +1173,7 @@ def _build_incremental_persist_callback( if execution.output_dir is None: return None, tile_artifacts, slide_artifacts - persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_tile_embeddings = runtime_embedding.should_persist_tile_embeddings(model, execution) persist_hierarchical_embeddings = _is_hierarchical_preprocessing(preprocessing) include_slide_embeddings = model.level == "slide" @@ -1330,7 +1324,7 @@ def _collect_distributed_pipeline_artifacts( list[HierarchicalEmbeddingArtifact], list[SlideEmbeddingArtifact], ]: - persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_tile_embeddings = runtime_embedding.should_persist_tile_embeddings(model, execution) persist_hierarchical_embeddings = _is_hierarchical_preprocessing(preprocessing) include_slide_embeddings = model.level == "slide" include_tile_embeddings = persist_tile_embeddings and not persist_hierarchical_embeddings @@ -1461,7 +1455,7 @@ def _compute_tile_embeddings_for_slide( return torch.empty((0, int(feature_dim)), dtype=torch.float32) _supertile_reorder = None if preprocessing.on_the_fly and preprocessing.read_tiles_from is None: - resolved_backend = _resolve_slide_backend(preprocessing, tiling_result) + resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) collate_fn = OnTheFlyBatchTileCollator( image_path=slide.image_path, tiling_result=tiling_result, @@ -1488,8 +1482,8 @@ def _compute_tile_embeddings_for_slide( logging.getLogger(__name__).warning( "read_tiles_from is set; ignoring on_the_fly=True and reading tiles from tar archives" ) - tar_path = _resolve_tile_store_archive_for_slide( - slide=slide, + tar_path = runtime_tiling.resolve_tile_store_archive_for_slide( + slide_sample_id=slide.sample_id, tiling_result=tiling_result, preprocessing=preprocessing, ) @@ -1508,7 +1502,7 @@ def _compute_tile_embeddings_for_slide( tiling_result, ) loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) - resolved_backend = _resolve_slide_backend(preprocessing, tiling_result) + resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if preprocessing.on_the_fly and preprocessing.read_tiles_from is None and resolved_backend == "cucim": effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) loader_kwargs["num_workers"] = effective_num_workers @@ -1577,7 +1571,7 @@ def _compute_hierarchical_embeddings_for_slide( subtile_index_within_region=index.subtile_index_within_region, read_region_size_px=int(geometry["read_region_size_px"]), read_tile_size_px=int(geometry["read_tile_size_px"]), - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), num_cucim_workers=preprocessing.num_cucim_workers, gpu_decode=preprocessing.gpu_decode, ) @@ -1587,7 +1581,7 @@ def _compute_hierarchical_embeddings_for_slide( requested_tile_size_px=int(geometry["requested_tile_size_px"]), ) loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) - resolved_backend = _resolve_slide_backend(preprocessing, tiling_result) + resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if resolved_backend == "cucim": effective_num_workers, _ = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) loader_kwargs["num_workers"] = effective_num_workers @@ -1660,7 +1654,7 @@ def _compute_hierarchical_embedding_shard_for_slide( subtile_index_within_region=index.subtile_index_within_region, read_region_size_px=int(geometry["read_region_size_px"]), read_tile_size_px=int(geometry["read_tile_size_px"]), - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), num_cucim_workers=preprocessing.num_cucim_workers, gpu_decode=preprocessing.gpu_decode, ) @@ -1670,7 +1664,7 @@ def _compute_hierarchical_embedding_shard_for_slide( requested_tile_size_px=int(geometry["requested_tile_size_px"]), ) loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) - resolved_backend = _resolve_slide_backend(preprocessing, tiling_result) + resolved_backend = runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) if resolved_backend == "cucim": effective_num_workers, _worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) loader_kwargs["num_workers"] = effective_num_workers @@ -1727,7 +1721,7 @@ def _aggregate_tile_embeddings_for_slide( x_values, y_values = coordinate_arrays(tiling_result) coordinates = np.column_stack((x_values, y_values)) if model.name == "prov-gigapath": - coordinates = _scale_coordinates( + coordinates = runtime_tiling.scale_coordinates( coordinates, float(tiling_result.base_spacing_um), float(tiling_result.requested_spacing_um), @@ -1799,150 +1793,58 @@ def _persist_embedded_slide( output_format=execution.output_format, feature_dim=None, num_tiles=0, - metadata=_build_tile_embedding_metadata( + metadata=runtime_embedding.build_tile_embedding_metadata( model, tiling_result=tiling_result, image_path=embedded_slide.image_path, mask_path=embedded_slide.mask_path, tile_size_lv0=embedded_slide.tile_size_lv0, - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), ), ) return None, None if _is_hierarchical_preprocessing(preprocessing): - hierarchical_artifact = _write_hierarchical_embedding_artifact( + hierarchical_artifact = runtime_embedding.write_hierarchical_embedding_artifact( embedded_slide.sample_id, embedded_slide.tile_embeddings, execution=execution, - metadata=_build_hierarchical_embedding_metadata( + metadata=runtime_embedding.build_hierarchical_embedding_metadata( model, tiling_result=tiling_result, image_path=embedded_slide.image_path, mask_path=embedded_slide.mask_path, - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), preprocessing=preprocessing, ), ) return hierarchical_artifact, None tile_artifact = None - if _should_persist_tile_embeddings(model, execution): - tile_artifact = _write_tile_embedding_artifact( + if runtime_embedding.should_persist_tile_embeddings(model, execution): + tile_artifact = runtime_embedding.write_tile_embedding_artifact( embedded_slide.sample_id, embedded_slide.tile_embeddings, execution=execution, - metadata=_build_tile_embedding_metadata( + metadata=runtime_embedding.build_tile_embedding_metadata( model, tiling_result=tiling_result, image_path=embedded_slide.image_path, mask_path=embedded_slide.mask_path, tile_size_lv0=embedded_slide.tile_size_lv0, - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), ), ) slide_artifact = None if embedded_slide.slide_embedding is not None: - slide_artifact = _write_slide_embedding_artifact( + slide_artifact = runtime_embedding.write_slide_embedding_artifact( embedded_slide.sample_id, embedded_slide.slide_embedding, execution=execution, - metadata=_build_slide_embedding_metadata(model, image_path=embedded_slide.image_path), + metadata=runtime_embedding.build_slide_embedding_metadata(model, image_path=embedded_slide.image_path), latents=embedded_slide.latents, ) return tile_artifact, slide_artifact -def _build_tile_embedding_metadata( - model, - *, - tiling_result, - image_path: Path | str, - mask_path: Path | str | None, - tile_size_lv0: int, - backend: str, -) -> dict[str, Any]: - return runtime_embedding.build_tile_embedding_metadata( - model, - tiling_result=tiling_result, - image_path=image_path, - mask_path=mask_path, - tile_size_lv0=tile_size_lv0, - backend=backend, - ) - - -def _build_slide_embedding_metadata(model, *, image_path: Path | str) -> dict[str, Any]: - return runtime_embedding.build_slide_embedding_metadata(model, image_path=image_path) - - -def _build_hierarchical_embedding_metadata( - model, - *, - tiling_result, - image_path: Path | str, - mask_path: Path | str | None, - backend: str, - preprocessing: PreprocessingConfig, -) -> dict[str, Any]: - return runtime_embedding.build_hierarchical_embedding_metadata( - model, - tiling_result=tiling_result, - image_path=image_path, - mask_path=mask_path, - backend=backend, - preprocessing=preprocessing, - resolve_hierarchical_geometry_fn=_resolve_hierarchical_geometry, - ) - - -def _write_tile_embedding_artifact( - sample_id: str, - features, - *, - execution: ExecutionOptions, - metadata: dict[str, Any], -) -> TileEmbeddingArtifact: - return runtime_embedding.write_tile_embedding_artifact( - sample_id, - features, - execution=execution, - metadata=metadata, - num_rows_fn=_num_rows, - ) - - -def _write_slide_embedding_artifact( - sample_id: str, - embedding, - *, - execution: ExecutionOptions, - metadata: dict[str, Any], - latents=None, -) -> SlideEmbeddingArtifact: - return runtime_embedding.write_slide_embedding_artifact( - sample_id, - embedding, - execution=execution, - metadata=metadata, - latents=latents, - ) - - -def _write_hierarchical_embedding_artifact( - sample_id: str, - features, - *, - execution: ExecutionOptions, - metadata: dict[str, Any], -) -> HierarchicalEmbeddingArtifact: - return runtime_embedding.write_hierarchical_embedding_artifact( - sample_id, - features, - execution=execution, - metadata=metadata, - ) - - - def _describe_device_mode(model, execution: ExecutionOptions) -> str: requested_device = getattr(model, "_requested_device", None) if requested_device == "cpu": @@ -2062,12 +1964,12 @@ def _write_zero_tile_embedding_sidecars( np.empty((0, int(geometry["tiles_per_region"]), 0), dtype=np.float32), output_dir=output_dir, output_format=output_format, - metadata=_build_hierarchical_embedding_metadata( + metadata=runtime_embedding.build_hierarchical_embedding_metadata( model, tiling_result=tiling_result, image_path=slide.image_path, mask_path=slide.mask_path, - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), preprocessing=preprocessing, ), ) @@ -2078,13 +1980,13 @@ def _write_zero_tile_embedding_sidecars( output_format=output_format, feature_dim=None, num_tiles=0, - metadata=_build_tile_embedding_metadata( + metadata=runtime_embedding.build_tile_embedding_metadata( model=model, tiling_result=tiling_result, image_path=slide.image_path, mask_path=slide.mask_path, tile_size_lv0=int(tiling_result.tile_size_lv0), - backend=_resolve_slide_backend(preprocessing, tiling_result), + backend=runtime_tiling.resolve_slide_backend(preprocessing, tiling_result), ), ) @@ -2123,10 +2025,6 @@ def _emit_tiling_summary( ) -def _should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: - return runtime_embedding.should_persist_tile_embeddings(model, execution) - - def _resolved_process_list_output_variant(model) -> str | None: requested_output_variant = getattr(model, "_output_variant", None) if not hasattr(model, "name") or model.name not in encoder_registry: @@ -2239,7 +2137,7 @@ def _tile_slides( num_workers: int, ) -> list[Any]: _preload_asap_wholeslidedata(preprocessing) - tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_coordinates_from, resume = _build_hs2p_configs(preprocessing) + tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_coordinates_from, resume = runtime_tiling.build_hs2p_configs(preprocessing) def _run_tile_slides(): return tile_slides( @@ -2262,7 +2160,7 @@ def _run_tile_slides(): def _preload_asap_wholeslidedata(preprocessing: PreprocessingConfig) -> None: """Load wholeslidedata quietly so ASAP backend import noise stays off stderr.""" - if _resolve_tiling_backend(preprocessing) != "asap": + if runtime_tiling.resolve_tiling_backend(preprocessing) != "asap": return with suppress_c_stderr(): try: @@ -2344,61 +2242,6 @@ def _resolve_path_str(value: Any) -> str | None: process_df.to_csv(process_list_path, index=False) -def _build_preview_config(preview: dict[str, Any]) -> PreviewConfig: - return runtime_tiling.build_preview_config(preview, preview_config_cls=PreviewConfig) - - -def _build_hs2p_configs(preprocessing: PreprocessingConfig): - return runtime_tiling.build_hs2p_configs( - preprocessing, - is_hierarchical_preprocessing_fn=_is_hierarchical_preprocessing, - resolve_tiling_backend_fn=_resolve_tiling_backend, - tiling_config_cls=TilingConfig, - segmentation_config_cls=SegmentationConfig, - filter_config_cls=FilterConfig, - preview_config_cls=PreviewConfig, - ) - - -def _resolve_tile_store_archive_for_slide( - *, - slide: SlideSpec, - tiling_result, - preprocessing: PreprocessingConfig, -) -> Path | None: - return runtime_tiling.resolve_tile_store_archive_for_slide( - slide_sample_id=slide.sample_id, - tiling_result=tiling_result, - preprocessing=preprocessing, - ) - - -def _tile_store_archive_path(tile_store_root: Path, sample_id: str) -> Path: - return runtime_tiling.tile_store_archive_path(tile_store_root, sample_id) - - - -def _load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path): - return runtime_tiling.load_tiling_result( - coordinates_npz_path, - coordinates_meta_path, - load_tiling_result_fn=load_tiling_result, - ) - - -def _scale_coordinates(coordinates: np.ndarray, base_spacing_um: float, spacing: float) -> np.ndarray: - return runtime_tiling.scale_coordinates(coordinates, base_spacing_um, spacing) - - - -def _resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: - return runtime_tiling.resolve_tiling_backend(preprocessing) - - -def _resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_result) -> str: - return runtime_tiling.resolve_slide_backend(preprocessing, tiling_result) - - def _resolve_model_preprocessing(model, preprocessing: PreprocessingConfig | None) -> PreprocessingConfig: defaults = None @@ -2504,16 +2347,16 @@ def _embed_single_slide_distributed( sample_id=slide.sample_id, ) if _is_hierarchical_preprocessing(preprocessing): - shard_payloads = _load_hierarchical_embedding_shards(coordination_dir, slide.sample_id) + shard_payloads = runtime_distributed.load_hierarchical_embedding_shards(coordination_dir, slide.sample_id) geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) - tile_embeddings = _merge_hierarchical_embedding_shards( + tile_embeddings = runtime_distributed.merge_hierarchical_embedding_shards( shard_payloads, num_regions=_num_tiles(tiling_result), tiles_per_region=int(geometry["tiles_per_region"]), ) else: - shard_payloads = _load_tile_embedding_shards(coordination_dir, slide.sample_id) - tile_embeddings = _merge_tile_embedding_shards(shard_payloads) + shard_payloads = runtime_distributed.load_tile_embedding_shards(coordination_dir, slide.sample_id) + tile_embeddings = runtime_distributed.merge_tile_embedding_shards(shard_payloads) if model.level != "slide": return _make_embedded_slide( slide=slide, @@ -2548,7 +2391,11 @@ def _embed_multi_slides_distributed( execution: ExecutionOptions, work_dir: Path, ) -> list[EmbeddedSlide]: - assignments = _assign_slides_to_ranks(slide_records, tiling_results, num_gpus=execution.num_gpus) + assignments = runtime_distributed.assign_slides_to_ranks( + slide_records, + tiling_results, + num_gpus=execution.num_gpus, + ) with _distributed_coordination_dir(work_dir) as coordination_dir: _run_distributed_direct_embedding_stage( model, @@ -2561,7 +2408,7 @@ def _embed_multi_slides_distributed( ) results = [] for slide, tiling_result in zip(slide_records, tiling_results): - payload = _load_embedded_slide_payload(coordination_dir, slide.sample_id) + payload = runtime_distributed.load_embedded_slide_payload(coordination_dir, slide.sample_id) slide_embedding = payload["slide_embedding"] if "slide_embedding" in payload else None latents = payload["latents"] if "latents" in payload else None results.append( @@ -2645,8 +2492,8 @@ def _build_pipeline_worker_request_payload( progress_events_path: Path | None = None, ) -> dict[str, Any]: return { - "model": _serialize_model(model), - "preprocessing": _serialize_preprocessing(preprocessing), + "model": runtime_serialization.serialize_model(model), + "preprocessing": runtime_serialization.serialize_preprocessing(preprocessing), "execution": _serialize_execution(execution, preprocessing=preprocessing), "progress_events_path": str(progress_events_path) if progress_events_path is not None else None, } @@ -2659,8 +2506,8 @@ def _write_embedding_request( output_dir: Path, ) -> None: payload = { - "model": _serialize_model(model), - "preprocessing": _serialize_preprocessing(preprocessing), + "model": runtime_serialization.serialize_model(model), + "preprocessing": runtime_serialization.serialize_preprocessing(preprocessing), "execution": _serialize_execution(execution, preprocessing=preprocessing), } request_path = output_dir / "embedding_request.json" @@ -2680,8 +2527,8 @@ def _build_direct_embed_worker_request_payload( ) -> dict[str, Any]: return { "strategy": strategy, - "model": _serialize_model(model), - "preprocessing": _serialize_preprocessing(preprocessing), + "model": runtime_serialization.serialize_model(model), + "preprocessing": runtime_serialization.serialize_preprocessing(preprocessing), "execution": _serialize_execution(execution, preprocessing=preprocessing), "coordination_dir": str(coordination_dir), "sample_id": sample_id, @@ -2694,58 +2541,6 @@ def _reset_progress_event_logs(progress_events_path: Path) -> None: runtime_distributed.reset_progress_event_logs(progress_events_path) -def _drain_stream_to_buffer(stream, chunks: list[str]) -> None: - # Compatibility shim for tests monkeypatching this helper in slide2vec.inference. - runtime_distributed.drain_stream_to_buffer(stream, chunks) - - -def _write_worker_logs(module: str, output_dir: Path, stdout_text: str, stderr_text: str) -> tuple[Path, Path]: - return runtime_distributed.write_worker_logs(module, output_dir, stdout_text, stderr_text) - - -def _assign_slides_to_ranks( - slide_records: Sequence[SlideSpec], - tiling_results, - *, - num_gpus: int, -) -> dict[int, list[str]]: - return runtime_distributed.assign_slides_to_ranks( - slide_records, - tiling_results, - num_gpus=num_gpus, - num_tiles_fn=_num_tiles, - ) - - -def _merge_tile_embedding_shards(shard_payloads): - return runtime_distributed.merge_tile_embedding_shards(shard_payloads) - - -def _merge_hierarchical_embedding_shards( - shard_payloads, - *, - num_regions: int, - tiles_per_region: int, -): - return runtime_distributed.merge_hierarchical_embedding_shards( - shard_payloads, - num_regions=num_regions, - tiles_per_region=tiles_per_region, - ) - - -def _load_tile_embedding_shards(coordination_dir: Path, sample_id: str): - return runtime_distributed.load_tile_embedding_shards(coordination_dir, sample_id) - - -def _load_hierarchical_embedding_shards(coordination_dir: Path, sample_id: str): - return runtime_distributed.load_hierarchical_embedding_shards(coordination_dir, sample_id) - - -def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): - return runtime_distributed.load_embedded_slide_payload(coordination_dir, sample_id) - - def load_successful_tiled_slides(output_dir: str | Path) -> tuple[list[SlideSpec], list[Any]]: base_dir = Path(output_dir) process_df = load_tiling_process_df(base_dir / "process_list.csv") diff --git a/slide2vec/runtime/distributed.py b/slide2vec/runtime/distributed.py index cc83064..c5de80f 100644 --- a/slide2vec/runtime/distributed.py +++ b/slide2vec/runtime/distributed.py @@ -16,6 +16,7 @@ from hs2p import SlideSpec from slide2vec.progress import emit_progress_event, read_progress_events +from slide2vec.runtime.hierarchical import num_tiles @contextmanager @@ -126,18 +127,17 @@ def assign_slides_to_ranks( tiling_results, *, num_gpus: int, - num_tiles_fn, ) -> dict[int, list[str]]: assignments: dict[int, list[str]] = {rank: [] for rank in range(num_gpus)} assigned_ranks = [(0, rank) for rank in range(num_gpus)] heapq.heapify(assigned_ranks) sortable = [] for slide, tiling_result in zip(slide_records, tiling_results): - sortable.append((slide.sample_id, num_tiles_fn(tiling_result))) - for sample_id, num_tiles in sorted(sortable, key=lambda item: (-item[1], item[0])): + sortable.append((slide.sample_id, num_tiles(tiling_result))) + for sample_id, tile_count in sorted(sortable, key=lambda item: (-item[1], item[0])): assigned_tiles, rank = heapq.heappop(assigned_ranks) assignments[rank].append(sample_id) - heapq.heappush(assigned_ranks, (assigned_tiles + int(num_tiles), rank)) + heapq.heappush(assigned_ranks, (assigned_tiles + int(tile_count), rank)) return assignments diff --git a/slide2vec/runtime/embedding.py b/slide2vec/runtime/embedding.py index 01e16d9..978e2e4 100644 --- a/slide2vec/runtime/embedding.py +++ b/slide2vec/runtime/embedding.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Callable +from typing import Any import numpy as np @@ -14,6 +14,7 @@ write_slide_embeddings, write_tile_embeddings, ) +from slide2vec.runtime.hierarchical import resolve_hierarchical_geometry def should_persist_tile_embeddings(model, execution: ExecutionOptions) -> bool: @@ -67,7 +68,6 @@ def build_hierarchical_embedding_metadata( mask_path: Path | str | None, backend: str, preprocessing: PreprocessingConfig, - resolve_hierarchical_geometry_fn: Callable[[PreprocessingConfig, Any], dict[str, int]], ) -> dict[str, Any]: coordinates_npz_path = ( tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None @@ -75,7 +75,7 @@ def build_hierarchical_embedding_metadata( coordinates_meta_path = ( tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None ) - geometry = resolve_hierarchical_geometry_fn(preprocessing, tiling_result) + geometry = resolve_hierarchical_geometry(preprocessing, tiling_result) return { "encoder_name": model.name, "encoder_level": model.level, @@ -100,7 +100,6 @@ def write_tile_embedding_artifact( *, execution: ExecutionOptions, metadata: dict[str, Any], - num_rows_fn: Callable[[Any], int], ) -> TileEmbeddingArtifact: if execution.output_dir is None: raise ValueError("ExecutionOptions.output_dir is required to persist tile embeddings") @@ -110,10 +109,16 @@ def write_tile_embedding_artifact( output_dir=execution.output_dir, output_format=execution.output_format, metadata=metadata, - tile_index=np.arange(num_rows_fn(features), dtype=np.int64), + tile_index=np.arange(_num_rows(features), dtype=np.int64), ) +def _num_rows(data: Any) -> int: + if hasattr(data, "shape") and len(data.shape) >= 1: + return int(data.shape[0]) + return len(data) + + def write_slide_embedding_artifact( sample_id: str, embedding, @@ -150,4 +155,3 @@ def write_hierarchical_embedding_artifact( output_format=execution.output_format, metadata=metadata, ) - diff --git a/slide2vec/runtime/tiling.py b/slide2vec/runtime/tiling.py index ecd5beb..289901c 100644 --- a/slide2vec/runtime/tiling.py +++ b/slide2vec/runtime/tiling.py @@ -1,11 +1,13 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Callable +from typing import Any import numpy as np +from hs2p import FilterConfig, PreviewConfig, SegmentationConfig, TilingConfig, load_tiling_result from slide2vec.api import PreprocessingConfig +from slide2vec.runtime.hierarchical import is_hierarchical_preprocessing def resolve_tiling_backend(preprocessing: PreprocessingConfig | None) -> str: @@ -24,42 +26,38 @@ def resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_resu return "asap" -def build_preview_config(preview: dict[str, Any], *, preview_config_cls): - return preview_config_cls( +def build_preview_config(preview: dict[str, Any]) -> PreviewConfig: + overlay_color = preview.get("mask_overlay_color") + if overlay_color is None: + overlay_color = preview["tissue_contour_color"] + return PreviewConfig( save_mask_preview=bool(preview["save_mask_preview"]), save_tiling_preview=bool(preview["save_tiling_preview"]), downsample=int(preview["downsample"]), - tissue_contour_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), + mask_overlay_color=tuple(int(channel) for channel in overlay_color), mask_overlay_alpha=float(preview["mask_overlay_alpha"]), ) def build_hs2p_configs( preprocessing: PreprocessingConfig, - *, - is_hierarchical_preprocessing_fn: Callable[[PreprocessingConfig | None], bool], - resolve_tiling_backend_fn: Callable[[PreprocessingConfig | None], str], - tiling_config_cls, - segmentation_config_cls, - filter_config_cls, - preview_config_cls, ): requested_tile_size_px = ( preprocessing.requested_region_size_px - if is_hierarchical_preprocessing_fn(preprocessing) + if is_hierarchical_preprocessing(preprocessing) else preprocessing.requested_tile_size_px ) - tiling_cfg = tiling_config_cls( - backend=resolve_tiling_backend_fn(preprocessing), + tiling_cfg = TilingConfig( + backend=resolve_tiling_backend(preprocessing), requested_spacing_um=preprocessing.requested_spacing_um, requested_tile_size_px=requested_tile_size_px, tolerance=preprocessing.tolerance, overlap=preprocessing.overlap, tissue_threshold=preprocessing.tissue_threshold, ) - segmentation_cfg = segmentation_config_cls(**dict(preprocessing.segmentation)) - filtering_cfg = filter_config_cls(**dict(preprocessing.filtering)) - preview_cfg = build_preview_config(dict(preprocessing.preview), preview_config_cls=preview_config_cls) + segmentation_cfg = SegmentationConfig(**dict(preprocessing.segmentation)) + filtering_cfg = FilterConfig(**dict(preprocessing.filtering)) + preview_cfg = build_preview_config(dict(preprocessing.preview)) return ( tiling_cfg, segmentation_cfg, @@ -90,11 +88,13 @@ def resolve_tile_store_archive_for_slide( return tiling_result.tiles_tar_path if hasattr(tiling_result, "tiles_tar_path") else None -def load_tiling_result(coordinates_npz_path: Path, coordinates_meta_path: Path, *, load_tiling_result_fn): - return load_tiling_result_fn(coordinates_npz_path=coordinates_npz_path, coordinates_meta_path=coordinates_meta_path) +def load_tiling_result_from_paths(coordinates_npz_path: Path, coordinates_meta_path: Path): + return load_tiling_result( + coordinates_npz_path=coordinates_npz_path, + coordinates_meta_path=coordinates_meta_path, + ) def scale_coordinates(coordinates: np.ndarray, base_spacing_um: float, spacing: float) -> np.ndarray: scale = base_spacing_um / spacing return (coordinates * scale).astype(int) - diff --git a/tasks/lessons.md b/tasks/lessons.md index ec3d646..3b7b027 100644 --- a/tasks/lessons.md +++ b/tasks/lessons.md @@ -2,6 +2,7 @@ ## 2026-04-18 +- When internal runtime modules already own the implementation, call them directly from orchestrators; avoid pass-through wrappers or alias assignments in `inference.py` that only forward arguments unchanged. - Prefer neutral package names like `runtime/` for internal implementation modules unless the user explicitly wants a private-style namespace; leading underscores in directory names read as accidental or overly internal. - When slide2vec depends on bridged HS2P progress events, keep the bridge whitelist in sync with every reporter stage the UI renders; otherwise the code can define a preview bar and still never receive preview events. diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index dfcd28d..312dd1b 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -903,8 +903,8 @@ def fake_tile_slides(slides, **kwargs): monkeypatch.setattr(inference, "tile_slides", fake_tile_slides) monkeypatch.setattr( - inference, - "_build_hs2p_configs", + inference.runtime_tiling, + "build_hs2p_configs", lambda preprocessing: ( SimpleNamespace(requested_backend="cucim"), "segmentation", @@ -946,8 +946,8 @@ def fake_tile_slides(slides, **kwargs): monkeypatch.setattr(inference, "tile_slides", fake_tile_slides) monkeypatch.setattr( - inference, - "_build_hs2p_configs", + inference.runtime_tiling, + "build_hs2p_configs", lambda preprocessing: ( SimpleNamespace(requested_backend="auto"), "segmentation", @@ -1049,8 +1049,8 @@ def fake_tile_slides(slides, **kwargs): assert not hasattr(inference, "resolve_backend") monkeypatch.setattr(inference, "tile_slides", fake_tile_slides) monkeypatch.setattr( - inference, - "_build_hs2p_configs", + inference.runtime_tiling, + "build_hs2p_configs", lambda preprocessing: ( SimpleNamespace(requested_backend="auto"), "segmentation", @@ -1084,29 +1084,8 @@ def fake_tile_slides(slides, **kwargs): ] -def test_build_hs2p_configs_constructs_preview_config(monkeypatch): - import slide2vec.inference as inference - - class FakeTilingConfig: - def __init__(self, **kwargs): - self.kwargs = kwargs - - class FakeSegmentationConfig: - def __init__(self, **kwargs): - self.kwargs = kwargs - - class FakeFilterConfig: - def __init__(self, **kwargs): - self.kwargs = kwargs - - class FakePreviewConfig: - def __init__(self, **kwargs): - self.kwargs = kwargs - - monkeypatch.setattr(inference, "TilingConfig", FakeTilingConfig) - monkeypatch.setattr(inference, "SegmentationConfig", FakeSegmentationConfig) - monkeypatch.setattr(inference, "FilterConfig", FakeFilterConfig) - monkeypatch.setattr(inference, "PreviewConfig", FakePreviewConfig) +def test_build_hs2p_configs_constructs_preview_config(): + import slide2vec.runtime.tiling as runtime_tiling preprocessing = PreprocessingConfig( backend="asap", @@ -1127,19 +1106,17 @@ def __init__(self, **kwargs): ) tiling_cfg, segmentation_cfg, filtering_cfg, preview_cfg, read_coordinates_from, resume = ( - inference._build_hs2p_configs(preprocessing) - ) - - assert tiling_cfg.kwargs["backend"] == "asap" - assert segmentation_cfg.kwargs == {"downsample": 64} - assert filtering_cfg.kwargs == {"ref_tile_size": 224} - assert preview_cfg.kwargs == { - "save_mask_preview": True, - "save_tiling_preview": False, - "downsample": 32, - "tissue_contour_color": (157, 219, 129), - "mask_overlay_alpha": 0.5, - } + runtime_tiling.build_hs2p_configs(preprocessing) + ) + + assert tiling_cfg.backend == "asap" + assert segmentation_cfg.downsample == 64 + assert filtering_cfg.ref_tile_size == 224 + assert preview_cfg.save_mask_preview is True + assert preview_cfg.save_tiling_preview is False + assert preview_cfg.downsample == 32 + assert preview_cfg.mask_overlay_color == (157, 219, 129) + assert preview_cfg.mask_overlay_alpha == pytest.approx(0.5) assert read_coordinates_from is None assert resume is False @@ -1244,12 +1221,12 @@ def test_record_slide_metadata_in_process_list_adds_backend_columns(monkeypatch, def test_resolve_slide_backend_uses_tiling_result_backend_for_auto(): - import slide2vec.inference as inference + import slide2vec.runtime.tiling as runtime_tiling - assert inference._resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace(backend="cucim")) == "cucim" - assert inference._resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace(backend="asap")) == "asap" - assert inference._resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace()) == "asap" - assert inference._resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="cucim"), SimpleNamespace(backend="asap")) == "cucim" + assert runtime_tiling.resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace(backend="cucim")) == "cucim" + assert runtime_tiling.resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace(backend="asap")) == "asap" + assert runtime_tiling.resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="auto"), SimpleNamespace()) == "asap" + assert runtime_tiling.resolve_slide_backend(replace(DEFAULT_PREPROCESSING, backend="cucim"), SimpleNamespace(backend="asap")) == "cucim" def test_preload_asap_wholeslidedata_suppresses_noisy_import(monkeypatch, capfd): @@ -1368,8 +1345,8 @@ def fake_coordination_dir(work_dir: Path): monkeypatch.setattr(inference, "_distributed_coordination_dir", fake_coordination_dir) monkeypatch.setattr(inference, "_run_distributed_direct_embedding_stage", lambda *args, **kwargs: None) monkeypatch.setattr( - inference, - "_load_tile_embedding_shards", + inference.runtime_distributed, + "load_tile_embedding_shards", lambda *_args, **_kwargs: [ { "tile_index": np.array([0, 1], dtype=np.int64), @@ -1432,8 +1409,8 @@ def fake_coordination_dir(work_dir: Path): monkeypatch.setattr(inference, "_distributed_coordination_dir", fake_coordination_dir) monkeypatch.setattr(inference, "_run_distributed_direct_embedding_stage", lambda *args, **kwargs: None) monkeypatch.setattr( - inference, - "_load_tile_embedding_shards", + inference.runtime_distributed, + "load_tile_embedding_shards", lambda *_args, **_kwargs: [ { "tile_index": np.array([0, 1], dtype=np.int64), @@ -1442,8 +1419,8 @@ def fake_coordination_dir(work_dir: Path): ], ) monkeypatch.setattr( - inference, - "_merge_tile_embedding_shards", + inference.runtime_distributed, + "merge_tile_embedding_shards", lambda shard_payloads: shard_payloads[0]["tile_embeddings"], ) @@ -2007,7 +1984,7 @@ def test_pipeline_worker_assigns_slides_by_tile_count(): SimpleNamespace(x=np.arange(6), y=np.arange(6), tile_size_lv0=224), ] - assignments = pipeline_worker._assign_slides_to_ranks(slides, tiling_results, num_gpus=2) + assignments = pipeline_worker.assign_slides_to_ranks(slides, tiling_results, num_gpus=2) assert assignments == { 0: ["slide-a", "slide-d"], @@ -2015,7 +1992,7 @@ def test_pipeline_worker_assigns_slides_by_tile_count(): } def test_assign_slides_to_ranks_balances_by_tile_count(): - import slide2vec.inference as inference + from slide2vec.runtime.distributed import assign_slides_to_ranks slides = [ make_slide("slide-a"), @@ -2030,7 +2007,7 @@ def test_assign_slides_to_ranks_balances_by_tile_count(): SimpleNamespace(x=np.arange(6), y=np.arange(6), tile_size_lv0=224), ] - assignments = inference._assign_slides_to_ranks(slides, tiling_results, num_gpus=2) + assignments = assign_slides_to_ranks(slides, tiling_results, num_gpus=2) assert assignments == { 0: ["slide-a", "slide-d"], @@ -2039,7 +2016,7 @@ def test_assign_slides_to_ranks_balances_by_tile_count(): def test_assign_slides_to_ranks_tiebreaks_by_rank_deterministically(): - import slide2vec.inference as inference + from slide2vec.runtime.distributed import assign_slides_to_ranks slides = [ make_slide("slide-a"), @@ -2056,7 +2033,7 @@ def test_assign_slides_to_ranks_tiebreaks_by_rank_deterministically(): SimpleNamespace(x=np.arange(1), y=np.arange(1), tile_size_lv0=224), ] - assignments = inference._assign_slides_to_ranks(slides, tiling_results, num_gpus=3) + assignments = assign_slides_to_ranks(slides, tiling_results, num_gpus=3) assert assignments == { 0: ["slide-a", "slide-d"], @@ -2066,9 +2043,9 @@ def test_assign_slides_to_ranks_tiebreaks_by_rank_deterministically(): def test_merge_tile_embedding_shards_restores_original_tile_order(): - import slide2vec.inference as inference + from slide2vec.runtime.distributed import merge_tile_embedding_shards - merged = inference._merge_tile_embedding_shards( + merged = merge_tile_embedding_shards( [ { "tile_index": np.array([2, 0], dtype=np.int64), @@ -2207,6 +2184,7 @@ def encode_tiles(self, image): def test_serialize_execution_preserves_loader_optimization_fields(): import slide2vec.inference as inference + from slide2vec.runtime.serialization import deserialize_execution execution = ExecutionOptions( output_dir=Path("/tmp/output"), @@ -2221,7 +2199,7 @@ def test_serialize_execution_preserves_loader_optimization_fields(): ) payload = inference._serialize_execution(execution) - restored = inference.deserialize_execution(payload) + restored = deserialize_execution(payload) assert payload["prefetch_factor"] == 7 assert payload["persistent_workers"] is False @@ -2232,17 +2210,17 @@ def test_serialize_execution_preserves_loader_optimization_fields(): def test_deserialize_execution_defaults_num_workers_to_auto(): - import slide2vec.inference as inference + from slide2vec.runtime.serialization import deserialize_execution - restored = inference.deserialize_execution({"batch_size": 4, "num_gpus": 1}) + restored = deserialize_execution({"batch_size": 4, "num_gpus": 1}) assert restored.num_workers is None def test_deserialize_execution_preserves_auto_num_workers(): - import slide2vec.inference as inference + from slide2vec.runtime.serialization import deserialize_execution - restored = inference.deserialize_execution({"batch_size": 4, "num_workers": None, "num_gpus": 1}) + restored = deserialize_execution({"batch_size": 4, "num_workers": None, "num_gpus": 1}) assert restored.num_workers is None @@ -2940,8 +2918,8 @@ def test_persist_embedded_slide_records_resolved_backend_when_auto(monkeypatch, captured = {} monkeypatch.setattr( - inference, - "_write_tile_embedding_artifact", + inference.runtime_embedding, + "write_tile_embedding_artifact", lambda sample_id, features, *, execution, metadata: captured.setdefault("metadata", metadata) or SimpleNamespace(), ) @@ -3097,9 +3075,9 @@ def test_build_hierarchical_index_uses_tile_first_level0_offsets_under_spacing_m def test_merge_hierarchical_embedding_shards_restores_original_region_shape(): - import slide2vec.inference as inference + from slide2vec.runtime.distributed import merge_hierarchical_embedding_shards - merged = inference._merge_hierarchical_embedding_shards( + merged = merge_hierarchical_embedding_shards( [ { "flat_index": np.array([2, 0, 7], dtype=np.int64), @@ -3304,17 +3282,17 @@ def to(self, device): def test_scale_coordinates_scales_down(): - from slide2vec.inference import _scale_coordinates + from slide2vec.runtime.tiling import scale_coordinates coords = np.array([[10, 20], [30, 40]]) # base=0.25, target=0.5 → scale=0.5 → coordinates halved - result = _scale_coordinates(coords, base_spacing_um=0.25, spacing=0.5) + result = scale_coordinates(coords, base_spacing_um=0.25, spacing=0.5) np.testing.assert_array_equal(result, [[5, 10], [15, 20]]) def test_scale_coordinates_identity_when_spacings_equal(): - from slide2vec.inference import _scale_coordinates + from slide2vec.runtime.tiling import scale_coordinates coords = np.array([[10, 20], [30, 40]]) - result = _scale_coordinates(coords, base_spacing_um=0.5, spacing=0.5) + result = scale_coordinates(coords, base_spacing_um=0.5, spacing=0.5) np.testing.assert_array_equal(result, [[10, 20], [30, 40]]) From abb52a9cbcb80fda2c9d774a6d991b66a7c1d92e Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 22:14:05 +0200 Subject: [PATCH 10/11] Drop stale distributed wrappers and remove preview compat branch --- docs/documentation.md | 2 ++ slide2vec/inference.py | 51 ++++++------------------------ slide2vec/runtime/tiling.py | 5 +-- tests/test_progress.py | 11 ++++--- tests/test_regression_inference.py | 4 +-- 5 files changed, 21 insertions(+), 52 deletions(-) diff --git a/docs/documentation.md b/docs/documentation.md index 98b61cc..ed3c12e 100644 --- a/docs/documentation.md +++ b/docs/documentation.md @@ -4,6 +4,8 @@ - Split `slide2vec.inference` into workflow-scoped internal runtime helpers under `slide2vec.runtime` (`batching`, `hierarchical`, `progress_bridge`, `serialization`, `types`) while keeping `slide2vec.inference` as the stable orchestration entrypoint. - Removed remaining pass-through wrappers from `slide2vec.inference` (tiling/embedding/distributed helpers) and switched workers/tests to import `slide2vec.runtime.*` directly, with simplified runtime helper signatures replacing callback/class injection patterns. +- Removed the leftover distributed orchestration pass-through layer in `slide2vec.inference` (`_distributed_coordination_dir`, `_run_torchrun_worker`, `_reset_progress_event_logs`) so torchrun/progress coordination now calls `slide2vec.runtime.distributed` directly. +- Dropped the temporary preview-key compatibility branch in `runtime.tiling.build_preview_config`; preview config now uses the canonical `tissue_contour_color` input only. - Added architecture guardrail tests that keep workflow helpers bounded (soft target around 400 lines, enforced ceiling 500) and prevent `slide2vec/inference.py` from regressing toward the previous monolith size. - Extracted distributed torchrun orchestration, shard merge/loading, and rank-assignment helpers into `slide2vec.runtime.distributed`, with inference-level compatibility shims preserved for existing tests and monkeypatch patterns. - Moved artifact collection/loading and process-list embedding status updates into `slide2vec.runtime.persistence` so the pipeline orchestration flow in `slide2vec.inference` stays focused on control flow. diff --git a/slide2vec/inference.py b/slide2vec/inference.py index 4dac491..5002955 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -1,7 +1,6 @@ import json import importlib import os -import subprocess import tempfile import threading import time @@ -2299,7 +2298,7 @@ def _run_distributed_embedding_stage( return request_path = output_dir / "embedding_request.json" progress_events_path = output_dir / "logs" / "pipeline_worker.progress.jsonl" - _reset_progress_event_logs(progress_events_path) + runtime_distributed.reset_progress_event_logs(progress_events_path) request_payload = _build_pipeline_worker_request_payload( model, preprocessing, @@ -2312,13 +2311,14 @@ def _run_distributed_embedding_stage( slide_count=len(successful_slides), num_gpus=execution.num_gpus, ) - _run_torchrun_worker( + runtime_distributed.run_torchrun_worker( module="slide2vec.distributed.pipeline_worker", - execution=execution, + num_gpus=execution.num_gpus, output_dir=output_dir, request_path=request_path, failure_title="Distributed feature extraction failed", progress_events_path=progress_events_path, + popen_factory=runtime_distributed.subprocess.Popen, ) emit_progress( "embedding.assignment.finished", @@ -2336,7 +2336,7 @@ def _embed_single_slide_distributed( execution: ExecutionOptions, work_dir: Path, ) -> EmbeddedSlide: - with _distributed_coordination_dir(work_dir) as coordination_dir: + with runtime_distributed.distributed_coordination_dir(work_dir) as coordination_dir: _run_distributed_direct_embedding_stage( model, preprocessing=preprocessing, @@ -2396,7 +2396,7 @@ def _embed_multi_slides_distributed( tiling_results, num_gpus=execution.num_gpus, ) - with _distributed_coordination_dir(work_dir) as coordination_dir: + with runtime_distributed.distributed_coordination_dir(work_dir) as coordination_dir: _run_distributed_direct_embedding_stage( model, preprocessing=preprocessing, @@ -2423,12 +2423,6 @@ def _embed_multi_slides_distributed( return results -@contextmanager -def _distributed_coordination_dir(work_dir: Path): - with runtime_distributed.distributed_coordination_dir(work_dir) as coordination_dir: - yield coordination_dir - - def _run_distributed_direct_embedding_stage( model, *, @@ -2442,7 +2436,7 @@ def _run_distributed_direct_embedding_stage( ) -> None: request_path = coordination_dir / "direct_embedding_request.json" progress_events_path = output_dir / "logs" / "direct_embed_worker.progress.jsonl" - _reset_progress_event_logs(progress_events_path) + runtime_distributed.reset_progress_event_logs(progress_events_path) request_payload = _build_direct_embed_worker_request_payload( model=model, preprocessing=preprocessing, @@ -2454,33 +2448,14 @@ def _run_distributed_direct_embedding_stage( progress_events_path=progress_events_path, ) request_path.write_text(json.dumps(request_payload, indent=2, sort_keys=True), encoding="utf-8") - _run_torchrun_worker( - module="slide2vec.distributed.direct_embed_worker", - execution=execution, - output_dir=output_dir, - request_path=request_path, - failure_title="Distributed direct embedding failed", - progress_events_path=progress_events_path, - ) - - -def _run_torchrun_worker( - *, - module: str, - execution: ExecutionOptions, - output_dir: Path, - request_path: Path, - failure_title: str, - progress_events_path: Path | None = None, -) -> None: runtime_distributed.run_torchrun_worker( - module=module, + module="slide2vec.distributed.direct_embed_worker", num_gpus=execution.num_gpus, output_dir=output_dir, request_path=request_path, - failure_title=failure_title, + failure_title="Distributed direct embedding failed", progress_events_path=progress_events_path, - popen_factory=subprocess.Popen, + popen_factory=runtime_distributed.subprocess.Popen, ) @@ -2535,12 +2510,6 @@ def _build_direct_embed_worker_request_payload( "assignments": {str(rank): sample_ids for rank, sample_ids in (assignments or {}).items()}, "progress_events_path": str(progress_events_path) if progress_events_path is not None else None, } - - -def _reset_progress_event_logs(progress_events_path: Path) -> None: - runtime_distributed.reset_progress_event_logs(progress_events_path) - - def load_successful_tiled_slides(output_dir: str | Path) -> tuple[list[SlideSpec], list[Any]]: base_dir = Path(output_dir) process_df = load_tiling_process_df(base_dir / "process_list.csv") diff --git a/slide2vec/runtime/tiling.py b/slide2vec/runtime/tiling.py index 289901c..6047ae5 100644 --- a/slide2vec/runtime/tiling.py +++ b/slide2vec/runtime/tiling.py @@ -27,14 +27,11 @@ def resolve_slide_backend(preprocessing: PreprocessingConfig | None, tiling_resu def build_preview_config(preview: dict[str, Any]) -> PreviewConfig: - overlay_color = preview.get("mask_overlay_color") - if overlay_color is None: - overlay_color = preview["tissue_contour_color"] return PreviewConfig( save_mask_preview=bool(preview["save_mask_preview"]), save_tiling_preview=bool(preview["save_tiling_preview"]), downsample=int(preview["downsample"]), - mask_overlay_color=tuple(int(channel) for channel in overlay_color), + mask_overlay_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), mask_overlay_alpha=float(preview["mask_overlay_alpha"]), ) diff --git a/tests/test_progress.py b/tests/test_progress.py index 930db4b..292baa3 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -329,7 +329,7 @@ def test_run_pipeline_emits_assignment_progress_for_multi_gpu_embedding(monkeypa lambda *args, **kwargs: [embedded_a, embedded_b], ) monkeypatch.setattr(inference, "_persist_embedded_slide", lambda *args, **kwargs: (None, None)) - monkeypatch.setattr(inference, "_run_torchrun_worker", lambda *args, **kwargs: None) + monkeypatch.setattr(inference.runtime_distributed, "run_torchrun_worker", lambda *args, **kwargs: None) monkeypatch.setattr( inference, "_collect_pipeline_artifacts", @@ -636,17 +636,18 @@ def wait(self, timeout=None): self.returncode = 0 return 0 - monkeypatch.setattr(inference.subprocess, "Popen", FakePopen) - monkeypatch.setattr(inference.time, "sleep", lambda _seconds: None) + monkeypatch.setattr(inference.runtime_distributed.subprocess, "Popen", FakePopen) + monkeypatch.setattr(inference.runtime_distributed.time, "sleep", lambda _seconds: None) with progress.activate_progress_reporter(reporter): - inference._run_torchrun_worker( + inference.runtime_distributed.run_torchrun_worker( module="slide2vec.distributed.direct_embed_worker", - execution=inference.ExecutionOptions(output_dir=tmp_path, num_gpus=2), + num_gpus=2, output_dir=tmp_path, request_path=request_path, failure_title="boom", progress_events_path=progress_events_path, + popen_factory=FakePopen, ) assert [event.kind for event in reporter.events] == ["embedding.slide.started"] diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index 312dd1b..6a9c04e 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -1342,7 +1342,7 @@ def test_embed_single_slide_distributed_uses_shared_slide_aggregation_helper(mon def fake_coordination_dir(work_dir: Path): yield work_dir / "coord" - monkeypatch.setattr(inference, "_distributed_coordination_dir", fake_coordination_dir) + monkeypatch.setattr(inference.runtime_distributed, "distributed_coordination_dir", fake_coordination_dir) monkeypatch.setattr(inference, "_run_distributed_direct_embedding_stage", lambda *args, **kwargs: None) monkeypatch.setattr( inference.runtime_distributed, @@ -1406,7 +1406,7 @@ def test_embed_single_slide_distributed_skips_parent_backend_load_for_tile_model def fake_coordination_dir(work_dir: Path): yield work_dir / "coord" - monkeypatch.setattr(inference, "_distributed_coordination_dir", fake_coordination_dir) + monkeypatch.setattr(inference.runtime_distributed, "distributed_coordination_dir", fake_coordination_dir) monkeypatch.setattr(inference, "_run_distributed_direct_embedding_stage", lambda *args, **kwargs: None) monkeypatch.setattr( inference.runtime_distributed, From 2ac5beca692fb87b85780efa8361044f15f8b384 Mon Sep 17 00:00:00 2001 From: clement grisi Date: Sat, 18 Apr 2026 22:21:39 +0200 Subject: [PATCH 11/11] Fix hs2p config regression test for CI runtime schema --- slide2vec/runtime/tiling.py | 2 +- tests/test_regression_inference.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/slide2vec/runtime/tiling.py b/slide2vec/runtime/tiling.py index 6047ae5..53766ce 100644 --- a/slide2vec/runtime/tiling.py +++ b/slide2vec/runtime/tiling.py @@ -31,7 +31,7 @@ def build_preview_config(preview: dict[str, Any]) -> PreviewConfig: save_mask_preview=bool(preview["save_mask_preview"]), save_tiling_preview=bool(preview["save_tiling_preview"]), downsample=int(preview["downsample"]), - mask_overlay_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), + tissue_contour_color=tuple(int(channel) for channel in preview["tissue_contour_color"]), mask_overlay_alpha=float(preview["mask_overlay_alpha"]), ) diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index 6a9c04e..c8325e0 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -1094,7 +1094,7 @@ def test_build_hs2p_configs_constructs_preview_config(): tolerance=0.05, overlap=0.0, tissue_threshold=0.1, - segmentation={"downsample": 64}, + segmentation={"downsample": 64, "method": "hsv"}, filtering={"ref_tile_size": 224}, preview={ "save_mask_preview": True, @@ -1111,11 +1111,12 @@ def test_build_hs2p_configs_constructs_preview_config(): assert tiling_cfg.backend == "asap" assert segmentation_cfg.downsample == 64 + assert segmentation_cfg.method == "hsv" assert filtering_cfg.ref_tile_size == 224 assert preview_cfg.save_mask_preview is True assert preview_cfg.save_tiling_preview is False assert preview_cfg.downsample == 32 - assert preview_cfg.mask_overlay_color == (157, 219, 129) + assert preview_cfg.tissue_contour_color == (157, 219, 129) assert preview_cfg.mask_overlay_alpha == pytest.approx(0.5) assert read_coordinates_from is None assert resume is False