diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml index ff7b1d1..05a11c0 100644 --- a/.github/workflows/pr-test.yaml +++ b/.github/workflows/pr-test.yaml @@ -6,28 +6,7 @@ on: workflow_dispatch: jobs: - regression-guards: - runs-on: ubuntu-latest - timeout-minutes: 10 - permissions: - contents: read - - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Run regression guard tests - run: | - set -euo pipefail - python -m unittest discover -s test -p 'test_regression_bugfixes.py' - docker-test: - needs: regression-guards runs-on: ubuntu-latest timeout-minutes: 60 permissions: diff --git a/docs/benchmarking.md b/docs/benchmarking.md index b7c6c9b..1420bb9 100644 --- a/docs/benchmarking.md +++ b/docs/benchmarking.md @@ -5,7 +5,7 @@ The script samples a balanced subset of your manifest, runs untimed warmups plus repeated measured trials, tunes only: - `model.batch_size` -- `speed.num_workers_embedding` +- `speed.num_dataloader_workers` It keeps the rest of each model config fixed, disables previews / resume / Weights & Biases, and writes: diff --git a/docs/cli.md b/docs/cli.md index a4ac2ad..ceff009 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -79,7 +79,7 @@ Common overrides: - `output_dir=/path/to/output` - `speed.num_gpus=4` -- `speed.num_workers_embedding=8` +- `speed.num_dataloader_workers=8` - `tiling.preview.save=true` - `model.name=...` - `model.output_variant=...` diff --git a/pyproject.toml b/pyproject.toml index 0871ef6..ef99eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,13 +23,10 @@ classifiers = [ dependencies = [ "hs2p>=3.1.2", "omegaconf", - "h5py", - "matplotlib", "numpy<2", "pandas", "pillow", "rich", - "tqdm", "torch", "torchvision", "transformers", diff --git a/slide2vec/__init__.py b/slide2vec/__init__.py index 5df5b8f..d174ab2 100644 --- a/slide2vec/__init__.py +++ b/slide2vec/__init__.py @@ -1,5 +1,5 @@ from slide2vec.api import EmbeddedSlide, ExecutionOptions, Model, Pipeline, PreprocessingConfig, RunResult -from slide2vec.artifacts import SlideEmbeddingArtifact, TileEmbeddingArtifact +from slide2vec.artifacts import HierarchicalEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact __version__ = "3.2.1" @@ -12,6 +12,7 @@ "RunResult", "EmbeddedSlide", "SlideEmbeddingArtifact", + "HierarchicalEmbeddingArtifact", "TileEmbeddingArtifact", "__version__", ] diff --git a/slide2vec/api.py b/slide2vec/api.py index b3a8380..f039db2 100644 --- a/slide2vec/api.py +++ b/slide2vec/api.py @@ -1,5 +1,4 @@ -import os from dataclasses import dataclass, field, replace from contextlib import contextmanager from pathlib import Path @@ -8,7 +7,11 @@ import torch from hs2p import SlideSpec -from slide2vec.artifacts import SlideEmbeddingArtifact, TileEmbeddingArtifact +from slide2vec.artifacts import ( + HierarchicalEmbeddingArtifact, + SlideEmbeddingArtifact, + TileEmbeddingArtifact, +) from slide2vec.encoders.registry import ( encoder_registry, resolve_preprocessing_defaults, @@ -17,6 +20,7 @@ from slide2vec.model_settings import canonicalize_model_name, normalize_precision_name from slide2vec.progress import emit_progress from slide2vec.runtime_types import LoadedModel +from slide2vec.utils.utils import slurm_cpu_limit PathLike = str | Path @@ -38,6 +42,8 @@ class PreprocessingConfig: backend: str = "auto" target_spacing_um: float | None = None target_tile_size_px: int | None = None + target_region_size_px: int | None = None + region_tile_multiple: int | None = None tolerance: float = 0.05 overlap: float = 0.0 tissue_threshold: float = 0.01 @@ -69,6 +75,16 @@ def from_config(cls, cfg: Any) -> "PreprocessingConfig": backend=tiling.backend, target_spacing_um=float(tiling.params.target_spacing_um), target_tile_size_px=int(tiling.params.target_tile_size_px), + target_region_size_px=( + int(v) + if (v := getattr(tiling.params, "target_region_size_px", None)) is not None + else None + ), + region_tile_multiple=( + int(v) + if (v := getattr(tiling.params, "region_tile_multiple", None)) is not None + else None + ), tolerance=float(tiling.params.tolerance), overlap=float(tiling.params.overlap), tissue_threshold=float(tiling.params.tissue_threshold), @@ -140,17 +156,10 @@ def __post_init__(self) -> None: raise ValueError("ExecutionOptions.num_gpus must be at least 1") if self.prefetch_factor < 1: raise ValueError("ExecutionOptions.prefetch_factor must be at least 1") - slurm_cpu_limit = None - for env_name in ("SLURM_CPUS_PER_TASK", "SLURM_CPUS_ON_NODE", "SLURM_JOB_CPUS_PER_NODE"): - if env_name not in os.environ: - continue - value = os.environ[env_name] - if value and value.strip().isdigit() and int(value.strip()) > 0: - slurm_cpu_limit = int(value.strip()) - break - if slurm_cpu_limit is not None: - object.__setattr__(self, "num_workers", min(self.num_workers, slurm_cpu_limit)) - object.__setattr__(self, "num_preprocessing_workers", min(self.num_preprocessing_workers, slurm_cpu_limit)) + limit = slurm_cpu_limit() + if limit is not None: + object.__setattr__(self, "num_workers", min(self.num_workers, limit)) + object.__setattr__(self, "num_preprocessing_workers", min(self.num_preprocessing_workers, limit)) def with_output_dir(self, output_dir: PathLike | None) -> "ExecutionOptions": if output_dir is None: @@ -161,6 +170,7 @@ def with_output_dir(self, output_dir: PathLike | None) -> "ExecutionOptions": @dataclass(frozen=True, kw_only=True) class RunResult: tile_artifacts: list[TileEmbeddingArtifact] + hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] slide_artifacts: list[SlideEmbeddingArtifact] process_list_path: Path | None = None @@ -228,7 +238,7 @@ def embed_tiles( *, preprocessing: PreprocessingConfig | None = None, execution: ExecutionOptions | None = None, - ) -> list[TileEmbeddingArtifact]: + ) -> list[TileEmbeddingArtifact] | list[HierarchicalEmbeddingArtifact]: from slide2vec.inference import embed_tiles resolved = _coerce_execution_options(execution, model=self) @@ -353,6 +363,25 @@ def run( execution=self.execution, ) + def run_with_coordinates( + self, + coordinates_dir: str | Path, + *, + slides: SlideSequence | None = None, + ) -> RunResult: + from slide2vec.inference import run_pipeline_with_coordinates + + with _auto_progress_reporting(output_dir=self.execution.output_dir): + resolved_preprocessing = _resolve_direct_api_preprocessing(self.model, self.preprocessing) + _validate_model_config(self.model, resolved_preprocessing, self.execution) + return run_pipeline_with_coordinates( + self.model, + coordinates_dir=coordinates_dir, + slides=slides, + preprocessing=resolved_preprocessing, + execution=self.execution, + ) + def _coerce_execution_options( options: ExecutionOptions | None, @@ -398,10 +427,12 @@ def ensure_defaults() -> tuple[int, float]: if preprocessing is None: target_tile_size_px, target_spacing_um = ensure_defaults() - return PreprocessingConfig( - backend="auto", - target_spacing_um=target_spacing_um, - target_tile_size_px=target_tile_size_px, + return _resolve_hierarchical_preprocessing( + PreprocessingConfig( + backend="auto", + target_spacing_um=target_spacing_um, + target_tile_size_px=target_tile_size_px, + ) ) target_spacing_um = preprocessing.target_spacing_um @@ -412,10 +443,12 @@ def ensure_defaults() -> tuple[int, float]: target_spacing_um = default_spacing_um if target_tile_size_px is None: target_tile_size_px = default_tile_size_px - return replace( - preprocessing, - target_spacing_um=target_spacing_um, - target_tile_size_px=target_tile_size_px, + return _resolve_hierarchical_preprocessing( + replace( + preprocessing, + target_spacing_um=target_spacing_um, + target_tile_size_px=target_tile_size_px, + ) ) @@ -438,6 +471,10 @@ def _validate_model_config( name = model.name if name not in encoder_registry: return + if preprocessing.region_tile_multiple is not None or preprocessing.target_region_size_px is not None: + info = encoder_registry.info(name) + if info["level"] != "tile": + raise ValueError("Hierarchical preprocessing is only supported for tile encoders") # Skip precision validation for CPU execution (fp32 is always valid on CPU). on_cpu = model._requested_device == "cpu" precision = None if on_cpu or execution is None else execution.precision @@ -451,6 +488,38 @@ def _validate_model_config( ) +def _resolve_hierarchical_preprocessing(preprocessing: PreprocessingConfig) -> PreprocessingConfig: + multiple = preprocessing.region_tile_multiple + target_region_size_px = preprocessing.target_region_size_px + if multiple is not None: + multiple = int(multiple) + if multiple < 2: + raise ValueError("region_tile_multiple must be at least 2") + if multiple is None and target_region_size_px is None: + return preprocessing + if preprocessing.target_tile_size_px is None: + raise ValueError( + "target_tile_size_px must be resolved before deriving hierarchical region geometry" + ) + if target_region_size_px is None: + target_region_size_px = int(preprocessing.target_tile_size_px) * int(multiple) + elif multiple is None: + if int(target_region_size_px) % int(preprocessing.target_tile_size_px) != 0: + raise ValueError( + "target_region_size_px must be an exact multiple of target_tile_size_px" + ) + multiple = int(target_region_size_px) // int(preprocessing.target_tile_size_px) + elif int(target_region_size_px) != int(preprocessing.target_tile_size_px) * int(multiple): + raise ValueError( + "target_region_size_px must match target_tile_size_px * region_tile_multiple" + ) + return replace( + preprocessing, + target_region_size_px=int(target_region_size_px), + region_tile_multiple=int(multiple), + ) + + @contextmanager def _auto_progress_reporting(*, output_dir: PathLike | None): from slide2vec.progress import ( diff --git a/slide2vec/artifacts.py b/slide2vec/artifacts.py index ffa3c71..6011101 100644 --- a/slide2vec/artifacts.py +++ b/slide2vec/artifacts.py @@ -35,6 +35,21 @@ def metadata(self) -> dict[str, Any]: return load_metadata(self.metadata_path) +@dataclass(frozen=True, kw_only=True) +class HierarchicalEmbeddingArtifact: + sample_id: str + path: Path + metadata_path: Path + format: str + feature_dim: int + num_regions: int + tiles_per_region: int + + @property + def metadata(self) -> dict[str, Any]: + return load_metadata(self.metadata_path) + + def _validate_output_format(output_format: str) -> str: normalized = output_format.lower() if normalized not in {"pt", "npz"}: @@ -61,6 +76,14 @@ def _write_metadata(path: Path, metadata: dict[str, Any]) -> None: path.write_text(json.dumps(metadata, indent=2, sort_keys=True), encoding="utf-8") +def _setup_artifact_paths( + output_dir: str | Path, subdir: str, sample_id: str, output_format: str +) -> tuple[Path, Path]: + base_dir = Path(output_dir) / subdir + base_dir.mkdir(parents=True, exist_ok=True) + return base_dir / f"{sample_id}.{output_format}", base_dir / f"{sample_id}.meta.json" + + def _build_tile_embedding_metadata( sample_id: str, *, @@ -107,11 +130,7 @@ def write_tile_embeddings( tile_index: Any | None = None, ) -> TileEmbeddingArtifact: output_format = _validate_output_format(output_format) - base_dir = Path(output_dir) / "tile_embeddings" - base_dir.mkdir(parents=True, exist_ok=True) - artifact_path = base_dir / f"{sample_id}.{output_format}" - metadata_path = base_dir / f"{sample_id}.meta.json" - + artifact_path, metadata_path = _setup_artifact_paths(output_dir, "tile_embeddings", sample_id, output_format) feature_array = _ensure_array(features) if output_format == "pt": torch.save(_ensure_tensor(features), artifact_path) @@ -149,9 +168,7 @@ def write_tile_embedding_metadata( metadata: dict[str, Any] | None = None, ) -> Path: output_format = _validate_output_format(output_format) - base_dir = Path(output_dir) / "tile_embeddings" - base_dir.mkdir(parents=True, exist_ok=True) - metadata_path = base_dir / f"{sample_id}.meta.json" + _, metadata_path = _setup_artifact_paths(output_dir, "tile_embeddings", sample_id, output_format) tile_metadata = _build_tile_embedding_metadata( sample_id, output_format=output_format, @@ -173,27 +190,18 @@ def write_slide_embeddings( latents: Any | None = None, ) -> SlideEmbeddingArtifact: output_format = _validate_output_format(output_format) - base_dir = Path(output_dir) / "slide_embeddings" - base_dir.mkdir(parents=True, exist_ok=True) - artifact_path = base_dir / f"{sample_id}.{output_format}" - metadata_path = base_dir / f"{sample_id}.meta.json" - + artifact_path, metadata_path = _setup_artifact_paths(output_dir, "slide_embeddings", sample_id, output_format) embedding_array = _ensure_array(embedding) latent_path = None if output_format == "pt": torch.save(_ensure_tensor(embedding), artifact_path) - if latents is not None: - latents_dir = Path(output_dir) / "slide_latents" - latents_dir.mkdir(parents=True, exist_ok=True) - latent_path = latents_dir / f"{sample_id}.pt" - torch.save(_ensure_tensor(latents), latent_path) else: - payload = {"features": embedding_array} - np.savez_compressed(artifact_path, **payload) - if latents is not None: - latents_dir = Path(output_dir) / "slide_latents" - latents_dir.mkdir(parents=True, exist_ok=True) - latent_path = latents_dir / f"{sample_id}.npz" + np.savez_compressed(artifact_path, features=embedding_array) + if latents is not None: + latent_path, _ = _setup_artifact_paths(output_dir, "slide_latents", sample_id, output_format) + if output_format == "pt": + torch.save(_ensure_tensor(latents), latent_path) + else: np.savez_compressed(latent_path, latents=_ensure_array(latents)) slide_metadata = { @@ -213,3 +221,45 @@ def write_slide_embeddings( feature_dim=slide_metadata["feature_dim"], latent_path=latent_path, ) + + +def write_hierarchical_embeddings( + sample_id: str, + features, + *, + output_dir: str | Path, + output_format: str = "pt", + metadata: dict[str, Any] | None = None, +) -> HierarchicalEmbeddingArtifact: + output_format = _validate_output_format(output_format) + artifact_path, metadata_path = _setup_artifact_paths(output_dir, "hierarchical_embeddings", sample_id, output_format) + feature_array = _ensure_array(features) + if feature_array.ndim != 3: + raise ValueError( + "Hierarchical embeddings must have shape (num_regions, tiles_per_region, feature_dim)" + ) + if output_format == "pt": + torch.save(_ensure_tensor(features), artifact_path) + else: + np.savez_compressed(artifact_path, features=feature_array) + + hierarchical_metadata = { + "sample_id": sample_id, + "artifact_type": "hierarchical_embeddings", + "format": output_format, + "feature_dim": int(feature_array.shape[2]), + "num_regions": int(feature_array.shape[0]), + "tiles_per_region": int(feature_array.shape[1]), + } + if metadata: + hierarchical_metadata.update(metadata) + _write_metadata(metadata_path, hierarchical_metadata) + return HierarchicalEmbeddingArtifact( + sample_id=sample_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=int(hierarchical_metadata["feature_dim"]), + num_regions=int(hierarchical_metadata["num_regions"]), + tiles_per_region=int(hierarchical_metadata["tiles_per_region"]), + ) diff --git a/slide2vec/configs/default.yaml b/slide2vec/configs/default.yaml index dd7cc6a..c83d5da 100644 --- a/slide2vec/configs/default.yaml +++ b/slide2vec/configs/default.yaml @@ -29,6 +29,8 @@ tiling: target_spacing_um: # spacing at which to tile the slide, in microns per pixel; filled from a preset model when available tolerance: 0.05 # tolerance for matching the spacing (float between 0 and 1, deciding how much the spacing can deviate from the one specified in the slide metadata) target_tile_size_px: # size of the tiles to extract, in pixels; filled from a preset model when available + target_region_size_px: # size of hierarchical parent regions in pixels; when unset and region_tile_multiple is set, derived from target_tile_size_px * region_tile_multiple + region_tile_multiple: # hierarchical region grid width/height in tiles; e.g. 6 means 6x6 tiles per region overlap: 0.0 # percentage of overlap between two consecutive tiles (float between 0 and 1) tissue_threshold: 0.1 # minimum fraction of pixels that must be tissue to keep a tile (float between 0 and 1) seg_params: diff --git a/slide2vec/data/augmentations.py b/slide2vec/data/augmentations.py deleted file mode 100644 index e18c9db..0000000 --- a/slide2vec/data/augmentations.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -import torchvision.transforms.functional as F - -from typing import Sequence -from einops import rearrange -from torchvision import transforms - -# Use timm's names -IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) -IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) - - -def make_normalize_transform( - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -) -> transforms.Normalize: - return transforms.Normalize(mean=mean, std=std) - - -class MaybeToTensor(transforms.ToTensor): - """ - Convert a PIL Image or ndarray to tensor if it's not already one. - """ - - def __init__(self): - super().__init__() - - def __call__(self, pic): - """ - Args: - pic (PIL Image or numpy.ndarray): Image to be converted to tensor. - - Returns: - Tensor: Converted image. - """ - if isinstance(pic, torch.Tensor): - return pic - return F.to_tensor(pic) - - def __repr__(self): - return f"{self.__class__.__name__}()" - - -class RegionUnfolding: - def __init__(self, tile_size): - self.tile_size = tile_size - - def __call__(self, x): - # x = [3, region_size, region_size] - # unfold into tilees and rearrange - x = x.unfold(1, self.tile_size, self.tile_size).unfold( - 2, self.tile_size, self.tile_size - ) # [3, ntile, region_size, tile_size] -> [3, ntile, ntile, tile_size, tile_size] - x = rearrange( - x, "c p1 p2 w h -> (p1 p2) c w h" - ) # [num_tilees, 3, tile_size, tile_size] - return x diff --git a/slide2vec/data/tile_reader.py b/slide2vec/data/tile_reader.py index d82bc7d..08880c0 100644 --- a/slide2vec/data/tile_reader.py +++ b/slide2vec/data/tile_reader.py @@ -1,24 +1,24 @@ from collections import defaultdict import time +from pathlib import Path import numpy as np import torch from hs2p import TilingResult +from hs2p.wsi.streaming.plans import build_supertile_index class SuperTileBatchSampler: - """Batch sampler that keeps super tiles intact. + """Greedily packs whole groups into batches of approximately ``batch_size`` items. - Greedily packs whole super tiles into batches of approximately - ``batch_size`` tiles. No super tile is ever split across batches, - so each WSI region is read exactly once. + No group is ever split across batches. """ - def __init__(self, *, supertile_groups: list[np.ndarray], batch_size: int): + def __init__(self, *, groups: list[np.ndarray], batch_size: int): self.batches: list[list[int]] = [] current: list[int] = [] - for group in supertile_groups: + for group in groups: positions = group.tolist() if current and len(current) + len(positions) > batch_size: self.batches.append(current) @@ -35,6 +35,29 @@ def __len__(self): return len(self.batches) +def _open_wsi_backend(image_path: str, backend: str, gpu_decode: bool): + """Open a WSI file with the given backend and return the reader.""" + if backend == "cucim": + from hs2p.wsi.backends.cucim import CuCIMReader + return CuCIMReader(image_path, gpu_decode=gpu_decode) + elif backend == "openslide": + from hs2p.wsi.backends.openslide import OpenSlideReader + return OpenSlideReader(image_path) + elif backend == "vips": + from hs2p.wsi.backends.vips import VIPSReader + return VIPSReader(image_path) + elif backend == "asap": + from hs2p.wsi.backends.asap import ASAPReader + from slide2vec.utils.log_utils import suppress_c_stderr + with suppress_c_stderr(): + return ASAPReader(image_path) + else: + raise ValueError( + f"Unknown backend: {backend!r}. " + "Choose from: cucim, openslide, vips, asap" + ) + + class WSITileReader: """Random-access tile reader for WSI files supporting four backends. @@ -50,7 +73,7 @@ class WSITileReader: def __init__( self, - image_path: "Path", + image_path: Path, tiling_result: TilingResult, *, backend: str = "cucim", @@ -69,8 +92,6 @@ def __init__( self._reader = None if use_supertiles: - from hs2p.wsi.streaming.plans import build_supertile_index - index = build_supertile_index(tiling_result) self._supertile_plans = index.plans self._tile_to_st = index.tile_to_st @@ -84,31 +105,8 @@ def __init__( self._use_supertiles = use_supertiles def _ensure_open(self) -> None: - if self._reader is not None: - return - if self._backend == "cucim": - from hs2p.wsi.backends.cucim import CuCIMReader - - self._reader = CuCIMReader(self._image_path, gpu_decode=self._gpu_decode) - elif self._backend == "openslide": - from hs2p.wsi.backends.openslide import OpenSlideReader - - self._reader = OpenSlideReader(self._image_path) - elif self._backend == "vips": - from hs2p.wsi.backends.vips import VIPSReader - - self._reader = VIPSReader(self._image_path) - elif self._backend == "asap": - from hs2p.wsi.backends.asap import ASAPReader - from slide2vec.utils.log_utils import suppress_c_stderr - - with suppress_c_stderr(): - self._reader = ASAPReader(self._image_path) - else: - raise ValueError( - f"Unknown backend: {self._backend!r}. " - "Choose from: cucim, openslide, vips, asap" - ) + if self._reader is None: + self._reader = _open_wsi_backend(self._image_path, self._backend, self._gpu_decode) def _read_regions_batch( self, locations: list[tuple[int, int]], size: int @@ -205,7 +203,7 @@ class OnTheFlyBatchTileCollator: def __init__( self, *, - image_path: "Path", + image_path: Path, tiling_result: TilingResult, backend: str = "cucim", num_cucim_workers: int = 4, @@ -253,7 +251,7 @@ def build_batch_sampler( start = pos if start < len(dataset_indices): groups.append(np.arange(start, len(dataset_indices), dtype=np.int64)) - return SuperTileBatchSampler(supertile_groups=groups, batch_size=batch_size) + return SuperTileBatchSampler(groups=groups, batch_size=batch_size) def __call__(self, batch_indices): if not batch_indices: @@ -267,3 +265,149 @@ def __call__(self, batch_indices): tensor, timing = self._reader.read_batch_with_timing(tile_indices) timing["worker_batch_ms"] = (time.perf_counter() - worker_start) * 1000.0 return torch.as_tensor(tile_indices, dtype=torch.long), tensor, timing + + +class WSIRegionReader: + """Random-access region reader for hierarchical extraction.""" + + def __init__( + self, + image_path: Path, + *, + read_level: int, + region_size_px: int, + backend: str = "cucim", + num_cucim_workers: int = 4, + gpu_decode: bool = False, + ): + self._image_path = str(image_path) + self._backend = backend + self._num_cucim_workers = num_cucim_workers + self._gpu_decode = gpu_decode + self._read_level = int(read_level) + self._region_size_px = int(region_size_px) + self._reader = None + + def _ensure_open(self) -> None: + if self._reader is None: + self._reader = _open_wsi_backend(self._image_path, self._backend, self._gpu_decode) + + def _read_regions_batch(self, locations: list[tuple[int, int]]) -> list[np.ndarray]: + if self._backend == "cucim": + return list( + self._reader.read_regions( + locations, + self._read_level, + (self._region_size_px, self._region_size_px), + num_workers=self._num_cucim_workers, + ) + ) + return [ + self._reader.read_region( + loc, + self._read_level, + (self._region_size_px, self._region_size_px), + ) + for loc in locations + ] + + def read_batch_with_timing( + self, + locations: list[tuple[int, int]], + ) -> tuple[torch.Tensor, dict[str, float]]: + if not locations: + return ( + torch.empty((0, 3, self._region_size_px, self._region_size_px), dtype=torch.uint8), + {"reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + was_closed = self._reader is None + open_start = time.perf_counter() + self._ensure_open() + reader_open_ms = (time.perf_counter() - open_start) * 1000.0 if was_closed else 0.0 + read_start = time.perf_counter() + regions = self._read_regions_batch(locations) + reader_read_ms = (time.perf_counter() - read_start) * 1000.0 + batch = np.stack([np.asarray(region)[:, :, :3] for region in regions], axis=0) + tensor = torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous() + return tensor, {"reader_open_ms": reader_open_ms, "reader_read_ms": reader_read_ms} + + +class OnTheFlyHierarchicalBatchCollator: + """Collator that reads region crops once and unfolds selected subtiles.""" + + def __init__( + self, + *, + image_path: Path, + tiling_result: TilingResult, + region_index: np.ndarray, + subtile_index_within_region: np.ndarray, + effective_region_size_px: int, + effective_tile_size_px: int, + backend: str = "cucim", + num_cucim_workers: int = 4, + gpu_decode: bool = False, + ): + self._region_index = np.asarray(region_index, dtype=np.int32) + self._subtile_index_within_region = np.asarray(subtile_index_within_region, dtype=np.int32) + self._tiles_per_region = int(self._subtile_index_within_region.max()) + 1 if len(self._subtile_index_within_region) else 0 + self._tile_size = int(effective_tile_size_px) + self._reader = WSIRegionReader( + image_path, + read_level=int(tiling_result.read_level), + region_size_px=int(effective_region_size_px), + backend=backend, + num_cucim_workers=num_cucim_workers, + gpu_decode=gpu_decode, + ) + self._region_locations = [ + (int(x), int(y)) + for x, y in zip(np.asarray(tiling_result.x), np.asarray(tiling_result.y)) + ] + + def build_batch_sampler( + self, + *, + batch_size: int, + dataset_indices: np.ndarray, + ) -> SuperTileBatchSampler: + if len(dataset_indices) == 0: + return SuperTileBatchSampler(groups=[], batch_size=batch_size) + regions = self._region_index[dataset_indices] + boundaries = np.where(np.concatenate(([True], regions[1:] != regions[:-1], [True])))[0] + groups = [np.arange(boundaries[i], boundaries[i + 1], dtype=np.int64) for i in range(len(boundaries) - 1)] + return SuperTileBatchSampler(groups=groups, batch_size=batch_size) + + def __call__(self, batch_indices): + if not batch_indices: + return ( + torch.empty((0,), dtype=torch.long), + torch.empty((0, 3, self._tile_size, self._tile_size), dtype=torch.uint8), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ) + worker_start = time.perf_counter() + flat_indices = np.asarray(batch_indices, dtype=np.int64) + requested_regions = self._region_index[flat_indices] + unique_regions, inverse = np.unique(requested_regions, return_inverse=True) + locations = [self._region_locations[int(region)] for region in unique_regions] + region_tensor, timing = self._reader.read_batch_with_timing(locations) + unfolded = _unfold_region_tensor_uint8(region_tensor, self._tile_size) + subtile_indices = self._subtile_index_within_region[flat_indices] + out = unfolded[torch.as_tensor(inverse, dtype=torch.long), torch.as_tensor(subtile_indices, dtype=torch.long)] + timing["worker_batch_ms"] = (time.perf_counter() - worker_start) * 1000.0 + return torch.as_tensor(flat_indices, dtype=torch.long), out, timing + + +def _unfold_region_tensor_uint8(region_tensor: torch.Tensor, tile_size: int) -> torch.Tensor: + if region_tensor.numel() == 0: + return torch.empty((0, 0, 3, tile_size, tile_size), dtype=torch.uint8) + if int(region_tensor.shape[-1]) % tile_size != 0 or int(region_tensor.shape[-2]) % tile_size != 0: + raise ValueError("Region tensor dimensions must be divisible by the tile size") + unfolded = torch.nn.functional.unfold( + region_tensor.to(torch.float32), + kernel_size=tile_size, + stride=tile_size, + ) + unfolded = unfolded.transpose(1, 2) + reshaped = unfolded.reshape(region_tensor.shape[0], -1, region_tensor.shape[1], tile_size, tile_size) + return reshaped.round().clamp(0, 255).to(torch.uint8) diff --git a/slide2vec/distributed/__init__.py b/slide2vec/distributed/__init__.py index 069b60f..4d0c299 100644 --- a/slide2vec/distributed/__init__.py +++ b/slide2vec/distributed/__init__.py @@ -1,9 +1,7 @@ import datetime import os import random -import re import socket -from typing import Dict, List import torch import torch.distributed as dist @@ -93,17 +91,6 @@ def print(*args, **kwargs): __builtin__.print = print -def _get_master_port(seed: int = 0) -> int: - MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) - - master_port_str = os.environ.get("MASTER_PORT") - if master_port_str is None: - rng = random.Random(seed) - return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) - - return int(master_port_str) - - def _get_available_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # A "" host address means INADDR_ANY i.e. binding to all interfaces. @@ -123,7 +110,7 @@ def _get_available_port() -> int: ) -def _collect_env_vars() -> Dict[str, str]: +def _collect_env_vars() -> dict[str, str]: return { env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS @@ -131,27 +118,6 @@ def _collect_env_vars() -> Dict[str, str]: } -def _is_slurm_job_process() -> bool: - return "SLURM_JOB_ID" in os.environ - - -def _parse_slurm_node_list(s: str) -> List[str]: - nodes = [] - # Extract "hostname", "hostname[1-2,3,4-5]," substrings - p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") - for m in p.finditer(s): - prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] - for suffix in suffixes.split(","): - span = suffix.split("-") - if len(span) == 1: - nodes.append(prefix + suffix) - else: - width = len(span[0]) - start, end = int(span[0]), int(span[1]) + 1 - nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) - return nodes - - def _check_env_variable(key: str, new_value: str): # Only check for difference with preset environment variables if key in os.environ and os.environ[key] != new_value: @@ -169,9 +135,6 @@ def __init__(self): self.local_rank = -1 self.local_world_size = -1 - # if _is_slurm_job_process(): - # return self._set_from_slurm_env() - env_vars = _collect_env_vars() if not env_vars: # Environment is not set @@ -189,23 +152,6 @@ def __init__(self): raise RuntimeError("Can't initialize PyTorch distributed environment") - # Slurm job created with sbatch, submitit, etc... - def _set_from_slurm_env(self): - # logger.info("Initialization from Slurm environment") - job_id = int(os.environ["SLURM_JOB_ID"]) - node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) - nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) - assert len(nodes) == node_count - - self.master_addr = nodes[0] - self.master_port = _get_master_port(seed=job_id) - self.rank = int(os.environ["SLURM_PROCID"]) - self.world_size = int(os.environ["SLURM_NTASKS"]) - assert self.rank < self.world_size - self.local_rank = int(os.environ["SLURM_LOCALID"]) - self.local_world_size = self.world_size // node_count - assert self.local_rank < self.local_world_size - # Single node job with preset environment (i.e. torchrun) def _set_from_preset_env(self): # logger.info("Initialization from preset environment") diff --git a/slide2vec/distributed/direct_embed_worker.py b/slide2vec/distributed/direct_embed_worker.py index cf793b9..45fde5e 100644 --- a/slide2vec/distributed/direct_embed_worker.py +++ b/slide2vec/distributed/direct_embed_worker.py @@ -20,8 +20,11 @@ def main(argv=None) -> int: import slide2vec.distributed as distributed from slide2vec.api import Model from slide2vec.inference import ( + _build_hierarchical_index, _compute_embedded_slides, + _compute_hierarchical_embedding_shard_for_slide, _compute_tile_embeddings_for_slide, + _is_hierarchical_preprocessing, deserialize_execution, deserialize_preprocessing, load_successful_tiled_slides, @@ -69,23 +72,43 @@ def main(argv=None) -> int: if request["strategy"] == "tile_shard": sample_id = request["sample_id"] slide, tiling_result = paired_by_sample[sample_id] - num_tiles = len(tiling_result.x) - tile_indices = np.array_split(np.arange(num_tiles, dtype=np.int64), world_size)[global_rank] loaded = model._load_backend() - tile_embeddings = _compute_tile_embeddings_for_slide( - loaded, - model, - slide, - tiling_result, - preprocessing=preprocessing, - execution=execution, - tile_indices=tile_indices, - ) - payload = { - "tile_index": torch.as_tensor(tile_indices, dtype=torch.long), - "tile_embeddings": tile_embeddings.detach().cpu() if torch.is_tensor(tile_embeddings) else torch.as_tensor(tile_embeddings), - } - torch.save(payload, coordination_dir / f"{sample_id}.tiles.rank{global_rank}.pt") + if _is_hierarchical_preprocessing(preprocessing): + index = _build_hierarchical_index( + tiling_result, + region_tile_multiple=int(preprocessing.region_tile_multiple), + ) + flat_indices = np.array_split(index.flat_index, world_size)[global_rank] + shard_indices, tile_embeddings = _compute_hierarchical_embedding_shard_for_slide( + loaded, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + flat_indices=flat_indices, + ) + payload = { + "flat_index": torch.as_tensor(shard_indices, dtype=torch.long), + "tile_embeddings": tile_embeddings.detach().cpu() if torch.is_tensor(tile_embeddings) else torch.as_tensor(tile_embeddings), + } + torch.save(payload, coordination_dir / f"{sample_id}.hier.rank{global_rank}.pt") + else: + num_tiles = len(tiling_result.x) + tile_indices = np.array_split(np.arange(num_tiles, dtype=np.int64), world_size)[global_rank] + tile_embeddings = _compute_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + tile_indices=tile_indices, + ) + payload = { + "tile_index": torch.as_tensor(tile_indices, dtype=torch.long), + "tile_embeddings": tile_embeddings.detach().cpu() if torch.is_tensor(tile_embeddings) else torch.as_tensor(tile_embeddings), + } + torch.save(payload, coordination_dir / f"{sample_id}.tiles.rank{global_rank}.pt") return 0 assigned_ids = list(request.get("assignments", {}).get(str(global_rank), [])) @@ -102,9 +125,9 @@ def main(argv=None) -> int: ) for embedded_slide in embedded_slides: payload = { - "tile_embeddings": _to_cpu_payload(torch, embedded_slide.tile_embeddings), - "slide_embedding": _to_cpu_payload(torch, embedded_slide.slide_embedding), - "latents": _to_cpu_payload(torch, embedded_slide.latents), + "tile_embeddings": _to_cpu_payload(embedded_slide.tile_embeddings), + "slide_embedding": _to_cpu_payload(embedded_slide.slide_embedding), + "latents": _to_cpu_payload(embedded_slide.latents), } torch.save(payload, coordination_dir / f"{embedded_slide.sample_id}.embedded.pt") return 0 @@ -113,7 +136,8 @@ def main(argv=None) -> int: dist.destroy_process_group() -def _to_cpu_payload(torch, value): +def _to_cpu_payload(value): + import torch if value is None: return None if torch.is_tensor(value): diff --git a/slide2vec/inference.py b/slide2vec/inference.py index ce98e0c..eb9456f 100644 --- a/slide2vec/inference.py +++ b/slide2vec/inference.py @@ -1,7 +1,6 @@ import json import importlib import os -import re import shutil import subprocess import sys @@ -26,10 +25,13 @@ ExecutionOptions, PreprocessingConfig, RunResult, + _resolve_hierarchical_preprocessing, ) from slide2vec.artifacts import ( + HierarchicalEmbeddingArtifact, SlideEmbeddingArtifact, TileEmbeddingArtifact, + write_hierarchical_embeddings, load_array, load_metadata, write_slide_embeddings, @@ -47,9 +49,10 @@ ) from slide2vec.utils.log_utils import suppress_c_stderr from slide2vec.data.dataset import BatchTileCollator, TileIndexDataset -from slide2vec.data.tile_reader import OnTheFlyBatchTileCollator +from slide2vec.data.tile_reader import OnTheFlyBatchTileCollator, OnTheFlyHierarchicalBatchCollator from slide2vec.utils.coordinates import coordinate_arrays -from slide2vec.utils.tiling_io import load_process_df, load_slide_manifest, load_tiling_result_from_row +from slide2vec.utils.tiling_io import load_process_df, load_slide_manifest, load_tiling_result_from_row, _optional_float +from slide2vec.utils.utils import slurm_cpu_limit @dataclass(frozen=True, kw_only=True) @@ -58,7 +61,6 @@ class BatchTransformSpec: center_crop_size: tuple[int, int] | None mean: tuple[float, ...] | None std: tuple[float, ...] | None - region_unfold_tile_size: int | None = None resize_interpolation: str = "bilinear" @@ -74,37 +76,107 @@ class PreparedBatch: reader_read_ms: float = 0.0 +@dataclass(frozen=True, kw_only=True) +class HierarchicalIndex: + flat_index: np.ndarray + region_index: np.ndarray + subtile_index_within_region: np.ndarray + subtile_x: np.ndarray + subtile_y: np.ndarray + num_regions: int + tiles_per_region: int -def _optional_float(value: Any) -> float | None: - if value is None: - return None - try: - if np.isnan(value): - return None - except TypeError: - pass - return float(value) +def _is_hierarchical_preprocessing(preprocessing: PreprocessingConfig | None) -> bool: + if preprocessing is None: + return False + return preprocessing.region_tile_multiple is not None or preprocessing.target_region_size_px is not None + + +def _resolve_hierarchical_geometry(preprocessing: PreprocessingConfig, tiling_result) -> dict[str, int]: + if preprocessing.region_tile_multiple is None: + raise ValueError("Hierarchical preprocessing requires region_tile_multiple") + if preprocessing.target_region_size_px is None: + raise ValueError("Hierarchical preprocessing requires target_region_size_px") + target_tile_size_px = int(preprocessing.target_tile_size_px) + target_region_size_px = int(preprocessing.target_region_size_px) + effective_region_size_px = int(getattr(tiling_result, "effective_tile_size_px")) + tile_size_lv0 = int(getattr(tiling_result, "tile_size_lv0")) + multiple = int(preprocessing.region_tile_multiple) + if target_region_size_px % multiple != 0: + raise ValueError("target_region_size_px must be divisible by region_tile_multiple") + if effective_region_size_px % multiple != 0: + raise ValueError("effective_region_size_px must be divisible by region_tile_multiple") + if tile_size_lv0 % multiple != 0: + raise ValueError("tile_size_lv0 must be divisible by region_tile_multiple") + return { + "region_tile_multiple": multiple, + "tiles_per_region": multiple * multiple, + "target_tile_size_px": target_tile_size_px, + "effective_tile_size_px": effective_region_size_px // multiple, + "target_region_size_px": target_region_size_px, + "effective_region_size_px": effective_region_size_px, + "tile_size_lv0": tile_size_lv0 // multiple, + } + + +def _build_hierarchical_index( + tiling_result, + *, + region_tile_multiple: int, +) -> HierarchicalIndex: + x_values, y_values = coordinate_arrays(tiling_result) + num_regions = int(len(x_values)) + multiple = int(region_tile_multiple) + if multiple < 2: + raise ValueError("region_tile_multiple must be at least 2") + tile_size_lv0 = int(getattr(tiling_result, "tile_size_lv0")) + if tile_size_lv0 % multiple != 0: + raise ValueError("tile_size_lv0 must be divisible by region_tile_multiple") + subtile_size_lv0 = tile_size_lv0 // multiple + tiles_per_region = multiple * multiple + if num_regions == 0: + empty = np.empty(0, dtype=np.int64) + return HierarchicalIndex( + flat_index=empty, + region_index=np.empty(0, dtype=np.int32), + subtile_index_within_region=np.empty(0, dtype=np.int32), + subtile_x=empty, + subtile_y=empty, + num_regions=0, + tiles_per_region=tiles_per_region, + ) + rows, cols = np.divmod(np.arange(tiles_per_region, dtype=np.int32), multiple) + offsets_x = cols.astype(np.int64) * subtile_size_lv0 + offsets_y = rows.astype(np.int64) * subtile_size_lv0 + region_x = np.asarray(x_values, dtype=np.int64)[:, np.newaxis] + region_y = np.asarray(y_values, dtype=np.int64)[:, np.newaxis] + subtile_x = (region_x + offsets_x[np.newaxis, :]).reshape(-1) + subtile_y = (region_y + offsets_y[np.newaxis, :]).reshape(-1) + return HierarchicalIndex( + flat_index=np.arange(num_regions * tiles_per_region, dtype=np.int64), + region_index=np.repeat(np.arange(num_regions, dtype=np.int32), tiles_per_region), + subtile_index_within_region=np.tile(np.arange(tiles_per_region, dtype=np.int32), num_regions), + subtile_x=subtile_x, + subtile_y=subtile_y, + num_regions=num_regions, + tiles_per_region=tiles_per_region, + ) + + +def _num_embedding_items(tiling_result, preprocessing: PreprocessingConfig | None) -> int: + if not _is_hierarchical_preprocessing(preprocessing): + return _num_tiles(tiling_result) + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + return _num_tiles(tiling_result) * int(geometry["tiles_per_region"]) -def _slurm_cpu_limit() -> int | None: - for env_name in ("SLURM_CPUS_PER_TASK", "SLURM_CPUS_ON_NODE", "SLURM_JOB_CPUS_PER_NODE"): - if env_name not in os.environ: - continue - value = os.environ[env_name] - match = re.match(r"\s*(\d+)", value) - if match is None: - continue - limit = int(match.group(1)) - if limit > 0: - return limit - return None def _resolve_on_the_fly_num_workers(num_cucim_workers: int) -> tuple[int, str]: cpu_count = os.cpu_count() or 4 worker_budget = cpu_count details = [f"cpu_count={cpu_count}"] - slurm_limit = _slurm_cpu_limit() + slurm_limit = slurm_cpu_limit() if slurm_limit is not None: worker_budget = min(worker_budget, slurm_limit) details.append(f"slurm_cpu_limit={slurm_limit}") @@ -314,7 +386,7 @@ def embed_tiles( *, execution: ExecutionOptions, preprocessing: PreprocessingConfig | None = None, -) -> list[TileEmbeddingArtifact]: +) -> list[TileEmbeddingArtifact] | list[HierarchicalEmbeddingArtifact]: if execution.output_dir is None: raise ValueError("ExecutionOptions.output_dir is required to persist tile embeddings") @@ -322,30 +394,53 @@ def embed_tiles( slide_records = [_coerce_slide_spec(slide) for slide in slides] resolved_tiling_results = _normalize_tiling_results(tiling_results, slide_records) resolved_preprocessing = _resolve_model_preprocessing(model, preprocessing) - artifacts: list[TileEmbeddingArtifact] = [] + hierarchical_mode = _is_hierarchical_preprocessing(resolved_preprocessing) + artifacts: list[TileEmbeddingArtifact] | list[HierarchicalEmbeddingArtifact] = [] for slide, tiling_result in zip(slide_records, resolved_tiling_results): - features = _compute_tile_embeddings_for_slide( - loaded, - model, - slide, - tiling_result, - preprocessing=resolved_preprocessing, - execution=execution, - ) - metadata = _build_tile_embedding_metadata( - model, - tiling_result=tiling_result, - image_path=slide.image_path, - mask_path=slide.mask_path, - tile_size_lv0=int(tiling_result.tile_size_lv0), - backend=_resolve_slide_backend(resolved_preprocessing, tiling_result), - ) - artifact = _write_tile_embedding_artifact( - slide.sample_id, - features, - execution=execution, - metadata=metadata, - ) + if hierarchical_mode: + features = _compute_hierarchical_embeddings_for_slide( + loaded, + slide, + tiling_result, + preprocessing=resolved_preprocessing, + execution=execution, + ) + artifact = _write_hierarchical_embedding_artifact( + slide.sample_id, + features, + execution=execution, + metadata=_build_hierarchical_embedding_metadata( + model, + tiling_result=tiling_result, + image_path=slide.image_path, + mask_path=slide.mask_path, + backend=_resolve_slide_backend(resolved_preprocessing, tiling_result), + preprocessing=resolved_preprocessing, + ), + ) + else: + features = _compute_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + preprocessing=resolved_preprocessing, + execution=execution, + ) + metadata = _build_tile_embedding_metadata( + model, + tiling_result=tiling_result, + image_path=slide.image_path, + mask_path=slide.mask_path, + tile_size_lv0=int(tiling_result.tile_size_lv0), + backend=_resolve_slide_backend(resolved_preprocessing, tiling_result), + ) + artifact = _write_tile_embedding_artifact( + slide.sample_id, + features, + execution=execution, + metadata=metadata, + ) artifacts.append(artifact) return artifacts @@ -462,7 +557,12 @@ def run_pipeline( output_dir=str(output_dir), logs_dir=str(output_dir / "logs"), ) - return RunResult(tile_artifacts=[], slide_artifacts=[], process_list_path=process_list_path) + return RunResult( + tile_artifacts=[], + hierarchical_artifacts=[], + slide_artifacts=[], + process_list_path=process_list_path, + ) _write_zero_tile_embedding_sidecars( zero_tile_pairs, @@ -474,7 +574,7 @@ def run_pipeline( emit_progress("embedding.started", slide_count=len(embeddable_slides)) if execution.num_gpus > 1: - tile_artifacts, slide_artifacts = _collect_distributed_pipeline_artifacts( + tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_distributed_pipeline_artifacts( model=model, successful_slides=embeddable_slides, process_list_path=process_list_path, @@ -486,7 +586,7 @@ def run_pipeline( "embedding.finished", slide_count=len(embeddable_slides), slides_completed=len(embeddable_slides), - tile_artifacts=len(tile_artifacts), + tile_artifacts=len(tile_artifacts) + len(hierarchical_artifacts), slide_artifacts=len(slide_artifacts), ) emit_progress( @@ -496,11 +596,13 @@ def run_pipeline( ) return RunResult( tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, slide_artifacts=slide_artifacts, process_list_path=process_list_path, ) persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_hierarchical_embeddings = _is_hierarchical_preprocessing(resolved_preprocessing) include_slide_embeddings = model.level == "slide" pending_slides, pending_tiling_results = _pending_local_embedding_records( embeddable_slides, @@ -509,6 +611,7 @@ def run_pipeline( output_dir=output_dir, output_format=execution.output_format, persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, save_latents=execution.save_latents, resume=resolved_preprocessing.resume, @@ -529,26 +632,29 @@ def run_pipeline( execution=execution, on_embedded_slide=local_persist_callback, ) - tile_artifacts, slide_artifacts = _collect_pipeline_artifacts( + tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_pipeline_artifacts( embeddable_slides, output_dir=output_dir, output_format=execution.output_format, include_tile_embeddings=persist_tile_embeddings, + include_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, ) _update_process_list_after_embedding( process_list_path, successful_slides=embeddable_slides, persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, slide_artifacts=slide_artifacts, ) emit_progress( "embedding.finished", slide_count=len(embeddable_slides), slides_completed=len(embeddable_slides), - tile_artifacts=len(tile_artifacts), + tile_artifacts=len(tile_artifacts) + len(hierarchical_artifacts), slide_artifacts=len(slide_artifacts), ) emit_progress( @@ -558,6 +664,98 @@ def run_pipeline( ) return RunResult( tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, + slide_artifacts=slide_artifacts, + process_list_path=process_list_path, + ) + except Exception as exc: + emit_progress("run.failed", stage="pipeline", error=str(exc)) + raise + + +def run_pipeline_with_coordinates( + model, + *, + coordinates_dir: str | Path, + slides=None, + preprocessing: PreprocessingConfig | None = None, + execution: ExecutionOptions, +) -> RunResult: + if execution.output_dir is None: + raise ValueError("ExecutionOptions.output_dir is required for Pipeline.run_with_coordinates(...)") + if execution.num_gpus > 1: + _validate_multi_gpu_execution(model, execution) + + output_dir = Path(execution.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + resolved_preprocessing = _resolve_model_preprocessing(model, preprocessing) + available_slides, available_tilings = load_successful_tiled_slides(coordinates_dir) + if slides is None: + slide_records = available_slides + tiling_results = available_tilings + else: + requested_ids = {slide.sample_id: slide for slide in [_coerce_slide_spec(slide) for slide in slides]} + slide_records = [] + tiling_results = [] + for slide, tiling_result in zip(available_slides, available_tilings): + if slide.sample_id not in requested_ids: + continue + slide_records.append(requested_ids[slide.sample_id]) + tiling_results.append(tiling_result) + process_list_path = Path(coordinates_dir) / "process_list.csv" + emit_progress( + "run.started", + model_name=model.name, + level=model.level, + device_mode=_describe_device_mode(model, execution), + slide_count=len(slide_records), + output_dir=str(output_dir), + ) + try: + embeddable_slides, embeddable_tiling_results, zero_tile_pairs = _partition_slides_by_tile_count( + slide_records, + tiling_results, + ) + _write_zero_tile_embedding_sidecars( + zero_tile_pairs, + model=model, + preprocessing=resolved_preprocessing, + output_dir=output_dir, + output_format=execution.output_format, + ) + emit_progress("embedding.started", slide_count=len(embeddable_slides)) + if execution.num_gpus > 1: + tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_distributed_pipeline_artifacts( + model=model, + successful_slides=embeddable_slides, + process_list_path=process_list_path, + preprocessing=resolved_preprocessing, + execution=execution, + output_dir=output_dir, + ) + return RunResult( + tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, + slide_artifacts=slide_artifacts, + process_list_path=process_list_path, + ) + embedded_slides = _compute_embedded_slides( + model, + embeddable_slides, + embeddable_tiling_results, + preprocessing=resolved_preprocessing, + execution=execution, + ) + tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_local_pipeline_artifacts( + model=model, + embedded_slides=embedded_slides, + tiling_results=embeddable_tiling_results, + preprocessing=resolved_preprocessing, + execution=execution, + ) + return RunResult( + tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, slide_artifacts=slide_artifacts, process_list_path=process_list_path, ) @@ -573,8 +771,9 @@ def _collect_local_pipeline_artifacts( tiling_results, preprocessing: PreprocessingConfig, execution: ExecutionOptions, -) -> tuple[list[TileEmbeddingArtifact], list[SlideEmbeddingArtifact]]: +) -> tuple[list[TileEmbeddingArtifact], list[HierarchicalEmbeddingArtifact], list[SlideEmbeddingArtifact]]: tile_artifacts: list[TileEmbeddingArtifact] = [] + hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = [] slide_artifacts: list[SlideEmbeddingArtifact] = [] for embedded_slide, tiling_result in zip(embedded_slides, tiling_results): tile_artifact, slide_artifact = _persist_embedded_slide( @@ -584,11 +783,13 @@ def _collect_local_pipeline_artifacts( preprocessing=preprocessing, execution=execution, ) - if tile_artifact is not None: + if isinstance(tile_artifact, HierarchicalEmbeddingArtifact): + hierarchical_artifacts.append(tile_artifact) + elif tile_artifact is not None: tile_artifacts.append(tile_artifact) if slide_artifact is not None: slide_artifacts.append(slide_artifact) - return tile_artifacts, slide_artifacts + return tile_artifacts, hierarchical_artifacts, slide_artifacts def _build_incremental_persist_callback( @@ -599,15 +800,16 @@ def _build_incremental_persist_callback( process_list_path: Path | None = None, ) -> tuple[ Callable[[SlideSpec, Any, EmbeddedSlide], None] | None, - list[TileEmbeddingArtifact], + list[TileEmbeddingArtifact] | list[HierarchicalEmbeddingArtifact], list[SlideEmbeddingArtifact], ]: - tile_artifacts: list[TileEmbeddingArtifact] = [] + tile_artifacts: list[TileEmbeddingArtifact] | list[HierarchicalEmbeddingArtifact] = [] slide_artifacts: list[SlideEmbeddingArtifact] = [] if execution.output_dir is None: return None, tile_artifacts, slide_artifacts persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_hierarchical_embeddings = _is_hierarchical_preprocessing(preprocessing) include_slide_embeddings = model.level == "slide" def _persist_completed_slide(slide: SlideSpec, tiling_result, embedded_slide: EmbeddedSlide) -> None: @@ -627,8 +829,10 @@ def _persist_completed_slide(slide: SlideSpec, tiling_result, embedded_slide: Em process_list_path, successful_slides=[slide], persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, - tile_artifacts=[tile_artifact] if tile_artifact is not None else [], + tile_artifacts=[tile_artifact] if isinstance(tile_artifact, TileEmbeddingArtifact) else [], + hierarchical_artifacts=[tile_artifact] if isinstance(tile_artifact, HierarchicalEmbeddingArtifact) else [], slide_artifacts=[slide_artifact] if slide_artifact is not None else [], ) @@ -643,6 +847,7 @@ def _pending_local_embedding_records( output_dir: Path, output_format: str, persist_tile_embeddings: bool, + persist_hierarchical_embeddings: bool, include_slide_embeddings: bool, save_latents: bool, resume: bool, @@ -655,6 +860,7 @@ def _pending_local_embedding_records( output_dir=output_dir, output_format=output_format, persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, save_latents=save_latents, ) @@ -674,12 +880,13 @@ def _completed_local_embedding_sample_ids( output_dir: Path, output_format: str, persist_tile_embeddings: bool, + persist_hierarchical_embeddings: bool, include_slide_embeddings: bool, save_latents: bool, ) -> set[str]: process_df = load_process_df( process_list_path, - include_feature_status=persist_tile_embeddings or include_slide_embeddings, + include_feature_status=persist_tile_embeddings or persist_hierarchical_embeddings or include_slide_embeddings, include_aggregation_status=include_slide_embeddings, ) completed_ids: set[str] = set() @@ -696,6 +903,7 @@ def _completed_local_embedding_sample_ids( output_dir=output_dir, output_format=output_format, persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, save_latents=save_latents, ): @@ -710,6 +918,7 @@ def _has_complete_local_embedding_outputs( output_dir: Path, output_format: str, persist_tile_embeddings: bool, + persist_hierarchical_embeddings: bool, include_slide_embeddings: bool, save_latents: bool, ) -> bool: @@ -718,6 +927,11 @@ def _has_complete_local_embedding_outputs( tile_metadata_path = output_dir / "tile_embeddings" / f"{sample_id}.meta.json" if not tile_artifact_path.is_file() or not tile_metadata_path.is_file(): return False + if persist_hierarchical_embeddings: + hierarchical_artifact_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.{output_format}" + hierarchical_metadata_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.meta.json" + if not hierarchical_artifact_path.is_file() or not hierarchical_metadata_path.is_file(): + return False if include_slide_embeddings: slide_artifact_path = output_dir / "slide_embeddings" / f"{sample_id}.{output_format}" slide_metadata_path = output_dir / "slide_embeddings" / f"{sample_id}.meta.json" @@ -739,8 +953,13 @@ def _collect_distributed_pipeline_artifacts( preprocessing: PreprocessingConfig, execution: ExecutionOptions, output_dir: Path, -) -> tuple[list[TileEmbeddingArtifact], list[SlideEmbeddingArtifact]]: +) -> tuple[ + list[TileEmbeddingArtifact], + list[HierarchicalEmbeddingArtifact], + list[SlideEmbeddingArtifact], +]: persist_tile_embeddings = _should_persist_tile_embeddings(model, execution) + persist_hierarchical_embeddings = _is_hierarchical_preprocessing(preprocessing) include_slide_embeddings = model.level == "slide" _run_distributed_embedding_stage( model=model, @@ -749,22 +968,25 @@ def _collect_distributed_pipeline_artifacts( execution=execution, output_dir=output_dir, ) - tile_artifacts, slide_artifacts = _collect_pipeline_artifacts( + tile_artifacts, hierarchical_artifacts, slide_artifacts = _collect_pipeline_artifacts( successful_slides, output_dir=output_dir, output_format=execution.output_format, include_tile_embeddings=persist_tile_embeddings, + include_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, ) _update_process_list_after_embedding( process_list_path, successful_slides=successful_slides, persist_tile_embeddings=persist_tile_embeddings, + persist_hierarchical_embeddings=persist_hierarchical_embeddings, include_slide_embeddings=include_slide_embeddings, tile_artifacts=tile_artifacts, + hierarchical_artifacts=hierarchical_artifacts, slide_artifacts=slide_artifacts, ) - return tile_artifacts, slide_artifacts + return tile_artifacts, hierarchical_artifacts, slide_artifacts def _compute_embedded_slides( @@ -782,21 +1004,30 @@ def _compute_embedded_slides( emit_progress( "embedding.slide.started", sample_id=slide.sample_id, - total_tiles=_num_tiles(tiling_result), - ) - tile_embeddings = _compute_tile_embeddings_for_slide( - loaded, - model, - slide, - tiling_result, - preprocessing=preprocessing, - execution=execution, + total_tiles=_num_embedding_items(tiling_result, preprocessing), ) + if _is_hierarchical_preprocessing(preprocessing): + tile_embeddings = _compute_hierarchical_embeddings_for_slide( + loaded, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + ) + else: + tile_embeddings = _compute_tile_embeddings_for_slide( + loaded, + model, + slide, + tiling_result, + preprocessing=preprocessing, + execution=execution, + ) if model.level == "slide": emit_progress( "aggregation.started", sample_id=slide.sample_id, - total_tiles=_num_tiles(tiling_result), + total_tiles=_num_embedding_items(tiling_result, preprocessing), ) slide_embedding, latents = _aggregate_tile_embeddings_for_slide( loaded, @@ -826,7 +1057,7 @@ def _compute_embedded_slides( emit_progress( "embedding.slide.finished", sample_id=slide.sample_id, - num_tiles=_num_tiles(tiling_result), + num_tiles=_num_embedding_items(tiling_result, preprocessing), ) return embedded_slides @@ -899,7 +1130,6 @@ def _compute_tile_embeddings_for_slide( dataset = TileIndexDataset(resolved_indices) batch_preprocessor = _build_batch_preprocessor( loaded, - model, tiling_result, ) loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) @@ -932,7 +1162,7 @@ def _compute_tile_embeddings_for_slide( batch_preprocessor=batch_preprocessor, sample_id=slide.sample_id, total_items=len(dataset), - unit_label="region" if model.level == "region" else "tile", + unit_label="tile", ) if _supertile_reorder is not None: inverse = np.argsort(_supertile_reorder, kind="stable") @@ -940,6 +1170,150 @@ def _compute_tile_embeddings_for_slide( return tile_embeddings +def _compute_hierarchical_embeddings_for_slide( + loaded: LoadedModel, + slide: SlideSpec, + tiling_result, + *, + preprocessing: PreprocessingConfig, + execution: ExecutionOptions, + flat_indices=None, +): + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + index = _build_hierarchical_index( + tiling_result, + region_tile_multiple=int(geometry["region_tile_multiple"]), + ) + resolved_indices = index.flat_index + if flat_indices is not None: + resolved_indices = np.asarray(flat_indices, dtype=np.int64) + if resolved_indices.size == 0: + return torch.empty( + (index.num_regions, index.tiles_per_region, int(loaded.feature_dim)), + dtype=torch.float32, + ) + collate_fn = OnTheFlyHierarchicalBatchCollator( + image_path=slide.image_path, + tiling_result=tiling_result, + region_index=index.region_index, + subtile_index_within_region=index.subtile_index_within_region, + effective_region_size_px=int(geometry["effective_region_size_px"]), + effective_tile_size_px=int(geometry["effective_tile_size_px"]), + backend=_resolve_slide_backend(preprocessing, tiling_result), + num_cucim_workers=preprocessing.num_cucim_workers, + gpu_decode=preprocessing.gpu_decode, + ) + dataset = TileIndexDataset(resolved_indices) + batch_preprocessor = _build_batch_preprocessor_for_tile_images( + loaded, + target_tile_size_px=int(geometry["target_tile_size_px"]), + ) + loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) + effective_num_workers, worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + if effective_num_workers != execution.num_workers: + logging.getLogger(__name__).info( + f"on-the-fly hierarchical mode: setting DataLoader num_workers={effective_num_workers} " + f"({worker_context}); " + f"ignoring speed.num_dataloader_workers={execution.num_workers}" + ) + loader_kwargs["num_workers"] = effective_num_workers + if effective_num_workers == 0: + loader_kwargs.pop("persistent_workers", None) + loader_kwargs.pop("prefetch_factor", None) + loader_kwargs["batch_sampler"] = collate_fn.build_batch_sampler( + batch_size=execution.batch_size, + dataset_indices=np.asarray(resolved_indices, dtype=np.int64), + ) + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + **loader_kwargs, + ) + autocast_dtype = _autocast_dtype(torch, execution.precision) + autocast_context = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) + if autocast_dtype is not None and str(loaded.device).startswith("cuda") + else nullcontext() + ) + batch_flat_indices, flat_embeddings = _run_forward_pass( + dataloader, + loaded, + autocast_context, + batch_preprocessor=batch_preprocessor, + sample_id=slide.sample_id, + total_items=len(dataset), + unit_label="tile", + return_indices=True, + ) + result = torch.empty( + (index.num_regions * index.tiles_per_region, int(flat_embeddings.shape[-1])), + dtype=flat_embeddings.dtype, + ) + result[batch_flat_indices] = flat_embeddings + return result.reshape(index.num_regions, index.tiles_per_region, int(flat_embeddings.shape[-1])) + + +def _compute_hierarchical_embedding_shard_for_slide( + loaded: LoadedModel, + slide: SlideSpec, + tiling_result, + *, + preprocessing: PreprocessingConfig, + execution: ExecutionOptions, + flat_indices, +): + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + index = _build_hierarchical_index( + tiling_result, + region_tile_multiple=int(geometry["region_tile_multiple"]), + ) + resolved_indices = np.asarray(flat_indices, dtype=np.int64) + collate_fn = OnTheFlyHierarchicalBatchCollator( + image_path=slide.image_path, + tiling_result=tiling_result, + region_index=index.region_index, + subtile_index_within_region=index.subtile_index_within_region, + effective_region_size_px=int(geometry["effective_region_size_px"]), + effective_tile_size_px=int(geometry["effective_tile_size_px"]), + backend=_resolve_slide_backend(preprocessing, tiling_result), + num_cucim_workers=preprocessing.num_cucim_workers, + gpu_decode=preprocessing.gpu_decode, + ) + dataset = TileIndexDataset(resolved_indices) + batch_preprocessor = _build_batch_preprocessor_for_tile_images( + loaded, + target_tile_size_px=int(geometry["target_tile_size_px"]), + ) + loader_kwargs = _embedding_dataloader_kwargs(loaded, execution) + effective_num_workers, _worker_context = _resolve_on_the_fly_num_workers(preprocessing.num_cucim_workers) + loader_kwargs["num_workers"] = effective_num_workers + if effective_num_workers == 0: + loader_kwargs.pop("persistent_workers", None) + loader_kwargs.pop("prefetch_factor", None) + loader_kwargs["batch_sampler"] = collate_fn.build_batch_sampler( + batch_size=execution.batch_size, + dataset_indices=resolved_indices, + ) + dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, **loader_kwargs) + autocast_dtype = _autocast_dtype(torch, execution.precision) + autocast_context = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) + if autocast_dtype is not None and str(loaded.device).startswith("cuda") + else nullcontext() + ) + batch_flat_indices, flat_embeddings = _run_forward_pass( + dataloader, + loaded, + autocast_context, + batch_preprocessor=batch_preprocessor, + sample_id=slide.sample_id, + total_items=len(dataset), + unit_label="tile", + return_indices=True, + ) + return batch_flat_indices.numpy(), flat_embeddings + + def _aggregate_tile_embeddings_for_slide( loaded: LoadedModel, model, @@ -1018,7 +1392,7 @@ def _persist_embedded_slide( *, preprocessing: PreprocessingConfig, execution: ExecutionOptions, -) -> tuple[TileEmbeddingArtifact | None, SlideEmbeddingArtifact | None]: +) -> tuple[TileEmbeddingArtifact | HierarchicalEmbeddingArtifact | None, SlideEmbeddingArtifact | None]: if execution.output_dir is None: raise ValueError("ExecutionOptions.output_dir is required to persist embedded slides") if _num_rows(embedded_slide.tile_embeddings) == 0: @@ -1038,6 +1412,21 @@ def _persist_embedded_slide( ), ) return None, None + if _is_hierarchical_preprocessing(preprocessing): + hierarchical_artifact = _write_hierarchical_embedding_artifact( + embedded_slide.sample_id, + embedded_slide.tile_embeddings, + execution=execution, + metadata=_build_hierarchical_embedding_metadata( + model, + tiling_result=tiling_result, + image_path=embedded_slide.image_path, + mask_path=embedded_slide.mask_path, + backend=_resolve_slide_backend(preprocessing, tiling_result), + preprocessing=preprocessing, + ), + ) + return hierarchical_artifact, None tile_artifact = None if _should_persist_tile_embeddings(model, execution): tile_artifact = _write_tile_embedding_artifact( @@ -1102,6 +1491,40 @@ def _build_slide_embedding_metadata(model, *, image_path: Path | str) -> dict[st } +def _build_hierarchical_embedding_metadata( + model, + *, + tiling_result, + image_path: Path | str, + mask_path: Path | str | None, + backend: str, + preprocessing: PreprocessingConfig, +) -> dict[str, Any]: + coordinates_npz_path = ( + tiling_result.coordinates_npz_path if hasattr(tiling_result, "coordinates_npz_path") else None + ) + coordinates_meta_path = ( + tiling_result.coordinates_meta_path if hasattr(tiling_result, "coordinates_meta_path") else None + ) + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + return { + "encoder_name": model.name, + "encoder_level": model.level, + "coordinates_npz_path": str(coordinates_npz_path or ""), + "coordinates_meta_path": str(coordinates_meta_path or ""), + "image_path": str(image_path), + "mask_path": str(mask_path) if mask_path is not None else None, + "backend": backend, + "region_tile_multiple": int(geometry["region_tile_multiple"]), + "target_tile_size_px": int(geometry["target_tile_size_px"]), + "effective_tile_size_px": int(geometry["effective_tile_size_px"]), + "target_region_size_px": int(geometry["target_region_size_px"]), + "effective_region_size_px": int(geometry["effective_region_size_px"]), + "target_spacing_um": float(preprocessing.target_spacing_um), + "subtile_order": "row_major", + } + + def _write_tile_embedding_artifact( sample_id: str, features, @@ -1141,6 +1564,24 @@ def _write_slide_embedding_artifact( ) +def _write_hierarchical_embedding_artifact( + sample_id: str, + features, + *, + execution: ExecutionOptions, + metadata: dict[str, Any], +) -> HierarchicalEmbeddingArtifact: + if execution.output_dir is None: + raise ValueError("ExecutionOptions.output_dir is required to persist hierarchical embeddings") + return write_hierarchical_embeddings( + sample_id, + features, + output_dir=execution.output_dir, + output_format=execution.output_format, + metadata=metadata, + ) + + def _embedding_dataloader_kwargs(loaded: LoadedModel, execution: ExecutionOptions) -> dict[str, Any]: kwargs: dict[str, Any] = { @@ -1156,10 +1597,19 @@ def _embedding_dataloader_kwargs(loaded: LoadedModel, execution: ExecutionOption def _build_batch_preprocessor( loaded: LoadedModel, - model, tiling_result, ): + return _build_batch_preprocessor_for_tile_images( + loaded, + target_tile_size_px=int(getattr(tiling_result, "requested_tile_size_px")), + ) + +def _build_batch_preprocessor_for_tile_images( + loaded: LoadedModel, + *, + target_tile_size_px: int, +): spec = _build_batch_transform_spec(loaded.transforms) if spec is None: logging.getLogger(__name__).warning( @@ -1168,23 +1618,15 @@ def _build_batch_preprocessor( loaded.name, ) return None + def preprocess(batch): - image = batch - image = _prepare_batch_tensor(image) + image = _prepare_batch_tensor(batch) if spec.resize_size is None: - # Model has no Resize transform: apply bilinear resize to target tile size as fallback image = _resize_image_batch( image, - (int(tiling_result.requested_tile_size_px), int(tiling_result.requested_tile_size_px)), - ) - if model.level == "region": - image = _apply_region_batch_transform_spec( - image, - spec, - tile_size=int(loaded.model.tile_size), + (int(target_tile_size_px), int(target_tile_size_px)), ) - else: - image = _apply_batch_transform_spec(image, spec) + image = _apply_batch_transform_spec(image, spec) if image.device != loaded.device: image = image.to(loaded.device, non_blocking=str(loaded.device).startswith("cuda")) return image.contiguous() @@ -1206,7 +1648,6 @@ def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: center_crop_size=None, mean=tuple(float(value) for value in mean) if mean is not None else None, std=tuple(float(value) for value in std) if std is not None else None, - region_unfold_tile_size=None, ) transform_steps = _iter_transform_steps(transforms) @@ -1218,7 +1659,6 @@ def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: center_crop_size = None mean = None std = None - region_unfold_tile_size = None supported_step_names = { "Resize", "CenterCrop", @@ -1229,12 +1669,6 @@ def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: "ConvertImageDtype", } for step in transform_steps: - if _is_region_unfolding_transform(step): - step_tile_size = int(step.tile_size) - if region_unfold_tile_size is not None and region_unfold_tile_size != step_tile_size: - return None - region_unfold_tile_size = step_tile_size - continue step_name = type(step).__name__ if step_name not in supported_step_names: return None @@ -1251,7 +1685,6 @@ def _build_batch_transform_spec(transforms) -> BatchTransformSpec | None: center_crop_size=center_crop_size, mean=mean, std=std, - region_unfold_tile_size=region_unfold_tile_size, resize_interpolation=resize_interpolation, ) @@ -1268,12 +1701,6 @@ def _iter_transform_steps(transforms): else: flattened.append(step) return flattened - - -def _is_region_unfolding_transform(step) -> bool: - return type(step).__name__ == "RegionUnfolding" and hasattr(step, "tile_size") - - def _prepare_batch_tensor(image): if image.dtype == torch.uint8: return image.float().div(255.0) @@ -1330,34 +1757,6 @@ def _apply_batch_transform_spec(image, spec: BatchTransformSpec): std = torch.tensor(spec.std, dtype=image.dtype, device=image.device).view(1, -1, 1, 1) image = (image - mean) / std return image - - -def _apply_region_batch_transform_spec(image, spec: BatchTransformSpec, *, tile_size: int): - if spec.region_unfold_tile_size is not None and spec.region_unfold_tile_size != tile_size: - raise ValueError( - "Region transform stack RegionUnfolding tile_size does not match the region model tile_size" - ) - region_tile_size = spec.region_unfold_tile_size or tile_size - batch_size = int(image.shape[0]) - unfolded = _unfold_region_batch(image, region_tile_size) - num_tiles = int(unfolded.shape[1]) - flattened = unfolded.reshape(batch_size * num_tiles, *unfolded.shape[-3:]) - transformed = _apply_batch_transform_spec(flattened, spec) - return transformed.reshape(batch_size, num_tiles, *transformed.shape[-3:]) - - -def _unfold_region_batch(image, tile_size: int): - - height, width = (int(image.shape[-2]), int(image.shape[-1])) - if height % tile_size != 0 or width % tile_size != 0: - raise ValueError( - f"Region batch with shape {height}x{width} is not divisible by tile_size={tile_size}" - ) - unfolded = torch.nn.functional.unfold(image, kernel_size=tile_size, stride=tile_size) - unfolded = unfolded.transpose(1, 2) - return unfolded.reshape(image.shape[0], -1, image.shape[1], tile_size, tile_size) - - def _normalize_hw(value) -> tuple[int, int] | None: if value is None: return None @@ -1511,9 +1910,11 @@ def _run_forward_pass( sample_id: str | None = None, total_items: int | None = None, unit_label: str = "tile", + return_indices: bool = False, ): outputs = [] + batch_indices = [] if return_indices else None processed = 0 batch_index = 0 prefetcher = _BatchPrefetcher(dataloader, loaded, batch_preprocessor) @@ -1524,6 +1925,8 @@ def _run_forward_pass( embedding = loaded.model.encode_tiles(image).detach().cpu() forward_ms = (time.perf_counter() - forward_start) * 1000.0 outputs.append(embedding) + if batch_indices is not None: + batch_indices.append(torch.as_tensor(prepared_batch.indices, dtype=torch.long).detach().cpu()) processed += int(embedding.shape[0]) batch_index += 1 batch_total_ms = ( @@ -1562,8 +1965,14 @@ def _run_forward_pass( ) if not outputs: feature_dim = loaded.tile_feature_dim if loaded.tile_feature_dim is not None else loaded.feature_dim - return torch.empty((0, int(feature_dim)), dtype=torch.float32) - return torch.cat(outputs, dim=0) + empty = torch.empty((0, int(feature_dim)), dtype=torch.float32) + if batch_indices is not None: + return torch.empty((0,), dtype=torch.long), empty + return empty + embeddings = torch.cat(outputs, dim=0) + if batch_indices is not None: + return torch.cat(batch_indices, dim=0), embeddings + return embeddings @@ -1654,6 +2063,23 @@ def _write_zero_tile_embedding_sidecars( if output_dir is None: return for slide, tiling_result in zero_tile_pairs: + if _is_hierarchical_preprocessing(preprocessing): + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + write_hierarchical_embeddings( + slide.sample_id, + np.empty((0, int(geometry["tiles_per_region"]), 0), dtype=np.float32), + output_dir=output_dir, + output_format=output_format, + metadata=_build_hierarchical_embedding_metadata( + model, + tiling_result=tiling_result, + image_path=slide.image_path, + mask_path=slide.mask_path, + backend=_resolve_slide_backend(preprocessing, tiling_result), + preprocessing=preprocessing, + ), + ) + continue write_tile_embedding_metadata( slide.sample_id, output_dir=output_dir, @@ -1878,10 +2304,15 @@ def _record_slide_metadata_in_process_list( def _build_hs2p_configs(preprocessing: PreprocessingConfig): + target_tile_size_px = ( + preprocessing.target_region_size_px + if _is_hierarchical_preprocessing(preprocessing) + else preprocessing.target_tile_size_px + ) tiling_cfg = TilingConfig( backend=_resolve_tiling_backend(preprocessing), target_spacing_um=preprocessing.target_spacing_um, - target_tile_size_px=preprocessing.target_tile_size_px, + target_tile_size_px=target_tile_size_px, tolerance=preprocessing.tolerance, overlap=preprocessing.overlap, tissue_threshold=preprocessing.tissue_threshold, @@ -1957,11 +2388,11 @@ def ensure_defaults() -> tuple[int, float]: if preprocessing is None: target_tile_size_px, target_spacing_um = ensure_defaults() - return PreprocessingConfig( + return _resolve_hierarchical_preprocessing(PreprocessingConfig( backend="auto", target_spacing_um=target_spacing_um, target_tile_size_px=target_tile_size_px, - ) + )) target_spacing_um = preprocessing.target_spacing_um target_tile_size_px = preprocessing.target_tile_size_px @@ -1971,11 +2402,11 @@ def ensure_defaults() -> tuple[int, float]: target_spacing_um = default_spacing_um if target_tile_size_px is None: target_tile_size_px = default_tile_size_px - return replace( + return _resolve_hierarchical_preprocessing(replace( preprocessing, target_spacing_um=target_spacing_um, target_tile_size_px=target_tile_size_px, - ) + )) def _validate_multi_gpu_execution(model, execution: ExecutionOptions) -> None: @@ -2039,8 +2470,17 @@ def _embed_single_slide_distributed( strategy="tile_shard", sample_id=slide.sample_id, ) - shard_payloads = _load_tile_embedding_shards(coordination_dir, slide.sample_id) - tile_embeddings = _merge_tile_embedding_shards(shard_payloads) + if _is_hierarchical_preprocessing(preprocessing): + shard_payloads = _load_hierarchical_embedding_shards(coordination_dir, slide.sample_id) + geometry = _resolve_hierarchical_geometry(preprocessing, tiling_result) + tile_embeddings = _merge_hierarchical_embedding_shards( + shard_payloads, + num_regions=_num_tiles(tiling_result), + tiles_per_region=int(geometry["tiles_per_region"]), + ) + else: + shard_payloads = _load_tile_embedding_shards(coordination_dir, slide.sample_id) + tile_embeddings = _merge_tile_embedding_shards(shard_payloads) if model.level != "slide": return _make_embedded_slide( slide=slide, @@ -2310,12 +2750,41 @@ def _merge_tile_embedding_shards(shard_payloads): return merged[order] +def _merge_hierarchical_embedding_shards( + shard_payloads, + *, + num_regions: int, + tiles_per_region: int, +): + if not shard_payloads: + raise ValueError("No hierarchical embedding shards were produced") + indices = np.concatenate( + [np.asarray(payload["flat_index"], dtype=np.int64) for payload in shard_payloads], + axis=0, + ) + order = np.argsort(indices, kind="stable") + embeddings = [payload["tile_embeddings"] for payload in shard_payloads] + first = embeddings[0] + if torch.is_tensor(first): + merged = torch.cat(embeddings, dim=0) + merged = merged[torch.as_tensor(order, dtype=torch.long)] + return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) + merged = np.concatenate([np.asarray(embedding) for embedding in embeddings], axis=0) + merged = merged[order] + return merged.reshape(int(num_regions), int(tiles_per_region), int(merged.shape[-1])) + + def _load_tile_embedding_shards(coordination_dir: Path, sample_id: str): shard_paths = sorted(coordination_dir.glob(f"{sample_id}.tiles.rank*.pt")) return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] +def _load_hierarchical_embedding_shards(coordination_dir: Path, sample_id: str): + shard_paths = sorted(coordination_dir.glob(f"{sample_id}.hier.rank*.pt")) + return [torch.load(path, map_location="cpu", weights_only=True) for path in shard_paths] + + def _load_embedded_slide_payload(coordination_dir: Path, sample_id: str): payload_path = coordination_dir / f"{sample_id}.embedded.pt" @@ -2339,6 +2808,8 @@ def _serialize_preprocessing(preprocessing: PreprocessingConfig) -> dict[str, An "backend": preprocessing.backend, "target_spacing_um": preprocessing.target_spacing_um, "target_tile_size_px": preprocessing.target_tile_size_px, + "target_region_size_px": preprocessing.target_region_size_px, + "region_tile_multiple": preprocessing.region_tile_multiple, "tolerance": preprocessing.tolerance, "overlap": preprocessing.overlap, "tissue_threshold": preprocessing.tissue_threshold, @@ -2381,6 +2852,16 @@ def deserialize_preprocessing(payload: dict[str, Any]) -> PreprocessingConfig: backend=payload["backend"], target_spacing_um=float(payload["target_spacing_um"]), target_tile_size_px=int(payload["target_tile_size_px"]), + target_region_size_px=( + int(payload["target_region_size_px"]) + if "target_region_size_px" in payload and payload["target_region_size_px"] is not None + else None + ), + region_tile_multiple=( + int(payload["region_tile_multiple"]) + if "region_tile_multiple" in payload and payload["region_tile_multiple"] is not None + else None + ), tolerance=float(payload["tolerance"]), overlap=float(payload["overlap"]), tissue_threshold=float(payload["tissue_threshold"]), @@ -2435,18 +2916,28 @@ def _collect_pipeline_artifacts( output_dir: Path, output_format: str, include_tile_embeddings: bool, + include_hierarchical_embeddings: bool, include_slide_embeddings: bool, -) -> tuple[list[TileEmbeddingArtifact], list[SlideEmbeddingArtifact]]: +) -> tuple[ + list[TileEmbeddingArtifact], + list[HierarchicalEmbeddingArtifact], + list[SlideEmbeddingArtifact], +]: tile_artifacts: list[TileEmbeddingArtifact] = [] + hierarchical_artifacts: list[HierarchicalEmbeddingArtifact] = [] slide_artifacts: list[SlideEmbeddingArtifact] = [] for slide in slide_records: if include_tile_embeddings: tile_artifacts.append(_load_tile_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format)) + if include_hierarchical_embeddings: + hierarchical_artifacts.append( + _load_hierarchical_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) + ) if include_slide_embeddings: slide_artifacts.append( _load_slide_artifact(slide.sample_id, output_dir=output_dir, output_format=output_format) ) - return tile_artifacts, slide_artifacts + return tile_artifacts, hierarchical_artifacts, slide_artifacts def _load_tile_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> TileEmbeddingArtifact: @@ -2463,6 +2954,26 @@ def _load_tile_artifact(sample_id: str, *, output_dir: Path, output_format: str) ) +def _load_hierarchical_artifact( + sample_id: str, + *, + output_dir: Path, + output_format: str, +) -> HierarchicalEmbeddingArtifact: + artifact_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.{output_format}" + metadata_path = output_dir / "hierarchical_embeddings" / f"{sample_id}.meta.json" + metadata = load_metadata(metadata_path) + return HierarchicalEmbeddingArtifact( + sample_id=sample_id, + path=artifact_path, + metadata_path=metadata_path, + format=output_format, + feature_dim=int(metadata["feature_dim"]), + num_regions=int(metadata["num_regions"]), + tiles_per_region=int(metadata["tiles_per_region"]), + ) + + def _load_slide_artifact(sample_id: str, *, output_dir: Path, output_format: str) -> SlideEmbeddingArtifact: artifact_path = output_dir / "slide_embeddings" / f"{sample_id}.{output_format}" metadata_path = output_dir / "slide_embeddings" / f"{sample_id}.meta.json" @@ -2484,8 +2995,10 @@ def _update_process_list_after_embedding( *, successful_slides: Sequence[SlideSpec], persist_tile_embeddings: bool, + persist_hierarchical_embeddings: bool, include_slide_embeddings: bool, tile_artifacts: Sequence[TileEmbeddingArtifact], + hierarchical_artifacts: Sequence[HierarchicalEmbeddingArtifact], slide_artifacts: Sequence[SlideEmbeddingArtifact], ) -> None: df = pd.read_csv(process_list_path) @@ -2494,12 +3007,19 @@ def _update_process_list_after_embedding( if include_slide_embeddings and "aggregation_status" not in df.columns: df["aggregation_status"] = ["tbp"] * len(df) tile_success_ids = {artifact.sample_id for artifact in tile_artifacts} + hierarchical_success_ids = {artifact.sample_id for artifact in hierarchical_artifacts} slide_success_ids = {artifact.sample_id for artifact in slide_artifacts} for slide in successful_slides: mask = df["sample_id"].astype(str) == slide.sample_id df.loc[mask, "feature_status"] = ( "success" - if not persist_tile_embeddings or slide.sample_id in tile_success_ids + if ( + (not persist_tile_embeddings or slide.sample_id in tile_success_ids) + and ( + not persist_hierarchical_embeddings + or slide.sample_id in hierarchical_success_ids + ) + ) else "error" ) if include_slide_embeddings: diff --git a/slide2vec/models/__init__.py b/slide2vec/models/__init__.py deleted file mode 100644 index 22ed285..0000000 --- a/slide2vec/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Model implementations have moved to slide2vec.encoders. diff --git a/slide2vec/progress.py b/slide2vec/progress.py index d83da1a..a93d47b 100644 --- a/slide2vec/progress.py +++ b/slide2vec/progress.py @@ -5,7 +5,7 @@ from pathlib import Path import sys import time -from typing import Any, Iterable +from typing import Any import pandas as pd diff --git a/slide2vec/utils/__init__.py b/slide2vec/utils/__init__.py index 44018c3..6b6c610 100644 --- a/slide2vec/utils/__init__.py +++ b/slide2vec/utils/__init__.py @@ -1,10 +1,10 @@ -from .utils import initialize_wandb, fix_random_seeds, get_sha, update_state_dict +from .utils import initialize_wandb, fix_random_seeds, get_sha, slurm_cpu_limit from .log_utils import setup_logging __all__ = [ "initialize_wandb", "fix_random_seeds", "get_sha", - "update_state_dict", + "slurm_cpu_limit", "setup_logging", ] diff --git a/slide2vec/utils/log_utils.py b/slide2vec/utils/log_utils.py index 639aecb..331717c 100644 --- a/slide2vec/utils/log_utils.py +++ b/slide2vec/utils/log_utils.py @@ -3,7 +3,6 @@ import logging import os import sys -from typing import Optional import slide2vec.distributed as distributed from slide2vec.progress import emit_progress_log @@ -43,10 +42,10 @@ def emit(self, record: logging.LogRecord) -> None: # So that calling _configure_logger multiple times won't add many handlers @functools.lru_cache() def _configure_logger( - name: Optional[str] = None, + name: str | None = None, *, level: int = logging.DEBUG, - output: Optional[str] = None, + output: str | None = None, ): """ Configure a logger. @@ -110,8 +109,8 @@ def _configure_logger( def setup_logging( *, - output: Optional[str] = None, - name: Optional[str] = None, + output: str | None = None, + name: str | None = None, level: int = logging.DEBUG, capture_warnings: bool = True, ) -> None: diff --git a/slide2vec/utils/paths.py b/slide2vec/utils/paths.py deleted file mode 100644 index 4111997..0000000 --- a/slide2vec/utils/paths.py +++ /dev/null @@ -1,10 +0,0 @@ -from pathlib import Path - - -def resolve_output_dir(*, config_output_dir: str, cli_output_dir: str | None) -> Path: - if cli_output_dir is None: - return Path(config_output_dir) - cli_path = Path(cli_output_dir) - if cli_path.is_absolute(): - return cli_path - return Path(config_output_dir, cli_output_dir) diff --git a/slide2vec/utils/tiling_io.py b/slide2vec/utils/tiling_io.py index 1813de6..fc77d15 100644 --- a/slide2vec/utils/tiling_io.py +++ b/slide2vec/utils/tiling_io.py @@ -30,6 +30,8 @@ def _optional_float(value: Any) -> float | None: if value is None or pd.isna(value): return None return float(value) + + def load_slide_manifest(csv_path: str | Path) -> list[SlideSpec]: manifest_path = Path(csv_path).resolve() df = pd.read_csv(manifest_path) diff --git a/slide2vec/utils/utils.py b/slide2vec/utils/utils.py index ab25141..79560ea 100644 --- a/slide2vec/utils/utils.py +++ b/slide2vec/utils/utils.py @@ -4,8 +4,6 @@ import numpy as np import torch -from typing import Any, Optional - def fix_random_seeds(seed=31): """ @@ -36,16 +34,17 @@ def _run(command): pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message + + def initialize_wandb( - cfg: Any, + cfg, *, - key: Optional[str] = "", + key: str | None = "", ): import wandb from omegaconf import OmegaConf - command = f"wandb login {key}" - subprocess.call(command, shell=True) + subprocess.call(["wandb", "login", key]) if cfg.wandb.tags is None: tags = [] else: @@ -76,57 +75,10 @@ def initialize_wandb( return run -def update_state_dict( - *, - model_dict: dict, - state_dict: dict, -): - """ - Matches weights between `model_dict` and `state_dict`, accounting for: - - Key mismatches (missing in model_dict) - - Shape mismatches (tensor size differences) - - Args: - model_dict (dict): model state dictionary (expected keys and shapes) - state_dict (dict): checkpoint state dictionary (loaded keys and values) - - Returns: - updated_state_dict (dict): Weights mapped correctly to `model_dict` - msg (str): Log message summarizing the result - """ - success = 0 - shape_mismatch = 0 - missing_keys = 0 - updated_state_dict = {} - shape_mismatch_list = [] - missing_keys_list = [] - used_keys = set() - for model_key, model_val in model_dict.items(): - matched_key = False - for state_key, state_val in state_dict.items(): - if state_key in used_keys: - continue - if model_key == state_key: - if model_val.size() == state_val.size(): - updated_state_dict[model_key] = state_val - used_keys.add(state_key) - success += 1 - matched_key = True # key is successfully matched - break - else: - shape_mismatch += 1 - shape_mismatch_list.append(model_key) - matched_key = True # key is matched, but weight cannot be loaded - break - if not matched_key: - # key not found in state_dict - updated_state_dict[model_key] = model_val # keep original weights - missing_keys += 1 - missing_keys_list.append(model_key) - # log summary - msg = f"{success}/{len(model_dict)} weight(s) loaded successfully" - if shape_mismatch > 0: - msg += f"\n{shape_mismatch} weight(s) not loaded due to mismatching shapes: {shape_mismatch_list}" - if missing_keys > 0: - msg += f"\n{missing_keys} key(s) from checkpoint not found in model: {missing_keys_list}" - return updated_state_dict, msg +def slurm_cpu_limit() -> int | None: + """Return the CPU limit imposed by SLURM, or None if not running under SLURM.""" + for env_name in ("SLURM_CPUS_PER_TASK", "SLURM_CPUS_ON_NODE", "SLURM_JOB_CPUS_PER_NODE"): + value = os.environ.get(env_name, "") + if value.strip().isdigit() and int(value.strip()) > 0: + return int(value.strip()) + return None diff --git a/tests/test_progress.py b/tests/test_progress.py index e8ad6af..8eebe49 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -167,7 +167,7 @@ def test_run_pipeline_emits_local_progress_events_in_order(monkeypatch, tmp_path monkeypatch.setattr( inference, "_collect_pipeline_artifacts", - lambda *args, **kwargs: (["tile-artifact"], ["slide-artifact"]), + lambda *args, **kwargs: (["tile-artifact"], [], ["slide-artifact"]), ) monkeypatch.setattr(inference, "_update_process_list_after_embedding", lambda *args, **kwargs: None) diff --git a/tests/test_regression_core.py b/tests/test_regression_core.py index fce6a10..0c6ed3d 100644 --- a/tests/test_regression_core.py +++ b/tests/test_regression_core.py @@ -15,6 +15,7 @@ from slide2vec.artifacts import ( load_array, load_metadata, + write_hierarchical_embeddings, write_slide_embeddings, write_tile_embeddings, ) @@ -130,6 +131,85 @@ def test_pt_artifacts_round_trip(tmp_path: Path): assert torch.equal(loaded, features) assert metadata["image_path"] == "/tmp/sample-b.svs" + +def test_hierarchical_npz_artifacts_round_trip(tmp_path: Path): + features = np.arange(24, dtype=np.float32).reshape(2, 3, 4) + artifact = write_hierarchical_embeddings( + "sample-h", + features, + output_dir=tmp_path, + output_format="npz", + metadata={ + "coordinates_npz_path": "/tmp/sample-h.coordinates.npz", + "target_tile_size_px": 224, + "effective_tile_size_px": 224, + "target_region_size_px": 672, + "effective_region_size_px": 672, + "tiles_per_region": 3, + }, + ) + + loaded = load_array(artifact.path) + metadata = load_metadata(artifact.metadata_path) + + np.testing.assert_array_equal(loaded, features) + assert artifact.path == tmp_path / "hierarchical_embeddings" / "sample-h.npz" + assert metadata["artifact_type"] == "hierarchical_embeddings" + assert metadata["num_regions"] == 2 + assert metadata["tiles_per_region"] == 3 + assert metadata["feature_dim"] == 4 + assert metadata["target_region_size_px"] == 672 + + +def test_resolve_direct_api_preprocessing_derives_target_region_size_from_multiple(): + import slide2vec.api as api + + model = Model.from_preset("uni") + resolved = api._resolve_direct_api_preprocessing( + model, + PreprocessingConfig( + target_spacing_um=0.5, + target_tile_size_px=224, + region_tile_multiple=6, + ), + ) + + assert resolved.target_tile_size_px == 224 + assert resolved.target_region_size_px == 1344 + + +def test_resolve_direct_api_preprocessing_uses_model_defaults_before_region_derivation(): + import slide2vec.api as api + + model = Model.from_preset("conchv15") + resolved = api._resolve_direct_api_preprocessing( + model, + PreprocessingConfig( + region_tile_multiple=6, + ), + ) + + assert resolved.target_spacing_um == pytest.approx(0.5) + assert resolved.target_tile_size_px == 448 + assert resolved.target_region_size_px == 2688 + + +def test_resolve_direct_api_preprocessing_rejects_mismatched_region_size_and_multiple(): + import slide2vec.api as api + + model = Model.from_preset("uni") + + with pytest.raises(ValueError, match="target_region_size_px"): + api._resolve_direct_api_preprocessing( + model, + PreprocessingConfig( + target_spacing_um=0.5, + target_tile_size_px=224, + target_region_size_px=1024, + region_tile_multiple=6, + ), + ) + def test_pipeline_run_delegates_to_internal_runner(monkeypatch, tmp_path: Path): model = Model.from_preset("virchow2") preprocessing = DEFAULT_PREPROCESSING diff --git a/tests/test_regression_inference.py b/tests/test_regression_inference.py index d494b31..82a04bc 100644 --- a/tests/test_regression_inference.py +++ b/tests/test_regression_inference.py @@ -81,7 +81,7 @@ def test_pipeline_run_uses_distributed_embedding_path_when_num_gpus_is_greater_t monkeypatch.setattr( inference, "_collect_pipeline_artifacts", - lambda *args, **kwargs: (["tile-artifact"], ["slide-artifact"]), + lambda *args, **kwargs: (["tile-artifact"], [], ["slide-artifact"]), ) monkeypatch.setattr(inference, "_update_process_list_after_embedding", lambda *args, **kwargs: None) @@ -119,7 +119,7 @@ def fake_collect(*, model, successful_slides, process_list_path, preprocessing, captured["preprocessing"] = preprocessing captured["execution"] = execution captured["output_dir"] = output_dir - return ["tile-artifact"], ["slide-artifact"] + return ["tile-artifact"], [], ["slide-artifact"] monkeypatch.setattr(inference, "_collect_distributed_pipeline_artifacts", fake_collect) @@ -157,31 +157,36 @@ def fake_run_stage(*, model, successful_slides, preprocessing, execution, output "output_dir": output_dir, } - def fake_collect(slides, *, output_dir, output_format, include_tile_embeddings, include_slide_embeddings): + def fake_collect(slides, *, output_dir, output_format, include_tile_embeddings, include_hierarchical_embeddings, include_slide_embeddings): captured["collect"] = { "slides": slides, "output_dir": output_dir, "output_format": output_format, "include_tile_embeddings": include_tile_embeddings, + "include_hierarchical_embeddings": include_hierarchical_embeddings, "include_slide_embeddings": include_slide_embeddings, } - return ["tile-artifact"], ["slide-artifact"] + return ["tile-artifact"], [], ["slide-artifact"] def fake_update( process_list_path_arg, *, successful_slides, persist_tile_embeddings, + persist_hierarchical_embeddings, include_slide_embeddings, tile_artifacts, + hierarchical_artifacts, slide_artifacts, ): captured["update"] = { "process_list_path": process_list_path_arg, "successful_slides": successful_slides, "persist_tile_embeddings": persist_tile_embeddings, + "persist_hierarchical_embeddings": persist_hierarchical_embeddings, "include_slide_embeddings": include_slide_embeddings, "tile_artifacts": tile_artifacts, + "hierarchical_artifacts": hierarchical_artifacts, "slide_artifacts": slide_artifacts, } @@ -189,7 +194,7 @@ def fake_update( monkeypatch.setattr(inference, "_collect_pipeline_artifacts", fake_collect) monkeypatch.setattr(inference, "_update_process_list_after_embedding", fake_update) - tile_artifacts, slide_artifacts = inference._collect_distributed_pipeline_artifacts( + tile_artifacts, hierarchical_artifacts, slide_artifacts = inference._collect_distributed_pipeline_artifacts( model=model, successful_slides=[slide], process_list_path=process_list_path, @@ -206,16 +211,20 @@ def fake_update( assert captured["collect"]["output_dir"] == tmp_path assert captured["collect"]["output_format"] == "npz" assert captured["collect"]["include_tile_embeddings"] is True + assert captured["collect"]["include_hierarchical_embeddings"] is False assert captured["collect"]["include_slide_embeddings"] is True assert captured["update"]["process_list_path"] == process_list_path assert captured["update"]["successful_slides"] == [slide] assert captured["update"]["persist_tile_embeddings"] is True + assert captured["update"]["persist_hierarchical_embeddings"] is False assert captured["update"]["include_slide_embeddings"] is True assert captured["update"]["tile_artifacts"] == ["tile-artifact"] + assert captured["update"]["hierarchical_artifacts"] == [] assert captured["update"]["slide_artifacts"] == ["slide-artifact"] assert tile_artifacts == ["tile-artifact"] + assert hierarchical_artifacts == [] assert slide_artifacts == ["slide-artifact"] @@ -290,7 +299,7 @@ def fake_compute_embedded_slides(model, slide_records, tiling_results, *, prepro return [embedded_full] monkeypatch.setattr(inference, "_compute_embedded_slides", fake_compute_embedded_slides) - monkeypatch.setattr(inference, "_collect_pipeline_artifacts", lambda *args, **kwargs: (["tile-artifact"], ["slide-artifact"])) + monkeypatch.setattr(inference, "_collect_pipeline_artifacts", lambda *args, **kwargs: (["tile-artifact"], [], ["slide-artifact"])) monkeypatch.setattr(inference, "_update_process_list_after_embedding", lambda *args, **kwargs: None) model = SimpleNamespace( @@ -375,7 +384,7 @@ def test_collect_local_pipeline_artifacts_filters_none_artifacts(monkeypatch): ] monkeypatch.setattr(inference, "_persist_embedded_slide", lambda *args, **kwargs: responses.pop(0)) - tile_artifacts, slide_artifacts = inference._collect_local_pipeline_artifacts( + tile_artifacts, hierarchical_artifacts, slide_artifacts = inference._collect_local_pipeline_artifacts( model=SimpleNamespace(), embedded_slides=embedded_slides, tiling_results=tiling_results, @@ -384,6 +393,7 @@ def test_collect_local_pipeline_artifacts_filters_none_artifacts(monkeypatch): ) assert tile_artifacts == ["tile-a"] + assert hierarchical_artifacts == [] assert slide_artifacts == ["slide-a", "slide-b"] @@ -456,7 +466,7 @@ def fake_build_callback(*, model, preprocessing, execution, process_list_path): monkeypatch.setattr( inference, "_collect_pipeline_artifacts", - lambda *args, **kwargs: (["tile-artifact"], ["slide-artifact"]), + lambda *args, **kwargs: (["tile-artifact"], [], ["slide-artifact"]), ) monkeypatch.setattr(inference, "_update_process_list_after_embedding", lambda *args, **kwargs: None) @@ -1617,119 +1627,6 @@ def encode_tiles(self, image): assert result.dtype == torch.float32 -def test_region_batch_preprocessor_resizes_whole_region_before_unfolding(): - import slide2vec.inference as inference - torch = pytest.importorskip("torch") - - loaded = inference.LoadedModel( - name="region-model", - level="region", - model=SimpleNamespace(tile_size=2), - transforms=SimpleNamespace(transforms=[]), - feature_dim=3, - device=torch.device("cpu"), - ) - tiling_result = SimpleNamespace( - requested_tile_size_px=4, - effective_tile_size_px=2, - ) - preprocess = inference._build_batch_preprocessor( - loaded, - SimpleNamespace(level="region"), - tiling_result, - ) - - batch = torch.full((1, 3, 2, 2), 255, dtype=torch.uint8) - processed = preprocess(batch) - - assert processed.shape == (1, 4, 3, 2, 2) - assert processed.dtype == torch.float32 - assert torch.allclose(processed, torch.ones_like(processed)) - - -def test_region_batch_preprocessor_unfolds_then_applies_tile_transforms(): - import slide2vec.inference as inference - torch = pytest.importorskip("torch") - - class Resize: - def __init__(self, size): - self.size = size - - loaded = inference.LoadedModel( - name="region-model", - level="region", - model=SimpleNamespace(tile_size=2), - transforms=SimpleNamespace(transforms=[Resize(1)]), - feature_dim=3, - device=torch.device("cpu"), - ) - tiling_result = SimpleNamespace( - requested_tile_size_px=4, - effective_tile_size_px=4, - ) - preprocess = inference._build_batch_preprocessor( - loaded, - SimpleNamespace(level="region"), - tiling_result, - ) - - quadrant_values = torch.tensor( - [ - [ - [0, 0, 85, 85], - [0, 0, 85, 85], - [170, 170, 255, 255], - [170, 170, 255, 255], - ] - ], - dtype=torch.uint8, - ) - batch = quadrant_values.unsqueeze(0).repeat(1, 3, 1, 1) - processed = preprocess(batch) - - expected = torch.tensor([0.0, 85.0 / 255.0, 170.0 / 255.0, 1.0], dtype=torch.float32) - - assert processed.shape == (1, 4, 3, 1, 1) - assert torch.allclose(processed[0, :, 0, 0, 0], expected, atol=1e-5) - assert torch.allclose(processed[0, :, 1, 0, 0], expected, atol=1e-5) - assert torch.allclose(processed[0, :, 2, 0, 0], expected, atol=1e-5) - - -def test_build_batch_transform_spec_supports_nested_region_unfolding_transform(): - import slide2vec.inference as inference - - class Compose: - def __init__(self, transforms): - self.transforms = transforms - - class RegionUnfolding: - def __init__(self, tile_size): - self.tile_size = tile_size - - class Normalize: - def __init__(self, mean, std): - self.mean = mean - self.std = std - - transforms = Compose( - [ - Compose( - [ - RegionUnfolding(8), - Normalize((0.5, 0.4, 0.3), (0.2, 0.3, 0.4)), - ] - ) - ] - ) - - spec = inference._build_batch_transform_spec(transforms) - - assert spec is not None - assert spec.region_unfold_tile_size == 8 - assert spec.mean == (0.5, 0.4, 0.3) - assert spec.std == (0.2, 0.3, 0.4) - - def test_build_batch_preprocessor_falls_back_for_unsupported_transform_stack(caplog): import slide2vec.inference as inference torch = pytest.importorskip("torch") @@ -1750,7 +1647,6 @@ class UnsupportedTransform: with caplog.at_level("WARNING", logger="slide2vec.inference"): preprocess = inference._build_batch_preprocessor( loaded, - SimpleNamespace(level="tile"), tiling_result, ) @@ -1811,78 +1707,6 @@ def encode_tiles(self, image): assert torch.allclose(result, torch.ones((2, 3), dtype=torch.float32)) -def test_region_batch_preprocessor_uses_region_unfolding_from_transform_stack(): - import slide2vec.inference as inference - torch = pytest.importorskip("torch") - - class Compose: - def __init__(self, transforms): - self.transforms = transforms - - class RegionUnfolding: - def __init__(self, tile_size): - self.tile_size = tile_size - - loaded = inference.LoadedModel( - name="region-model", - level="region", - model=SimpleNamespace(tile_size=4), - transforms=Compose([RegionUnfolding(4)]), - feature_dim=3, - device=torch.device("cpu"), - ) - tiling_result = SimpleNamespace( - requested_tile_size_px=8, - effective_tile_size_px=8, - ) - - preprocess = inference._build_batch_preprocessor( - loaded, - SimpleNamespace(level="region"), - tiling_result, - ) - - batch = torch.ones((1, 3, 8, 8), dtype=torch.uint8) - processed = preprocess(batch) - - assert processed.shape == (1, 4, 3, 4, 4) - - -def test_region_batch_preprocessor_rejects_mismatched_region_unfolding_tile_size(): - import slide2vec.inference as inference - torch = pytest.importorskip("torch") - - class Compose: - def __init__(self, transforms): - self.transforms = transforms - - class RegionUnfolding: - def __init__(self, tile_size): - self.tile_size = tile_size - - loaded = inference.LoadedModel( - name="region-model", - level="region", - model=SimpleNamespace(tile_size=2), - transforms=Compose([RegionUnfolding(4)]), - feature_dim=3, - device=torch.device("cpu"), - ) - tiling_result = SimpleNamespace( - requested_tile_size_px=8, - effective_tile_size_px=8, - ) - - preprocess = inference._build_batch_preprocessor( - loaded, - SimpleNamespace(level="region"), - tiling_result, - ) - - with pytest.raises(ValueError, match="tile_size"): - preprocess(torch.ones((1, 3, 8, 8), dtype=torch.uint8)) - - def test_serialize_execution_preserves_loader_optimization_fields(): import slide2vec.inference as inference @@ -2396,89 +2220,190 @@ def test_compute_tile_embeddings_for_slide_requires_current_run_tile_store_witho ) -def test_compute_tile_embeddings_for_slide_uses_batched_loader_for_region_models(monkeypatch): +def test_build_hierarchical_index_is_region_major_and_row_major_within_region(): + import slide2vec.inference as inference + + tiling_result = SimpleNamespace( + x=np.array([100, 1000], dtype=np.int64), + y=np.array([200, 1200], dtype=np.int64), + tile_size_lv0=672, + effective_region_size_px=672, + target_region_size_px=672, + effective_tile_size_px=224, + target_tile_size_px=224, + ) + + index = inference._build_hierarchical_index( + tiling_result, + region_tile_multiple=3, + ) + + np.testing.assert_array_equal(index.flat_index, np.arange(18, dtype=np.int64)) + np.testing.assert_array_equal( + index.region_index, + np.array([0] * 9 + [1] * 9, dtype=np.int32), + ) + np.testing.assert_array_equal( + index.subtile_index_within_region, + np.array(list(range(9)) * 2, dtype=np.int32), + ) + np.testing.assert_array_equal( + index.subtile_x[:9], + np.array([100, 324, 548, 100, 324, 548, 100, 324, 548], dtype=np.int64), + ) + np.testing.assert_array_equal( + index.subtile_y[:9], + np.array([200, 200, 200, 424, 424, 424, 648, 648, 648], dtype=np.int64), + ) + + +def test_merge_hierarchical_embedding_shards_restores_original_region_shape(): + import slide2vec.inference as inference + + merged = inference._merge_hierarchical_embedding_shards( + [ + { + "flat_index": np.array([2, 0, 7], dtype=np.int64), + "tile_embeddings": np.array([[20.0, 21.0], [0.0, 1.0], [70.0, 71.0]], dtype=np.float32), + }, + { + "flat_index": np.array([6, 3, 1, 5, 4], dtype=np.int64), + "tile_embeddings": np.array( + [[60.0, 61.0], [30.0, 31.0], [10.0, 11.0], [50.0, 51.0], [40.0, 41.0]], + dtype=np.float32, + ), + }, + ], + num_regions=2, + tiles_per_region=4, + ) + + np.testing.assert_array_equal( + merged, + np.array( + [ + [[0.0, 1.0], [10.0, 11.0], [20.0, 21.0], [30.0, 31.0]], + [[40.0, 41.0], [50.0, 51.0], [60.0, 61.0], [70.0, 71.0]], + ], + dtype=np.float32, + ), + ) + + +def test_compute_hierarchical_embeddings_for_slide_encodes_flat_tile_batches_and_reshapes(monkeypatch): import slide2vec.inference as inference torch = pytest.importorskip("torch") captured = {} + class DummyDataset: + def __init__(self, flat_indices): + self._flat_indices = np.asarray(flat_indices, dtype=np.int64) + + def __len__(self): + return int(self._flat_indices.shape[0]) + + def __getitem__(self, idx): + return int(self._flat_indices[idx]) + class DummyLoader: def __init__(self, dataset, **kwargs): - captured["dataset"] = dataset - captured["kwargs"] = kwargs + captured["loader_kwargs"] = kwargs + self._batches = [ + ( + torch.tensor([0, 3, 4, 7], dtype=torch.long), + torch.tensor( + [ + [[[0, 0], [0, 0]]] * 3, + [[[3, 3], [3, 3]]] * 3, + [[[4, 4], [4, 4]]] * 3, + [[[7, 7], [7, 7]]] * 3, + ], + dtype=torch.uint8, + ), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ), + ( + torch.tensor([1, 2, 5, 6], dtype=torch.long), + torch.tensor( + [ + [[[1, 1], [1, 1]]] * 3, + [[[2, 2], [2, 2]]] * 3, + [[[5, 5], [5, 5]]] * 3, + [[[6, 6], [6, 6]]] * 3, + ], + dtype=torch.uint8, + ), + {"worker_batch_ms": 0.0, "reader_open_ms": 0.0, "reader_read_ms": 0.0}, + ), + ] def __iter__(self): - yield ( - torch.tensor([0, 1], dtype=torch.long), - torch.full((2, 3, 4, 4), 255, dtype=torch.uint8), - ) + return iter(self._batches) def __len__(self): - return 1 - - class DummyRegionModel: - tile_size = 2 + return len(self._batches) + class DummyTileModel: def encode_tiles(self, image): - assert image.ndim == 5 - assert image.shape[1:] == (4, 3, 2, 2) - return torch.ones((image.shape[0], image.shape[1], 3), dtype=torch.float32, device=image.device) + assert image.ndim == 4 + values = image[:, 0, 0, 0].to(torch.float32) + return torch.stack((values, values + 100.0), dim=1) - monkeypatch.setattr(inference, "BatchTileCollator", lambda **kwargs: ("collator", kwargs)) - monkeypatch.setattr(inference, "TileIndexDataset", lambda tile_indices: list(tile_indices)) - monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) + class DummyCollator: + def __init__(self, **kwargs): + captured["collator_kwargs"] = kwargs + + def build_batch_sampler(self, *, batch_size, dataset_indices): + return None - class Normalize: - def __init__(self, mean, std): - self.mean = mean - self.std = std + monkeypatch.setattr(inference, "TileIndexDataset", DummyDataset) + monkeypatch.setattr(inference, "OnTheFlyHierarchicalBatchCollator", DummyCollator) + monkeypatch.setattr(torch.utils.data, "DataLoader", DummyLoader) loaded = inference.LoadedModel( - name="region-model", - level="region", - model=DummyRegionModel(), - transforms=SimpleNamespace(transforms=[Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), - feature_dim=3, + name="uni", + level="tile", + model=DummyTileModel(), + transforms=SimpleNamespace(transforms=[]), + feature_dim=2, device=torch.device("cpu"), ) - slide = make_slide("slide-a") + slide = make_slide("slide-h") tiling_result = SimpleNamespace( - x=np.array([0, 10]), - y=np.array([5, 15]), + x=np.array([0, 100], dtype=np.int64), + y=np.array([0, 100], dtype=np.int64), + requested_tile_size_px=224, + effective_tile_size_px=224, + target_tile_size_px=224, + target_region_size_px=448, + effective_region_size_px=448, + tile_size_lv0=448, target_spacing_um=0.5, - requested_tile_size_px=4, - read_spacing_um=0.5, - effective_tile_size_px=4, - tile_size_lv0=224, - tiles_tar_path=Path("/tmp/slide-a.tiles.tar"), - ) - execution = ExecutionOptions( - batch_size=2, - num_workers=3, - num_gpus=1, - prefetch_factor=9, - persistent_workers=True, + effective_spacing_um=0.5, + read_level=0, ) - result = inference._compute_tile_embeddings_for_slide( + result = inference._compute_hierarchical_embeddings_for_slide( loaded, - SimpleNamespace(level="region"), slide, tiling_result, - preprocessing=replace(DEFAULT_PREPROCESSING, on_the_fly=False), - execution=execution, + preprocessing=replace(DEFAULT_PREPROCESSING, region_tile_multiple=2, target_region_size_px=448), + execution=ExecutionOptions(batch_size=4, num_workers=0, num_gpus=1), ) - assert result.shape == (2, 4, 3) - assert captured["kwargs"]["persistent_workers"] is True - assert captured["kwargs"]["prefetch_factor"] == 9 - assert captured["kwargs"]["collate_fn"] == ( - "collator", - { - "tar_path": Path("/tmp/slide-a.tiles.tar"), - "tiling_result": tiling_result, - }, + assert result.shape == (2, 4, 2) + np.testing.assert_array_equal( + result.numpy(), + np.array( + [ + [[0.0 / 255.0, 100.0], [1.0 / 255.0, 100.0 + 1.0 / 255.0], [2.0 / 255.0, 100.0 + 2.0 / 255.0], [3.0 / 255.0, 100.0 + 3.0 / 255.0]], + [[4.0 / 255.0, 100.0 + 4.0 / 255.0], [5.0 / 255.0, 100.0 + 5.0 / 255.0], [6.0 / 255.0, 100.0 + 6.0 / 255.0], [7.0 / 255.0, 100.0 + 7.0 / 255.0]], + ], + dtype=np.float32, + ), ) + assert "collator_kwargs" in captured def test_scale_coordinates_scales_down():