From 3217285aec811417b634359b11087066dd7752b7 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:13:58 +0200 Subject: [PATCH 01/20] Add fields for 3d and qv filtering --- src/segger/io/fields.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/segger/io/fields.py b/src/segger/io/fields.py index 40bd6be..ff1ed35 100644 --- a/src/segger/io/fields.py +++ b/src/segger/io/fields.py @@ -8,6 +8,7 @@ class XeniumTranscriptFields: filename: str = 'transcripts.parquet' x: str = 'x_location' y: str = 'y_location' + z: str = 'z_location' feature: str = 'feature_name' cell_id: str = 'cell_id' null_cell_id: str = 'UNASSIGNED' @@ -38,6 +39,7 @@ class MerscopeTranscriptFields: filename: str = 'detected_transcripts.csv' x: str = 'global_x' y: str = 'global_y' + z: str = 'global_z' feature: str = 'gene' cell_id: str = 'cell_id' @@ -54,6 +56,7 @@ class CosMxTranscriptFields: filename: str = '*_tx_file.csv' x: str = 'x_global_px' y: str = 'y_global_px' + z: str = 'z' feature: str = 'target' cell_id: str = 'cell' compartment: str = 'CellComp' @@ -87,8 +90,10 @@ class StandardTranscriptFields: row_index: str = 'row_index' x: str = 'x' y: str = 'y' + z: str = 'z' feature: str = 'feature_name' cell_id: str = 'cell_id' + quality: str = 'qv' compartment: str = 'cell_compartment' extracellular_value: int = 0 cytoplasmic_value: int = 1 From ebf5105afe914d0a8712c3558365d396927f411e Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:16:17 +0200 Subject: [PATCH 02/20] Adapt for 3d and qv filtering, add column remapping for robustness --- src/segger/io/preprocessor.py | 84 ++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/src/segger/io/preprocessor.py b/src/segger/io/preprocessor.py index 597a818..23634ed 100644 --- a/src/segger/io/preprocessor.py +++ b/src/segger/io/preprocessor.py @@ -2,7 +2,7 @@ from functools import cached_property from abc import ABC, abstractmethod from anndata import AnnData -from typing import Literal +from typing import Literal, Optional from pathlib import Path import geopandas as gpd import polars as pl @@ -34,6 +34,14 @@ # Register of available ISTPreprocessor subclasses keyed by platform name. PREPROCESSORS = {} + +def _lazyframe_column_names(lf: pl.LazyFrame) -> list[str]: + """Return column names for a LazyFrame across Polars versions.""" + try: + return lf.collect_schema().names() + except AttributeError: + return lf.columns + def register_preprocessor(name): """ Decorator to register a preprocessor class under a given platform name. @@ -60,7 +68,14 @@ class ISTPreprocessor(ABC): transcript and boundary GeoDataFrames for the given platform. """ - def __init__(self, data_dir: Path): + DEFAULT_MIN_QV: Optional[float] = None + + def __init__( + self, + data_dir: Path, + min_qv: Optional[float] = None, + include_z: bool = True, + ): """ Parameters ---------- @@ -70,6 +85,8 @@ def __init__(self, data_dir: Path): data_dir = Path(data_dir) type(self)._validate_directory(data_dir) self.data_dir = data_dir + self.min_qv = self.DEFAULT_MIN_QV if min_qv is None else min_qv + self.include_z = include_z @staticmethod @abstractmethod @@ -280,7 +297,7 @@ def transcripts(self) -> pl.DataFrame: raw = CosMxTranscriptFields() std = StandardTranscriptFields() - return ( + lf = ( # Read in lazily pl.scan_csv(next(self.data_dir.glob(raw.filename))) .with_row_index(name=std.row_index) @@ -310,13 +327,22 @@ def transcripts(self) -> pl.DataFrame: .otherwise(None) .alias(std.cell_id) ) - # Map to standard field names - .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) + ) + rename_map = {raw.x: std.x, raw.y: std.y, raw.feature: std.feature} + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + if self.include_z: + schema_names = _lazyframe_column_names(lf) + if raw.z in schema_names: + rename_map[raw.z] = std.z + select_cols.append(std.z) + + return ( + lf + # Map to standard field names + .rename(rename_map) + # Subset to necessary fields + .select(select_cols) # Add numeric index .with_row_index() .collect() @@ -372,6 +398,8 @@ class XeniumPreprocessor(ISTPreprocessor): """ Preprocessor for 10x Genomics Xenium datasets. """ + DEFAULT_MIN_QV: float = 20.0 + @staticmethod def _validate_directory(data_dir: Path): @@ -397,7 +425,7 @@ def transcripts(self) -> pl.DataFrame: raw = XeniumTranscriptFields() std = StandardTranscriptFields() - return ( + lf = ( # Read in lazily pl.scan_parquet( self.data_dir / raw.filename, @@ -405,8 +433,12 @@ def transcripts(self) -> pl.DataFrame: ) # Add numeric index at beginning .with_row_index(name=std.row_index) - # Filter data - .filter(pl.col(raw.quality) >= 20) + ) + if self.min_qv is not None and self.min_qv > 0: + lf = lf.filter(pl.col(raw.quality) >= self.min_qv) + + lf = ( + lf .filter(pl.col(raw.feature).str.contains( '|'.join(raw.filter_substrings)).not_() ) @@ -415,7 +447,7 @@ def transcripts(self) -> pl.DataFrame: pl.when(pl.col(raw.compartment) == raw.nucleus_value) .then(std.nucleus_value) .when( - (pl.col(raw.compartment) != raw.nucleus_value) & + (pl.col(raw.compartment) != raw.nucleus_value) & (pl.col(raw.cell_id) != raw.null_cell_id) ) .then(std.cytoplasmic_value) @@ -428,12 +460,22 @@ def transcripts(self) -> pl.DataFrame: .replace(raw.null_cell_id, None) .alias(std.cell_id) ) + ) + + rename_map = {raw.x: std.x, raw.y: std.y, raw.feature: std.feature} + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + if self.include_z: + schema_names = _lazyframe_column_names(lf) + if raw.z in schema_names: + rename_map[raw.z] = std.z + select_cols.append(std.z) + + return ( + lf # Map to standard field names - .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) + .rename(rename_map) + # Subset to necessary fields + .select(select_cols) .collect() ) @@ -540,7 +582,9 @@ def _infer_platform(data_dir: Path) -> str: def get_preprocessor( data_dir: Path, - platform: str | None = None + platform: str | None = None, + min_qv: Optional[float] = None, + include_z: bool = True, ) -> ISTPreprocessor: data_dir = Path(data_dir) if platform is None: @@ -551,4 +595,4 @@ def get_preprocessor( f"Available: {list(PREPROCESSORS)}" ) cls = PREPROCESSORS[platform.lower()] - return cls(data_dir) + return cls(data_dir, min_qv=min_qv, include_z=include_z) From f99b3fb7f5067c13ee3c7eef97225a788dd21c06 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:18:39 +0200 Subject: [PATCH 03/20] Add arguments for 3d, qv filtering; pass spatialdata save flag to with ISTSegmentationWriter --- src/segger/cli/segment.py | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index dff7e05..a828ef8 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -58,6 +58,21 @@ help="Related to loss function parameters.", sort_key=7, ) +group_quality = Group( + name="Quality Filtering", + help="Related to transcript quality filtering.", + sort_key=8, +) +group_3d = Group( + name="3D Support", + help="Related to 3D coordinate handling.", + sort_key=9, +) + +def _resolve_use_3d_flag(use_3d: Literal["auto", "true", "false"]) -> bool | str: + if use_3d == "auto": + return "auto" + return use_3d == "true" app_segment = App(name="segment", help="Run cell segmentation on spatial transcriptomics data.") @@ -293,16 +308,46 @@ def segment( "save_anndata", group=group_io, )] = registry.get_default("save_anndata"), + + save_spatialdata: Annotated[bool, registry.get_parameter( + "save_spatialdata", + group=group_io, + )] = registry.get_default("save_spatialdata"), debug: Annotated[bool, Parameter( help="Whether to save additional debug information (trainer, predictions).", )] = "none", + + # Quality filtering + min_qv: Annotated[float | None, Parameter( + help="Minimum transcript quality threshold. Set to 0 to disable.", + validator=validators.Number(gte=0), + group=group_quality, + )] = 20.0, + + # 3D support + use_3d: Annotated[ + Literal["auto", "true", "false"], + Parameter( + help="Use 3D coordinates for graph construction ('false' default).", + group=group_3d, + ), + ] = "false", ): """Run cell segmentation on spatial transcriptomics data.""" # Setup logger and debug directory logger = logging.getLogger(__name__) + use_3d_value = _resolve_use_3d_flag(use_3d) + + output_directory = Path(output_directory) + if output_directory.exists() and not output_directory.is_dir(): + raise ValueError( + f"Output path exists and is not a directory: {output_directory}" + ) + output_directory.mkdir(parents=True, exist_ok=True) + # Remove SLURM environment autodetect from lightning.pytorch.plugins.environments import SLURMEnvironment SLURMEnvironment.detect = lambda: False @@ -328,6 +373,8 @@ def segment( tiling_margin_prediction=tiling_margin_prediction, tiling_nodes_per_tile=max_nodes_per_tile, edges_per_batch=max_edges_per_batch, + use_3d=use_3d_value, + min_qv=min_qv, ) # Setup Lightning Model @@ -366,6 +413,7 @@ def segment( writer = ISTSegmentationWriter( output_directory, save_anndata=save_anndata, + save_spatialdata=save_spatialdata, debug=debug, ) trainer = Trainer( From 6e9a2642738efa9e7f43e587ae00a3cdf7c7f7ea Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:28:02 +0200 Subject: [PATCH 04/20] Support 3d input and qv filtering, initialize spatialdata loader, other changes from v2-incremental --- src/segger/data/data_module.py | 70 ++++++++++++++++++++++++++--- src/segger/data/utils/heterodata.py | 3 ++ src/segger/data/utils/neighbors.py | 36 +++++++++++++-- 3 files changed, 99 insertions(+), 10 deletions(-) diff --git a/src/segger/data/data_module.py b/src/segger/data/data_module.py index efcdde2..650fffd 100644 --- a/src/segger/data/data_module.py +++ b/src/segger/data/data_module.py @@ -5,11 +5,12 @@ from lightning.pytorch import LightningDataModule from torchvision.transforms import Compose from dataclasses import dataclass -from typing import Literal +from typing import Literal, Optional from pathlib import Path import polars as pl import torch import gc +import os import numpy as np from .tile_dataset import ( @@ -143,6 +144,8 @@ class ISTDataModule(LightningDataModule): prediction_graph_mode: Literal["nucleus", "cell", "uniform"] = "cell" prediction_graph_max_k: int = 3 prediction_graph_buffer_ratio: float = 0.05 + use_3d: bool | Literal["auto"] = False + min_qv: Optional[float] = 20.0 tiling_mode: Literal["adaptive", "square"] = "adaptive" # TODO: Remove (benchmarking only) tiling_margin_training: float = 20. tiling_margin_prediction: float = 20. @@ -166,11 +169,54 @@ def load(self): tx_fields = StandardTranscriptFields() bd_fields = StandardBoundaryFields() - # Load standardized IST data self.logger.debug(f"Loading standardized IST data from {self.input_directory}...") - pp = get_preprocessor(self.input_directory) - tx = self.tx = pp.transcripts - bd = self.bd = pp.boundaries + # Load standardized IST data (raw platform directory or SpatialData .zarr) + input_path = Path(self.input_directory) + tx = None + bd = None + + try: + from ..io.spatialdata_loader import ( + is_spatialdata_path, + load_from_spatialdata, + ) + has_spatialdata_loader = True + except Exception: + has_spatialdata_loader = False + + if has_spatialdata_loader and is_spatialdata_path(input_path): + tx_lf, bd = load_from_spatialdata( + input_path, + boundary_type="all", + normalize=True, + ) + tx = tx_lf.collect() if isinstance(tx_lf, pl.LazyFrame) else tx_lf + + # Keep behavior consistent with raw Xenium filtering when quality exists. + quality_col = getattr(tx_fields, "quality", "qv") + if ( + self.min_qv is not None + and self.min_qv > 0 + and quality_col in tx.columns + ): + tx = tx.filter(pl.col(quality_col) >= self.min_qv) + else: + pp = get_preprocessor( + self.input_directory, + min_qv=self.min_qv, + include_z=(self.use_3d is not False), + ) + tx = pp.transcripts + bd = pp.boundaries + + self.tx = tx + self.bd = bd + + if bd is None or len(bd) == 0: + raise ValueError( + "No boundary shapes found in input data. " + "Segger requires cell/nucleus polygons in raw input or SpatialData shapes." + ) # Mask transcripts to reference segmentation if self.segmentation_graph_mode == "nucleus": @@ -187,8 +233,16 @@ def load(self): f"Unrecognized segmentation graph mode: " f"'{self.segmentation_graph_mode}'." ) - tx_mask = pl.col(tx_fields.compartment).is_in(compartments) - bd_mask = bd[bd_fields.boundary_type] == boundary_type + + if tx_fields.compartment in tx.columns: + tx_mask = pl.col(tx_fields.compartment).is_in(compartments) + else: + tx_mask = pl.col(tx_fields.cell_id).is_not_null() + + if bd_fields.boundary_type in bd.columns: + bd_mask = bd[bd_fields.boundary_type] == boundary_type + else: + bd_mask = np.ones(len(bd), dtype=bool) # Generate reference AnnData self.logger.debug("Generating reference AnnData object...") @@ -222,6 +276,8 @@ def load(self): prediction_graph_mode=self.prediction_graph_mode, prediction_graph_max_k=self.prediction_graph_max_k, prediction_graph_buffer_ratio=self.prediction_graph_buffer_ratio, + use_3d=self.use_3d, + me_gene_pairs=self.me_gene_pairs, ) # Tile graph dataset diff --git a/src/segger/data/utils/heterodata.py b/src/segger/data/utils/heterodata.py index 40f43d9..fd3bea1 100644 --- a/src/segger/data/utils/heterodata.py +++ b/src/segger/data/utils/heterodata.py @@ -24,6 +24,7 @@ def setup_heterodata( prediction_graph_mode: Literal["nucleus", "cell", "uniform"], prediction_graph_max_k: int, prediction_graph_buffer_ratio: float, + use_3d: bool | Literal["auto"] = False, cells_embedding_key: str = 'X_pca', cells_clusters_column: str = 'phenograph_cluster', cells_encoding_column: str = 'cell_encoding', @@ -135,6 +136,7 @@ def setup_heterodata( transcripts, max_k=transcripts_graph_max_k, max_dist=transcripts_graph_max_dist, + use_3d=use_3d, ) # Reference segmentation graph @@ -150,6 +152,7 @@ def setup_heterodata( max_k=prediction_graph_max_k, buffer_ratio=prediction_graph_buffer_ratio, mode=prediction_graph_mode, + use_3d=use_3d if prediction_graph_mode == "uniform" else False, ) return data diff --git a/src/segger/data/utils/neighbors.py b/src/segger/data/utils/neighbors.py index ce7ab3e..3c9d654 100644 --- a/src/segger/data/utils/neighbors.py +++ b/src/segger/data/utils/neighbors.py @@ -122,7 +122,7 @@ def edge_index_to_knn( def kdtree_neighbors( points: np.ndarray, max_k: int, - max_dist: float, + max_dist: float = np.inf, query: np.ndarray | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Wrapper for KDTree kNN and conversion to edge_index COO format. @@ -148,11 +148,25 @@ def setup_transcripts_graph( tx: pl.DataFrame, max_k: int, max_dist: float, + use_3d: bool | Literal["auto"] = False, ) -> torch.Tensor: """TODO: Add description. """ tx_fields = TrainingTranscriptFields() - points = tx[[tx_fields.x, tx_fields.y]].to_numpy() + coord_cols = [tx_fields.x, tx_fields.y] + has_z = tx_fields.z in tx.columns + + if use_3d == "auto": + use_3d = has_z and tx[tx_fields.z].null_count() < len(tx) + elif use_3d is True and not has_z: + raise ValueError( + f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. " + f"Available columns: {tx.columns}" + ) + if use_3d and has_z: + coord_cols.append(tx_fields.z) + + points = tx[coord_cols].to_numpy() edge_index, _ = kdtree_neighbors( points=points, max_k=max_k, @@ -184,6 +198,7 @@ def setup_prediction_graph( max_k: int, buffer_ratio: float, mode: Literal['nucleus', 'cell', 'uniform'] = 'cell', + use_3d: bool | Literal["auto"] = False, ) -> torch.Tensor: """TODO: Add description. """ @@ -192,12 +207,27 @@ def setup_prediction_graph( # Uniform kNN graph if mode == "uniform": - points = tx[[tx_fields.x, tx_fields.y]].to_numpy() + coord_cols = [tx_fields.x, tx_fields.y] + has_z = tx_fields.z in tx.columns + if use_3d == "auto": + use_3d = has_z and tx[tx_fields.z].null_count() < len(tx) + elif use_3d is True and not has_z: + raise ValueError( + f"use_3d=True but z column '{tx_fields.z}' not found in transcripts. " + f"Available columns: {tx.columns}" + ) + if use_3d and has_z: + coord_cols.append(tx_fields.z) + + points = tx[coord_cols].to_numpy() query = bd.geometry.centroid.get_coordinates().values + if use_3d and len(coord_cols) == 3: + query = np.hstack([query, np.zeros((len(query), 1))]) edge_index, _ = kdtree_neighbors( points=points, query=query, max_k=max_k, + max_dist=np.inf, ) return edge_index From e0424e89e6868353bf534abd5cce487653541154 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:32:50 +0200 Subject: [PATCH 05/20] Add spatialdata loader module, expose it, add all optional dependencies frm v2-incremental --- src/segger/io/__init__.py | 88 ++++- src/segger/io/spatialdata_loader.py | 476 ++++++++++++++++++++++++++++ src/segger/utils/__init__.py | 52 +++ src/segger/utils/optional_deps.py | 402 +++++++++++++++++++++++ 4 files changed, 1011 insertions(+), 7 deletions(-) create mode 100644 src/segger/io/spatialdata_loader.py create mode 100644 src/segger/utils/__init__.py create mode 100644 src/segger/utils/optional_deps.py diff --git a/src/segger/io/__init__.py b/src/segger/io/__init__.py index 1f1ad20..a913449 100644 --- a/src/segger/io/__init__.py +++ b/src/segger/io/__init__.py @@ -1,7 +1,81 @@ -from .preprocessor import get_preprocessor -from .fields import ( - StandardBoundaryFields, - TrainingBoundaryFields, - StandardTranscriptFields, - TrainingTranscriptFields, -) \ No newline at end of file +"""Input/output modules for spatial transcriptomics data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +import importlib + +__all__ = [ + # Preprocessors + "get_preprocessor", + # Fields + "StandardBoundaryFields", + "TrainingBoundaryFields", + "StandardTranscriptFields", + "TrainingTranscriptFields", + # SpatialData (optional) + "SpatialDataLoader", + "load_from_spatialdata", + "is_spatialdata_path", +] + +if TYPE_CHECKING: # pragma: no cover + from .fields import ( + StandardBoundaryFields, + TrainingBoundaryFields, + StandardTranscriptFields, + TrainingTranscriptFields, + ) + from .preprocessor import get_preprocessor + from .spatialdata_loader import ( + SpatialDataLoader, + load_from_spatialdata, + is_spatialdata_path, + ) + + +def __getattr__(name: str): + if name in { + "StandardBoundaryFields", + "TrainingBoundaryFields", + "StandardTranscriptFields", + "TrainingTranscriptFields", + }: + from .fields import ( + StandardBoundaryFields, + TrainingBoundaryFields, + StandardTranscriptFields, + TrainingTranscriptFields, + ) + return locals()[name] + + if name == "get_preprocessor": + from .preprocessor import get_preprocessor + return get_preprocessor + + if name in { + "SpatialDataLoader", + "load_from_spatialdata", + "is_spatialdata_path", + }: + try: + from .spatialdata_loader import ( + SpatialDataLoader, + load_from_spatialdata, + is_spatialdata_path, + ) + except Exception: + return None + return locals()[name] + + if name in { + "fields", + "preprocessor", + "spatialdata_loader", + }: + try: + return importlib.import_module(f"{__name__}.{name}") + except Exception as exc: + raise ImportError(f"Failed to import module '{name}'.") from exc + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/segger/io/spatialdata_loader.py b/src/segger/io/spatialdata_loader.py new file mode 100644 index 0000000..9d27f5c --- /dev/null +++ b/src/segger/io/spatialdata_loader.py @@ -0,0 +1,476 @@ +"""Load transcript and boundary data from SpatialData .zarr stores. + +This loader normalizes heterogeneous SpatialData point/shape schemas to +Segger's internal fields so the same downstream data module can run on both +vendor raw inputs and SpatialData inputs. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal, Optional + +import geopandas as gpd +import pandas as pd +import polars as pl + +from segger.io.fields import StandardBoundaryFields, StandardTranscriptFields +from segger.utils.optional_deps import ( + SPATIALDATA_IO_AVAILABLE, + require_spatialdata, + warn_spatialdata_io_unavailable, +) + + +_COMMON_POINTS_KEYS = [ + "transcripts", + "molecules", + "points", + "spots", + "tx", +] + +_COMMON_CELL_SHAPES_KEYS = [ + "cells", + "cell_boundaries", + "cell_shapes", + "cell_polygons", + "boundaries", +] + +_COMMON_NUCLEUS_SHAPES_KEYS = [ + "nuclei", + "nucleus_boundaries", + "nucleus_shapes", + "nucleus_polygons", + "nuclei_boundaries", +] + + +def _lazyframe_column_names(lf: pl.LazyFrame) -> list[str]: + """Return column names for a LazyFrame across Polars versions.""" + try: + return lf.collect_schema().names() + except AttributeError: + return lf.columns + + +def _safe_to_geodataframe(data: object) -> gpd.GeoDataFrame: + """Best-effort conversion of a SpatialData shapes element to GeoDataFrame.""" + if isinstance(data, gpd.GeoDataFrame): + return data.copy() + if hasattr(data, "compute"): + data = data.compute() + if isinstance(data, gpd.GeoDataFrame): + return data.copy() + if hasattr(data, "to_geopandas"): + return data.to_geopandas().copy() + if hasattr(data, "to_pandas"): + df = data.to_pandas() + elif isinstance(data, pd.DataFrame): + df = data + else: + df = pd.DataFrame(data) + + if "geometry" in df.columns: + return gpd.GeoDataFrame(df, geometry="geometry") + + raise ValueError( + "Could not convert shapes element to GeoDataFrame: no geometry column found." + ) + + +def _largest_polygon(geom): + """Convert MultiPolygon/GeometryCollection to a single polygon when possible.""" + if geom is None or geom.is_empty: + return geom + gtype = geom.geom_type + if gtype == "Polygon": + return geom + if gtype == "MultiPolygon": + parts = list(geom.geoms) + if not parts: + return geom + return max(parts, key=lambda p: p.area) + if gtype == "GeometryCollection": + parts = [g for g in geom.geoms if g.geom_type == "Polygon"] + if parts: + return max(parts, key=lambda p: p.area) + return geom + + +class SpatialDataLoader: + """Load and normalize points/shapes from a SpatialData Zarr store.""" + + def __init__( + self, + path: Path | str, + points_key: Optional[str] = None, + cell_shapes_key: Optional[str] = None, + nucleus_shapes_key: Optional[str] = None, + coordinate_system: str = "global", + ): + require_spatialdata() + if not SPATIALDATA_IO_AVAILABLE: + warn_spatialdata_io_unavailable( + "Platform-specific SpatialData readers (Xenium/MERSCOPE/CosMX)" + ) + + self._path = Path(path) + self._points_key = points_key + self._cell_shapes_key = cell_shapes_key + self._nucleus_shapes_key = nucleus_shapes_key + self._coordinate_system = coordinate_system + self._sdata = None + + if not self._path.exists(): + raise FileNotFoundError(f"SpatialData store not found: {self._path}") + + @property + def sdata(self): + if self._sdata is None: + spatialdata = require_spatialdata() + if hasattr(spatialdata, "read_zarr"): + self._sdata = spatialdata.read_zarr(str(self._path)) + else: + # Fallback for API variants + self._sdata = spatialdata.SpatialData.read(str(self._path)) + return self._sdata + + @property + def points_key(self) -> str: + if self._points_key is None: + self._points_key = self._detect_points_key() + return self._points_key + + @property + def cell_shapes_key(self) -> Optional[str]: + if self._cell_shapes_key is None: + self._cell_shapes_key = self._detect_shapes_key(_COMMON_CELL_SHAPES_KEYS) + return self._cell_shapes_key + + @property + def nucleus_shapes_key(self) -> Optional[str]: + if self._nucleus_shapes_key is None: + self._nucleus_shapes_key = self._detect_shapes_key(_COMMON_NUCLEUS_SHAPES_KEYS) + return self._nucleus_shapes_key + + def _detect_points_key(self) -> str: + available = list(self.sdata.points.keys()) + if not available: + raise ValueError( + f"No points elements found in SpatialData store: {self._path}" + ) + + for key in _COMMON_POINTS_KEYS: + if key in available: + return key + + lowered = {k.lower(): k for k in available} + for pattern in ("transcript", "molecule", "spot", "point"): + for lk, orig in lowered.items(): + if pattern in lk: + return orig + + return available[0] + + def _detect_shapes_key(self, preferred: list[str]) -> Optional[str]: + available = list(self.sdata.shapes.keys()) + if not available: + return None + + for key in preferred: + if key in available: + return key + + lowered = {k.lower(): k for k in available} + # Fuzzy fallback for newer naming conventions + for pattern in ("cell", "nucleus", "nuclei", "boundar", "polygon", "shape"): + for lk, orig in lowered.items(): + if pattern in lk: + return orig + + return available[0] + + @staticmethod + def _detect_column( + columns: set[str], + candidates: list[str], + optional: bool = False, + ) -> Optional[str]: + for candidate in candidates: + if candidate in columns: + return candidate + if optional: + return None + raise ValueError( + f"Could not detect required column. Tried {candidates}. " + f"Available columns: {sorted(columns)}" + ) + + def _points_to_pandas(self, points_obj) -> pd.DataFrame: + if hasattr(points_obj, "compute"): + points_obj = points_obj.compute() + if isinstance(points_obj, pd.DataFrame): + df = points_obj.copy() + elif hasattr(points_obj, "to_pandas"): + df = points_obj.to_pandas() + else: + df = pd.DataFrame(points_obj) + + # Recover coordinates from geometry when needed + if "geometry" in df.columns and ("x" not in df.columns or "y" not in df.columns): + geom = df["geometry"] + if len(geom) > 0: + try: + df = df.copy() + df["x"] = geom.x + df["y"] = geom.y + except Exception: + pass + + return df + + def transcripts( + self, + normalize: bool = True, + gene_column: Optional[str] = None, + quality_column: Optional[str] = None, + ) -> pl.LazyFrame: + """Load transcripts from SpatialData and normalize to standard fields.""" + std = StandardTranscriptFields() + points_obj = self.sdata.points[self.points_key] + df = self._points_to_pandas(points_obj) + + lf = pl.from_pandas(df).lazy().with_row_index(name=std.row_index) + if not normalize: + return lf + + columns = set(df.columns) + + x_col = self._detect_column(columns, ["x", "x_location", "global_x", "x_global_px"]) + y_col = self._detect_column(columns, ["y", "y_location", "global_y", "y_global_px"]) + z_col = self._detect_column( + columns, + ["z", "z_location", "global_z", "z_global_px"], + optional=True, + ) + + if gene_column is None: + gene_column = self._detect_column( + columns, + ["feature_name", "gene", "target", "gene_name", "feature"], + ) + + if quality_column is None: + quality_column = self._detect_column( + columns, + ["qv", "quality", "quality_score", "score"], + optional=True, + ) + + cell_id_col = self._detect_column( + columns, + ["cell_id", "cell", "segger_cell_id", "segmentation_cell_id", "instance_id"], + optional=True, + ) + + compartment_col = self._detect_column( + columns, + ["cell_compartment", "overlaps_nucleus", "compartment", "CellComp"], + optional=True, + ) + + rename_map = { + x_col: std.x, + y_col: std.y, + gene_column: std.feature, + } + if z_col: + rename_map[z_col] = std.z + if cell_id_col: + rename_map[cell_id_col] = std.cell_id + quality_field = getattr(std, "quality", None) + if quality_column and quality_field: + rename_map[quality_column] = quality_field + + lf = lf.rename({k: v for k, v in rename_map.items() if k != v}) + + # Normalize/derive compartment labels for segmentation masking. + if compartment_col: + # Handle common formats: bool overlaps_nucleus, numeric labels, strings. + source_col = compartment_col + if source_col in rename_map: + source_col = rename_map[source_col] + source_dtype = lf.collect_schema().get(source_col) + + if source_dtype == pl.Boolean: + lf = lf.with_columns( + pl.when(pl.col(source_col)) + .then(std.nucleus_value) + .when(pl.col(std.cell_id).is_not_null()) + .then(std.cytoplasmic_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + else: + as_str = pl.col(source_col).cast(pl.Utf8).str.to_lowercase() + lf = lf.with_columns( + pl.when( + as_str.is_in(["1", "true", "t", "nucleus", "nuclear"]) + ) + .then(std.nucleus_value) + .when( + as_str.is_in(["2", "cytoplasm", "cytoplasmic", "membrane"]) + ) + .then(std.cytoplasmic_value) + .when(pl.col(std.cell_id).is_not_null()) + .then(std.cytoplasmic_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + else: + lf = lf.with_columns( + pl.when(pl.col(std.cell_id).is_not_null()) + .then(std.nucleus_value) + .otherwise(std.extracellular_value) + .alias(std.compartment) + ) + + select_cols = [std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment] + schema_names = _lazyframe_column_names(lf) + + if z_col and std.z in schema_names: + select_cols.append(std.z) + + quality_field = getattr(std, "quality", None) + if quality_field and quality_field in schema_names: + select_cols.append(quality_field) + + return lf.select(select_cols) + + def _normalize_boundary_ids(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + std = StandardBoundaryFields() + if std.id in gdf.columns: + return gdf + + for candidate in ( + "cell_id", + "cell", + "instance_id", + "segger_cell_id", + "id", + "label", + "EntityID", + ): + if candidate in gdf.columns: + gdf = gdf.copy() + gdf[std.id] = gdf[candidate] + return gdf + + gdf = gdf.reset_index(drop=False) + index_col = gdf.columns[0] + gdf[std.id] = gdf[index_col] + return gdf + + def _prepare_shapes( + self, + shape_key: str, + boundary_label: str, + ) -> gpd.GeoDataFrame: + std = StandardBoundaryFields() + raw = self.sdata.shapes[shape_key] + gdf = _safe_to_geodataframe(raw) + gdf = self._normalize_boundary_ids(gdf) + + gdf = gdf[gdf.geometry.notnull()].copy() + if not gdf.empty: + try: + gdf["geometry"] = gdf.geometry.buffer(0) + except Exception: + pass + gdf["geometry"] = gdf.geometry.apply(_largest_polygon) + gdf = gdf[gdf.geometry.notnull()].copy() + gdf = gdf[~gdf.geometry.is_empty].copy() + + gdf[std.boundary_type] = boundary_label + return gdf + + def boundaries( + self, + boundary_type: Literal["cell", "nucleus", "all"] = "all", + ) -> Optional[gpd.GeoDataFrame]: + """Load boundaries from SpatialData and normalize to standard fields.""" + std = StandardBoundaryFields() + + parts: list[gpd.GeoDataFrame] = [] + if boundary_type in {"cell", "all"} and self.cell_shapes_key is not None: + parts.append(self._prepare_shapes(self.cell_shapes_key, std.cell_value)) + + if boundary_type in {"nucleus", "all"} and self.nucleus_shapes_key is not None: + parts.append(self._prepare_shapes(self.nucleus_shapes_key, std.nucleus_value)) + + if not parts: + # Fallback: if no specific key detected but shapes exist, use first key as cell shapes. + available = list(self.sdata.shapes.keys()) + if not available: + return None + parts.append(self._prepare_shapes(available[0], std.cell_value)) + + result = gpd.GeoDataFrame( + pd.concat(parts, ignore_index=True), + geometry="geometry", + crs=parts[0].crs if parts and hasattr(parts[0], "crs") else None, + ) + + # Compute contains_nucleus when possible + if std.contains_nucleus not in result.columns: + if std.boundary_type in result.columns: + nucleus_ids = set( + result.loc[ + result[std.boundary_type] == std.nucleus_value, + std.id, + ].astype(str) + ) + result[std.contains_nucleus] = result[std.id].astype(str).isin(nucleus_ids) + result.loc[ + result[std.boundary_type] == std.nucleus_value, + std.contains_nucleus, + ] = True + else: + result[std.contains_nucleus] = False + + return result + + +def load_from_spatialdata( + path: Path | str, + points_key: Optional[str] = None, + cell_shapes_key: Optional[str] = None, + nucleus_shapes_key: Optional[str] = None, + boundary_type: Literal["cell", "nucleus", "all"] = "all", + normalize: bool = True, +) -> tuple[pl.LazyFrame, Optional[gpd.GeoDataFrame]]: + """Convenience loader for SpatialData .zarr stores.""" + loader = SpatialDataLoader( + path=path, + points_key=points_key, + cell_shapes_key=cell_shapes_key, + nucleus_shapes_key=nucleus_shapes_key, + ) + tx = loader.transcripts(normalize=normalize) + bd = loader.boundaries(boundary_type=boundary_type) + return tx, bd + + +def is_spatialdata_path(path: Path | str) -> bool: + """Check whether a path looks like a SpatialData zarr store.""" + p = Path(path) + return ( + p.suffix == ".zarr" + or (p / ".zgroup").exists() + or (p / "zarr.json").exists() + or (p / "points").exists() + or (p / "shapes").exists() + or (p / "tables").exists() + ) diff --git a/src/segger/utils/__init__.py b/src/segger/utils/__init__.py new file mode 100644 index 0000000..be161eb --- /dev/null +++ b/src/segger/utils/__init__.py @@ -0,0 +1,52 @@ +"""Utility modules for Segger.""" + +from segger.utils.optional_deps import ( + # Availability flags + SPATIALDATA_AVAILABLE, + SPATIALDATA_IO_AVAILABLE, + SOPA_AVAILABLE, + # Import functions (raise ImportError if missing) + require_spatialdata, + require_spatialdata_io, + require_sopa, + # Decorators for functions requiring optional deps + requires_spatialdata, + requires_spatialdata_io, + requires_sopa, + # Warning functions for soft failures + warn_spatialdata_unavailable, + warn_spatialdata_io_unavailable, + warn_sopa_unavailable, + warn_rapids_unavailable, + # RAPIDS helpers + require_rapids, + # Version utilities + get_spatialdata_version, + get_sopa_version, + check_spatialdata_version, +) + +__all__ = [ + # Availability flags + "SPATIALDATA_AVAILABLE", + "SPATIALDATA_IO_AVAILABLE", + "SOPA_AVAILABLE", + # Import functions + "require_spatialdata", + "require_spatialdata_io", + "require_sopa", + # Decorators + "requires_spatialdata", + "requires_spatialdata_io", + "requires_sopa", + # Warning functions + "warn_spatialdata_unavailable", + "warn_spatialdata_io_unavailable", + "warn_sopa_unavailable", + "warn_rapids_unavailable", + "require_rapids", + # Version utilities + "get_spatialdata_version", + "get_sopa_version", + "check_spatialdata_version", +] diff --git a/src/segger/utils/optional_deps.py b/src/segger/utils/optional_deps.py new file mode 100644 index 0000000..b7231ca --- /dev/null +++ b/src/segger/utils/optional_deps.py @@ -0,0 +1,402 @@ +"""Optional dependency handling with informative warnings. + +This module provides lazy import wrappers for optional dependencies +(spatialdata, spatialdata-io, sopa) with clear installation instructions +when the dependencies are not available. + +Usage +----- +Check availability: + >>> from segger.utils.optional_deps import SPATIALDATA_AVAILABLE + >>> if SPATIALDATA_AVAILABLE: + ... import spatialdata + +Require and get import (raises ImportError with instructions if missing): + >>> from segger.utils.optional_deps import require_spatialdata + >>> spatialdata = require_spatialdata() + +Decorator for functions requiring optional deps: + >>> from segger.utils.optional_deps import requires_spatialdata + >>> @requires_spatialdata + ... def my_function(): + ... import spatialdata + ... return spatialdata.SpatialData() +""" + +from __future__ import annotations + +import functools +import importlib +import importlib.util +import warnings +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +if TYPE_CHECKING: + import types + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + + +# ----------------------------------------------------------------------------- +# Availability flags +# ----------------------------------------------------------------------------- + +def _check_spatialdata() -> bool: + """Check if spatialdata is available.""" + try: + return importlib.util.find_spec("spatialdata") is not None + except Exception: + return False + + +def _check_spatialdata_io() -> bool: + """Check if spatialdata-io is available.""" + try: + return importlib.util.find_spec("spatialdata_io") is not None + except Exception: + return False + + +def _check_sopa() -> bool: + """Check if sopa is available.""" + try: + return importlib.util.find_spec("sopa") is not None + except Exception: + return False + + +# Availability flags (evaluated once at import time) +SPATIALDATA_AVAILABLE: bool = _check_spatialdata() +SPATIALDATA_IO_AVAILABLE: bool = _check_spatialdata_io() +SOPA_AVAILABLE: bool = _check_sopa() + + +# ----------------------------------------------------------------------------- +# Installation instructions +# ----------------------------------------------------------------------------- + +SPATIALDATA_INSTALL_MSG = """ +spatialdata is not installed. This package is required for SpatialData I/O support. + +To install spatialdata support: + pip install segger[spatialdata] + +Or install spatialdata directly: + pip install spatialdata>=0.7.2 +""" + +SPATIALDATA_IO_INSTALL_MSG = """ +spatialdata-io is not installed. This package is required for reading platform-specific +SpatialData formats (Xenium, MERSCOPE, CosMX). + +To install spatialdata-io support: + pip install segger[spatialdata-io] + +For full SpatialData support: + pip install segger[spatialdata] + +Or install spatialdata-io directly: + pip install spatialdata-io>=0.6.0 +""" + +SOPA_INSTALL_MSG = """ +sopa is not installed. This package is required for SOPA compatibility features. + +To install SOPA support: + pip install segger[sopa] + +Or install sopa directly: + pip install sopa>=2.0.0 + +For all SpatialData features including SOPA: + pip install segger[spatialdata-all] +""" + +RAPIDS_INSTALL_MSG = """ +RAPIDS GPU packages are not installed. Segger requires CuPy/cuDF/cuML/cuGraph/cuSpatial and a CUDA-enabled GPU. + +See docs/INSTALLATION.md for RAPIDS/CUDA setup. +""" + + +# ----------------------------------------------------------------------------- +# Import functions with error messages +# ----------------------------------------------------------------------------- + +def require_spatialdata() -> "types.ModuleType": + """Import and return spatialdata, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata module. + + Raises + ------ + ImportError + If spatialdata is not installed, with installation instructions. + """ + if not SPATIALDATA_AVAILABLE: + raise ImportError(SPATIALDATA_INSTALL_MSG) + import spatialdata + return spatialdata + + +def require_spatialdata_io() -> "types.ModuleType": + """Import and return spatialdata_io, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The spatialdata_io module. + + Raises + ------ + ImportError + If spatialdata-io is not installed, with installation instructions. + """ + if not SPATIALDATA_IO_AVAILABLE: + raise ImportError(SPATIALDATA_IO_INSTALL_MSG) + import spatialdata_io + return spatialdata_io + + +def require_sopa() -> "types.ModuleType": + """Import and return sopa, raising ImportError if not available. + + Returns + ------- + types.ModuleType + The sopa module. + + Raises + ------ + ImportError + If sopa is not installed, with installation instructions. + """ + if not SOPA_AVAILABLE: + raise ImportError(SOPA_INSTALL_MSG) + import sopa + return sopa + + +# ----------------------------------------------------------------------------- +# Decorators for requiring optional dependencies +# ----------------------------------------------------------------------------- + +def requires_spatialdata(func: F) -> F: + """Decorator that raises ImportError if spatialdata is not available. + + Parameters + ---------- + func + Function that requires spatialdata. + + Returns + ------- + F + Wrapped function that checks for spatialdata before execution. + + Examples + -------- + >>> @requires_spatialdata + ... def load_from_zarr(path): + ... import spatialdata + ... return spatialdata.read_zarr(path) + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_spatialdata_io(func: F) -> F: + """Decorator that raises ImportError if spatialdata-io is not available. + + Parameters + ---------- + func + Function that requires spatialdata-io. + + Returns + ------- + F + Wrapped function that checks for spatialdata-io before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_spatialdata_io() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +def requires_sopa(func: F) -> F: + """Decorator that raises ImportError if sopa is not available. + + Parameters + ---------- + func + Function that requires sopa. + + Returns + ------- + F + Wrapped function that checks for sopa before execution. + """ + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + require_sopa() + return func(*args, **kwargs) + return wrapper # type: ignore[return-value] + + +# ----------------------------------------------------------------------------- +# Warning functions for soft failures +# ----------------------------------------------------------------------------- + +def warn_spatialdata_unavailable(feature: str = "SpatialData support") -> None: + """Emit a warning that spatialdata is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata. + """ + warnings.warn( + f"{feature} requires spatialdata. " + "Install with: pip install segger[spatialdata]", + UserWarning, + stacklevel=2, + ) + + +def warn_spatialdata_io_unavailable(feature: str = "Platform-specific SpatialData readers") -> None: + """Emit a warning that spatialdata-io is not available. + + Parameters + ---------- + feature + Description of the feature requiring spatialdata-io. + """ + warnings.warn( + f"{feature} requires spatialdata-io. " + "Install with: pip install segger[spatialdata-io]", + UserWarning, + stacklevel=2, + ) + + +def warn_sopa_unavailable(feature: str = "SOPA compatibility") -> None: + """Emit a warning that sopa is not available. + + Parameters + ---------- + feature + Description of the feature requiring sopa. + """ + warnings.warn( + f"{feature} requires sopa. " + "Install with: pip install segger[sopa]", + UserWarning, + stacklevel=2, + ) + + +def _import_optional_packages(packages: list[str]) -> tuple[dict[str, "types.ModuleType"], list[str]]: + """Import optional packages and return (modules, missing).""" + modules: dict[str, "types.ModuleType"] = {} + missing: list[str] = [] + for package in packages: + try: + modules[package] = importlib.import_module(package) + except Exception: + missing.append(package) + return modules, missing + + +def require_rapids( + packages: list[str] | None = None, + feature: str = "Segger", +) -> dict[str, "types.ModuleType"]: + """Import RAPIDS-related packages or raise with installation instructions.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + modules, missing = _import_optional_packages(package_list) + if missing: + missing_list = ", ".join(missing) + raise ImportError( + f"{feature} requires RAPIDS GPU packages: {missing_list}. " + + RAPIDS_INSTALL_MSG.strip() + ) + return modules + + +def warn_rapids_unavailable( + feature: str = "Segger", + packages: list[str] | None = None, +) -> bool: + """Warn if RAPIDS-related packages are unavailable. Returns True if present.""" + package_list = packages or ["cupy", "cudf", "cuml", "cugraph", "cuspatial"] + _, missing = _import_optional_packages(package_list) + if not missing: + return True + missing_list = ", ".join(missing) + warnings.warn( + f"{feature} requires RAPIDS GPU packages ({missing_list}). " + + RAPIDS_INSTALL_MSG.strip(), + UserWarning, + stacklevel=2, + ) + return False + + +# ----------------------------------------------------------------------------- +# Version checking +# ----------------------------------------------------------------------------- + +def get_spatialdata_version() -> str | None: + """Get the installed spatialdata version, or None if not installed.""" + if not SPATIALDATA_AVAILABLE: + return None + try: + import spatialdata + return getattr(spatialdata, "__version__", "unknown") + except Exception: + return None + + +def get_sopa_version() -> str | None: + """Get the installed sopa version, or None if not installed.""" + if not SOPA_AVAILABLE: + return None + try: + import sopa + return getattr(sopa, "__version__", "unknown") + except Exception: + return None + + +def check_spatialdata_version(min_version: str = "0.7.2") -> bool: + """Check if spatialdata version meets minimum requirement. + + Parameters + ---------- + min_version + Minimum required version string. + + Returns + ------- + bool + True if version is sufficient, False otherwise. + """ + version = get_spatialdata_version() + if version is None or version == "unknown": + return False + + try: + from packaging.version import Version + return Version(version) >= Version(min_version) + except ImportError: + # Fallback to simple string comparison + return version >= min_version From e4e8846b494d90853eb900febb29c28f9152af8a Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:36:43 +0200 Subject: [PATCH 06/20] Add all export modules from v2-incremental --- src/segger/export/__init__.py | 144 ++++ src/segger/export/adapter.py | 165 +++++ src/segger/export/anndata_writer.py | 250 +++++++ src/segger/export/boundary.py | 525 +++++++++++++++ src/segger/export/merged_writer.py | 317 +++++++++ src/segger/export/output_formats.py | 309 +++++++++ src/segger/export/sopa_compat.py | 396 +++++++++++ src/segger/export/spatialdata_writer.py | 535 +++++++++++++++ src/segger/export/xenium.py | 862 ++++++++++++++++++++++++ 9 files changed, 3503 insertions(+) create mode 100644 src/segger/export/__init__.py create mode 100644 src/segger/export/adapter.py create mode 100644 src/segger/export/anndata_writer.py create mode 100644 src/segger/export/boundary.py create mode 100644 src/segger/export/merged_writer.py create mode 100644 src/segger/export/output_formats.py create mode 100644 src/segger/export/sopa_compat.py create mode 100644 src/segger/export/spatialdata_writer.py create mode 100644 src/segger/export/xenium.py diff --git a/src/segger/export/__init__.py b/src/segger/export/__init__.py new file mode 100644 index 0000000..0df59d6 --- /dev/null +++ b/src/segger/export/__init__.py @@ -0,0 +1,144 @@ +"""Export module for segmentation results. + +This module provides functionality to export segmentation results to various formats: +- Xenium Explorer format for visualization and validation +- Merged transcripts (original data with segmentation results) +- SpatialData Zarr format for scverse ecosystem +- SOPA-compatible format for spatial omics workflows +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +import importlib + +__all__ = [ + # Existing exports + "BoundaryIdentification", + "generate_boundary", + "generate_boundaries", + "seg2explorer", + "seg2explorer_pqdm", + "predictions_to_dataframe", + # Output formats + "OutputFormat", + "OutputWriter", + "get_writer", + "register_writer", + "write_all_formats", + # Writers + "MergedTranscriptsWriter", + "SeggerRawWriter", + "AnnDataWriter", + "merge_predictions_with_transcripts", + # SpatialData (optional) + "SpatialDataWriter", + "write_spatialdata", + # SOPA (optional) + "validate_sopa_compatibility", + "export_for_sopa", + "sopa_to_segger_input", + "check_sopa_installation", +] + +if TYPE_CHECKING: # pragma: no cover + from .boundary import BoundaryIdentification, generate_boundary, generate_boundaries + from .xenium import seg2explorer, seg2explorer_pqdm + from .adapter import predictions_to_dataframe + from .output_formats import ( + OutputFormat, + OutputWriter, + get_writer, + register_writer, + write_all_formats, + ) + from .merged_writer import ( + MergedTranscriptsWriter, + SeggerRawWriter, + merge_predictions_with_transcripts, + ) + from .anndata_writer import AnnDataWriter + from .spatialdata_writer import SpatialDataWriter, write_spatialdata + from .sopa_compat import ( + validate_sopa_compatibility, + export_for_sopa, + sopa_to_segger_input, + check_sopa_installation, + ) + + +def __getattr__(name: str): + if name in {"BoundaryIdentification", "generate_boundary", "generate_boundaries"}: + from .boundary import BoundaryIdentification, generate_boundary, generate_boundaries + return locals()[name] + if name in {"seg2explorer", "seg2explorer_pqdm"}: + from .xenium import seg2explorer, seg2explorer_pqdm + return locals()[name] + if name == "predictions_to_dataframe": + from .adapter import predictions_to_dataframe + return predictions_to_dataframe + if name in { + "OutputFormat", + "OutputWriter", + "get_writer", + "register_writer", + "write_all_formats", + }: + from .output_formats import ( + OutputFormat, + OutputWriter, + get_writer, + register_writer, + write_all_formats, + ) + return locals()[name] + if name in { + "MergedTranscriptsWriter", + "SeggerRawWriter", + "AnnDataWriter", + "merge_predictions_with_transcripts", + }: + from .merged_writer import ( + MergedTranscriptsWriter, + SeggerRawWriter, + merge_predictions_with_transcripts, + ) + if name == "AnnDataWriter": + from .anndata_writer import AnnDataWriter + return locals()[name] + if name in {"SpatialDataWriter", "write_spatialdata"}: + try: + from .spatialdata_writer import SpatialDataWriter, write_spatialdata + except Exception: + return None + return locals()[name] + if name in { + "validate_sopa_compatibility", + "export_for_sopa", + "sopa_to_segger_input", + "check_sopa_installation", + }: + try: + from .sopa_compat import ( + validate_sopa_compatibility, + export_for_sopa, + sopa_to_segger_input, + check_sopa_installation, + ) + except Exception: + return None + return locals()[name] + if name in { + "boundary", + "xenium", + "adapter", + "output_formats", + "merged_writer", + "spatialdata_writer", + "sopa_compat", + }: + try: + return importlib.import_module(f"{__name__}.{name}") + except Exception as exc: + raise ImportError(f"Failed to import optional module '{name}'.") from exc + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/segger/export/adapter.py b/src/segger/export/adapter.py new file mode 100644 index 0000000..541daa9 --- /dev/null +++ b/src/segger/export/adapter.py @@ -0,0 +1,165 @@ +"""Adapter to convert model predictions to export-compatible format. + +This module bridges the gap between LitISTEncoder.predict_step() output +and the seg2explorer functions for Xenium Explorer export. +""" + +from typing import Optional, Union +import pandas as pd +import polars as pl +import torch + + +def predictions_to_dataframe( + src_idx: torch.Tensor, + seg_idx: torch.Tensor, + max_sim: torch.Tensor, + gen_idx: torch.Tensor, + transcript_data: Union[pd.DataFrame, pl.DataFrame], + min_similarity: float = 0.5, + x_column: str = "x", + y_column: str = "y", + gene_column: str = "feature_name", +) -> pd.DataFrame: + """Convert prediction tensors to seg2explorer-compatible DataFrame. + + This function takes the output from LitISTEncoder.predict_step() and + combines it with the original transcript data to create a DataFrame + suitable for Xenium Explorer export. + + Parameters + ---------- + src_idx : torch.Tensor + Transcript indices from prediction, shape (N,). + seg_idx : torch.Tensor + Assigned boundary/cell indices, shape (N,). Value of -1 indicates + unassigned transcripts. + max_sim : torch.Tensor + Maximum similarity scores, shape (N,). + gen_idx : torch.Tensor + Gene indices for each transcript, shape (N,). + transcript_data : Union[pd.DataFrame, pl.DataFrame] + Original transcript DataFrame with coordinates. + min_similarity : float + Minimum similarity threshold for valid assignments. + x_column : str + Column name for x coordinates. + y_column : str + Column name for y coordinates. + gene_column : str + Column name for gene/feature names. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - row_index: Original transcript index + - x: X coordinate + - y: Y coordinate + - seg_cell_id: Assigned cell ID (or -1 if unassigned) + - similarity: Assignment confidence score + - feature_name: Gene name + """ + # Convert to numpy + src_idx_np = src_idx.cpu().numpy() + seg_idx_np = seg_idx.cpu().numpy() + max_sim_np = max_sim.cpu().numpy() + + # Filter by similarity threshold + valid_mask = (seg_idx_np >= 0) & (max_sim_np >= min_similarity) + + # Convert Polars to pandas if needed + if isinstance(transcript_data, pl.DataFrame): + transcript_data = transcript_data.to_pandas() + + # Build result DataFrame + result = pd.DataFrame({ + "row_index": src_idx_np, + "seg_cell_id": seg_idx_np, + "similarity": max_sim_np, + }) + + # Mark low-similarity assignments as unassigned + result.loc[~valid_mask, "seg_cell_id"] = -1 + + # Merge with original transcript data for coordinates + if "row_index" in transcript_data.columns: + # Use existing row_index + result = result.merge( + transcript_data[["row_index", x_column, y_column, gene_column]], + on="row_index", + how="left", + ) + else: + # Use index as row_index + transcript_data = transcript_data.reset_index() + transcript_data = transcript_data.rename(columns={"index": "row_index"}) + result = result.merge( + transcript_data[["row_index", x_column, y_column, gene_column]], + on="row_index", + how="left", + ) + + # Rename columns for consistency + result = result.rename(columns={ + gene_column: "feature_name", + x_column: "x", + y_column: "y", + }) + + return result + + +def collect_predictions( + predictions: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Collect predictions from multiple batches. + + Parameters + ---------- + predictions : list[tuple] + List of (src_idx, seg_idx, max_sim, gen_idx) tuples from predict_step. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Concatenated (src_idx, seg_idx, max_sim, gen_idx) tensors. + """ + src_indices = [] + seg_indices = [] + similarities = [] + gene_indices = [] + + for src_idx, seg_idx, max_sim, gen_idx in predictions: + src_indices.append(src_idx) + seg_indices.append(seg_idx) + similarities.append(max_sim) + gene_indices.append(gen_idx) + + return ( + torch.cat(src_indices), + torch.cat(seg_indices), + torch.cat(similarities), + torch.cat(gene_indices), + ) + + +def filter_assigned_transcripts( + seg_df: pd.DataFrame, + cell_id_column: str = "seg_cell_id", +) -> pd.DataFrame: + """Filter DataFrame to only include assigned transcripts. + + Parameters + ---------- + seg_df : pd.DataFrame + Segmentation result DataFrame. + cell_id_column : str + Column name for cell IDs. + + Returns + ------- + pd.DataFrame + DataFrame with only assigned transcripts. + """ + return seg_df[seg_df[cell_id_column] >= 0].copy() diff --git a/src/segger/export/anndata_writer.py b/src/segger/export/anndata_writer.py new file mode 100644 index 0000000..716c2cc --- /dev/null +++ b/src/segger/export/anndata_writer.py @@ -0,0 +1,250 @@ +"""Write segmentation results as AnnData (.h5ad). + +This writer builds a cell x gene count matrix from transcript assignments +and saves it as an AnnData object. The output can also be embedded as a +table in SpatialData. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import pandas as pd +import polars as pl +from anndata import AnnData +from scipy import sparse as sp + +from segger.export.output_formats import OutputFormat, register_writer +from segger.export.merged_writer import merge_predictions_with_transcripts + + +def build_anndata_table( + transcripts: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + unassigned_value: Union[int, str, None] = -1, + region: Optional[str] = None, + region_key: Optional[str] = None, + obs_index_as_str: bool = False, +) -> AnnData: + """Build AnnData from assigned transcripts. + + Parameters + ---------- + transcripts + Transcript DataFrame with segmentation assignments. + cell_id_column + Column with assigned cell IDs. + feature_column + Column with gene/feature names. + x_column, y_column, z_column + Coordinate columns (optional). If present, centroids are stored in + ``obsm["X_spatial"]``. + unassigned_value + Marker for unassigned transcripts (filtered out). + region, region_key + SpatialData table linkage metadata. + obs_index_as_str + If True, cast cell IDs to string for ``obs`` index. + """ + if cell_id_column not in transcripts.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + if feature_column not in transcripts.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) + if unassigned_value is not None: + col_dtype = transcripts.schema.get(cell_id_column) + try: + compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() + filter_expr = pl.col(cell_id_column) != compare_value + except Exception: + filter_expr = ( + pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) + ) + assigned = assigned.filter(filter_expr) + + # Gene list from all transcripts (even if no assignments) + var_idx = ( + transcripts + .select(feature_column) + .unique() + .sort(feature_column) + .get_column(feature_column) + .to_list() + ) + + if assigned.height == 0: + obs_index = pd.Index([], name=cell_id_column) + if obs_index_as_str: + var_index = pd.Index([str(v) for v in var_idx], name=feature_column) + else: + var_index = pd.Index(var_idx, name=feature_column) + X = sp.csr_matrix((0, len(var_index))) + adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + return adata + + feature_idx = ( + assigned + .select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned + .select(cell_id_column) + .unique() + .sort(cell_id_column) + .with_row_index(name="_cid") + ) + + mapped = ( + assigned + .join(feature_idx, on=feature_column) + .join(cell_idx, on=cell_id_column) + ) + counts = ( + mapped + .group_by(["_cid", "_fid"]) + .agg(pl.len().alias("_count")) + ) + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + n_cells = cell_idx.height + n_genes = feature_idx.height + X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() + + obs_ids = cell_idx.get_column(cell_id_column).to_list() + var_ids = feature_idx.get_column(feature_column).to_list() + if obs_index_as_str: + obs_ids = [str(v) for v in obs_ids] + var_ids = [str(v) for v in var_ids] + + adata = AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), + var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), + ) + + # Add centroid coordinates if present + if x_column in assigned.columns and y_column in assigned.columns: + coords_cols = [x_column, y_column] + if z_column and z_column in assigned.columns: + coords_cols.append(z_column) + centroids = ( + assigned + .group_by(cell_id_column) + .agg([pl.col(c).mean().alias(c) for c in coords_cols]) + ) + centroids_pd = ( + centroids + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() + + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + + return adata + + +@register_writer(OutputFormat.ANNDATA) +class AnnDataWriter: + """Write segmentation results as AnnData (.h5ad).""" + + def __init__( + self, + unassigned_marker: Union[int, str, None] = -1, + compression: Optional[str] = "gzip", + compression_opts: Optional[int] = 4, + ): + self.unassigned_marker = unassigned_marker + self.compression = compression + self.compression_opts = compression_opts + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + transcripts: Optional[pl.DataFrame] = None, + output_name: str = "segger_segmentation.h5ad", + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + overwrite: bool = False, + **kwargs, + ) -> Path: + """Write segmentation results to AnnData (.h5ad). + + Parameters + ---------- + predictions + Segmentation predictions. + output_dir + Output directory. + transcripts + Original transcripts DataFrame (required). + output_name + Output filename. Default "segger_segmentation.h5ad". + """ + if transcripts is None: + raise ValueError("AnnData output requires transcripts DataFrame.") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / output_name + + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Output path exists: {output_path}. " + "Use overwrite=True to replace." + ) + + merged = merge_predictions_with_transcripts( + predictions=predictions, + transcripts=transcripts, + row_index_column=row_index_column, + cell_id_column=cell_id_column, + similarity_column=similarity_column, + unassigned_marker=self.unassigned_marker, + ) + + adata = build_anndata_table( + transcripts=merged, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=self.unassigned_marker, + ) + + write_kwargs = {} + if self.compression is not None: + write_kwargs["compression"] = self.compression + if self.compression_opts is not None: + write_kwargs["compression_opts"] = self.compression_opts + + adata.write_h5ad(output_path, **write_kwargs) + return output_path diff --git a/src/segger/export/boundary.py b/src/segger/export/boundary.py new file mode 100644 index 0000000..82eca0f --- /dev/null +++ b/src/segger/export/boundary.py @@ -0,0 +1,525 @@ +"""Delaunay triangulation-based cell boundary generation. + +This module provides sophisticated boundary extraction using Delaunay triangulation +with iterative edge refinement and cycle detection. This produces more accurate +cell boundaries than simple convex hulls. +""" + +from typing import Iterable, Tuple, Union +from concurrent.futures import ThreadPoolExecutor +import geopandas as gpd +import numpy as np +import pandas as pd +import polars as pl +import rtree.index +from scipy.spatial import Delaunay +from shapely.geometry import MultiPolygon, Polygon +from tqdm import tqdm + + +def vector_angle(v1: np.ndarray, v2: np.ndarray) -> float: + """Calculate angle between two vectors in degrees. + + Parameters + ---------- + v1 : np.ndarray + First vector. + v2 : np.ndarray + Second vector. + + Returns + ------- + float + Angle in degrees. + """ + dot_product = np.dot(v1, v2) + magnitude_v1 = np.linalg.norm(v1) + magnitude_v2 = np.linalg.norm(v2) + cos_angle = np.clip(dot_product / (magnitude_v1 * magnitude_v2 + 1e-8), -1.0, 1.0) + return np.degrees(np.arccos(cos_angle)) + + +def triangle_angles_from_points( + points: np.ndarray, + triangles: np.ndarray, +) -> np.ndarray: + """Calculate angles for all triangles in a Delaunay triangulation. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + triangles : np.ndarray + Triangle vertex indices, shape (M, 3). + + Returns + ------- + np.ndarray + Angles for each triangle vertex, shape (M, 3). + """ + # Vectorized angle computation for all triangles + p1 = points[triangles[:, 0]] + p2 = points[triangles[:, 1]] + p3 = points[triangles[:, 2]] + + v1 = p2 - p1 + v2 = p3 - p1 + v3 = p3 - p2 + + def _angles(u: np.ndarray, v: np.ndarray) -> np.ndarray: + dot = (u * v).sum(axis=1) + denom = (np.linalg.norm(u, axis=1) * np.linalg.norm(v, axis=1)) + 1e-8 + cos = np.clip(dot / denom, -1.0, 1.0) + return np.degrees(np.arccos(cos)) + + a = _angles(v1, v2) + b = _angles(-v1, v3) + c = _angles(-v2, -v3) + return np.stack([a, b, c], axis=1) + + +def dfs(v: int, graph: dict, path: list, colors: dict) -> None: + """Depth-first search for cycle detection. + + Parameters + ---------- + v : int + Current vertex. + graph : dict + Adjacency list representation of graph. + path : list + Current path being built. + colors : dict + Vertex visit status (0=unvisited, 1=visited). + """ + colors[v] = 1 + path.append(v) + for d in graph[v]: + if colors[d] == 0: + dfs(d, graph, path, colors) + + +class BoundaryIdentification: + """Delaunay triangulation-based polygon boundary extraction. + + This class implements a two-phase iterative algorithm for extracting + cell boundaries from transcript point clouds: + + 1. Phase 1: Remove long boundary edges (> 2 * d_max) + 2. Phase 2: Remove boundary edges with extreme angles + + Parameters + ---------- + data : np.ndarray + 2D point coordinates, shape (N, 2). + """ + + def __init__(self, data: np.ndarray): + self.graph = None + self.edges = {} + self.d = Delaunay(data) + self.d_max = self.calculate_d_max(self.d.points) + self.generate_edges() + + def generate_edges(self) -> None: + """Generate edge dictionary from Delaunay triangulation.""" + d = self.d + edges = {} + angles = triangle_angles_from_points(d.points, d.simplices) + + for index, simplex in enumerate(d.simplices): + for p in range(3): + edge = tuple(sorted((simplex[p], simplex[(p + 1) % 3]))) + if edge not in edges: + edges[edge] = {"simplices": {}} + edges[edge]["simplices"][index] = angles[index][(p + 2) % 3] + + edges_coordinates = d.points[np.array(list(edges.keys()))] + edges_length = np.sqrt( + (edges_coordinates[:, 1, 0] - edges_coordinates[:, 0, 0]) ** 2 + + (edges_coordinates[:, 1, 1] - edges_coordinates[:, 0, 1]) ** 2 + ) + + for edge, coords, length in zip(edges, edges_coordinates, edges_length): + edges[edge]["coords"] = coords + edges[edge]["length"] = length + + self.edges = edges + + def calculate_part_1(self, plot: bool = False) -> None: + """Phase 1: Remove long boundary edges iteratively. + + Removes edges longer than 2 * d_max from the boundary. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + + flag = True + while flag: + flag = False + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if edges[current_edge]["length"] > 2 * d_max: + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + flag = True + else: + next_boundary_edges.append(current_edge) + + boundary_edges = next_boundary_edges + + def calculate_part_2(self, plot: bool = False) -> None: + """Phase 2: Remove boundary edges with extreme angles. + + Removes edges where the opposite angle is too large, indicating + a concave region that should be excluded. + + Parameters + ---------- + plot : bool + Whether to generate visualization (not implemented). + """ + edges = self.edges + d = self.d + d_max = self.d_max + + boundary_edges = [edge for edge in edges if len(edges[edge]["simplices"]) < 2] + boundary_edges_length = len(boundary_edges) + next_boundary_edges = [] + + while len(next_boundary_edges) != boundary_edges_length: + next_boundary_edges = [] + + for current_edge in boundary_edges: + if current_edge not in edges: + continue + + if len(edges[current_edge]["simplices"].keys()) == 0: + del edges[current_edge] + continue + + simplex_id = list(edges[current_edge]["simplices"].keys())[0] + simplex = d.simplices[simplex_id] + + # Remove if edge is long with large angle, or if angle is very obtuse + if ( + edges[current_edge]["length"] > 1.5 * d_max + and edges[current_edge]["simplices"][simplex_id] > 90 + ) or edges[current_edge]["simplices"][simplex_id] > 180 - 180 / 16: + + for edge in self.get_edges_from_simplex(simplex): + if edge != current_edge: + edges[edge]["simplices"].pop(simplex_id) + next_boundary_edges.append(edge) + + del edges[current_edge] + else: + next_boundary_edges.append(current_edge) + + boundary_edges_length = len(boundary_edges) + boundary_edges = next_boundary_edges + + def find_cycles(self) -> Union[Polygon, MultiPolygon, None]: + """Find boundary cycles and convert to Shapely geometry. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Polygon if single cycle, MultiPolygon if multiple, None on error. + """ + e = self.edges + boundary_edges = [edge for edge in e if len(e[edge]["simplices"]) < 2] + self.graph = self.generate_graph(boundary_edges) + cycles = self.get_cycles(self.graph) + + try: + if len(cycles) == 1: + geom = Polygon(self.d.points[cycles[0]]) + else: + geom = MultiPolygon( + [Polygon(self.d.points[c]) for c in cycles if len(c) >= 3] + ) + except Exception: + return None + + return geom + + @staticmethod + def calculate_d_max(points: np.ndarray) -> float: + """Calculate maximum nearest-neighbor distance. + + Parameters + ---------- + points : np.ndarray + Point coordinates, shape (N, 2). + + Returns + ------- + float + Maximum nearest-neighbor distance. + """ + index = rtree.index.Index() + for i, p in enumerate(points): + index.insert(i, p[[0, 1, 0, 1]]) + + short_edges = [] + for i, p in enumerate(points): + res = list(index.nearest(p[[0, 1, 0, 1]], 2))[-1] + short_edges.append([i, res]) + + nearest_points = points[short_edges] + nearest_dists = np.sqrt( + (nearest_points[:, 0, 0] - nearest_points[:, 1, 0]) ** 2 + + (nearest_points[:, 0, 1] - nearest_points[:, 1, 1]) ** 2 + ) + return nearest_dists.max() + + @staticmethod + def get_edges_from_simplex(simplex: np.ndarray) -> list: + """Extract edge tuples from a triangle simplex. + + Parameters + ---------- + simplex : np.ndarray + Triangle vertex indices, shape (3,). + + Returns + ------- + list + List of edge tuples. + """ + edges = [] + for p in range(3): + edges.append(tuple(sorted((simplex[p], simplex[(p + 1) % 3])))) + return edges + + @staticmethod + def generate_graph(edges: list) -> dict: + """Generate adjacency list from edge list. + + Parameters + ---------- + edges : list + List of edge tuples. + + Returns + ------- + dict + Adjacency list representation. + """ + vertices = set() + for edge in edges: + vertices.add(edge[0]) + vertices.add(edge[1]) + + vertices = sorted(list(vertices)) + graph = {v: [] for v in vertices} + + for e in edges: + graph[e[0]].append(e[1]) + graph[e[1]].append(e[0]) + + return graph + + @staticmethod + def get_cycles(graph: dict) -> list: + """Find all connected components (cycles) in boundary graph. + + Parameters + ---------- + graph : dict + Adjacency list representation. + + Returns + ------- + list + List of cycles (each cycle is a list of vertex indices). + """ + colors = {v: 0 for v in graph} + cycles = [] + + for v in graph.keys(): + if colors[v] == 0: + cycle = [] + dfs(v, graph, cycle, colors) + cycles.append(cycle) + + return cycles + + +def generate_boundary( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", +) -> Union[Polygon, MultiPolygon, None]: + """Generate boundary polygon for a single cell's transcripts. + + Uses Delaunay triangulation with iterative edge refinement to produce + more accurate boundaries than simple convex hulls. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with x, y coordinates. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + + Returns + ------- + Union[Polygon, MultiPolygon, None] + Cell boundary geometry, or None if insufficient points. + """ + # Convert Polars to pandas if needed + if isinstance(df, pl.DataFrame): + df = df.to_pandas() + + if len(df) < 3: + return None + + bi = BoundaryIdentification(df[[x, y]].values) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + return bi.find_cycles() + + +def generate_boundaries( + df: Union[pd.DataFrame, pl.DataFrame], + x: str = "x", + y: str = "y", + cell_id: str = "seg_cell_id", + n_jobs: int = 1, + chunksize: int = 8, + progress: bool = True, +) -> gpd.GeoDataFrame: + """Generate boundaries for all cells in a segmentation result. + + Parameters + ---------- + df : Union[pd.DataFrame, pl.DataFrame] + Transcript data with cell assignments. + x : str + Column name for x coordinate. + y : str + Column name for y coordinate. + cell_id : str + Column name for cell ID. + + Returns + ------- + gpd.GeoDataFrame + GeoDataFrame with cell_id, length, and geometry columns. + """ + def iter_groups() -> Tuple[Iterable[Tuple[object, np.ndarray]], int]: + if isinstance(df, pl.DataFrame): + grouped = df.group_by(cell_id).agg( + [ + pl.col(x).list().alias("_x"), + pl.col(y).list().alias("_y"), + ] + ) + total = grouped.height + + def _gen(): + for cid, xs, ys in grouped.iter_rows(): + yield cid, np.column_stack((xs, ys)) + + return _gen(), total + + group_df = df.groupby(cell_id) + total = group_df.ngroups + + def _gen(): + for cid, t in group_df: + yield cid, t[[x, y]].to_numpy() + + return _gen(), total + + def _compute_one(item: Tuple[object, np.ndarray]) -> Tuple[object, int, Union[Polygon, MultiPolygon, None]]: + cid, points = item + n_points = points.shape[0] + if n_points < 3: + return cid, n_points, None + try: + bi = BoundaryIdentification(points) + bi.calculate_part_1(plot=False) + bi.calculate_part_2(plot=False) + geom = bi.find_cycles() + except Exception: + geom = None + return cid, n_points, geom + + group_iter, total = iter_groups() + res = [] + + if n_jobs and n_jobs > 1: + with ThreadPoolExecutor(max_workers=n_jobs) as ex: + iterator = ex.map(_compute_one, group_iter, chunksize=chunksize) + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for cid, length, geom in iterator: + res.append({"cell_id": cid, "length": length, "geom": geom}) + else: + iterator = group_iter + if progress: + iterator = tqdm(iterator, total=total, desc="Generating boundaries") + for item in iterator: + cid, length, geom = _compute_one(item) + res.append({"cell_id": cid, "length": length, "geom": geom}) + + return gpd.GeoDataFrame( + data=[[b["cell_id"], b["length"]] for b in res], + geometry=[b["geom"] for b in res], + columns=["cell_id", "length"], + ) + + +def extract_largest_polygon( + geom: Union[Polygon, MultiPolygon, None], +) -> Union[Polygon, None]: + """Extract the largest polygon from a geometry. + + Parameters + ---------- + geom : Union[Polygon, MultiPolygon, None] + Input geometry. + + Returns + ------- + Union[Polygon, None] + Largest polygon, or None if input is None. + """ + if geom is None: + return None + if getattr(geom, "is_empty", False): + return None + if isinstance(geom, MultiPolygon): + candidates = [p for p in geom.geoms if p is not None and not p.is_empty] + if not candidates: + return None + return max(candidates, key=lambda p: p.area) + return geom diff --git a/src/segger/export/merged_writer.py b/src/segger/export/merged_writer.py new file mode 100644 index 0000000..eb687df --- /dev/null +++ b/src/segger/export/merged_writer.py @@ -0,0 +1,317 @@ +"""Write segmentation results merged back to original transcripts. + +This writer joins segmentation predictions with the original transcript data, +producing a single output file that contains all original columns plus +the segmentation results (segger_cell_id, segger_similarity). + +Usage +----- +>>> from segger.export.merged_writer import MergedTranscriptsWriter +>>> writer = MergedTranscriptsWriter( +... original_transcripts_path=Path("data/transcripts.parquet") +... ) +>>> output_path = writer.write(predictions, Path("output/")) + +The output file contains: +- All original transcript columns +- segger_cell_id: Assigned cell ID (-1 for unassigned) +- segger_similarity: Assignment confidence score (0.0 for unassigned) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional, Union + +import polars as pl + +from segger.export.output_formats import OutputFormat, register_writer + +if TYPE_CHECKING: + pass + + +@register_writer(OutputFormat.SEGGER_RAW) +class SeggerRawWriter: + """Write raw Segger prediction output (default format). + + This writer outputs just the predictions DataFrame without merging + with original transcripts. This is the default Segger output format. + + Output columns: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID + - segger_similarity: Assignment confidence score + """ + + def __init__( + self, + compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", + ): + """Initialize the raw writer. + + Parameters + ---------- + compression + Parquet compression algorithm. Default is 'snappy'. + """ + self.compression = compression if compression != "none" else None + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + output_name: str = "predictions.parquet", + **kwargs, + ) -> Path: + """Write predictions to Parquet file. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Output directory. + output_name + Output filename. Default is 'predictions.parquet'. + + Returns + ------- + Path + Path to the written Parquet file. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / output_name + predictions.write_parquet(output_path, compression=self.compression) + + return output_path + + +@register_writer(OutputFormat.MERGED_TRANSCRIPTS) +class MergedTranscriptsWriter: + """Write segmentation results merged with original transcripts. + + This writer joins predictions with original transcript data, producing + a complete output file with all original columns plus segmentation results. + + Output columns: + - All original transcript columns + - segger_cell_id: Assigned cell ID (configurable marker for unassigned) + - segger_similarity: Assignment confidence score + + Parameters + ---------- + original_transcripts_path + Path to the original transcripts file (Parquet or CSV). + If not provided, must be passed to write() via kwargs. + unassigned_marker + Value to use for unassigned transcripts. Default is -1. + Can be int, str, or None. + include_similarity + Whether to include the similarity score column. Default True. + compression + Parquet compression algorithm. Default is 'snappy'. + """ + + def __init__( + self, + original_transcripts_path: Optional[Path] = None, + unassigned_marker: Union[int, str, None] = -1, + include_similarity: bool = True, + compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", + ): + self.original_transcripts_path = ( + Path(original_transcripts_path) if original_transcripts_path else None + ) + self.unassigned_marker = unassigned_marker + self.include_similarity = include_similarity + self.compression = compression if compression != "none" else None + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + output_name: str = "transcripts_segmented.parquet", + transcripts: Optional[pl.DataFrame] = None, + original_transcripts_path: Optional[Path] = None, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + **kwargs, + ) -> Path: + """Merge predictions with original transcripts and write to file. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. Must contain: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID + - segger_similarity: Assignment confidence score (optional) + output_dir + Output directory. + output_name + Output filename. Default is 'transcripts_segmented.parquet'. + transcripts + Original transcripts DataFrame. If provided, used instead of + loading from original_transcripts_path. + original_transcripts_path + Path to original transcripts. Overrides constructor parameter. + row_index_column + Column name for row index in predictions. Default 'row_index'. + cell_id_column + Column name for cell ID in predictions. Default 'segger_cell_id'. + similarity_column + Column name for similarity in predictions. Default 'segger_similarity'. + + Returns + ------- + Path + Path to the written Parquet file. + + Raises + ------ + ValueError + If no transcripts source is provided. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Get original transcripts + if transcripts is not None: + original = transcripts + else: + path = original_transcripts_path or self.original_transcripts_path + if path is None: + raise ValueError( + "No original transcripts provided. Either pass 'transcripts' " + "DataFrame or specify 'original_transcripts_path'." + ) + original = self._load_transcripts(path) + + # Prepare predictions for join + pred_cols = [row_index_column, cell_id_column] + if self.include_similarity and similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Handle missing row_index in original (add if needed) + if row_index_column not in original.columns: + original = original.with_row_index(name=row_index_column) + + # Join predictions with original transcripts + merged = original.join( + pred_subset, + on=row_index_column, + how="left", + ) + + # Fill unassigned values + if self.unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(self.unassigned_marker) + ) + if self.include_similarity and similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + # Write output + output_path = output_dir / output_name + merged.write_parquet(output_path, compression=self.compression) + + return output_path + + def _load_transcripts(self, path: Path) -> pl.DataFrame: + """Load transcripts from file. + + Parameters + ---------- + path + Path to transcripts file (Parquet or CSV). + + Returns + ------- + pl.DataFrame + Loaded transcripts. + """ + path = Path(path) + suffix = path.suffix.lower() + + if suffix == ".parquet": + return pl.read_parquet(path) + elif suffix in (".csv", ".tsv"): + separator = "\t" if suffix == ".tsv" else "," + return pl.read_csv(path, separator=separator) + else: + # Try Parquet first, then CSV + try: + return pl.read_parquet(path) + except Exception: + return pl.read_csv(path) + + +def merge_predictions_with_transcripts( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + unassigned_marker: Union[int, str, None] = -1, +) -> pl.DataFrame: + """Merge predictions with transcripts (functional interface). + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + transcripts + Original transcripts DataFrame. + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + unassigned_marker + Value for unassigned transcripts. + + Returns + ------- + pl.DataFrame + Merged DataFrame with all original columns plus predictions. + + Examples + -------- + >>> merged = merge_predictions_with_transcripts(predictions, transcripts) + >>> print(merged.columns) + ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] + """ + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned + if unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(unassigned_marker) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged diff --git a/src/segger/export/output_formats.py b/src/segger/export/output_formats.py new file mode 100644 index 0000000..d08a990 --- /dev/null +++ b/src/segger/export/output_formats.py @@ -0,0 +1,309 @@ +"""Output format definitions and writer registry for segmentation results. + +This module provides: +- OutputFormat enum for available output formats +- OutputWriter protocol for implementing format-specific writers +- Factory function to get the appropriate writer for a format + +Available formats: +- SEGGER_RAW: Default Segger output (predictions parquet) +- MERGED_TRANSCRIPTS: Original transcripts merged with assignments +- SPATIALDATA: SpatialData Zarr format for scverse ecosystem +- ANNDATA: AnnData (.h5ad) cell x gene matrix + +Usage +----- +>>> from segger.export.output_formats import OutputFormat, get_writer +>>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS) +>>> writer.write(predictions, transcripts, output_dir) +""" + +from __future__ import annotations + +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import geopandas as gpd + import polars as pl + + +class OutputFormat(str, Enum): + """Available output formats for segmentation results. + + Attributes + ---------- + SEGGER_RAW : str + Default Segger output format. Writes predictions as Parquet file + with columns: row_index, segger_cell_id, segger_similarity. + + MERGED_TRANSCRIPTS : str + Merged transcripts format. Original transcript data with segmentation + results joined (segger_cell_id, segger_similarity columns added). + + SPATIALDATA : str + SpatialData Zarr format. Creates a .zarr store compatible with + the scverse ecosystem, containing transcripts and optional boundaries. + + ANNDATA : str + AnnData format. Creates a .h5ad file with a cell x gene matrix + derived from transcript assignments. + """ + + SEGGER_RAW = "segger_raw" + MERGED_TRANSCRIPTS = "merged" + SPATIALDATA = "spatialdata" + ANNDATA = "anndata" + + @classmethod + def from_string(cls, value: str) -> "OutputFormat": + """Parse OutputFormat from string, case-insensitive. + + Parameters + ---------- + value + Format name ('segger_raw', 'merged', 'spatialdata', 'anndata', or 'all'). + + Returns + ------- + OutputFormat + Corresponding enum value. + + Raises + ------ + ValueError + If value is not a valid format name. + """ + value_lower = value.lower().strip() + + # Handle aliases + aliases = { + "raw": cls.SEGGER_RAW, + "segger": cls.SEGGER_RAW, + "default": cls.SEGGER_RAW, + "merge": cls.MERGED_TRANSCRIPTS, + "merged": cls.MERGED_TRANSCRIPTS, + "transcripts": cls.MERGED_TRANSCRIPTS, + "sdata": cls.SPATIALDATA, + "zarr": cls.SPATIALDATA, + "h5ad": cls.ANNDATA, + "ann": cls.ANNDATA, + "anndata": cls.ANNDATA, + } + + if value_lower in aliases: + return aliases[value_lower] + + # Try direct match + for fmt in cls: + if fmt.value == value_lower: + return fmt + + valid = [f.value for f in cls] + list(aliases.keys()) + raise ValueError( + f"Unknown output format: '{value}'. " + f"Valid formats: {sorted(set(valid))}" + ) + + +@runtime_checkable +class OutputWriter(Protocol): + """Protocol for output format writers. + + Implementations must provide a `write` method that writes segmentation + results to the specified output directory. + """ + + def write( + self, + predictions: "pl.DataFrame", + output_dir: Path, + **kwargs: Any, + ) -> Path: + """Write segmentation results to output format. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. Must contain: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID (or -1/None for unassigned) + - segger_similarity: Assignment confidence score + + output_dir + Directory to write output files. + + **kwargs + Format-specific options (e.g., transcripts, boundaries). + + Returns + ------- + Path + Path to the primary output file/directory. + """ + ... + + +# Registry of output writers by format +_OUTPUT_WRITERS: dict[OutputFormat, type] = {} + + +def register_writer(fmt: OutputFormat): + """Decorator to register an output writer class. + + Parameters + ---------- + fmt + Output format this writer handles. + + Returns + ------- + decorator + Class decorator that registers the writer. + + Examples + -------- + >>> @register_writer(OutputFormat.MERGED_TRANSCRIPTS) + ... class MergedTranscriptsWriter: + ... def write(self, predictions, output_dir, **kwargs): + ... ... + """ + def decorator(cls): + _OUTPUT_WRITERS[fmt] = cls + return cls + return decorator + + +def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: + """Get an output writer for the specified format. + + Parameters + ---------- + fmt + Output format (enum or string). + **init_kwargs + Keyword arguments passed to the writer constructor. + + Returns + ------- + OutputWriter + Writer instance for the specified format. + + Raises + ------ + ValueError + If format is not recognized or writer not registered. + + Examples + -------- + >>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS, unassigned_marker=-1) + >>> writer.write(predictions, Path("output/")) + """ + if isinstance(fmt, str): + fmt = OutputFormat.from_string(fmt) + + if fmt not in _OUTPUT_WRITERS: + raise ValueError( + f"No writer registered for format: {fmt.value}. " + f"Available formats: {[f.value for f in _OUTPUT_WRITERS.keys()]}" + ) + + writer_cls = _OUTPUT_WRITERS[fmt] + return writer_cls(**init_kwargs) + + +def get_all_writers(**init_kwargs: Any) -> dict[OutputFormat, OutputWriter]: + """Get writers for all registered formats. + + Parameters + ---------- + **init_kwargs + Keyword arguments passed to each writer constructor. + + Returns + ------- + dict[OutputFormat, OutputWriter] + Dictionary mapping formats to writer instances. + """ + return {fmt: get_writer(fmt, **init_kwargs) for fmt in _OUTPUT_WRITERS} + + +def write_all_formats( + predictions: "pl.DataFrame", + output_dir: Path, + **kwargs: Any, +) -> dict[OutputFormat, Path]: + """Write segmentation results in all available formats. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Base output directory. Subdirectories may be created for each format. + **kwargs + Additional arguments passed to each writer (transcripts, boundaries, etc.). + + Returns + ------- + dict[OutputFormat, Path] + Dictionary mapping formats to output paths. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + results = {} + for fmt, writer in get_all_writers().items(): + try: + path = writer.write(predictions, output_dir, **kwargs) + results[fmt] = path + except Exception as e: + # Log error but continue with other formats + import warnings + warnings.warn( + f"Failed to write {fmt.value} format: {e}", + UserWarning, + stacklevel=2, + ) + + return results + + +# Import writers to register them (done at end to avoid circular imports) +def _register_builtin_writers(): + """Register built-in output writers. + + Called lazily to avoid import errors if optional dependencies are missing. + """ + # Import here to register writers via decorators + from segger.export import merged_writer # noqa: F401 + from segger.export import anndata_writer # noqa: F401 + + # SpatialData writer is optional + try: + from segger.export import spatialdata_writer # noqa: F401 + except ImportError: + pass + + +# Lazy registration on first use +_writers_registered = False + + +def _ensure_writers_registered(): + """Ensure built-in writers are registered.""" + global _writers_registered + if not _writers_registered: + _register_builtin_writers() + _writers_registered = True + + +# Override get_writer to ensure registration +_original_get_writer = get_writer + + +def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: + """Get an output writer for the specified format.""" + _ensure_writers_registered() + return _original_get_writer(fmt, **init_kwargs) diff --git a/src/segger/export/sopa_compat.py b/src/segger/export/sopa_compat.py new file mode 100644 index 0000000..230157c --- /dev/null +++ b/src/segger/export/sopa_compat.py @@ -0,0 +1,396 @@ +"""SOPA compatibility utilities for SpatialData export. + +SOPA (Spatial Omics Pipeline Architecture) is a framework for spatial omics +analysis built on SpatialData. This module provides utilities to ensure +Segger output is compatible with SOPA workflows. + +SOPA Conventions +---------------- +- shapes[cell_key]: Cell polygons with 'cell_id' column +- points[transcript_key]: Transcripts with 'cell_id' assignment column +- No images required for segmentation workflows +- Cell IDs should be consistent between shapes and points + +Usage +----- +>>> from segger.export.sopa_compat import validate_sopa_compatibility +>>> issues = validate_sopa_compatibility(sdata) +>>> if not issues: +... print("SpatialData is SOPA-compatible") + +>>> from segger.export.sopa_compat import export_for_sopa +>>> path = export_for_sopa(sdata, Path("output/sopa_compatible.zarr")) + +Installation +------------ +Requires the spatialdata optional dependency: + pip install segger[spatialdata] + +For full SOPA integration: + pip install segger[sopa] +""" + +from __future__ import annotations + +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +import polars as pl + +from segger.utils.optional_deps import ( + SPATIALDATA_AVAILABLE, + SOPA_AVAILABLE, + require_spatialdata, + warn_sopa_unavailable, +) + +if TYPE_CHECKING: + import geopandas as gpd + from spatialdata import SpatialData + + +# SOPA expected keys and columns +SOPA_DEFAULT_CELL_KEY = "cells" +SOPA_DEFAULT_TRANSCRIPT_KEY = "transcripts" +SOPA_CELL_ID_COLUMN = "cell_id" + + +def validate_sopa_compatibility( + sdata: "SpatialData", + cell_key: str = SOPA_DEFAULT_CELL_KEY, + transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, +) -> list[str]: + """Validate SpatialData object for SOPA compatibility. + + Checks that the SpatialData object follows SOPA conventions: + - Cell shapes exist with cell_id column + - Transcripts exist with cell_id assignment column + - Cell IDs are consistent between shapes and points + + Parameters + ---------- + sdata + SpatialData object to validate. + cell_key + Expected key for cell shapes. Default "cells". + transcript_key + Expected key for transcripts. Default "transcripts". + + Returns + ------- + list[str] + List of compatibility issues (empty if fully compatible). + + Examples + -------- + >>> issues = validate_sopa_compatibility(sdata) + >>> if issues: + ... for issue in issues: + ... print(f"- {issue}") + """ + require_spatialdata() + + issues = [] + + # Check for cell shapes + if cell_key not in sdata.shapes: + issues.append( + f"Missing cell shapes: expected shapes['{cell_key}']. " + f"Available shapes: {list(sdata.shapes.keys())}" + ) + else: + cells = sdata.shapes[cell_key] + if SOPA_CELL_ID_COLUMN not in cells.columns: + issues.append( + f"Cell shapes missing '{SOPA_CELL_ID_COLUMN}' column. " + f"Available columns: {list(cells.columns)}" + ) + + # Check for transcripts + if transcript_key not in sdata.points: + issues.append( + f"Missing transcripts: expected points['{transcript_key}']. " + f"Available points: {list(sdata.points.keys())}" + ) + else: + transcripts = sdata.points[transcript_key] + # Get column names from Dask DataFrame + if hasattr(transcripts, "columns"): + tx_columns = list(transcripts.columns) + else: + tx_columns = [] + + if SOPA_CELL_ID_COLUMN not in tx_columns: + # Check for alternative names + alt_names = ["segger_cell_id", "seg_cell_id", "cell"] + found = [c for c in alt_names if c in tx_columns] + if found: + issues.append( + f"Transcripts use '{found[0]}' instead of '{SOPA_CELL_ID_COLUMN}'. " + "SOPA expects 'cell_id' column for assignments." + ) + else: + issues.append( + f"Transcripts missing '{SOPA_CELL_ID_COLUMN}' column. " + f"Available columns: {tx_columns}" + ) + + # Check cell ID consistency + if cell_key in sdata.shapes and transcript_key in sdata.points: + try: + cells = sdata.shapes[cell_key] + transcripts = sdata.points[transcript_key] + + if SOPA_CELL_ID_COLUMN in cells.columns: + cell_ids_shapes = set(cells[SOPA_CELL_ID_COLUMN].unique()) + + if hasattr(transcripts, "compute"): + tx_computed = transcripts.compute() + else: + tx_computed = transcripts + + if SOPA_CELL_ID_COLUMN in tx_computed.columns: + cell_ids_tx = set( + tx_computed[SOPA_CELL_ID_COLUMN].dropna().unique() + ) + # Filter out unassigned (-1 or negative) + cell_ids_tx = {c for c in cell_ids_tx if c >= 0} + + missing_in_shapes = cell_ids_tx - cell_ids_shapes + if missing_in_shapes: + issues.append( + f"Cell IDs in transcripts not found in shapes: " + f"{len(missing_in_shapes)} IDs missing" + ) + except Exception as e: + issues.append(f"Could not verify cell ID consistency: {e}") + + return issues + + +def export_for_sopa( + sdata: "SpatialData", + output_path: Path, + cell_key: str = SOPA_DEFAULT_CELL_KEY, + transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, + rename_cell_id: bool = True, + overwrite: bool = False, +) -> Path: + """Export SpatialData in SOPA-expected structure. + + Ensures the output follows SOPA conventions: + - shapes[cell_key]: Cell polygons with 'cell_id' column + - points[transcript_key]: Transcripts with 'cell_id' assignment + + Parameters + ---------- + sdata + SpatialData object to export. + output_path + Path for output .zarr store. + cell_key + Key for cell shapes. Default "cells". + transcript_key + Key for transcripts. Default "transcripts". + rename_cell_id + If True, rename 'segger_cell_id' to 'cell_id' for SOPA. + overwrite + Whether to overwrite existing output. + + Returns + ------- + Path + Path to exported .zarr store. + + Examples + -------- + >>> path = export_for_sopa(sdata, Path("output/sopa_ready.zarr")) + """ + require_spatialdata() + import spatialdata + + output_path = Path(output_path) + + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Output exists: {output_path}. Use overwrite=True to replace." + ) + + # Create a modified copy for SOPA compatibility + elements = {} + + # Process points (transcripts) + for key in sdata.points: + points = sdata.points[key] + + # Rename to expected key if needed + target_key = transcript_key if key == list(sdata.points.keys())[0] else key + + # Rename cell_id column if needed + if rename_cell_id and hasattr(points, "columns"): + if "segger_cell_id" in points.columns and SOPA_CELL_ID_COLUMN not in points.columns: + points = points.rename(columns={"segger_cell_id": SOPA_CELL_ID_COLUMN}) + + elements[f"points/{target_key}"] = points + + # Process shapes + for key in sdata.shapes: + shapes = sdata.shapes[key] + + # Rename to expected key if needed + target_key = cell_key if key == list(sdata.shapes.keys())[0] else key + + # Ensure cell_id column exists + if SOPA_CELL_ID_COLUMN not in shapes.columns: + if "segger_cell_id" in shapes.columns: + shapes = shapes.rename(columns={"segger_cell_id": SOPA_CELL_ID_COLUMN}) + elif shapes.index.name: + shapes = shapes.reset_index() + if shapes.columns[0] != SOPA_CELL_ID_COLUMN: + shapes = shapes.rename(columns={shapes.columns[0]: SOPA_CELL_ID_COLUMN}) + + elements[f"shapes/{target_key}"] = shapes + + # Create new SpatialData + sdata_sopa = spatialdata.SpatialData.from_elements_dict(elements) + + # Write + if output_path.exists(): + import shutil + shutil.rmtree(output_path) + + sdata_sopa.write(output_path) + + return output_path + + +def sopa_to_segger_input( + sopa_sdata: "SpatialData", + cell_key: str = SOPA_DEFAULT_CELL_KEY, + transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, +) -> tuple[pl.LazyFrame, "gpd.GeoDataFrame"]: + """Convert SOPA SpatialData to Segger internal format. + + Enables round-trip: SOPA → Segger → SOPA + + Parameters + ---------- + sopa_sdata + SOPA-formatted SpatialData object. + cell_key + Key for cell shapes. + transcript_key + Key for transcripts. + + Returns + ------- + tuple[pl.LazyFrame, gpd.GeoDataFrame] + (transcripts, boundaries) in Segger internal format. + + Examples + -------- + >>> transcripts, boundaries = sopa_to_segger_input(sdata) + >>> # Run Segger segmentation + >>> predictions = segment(transcripts, boundaries) + >>> # Export back to SOPA format + >>> export_for_sopa(results, "output.zarr") + """ + require_spatialdata() + import geopandas as gpd + + # Extract transcripts + if transcript_key not in sopa_sdata.points: + available = list(sopa_sdata.points.keys()) + raise ValueError( + f"Transcript key '{transcript_key}' not found. Available: {available}" + ) + + points = sopa_sdata.points[transcript_key] + + # Convert to Polars + if hasattr(points, "compute"): + points_pd = points.compute() + else: + points_pd = points + + transcripts = pl.from_pandas(points_pd).lazy() + + # Normalize column names + column_map = { + SOPA_CELL_ID_COLUMN: "cell_id", + } + for old, new in column_map.items(): + if old in transcripts.collect_schema().names() and old != new: + transcripts = transcripts.rename({old: new}) + + # Add row_index if missing + schema = transcripts.collect_schema() + if "row_index" not in schema.names(): + transcripts = transcripts.with_row_index(name="row_index") + + # Extract boundaries + boundaries = None + if cell_key in sopa_sdata.shapes: + boundaries = sopa_sdata.shapes[cell_key].copy() + + # Normalize cell_id column + if SOPA_CELL_ID_COLUMN not in boundaries.columns: + if boundaries.index.name: + boundaries = boundaries.reset_index() + boundaries = boundaries.rename( + columns={boundaries.columns[0]: SOPA_CELL_ID_COLUMN} + ) + + return transcripts, boundaries + + +def check_sopa_installation() -> dict[str, bool]: + """Check SOPA and related package installation status. + + Returns + ------- + dict[str, bool] + Dictionary with package names and installation status. + """ + status = { + "spatialdata": SPATIALDATA_AVAILABLE, + "sopa": SOPA_AVAILABLE, + } + + # Check spatialdata-io + try: + import spatialdata_io # noqa: F401 + status["spatialdata_io"] = True + except ImportError: + status["spatialdata_io"] = False + + return status + + +def get_sopa_installation_instructions() -> str: + """Get installation instructions for SOPA integration. + + Returns + ------- + str + Installation instructions. + """ + status = check_sopa_installation() + + lines = ["SOPA Integration Installation Status:", ""] + + for pkg, installed in status.items(): + mark = "✓" if installed else "✗" + lines.append(f" {mark} {pkg}: {'installed' if installed else 'not installed'}") + + lines.append("") + lines.append("To install all SOPA dependencies:") + lines.append(" pip install segger[spatialdata-all]") + lines.append("") + lines.append("Or install individually:") + lines.append(" pip install spatialdata>=0.2.0") + lines.append(" pip install spatialdata-io>=0.1.0") + lines.append(" pip install sopa>=1.0.0") + + return "\n".join(lines) diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py new file mode 100644 index 0000000..18720c5 --- /dev/null +++ b/src/segger/export/spatialdata_writer.py @@ -0,0 +1,535 @@ +"""Write segmentation results as SpatialData Zarr stores. + +This writer creates SpatialData-compatible Zarr stores containing: +- points["transcripts"]: Transcripts with segger_cell_id column +- shapes["cells"]: Cell boundaries (optional, can be input or generated) +- tables["cell_table"]: AnnData table with cell x gene counts (optional) + +NO images are included (per requirements). + +Usage +----- +>>> from segger.export.spatialdata_writer import SpatialDataWriter +>>> writer = SpatialDataWriter() +>>> output_path = writer.write( +... predictions=predictions, +... transcripts=transcripts, +... output_dir=Path("output/"), +... boundaries=boundaries, # Optional +... ) + +Installation +------------ +Requires the spatialdata optional dependency: + pip install segger[spatialdata] +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional + +import polars as pl + +from segger.utils.optional_deps import ( + require_spatialdata, +) +from segger.export.output_formats import OutputFormat, register_writer +from segger.export.anndata_writer import build_anndata_table + +if TYPE_CHECKING: + import geopandas as gpd + from spatialdata import SpatialData + + +@register_writer(OutputFormat.SPATIALDATA) +class SpatialDataWriter: + """Write segmentation results as SpatialData Zarr store. + + Creates a SpatialData object with: + - points["transcripts"]: Transcripts with cell assignments + - shapes["cells"]: Cell boundaries (if provided or generated) + + Parameters + ---------- + include_boundaries + Whether to include cell shapes in output. Default True. + boundary_method + How to generate boundaries if not provided: + - "input": Use input boundaries if available + - "convex_hull": Generate convex hull per cell + - "delaunay": Delaunay triangulation-based boundary extraction + - "skip": Don't include shapes + boundary_n_jobs + Parallel workers for Delaunay boundary generation (threads). + points_key + Key for transcripts in sdata.points. Default "transcripts". + shapes_key + Key for cell shapes in sdata.shapes. Default "cells". + include_table + Whether to include AnnData table in sdata.tables. Default True. + table_key + Key for AnnData table in sdata.tables. Default "cell_table". + table_region_key + Column in shapes that identifies cells. Default "cell_id". + """ + + def __init__( + self, + include_boundaries: bool = True, + boundary_method: Literal["input", "convex_hull", "delaunay", "skip"] = "input", + boundary_n_jobs: int = 1, + points_key: str = "transcripts", + shapes_key: str = "cells", + include_table: bool = True, + table_key: str = "cell_table", + table_region_key: str = "cell_id", + ): + require_spatialdata() + + self.include_boundaries = include_boundaries + self.boundary_method = boundary_method + self.boundary_n_jobs = boundary_n_jobs + self.points_key = points_key + self.shapes_key = shapes_key + self.include_table = include_table + self.table_key = table_key + self.table_region_key = table_region_key + + def write( + self, + predictions: pl.DataFrame, + output_dir: Path, + transcripts: Optional[pl.DataFrame] = None, + boundaries: Optional["gpd.GeoDataFrame"] = None, + output_name: str = "segmentation.zarr", + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + feature_column: str = "feature_name", + x_column: str = "x", + y_column: str = "y", + z_column: Optional[str] = "z", + overwrite: bool = False, + **kwargs, + ) -> Path: + """Write segmentation results to SpatialData Zarr store. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + output_dir + Output directory. + transcripts + Original transcripts DataFrame. Required for SPATIALDATA format. + boundaries + Cell boundaries GeoDataFrame. Optional. + output_name + Output Zarr store name. Default "segmentation.zarr". + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + feature_column + Column name for gene/feature in transcripts. + x_column + Column name for x-coordinate. + y_column + Column name for y-coordinate. + z_column + Column name for z-coordinate (optional). + overwrite + Whether to overwrite existing Zarr store. + + Returns + ------- + Path + Path to the written .zarr store. + + Raises + ------ + ValueError + If transcripts are not provided. + """ + if transcripts is None: + raise ValueError( + "SpatialData format requires transcripts DataFrame. " + "Pass 'transcripts' parameter to write()." + ) + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / output_name + + # Check if exists + if output_path.exists() and not overwrite: + raise FileExistsError( + f"Output path exists: {output_path}. " + "Use overwrite=True to replace." + ) + + # Merge predictions with transcripts + merged = self._merge_predictions( + predictions=predictions, + transcripts=transcripts, + row_index_column=row_index_column, + cell_id_column=cell_id_column, + similarity_column=similarity_column, + ) + + # Create SpatialData object + sdata = self._create_spatialdata( + transcripts=merged, + boundaries=boundaries, + x_column=x_column, + y_column=y_column, + z_column=z_column, + cell_id_column=cell_id_column, + feature_column=feature_column, + ) + + # Write to Zarr + self._write_spatialdata_zarr( + sdata=sdata, + output_path=output_path, + overwrite=overwrite, + ) + + return output_path + + def _merge_predictions( + self, + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str, + cell_id_column: str, + similarity_column: str, + ) -> pl.DataFrame: + """Merge predictions with transcripts.""" + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned with -1 + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(-1) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged + + def _create_spatialdata( + self, + transcripts: pl.DataFrame, + boundaries: Optional["gpd.GeoDataFrame"], + x_column: str, + y_column: str, + z_column: Optional[str], + cell_id_column: str, + feature_column: str, + ) -> "SpatialData": + """Create SpatialData object from transcripts and boundaries.""" + import spatialdata + from spatialdata.models import PointsModel, ShapesModel, TableModel + import dask.dataframe as dd + + identity = self._identity_transform() + transformations = {"global": identity} if identity is not None else None + + # Convert transcripts to pandas for SpatialData + tx_pd = transcripts.to_pandas() + + # SOPA expects "cell_id" assignment in points. + if cell_id_column in tx_pd.columns and "cell_id" not in tx_pd.columns: + tx_pd["cell_id"] = tx_pd[cell_id_column] + + # Check for z-coordinate + has_z = z_column and z_column in tx_pd.columns + + # Create points element + # SpatialData expects coordinates in specific columns + coords_cols = [x_column, y_column] + if has_z: + coords_cols.append(z_column) + + # Ensure coordinates are float + for col in coords_cols: + if col in tx_pd.columns: + tx_pd[col] = tx_pd[col].astype(float) + + # Create Dask DataFrame for points + tx_dask = dd.from_pandas(tx_pd, npartitions=1) + + # Points element + points_parse_kwargs = { + "coordinates": { + "x": x_column, + "y": y_column, + **({"z": z_column} if has_z else {}), + }, + } + if transformations is not None: + points_parse_kwargs["transformations"] = transformations + + points = PointsModel.parse(tx_dask, **points_parse_kwargs) + points_elements = {self.points_key: points} + shapes_elements = {} + + # Shapes element (if boundaries provided or generated) + if self.include_boundaries and self.boundary_method != "skip": + shapes = self._get_boundaries( + transcripts=tx_pd, + boundaries=boundaries, + x_column=x_column, + y_column=y_column, + cell_id_column=cell_id_column, + ) + if shapes is not None and len(shapes) > 0: + shapes_parse_kwargs = {} + if transformations is not None: + shapes_parse_kwargs["transformations"] = transformations + shapes_parsed = ShapesModel.parse(shapes, **shapes_parse_kwargs) + shapes_elements[self.shapes_key] = shapes_parsed + + tables_elements = {} + + # Optional AnnData table + if self.include_table: + region = self.shapes_key if self.shapes_key in shapes_elements else None + instance_key = self.table_region_key if region is not None else None + table = build_anndata_table( + transcripts=transcripts, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=-1, + region=None, + region_key=None, + obs_index_as_str=True, + ) + if region is not None: + table.obs["region"] = region + if instance_key and instance_key not in table.obs.columns: + table.obs[instance_key] = table.obs.index.astype(str) + try: + table = TableModel.parse( + table, + region=region, + region_key="region", + instance_key=instance_key or "instance_id", + ) + except Exception: + pass + tables_elements[self.table_key] = table + + # Create SpatialData (prefer modern constructor methods, keep fallback) + sdata = self._build_spatialdata( + spatialdata=spatialdata, + points=points_elements, + shapes=shapes_elements, + tables=tables_elements, + ) + + return sdata + + def _identity_transform(self): + """Return SpatialData identity transform when available.""" + try: + from spatialdata.transformations import Identity + return Identity() + except Exception: + return None + + def _build_spatialdata(self, spatialdata, points: dict, shapes: dict, tables: dict): + """Build a SpatialData object across SpatialData API variants.""" + shapes_arg = shapes or None + tables_arg = tables or None + + if hasattr(spatialdata.SpatialData, "init_from_elements"): + return spatialdata.SpatialData.init_from_elements( + points=points, + shapes=shapes_arg, + tables=tables_arg, + ) + + try: + return spatialdata.SpatialData( + points=points, + shapes=shapes_arg, + tables=tables_arg, + ) + except Exception: + elements = {} + for key, value in points.items(): + elements[f"points/{key}"] = value + for key, value in shapes.items(): + elements[f"shapes/{key}"] = value + sdata = spatialdata.SpatialData.from_elements_dict(elements) + for key, value in (tables or {}).items(): + sdata.tables[key] = value + return sdata + + def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> None: + """Write SpatialData object with compatibility fallback.""" + try: + sdata.write(output_path, overwrite=overwrite) + return + except TypeError: + pass + + if output_path.exists(): + import shutil + shutil.rmtree(output_path) + sdata.write(output_path) + + def _get_boundaries( + self, + transcripts: "pd.DataFrame", + boundaries: Optional["gpd.GeoDataFrame"], + x_column: str, + y_column: str, + cell_id_column: str, + ) -> Optional["gpd.GeoDataFrame"]: + """Get or generate cell boundaries.""" + import geopandas as gpd + import pandas as pd + from shapely.geometry import MultiPoint + + def _ensure_cell_id(gdf: "gpd.GeoDataFrame") -> "gpd.GeoDataFrame": + if "cell_id" in gdf.columns: + return gdf + if cell_id_column in gdf.columns: + gdf = gdf.copy() + gdf["cell_id"] = gdf[cell_id_column] + return gdf + gdf = gdf.reset_index(drop=False) + if "cell_id" not in gdf.columns and len(gdf.columns) > 0: + gdf["cell_id"] = gdf[gdf.columns[0]] + return gdf + + # Use input boundaries if available + if boundaries is not None: + return _ensure_cell_id(boundaries) + + # Generate boundaries based on method + if self.boundary_method == "input": + # No input boundaries, skip + return None + + elif self.boundary_method == "convex_hull": + # Generate convex hulls from transcript positions + assigned = transcripts[transcripts[cell_id_column] != -1].copy() + + if len(assigned) == 0: + return None + + # Group by cell and create convex hulls + hulls = [] + cell_ids = [] + + for cell_id, group in assigned.groupby(cell_id_column): + if len(group) < 3: + continue # Need at least 3 points for convex hull + + points = list(zip(group[x_column], group[y_column])) + mp = MultiPoint(points) + hull = mp.convex_hull + + if not hull.is_empty: + hulls.append(hull) + cell_ids.append(cell_id) + + if not hulls: + return None + + return _ensure_cell_id(gpd.GeoDataFrame( + {"cell_id": cell_ids}, + geometry=hulls, + )) + + elif self.boundary_method == "delaunay": + from segger.export.boundary import generate_boundaries + + assigned = transcripts[transcripts[cell_id_column] != -1].copy() + if len(assigned) == 0: + return None + + boundaries_gdf = generate_boundaries( + assigned, + x=x_column, + y=y_column, + cell_id=cell_id_column, + n_jobs=self.boundary_n_jobs, + ) + if boundaries_gdf is None or len(boundaries_gdf) == 0: + return None + return _ensure_cell_id(boundaries_gdf) + + return None + + +def write_spatialdata( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + output_dir: Path, + boundaries: Optional["gpd.GeoDataFrame"] = None, + output_name: str = "segmentation.zarr", + **kwargs, +) -> Path: + """Convenience function to write SpatialData output. + + Parameters + ---------- + predictions + Segmentation predictions. + transcripts + Original transcripts. + output_dir + Output directory. + boundaries + Cell boundaries (optional). + output_name + Output filename. + **kwargs + Additional arguments passed to SpatialDataWriter.write(). + + Returns + ------- + Path + Path to written .zarr store. + + Examples + -------- + >>> path = write_spatialdata( + ... predictions=preds, + ... transcripts=tx, + ... output_dir=Path("output/"), + ... ) + """ + writer = SpatialDataWriter() + return writer.write( + predictions=predictions, + output_dir=output_dir, + transcripts=transcripts, + boundaries=boundaries, + output_name=output_name, + **kwargs, + ) diff --git a/src/segger/export/xenium.py b/src/segger/export/xenium.py new file mode 100644 index 0000000..3ab1ebe --- /dev/null +++ b/src/segger/export/xenium.py @@ -0,0 +1,862 @@ +"""Xenium Explorer export functionality. + +This module converts segmentation results into Xenium Explorer-compatible +Zarr format for visualization and validation. +""" + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +import json +from concurrent.futures.process import BrokenProcessPool + +import numpy as np +import pandas as pd +import polars as pl +import zarr +from pqdm.processes import pqdm as pqdm_processes +try: + from pqdm.threads import pqdm as pqdm_threads +except Exception: + pqdm_threads = None +from shapely.geometry import MultiPoint, MultiPolygon, Polygon +from tqdm import tqdm +from zarr.storage import ZipStore + +from .boundary import extract_largest_polygon, generate_boundary + + +def _normalize_polygon_vertices( + polygon: Polygon, + max_vertices: int, +) -> Tuple[List[Tuple[float, float]], int]: + """Normalize polygon vertices to a fixed length with closure. + + Returns a list of vertices padded/truncated to ``max_vertices`` and the + true number of vertices including the closing vertex. + """ + coords = list(polygon.exterior.coords) + # Remove duplicate closing vertex + if coords[0] == coords[-1]: + coords = coords[:-1] + + if len(coords) < 3: + return [], 0 + + num_vertices = len(coords) + 1 # include closing vertex + target = max_vertices - 1 + + if len(coords) > target: + indices = np.linspace(0, len(coords) - 1, target, dtype=int) + coords = [coords[i] for i in indices] + + # Close polygon and pad + coords.append(coords[0]) + if len(coords) < max_vertices: + coords += [coords[0]] * (max_vertices - len(coords)) + + return coords, num_vertices + + +def _safe_boundary_polygon( + seg_cell: pd.DataFrame, + x: str, + y: str, + boundary_method: str = "delaunay", + boundary_voxel_size: float = 0.0, +) -> Optional[Polygon]: + """Generate a robust polygon boundary for a cell. + + Uses the requested boundary method with robust fallbacks. + """ + if boundary_method in {"convex_hull", "input"}: + mp = MultiPoint(seg_cell[[x, y]].values) + cell_poly = mp.convex_hull if not mp.is_empty else None + elif boundary_method == "voxel": + if boundary_voxel_size <= 0: + return None + points = seg_cell[[x, y]].to_numpy(dtype=np.float64) + if len(points) < 3: + return None + mins = points.min(axis=0) + bins = np.floor((points - mins) / boundary_voxel_size).astype(np.int64) + _, keep = np.unique(bins, axis=0, return_index=True) + reduced = points[np.sort(keep)] + if len(reduced) < 3: + return None + mp = MultiPoint(reduced) + cell_poly = mp.convex_hull if not mp.is_empty else None + else: + working = seg_cell + if boundary_voxel_size > 0: + points = seg_cell[[x, y]].to_numpy(dtype=np.float64) + mins = points.min(axis=0) + bins = np.floor((points - mins) / boundary_voxel_size).astype(np.int64) + _, keep = np.unique(bins, axis=0, return_index=True) + working = seg_cell.iloc[np.sort(keep)] + + try: + cell_poly = generate_boundary(working, x=x, y=y) + if isinstance(cell_poly, MultiPolygon): + cell_poly = extract_largest_polygon(cell_poly) + except Exception: + cell_poly = None + + if cell_poly is None or not isinstance(cell_poly, Polygon) or cell_poly.is_empty: + # Fallback: convex hull of points + mp = MultiPoint(seg_cell[[x, y]].values) + cell_poly = mp.convex_hull if not mp.is_empty else None + + if cell_poly is None or not isinstance(cell_poly, Polygon) or cell_poly.is_empty: + return None + + return cell_poly + + +def _prepare_input_boundaries( + boundaries, + boundary_id_column: str = "cell_id", + boundary_type_column: str = "boundary_type", + boundary_cell_value: str = "cell", + boundary_nucleus_value: str = "nucleus", +) -> Tuple[Dict[Any, Polygon], Dict[Any, Polygon]]: + """Prepare lookup tables for input cell/nucleus boundaries.""" + if boundaries is None: + return {}, {} + + gdf = boundaries + if boundary_id_column not in gdf.columns: + if gdf.index.name == boundary_id_column: + gdf = gdf.reset_index() + else: + return {}, {} + + def _pick_largest(group): + largest = None + max_area = -1.0 + for geom in group.geometry: + if geom is None or getattr(geom, "is_empty", True): + continue + if isinstance(geom, MultiPolygon): + geom = extract_largest_polygon(geom) + if not isinstance(geom, Polygon) or geom is None or geom.is_empty: + continue + area = float(geom.area) + if area > max_area: + max_area = area + largest = geom + return largest + + if boundary_type_column in gdf.columns: + cells = gdf[gdf[boundary_type_column] == boundary_cell_value] + nuclei = gdf[gdf[boundary_type_column] == boundary_nucleus_value] + else: + cells = gdf + nuclei = gdf.iloc[0:0] + + cell_lookup: Dict[Any, Polygon] = {} + for cell_id, group in cells.groupby(boundary_id_column): + poly = _pick_largest(group) + if poly is not None: + cell_lookup[cell_id] = poly + + nucleus_lookup: Dict[Any, Polygon] = {} + for cell_id, group in nuclei.groupby(boundary_id_column): + poly = _pick_largest(group) + if poly is not None: + nucleus_lookup[cell_id] = poly + + return cell_lookup, nucleus_lookup + + +def get_indices_indptr(input_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get sparse matrix representation for cluster assignments. + + Parameters + ---------- + input_array : np.ndarray + Array of cluster labels. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Indices and indptr arrays for CSR-like representation. + """ + clusters = sorted(np.unique(input_array[input_array != 0])) + indptr = np.zeros(len(clusters), dtype=np.uint32) + indices = [] + + for cluster in clusters: + cluster_indices = np.where(input_array == cluster)[0] + indptr[cluster - 1] = len(indices) + indices.extend(cluster_indices) + + indices.extend(-np.zeros(len(input_array[input_array == 0]))) + indices = np.array(indices, dtype=np.int32).astype(np.uint32) + return indices, indptr + + +def generate_experiment_file( + template_path: Path, + output_path: Path, + cells_name: str = "seg_cells", + analysis_name: str = "seg_analysis", +) -> None: + """Generate Xenium experiment manifest file. + + Parameters + ---------- + template_path : Path + Path to template experiment.xenium file. + output_path : Path + Path for output experiment file. + cells_name : str + Name of cells Zarr file (without extension). + analysis_name : str + Name of analysis Zarr file (without extension). + Notes + ----- + We only replace the cells and analysis Zarr paths, preserving all other + entries (including morphology image references). This keeps multi-channel + morphology_focus image stacks intact for segmentation kit datasets. + """ + with open(template_path) as f: + experiment = json.load(f) + + experiment["xenium_explorer_files"]["cells_zarr_filepath"] = f"{cells_name}.zarr.zip" + experiment["xenium_explorer_files"].pop("cell_features_zarr_filepath", None) + experiment["xenium_explorer_files"]["analysis_zarr_filepath"] = f"{analysis_name}.zarr.zip" + + with open(output_path, "w") as f: + json.dump(experiment, f, indent=2) + + +def seg2explorer( + seg_df: Union[pd.DataFrame, pl.DataFrame], + source_path: Union[str, Path], + output_dir: Union[str, Path], + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", + xenium_filename: str = "seg_experiment.xenium", + analysis_df: Optional[pd.DataFrame] = None, + cell_id_column: str = "seg_cell_id", + x_column: str = "x", + y_column: str = "y", + z_column: Optional[str] = "z", + nucleus_column: Optional[str] = "cell_compartment", + nucleus_value: int = 2, + area_low: float = 10, + area_high: float = 100, + polygon_max_vertices: int = 13, + boundary_method: str = "delaunay", + boundary_voxel_size: float = 0.0, + boundaries: Optional["gpd.GeoDataFrame"] = None, + boundary_id_column: str = "cell_id", + boundary_type_column: str = "boundary_type", + boundary_cell_value: str = "cell", + boundary_nucleus_value: str = "nucleus", + cell_id_columns: Optional[str] = None, +) -> None: + """Convert segmentation results to Xenium Explorer format. + + Parameters + ---------- + seg_df : Union[pd.DataFrame, pl.DataFrame] + Segmented transcript DataFrame with cell assignments. + source_path : Union[str, Path] + Path to source Xenium data directory. + output_dir : Union[str, Path] + Output directory for Zarr files. + cells_filename : str + Filename prefix for cells Zarr. + analysis_filename : str + Filename prefix for analysis Zarr. + xenium_filename : str + Filename for experiment manifest. + analysis_df : Optional[pd.DataFrame] + Optional clustering/annotation DataFrame. + cell_id_column : str + Column name for cell IDs. + x_column : str + Column name for x coordinates. + y_column : str + Column name for y coordinates. + z_column : Optional[str] + Column name for z coordinates (if available). + nucleus_column : Optional[str] + Column name for nucleus/compartment assignment. + nucleus_value : int + Value indicating nuclear compartment. + area_low : float + Minimum cell area threshold. + area_high : float + Maximum cell area threshold. + polygon_max_vertices : int + Maximum number of vertices per polygon (including closure). + """ + if cell_id_columns is not None: + cell_id_column = cell_id_columns + + if boundary_method == "skip": + raise ValueError("boundary_method='skip' is not supported for Xenium export.") + + # Convert Polars to pandas + if isinstance(seg_df, pl.DataFrame): + seg_df = seg_df.to_pandas() + + source_path = Path(source_path) + storage = Path(output_dir) + storage.mkdir(parents=True, exist_ok=True) + + cell_boundaries: Dict[Any, Polygon] = {} + nucleus_boundaries: Dict[Any, Polygon] = {} + if boundary_method == "input": + cell_boundaries, nucleus_boundaries = _prepare_input_boundaries( + boundaries=boundaries, + boundary_id_column=boundary_id_column, + boundary_type_column=boundary_type_column, + boundary_cell_value=boundary_cell_value, + boundary_nucleus_value=boundary_nucleus_value, + ) + + # Drop unassigned cells if numeric + if cell_id_column in seg_df.columns: + if pd.api.types.is_numeric_dtype(seg_df[cell_id_column]): + seg_df = seg_df[seg_df[cell_id_column] >= 0] + else: + seg_df = seg_df[seg_df[cell_id_column].notna()] + + cell_id2old_id: Dict[int, Any] = {} + cell_id: List[int] = [] + cell_summary_rows: List[List[float]] = [] + cell_num_vertices: List[int] = [] + nucleus_num_vertices: List[int] = [] + cell_vertices: List[List[Tuple[float, float]]] = [] + nucleus_vertices: List[List[Tuple[float, float]]] = [] + + grouped_by = seg_df.groupby(cell_id_column) + + for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm( + enumerate(grouped_by), total=len(grouped_by), desc="Processing cells" + ): + if len(seg_cell) < 5: + continue + + if boundary_method == "input" and cell_boundaries: + cell_poly = cell_boundaries.get(seg_cell_id) + else: + fallback_method = "delaunay" if boundary_method == "input" else boundary_method + cell_poly = _safe_boundary_polygon( + seg_cell, + x=x_column, + y=y_column, + boundary_method=fallback_method, + boundary_voxel_size=boundary_voxel_size, + ) + if cell_poly is None or not (area_low <= cell_poly.area <= area_high): + continue + + # Nucleus polygon (optional) + nucleus_poly = None + if boundary_method == "input" and nucleus_boundaries: + nucleus_poly = nucleus_boundaries.get(seg_cell_id) + elif nucleus_column is not None and nucleus_column in seg_cell.columns: + seg_nucleus = seg_cell[seg_cell[nucleus_column] == nucleus_value] + if len(seg_nucleus) >= 3: + nucleus_poly = MultiPoint(seg_nucleus[[x_column, y_column]].values).convex_hull + if isinstance(nucleus_poly, MultiPolygon): + nucleus_poly = extract_largest_polygon(nucleus_poly) + if not isinstance(nucleus_poly, Polygon) or nucleus_poly.is_empty: + nucleus_poly = None + + cell_coords, cell_nv = _normalize_polygon_vertices(cell_poly, polygon_max_vertices) + if cell_nv == 0: + continue + + zero_vertices = [(0.0, 0.0)] * polygon_max_vertices + if nucleus_poly is not None: + nuc_coords, nuc_nv = _normalize_polygon_vertices(nucleus_poly, polygon_max_vertices) + else: + nuc_coords, nuc_nv = zero_vertices, 0 + + uint_cell_id = cell_incremental_id + 1 + cell_id2old_id[uint_cell_id] = seg_cell_id + cell_id.append(uint_cell_id) + + # Compute z-level if available + z_level = 0.0 + if z_column is not None and z_column in seg_cell.columns: + z_level = (seg_cell[z_column].mean() // 3) * 3 + + cell_centroid = cell_poly.centroid + nucleus_centroid = nucleus_poly.centroid if nucleus_poly is not None else None + + cell_summary_rows.append([ + float(cell_centroid.x), + float(cell_centroid.y), + float(cell_poly.area), + float(nucleus_centroid.x) if nucleus_centroid is not None else 0.0, + float(nucleus_centroid.y) if nucleus_centroid is not None else 0.0, + float(nucleus_poly.area) if nucleus_poly is not None else 0.0, + float(z_level), + float(1 if nucleus_poly is not None else 0), + ]) + + cell_num_vertices.append(cell_nv) + nucleus_num_vertices.append(nuc_nv) + cell_vertices.append(cell_coords) + nucleus_vertices.append(nuc_coords) + + if len(cell_id) == 0: + raise ValueError("No valid cells found in segmentation data.") + + n_cells = len(cell_id) + cell_vertices_arr = np.array(cell_vertices, dtype=np.float32) + nucleus_vertices_arr = np.array(nucleus_vertices, dtype=np.float32) + cell_vertices_flat = cell_vertices_arr.reshape(n_cells, -1) + nucleus_vertices_flat = nucleus_vertices_arr.reshape(n_cells, -1) + + # Open source store and create new store + source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") + existing_store = zarr.open(source_zarr_store, mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + + # Root datasets + cell_id_arr = np.zeros((n_cells, 2), dtype=np.uint32) + cell_id_arr[:, 1] = np.array(cell_id, dtype=np.uint32) + new_store["cell_id"] = cell_id_arr + new_store["cell_summary"] = np.array(cell_summary_rows, dtype=np.float64) + + # Polygon sets + polygon_group = new_store.create_group("polygon_sets") + + # Nucleus polygons (set 0) + set0 = polygon_group.create_group("0") + set0["cell_index"] = np.array(cell_id, dtype=np.uint32) + set0["method"] = np.zeros(n_cells, dtype=np.uint32) + set0["num_vertices"] = np.array(nucleus_num_vertices, dtype=np.int32) + set0["vertices"] = nucleus_vertices_flat.astype(np.float32) + + # Cell polygons (set 1) + set1 = polygon_group.create_group("1") + set1["cell_index"] = np.array(cell_id, dtype=np.uint32) + set1["method"] = np.full(n_cells, 1, dtype=np.uint32) + set1["num_vertices"] = np.array(cell_num_vertices, dtype=np.int32) + set1["vertices"] = cell_vertices_flat.astype(np.float32) + + # Update attributes + attrs = dict(existing_store.attrs) + attrs["number_cells"] = n_cells + attrs["polygon_set_names"] = ["nucleus", "cell"] + attrs["polygon_set_display_names"] = ["Nucleus", "Cell"] + attrs["polygon_set_descriptions"] = [ + "Segger nucleus boundaries", + "Segger cell boundaries", + ] + cell_method = f"segger_cell_{boundary_method}" + nucleus_method = "segger_nucleus_convex_hull" + if boundary_method == "input" and nucleus_boundaries: + nucleus_method = "segger_nucleus_input" + attrs["segmentation_methods"] = [nucleus_method, cell_method] + attrs.setdefault("spatial_units", "microns") + attrs.setdefault("major_version", 4) + attrs.setdefault("minor_version", 0) + new_store.attrs.update(attrs) + + new_store.store.close() + source_zarr_store.close() + + # Create analysis data + if analysis_df is None: + analysis_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] + ) + analysis_df["default"] = "segger" + + zarr_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] + ) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_column) + clusters_names = [col for col in analysis_df.columns if col != cell_id_column] + + clusters_dict = { + cluster: { + label: idx + 1 + for idx, label in enumerate(sorted(np.unique(clustering_df[cluster].dropna()))) + } + for cluster in clusters_names + } + + new_zarr = zarr.open(storage / f"{analysis_filename}.zarr.zip", mode="w") + new_zarr.create_group("/cell_groups") + + for i, cluster in enumerate(clusters_names): + new_zarr["cell_groups"].create_group(str(i)) + group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] + indices, indptr = get_indices_indptr(np.array(group_values)) + new_zarr["cell_groups"][str(i)]["indices"] = indices + new_zarr["cell_groups"][str(i)]["indptr"] = indptr + + new_zarr["cell_groups"].attrs.update({ + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + sorted(clusters_dict[cluster], key=clusters_dict[cluster].get) + for cluster in clusters_names + ], + }) + new_zarr.store.close() + + generate_experiment_file( + template_path=source_path / "experiment.xenium", + output_path=storage / xenium_filename, + cells_name=cells_filename, + analysis_name=analysis_filename, + ) + + +def _process_one_cell(args: tuple) -> Optional[dict]: + """Process a single cell for parallel boundary generation.""" + ( + seg_cell_id, + seg_cell, + x_col, + y_col, + z_col, + nucleus_column, + nucleus_value, + area_low, + area_high, + polygon_max_vertices, + boundary_method, + boundary_voxel_size, + ) = args + + if len(seg_cell) < 5: + return None + + cell_poly = _safe_boundary_polygon( + seg_cell, + x=x_col, + y=y_col, + boundary_method=boundary_method, + boundary_voxel_size=boundary_voxel_size, + ) + if cell_poly is None or not (area_low <= cell_poly.area <= area_high): + return None + + cell_vertices, cell_nv = _normalize_polygon_vertices(cell_poly, polygon_max_vertices) + if cell_nv == 0: + return None + + # Nucleus polygon (optional) + nucleus_poly = None + if nucleus_column is not None and nucleus_column in seg_cell.columns: + seg_nucleus = seg_cell[seg_cell[nucleus_column] == nucleus_value] + if len(seg_nucleus) >= 3: + nucleus_poly = MultiPoint(seg_nucleus[[x_col, y_col]].values).convex_hull + if isinstance(nucleus_poly, MultiPolygon): + nucleus_poly = extract_largest_polygon(nucleus_poly) + if not isinstance(nucleus_poly, Polygon) or nucleus_poly.is_empty: + nucleus_poly = None + + if nucleus_poly is not None: + nucleus_vertices, nucleus_nv = _normalize_polygon_vertices( + nucleus_poly, polygon_max_vertices + ) + else: + nucleus_vertices = [(0.0, 0.0)] * polygon_max_vertices + nucleus_nv = 0 + + # Compute z-level if available + z_level = 0.0 + if z_col is not None and z_col in seg_cell.columns: + z_level = (seg_cell[z_col].mean() // 3) * 3 + + cell_centroid = cell_poly.centroid + nucleus_centroid = nucleus_poly.centroid if nucleus_poly is not None else None + + return { + "seg_cell_id": seg_cell_id, + "cell_area": float(cell_poly.area), + "cell_vertices": cell_vertices, + "cell_num_vertices": cell_nv, + "nucleus_vertices": nucleus_vertices, + "nucleus_num_vertices": nucleus_nv, + "cell_centroid_x": float(cell_centroid.x), + "cell_centroid_y": float(cell_centroid.y), + "nucleus_centroid_x": float(nucleus_centroid.x) if nucleus_centroid else 0.0, + "nucleus_centroid_y": float(nucleus_centroid.y) if nucleus_centroid else 0.0, + "nucleus_area": float(nucleus_poly.area) if nucleus_poly is not None else 0.0, + "z_level": float(z_level), + "nucleus_count": float(1 if nucleus_poly is not None else 0), + } + + +def seg2explorer_pqdm( + seg_df: Union[pd.DataFrame, pl.DataFrame], + source_path: Union[str, Path], + output_dir: Union[str, Path], + cells_filename: str = "seg_cells", + analysis_filename: str = "seg_analysis", + xenium_filename: str = "seg_experiment.xenium", + analysis_df: Optional[pd.DataFrame] = None, + cell_id_column: str = "seg_cell_id", + x_column: str = "x", + y_column: str = "y", + z_column: Optional[str] = "z", + nucleus_column: Optional[str] = "cell_compartment", + nucleus_value: int = 2, + area_low: float = 10, + area_high: float = 100, + n_jobs: int = 1, + polygon_max_vertices: int = 13, + boundary_method: str = "delaunay", + boundary_voxel_size: float = 0.0, + boundaries: Optional["gpd.GeoDataFrame"] = None, + boundary_id_column: str = "cell_id", + boundary_type_column: str = "boundary_type", + boundary_cell_value: str = "cell", + boundary_nucleus_value: str = "nucleus", + cell_id_columns: Optional[str] = None, +) -> None: + """Parallelized version of seg2explorer using pqdm. + + Parameters + ---------- + seg_df : Union[pd.DataFrame, pl.DataFrame] + Segmented transcript DataFrame. + source_path : Union[str, Path] + Path to source Xenium data. + output_dir : Union[str, Path] + Output directory. + cells_filename : str + Cells Zarr filename prefix. + analysis_filename : str + Analysis Zarr filename prefix. + xenium_filename : str + Experiment manifest filename. + analysis_df : Optional[pd.DataFrame] + Optional clustering annotations. + cell_id_column : str + Cell ID column name. + x_column : str + X coordinate column name. + y_column : str + Y coordinate column name. + z_column : Optional[str] + Z coordinate column name (if available). + nucleus_column : Optional[str] + Column name for nucleus/compartment assignment. + nucleus_value : int + Value indicating nuclear compartment. + area_low : float + Minimum cell area. + area_high : float + Maximum cell area. + n_jobs : int + Number of parallel workers. + polygon_max_vertices : int + Maximum number of vertices per polygon (including closure). + """ + if cell_id_columns is not None: + cell_id_column = cell_id_columns + + if boundary_method == "skip": + raise ValueError("boundary_method='skip' is not supported for Xenium export.") + if boundary_method == "input" and boundaries is not None: + raise ValueError( + "Parallel Xenium export does not support boundary_method='input'. " + "Use seg2explorer (serial) when passing input boundaries." + ) + if boundary_method == "input": + boundary_method = "delaunay" + + # Convert Polars to pandas + if isinstance(seg_df, pl.DataFrame): + seg_df = seg_df.to_pandas() + + source_path = Path(source_path) + storage = Path(output_dir) + storage.mkdir(parents=True, exist_ok=True) + + grouped_by = seg_df.groupby(cell_id_column) + + def _work_iter(): + return ( + ( + seg_cell_id, + seg_cell, + x_column, + y_column, + z_column, + nucleus_column, + nucleus_value, + area_low, + area_high, + polygon_max_vertices, + boundary_method, + boundary_voxel_size, + ) + for seg_cell_id, seg_cell in grouped_by + ) + + # Process backend first for throughput and "whole job" progress visibility. + # If the process pool crashes, restart once with thread workers. + try: + results = pqdm_processes( + _work_iter(), + _process_one_cell, + n_jobs=n_jobs, + desc="Processing cells", + exception_behaviour="immediate", + ) + except BrokenProcessPool: + if pqdm_threads is None: + raise RuntimeError( + "Process workers crashed and pqdm thread backend is unavailable." + ) + tqdm.write( + "Warning: process workers crashed during Xenium export. " + "Retrying with thread workers from 0% (completed process results " + "cannot be recovered by pqdm)." + ) + results = pqdm_threads( + _work_iter(), + _process_one_cell, + n_jobs=n_jobs, + desc="Processing cells (thread fallback)", + exception_behaviour="immediate", + ) + + # Collate results + cell_id2old_id: Dict[int, Any] = {} + cell_id: List[int] = [] + cell_num_vertices: List[int] = [] + nucleus_num_vertices: List[int] = [] + cell_vertices: List[List[Any]] = [] + nucleus_vertices: List[List[Any]] = [] + cell_summary_rows: List[List[float]] = [] + + kept = [r for r in results if r is not None] + for cell_incremental_id, r in enumerate(kept): + uint_cell_id = cell_incremental_id + 1 + cell_id2old_id[uint_cell_id] = r["seg_cell_id"] + cell_id.append(uint_cell_id) + cell_num_vertices.append(r["cell_num_vertices"]) + nucleus_num_vertices.append(r["nucleus_num_vertices"]) + cell_vertices.append(r["cell_vertices"]) + nucleus_vertices.append(r["nucleus_vertices"]) + cell_summary_rows.append([ + r["cell_centroid_x"], + r["cell_centroid_y"], + r["cell_area"], + r["nucleus_centroid_x"], + r["nucleus_centroid_y"], + r["nucleus_area"], + r["z_level"], + r["nucleus_count"], + ]) + + if len(cell_id) == 0: + raise ValueError("No valid cells found in segmentation data.") + + n_cells = len(cell_id) + cell_vertices_arr = np.array(cell_vertices, dtype=np.float32) + nucleus_vertices_arr = np.array(nucleus_vertices, dtype=np.float32) + cell_vertices_flat = cell_vertices_arr.reshape(n_cells, -1) + nucleus_vertices_flat = nucleus_vertices_arr.reshape(n_cells, -1) + + # Open source and create new store + source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") + existing_store = zarr.open(source_zarr_store, mode="r") + new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") + + # Root datasets + cell_id_arr = np.zeros((n_cells, 2), dtype=np.uint32) + cell_id_arr[:, 1] = np.array(cell_id, dtype=np.uint32) + new_store["cell_id"] = cell_id_arr + new_store["cell_summary"] = np.array(cell_summary_rows, dtype=np.float64) + + polygon_group = new_store.create_group("polygon_sets") + + # Nucleus polygons (set 0) + set0 = polygon_group.create_group("0") + set0["cell_index"] = np.array(cell_id, dtype=np.uint32) + set0["method"] = np.zeros(n_cells, dtype=np.uint32) + set0["num_vertices"] = np.array(nucleus_num_vertices, dtype=np.int32) + set0["vertices"] = nucleus_vertices_flat.astype(np.float32) + + # Cell polygons (set 1) + set1 = polygon_group.create_group("1") + set1["cell_index"] = np.array(cell_id, dtype=np.uint32) + set1["method"] = np.full(n_cells, 1, dtype=np.uint32) + set1["num_vertices"] = np.array(cell_num_vertices, dtype=np.int32) + set1["vertices"] = cell_vertices_flat.astype(np.float32) + + attrs = dict(existing_store.attrs) + attrs["number_cells"] = n_cells + attrs["polygon_set_names"] = ["nucleus", "cell"] + attrs["polygon_set_display_names"] = ["Nucleus", "Cell"] + attrs["polygon_set_descriptions"] = [ + "Segger nucleus boundaries", + "Segger cell boundaries", + ] + attrs["segmentation_methods"] = ["segger_nucleus_convex_hull", f"segger_cell_{boundary_method}"] + attrs.setdefault("spatial_units", "microns") + attrs.setdefault("major_version", 4) + attrs.setdefault("minor_version", 0) + new_store.attrs.update(attrs) + new_store.store.close() + source_zarr_store.close() + + # Create analysis data + if analysis_df is None: + analysis_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] + ) + analysis_df["default"] = "segger" + + zarr_df = pd.DataFrame( + [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] + ) + clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_column) + clusters_names = [col for col in analysis_df.columns if col != cell_id_column] + + clusters_dict = { + cluster: { + label: idx + 1 + for idx, label in enumerate(sorted(np.unique(clustering_df[cluster].dropna()))) + } + for cluster in clusters_names + } + + new_zarr = zarr.open(storage / f"{analysis_filename}.zarr.zip", mode="w") + new_zarr.create_group("/cell_groups") + + for i, cluster in enumerate(clusters_names): + new_zarr["cell_groups"].create_group(str(i)) + group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] + indices, indptr = get_indices_indptr(np.array(group_values)) + new_zarr["cell_groups"][str(i)]["indices"] = indices + new_zarr["cell_groups"][str(i)]["indptr"] = indptr + + new_zarr["cell_groups"].attrs.update({ + "major_version": 1, + "minor_version": 0, + "number_groupings": len(clusters_names), + "grouping_names": clusters_names, + "group_names": [ + sorted(clusters_dict[cluster], key=clusters_dict[cluster].get) + for cluster in clusters_names + ], + }) + new_zarr.store.close() + + generate_experiment_file( + template_path=source_path / "experiment.xenium", + output_path=storage / xenium_filename, + cells_name=cells_filename, + analysis_name=analysis_filename, + ) From 27e68168b4f29c93da132109d416dbe53bc83942 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 17:37:02 +0200 Subject: [PATCH 07/20] Update with optional dependencies --- pyproject.toml | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 097a80e..f336294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ dependencies = [ "opencv-python", "pandas", "polars", + "pqdm", "pyarrow", + "rtree", "scanpy", "scipy", "shapely", @@ -29,6 +31,33 @@ dependencies = [ "scikit-learn", "tifffile", "torch_geometric", + "zarr", +] + +[project.optional-dependencies] +spatialdata = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", +] + +spatialdata-io = [ + "spatialdata-io>=0.6.0", +] + +sopa = [ + "sopa>=2.0.0", + "spatialdata>=0.7.2", +] + +spatialdata-all = [ + "spatialdata>=0.7.2", + "spatialdata-io>=0.6.0", + "sopa>=2.0.0", +] + +plot = [ + "matplotlib>=3.7", + "uniplot>=0.10.0", ] [build-system] @@ -39,4 +68,4 @@ build-backend = "hatchling.build" packages = ["src/segger"] [project.scripts] -segger = "segger.cli.main:app" \ No newline at end of file +segger = "segger.cli.main:app" From 2049276cc2cfd67fd1c0516e813d3471d8609742 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 18:49:49 +0200 Subject: [PATCH 08/20] Restore debugging API --- src/segger/utils/__init__.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/segger/utils/__init__.py b/src/segger/utils/__init__.py index be161eb..4ae2257 100644 --- a/src/segger/utils/__init__.py +++ b/src/segger/utils/__init__.py @@ -1,5 +1,24 @@ """Utility modules for Segger.""" +import logging +import os +import sys +def setup_logging(level: str = "WARNING", log_file: str = None): + fmt = "%(asctime)s | %(levelname)-8s | %(name)s:%(lineno)d - %(message)s" + datefmt = "%Y-%m-%d %H:%M:%S" + + handlers = [logging.StreamHandler(sys.stdout)] + if log_file: + handlers.append(logging.FileHandler(log_file)) + + logging.basicConfig( + level=getattr(logging, level.upper()), + format=fmt, + datefmt=datefmt, + handlers=handlers, + force=True, # override any previously set handlers + ) + from segger.utils.optional_deps import ( # Availability flags SPATIALDATA_AVAILABLE, From 0bcff40cb3ed8061055d518a672e8ca1bd0a8009 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 19:44:38 +0200 Subject: [PATCH 09/20] Adjust parameters registering, align spatialdata writers arguments --- src/segger/cli/segment.py | 1 + src/segger/data/data_module.py | 1 - src/segger/data/utils/anndata.py | 6 ++- src/segger/data/writer.py | 67 ++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index a828ef8..6dc1fa0 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -411,6 +411,7 @@ def segment( csvlogger = CSVLogger(output_directory) writer = ISTSegmentationWriter( + input_directory, output_directory, save_anndata=save_anndata, save_spatialdata=save_spatialdata, diff --git a/src/segger/data/data_module.py b/src/segger/data/data_module.py index 650fffd..512c699 100644 --- a/src/segger/data/data_module.py +++ b/src/segger/data/data_module.py @@ -277,7 +277,6 @@ def load(self): prediction_graph_max_k=self.prediction_graph_max_k, prediction_graph_buffer_ratio=self.prediction_graph_buffer_ratio, use_3d=self.use_3d, - me_gene_pairs=self.me_gene_pairs, ) # Tile graph dataset diff --git a/src/segger/data/utils/anndata.py b/src/segger/data/utils/anndata.py index 93db4c4..293fccc 100644 --- a/src/segger/data/utils/anndata.py +++ b/src/segger/data/utils/anndata.py @@ -195,7 +195,11 @@ def setup_anndata( # Build gene embedding on filtered dataset C = np.corrcoef(ad[ad.obs['filtered']].layers['norm'].todense().T) C = np.nan_to_num(C, 0, posinf=True, neginf=True) - model = sklearn.decomposition.PCA(n_components=cells_embedding_size) + # model = sklearn.decomposition.PCA(n_components=cells_embedding_size) + model = sklearn.decomposition.PCA(n_components=min(cells_embedding_size, ad.var.shape[0])) + if ad.var.shape[0] < cells_embedding_size: + import warnings + warnings.warn('cell embedding size is larger than input feature space, falling back to that size.') ad.varm['X_corr'] = model.fit_transform(C) # Build PCs on filtered cells and project all cells diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index 5db1498..3e57526 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -12,6 +12,7 @@ from ..io import TrainingTranscriptFields, TrainingBoundaryFields from . import ISTDataModule from .utils.anndata import anndata_from_transcripts +from ..export.spatialdata_writer import SpatialDataWriter class ISTSegmentationWriter(BasePredictionWriter): """TODO: Description @@ -23,14 +24,18 @@ class ISTSegmentationWriter(BasePredictionWriter): """ def __init__( self, + input_directory: Path, output_directory: Path, save_anndata: bool = True, + save_spatialdata: bool = True, debug: bool = False ): # "write" callback at the end of prediction epoch super().__init__(write_interval="epoch") + self.input_directory = Path (input_directory) self.output_directory = Path(output_directory) self.save_anndata = save_anndata + self.save_spatialdata = save_spatialdata self.segger_logger = logging.getLogger(__name__) # setup debugging @@ -124,6 +129,24 @@ def write_anndata( ) adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad') + if self.save_spatialdata: + writer = SpatialDataWriter( + include_boundaries="True", + boundary_method='convex_hull', + # boundary_n_jobs=max(num_workers, 1), + ) + tx, _ = _resolve_transcripts_and_boundaries(self.input_directory) + output_path = writer.write( + predictions=segmentation, + output_dir=self.output_directory, + transcripts=tx, + # boundaries=bd, + output_name="segger_segmentation.zarr", + ) + print(f"Written SpatialData output: {output_path}") + + + @classmethod def assign_transcripts_to_cells( cls, @@ -286,3 +309,47 @@ def on_fit_end(self, trainer, pl_module): self.segger_logger.debug(f"Saving trainer state to {self.path_debug / 'trainer_state_final.ckpt'}") trainer.save_checkpoint(self.path_debug / "trainer_state_final.ckpt") +def _is_spatialdata_path(path: Path | str) -> bool: + try: + from ..io.spatialdata_loader import is_spatialdata_path as _impl + return _impl(path) + except Exception: + p = Path(path) + return ( + p.suffix == ".zarr" + or (p / ".zgroup").exists() + or (p / "zarr.json").exists() + or (p / "points").exists() + or (p / "shapes").exists() + ) + + +def _resolve_transcripts_and_boundaries(source_path): + "Spatialdata loader for tx/bd. Hardcoded to Xenium naming." + if _is_spatialdata_path(source_path): + try: + from ..io.spatialdata_loader import load_from_spatialdata + except Exception as exc: + raise ImportError( + "SpatialData input requested, but spatialdata support is unavailable. " + "Install with: pip install segger[spatialdata]" + ) from exc + tx, bd = load_from_spatialdata( + source_path, + points_key="transcripts", + cell_shapes_key="cell_boundaries", + nucleus_shapes_key="nucleus_boundaries", + boundary_type="all", + ) + return (tx.collect() if isinstance(tx, pl.LazyFrame) else tx), bd + + from ..io import get_preprocessor + pp = get_preprocessor(source_path) + tx = pp.transcripts + if isinstance(tx, pl.LazyFrame): + tx = tx.collect() + try: + bd = pp.boundaries + except Exception: + bd = None + return tx, bd \ No newline at end of file From 344365e48162b4899d5711ab63b91296545ce6ad Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 19:51:07 +0200 Subject: [PATCH 10/20] Return separate shape elements on input boundarie --- src/segger/export/spatialdata_writer.py | 253 +++++++++++++++--------- 1 file changed, 155 insertions(+), 98 deletions(-) diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index 18720c5..2b1e4f7 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -26,6 +26,7 @@ from __future__ import annotations +import warnings from pathlib import Path from typing import TYPE_CHECKING, Literal, Optional @@ -82,7 +83,8 @@ def __init__( points_key: str = "transcripts", shapes_key: str = "cells", include_table: bool = True, - table_key: str = "cell_table", + table_key: str = "cells_table", # no duplicate names allowed + # fragment_table_key: str = "fragments_table", table_region_key: str = "cell_id", ): require_spatialdata() @@ -247,6 +249,7 @@ def _create_spatialdata( """Create SpatialData object from transcripts and boundaries.""" import spatialdata from spatialdata.models import PointsModel, ShapesModel, TableModel + import spatialdata.models._accessor # for points parsing on pre-release (https://github.com/scverse/spatialdata/issues/1093) import dask.dataframe as dd identity = self._identity_transform() @@ -257,7 +260,10 @@ def _create_spatialdata( # SOPA expects "cell_id" assignment in points. if cell_id_column in tx_pd.columns and "cell_id" not in tx_pd.columns: - tx_pd["cell_id"] = tx_pd[cell_id_column] + tx_pd['cell_id']= tx_pd[cell_id_column] + #NOTE: having both 'cell_id' and 'segger_cell_id' creates confusion + # tx_pd = tx_pd.rename(columns={cell_id_column: "cell_id"}) + # this would be better but fails as later code still relies on cell_id_column # Check for z-coordinate has_z = z_column and z_column in tx_pd.columns @@ -274,7 +280,7 @@ def _create_spatialdata( tx_pd[col] = tx_pd[col].astype(float) # Create Dask DataFrame for points - tx_dask = dd.from_pandas(tx_pd, npartitions=1) + tx_dask = dd.from_pandas(tx_pd) # Points element points_parse_kwargs = { @@ -283,33 +289,68 @@ def _create_spatialdata( "y": y_column, **({"z": z_column} if has_z else {}), }, + "instance_key": cell_id_column, # or 'cell_id' which is hard-coded now + "feature_key": feature_column, } if transformations is not None: points_parse_kwargs["transformations"] = transformations points = PointsModel.parse(tx_dask, **points_parse_kwargs) points_elements = {self.points_key: points} + + # Shapes + def _ensure_cell_id(gdf): + if gdf is None: + return None + if "cell_id" in gdf.columns: + return gdf + if cell_id_column in gdf.columns: + gdf = gdf.copy() + gdf["cell_id"] = gdf[cell_id_column] + return gdf + gdf = gdf.reset_index(drop=False) + if "cell_id" not in gdf.columns and len(gdf.columns) > 0: + gdf["cell_id"] = gdf[gdf.columns[0]] + return gdf + + + def _parse_shapes(shapes): + if shapes is None or len(shapes) == 0: + return None + kwargs = {"transformations": transformations} if transformations is not None else {} + return ShapesModel.parse(shapes, **kwargs) + + shapes_elements = {} - # Shapes element (if boundaries provided or generated) if self.include_boundaries and self.boundary_method != "skip": - shapes = self._get_boundaries( - transcripts=tx_pd, - boundaries=boundaries, - x_column=x_column, - y_column=y_column, - cell_id_column=cell_id_column, - ) - if shapes is not None and len(shapes) > 0: - shapes_parse_kwargs = {} - if transformations is not None: - shapes_parse_kwargs["transformations"] = transformations - shapes_parsed = ShapesModel.parse(shapes, **shapes_parse_kwargs) - shapes_elements[self.shapes_key] = shapes_parsed - - tables_elements = {} + if self.boundary_method == "input": + for bd_type in ["cell", "nucleus"]: # these are segger hard-coded + shapes = self._get_input_boundaries( + cell_tx_pd, + cell_id_column, + boundaries, + bd_type) + shapes = _ensure_cell_id(shapes) + parsed = _parse_shapes(shapes) + if parsed is not None: + shapes_elements[f"{bd_type}_boundaries"] = parsed + # this naming convention is very Xenium-based (ideally one would maintain the input one which is currently lost) + else: + shape_specs = [(self.shapes_key, cell_tx_pd)] + if has_fragments and fragment_tx_pd is not None: + shape_specs.append((self.fragment_shapes_key, fragment_tx_pd)) + + for shape_key, shape_tx_pd in shape_specs: + shapes = self._get_generated_boundaries(shape_tx_pd, x_column, y_column, cell_id_column) + shapes = _ensure_cell_id(shapes) + parsed = _parse_shapes(shapes) + if parsed is not None: + shapes_elements[shape_key] = parsed + # Optional AnnData table + tables_elements = {} if self.include_table: region = self.shapes_key if self.shapes_key in shapes_elements else None instance_key = self.table_region_key if region is not None else None @@ -340,12 +381,12 @@ def _create_spatialdata( pass tables_elements[self.table_key] = table - # Create SpatialData (prefer modern constructor methods, keep fallback) + # Create SpatialData (prefer modern constructor methods, keep fallback on single elemnts) sdata = self._build_spatialdata( spatialdata=spatialdata, - points=points_elements, - shapes=shapes_elements, - tables=tables_elements, + points_elements=points_elements, + shapes_elements=shapes_elements, + tables_elements=tables_elements, ) return sdata @@ -358,34 +399,61 @@ def _identity_transform(self): except Exception: return None - def _build_spatialdata(self, spatialdata, points: dict, shapes: dict, tables: dict): + def _build_spatialdata(self, spatialdata, points_elements: dict, shapes_elements: dict, tables_elements: dict): """Build a SpatialData object across SpatialData API variants.""" - shapes_arg = shapes or None - tables_arg = tables or None if hasattr(spatialdata.SpatialData, "init_from_elements"): - return spatialdata.SpatialData.init_from_elements( - points=points, - shapes=shapes_arg, - tables=tables_arg, + return spatialdata.SpatialData.init_from_elements(points_elements | shapes_elements | tables_elements) + else: + return spatialdata.SpatialData( + points=points_elements, + shapes=shapes_elements, + tables=tables_elements, ) + + + def _build_table_element( + self, + TableModel, + transcripts: pl.DataFrame, + var_transcripts: pl.DataFrame, + region: Optional[str], + cell_id_column: str, + feature_column: str, + x_column: str, + y_column: str, + z_column: Optional[str], + ): + """Build a SpatialData table and attach region metadata when available.""" + table = build_anndata_table( + transcripts=transcripts, + var_transcripts=var_transcripts, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + z_column=z_column, + unassigned_value=-1, + region=None, + region_key=None, + obs_index_as_str=True, + ) + if region is None: + return table + instance_key = self.table_region_key + table.obs["region"] = region + if instance_key and instance_key not in table.obs.columns: + table.obs[instance_key] = table.obs.index.astype(str) try: - return spatialdata.SpatialData( - points=points, - shapes=shapes_arg, - tables=tables_arg, + return TableModel.parse( + table, + region=region, + region_key="region", + instance_key=instance_key or "instance_id", ) except Exception: - elements = {} - for key, value in points.items(): - elements[f"points/{key}"] = value - for key, value in shapes.items(): - elements[f"shapes/{key}"] = value - sdata = spatialdata.SpatialData.from_elements_dict(elements) - for key, value in (tables or {}).items(): - sdata.tables[key] = value - return sdata + return table def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> None: """Write SpatialData object with compatibility fallback.""" @@ -400,78 +468,64 @@ def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> shutil.rmtree(output_path) sdata.write(output_path) - def _get_boundaries( + + + def _get_input_boundaries(self, cell_tx_pd, cell_id_column, boundaries, bd_type): + + selected_ids = cell_tx_pd[cell_id_column].dropna().unique() + if len(selected_ids) == 0 or boundaries is None: + if boundaries is None: + warnings.warn("No input boundaries were found. Skipping boundary generation.") + return None + + boundaries_filtered = boundaries.loc[boundaries['boundary_type'] == bd_type] + boundaries_gdf = boundaries_filtered[boundaries_filtered["cell_id"].isin(selected_ids)].copy() + + return boundaries_gdf if not boundaries_gdf.empty else None + + + + def _get_generated_boundaries( self, - transcripts: "pd.DataFrame", - boundaries: Optional["gpd.GeoDataFrame"], + transcripts: pd.DataFrame, x_column: str, y_column: str, cell_id_column: str, - ) -> Optional["gpd.GeoDataFrame"]: - """Get or generate cell boundaries.""" + ) -> Optional[gpd.GeoDataFrame]: + """Generate cell boundaries based on the selected boundary method. + Args + transcripts: dataframe of group transcripts (cells or fragments) + x_column, y_column: transcripts 2D coordinates + cell_id_column: cell ID + """ import geopandas as gpd - import pandas as pd - from shapely.geometry import MultiPoint - - def _ensure_cell_id(gdf: "gpd.GeoDataFrame") -> "gpd.GeoDataFrame": - if "cell_id" in gdf.columns: - return gdf - if cell_id_column in gdf.columns: - gdf = gdf.copy() - gdf["cell_id"] = gdf[cell_id_column] - return gdf - gdf = gdf.reset_index(drop=False) - if "cell_id" not in gdf.columns and len(gdf.columns) > 0: - gdf["cell_id"] = gdf[gdf.columns[0]] - return gdf - - # Use input boundaries if available - if boundaries is not None: - return _ensure_cell_id(boundaries) - - # Generate boundaries based on method - if self.boundary_method == "input": - # No input boundaries, skip + + assigned = transcripts[transcripts[cell_id_column] != -1].copy() + if assigned.empty: return None - elif self.boundary_method == "convex_hull": - # Generate convex hulls from transcript positions - assigned = transcripts[transcripts[cell_id_column] != -1].copy() - - if len(assigned) == 0: - return None + if self.boundary_method == "convex_hull": + from shapely.geometry import MultiPoint - # Group by cell and create convex hulls - hulls = [] - cell_ids = [] + hulls, cell_ids = [], [] for cell_id, group in assigned.groupby(cell_id_column): if len(group) < 3: - continue # Need at least 3 points for convex hull - + continue points = list(zip(group[x_column], group[y_column])) - mp = MultiPoint(points) - hull = mp.convex_hull - - if not hull.is_empty: - hulls.append(hull) - cell_ids.append(cell_id) + hull = MultiPoint(points).convex_hull + if hull.is_empty or hull.geom_type != "Polygon": + continue + hulls.append(hull) + cell_ids.append(cell_id) if not hulls: return None - - return _ensure_cell_id(gpd.GeoDataFrame( - {"cell_id": cell_ids}, - geometry=hulls, - )) + return gpd.GeoDataFrame({"cell_id": cell_ids}, geometry=hulls) elif self.boundary_method == "delaunay": from segger.export.boundary import generate_boundaries - assigned = transcripts[transcripts[cell_id_column] != -1].copy() - if len(assigned) == 0: - return None - boundaries_gdf = generate_boundaries( assigned, x=x_column, @@ -479,9 +533,12 @@ def _ensure_cell_id(gdf: "gpd.GeoDataFrame") -> "gpd.GeoDataFrame": cell_id=cell_id_column, n_jobs=self.boundary_n_jobs, ) - if boundaries_gdf is None or len(boundaries_gdf) == 0: + boundaries_gdf = boundaries_gdf[ + boundaries_gdf.geometry.notna() & ~boundaries_gdf.geometry.is_empty + ] + if len(boundaries_gdf) == 0: return None - return _ensure_cell_id(boundaries_gdf) + return boundaries_gdf return None From 8948d7a340c0bde19b052ee6c003d6d46d8e5ae9 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 19:51:34 +0200 Subject: [PATCH 11/20] Fix table parsing on input boundaries and improve code behavior --- scripts/benchmark_status_dashboard.sh | 1191 +++++++++++++++++ scripts/build_benchmark_pdf_report.py | 1168 ++++++++++++++++ scripts/build_benchmark_validation_table.sh | 783 +++++++++++ .../build_default_10x_reference_artifacts.py | 230 ++++ scripts/presentation/experiments.md | 361 +++++ scripts/presentation/experiments_plan.md | 327 +++++ scripts/run_ablation_study.sh | 1138 ++++++++++++++++ scripts/run_param_benchmark_2gpu.sh | 764 +++++++++++ scripts/run_robustness_ablation_2gpu.sh | 845 ++++++++++++ src/segger/export/spatialdata_writer.py | 26 +- src/segger/models/alignment_loss.py | 118 ++ src/segger/validation/__init__.py | 14 + src/segger/validation/me_genes.py | 421 ++++++ src/segger/validation/quick_metrics.py | 1050 +++++++++++++++ 14 files changed, 8427 insertions(+), 9 deletions(-) create mode 100755 scripts/benchmark_status_dashboard.sh create mode 100644 scripts/build_benchmark_pdf_report.py create mode 100755 scripts/build_benchmark_validation_table.sh create mode 100755 scripts/build_default_10x_reference_artifacts.py create mode 100644 scripts/presentation/experiments.md create mode 100644 scripts/presentation/experiments_plan.md create mode 100755 scripts/run_ablation_study.sh create mode 100755 scripts/run_param_benchmark_2gpu.sh create mode 100755 scripts/run_robustness_ablation_2gpu.sh create mode 100644 src/segger/models/alignment_loss.py create mode 100644 src/segger/validation/__init__.py create mode 100644 src/segger/validation/me_genes.py create mode 100644 src/segger/validation/quick_metrics.py diff --git a/scripts/benchmark_status_dashboard.sh b/scripts/benchmark_status_dashboard.sh new file mode 100755 index 0000000..667b8f9 --- /dev/null +++ b/scripts/benchmark_status_dashboard.sh @@ -0,0 +1,1191 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'EOF' +Benchmark status snapshot + terminal dashboard. + +Usage: + bash scripts/benchmark_status_dashboard.sh [options] + +Options: + --root Benchmark root directory + (default: ./results/mossi_main_big_benchmark_nightly) + --out-tsv Snapshot TSV output path + (default: /summaries/status_snapshot.tsv) + --watch [sec] Refresh dashboard every N seconds (default: 20) + --no-color Disable ANSI colors + -h, --help Show this help +EOF +} + +ROOT="./results/mossi_main_big_benchmark_nightly" +OUT_TSV="" +WATCH_SEC=0 +NO_COLOR=0 + +while [[ $# -gt 0 ]]; do + case "$1" in + --root) + if [[ $# -lt 2 ]]; then + echo "ERROR: --root requires a value." >&2 + exit 1 + fi + ROOT="$2" + shift 2 + ;; + --out-tsv) + if [[ $# -lt 2 ]]; then + echo "ERROR: --out-tsv requires a value." >&2 + exit 1 + fi + OUT_TSV="$2" + shift 2 + ;; + --watch) + if [[ $# -ge 2 ]] && [[ ! "${2-}" =~ ^- ]]; then + WATCH_SEC="$2" + shift 2 + else + WATCH_SEC=20 + shift + fi + ;; + --no-color) + NO_COLOR=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage + exit 1 + ;; + esac +done + +if [[ -z "${OUT_TSV}" ]]; then + OUT_TSV="${ROOT}/summaries/status_snapshot.tsv" +fi + +if ! [[ "${WATCH_SEC}" =~ ^[0-9]+$ ]]; then + echo "ERROR: --watch must be a non-negative integer." >&2 + exit 1 +fi + +PLAN_FILE="${ROOT}/job_plan.tsv" +SUMMARY_DIR="${ROOT}/summaries" +LOGS_DIR="${ROOT}/logs" +RUNS_DIR="${ROOT}/runs" +EXPORTS_DIR="${ROOT}/exports" +VALIDATION_TSV="${SUMMARY_DIR}/validation_metrics.tsv" +if [[ ! -f "${VALIDATION_TSV}" ]] && [[ -f "${ROOT}/validation_metrics.tsv" ]]; then + VALIDATION_TSV="${ROOT}/validation_metrics.tsv" +fi + +if [[ ! -f "${PLAN_FILE}" ]]; then + echo "ERROR: Missing plan file: ${PLAN_FILE}" >&2 + exit 1 +fi + +mkdir -p "$(dirname "${OUT_TSV}")" + +if [[ "${NO_COLOR}" == "0" ]] && [[ -t 1 ]]; then + C_RESET=$'\033[0m' + C_BOLD=$'\033[1m' + C_BOLD_OFF=$'\033[22m' + C_GREEN=$'\033[32m' + C_RED=$'\033[31m' + C_YELLOW=$'\033[33m' + C_BLUE=$'\033[34m' + C_CYAN=$'\033[36m' +else + C_RESET="" + C_BOLD="" + C_BOLD_OFF="" + C_GREEN="" + C_RED="" + C_YELLOW="" + C_BLUE="" + C_CYAN="" +fi + +collect_status_map() { + local out_file="$1" + local running_jobs_file="${2:-}" + local -a status_files=() + local f + local include_recovery=1 + + if [[ -n "${running_jobs_file}" ]] && [[ -s "${running_jobs_file}" ]]; then + include_recovery=0 + fi + + for f in "${SUMMARY_DIR}"/gpu*.tsv; do + [[ -f "${f}" ]] || continue + status_files+=("${f}") + done + if [[ "${include_recovery}" == "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + status_files+=("${SUMMARY_DIR}/recovery.tsv") + fi + + if [[ "${#status_files[@]}" -eq 0 ]]; then + : > "${out_file}" + return 0 + fi + + awk -F'\t' ' + FNR == 1 { next } + { + job = $1 + gpu = $2 + status = $3 + elapsed = $4 + note = "" + seg = "" + log_path = "" + if (NF >= 7) { + note = $5 + seg = $6 + log_path = $7 + } else if (NF >= 6) { + seg = $5 + log_path = $6 + } + + if (note == "") { + note = "-" + } + gpu_map[job] = gpu + status_map[job] = status + elapsed_map[job] = elapsed + note_map[job] = note + seg_map[job] = seg + log_map[job] = log_path + } + END { + for (job in status_map) { + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n", + job, + gpu_map[job], + status_map[job], + elapsed_map[job], + note_map[job], + seg_map[job], + log_map[job] + } + } + ' "${status_files[@]}" > "${out_file}" +} + +collect_running_jobs() { + local out_file="$1" + if ! command -v pgrep >/dev/null 2>&1; then + : > "${out_file}" + return 0 + fi + + pgrep -af 'segger segment|segger predict' 2>/dev/null \ + | awk ' + { + for (i = 1; i <= NF; i++) { + if ($i == "-o" && (i + 1) <= NF) { + out = $(i + 1) + gsub(/\/+$/, "", out) + n = split(out, a, "/") + if (n > 0) { + print a[n] + } + } + } + } + ' \ + | sed '/^$/d' \ + | sort -u > "${out_file}" || : > "${out_file}" +} + +pick_log_file() { + local job="$1" + local f + local found="" + for f in "${LOGS_DIR}/${job}.gpu"*.log; do + [[ -f "${f}" ]] || continue + found="${f}" + done + printf '%s' "${found}" +} + +build_snapshot() { + local status_map="$1" + local running_jobs="$2" + local out_file="$3" + local tmp_file + local plan_header plan_has_study_block + tmp_file="$(mktemp)" + + plan_header="$(head -n 1 "${PLAN_FILE}")" + plan_has_study_block=0 + if printf '%s\n' "${plan_header}" | tr '\t' '\n' | grep -Fxq "study_block"; then + plan_has_study_block=1 + fi + + printf "job\tgroup\tgpu\tstatus\tstate\trunning\telapsed_s\trun_count\thad_rerun\thad_anc_retry\thad_predict_fallback\thad_recovery_pass\tseg_exists\tanndata_exists\txenium_exists\tseg_dir\tlog_file\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\tnote\n" > "${tmp_file}" + + tail -n +2 "${PLAN_FILE}" | while IFS=$'\t' read -r -a cols; do + local job group use_3d expansion tx_max_k tx_max_dist n_mid_layers n_heads cells_min_counts min_qv alignment_loss + local row gpu status elapsed note seg_dir log_file + local running seg_exists anndata_exists xenium_exists + local run_count had_rerun had_anc_retry had_predict_fallback had_recovery_pass + local state + + if [[ "${plan_has_study_block}" == "1" ]]; then + job="${cols[0]-}" + group="${cols[2]-}" + use_3d="${cols[3]-}" + expansion="${cols[4]-}" + tx_max_k="${cols[5]-}" + tx_max_dist="${cols[6]-}" + n_mid_layers="${cols[7]-}" + n_heads="${cols[8]-}" + cells_min_counts="${cols[9]-}" + min_qv="${cols[10]-}" + alignment_loss="${cols[11]-}" + else + job="${cols[0]-}" + group="${cols[1]-}" + use_3d="${cols[2]-}" + expansion="${cols[3]-}" + tx_max_k="${cols[4]-}" + tx_max_dist="${cols[5]-}" + n_mid_layers="${cols[6]-}" + n_heads="${cols[7]-}" + cells_min_counts="${cols[8]-}" + min_qv="${cols[9]-}" + alignment_loss="${cols[10]-}" + fi + + row="$(awk -F'\t' -v j="${job}" '$1 == j { print; exit }' "${status_map}")" + gpu="" + status="" + elapsed="" + note="" + seg_dir="" + log_file="" + if [[ -n "${row}" ]]; then + IFS=$'\t' read -r _ gpu status elapsed note seg_dir log_file <<< "${row}" + if [[ "${note}" == "-" ]]; then + note="" + fi + fi + + [[ -n "${seg_dir}" ]] || seg_dir="${RUNS_DIR}/${job}" + [[ -n "${log_file}" ]] || log_file="$(pick_log_file "${job}")" + [[ -n "${log_file}" ]] || log_file="${LOGS_DIR}/${job}.gpu?.log" + + running=0 + if [[ -s "${running_jobs}" ]] && grep -Fxq "${job}" "${running_jobs}"; then + running=1 + fi + + seg_exists=0 + anndata_exists=0 + xenium_exists=0 + [[ -f "${RUNS_DIR}/${job}/segger_segmentation.parquet" ]] && seg_exists=1 + [[ -f "${EXPORTS_DIR}/${job}/anndata/segger_segmentation.h5ad" ]] && anndata_exists=1 + [[ -f "${EXPORTS_DIR}/${job}/xenium_explorer/seg_experiment.xenium" ]] && xenium_exists=1 + + run_count=0 + had_rerun=0 + had_anc_retry=0 + had_predict_fallback=0 + had_recovery_pass=0 + if [[ -f "${log_file}" ]]; then + run_count="$(grep -c "START job=${job}" "${log_file}" 2>/dev/null || printf '0')" + if [[ "${run_count}" -gt 1 ]]; then + had_rerun=1 + fi + grep -q "segment failed with ancdata; retrying" "${log_file}" 2>/dev/null && had_anc_retry=1 || true + grep -q "predict fallback succeeded after OOM" "${log_file}" 2>/dev/null && had_predict_fallback=1 || true + grep -q "RECOVERY job=${job}" "${log_file}" 2>/dev/null && had_recovery_pass=1 || true + fi + + state="pending" + if [[ "${running}" == "1" ]]; then + state="running" + elif [[ "${seg_exists}" == "1" && "${anndata_exists}" == "1" && "${xenium_exists}" == "1" ]]; then + state="done" + elif [[ -n "${status}" ]]; then + case "${status}" in + ok|skipped_existing|recovered_predict_ok) + state="partial" + ;; + *) + state="failed" + ;; + esac + else + if [[ "${seg_exists}" == "1" || "${anndata_exists}" == "1" || "${xenium_exists}" == "1" ]]; then + state="partial" + else + state="pending" + fi + fi + + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job}" "${group}" "${gpu}" "${status}" "${state}" "${running}" "${elapsed}" \ + "${run_count}" "${had_rerun}" "${had_anc_retry}" "${had_predict_fallback}" "${had_recovery_pass}" \ + "${seg_exists}" "${anndata_exists}" "${xenium_exists}" "${seg_dir}" "${log_file}" \ + "${use_3d}" "${expansion}" "${tx_max_k}" "${tx_max_dist}" "${n_mid_layers}" "${n_heads}" \ + "${cells_min_counts}" "${min_qv}" "${alignment_loss}" "${note}" \ + >> "${tmp_file}" + done + + mv "${tmp_file}" "${out_file}" +} + +draw_progress_bar() { + local current="$1" + local total="$2" + local width=40 + local fill=0 + local pct=0 + if [[ "${total}" -gt 0 ]]; then + fill=$((current * width / total)) + pct=$((current * 100 / total)) + fi + local empty=$((width - fill)) + local left right + left="$(printf '%*s' "${fill}" '' | tr ' ' '#')" + right="$(printf '%*s' "${empty}" '' | tr ' ' '-')" + printf "[%s%s] %d/%d (%d%%)" "${left}" "${right}" "${current}" "${total}" "${pct}" +} + +render_tsv_table() { + awk -F'\t' ' + function repeat_char(ch, n, out, i) { + out = "" + for (i = 0; i < n; i++) out = out ch + return out + } + function visible_len(s, t) { + t = s + gsub(/\033\[[0-9;]*m/, "", t) + return length(t) + } + function print_cell(val, width, vis, i) { + vis = visible_len(val) + printf(" %s", val) + for (i = vis; i < width; i++) { + printf(" ") + } + printf(" |") + } + { + rows = NR + if ($1 == "__ROW_SEP__") { + row_sep[NR] = 1 + next + } + if (NF > ncols) ncols = NF + for (i = 1; i <= NF; i++) { + val = $i + sub(/\r$/, "", val) + cells[NR, i] = val + vis = visible_len(val) + if (vis > widths[i]) widths[i] = vis + } + } + END { + if (rows == 0) exit + + sep = "+" + for (i = 1; i <= ncols; i++) { + sep = sep repeat_char("-", widths[i] + 2) "+" + } + + print sep + printf("|") + for (i = 1; i <= ncols; i++) { + print_cell(cells[1, i], widths[i]) + } + printf("\n") + print sep + + for (r = 2; r <= rows; r++) { + if (row_sep[r]) { + print sep + continue + } + printf("|") + for (i = 1; i <= ncols; i++) { + print_cell(cells[r, i], widths[i]) + } + printf("\n") + } + print sep + } + ' +} + +colorize_rows_by_state_column() { + local state_col="$1" + awk -F'\t' \ + -v state_col="${state_col}" \ + -v c_done="${C_GREEN}" \ + -v c_run="${C_BLUE}" \ + -v c_fail="${C_RED}" \ + -v c_pending="${C_YELLOW}" \ + -v c_partial="${C_CYAN}" \ + -v c_reset="${C_RESET}" \ + ' + function color_for_state(s, lower) { + lower = tolower(s) + if (lower == "done") return c_done + if (lower == "running") return c_run + if (lower == "failed") return c_fail + if (lower == "pending") return c_pending + if (lower == "reference") return c_partial + if (lower == "partial") return c_partial + return "" + } + NR == 1 { print; next } + { + state = (state_col > 0 && state_col <= NF) ? $state_col : "" + color = color_for_state(state) + if (color != "") { + for (i = 1; i <= NF; i++) { + $i = color $i c_reset + } + } + print + } + ' OFS=$'\t' +} + +colorize_rows_by_status_column() { + local status_col="$1" + awk -F'\t' \ + -v status_col="${status_col}" \ + -v c_done="${C_GREEN}" \ + -v c_run="${C_BLUE}" \ + -v c_fail="${C_RED}" \ + -v c_pending="${C_YELLOW}" \ + -v c_partial="${C_CYAN}" \ + -v c_reset="${C_RESET}" \ + ' + function state_from_status(status, lower) { + lower = tolower(status) + if (status == "" || status == "") return "pending" + if (lower ~ /running|in_progress/) return "running" + if (lower ~ /oom|oot|ancdata|fail|error|missing|recovery_no_checkpoint/) return "failed" + if (lower == "ok" || lower == "skipped_existing" || lower == "recovered_predict_ok") return "done" + return "pending" + } + function color_for_state(s, lower) { + lower = tolower(s) + if (lower == "done") return c_done + if (lower == "running") return c_run + if (lower == "failed") return c_fail + if (lower == "pending") return c_pending + if (lower == "reference") return c_partial + if (lower == "partial") return c_partial + return "" + } + NR == 1 { print; next } + { + status = (status_col > 0 && status_col <= NF) ? $status_col : "" + state = state_from_status(status) + color = color_for_state(state) + if (color != "") { + for (i = 1; i <= NF; i++) { + $i = color $i c_reset + } + } + print + } + ' OFS=$'\t' +} + +render_dashboard() { + local snapshot="$1" + local total done_count running_count pending_count partial_count failed_count + local oom_count oot_count anc_count rerun_count recovered_count processed + local now + + total="$(awk 'END { print NR-1 }' "${snapshot}")" + done_count="$(awk -F'\t' 'NR>1 && $5=="done" {c++} END{print c+0}' "${snapshot}")" + running_count="$(awk -F'\t' 'NR>1 && $5=="running" {c++} END{print c+0}' "${snapshot}")" + pending_count="$(awk -F'\t' 'NR>1 && $5=="pending" {c++} END{print c+0}' "${snapshot}")" + partial_count="$(awk -F'\t' 'NR>1 && $5=="partial" {c++} END{print c+0}' "${snapshot}")" + failed_count="$(awk -F'\t' 'NR>1 && $5=="failed" {c++} END{print c+0}' "${snapshot}")" + + oom_count="$(awk -F'\t' 'NR>1 && $4 ~ /oom/ {c++} END{print c+0}' "${snapshot}")" + oot_count="$(awk -F'\t' 'NR>1 && $4=="segment_oot" {c++} END{print c+0}' "${snapshot}")" + anc_count="$(awk -F'\t' 'NR>1 && ($4=="segment_ancdata" || $10=="1") {c++} END{print c+0}' "${snapshot}")" + rerun_count="$(awk -F'\t' 'NR>1 && $9=="1" {c++} END{print c+0}' "${snapshot}")" + recovered_count="$(awk -F'\t' 'NR>1 && ($4=="recovered_predict_ok" || $11=="1" || $12=="1") {c++} END{print c+0}' "${snapshot}")" + processed=$((done_count + partial_count + failed_count)) + + now="$(date '+%Y-%m-%d %H:%M:%S')" + echo "${C_CYAN}Benchmark Dashboard${C_RESET} | ${now}" + echo "Root: ${ROOT}" + echo "Snapshot: ${snapshot}" + printf "Progress: " + draw_progress_bar "${processed}" "${total}" + echo + echo + printf "%b\n" "${C_BLUE}running=${running_count}${C_RESET} ${C_YELLOW}pending=${pending_count}${C_RESET} ${C_RED}failed=${failed_count}${C_RESET} ${C_GREEN}done=${done_count}${C_RESET} ${C_CYAN}partial=${partial_count}${C_RESET}" + printf "oom=%s oot=%s ancdata=%s rerun=%s recovered=%s\n" "${oom_count}" "${oot_count}" "${anc_count}" "${rerun_count}" "${recovered_count}" + + echo + echo "State Counts:" + awk -F'\t' ' + NR > 1 { + c[$5]++ + } + END { + print "state\tcount" + order[1] = "running" + order[2] = "pending" + order[3] = "failed" + order[4] = "done" + order[5] = "partial" + + for (i = 1; i <= 5; i++) { + s = order[i] + print s "\t" (c[s] + 0) + seen[s] = 1 + } + for (k in c) { + if (!(k in seen)) { + print k "\t" c[k] + } + } + } + ' "${snapshot}" \ + | colorize_rows_by_state_column 1 \ + | render_tsv_table + + echo + echo "Status Counts:" + awk -F'\t' ' + function state_from_status(status, lower) { + lower = tolower(status) + if (status == "" || status == "") return "pending" + if (lower ~ /running|in_progress/) return "running" + if (lower ~ /oom|oot|ancdata|fail|error|missing|recovery_no_checkpoint/) return "failed" + if (lower == "ok" || lower == "skipped_existing" || lower == "recovered_predict_ok") return "done" + return "pending" + } + function rank(state, lower) { + lower = tolower(state) + if (lower == "running") return 1 + if (lower == "pending") return 2 + if (lower == "failed") return 3 + if (lower == "done") return 4 + if (lower == "partial") return 5 + return 99 + } + NR > 1 { + key = $4 + if (key == "") key = "" + c[key]++ + } + END { + print "rank\tstatus\tcount" + for (k in c) { + st = state_from_status(k) + print rank(st) "\t" k "\t" c[k] + } + } + ' "${snapshot}" \ + | { + IFS= read -r header || true + if [[ -n "${header}" ]]; then + printf "status\tcount\n" + fi + sort -t $'\t' -k1,1n -k3,3nr -k2,2 \ + | cut -f2- + } \ + | colorize_rows_by_status_column 1 \ + | render_tsv_table + + echo + echo "All Jobs Overview:" + awk -F'\t' ' + function rank(st, lower) { + lower = tolower(st) + if (lower == "done") return 1 + if (lower == "running") return 2 + if (lower == "failed") return 3 + if (lower == "pending") return 4 + if (lower == "partial") return 5 + if (lower == "reference") return 6 + return 99 + } + function fmt_minutes(v, lower, n) { + lower = tolower(v) + if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" + n = (v + 0.0) / 60.0 + if (n < 0) n = 0 + return sprintf("%.2f", n) + } + BEGIN { + print "rank\tjob\tgroup\tgpu\tstate\tstatus\truns\telapsed_min\trerun\tanc_retry\toom_pred_fallback\trecovery\tseg\tanndata\txenium" + } + NR > 1 { + st = $5 + if (st == "") st = "pending" + status = $4 + if (status == "") status = "" + print rank(st) "\t" $1 "\t" $2 "\t" $3 "\t" st "\t" status "\t" $8 "\t" fmt_minutes($7) "\t" $9 "\t" $10 "\t" $11 "\t" $12 "\t" $13 "\t" $14 "\t" $15 + } + ' "${snapshot}" \ + | { + IFS= read -r header || true + if [[ -n "${header}" ]]; then + printf "job\tgroup\tgpu\tstate\tstatus\truns\telapsed_min\trerun\tanc_retry\toom_pred_fallback\trecovery\tseg\tanndata\txenium\n" + fi + sort -t $'\t' -k1,1n -k2,2 \ + | cut -f2- + } \ + | colorize_rows_by_state_column 4 \ + | render_tsv_table + + echo + echo "Model Parameterization:" + awk -F'\t' ' + function block_rank(block, lower) { + lower = tolower(block) + if (lower == "stability") return 1 + if (lower == "interaction") return 2 + if (lower == "stress") return 3 + if (block == "-") return 4 + return 5 + } + function model_label(job, lower) { + lower = tolower(job) + if (lower == "baseline") return "baseline" + if (index(lower, "stbl_baseline_") == 1) return "baseline_repeat" + if (index(lower, "stbl_anchor_") == 1) return "anchor_repeat" + if (index(lower, "stbl_sens_") == 1) return "sensitivity_repeat" + if (index(lower, "int_") == 1) return "interaction_ablation" + if (index(lower, "stress_") == 1) return "stress_test" + if (index(lower, "use3d_") == 1) return "ablation_use3d" + if (index(lower, "expansion_") == 1) return "ablation_expansion" + if (index(lower, "txk_") == 1) return "ablation_txk" + if (index(lower, "txdist_") == 1) return "ablation_txdist" + if (index(lower, "layers_") == 1) return "ablation_layers" + if (index(lower, "heads_") == 1) return "ablation_heads" + if (index(lower, "cellsmin_") == 1) return "ablation_cellsmin" + if (index(lower, "align_") == 1) return "ablation_alignment" + return "custom" + } + function col(name, fallback) { + if ((name in idx) && idx[name] > 0) return $(idx[name]) + return fallback + } + FNR == NR { + if (FNR == 1) { + for (i = 1; i <= NF; i++) { + if ($i == "job") snap_job_col = i + if ($i == "state") snap_state_col = i + if ($i == "status") snap_status_col = i + } + next + } + if (snap_job_col > 0) { + j = $snap_job_col + state_by_job[j] = (snap_state_col > 0 ? $snap_state_col : "") + status_by_job[j] = (snap_status_col > 0 ? $snap_status_col : "") + } + next + } + FNR == 1 { + for (i = 1; i <= NF; i++) idx[$i] = i + has_block = (("study_block" in idx) && idx["study_block"] > 0) ? 1 : 0 + print "rank\tjob\tmodel\tstudy_block\tgroup\tstate\tstatus\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss" + next + } + { + job = col("job", $1) + block = has_block ? col("study_block", "-") : "-" + group = col("group", "-") + state = state_by_job[job] + if (state == "") state = "pending" + status = status_by_job[job] + if (status == "") status = "" + + print block_rank(block) "\t" \ + job "\t" \ + model_label(job) "\t" \ + block "\t" \ + group "\t" \ + state "\t" \ + status "\t" \ + col("use_3d", "-") "\t" \ + col("expansion", "-") "\t" \ + col("tx_max_k", "-") "\t" \ + col("tx_max_dist", "-") "\t" \ + col("n_mid_layers", "-") "\t" \ + col("n_heads", "-") "\t" \ + col("cells_min_counts", "-") "\t" \ + col("min_qv", "-") "\t" \ + col("alignment_loss", "-") + } + ' "${snapshot}" "${PLAN_FILE}" \ + | { + IFS= read -r header || true + if [[ -n "${header}" ]]; then + printf "job\tmodel\tstudy_block\tgroup\tstate\tstatus\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" + fi + sort -t $'\t' -k1,1n -k2,2 \ + | cut -f2- + } \ + | colorize_rows_by_state_column 5 \ + | render_tsv_table + + echo + echo "Validation Metrics:" + if [[ -f "${VALIDATION_TSV}" ]] && [[ "$(awk 'END { print NR-1 }' "${VALIDATION_TSV}")" -gt 0 ]]; then + awk -F'\t' ' + function has_col(name) { + return (name in idx) && idx[name] > 0 + } + function get_col(name) { + if (has_col(name)) return $(idx[name]) + return "" + } + function fmt_float(v, lower) { + lower = tolower(v) + if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" + return sprintf("%.4f", v + 0.0) + } + function fmt_nonneg_int(v, lower, n) { + lower = tolower(v) + if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "0" + n = v + 0.0 + if (n < 0) n = 0 + return sprintf("%.0f", n) + } + function fmt_minutes(v, lower, n) { + lower = tolower(v) + if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" + n = (v + 0.0) / 60.0 + if (n < 0) n = 0 + return sprintf("%.2f", n) + } + FNR == NR { + if (FNR == 1) { + for (i = 1; i <= NF; i++) { + if ($i == "job") snap_job_col = i + if ($i == "state") snap_state_col = i + if ($i == "elapsed_s") snap_elapsed_col = i + } + next + } + if (snap_job_col > 0) { + job_key = $snap_job_col + if (snap_state_col > 0) state_by_job[job_key] = $snap_state_col + if (snap_elapsed_col > 0) elapsed_by_job[job_key] = $snap_elapsed_col + } + next + } + FNR == 1 { + for (i = 1; i <= NF; i++) { + idx[$i] = i + } + print "job\tkind\tstate\tvalidate_status\tgpu_time_min v\tcells\tassigned_pct ^\tmecr v\tcontamination_pct v\tresolvi_contam_pct v\ttco ^\tdoublet_pct v" + next + } + { + job = get_col("job") + job_disp = job + group = get_col("group") + is_reference = get_col("is_reference") + reference_kind = get_col("reference_kind") + if (reference_kind != "" && reference_kind != "-") { + kind = reference_kind + } else if (is_reference == "1" || group == "R") { + kind = "reference" + } else { + kind = "segger" + } + if (job == "baseline" && kind == "segger") { + job_disp = "baseline*" + } + + state = state_by_job[job] + if (kind != "segger") { + state = "reference" + } else if (state == "") { + state = "" + } + + validate_status = get_col("validate_status") + if (validate_status == "") validate_status = "" + + gpu_time = get_col("gpu_time_s") + if (gpu_time == "") gpu_time = get_col("elapsed_s") + if (gpu_time == "") gpu_time = elapsed_by_job[job] + gpu_time = fmt_minutes(gpu_time) + + cells = get_col("cells") + if (cells == "") cells = get_col("cells_total") + if (cells == "") cells = get_col("cells_assigned") + cells = fmt_nonneg_int(cells) + + assigned = get_col("assigned_pct") + if (assigned == "") assigned = get_col("transcripts_assigned_pct") + + mecr = get_col("mecr") + if (mecr == "") mecr = get_col("mecr_fast") + + contamination = get_col("contamination_pct") + if (contamination == "") contamination = get_col("border_contaminated_cells_pct_fast") + + resolvi = get_col("resolvi_contamination_pct") + if (resolvi == "") resolvi = get_col("resolvi_contamination_pct_fast") + + tco = get_col("tco") + if (tco == "") tco = get_col("transcript_centroid_offset_fast") + + doublet = get_col("doublet_pct") + if (doublet == "") { + doublet = get_col("signal_doublet_like_fraction_fast") + if (doublet != "" && tolower(doublet) != "nan" && tolower(doublet) != "none") { + doublet = 100.0 * (doublet + 0.0) + } + } + + print job_disp "\t" kind "\t" state "\t" validate_status "\t" gpu_time "\t" cells "\t" \ + fmt_float(assigned) "\t" fmt_float(mecr) "\t" fmt_float(contamination) "\t" \ + fmt_float(resolvi) "\t" fmt_float(tco) "\t" fmt_float(doublet) + } + ' "${snapshot}" "${VALIDATION_TSV}" \ + | { + IFS= read -r header || true + if [[ -n "${header}" ]]; then + printf "%s\n" "${header}" + fi + awk -F'\t' -v b_on="${C_BOLD}" -v b_off="${C_BOLD_OFF}" ' + function is_num(v, lower) { + lower = tolower(v) + return !(v == "" || lower == "nan" || lower == "none" || lower == "-") + } + function to_num(v) { + return v + 0.0 + } + function update_top2_up(v, x) { + if (!is_num(v)) return + x = to_num(v) + if (!have1 || x > top1) { + top2 = top1 + have2 = have1 + top1 = x + have1 = 1 + } else if (!have2 || x > top2) { + top2 = x + have2 = 1 + } + } + function update_top2_down(v, x) { + if (!is_num(v)) return + x = to_num(v) + if (!have1 || x < top1) { + top2 = top1 + have2 = have1 + top1 = x + have1 = 1 + } else if (!have2 || x < top2) { + top2 = x + have2 = 1 + } + } + function is_top2_up(v, best1, best2, have_2, x) { + if (!is_num(v) || !is_num(best1)) return 0 + x = to_num(v) + if (!have_2) return (x == best1) + return (x >= best2) + } + function is_top2_down(v, best1, best2, have_2, x) { + if (!is_num(v) || !is_num(best1)) return 0 + x = to_num(v) + if (!have_2) return (x == best1) + return (x <= best2) + } + function norm_up(v, lo, hi) { + if (!is_num(v)) return "" + if (hi <= lo) return 1.0 + return (to_num(v) - lo) / (hi - lo) + } + function norm_down(v, lo, hi) { + if (!is_num(v)) return "" + if (hi <= lo) return 1.0 + return (hi - to_num(v)) / (hi - lo) + } + { + n++ + for (j = 1; j <= NF; j++) { + cell[n, j] = $j + } + nf[n] = NF + m_assigned[n] = $7 + m_mecr[n] = $8 + m_contam[n] = $9 + m_resolvi[n] = $10 + m_tco[n] = $11 + m_doublet[n] = $12 + m_gpu[n] = $5 + st[n] = tolower($4) + is_ref[n] = (tolower($3) == "reference") + + if (is_num($7)) { + v = to_num($7) + if (!has_a || v < min_a) min_a = v + if (!has_a || v > max_a) max_a = v + has_a = 1 + } + if (is_num($8)) { + v = to_num($8) + if (!has_m || v < min_m) min_m = v + if (!has_m || v > max_m) max_m = v + has_m = 1 + } + if (is_num($9)) { + v = to_num($9) + if (!has_c || v < min_c) min_c = v + if (!has_c || v > max_c) max_c = v + has_c = 1 + } + if (is_num($10)) { + v = to_num($10) + if (!has_r || v < min_r) min_r = v + if (!has_r || v > max_r) max_r = v + has_r = 1 + } + if (is_num($11)) { + v = to_num($11) + if (!has_t || v < min_t) min_t = v + if (!has_t || v > max_t) max_t = v + has_t = 1 + } + if (is_num($12)) { + v = to_num($12) + if (!has_d || v < min_d) min_d = v + if (!has_d || v > max_d) max_d = v + has_d = 1 + } + + if (st[n] == "ok") { + have1 = have_a_best1; top1 = a_best1; have2 = have_a_best2; top2 = a_best2 + update_top2_up($7) + have_a_best1 = have1; a_best1 = top1; have_a_best2 = have2; a_best2 = top2 + + have1 = have_m_best1; top1 = m_best1; have2 = have_m_best2; top2 = m_best2 + update_top2_down($8) + have_m_best1 = have1; m_best1 = top1; have_m_best2 = have2; m_best2 = top2 + + have1 = have_c_best1; top1 = c_best1; have2 = have_c_best2; top2 = c_best2 + update_top2_down($9) + have_c_best1 = have1; c_best1 = top1; have_c_best2 = have2; c_best2 = top2 + + have1 = have_r_best1; top1 = r_best1; have2 = have_r_best2; top2 = r_best2 + update_top2_down($10) + have_r_best1 = have1; r_best1 = top1; have_r_best2 = have2; r_best2 = top2 + + have1 = have_t_best1; top1 = t_best1; have2 = have_t_best2; top2 = t_best2 + update_top2_up($11) + have_t_best1 = have1; t_best1 = top1; have_t_best2 = have2; t_best2 = top2 + + have1 = have_d_best1; top1 = d_best1; have2 = have_d_best2; top2 = d_best2 + update_top2_down($12) + have_d_best1 = have1; d_best1 = top1; have_d_best2 = have2; d_best2 = top2 + + if (!is_ref[n]) { + have1 = have_g_best1; top1 = g_best1; have2 = have_g_best2; top2 = g_best2 + update_top2_down($5) + have_g_best1 = have1; g_best1 = top1; have_g_best2 = have2; g_best2 = top2 + } + } + } + END { + for (i = 1; i <= n; i++) { + # Rank rows by overall score across assigned/mecr/contam/tco/doublet. + # Rows with no numeric metrics (or non-ok status) stay at the bottom. + if (st[i] != "ok") { + score = -1e9 + } else { + score = 0.0 + cnt = 0 + + s = norm_up(m_assigned[i], min_a, max_a) + if (s != "") { score += s; cnt++ } + + s = norm_down(m_mecr[i], min_m, max_m) + if (s != "") { score += s; cnt++ } + + s = norm_down(m_contam[i], min_c, max_c) + if (s != "") { score += s; cnt++ } + + s = norm_up(m_tco[i], min_t, max_t) + if (s != "") { score += s; cnt++ } + + s = norm_down(m_doublet[i], min_d, max_d) + if (s != "") { score += s; cnt++ } + + if (cnt > 0) { + score /= cnt + } else { + score = -1e9 + } + } + + if (b_on != "" && st[i] == "ok") { + if (!is_ref[i] && is_top2_down(m_gpu[i], g_best1, g_best2, have_g_best2)) cell[i, 5] = b_on cell[i, 5] b_off + if (is_top2_up(m_assigned[i], a_best1, a_best2, have_a_best2)) cell[i, 7] = b_on cell[i, 7] b_off + if (is_top2_down(m_mecr[i], m_best1, m_best2, have_m_best2)) cell[i, 8] = b_on cell[i, 8] b_off + if (is_top2_down(m_contam[i], c_best1, c_best2, have_c_best2)) cell[i, 9] = b_on cell[i, 9] b_off + if (is_top2_down(m_resolvi[i], r_best1, r_best2, have_r_best2)) cell[i, 10] = b_on cell[i, 10] b_off + if (is_top2_up(m_tco[i], t_best1, t_best2, have_t_best2)) cell[i, 11] = b_on cell[i, 11] b_off + if (is_top2_down(m_doublet[i], d_best1, d_best2, have_d_best2)) cell[i, 12] = b_on cell[i, 12] b_off + } + + row = cell[i, 1] + for (j = 2; j <= nf[i]; j++) { + row = row "\t" cell[i, j] + } + gkey = 1 + if (tolower(cell[i, 3]) == "reference") gkey = 0 + printf "%d\t%.10f\t%s\n", gkey, score, row + } + } + ' \ + | sort -t $'\t' -k1,1n -k2,2gr -k3,3 \ + | cut -f3- \ + | awk -F'\t' ' + { + state = tolower($3) + if (!seen_data) { + seen_data = 1 + } + if (state == "reference") { + seen_ref = 1 + print + next + } + if (seen_ref && !inserted_sep) { + print "__ROW_SEP__" + inserted_sep = 1 + } + print + } + ' + } \ + | colorize_rows_by_state_column 3 \ + | render_tsv_table + else + echo "No validation TSV found at ${VALIDATION_TSV}" + echo "Run: bash scripts/build_benchmark_validation_table.sh --root ${ROOT}" + fi + + if [[ "${running_count}" -gt 0 ]]; then + echo + echo "Running Jobs:" + awk -F'\t' ' + BEGIN { + print "job\tgpu\tstatus\tstate\truns\tlog_file" + } + NR > 1 && $6 == "1" { + st = $4 + if (st == "") st = "" + print $1 "\t" $3 "\t" st "\t" $5 "\t" $8 "\t" $17 + } + ' "${snapshot}" \ + | colorize_rows_by_state_column 4 \ + | render_tsv_table + fi + + if [[ "${failed_count}" -gt 0 ]]; then + echo + echo "Failed Jobs:" + awk -F'\t' ' + BEGIN { + print "job\tgpu\tstatus\tstate\truns\trerun\tlog_file" + } + NR > 1 && $5 == "failed" { + st = $4 + if (st == "") st = "" + print $1 "\t" $3 "\t" st "\t" $5 "\t" $8 "\t" $9 "\t" $17 + } + ' "${snapshot}" \ + | colorize_rows_by_state_column 4 \ + | render_tsv_table + fi + + if [[ "${rerun_count}" -gt 0 ]]; then + echo + echo "Rerun/Retry Jobs:" + awk -F'\t' ' + BEGIN { + print "job\tstate\tstatus\truns\tanc_retry\toom_predict_fallback\trecovery_pass" + } + NR > 1 && $9 == "1" { + st = $4 + if (st == "") st = "" + print $1 "\t" $5 "\t" st "\t" $8 "\t" $10 "\t" $11 "\t" $12 + } + ' "${snapshot}" \ + | { + IFS= read -r header || true + if [[ -n "${header}" ]]; then + printf "%s\n" "${header}" + fi + awk -F'\t' ' + function rank(state, lower) { + lower = tolower(state) + if (lower == "running") return 1 + if (lower == "pending") return 2 + if (lower == "failed") return 3 + if (lower == "done") return 4 + if (lower == "partial") return 5 + return 99 + } + { + print rank($2) "\t" $0 + } + ' \ + | sort -t $'\t' -k1,1n -k2,2 \ + | cut -f2- + } \ + | colorize_rows_by_state_column 2 \ + | render_tsv_table + fi +} + +snapshot_once() { + local tmp_status_map tmp_running + tmp_status_map="$(mktemp)" + tmp_running="$(mktemp)" + collect_running_jobs "${tmp_running}" + collect_status_map "${tmp_status_map}" "${tmp_running}" + build_snapshot "${tmp_status_map}" "${tmp_running}" "${OUT_TSV}" + rm -f "${tmp_status_map}" "${tmp_running}" +} + +if [[ "${WATCH_SEC}" -gt 0 ]]; then + while true; do + snapshot_once + if [[ -t 1 ]]; then + clear + fi + render_dashboard "${OUT_TSV}" + sleep "${WATCH_SEC}" + done +else + snapshot_once + render_dashboard "${OUT_TSV}" +fi diff --git a/scripts/build_benchmark_pdf_report.py b/scripts/build_benchmark_pdf_report.py new file mode 100644 index 0000000..89db660 --- /dev/null +++ b/scripts/build_benchmark_pdf_report.py @@ -0,0 +1,1168 @@ +#!/usr/bin/env python3 +"""Build a multi-page PDF report for benchmark validation metrics.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Iterable, Sequence + +import numpy as np +import pandas as pd + +try: + import anndata as ad +except Exception: # pragma: no cover + ad = None + +try: + import matplotlib.pyplot as plt + from matplotlib.backends.backend_pdf import PdfPages + from matplotlib.patches import Polygon as MplPolygon +except Exception: # pragma: no cover + plt = None + PdfPages = None + MplPolygon = None + +try: + import polars as pl +except Exception: # pragma: no cover + pl = None + + +METRIC_SPECS = [ + ("assigned_pct", "assigned_ci95", "Assigned %", "up"), + ("mecr", "mecr_ci95", "MECR", "down"), + ("contamination_pct", "contamination_ci95", "Contamination %", "down"), + ("tco", "tco_ci95", "TCO", "up"), + ("doublet_pct", "doublet_ci95", "Doublet %", "down"), + ("gpu_time_min", None, "GPU time (min)", "down"), +] + +OVERALL_RANK_METRICS = [ + ("assigned_pct", True), + ("mecr", False), + ("contamination_pct", False), + ("tco", True), + ("doublet_pct", False), +] + +REFERENCE_COLORS = { + "10x_cell": "#c97a3d", + "10x_nucleus": "#e0ab66", + "ref_other": "#b68f6f", +} +SEGGER_SHADE_LIGHT = "#d8e6f4" +SEGGER_SHADE_DARK = "#0f4c81" + + +def _to_numeric(df: pd.DataFrame, cols: Iterable[str]) -> None: + for c in cols: + if c in df.columns: + df[c] = pd.to_numeric(df[c], errors="coerce") + + +def _safe_str(value: object) -> str: + if value is None: + return "" + return str(value) + + +def _is_reference_row(row: pd.Series) -> bool: + is_ref = _safe_str(row.get("is_reference", "0")).strip().lower() in {"1", "true", "yes"} + group = _safe_str(row.get("group", "")).strip().upper() + job = _safe_str(row.get("job", "")).strip().lower() + ref_kind = _safe_str(row.get("reference_kind", "")).strip() + return is_ref or group == "R" or job.startswith("ref_") or (ref_kind not in {"", "-", "nan", "None"}) + + +def _kind_label(row: pd.Series) -> str: + ref_kind = _safe_str(row.get("reference_kind", "")).strip() + if ref_kind and ref_kind not in {"-", "nan", "None"}: + return ref_kind + return "segger" + + +def _display_job_label(row: pd.Series) -> str: + job = _safe_str(row.get("job", "")).strip() + if job == "baseline": + return "baseline*" + if _is_reference_row(row): + kind = _kind_label(row) + if kind == "10x_cell": + return "10x cell (ref)" + if kind == "10x_nucleus": + return "10x nucleus (ref)" + if kind and kind != "segger": + return f"{kind} (ref)" + return job + + +def _hex_to_rgb(color_hex: str) -> tuple[float, float, float]: + c = color_hex.strip().lstrip("#") + if len(c) != 6: + return (0.5, 0.5, 0.5) + return tuple(int(c[i : i + 2], 16) / 255.0 for i in (0, 2, 4)) + + +def _rgb_to_hex(rgb: Sequence[float]) -> str: + vals = [max(0, min(255, int(round(float(v) * 255.0)))) for v in rgb] + return f"#{vals[0]:02x}{vals[1]:02x}{vals[2]:02x}" + + +def _mix_hex(color_a: str, color_b: str, t: float) -> str: + t = float(max(0.0, min(1.0, t))) + a = _hex_to_rgb(color_a) + b = _hex_to_rgb(color_b) + return _rgb_to_hex(tuple((1.0 - t) * x + t * y for x, y in zip(a, b))) + + +def _build_gradient(light_hex: str, dark_hex: str, n: int) -> list[str]: + if n <= 0: + return [] + if n == 1: + return [dark_hex] + vals = np.linspace(0.0, 1.0, n) + return [_mix_hex(dark_hex, light_hex, float(v)) for v in vals] + + +def _normalize_metric(vals: np.ndarray, higher_is_better: bool) -> np.ndarray: + x = np.asarray(vals, dtype=float) + if x.size == 0: + return x + finite = np.isfinite(x) + out = np.full_like(x, np.nan) + if not np.any(finite): + return out + lo = np.nanmin(x[finite]) + hi = np.nanmax(x[finite]) + if not np.isfinite(lo) or not np.isfinite(hi): + return out + if hi <= lo: + out[finite] = 1.0 + else: + out[finite] = (x[finite] - lo) / (hi - lo) + if not higher_is_better: + out[finite] = 1.0 - out[finite] + return out + + +def _compute_overall_score(df: pd.DataFrame) -> pd.Series: + if df.empty: + return pd.Series(dtype=float) + mats = [] + for metric, hib in OVERALL_RANK_METRICS: + if metric not in df.columns: + mats.append(np.full(len(df), np.nan, dtype=float)) + continue + arr = pd.to_numeric(df[metric], errors="coerce").to_numpy(dtype=float) + mats.append(_normalize_metric(arr, hib)) + if not mats: + return pd.Series(np.nan, index=df.index, dtype=float) + stack = np.vstack(mats).T + with np.errstate(invalid="ignore"): + denom = np.sum(np.isfinite(stack), axis=1).astype(float) + numer = np.nansum(stack, axis=1) + score = np.divide(numer, denom, out=np.full_like(numer, np.nan), where=denom > 0) + return pd.Series(score, index=df.index, dtype=float) + + +def _rank_df(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + if "_overall_score" not in out.columns: + out["_overall_score"] = _compute_overall_score(out) + for c in ["assigned_pct", "mecr"]: + if c not in out.columns: + out[c] = np.nan + out = out.sort_values( + by=["_overall_score", "assigned_pct", "mecr"], + ascending=[False, False, True], + na_position="last", + ) + return out + + +def _ensure_columns(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + for col, _, _, _ in METRIC_SPECS: + if col not in out.columns: + out[col] = np.nan + for _, ci_col, _, _ in METRIC_SPECS: + if ci_col and ci_col not in out.columns: + out[ci_col] = np.nan + if "gpu_time_s" in out.columns and "gpu_time_min" not in out.columns: + out["gpu_time_min"] = pd.to_numeric(out["gpu_time_s"], errors="coerce") / 60.0 + if "gpu_time_min" not in out.columns: + out["gpu_time_min"] = np.nan + if "validate_status" not in out.columns: + out["validate_status"] = "" + if "job" not in out.columns: + out["job"] = "" + if "anndata_path" not in out.columns: + out["anndata_path"] = "" + if "segmentation_path" not in out.columns: + out["segmentation_path"] = "" + return out + + +def _assign_plot_colors(df: pd.DataFrame) -> pd.DataFrame: + out = df.copy() + out["_plot_color"] = "#4f6f8f" + out["_plot_label"] = out.apply(_display_job_label, axis=1) + out["_overall_score"] = _compute_overall_score(out) + + ref_mask = out["is_reference_row"].astype(bool) + for idx, row in out[ref_mask].iterrows(): + kind = _safe_str(row.get("kind", "")) + if kind == "10x_cell": + out.at[idx, "_plot_color"] = REFERENCE_COLORS["10x_cell"] + elif kind == "10x_nucleus": + out.at[idx, "_plot_color"] = REFERENCE_COLORS["10x_nucleus"] + else: + out.at[idx, "_plot_color"] = REFERENCE_COLORS["ref_other"] + + seg = out[~ref_mask].copy() + seg = _rank_df(seg) + shades = _build_gradient(SEGGER_SHADE_LIGHT, SEGGER_SHADE_DARK, len(seg)) + for i, (idx, _) in enumerate(seg.iterrows()): + out.at[idx, "_plot_color"] = shades[i] if i < len(shades) else SEGGER_SHADE_DARK + + return out + + +def _load_metrics(path: Path) -> pd.DataFrame: + df = pd.read_csv(path, sep="\t", dtype=str) + num_cols = [ + "gpu_time_s", + "cells", + "assigned_pct", + "assigned_ci95", + "mecr", + "mecr_ci95", + "contamination_pct", + "contamination_ci95", + "tco", + "tco_ci95", + "doublet_pct", + "doublet_ci95", + ] + _to_numeric(df, num_cols) + if "gpu_time_s" in df.columns: + df["gpu_time_min"] = df["gpu_time_s"] / 60.0 + else: + df["gpu_time_min"] = np.nan + if "validate_status" not in df.columns: + df["validate_status"] = "" + df["is_reference_row"] = df.apply(_is_reference_row, axis=1) + df["kind"] = df.apply(_kind_label, axis=1) + df = _ensure_columns(df) + df = _assign_plot_colors(df) + return df + + +def _ok_rows(df: pd.DataFrame) -> pd.DataFrame: + return df[df["validate_status"].astype(str).str.lower() == "ok"].copy() + + +def _ordered_ok_rows(df: pd.DataFrame) -> pd.DataFrame: + ok = _ok_rows(df) + if ok.empty: + return ok + refs = ok[ok["is_reference_row"]].copy() + refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) + refs = refs.sort_values(by=["_ref_order", "job"], ascending=[True, True]) + seg = _rank_df(ok[~ok["is_reference_row"]].copy()) + return pd.concat([refs, seg], axis=0, ignore_index=False) + + +def _apply_report_style() -> None: + if plt is None: + return + style_candidates = [ + Path(__file__).resolve().parents[2] / "segger-analysis" / "assets" / "paper.mplstyle", + Path(__file__).resolve().parents[1] / "assets" / "paper.mplstyle", + Path("../segger-analysis/assets/paper.mplstyle").resolve(), + ] + for style_path in style_candidates: + if style_path.exists(): + try: + plt.style.use(str(style_path)) + break + except Exception: + continue + + plt.rcParams.update( + { + "font.family": "sans-serif", + "font.sans-serif": ["Helvetica Neue", "Helvetica", "Arial", "DejaVu Sans"], + "axes.titlesize": 9, + "axes.labelsize": 8, + "xtick.labelsize": 7, + "ytick.labelsize": 7, + "axes.linewidth": 0.6, + "grid.linewidth": 0.45, + "grid.alpha": 0.16, + "figure.dpi": 220, + "savefig.dpi": 300, + "legend.frameon": False, + "legend.fontsize": 7, + } + ) + + +def _clean_axes(ax) -> None: + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + +def _metric_title(base: str, direction: str) -> str: + arrow = "^" if direction == "up" else "v" + return f"{base} {arrow}" + + +def _plot_bar_page(pdf: PdfPages, df: pd.DataFrame) -> None: + fig, axes = plt.subplots(2, 3, figsize=(11.6, 8.3)) + axes = axes.flatten() + + disp = _ordered_ok_rows(df) + if disp.empty: + fig.suptitle("Benchmark Comparison: no valid rows", fontsize=11) + for ax in axes: + ax.axis("off") + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + y = np.arange(len(disp)) + labels = [_safe_str(x) for x in disp["_plot_label"].tolist()] + colors = [_safe_str(x) for x in disp["_plot_color"].tolist()] + + for ax, (metric, ci_col, title, direction) in zip(axes, METRIC_SPECS): + vals = pd.to_numeric(disp[metric], errors="coerce").to_numpy(dtype=float) + errs = ( + pd.to_numeric(disp[ci_col], errors="coerce").to_numpy(dtype=float) + if ci_col is not None + else np.full_like(vals, np.nan, dtype=float) + ) + valid = np.isfinite(vals) + if not np.any(valid): + ax.set_title(_metric_title(title, direction), fontsize=9) + ax.text(0.5, 0.5, "no data", transform=ax.transAxes, ha="center", va="center", fontsize=8) + ax.set_yticks([]) + ax.grid(False) + _clean_axes(ax) + continue + + val_v = vals[valid] + err_v = errs[valid] + err_plot = np.where(np.isfinite(err_v) & (err_v >= 0), err_v, 0.0) + color_v = np.asarray(colors, dtype=object)[valid] + y_v = y[valid] + + ax.barh( + y_v, + val_v, + xerr=err_plot, + color=color_v, + alpha=0.9, + edgecolor="none", + error_kw={ + "elinewidth": 0.6, + "capthick": 0.6, + "capsize": 1.8, + "ecolor": "#2f2f2f", + "alpha": 0.9, + }, + ) + ax.set_title(_metric_title(title, direction), fontsize=9) + ax.set_yticks(y) + ax.set_yticklabels(labels, fontsize=6.8) + ax.invert_yaxis() + ax.grid(axis="x") + ax.tick_params(axis="x", labelsize=7) + _clean_axes(ax) + + for ax in axes[len(METRIC_SPECS) :]: + ax.axis("off") + + fig.suptitle("Benchmark Overview (Segger shades + 10x references)", fontsize=11, y=0.995) + fig.tight_layout(rect=[0, 0, 1, 0.98]) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _annotate_selected(ax, df: pd.DataFrame, x_col: str, y_col: str) -> None: + if df.empty: + return + refs = df[df["is_reference_row"]].copy() + seg = df[~df["is_reference_row"]].copy() + seg = _rank_df(seg) + + candidates = [] + if not refs.empty: + candidates.extend(refs.index.tolist()[:2]) + if not seg.empty: + candidates.extend(seg.index.tolist()[:3]) + + for idx in candidates: + row = df.loc[idx] + x = pd.to_numeric(pd.Series([row.get(x_col)]), errors="coerce").iloc[0] + y = pd.to_numeric(pd.Series([row.get(y_col)]), errors="coerce").iloc[0] + if not (np.isfinite(x) and np.isfinite(y)): + continue + ax.text(float(x), float(y), _safe_str(row.get("_plot_label", "")), fontsize=6.5, ha="left", va="bottom") + + +def _plot_scatter_page(pdf: PdfPages, df: pd.DataFrame) -> None: + ok = _ordered_ok_rows(df) + fig, axes = plt.subplots(1, 2, figsize=(11.5, 4.2)) + + for ax in axes: + _clean_axes(ax) + ax.grid(alpha=0.2) + + for _, row in ok.iterrows(): + color = _safe_str(row.get("_plot_color", "#4f6f8f")) + is_ref = bool(row.get("is_reference_row")) + marker = "s" if is_ref else "o" + size = 44 if is_ref else 32 + alpha = 0.95 if is_ref else 0.86 + + a = float(pd.to_numeric(pd.Series([row.get("assigned_pct")]), errors="coerce").iloc[0]) + c = float(pd.to_numeric(pd.Series([row.get("contamination_pct")]), errors="coerce").iloc[0]) + m = float(pd.to_numeric(pd.Series([row.get("mecr")]), errors="coerce").iloc[0]) + if np.isfinite(a) and np.isfinite(c): + axes[0].scatter(a, c, c=color, marker=marker, s=size, alpha=alpha, linewidths=0) + if np.isfinite(a) and np.isfinite(m): + axes[1].scatter(a, m, c=color, marker=marker, s=size, alpha=alpha, linewidths=0) + + axes[0].set_title("Sensitivity vs Contamination") + axes[0].set_xlabel("Assigned transcripts (%)") + axes[0].set_ylabel("Contamination (%) v") + axes[1].set_title("Sensitivity vs MECR") + axes[1].set_xlabel("Assigned transcripts (%)") + axes[1].set_ylabel("MECR v") + + _annotate_selected(axes[0], ok, "assigned_pct", "contamination_pct") + _annotate_selected(axes[1], ok, "assigned_pct", "mecr") + + fig.suptitle("Fast-Metric Trade-offs", fontsize=11) + fig.tight_layout(rect=[0, 0, 1, 0.95]) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _plot_heatmap_page(pdf: PdfPages, df: pd.DataFrame) -> None: + ok = _ordered_ok_rows(df) + metric_defs = [ + ("assigned_pct", True), + ("mecr", False), + ("contamination_pct", False), + ("tco", True), + ("doublet_pct", False), + ("gpu_time_min", False), + ] + if ok.empty: + fig, ax = plt.subplots(figsize=(11, 6)) + ax.axis("off") + ax.text(0.5, 0.5, "No valid rows for heatmap", ha="center", va="center", fontsize=10) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + arr = [] + labels = [] + for metric, hib in metric_defs: + vals = pd.to_numeric(ok[metric], errors="coerce").to_numpy(dtype=float) + arr.append(_normalize_metric(vals, hib)) + labels.append(metric) + mat = np.vstack(arr).T + mat_masked = np.ma.masked_invalid(mat) + + fig_h = max(4.0, 0.26 * len(ok) + 2.0) + fig, ax = plt.subplots(figsize=(10.3, fig_h)) + im = ax.imshow(mat_masked, aspect="auto", cmap="cividis", vmin=0.0, vmax=1.0) + ax.set_yticks(np.arange(len(ok))) + ax.set_yticklabels(ok["_plot_label"].astype(str).tolist(), fontsize=7) + ax.set_xticks(np.arange(len(labels))) + ax.set_xticklabels(labels, fontsize=8, rotation=30, ha="right") + ax.set_title("Metric Heatmap (normalized, higher is better)") + cbar = fig.colorbar(im, ax=ax, shrink=0.9, fraction=0.028, pad=0.015) + cbar.set_label("relative score", fontsize=8) + cbar.ax.tick_params(labelsize=7) + _clean_axes(ax) + fig.tight_layout() + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _find_sgutils_src() -> Path | None: + candidates = [ + Path(__file__).resolve().parents[2] / "segger-analysis" / "src", + Path("../segger-analysis/src").resolve(), + Path.cwd().parent / "segger-analysis" / "src", + ] + for c in candidates: + if c.exists(): + return c + return None + + +def _enable_sgutils_import() -> bool: + src = _find_sgutils_src() + if src is None: + return False + src_str = str(src) + if src_str not in sys.path: + sys.path.insert(0, src_str) + return True + + +def _valid_umap_xy(xy: np.ndarray | None) -> np.ndarray | None: + if xy is None: + return None + arr = np.asarray(xy) + if arr.ndim != 2 or arr.shape[0] == 0 or arr.shape[1] < 2: + return None + finite = np.isfinite(arr[:, 0]) & np.isfinite(arr[:, 1]) + arr = arr[finite] + if arr.shape[0] == 0: + return None + return arr[:, :2] + + +def _downsample_adata(adata_obj, max_cells: int, seed: int): + if getattr(adata_obj, "n_obs", 0) <= max_cells: + return adata_obj + rng = np.random.default_rng(seed) + keep = rng.choice(np.arange(adata_obj.n_obs), size=max_cells, replace=False) + keep = np.sort(keep) + return adata_obj[keep, :].copy() + + +def _compute_umap_with_sgutils(adata_obj, seed: int) -> np.ndarray | None: + if not _enable_sgutils_import(): + return None + try: + from sg_utils.pp.preprocess_rapids import preprocess_rapids + except Exception: + return None + try: + work = adata_obj.copy() + if getattr(work, "n_obs", 0) < 30 or getattr(work, "n_vars", 0) < 20: + return None + if getattr(work, "raw", None) is None: + work.raw = work.copy() + preprocess_rapids( + work, + n_hvgs=min(2000, max(256, int(getattr(work, "n_vars", 2000)))), + pca_total_var=0.9, + knn_neighbors=min(15, max(5, int(getattr(work, "n_obs", 100) - 1))), + umap_n_epochs=300, + random_state=seed, + show_progress=False, + ) + return _valid_umap_xy(work.obsm.get("X_umap")) + except Exception: + return None + + +def _compute_umap_with_scanpy(adata_obj, seed: int) -> np.ndarray | None: + try: + import scanpy as sc + except Exception: + return None + try: + work = adata_obj.copy() + if getattr(work, "n_obs", 0) < 25 or getattr(work, "n_vars", 0) < 15: + return None + sc.pp.filter_cells(work, min_counts=1) + sc.pp.filter_genes(work, min_counts=1) + if work.n_obs < 25 or work.n_vars < 15: + return None + + sc.pp.normalize_total(work, target_sum=1e4) + sc.pp.log1p(work) + + if work.n_vars > 80: + n_top = min(2000, max(80, int(0.5 * work.n_vars))) + sc.pp.highly_variable_genes(work, n_top_genes=n_top, flavor="seurat") + if "highly_variable" in work.var.columns: + hv_mask = np.asarray(work.var["highly_variable"]).astype(bool) + if int(hv_mask.sum()) >= 20: + work = work[:, hv_mask].copy() + + n_comps = min(35, work.n_obs - 1, work.n_vars - 1) + if n_comps < 2: + return None + sc.pp.pca(work, n_comps=n_comps) + n_neighbors = min(15, max(5, work.n_obs - 1)) + sc.pp.neighbors(work, n_neighbors=n_neighbors, n_pcs=min(20, n_comps)) + sc.tl.umap(work, min_dist=0.35, spread=1.0, random_state=seed) + return _valid_umap_xy(work.obsm.get("X_umap")) + except Exception: + return None + + +def _umap_points_for_path(anndata_path: Path, seed: int, max_cells: int) -> tuple[np.ndarray | None, str]: + if ad is None: + return None, "anndata_missing" + if not anndata_path.exists(): + return None, "missing_h5ad" + try: + adata_obj = ad.read_h5ad(anndata_path) + except Exception: + return None, "read_h5ad_failed" + + adata_obj = _downsample_adata(adata_obj, max_cells=max_cells, seed=seed) + xy = _valid_umap_xy(adata_obj.obsm.get("X_umap")) + if xy is not None: + return xy, "precomputed" + + xy = _compute_umap_with_sgutils(adata_obj, seed=seed) + if xy is not None: + return xy, "sgutils" + + xy = _compute_umap_with_scanpy(adata_obj, seed=seed) + if xy is not None: + return xy, "scanpy" + + return None, "umap_unavailable" + + +def _pick_umap_rows(df: pd.DataFrame) -> list[pd.Series]: + ok = _ordered_ok_rows(df) + ok = ok[ok["anndata_path"].astype(str).str.strip() != ""].copy() + if ok.empty: + return [] + + refs = ok[ok["is_reference_row"]].copy() + refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) + refs = refs.sort_values(by=["_ref_order", "job"]).head(2) + + seg = _rank_df(ok[~ok["is_reference_row"]].copy()) + best = seg.head(2) + worst = seg.tail(2) + + picked: list[pd.Series] = [] + seen_jobs: set[str] = set() + + def _push_rows(sub_df: pd.DataFrame) -> None: + for _, row in sub_df.iterrows(): + job = _safe_str(row.get("job", "")) + if job in seen_jobs: + continue + seen_jobs.add(job) + picked.append(row) + + _push_rows(refs) + _push_rows(best) + _push_rows(worst) + + if len(picked) < 6: + _push_rows(seg) + if len(picked) < 6: + _push_rows(refs) + + return picked[:6] + + +def _panel_title_from_row(row: pd.Series) -> str: + label = _safe_str(row.get("_plot_label", "")).strip() + if label: + return label + return _safe_str(row.get("job", "")) + + +def _plot_umap_panel(ax, row: pd.Series, seed: int, max_cells: int, cache: dict[str, tuple[np.ndarray | None, str]]) -> None: + _clean_axes(ax) + ax.set_xticks([]) + ax.set_yticks([]) + + title = _panel_title_from_row(row) + path = Path(_safe_str(row.get("anndata_path", "")).strip()) + color = _safe_str(row.get("_plot_color", "#4f6f8f")) + cache_key = str(path.resolve()) if path.as_posix() not in {"", "."} else _safe_str(row.get("job", "")) + + if cache_key in cache: + xy, source = cache[cache_key] + else: + xy, source = _umap_points_for_path(path, seed=seed, max_cells=max_cells) + cache[cache_key] = (xy, source) + + ax.set_title(title, fontsize=8.2) + if xy is None: + ax.text(0.5, 0.5, f"-- error: UMAP missing for {title}", ha="center", va="center", fontsize=7.2) + return + + ax.scatter(xy[:, 0], xy[:, 1], s=1.8, c=color, alpha=0.72, linewidths=0, rasterized=True) + ax.text( + 0.02, + 0.03, + f"n={xy.shape[0]} | {source}", + transform=ax.transAxes, + ha="left", + va="bottom", + fontsize=6.2, + color="#555555", + ) + ax.set_aspect("equal", adjustable="box") + + +def _plot_umap_page(pdf: PdfPages, df: pd.DataFrame, seed: int, umap_max_cells: int) -> None: + picks = _pick_umap_rows(df) + fig, axes = plt.subplots(2, 3, figsize=(11.2, 7.6)) + axes = axes.flatten() + cache: dict[str, tuple[np.ndarray | None, str]] = {} + + for i, ax in enumerate(axes): + if i < len(picks): + _plot_umap_panel(ax, picks[i], seed=seed + i, max_cells=umap_max_cells, cache=cache) + else: + ax.axis("off") + ax.text(0.5, 0.5, "no panel", ha="center", va="center", fontsize=8) + + fig.suptitle("UMAP Panels: 2 references + 2 best + 2 worst", fontsize=11) + fig.tight_layout(rect=[0, 0, 1, 0.96]) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def _load_transcript_xy(input_dir: Path) -> pd.DataFrame | None: + tx_path = input_dir / "transcripts.parquet" + if not tx_path.exists(): + return None + + if pl is not None: + try: + lf = pl.scan_parquet(tx_path, parallel="row_groups") + cols = lf.collect_schema().names() + if "row_index" in cols: + lf = lf.with_columns(pl.col("row_index").cast(pl.Int64)) + else: + lf = lf.with_row_index(name="row_index") + + x_col = "x_location" if "x_location" in cols else ("x" if "x" in cols else None) + y_col = "y_location" if "y_location" in cols else ("y" if "y" in cols else None) + if x_col is None or y_col is None: + return None + + tx = ( + lf.select(["row_index", x_col, y_col]) + .rename({x_col: "x", y_col: "y"}) + .collect() + .to_pandas() + ) + tx["row_index"] = pd.to_numeric(tx["row_index"], errors="coerce").astype("Int64") + tx["x"] = pd.to_numeric(tx["x"], errors="coerce") + tx["y"] = pd.to_numeric(tx["y"], errors="coerce") + return tx.dropna(subset=["row_index", "x", "y"]).copy() + except Exception: + pass + + try: + tx = pd.read_parquet(tx_path) + except Exception: + return None + if "row_index" not in tx.columns: + tx = tx.copy() + tx["row_index"] = np.arange(len(tx), dtype=np.int64) + x_col = "x_location" if "x_location" in tx.columns else ("x" if "x" in tx.columns else None) + y_col = "y_location" if "y_location" in tx.columns else ("y" if "y" in tx.columns else None) + if x_col is None or y_col is None: + return None + out = tx[["row_index", x_col, y_col]].rename(columns={x_col: "x", y_col: "y"}) + out["row_index"] = pd.to_numeric(out["row_index"], errors="coerce").astype("Int64") + out["x"] = pd.to_numeric(out["x"], errors="coerce") + out["y"] = pd.to_numeric(out["y"], errors="coerce") + return out.dropna(subset=["row_index", "x", "y"]).copy() + + +def _load_seg_assign(seg_path: Path) -> pd.DataFrame | None: + if not seg_path.exists(): + return None + + id_col_candidates = [ + "segger_cell_id", + "cell_id", + "xenium_cell_id", + "tenx_cell_id", + ] + + if pl is not None: + try: + lf = pl.scan_parquet(seg_path) + cols = lf.collect_schema().names() + id_col = None + for c in id_col_candidates: + if c in cols: + id_col = c + break + if id_col is None: + for c in cols: + if c.endswith("_cell_id"): + id_col = c + break + if id_col is None or "row_index" not in cols: + return None + + df = lf.select(["row_index", id_col]).collect().to_pandas() + df.columns = ["row_index", "segger_cell_id"] + df["row_index"] = pd.to_numeric(df["row_index"], errors="coerce").astype("Int64") + return df.dropna(subset=["row_index"]).copy() + except Exception: + pass + + try: + df = pd.read_parquet(seg_path) + except Exception: + return None + if "row_index" not in df.columns: + return None + id_col = None + for c in id_col_candidates: + if c in df.columns: + id_col = c + break + if id_col is None: + for c in df.columns: + if str(c).endswith("_cell_id"): + id_col = c + break + if id_col is None: + return None + out = df[["row_index", id_col]].copy() + out.columns = ["row_index", "segger_cell_id"] + out["row_index"] = pd.to_numeric(out["row_index"], errors="coerce").astype("Int64") + return out.dropna(subset=["row_index"]).copy() + + +def _is_assigned(series: pd.Series) -> pd.Series: + s = series.astype(str).str.strip() + return series.notna() & ~s.eq("") & ~s.str.upper().eq("UNASSIGNED") & ~s.str.lower().eq("nan") + + +def _select_small_fovs( + tx: pd.DataFrame, + n: int, + max_tx: int, + min_tx: int, + seed: int, +) -> list[tuple[float, float, float, float, int]]: + if tx.empty: + return [] + x = pd.to_numeric(tx["x"], errors="coerce").to_numpy(dtype=float) + y = pd.to_numeric(tx["y"], errors="coerce").to_numpy(dtype=float) + finite = np.isfinite(x) & np.isfinite(y) + x = x[finite] + y = y[finite] + if x.size == 0: + return [] + + n_bins = int(np.clip(np.sqrt(max(x.size / max(max_tx, 1), 8)) * 8, 24, 80)) + counts, xedges, yedges = np.histogram2d(x, y, bins=[n_bins, n_bins]) + + candidates = [] + for i in range(n_bins): + for j in range(n_bins): + c = int(counts[i, j]) + if min_tx <= c <= max_tx: + candidates.append((i, j, c)) + + if not candidates: + fallback = [] + for i in range(n_bins): + for j in range(n_bins): + c = int(counts[i, j]) + if c > 0: + fallback.append((i, j, c)) + fallback = sorted(fallback, key=lambda x: x[2]) + candidates = fallback[: max(1, n)] + + rng = np.random.default_rng(seed) + if len(candidates) > n: + order = rng.permutation(len(candidates)) + candidates = [candidates[k] for k in order[:n]] + + windows = [] + for i, j, c in candidates[:n]: + x0, x1 = float(xedges[i]), float(xedges[i + 1]) + y0, y1 = float(yedges[j]), float(yedges[j + 1]) + windows.append((x0, x1, y0, y1, int(c))) + return windows + + +def _convex_hull(points: np.ndarray) -> np.ndarray | None: + if points.ndim != 2 or points.shape[1] != 2: + return None + pts = np.unique(points, axis=0) + if pts.shape[0] < 3: + return None + pts = pts[np.lexsort((pts[:, 1], pts[:, 0]))] + + def cross(o, a, b) -> float: + return float((a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])) + + lower: list[np.ndarray] = [] + for p in pts: + while len(lower) >= 2 and cross(lower[-2], lower[-1], p) <= 0: + lower.pop() + lower.append(p) + + upper: list[np.ndarray] = [] + for p in pts[::-1]: + while len(upper) >= 2 and cross(upper[-2], upper[-1], p) <= 0: + upper.pop() + upper.append(p) + + hull = np.vstack((lower[:-1], upper[:-1])) + if hull.shape[0] < 3: + return None + return hull + + +def _pick_mask_rows(df: pd.DataFrame) -> list[pd.Series]: + ok = _ordered_ok_rows(df) + ok = ok[ok["segmentation_path"].astype(str).str.strip() != ""].copy() + if ok.empty: + return [] + + refs = ok[ok["is_reference_row"]].copy() + refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) + refs = refs.sort_values(by=["_ref_order", "job"]).head(2) + seg = _rank_df(ok[~ok["is_reference_row"]].copy()) + + picks: list[pd.Series] = [] + seen: set[str] = set() + for _, row in refs.iterrows(): + job = _safe_str(row.get("job", "")) + if job not in seen: + seen.add(job) + picks.append(row) + for sub in [seg.head(1), seg.tail(1), seg.head(3)]: + for _, row in sub.iterrows(): + job = _safe_str(row.get("job", "")) + if job not in seen: + seen.add(job) + picks.append(row) + if len(picks) >= 4: + break + if len(picks) >= 4: + break + return picks[:4] + + +def _plot_mask_panel(ax, sub: pd.DataFrame, base_color: str, title: str) -> None: + _clean_axes(ax) + ax.set_title(title, fontsize=7.2, pad=3.0) + ax.set_xticks([]) + ax.set_yticks([]) + + if sub.empty: + ax.text(0.5, 0.5, "no transcripts", transform=ax.transAxes, ha="center", va="center", fontsize=7) + return + + assigned_mask = _is_assigned(sub["segger_cell_id"]) + un = sub[~assigned_mask] + asn = sub[assigned_mask] + if not un.empty: + ax.scatter(un["x"], un["y"], s=1.0, c="#d3d3d3", alpha=0.22, linewidths=0, rasterized=True, zorder=1) + + n_cells = 0 + if not asn.empty and MplPolygon is not None: + for i, (_, grp) in enumerate(asn.groupby("segger_cell_id", sort=False)): + if len(grp) < 4: + continue + hull = _convex_hull(grp[["x", "y"]].to_numpy(dtype=float)) + if hull is None: + continue + t = (i % 8) / 8.0 + face = _mix_hex(base_color, "#ffffff", 0.18 + 0.35 * t) + edge = _mix_hex(base_color, "#0b2038", 0.35) + patch = MplPolygon( + hull, + closed=True, + facecolor=face, + edgecolor=edge, + linewidth=0.28, + alpha=0.48, + zorder=2, + ) + ax.add_patch(patch) + n_cells += 1 + + ax.scatter(asn["x"], asn["y"], s=1.1, c=base_color, alpha=0.44, linewidths=0, rasterized=True, zorder=3) + + ax.text( + 0.02, + 0.02, + f"tx={len(sub)} | cells={n_cells}", + transform=ax.transAxes, + ha="left", + va="bottom", + fontsize=6.2, + color="#5b5b5b", + ) + ax.set_aspect("equal", adjustable="box") + + +def _plot_fov_page( + pdf: PdfPages, + df: pd.DataFrame, + input_dir: Path | None, + seed: int, + fov_count: int, + fov_max_tx: int, + fov_min_tx: int, +) -> None: + if input_dir is None: + fig, ax = plt.subplots(figsize=(11, 3.8)) + ax.axis("off") + ax.text(0.5, 0.5, "FOV panels skipped: --input-dir not provided", ha="center", va="center", fontsize=10) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + tx = _load_transcript_xy(input_dir) + if tx is None or tx.empty: + fig, ax = plt.subplots(figsize=(11, 3.8)) + ax.axis("off") + ax.text(0.5, 0.5, "FOV panels skipped: transcripts.parquet unavailable", ha="center", va="center", fontsize=10) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + rows = _pick_mask_rows(df) + if not rows: + fig, ax = plt.subplots(figsize=(11, 3.8)) + ax.axis("off") + ax.text(0.5, 0.5, "FOV panels skipped: no usable segmentation rows", ha="center", va="center", fontsize=10) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + windows = _select_small_fovs( + tx=tx, + n=max(1, fov_count), + max_tx=max(100, fov_max_tx), + min_tx=max(10, fov_min_tx), + seed=seed, + ) + if not windows: + fig, ax = plt.subplots(figsize=(11, 3.8)) + ax.axis("off") + ax.text(0.5, 0.5, "FOV panels skipped: unable to find small windows", ha="center", va="center", fontsize=10) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + return + + tx_fovs = [] + for x0, x1, y0, y1, n_tx in windows: + sub = tx[(tx["x"] >= x0) & (tx["x"] <= x1) & (tx["y"] >= y0) & (tx["y"] <= y1)].copy() + tx_fovs.append((x0, x1, y0, y1, n_tx, sub)) + + nrows = len(rows) + ncols = len(tx_fovs) + fig_w = max(7.0, 3.9 * ncols) + fig_h = max(5.2, 2.7 * nrows) + fig, axes = plt.subplots(nrows, ncols, figsize=(fig_w, fig_h), squeeze=False) + + for r, row in enumerate(rows): + seg_path = Path(_safe_str(row.get("segmentation_path", "")).strip()) + seg_df = _load_seg_assign(seg_path) + if seg_df is None or seg_df.empty: + for c in range(ncols): + ax = axes[r, c] + ax.axis("off") + ax.text(0.5, 0.5, f"missing segmentation\n{_panel_title_from_row(row)}", ha="center", va="center", fontsize=7) + continue + seg_map = seg_df.set_index("row_index")["segger_cell_id"] + base_color = _safe_str(row.get("_plot_color", "#4f6f8f")) + method_name = _panel_title_from_row(row) + + for c, (x0, x1, y0, y1, n_tx, sub_tx) in enumerate(tx_fovs): + ax = axes[r, c] + sub = sub_tx.copy() + sub["segger_cell_id"] = sub["row_index"].map(seg_map) + title = f"{method_name} | FOV{c+1} ({n_tx} tx)" + _plot_mask_panel(ax, sub, base_color=base_color, title=title) + ax.set_xlim(x0, x1) + ax.set_ylim(y0, y1) + + fig.suptitle(f"Cell-mask FOV panels (convex hulls, < {fov_max_tx} transcripts/FOV)", fontsize=11, y=0.995) + fig.tight_layout(rect=[0, 0, 1, 0.98]) + pdf.savefig(fig, bbox_inches="tight") + plt.close(fig) + + +def build_report( + root: Path, + validation_tsv: Path, + out_pdf: Path, + input_dir: Path | None, + seed: int, + fov_count: int, + fov_max_tx: int, + fov_min_tx: int, + umap_max_cells: int, +) -> None: + if plt is None or PdfPages is None: + raise RuntimeError("matplotlib is required for PDF report generation") + if not validation_tsv.exists(): + raise FileNotFoundError(f"Validation TSV not found: {validation_tsv}") + + df = _load_metrics(validation_tsv) + out_pdf.parent.mkdir(parents=True, exist_ok=True) + _apply_report_style() + + with PdfPages(out_pdf) as pdf: + _plot_bar_page(pdf, df) + _plot_scatter_page(pdf, df) + _plot_heatmap_page(pdf, df) + _plot_umap_page(pdf, df, seed=seed, umap_max_cells=max(1000, umap_max_cells)) + _plot_fov_page( + pdf, + df, + input_dir=input_dir, + seed=seed, + fov_count=max(1, fov_count), + fov_max_tx=max(50, fov_max_tx), + fov_min_tx=max(5, min(fov_min_tx, fov_max_tx)), + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Build benchmark multi-page PDF report") + parser.add_argument("--root", type=Path, default=Path("./results/mossi_main_big_benchmark_nightly")) + parser.add_argument("--validation-tsv", type=Path, default=None) + parser.add_argument("--out-pdf", type=Path, default=None) + parser.add_argument("--input-dir", type=Path, default=None) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--fov-count", type=int, default=2) + parser.add_argument("--fov-max-transcripts", type=int, default=2000) + parser.add_argument("--fov-min-transcripts", type=int, default=150) + parser.add_argument("--umap-max-cells", type=int, default=12000) + args = parser.parse_args() + + root = args.root + validation_tsv = args.validation_tsv or (root / "summaries" / "validation_metrics.tsv") + out_pdf = args.out_pdf or (root / "summaries" / "benchmark_report.pdf") + + build_report( + root=root, + validation_tsv=validation_tsv, + out_pdf=out_pdf, + input_dir=args.input_dir, + seed=args.seed, + fov_count=max(1, args.fov_count), + fov_max_tx=max(50, args.fov_max_transcripts), + fov_min_tx=max(5, args.fov_min_transcripts), + umap_max_cells=max(1000, args.umap_max_cells), + ) + print(f"Wrote report: {out_pdf}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/build_benchmark_validation_table.sh b/scripts/build_benchmark_validation_table.sh new file mode 100755 index 0000000..64f33ba --- /dev/null +++ b/scripts/build_benchmark_validation_table.sh @@ -0,0 +1,783 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'EOF' +Build per-job validation metrics table for a benchmark root. + +Usage: + bash scripts/build_benchmark_validation_table.sh [options] + +Options: + --root Benchmark root (default: ./results/mossi_main_big_benchmark_nightly) + --input-dir Source dataset path for contamination/geometry/doublet metrics (optional) + --out-tsv Output TSV (default: /summaries/validation_metrics.tsv) + --recompute Recompute all jobs even if already present in output TSV + --segger-bin Segger executable/command (default: segger) + --me-gene-pairs-path Optional ME-gene pair file passed to segger validate + --scrna-reference-path Optional scRNA h5ad passed to segger validate + --scrna-celltype-column scRNA cell type column (default: cell_type) + --max-me-gene-pairs Max sampled ME-gene pairs (default: 500) + --gpu-a GPU id used for group A labels (default: env GPU_A or 0) + --gpu-b GPU id used for group B labels (default: env GPU_B or 1) + --include-default-10x Include ref_10x_cell/ref_10x_nucleus rows (default: true) + --reference-universe-seg Canonical Segger universe segmentation override + -h, --help Show this help +EOF +} + +timestamp() { + date '+%Y-%m-%d %H:%M:%S' +} + +sanitize_tsv_field() { + local value="${1:-}" + value="${value//$'\t'/ }" + value="${value//$'\r'/ }" + value="${value//$'\n'/ }" + printf '%s' "${value}" +} + +normalize_token() { + local value="${1:-}" + value="$(sanitize_tsv_field "${value}")" + if [[ -z "${value}" ]]; then + printf '%s' "-" + else + printf '%s' "${value}" + fi +} + +normalize_bool() { + local v="${1:-}" + local lc + lc="$(printf '%s' "${v}" | tr '[:upper:]' '[:lower:]')" + case "${lc}" in + 1|true|t|yes|y|on) + printf 'true' + ;; + 0|false|f|no|n|off) + printf 'false' + ;; + *) + return 1 + ;; + esac +} + +ROOT="./results/mossi_main_big_benchmark_nightly" +INPUT_DIR="" +OUT_TSV="" +SEGGER_BIN="segger" +ME_GENE_PAIRS_PATH="" +SCRNA_REFERENCE_PATH="" +SCRNA_CELLTYPE_COLUMN="cell_type" +MAX_ME_GENE_PAIRS=500 +GPU_A="${GPU_A:-0}" +GPU_B="${GPU_B:-1}" +INCLUDE_DEFAULT_10X="true" +REFERENCE_UNIVERSE_SEG="" +RECOMPUTE=0 + +require_value() { + local opt="$1" + if [[ $# -lt 2 ]] || [[ -z "${2}" ]] || [[ "${2}" == -* ]]; then + echo "ERROR: ${opt} requires a value." >&2 + exit 1 + fi +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --root) + require_value "$1" "${2-}" + ROOT="$2" + shift 2 + ;; + --input-dir) + require_value "$1" "${2-}" + INPUT_DIR="$2" + shift 2 + ;; + --out-tsv) + require_value "$1" "${2-}" + OUT_TSV="$2" + shift 2 + ;; + --recompute) + RECOMPUTE=1 + shift + ;; + --segger-bin) + require_value "$1" "${2-}" + SEGGER_BIN="$2" + shift 2 + ;; + --me-gene-pairs-path) + require_value "$1" "${2-}" + ME_GENE_PAIRS_PATH="$2" + shift 2 + ;; + --scrna-reference-path) + require_value "$1" "${2-}" + SCRNA_REFERENCE_PATH="$2" + shift 2 + ;; + --scrna-celltype-column) + require_value "$1" "${2-}" + SCRNA_CELLTYPE_COLUMN="$2" + shift 2 + ;; + --max-me-gene-pairs) + require_value "$1" "${2-}" + MAX_ME_GENE_PAIRS="$2" + shift 2 + ;; + --gpu-a) + require_value "$1" "${2-}" + GPU_A="$2" + shift 2 + ;; + --gpu-b) + require_value "$1" "${2-}" + GPU_B="$2" + shift 2 + ;; + --include-default-10x) + require_value "$1" "${2-}" + INCLUDE_DEFAULT_10X="$2" + shift 2 + ;; + --reference-universe-seg) + require_value "$1" "${2-}" + REFERENCE_UNIVERSE_SEG="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage + exit 1 + ;; + esac +done + +if [[ -z "${OUT_TSV}" ]]; then + OUT_TSV="${ROOT}/summaries/validation_metrics.tsv" +fi + +if ! [[ "${MAX_ME_GENE_PAIRS}" =~ ^[0-9]+$ ]] || [[ "${MAX_ME_GENE_PAIRS}" -le 0 ]]; then + echo "ERROR: --max-me-gene-pairs must be a positive integer." >&2 + exit 1 +fi + +if ! INCLUDE_DEFAULT_10X="$(normalize_bool "${INCLUDE_DEFAULT_10X}")"; then + echo "ERROR: --include-default-10x must be true/false." >&2 + exit 1 +fi + +if [[ -n "${INPUT_DIR}" ]] && [[ ! -e "${INPUT_DIR}" ]]; then + echo "ERROR: --input-dir not found: ${INPUT_DIR}" >&2 + exit 1 +fi + +if [[ -n "${ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ME_GENE_PAIRS_PATH}" ]]; then + echo "ERROR: --me-gene-pairs-path not found: ${ME_GENE_PAIRS_PATH}" >&2 + exit 1 +fi + +if [[ -n "${SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: --scrna-reference-path not found: ${SCRNA_REFERENCE_PATH}" >&2 + exit 1 +fi + +if [[ -n "${REFERENCE_UNIVERSE_SEG}" ]] && [[ ! -f "${REFERENCE_UNIVERSE_SEG}" ]]; then + echo "ERROR: --reference-universe-seg not found: ${REFERENCE_UNIVERSE_SEG}" >&2 + exit 1 +fi + +SEGGER_CMD_PATH="" +if [[ "${SEGGER_BIN}" == */* ]]; then + if [[ ! -x "${SEGGER_BIN}" ]]; then + echo "ERROR: --segger-bin is not executable: ${SEGGER_BIN}" >&2 + exit 1 + fi + SEGGER_CMD_PATH="${SEGGER_BIN}" +else + if ! SEGGER_CMD_PATH="$(command -v "${SEGGER_BIN}" 2>/dev/null)"; then + echo "ERROR: segger command not found: ${SEGGER_BIN}" >&2 + exit 1 + fi +fi + +SEGGER_PYTHON="$(dirname "${SEGGER_CMD_PATH}")/python" +if [[ ! -x "${SEGGER_PYTHON}" ]]; then + if command -v python3 >/dev/null 2>&1; then + SEGGER_PYTHON="$(command -v python3)" + else + echo "ERROR: Could not resolve python interpreter for reference artifact builder." >&2 + exit 1 + fi +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REFERENCE_BUILDER_PY="${SCRIPT_DIR}/build_default_10x_reference_artifacts.py" +if [[ "${INCLUDE_DEFAULT_10X}" == "true" ]] && [[ ! -f "${REFERENCE_BUILDER_PY}" ]]; then + echo "ERROR: Missing reference artifact builder: ${REFERENCE_BUILDER_PY}" >&2 + exit 1 +fi + +PLAN_FILE="${ROOT}/job_plan.tsv" +RUNS_DIR="${ROOT}/runs" +EXPORTS_DIR="${ROOT}/exports" +SUMMARY_DIR="${ROOT}/summaries" +LOG_FILE="${SUMMARY_DIR}/validation_metrics.log" +REFERENCE_ARTIFACTS_DIR="${SUMMARY_DIR}/reference_artifacts" + +if [[ ! -f "${PLAN_FILE}" ]]; then + echo "ERROR: Missing benchmark plan file: ${PLAN_FILE}" >&2 + exit 1 +fi + +mkdir -p "${SUMMARY_DIR}" "$(dirname "${OUT_TSV}")" "${REFERENCE_ARTIFACTS_DIR}" + +tmp_out="$(mktemp)" +tmp_row="$(mktemp "${TMPDIR:-/tmp}/segger_validate_row.XXXXXX.tsv")" +trap 'rm -f "${tmp_out}" "${tmp_row}"' EXIT + +METRIC_SCHEMA_VERSION="2026-02-25-v8" +RUN_INPUT_DIR_TOKEN="$(normalize_token "${INPUT_DIR}")" +RUN_SCRNA_REFERENCE_TOKEN="$(normalize_token "${SCRNA_REFERENCE_PATH}")" +RUN_ME_GENE_PAIRS_TOKEN="$(normalize_token "${ME_GENE_PAIRS_PATH}")" + +REFERENCE_UNIVERSE_SEG_RESOLVED="" +if [[ -n "${REFERENCE_UNIVERSE_SEG}" ]]; then + REFERENCE_UNIVERSE_SEG_RESOLVED="${REFERENCE_UNIVERSE_SEG}" +elif [[ -f "${RUNS_DIR}/baseline/segger_segmentation.parquet" ]]; then + REFERENCE_UNIVERSE_SEG_RESOLVED="${RUNS_DIR}/baseline/segger_segmentation.parquet" +else + first_run_seg="$(find "${RUNS_DIR}" -mindepth 2 -maxdepth 2 -type f -name 'segger_segmentation.parquet' | sort | head -n1 || true)" + if [[ -n "${first_run_seg}" ]]; then + REFERENCE_UNIVERSE_SEG_RESOLVED="${first_run_seg}" + elif [[ -n "${INPUT_DIR}" ]] && [[ -f "${INPUT_DIR}/segger_segmentation.parquet" ]]; then + REFERENCE_UNIVERSE_SEG_RESOLVED="${INPUT_DIR}/segger_segmentation.parquet" + fi +fi +RUN_REFERENCE_UNIVERSE_TOKEN="$(normalize_token "${REFERENCE_UNIVERSE_SEG_RESOLVED}")" + +OUTPUT_HEADER=$'job\tgroup\tgpu\tis_reference\treference_kind\tvalidate_status\tvalidate_error\tgpu_time_s\tcells\tassigned_pct\tassigned_ci95\tmecr\tmecr_ci95\tcontamination_pct\tcontamination_ci95\tresolvi_contamination_pct\tresolvi_contamination_ci95\ttco\ttco_ci95\tdoublet_pct\tdoublet_ci95\tsegmentation_path\tanndata_path\toutput_path\tupdated_at\tmetric_schema_version\trun_input_dir\trun_scrna_reference_path\trun_me_gene_pairs_path\trun_reference_universe_seg' +reuse_existing=0 +if [[ "${RECOMPUTE}" != "1" ]] && [[ -s "${OUT_TSV}" ]]; then + existing_header="$(head -n1 "${OUT_TSV}" || true)" + if [[ "${existing_header}" == "${OUTPUT_HEADER}" ]]; then + reuse_existing=1 + else + echo "[$(timestamp)] WARN existing TSV header mismatch; recomputing all jobs" >> "${LOG_FILE}" + fi +fi + +get_field() { + local column_name="$1" + if [[ ! -s "${tmp_row}" ]]; then + return 0 + fi + awk -F'\t' -v key="${column_name}" ' + NR == 1 { + for (i = 1; i <= NF; i++) { + if ($i == key) { + idx = i + break + } + } + next + } + NR == 2 { + if (idx > 0) { + print $idx + } + exit + } + ' "${tmp_row}" +} + +get_existing_field_by_job() { + local job_name="$1" + local column_name="$2" + awk -F'\t' -v j="${job_name}" -v key="${column_name}" ' + NR == 1 { + for (i = 1; i <= NF; i++) { + if ($i == key) { + idx = i + break + } + } + next + } + $1 == j { + if (idx > 0) { + print $idx + } + exit + } + ' "${OUT_TSV}" +} + +lookup_gpu_time() { + local job_name="$1" + local preferred_gpu="$2" + local elapsed="" + local f row + + if [[ -f "${SUMMARY_DIR}/gpu${preferred_gpu}.tsv" ]]; then + row="$(awk -F'\t' -v j="${job_name}" ' + NR == 1 { + for (i = 1; i <= NF; i++) { + if ($i == "job") job_col = i + if ($i == "elapsed_s") elapsed_col = i + } + next + } + job_col > 0 && $job_col == j { + if (elapsed_col > 0) print $elapsed_col + exit + } + ' "${SUMMARY_DIR}/gpu${preferred_gpu}.tsv")" + if [[ -n "${row}" ]]; then + elapsed="${row}" + fi + fi + + if [[ -z "${elapsed}" ]]; then + for f in "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv"; do + [[ -f "${f}" ]] || continue + row="$(awk -F'\t' -v j="${job_name}" ' + NR == 1 { + for (i = 1; i <= NF; i++) { + if ($i == "job") job_col = i + if ($i == "elapsed_s") elapsed_col = i + } + next + } + job_col > 0 && $job_col == j { + if (elapsed_col > 0) print $elapsed_col + exit + } + ' "${f}")" + if [[ -n "${row}" ]]; then + elapsed="${row}" + break + fi + done + fi + + if [[ -z "${elapsed}" ]]; then + elapsed="0" + fi + printf '%s' "${elapsed}" +} + +scale_frac_to_pct() { + local value="${1:-}" + local lower + lower="$(printf '%s' "${value}" | tr '[:upper:]' '[:lower:]')" + if [[ -z "${value}" ]] || [[ "${lower}" == "nan" ]] || [[ "${lower}" == "none" ]]; then + printf '%s' "nan" + return 0 + fi + awk -v v="${value}" 'BEGIN { printf "%.6f", (v + 0.0) * 100.0 }' +} + +append_row() { + local job="$1" + local group="$2" + local gpu="$3" + local is_reference="$4" + local reference_kind="$5" + local validate_status="$6" + local validate_error="$7" + local gpu_time_s="$8" + local cells="$9" + local assigned_pct="${10}" + local assigned_ci95="${11}" + local mecr="${12}" + local mecr_ci95="${13}" + local contamination_pct="${14}" + local contamination_ci95="${15}" + local resolvi_contamination_pct="${16}" + local resolvi_contamination_ci95="${17}" + local tco="${18}" + local tco_ci95="${19}" + local doublet_pct="${20}" + local doublet_ci95="${21}" + local row_seg_path="${22}" + local row_anndata_path="${23}" + + validate_error="$(sanitize_tsv_field "${validate_error}")" + reference_kind="$(normalize_token "${reference_kind}")" + + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job}" "${group}" "${gpu}" "${is_reference}" "${reference_kind}" "${validate_status}" "${validate_error}" "${gpu_time_s}" \ + "${cells}" "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" "${contamination_pct}" "${contamination_ci95}" \ + "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" "${OUT_TSV}" \ + "$(timestamp)" "${METRIC_SCHEMA_VERSION}" "${RUN_INPUT_DIR_TOKEN}" "${RUN_SCRNA_REFERENCE_TOKEN}" "${RUN_ME_GENE_PAIRS_TOKEN}" "${RUN_REFERENCE_UNIVERSE_TOKEN}" \ + >> "${tmp_out}" +} + +should_reuse_row() { + local job="$1" + local is_reference="$2" + local should_reuse=0 + local reason="" + + if [[ "${reuse_existing}" != "1" ]]; then + printf '0\tno_reuse_mode\n' + return 0 + fi + + existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" + if [[ -z "${existing_row}" ]]; then + printf '0\tmissing_existing_row\n' + return 0 + fi + + existing_status="$(get_existing_field_by_job "${job}" "validate_status")" + existing_metric_schema_version="$(get_existing_field_by_job "${job}" "metric_schema_version")" + existing_run_input_dir="$(get_existing_field_by_job "${job}" "run_input_dir")" + existing_run_scrna_ref="$(get_existing_field_by_job "${job}" "run_scrna_reference_path")" + existing_run_me_pairs="$(get_existing_field_by_job "${job}" "run_me_gene_pairs_path")" + existing_run_ref_universe="$(get_existing_field_by_job "${job}" "run_reference_universe_seg")" + existing_cells="$(get_existing_field_by_job "${job}" "cells")" + existing_is_reference="$(get_existing_field_by_job "${job}" "is_reference")" + + should_reuse=1 + reason="ok" + + if [[ "${existing_status}" != "ok" ]]; then + should_reuse=0 + reason="status=${existing_status}" + elif [[ "${existing_metric_schema_version:-}" != "${METRIC_SCHEMA_VERSION}" ]]; then + should_reuse=0 + reason="schema_mismatch" + elif [[ "${existing_run_input_dir:-}" != "${RUN_INPUT_DIR_TOKEN}" ]] || \ + [[ "${existing_run_scrna_ref:-}" != "${RUN_SCRNA_REFERENCE_TOKEN}" ]] || \ + [[ "${existing_run_me_pairs:-}" != "${RUN_ME_GENE_PAIRS_TOKEN}" ]] || \ + [[ "${existing_run_ref_universe:-}" != "${RUN_REFERENCE_UNIVERSE_TOKEN}" ]]; then + should_reuse=0 + reason="validation_inputs_changed" + elif [[ "${existing_is_reference:-0}" != "${is_reference}" ]]; then + should_reuse=0 + reason="reference_flag_changed" + elif [[ -z "${existing_cells:-}" ]] || [[ "$(printf '%s' "${existing_cells}" | tr '[:upper:]' '[:lower:]')" == "nan" ]]; then + should_reuse=0 + reason="missing_cells" + elif [[ -n "${INPUT_DIR}" ]]; then + existing_tco="$(get_existing_field_by_job "${job}" "tco")" + existing_contam="$(get_existing_field_by_job "${job}" "contamination_pct")" + existing_doublet="$(get_existing_field_by_job "${job}" "doublet_pct")" + if [[ -z "${existing_tco:-}" ]] || [[ -z "${existing_contam:-}" ]] || [[ -z "${existing_doublet:-}" ]]; then + should_reuse=0 + reason="missing_input_dependent_metrics" + elif [[ -n "${SCRNA_REFERENCE_PATH}" ]]; then + existing_resolvi="$(get_existing_field_by_job "${job}" "resolvi_contamination_pct")" + if [[ -z "${existing_resolvi:-}" ]]; then + should_reuse=0 + reason="missing_resolvi_metric" + fi + fi + fi + + printf '%s\t%s\n' "${should_reuse}" "${reason}" +} + +run_validate_for_row() { + local job="$1" + local seg_path="$2" + local anndata_path="$3" + local apply_elapsed_fallback="$4" + + validate_status="ok" + validate_error="" + cells="0" + assigned_pct="nan" + assigned_ci95="nan" + mecr="nan" + mecr_ci95="nan" + contamination_pct="nan" + contamination_ci95="nan" + resolvi_contamination_pct="nan" + resolvi_contamination_ci95="nan" + tco="nan" + tco_ci95="nan" + doublet_pct="nan" + doublet_ci95="nan" + row_seg_path="${seg_path}" + row_anndata_path="" + + if [[ ! -f "${seg_path}" ]]; then + validate_status="missing_segmentation" + validate_error="segger_segmentation.parquet not found" + return 0 + fi + + : > "${tmp_row}" + cmd=( + "${SEGGER_BIN}" validate + -s "${seg_path}" + -o "${tmp_row}" + --max-me-gene-pairs "${MAX_ME_GENE_PAIRS}" + ) + if [[ -f "${anndata_path}" ]]; then + cmd+=(-a "${anndata_path}") + fi + if [[ -n "${INPUT_DIR}" ]]; then + cmd+=(-i "${INPUT_DIR}") + fi + if [[ -n "${ME_GENE_PAIRS_PATH}" ]]; then + cmd+=(--me-gene-pairs-path "${ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${SCRNA_REFERENCE_PATH}" ]]; then + cmd+=( + --scrna-reference-path "${SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${SCRNA_CELLTYPE_COLUMN}" + ) + fi + + { + printf '[%s] job=%s CMD:' "$(timestamp)" "${job}" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${LOG_FILE}" + + if "${cmd[@]}" >> "${LOG_FILE}" 2>&1; then + parsed_status="$(get_field "validate_status")" + [[ -n "${parsed_status}" ]] && validate_status="${parsed_status}" + parsed_error="$(get_field "validate_error")" + [[ -n "${parsed_error}" ]] && validate_error="${parsed_error}" + + parsed_elapsed="$(get_field "elapsed_s")" + if [[ "${apply_elapsed_fallback}" == "1" ]] && [[ -n "${parsed_elapsed}" ]] && [[ "${gpu_time_s}" == "0" || -z "${gpu_time_s}" ]]; then + gpu_time_s="${parsed_elapsed}" + fi + + parsed_val="$(get_field "cells_total")" + if [[ -n "${parsed_val}" ]] && [[ "$(printf '%s' "${parsed_val}" | tr '[:upper:]' '[:lower:]')" != "nan" ]]; then + cells="${parsed_val}" + else + parsed_val="$(get_field "cells_assigned")" + [[ -n "${parsed_val}" ]] && cells="${parsed_val}" + fi + + parsed_val="$(get_field "transcripts_assigned_pct")" + [[ -n "${parsed_val}" ]] && assigned_pct="${parsed_val}" + parsed_val="$(get_field "transcripts_assigned_pct_ci95")" + [[ -n "${parsed_val}" ]] && assigned_ci95="${parsed_val}" + + parsed_val="$(get_field "mecr_fast")" + [[ -n "${parsed_val}" ]] && mecr="${parsed_val}" + parsed_val="$(get_field "mecr_ci95_fast")" + [[ -n "${parsed_val}" ]] && mecr_ci95="${parsed_val}" + + parsed_val="$(get_field "border_contaminated_cells_pct_fast")" + [[ -n "${parsed_val}" ]] && contamination_pct="${parsed_val}" + parsed_val="$(get_field "border_contaminated_cells_pct_ci95_fast")" + [[ -n "${parsed_val}" ]] && contamination_ci95="${parsed_val}" + + parsed_val="$(get_field "resolvi_contamination_pct_fast")" + [[ -n "${parsed_val}" ]] && resolvi_contamination_pct="${parsed_val}" + parsed_val="$(get_field "resolvi_contamination_ci95_fast")" + [[ -n "${parsed_val}" ]] && resolvi_contamination_ci95="${parsed_val}" + + parsed_val="$(get_field "transcript_centroid_offset_fast")" + [[ -n "${parsed_val}" ]] && tco="${parsed_val}" + parsed_val="$(get_field "transcript_centroid_offset_ci95_fast")" + [[ -n "${parsed_val}" ]] && tco_ci95="${parsed_val}" + + parsed_val="$(get_field "signal_doublet_like_fraction_fast")" + if [[ -n "${parsed_val}" ]]; then + doublet_pct="$(scale_frac_to_pct "${parsed_val}")" + fi + parsed_val="$(get_field "signal_doublet_like_fraction_ci95_fast")" + if [[ -n "${parsed_val}" ]]; then + doublet_ci95="$(scale_frac_to_pct "${parsed_val}")" + fi + + parsed_val="$(get_field "segmentation_path")" + [[ -n "${parsed_val}" ]] && row_seg_path="${parsed_val}" + parsed_val="$(get_field "anndata_path")" + [[ -n "${parsed_val}" ]] && row_anndata_path="${parsed_val}" + else + validate_status="validate_command_failed" + validate_error="segger validate command failed" + fi +} + +printf "%s\n" "${OUTPUT_HEADER}" > "${tmp_out}" + +echo "[$(timestamp)] START benchmark validation table build" >> "${LOG_FILE}" +echo "[$(timestamp)] ROOT=${ROOT}" >> "${LOG_FILE}" +echo "[$(timestamp)] OUT_TSV=${OUT_TSV}" >> "${LOG_FILE}" +echo "[$(timestamp)] RECOMPUTE=${RECOMPUTE}" >> "${LOG_FILE}" +echo "[$(timestamp)] INCLUDE_DEFAULT_10X=${INCLUDE_DEFAULT_10X}" >> "${LOG_FILE}" +echo "[$(timestamp)] REFERENCE_UNIVERSE_SEG=${RUN_REFERENCE_UNIVERSE_TOKEN}" >> "${LOG_FILE}" + +reused_count=0 +computed_count=0 + +while IFS=$'\t' read -r \ + job group _use3d _expansion _txk _txdist _layers _heads _cellsmin _minqv _alignment; do + if [[ -z "${job:-}" ]] || [[ "${job}" == "job" ]]; then + continue + fi + + reuse_info="$(should_reuse_row "${job}" "0")" + should_reuse="${reuse_info%%$'\t'*}" + reuse_reason="${reuse_info#*$'\t'}" + if [[ "${should_reuse}" == "1" ]]; then + existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" + printf "%s\n" "${existing_row}" >> "${tmp_out}" + reused_count=$((reused_count + 1)) + echo "[$(timestamp)] SKIP job=${job}: existing row reused" >> "${LOG_FILE}" + continue + fi + if [[ "${reuse_existing}" == "1" ]]; then + echo "[$(timestamp)] RECOMPUTE job=${job}: ${reuse_reason}" >> "${LOG_FILE}" + fi + + computed_count=$((computed_count + 1)) + + gpu="${GPU_A}" + if [[ "${group}" == "B" ]]; then + gpu="${GPU_B}" + fi + + seg_path="${RUNS_DIR}/${job}/segger_segmentation.parquet" + anndata_path="${EXPORTS_DIR}/${job}/anndata/segger_segmentation.h5ad" + gpu_time_s="$(lookup_gpu_time "${job}" "${gpu}")" + + start_ts="$(date +%s)" + run_validate_for_row "${job}" "${seg_path}" "${anndata_path}" "1" + end_ts="$(date +%s)" + if [[ "${gpu_time_s}" == "0" || -z "${gpu_time_s}" ]]; then + gpu_time_s="$((end_ts - start_ts))" + fi + + append_row \ + "${job}" "${group}" "${gpu}" "0" "-" \ + "${validate_status}" "${validate_error}" "${gpu_time_s}" "${cells}" \ + "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" \ + "${contamination_pct}" "${contamination_ci95}" "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" \ + "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" +done < "${PLAN_FILE}" + +if [[ "${INCLUDE_DEFAULT_10X}" == "true" ]]; then + for reference_kind in 10x_cell 10x_nucleus; do + job="ref_${reference_kind}" + group="R" + gpu="-" + + reuse_info="$(should_reuse_row "${job}" "1")" + should_reuse="${reuse_info%%$'\t'*}" + reuse_reason="${reuse_info#*$'\t'}" + if [[ "${should_reuse}" == "1" ]]; then + existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" + printf "%s\n" "${existing_row}" >> "${tmp_out}" + reused_count=$((reused_count + 1)) + echo "[$(timestamp)] SKIP job=${job}: existing row reused" >> "${LOG_FILE}" + continue + fi + if [[ "${reuse_existing}" == "1" ]]; then + echo "[$(timestamp)] RECOMPUTE job=${job}: ${reuse_reason}" >> "${LOG_FILE}" + fi + + computed_count=$((computed_count + 1)) + gpu_time_s="0" + validate_status="ok" + validate_error="" + cells="0" + assigned_pct="nan" + assigned_ci95="nan" + mecr="nan" + mecr_ci95="nan" + contamination_pct="nan" + contamination_ci95="nan" + resolvi_contamination_pct="nan" + resolvi_contamination_ci95="nan" + tco="nan" + tco_ci95="nan" + doublet_pct="nan" + doublet_ci95="nan" + row_seg_path="${REFERENCE_ARTIFACTS_DIR}/${job}/segger_segmentation.parquet" + row_anndata_path="${REFERENCE_ARTIFACTS_DIR}/${job}/segger_segmentation.h5ad" + + if [[ -z "${INPUT_DIR}" ]]; then + validate_status="missing_input_dir" + validate_error="--input-dir is required for default 10x references" + elif [[ -z "${REFERENCE_UNIVERSE_SEG_RESOLVED}" ]] || [[ ! -f "${REFERENCE_UNIVERSE_SEG_RESOLVED}" ]]; then + validate_status="missing_universe_segmentation" + validate_error="canonical Segger universe segmentation not found" + else + mkdir -p "$(dirname "${row_seg_path}")" + build_cmd=( + "${SEGGER_PYTHON}" "${REFERENCE_BUILDER_PY}" + --input-dir "${INPUT_DIR}" + --canonical-seg "${REFERENCE_UNIVERSE_SEG_RESOLVED}" + --kind "${reference_kind}" + --out-seg "${row_seg_path}" + --out-h5ad "${row_anndata_path}" + ) + + { + printf '[%s] job=%s BUILD_REF_CMD:' "$(timestamp)" "${job}" + printf ' %q' "${build_cmd[@]}" + printf '\n' + } >> "${LOG_FILE}" + + if "${build_cmd[@]}" >> "${LOG_FILE}" 2>&1; then + run_validate_for_row "${job}" "${row_seg_path}" "${row_anndata_path}" "0" + else + validate_status="reference_artifact_failed" + validate_error="failed to build default 10x reference artifacts" + fi + fi + + append_row \ + "${job}" "${group}" "${gpu}" "1" "${reference_kind}" \ + "${validate_status}" "${validate_error}" "${gpu_time_s}" "${cells}" \ + "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" \ + "${contamination_pct}" "${contamination_ci95}" "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" \ + "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" + done +fi + +tmp_sorted="$(mktemp)" +{ + head -n1 "${tmp_out}" + tail -n +2 "${tmp_out}" \ + | awk -F'\t' ' + function norm_num(v, lower) { + lower = tolower(v) + if (v == "" || lower == "nan" || lower == "none") return "" + return v + 0.0 + } + { + assigned = norm_num($10) + mecr = norm_num($12) + if (assigned == "") assigned_key = -1 + else assigned_key = assigned + if (mecr == "") mecr_key = 1e99 + else mecr_key = mecr + printf "%.10f\t%.10f\t%s\n", assigned_key, mecr_key, $0 + } + ' \ + | sort -t $'\t' -k1,1gr -k2,2g -k3,3 \ + | cut -f3- +} > "${tmp_sorted}" +mv "${tmp_sorted}" "${tmp_out}" + +mv "${tmp_out}" "${OUT_TSV}" +echo "[$(timestamp)] WROTE validation table: ${OUT_TSV}" >> "${LOG_FILE}" +echo "[$(timestamp)] SUMMARY reused=${reused_count} computed=${computed_count}" >> "${LOG_FILE}" +echo "Wrote validation table: ${OUT_TSV} (reused=${reused_count}, computed=${computed_count})" diff --git a/scripts/build_default_10x_reference_artifacts.py b/scripts/build_default_10x_reference_artifacts.py new file mode 100755 index 0000000..ddc5787 --- /dev/null +++ b/scripts/build_default_10x_reference_artifacts.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +"""Build temporary 10x reference segmentation + AnnData on Segger row_index universe.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import anndata as ad +import numpy as np +import pandas as pd +import polars as pl +from scipy import sparse as sp + + +def _pick_column(columns: list[str], candidates: list[str], required: bool = True) -> str | None: + for c in candidates: + if c in columns: + return c + if required: + raise ValueError(f"Missing required column; tried: {candidates}") + return None + + +def _clean_cell_id_expr(cell_col: str) -> pl.Expr: + cell_str = pl.col(cell_col).cast(pl.Utf8) + return ( + pl.when( + pl.col(cell_col).is_null() + | (cell_str == "") + | cell_str.str.to_uppercase().is_in(["UNASSIGNED", "NONE", "-1"]) + ) + .then(None) + .otherwise(cell_str) + ) + + +def _nucleus_overlap_expr(overlap_col: str) -> pl.Expr: + overlap_str = pl.col(overlap_col).cast(pl.Utf8).str.to_lowercase() + return overlap_str.is_in(["1", "2", "true", "t", "yes", "y", "nuclear"]) + + +def _build_anndata(tx: pl.DataFrame, out_h5ad: Path) -> tuple[int, int]: + if tx.height == 0: + adata = ad.AnnData( + X=sp.csr_matrix((0, 0)), + obs=pd.DataFrame(index=pd.Index([], name="segger_cell_id")), + var=pd.DataFrame(index=pd.Index([], name="feature_name")), + ) + adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) + return 0, 0 + + assigned = tx.filter( + pl.col("segger_cell_id").is_not_null() & pl.col("feature_name").is_not_null() + ) + if assigned.height == 0: + genes = ( + tx.select("feature_name") + .drop_nulls() + .unique() + .sort("feature_name") + .get_column("feature_name") + .cast(pl.Utf8) + .to_list() + ) + adata = ad.AnnData( + X=sp.csr_matrix((0, len(genes))), + obs=pd.DataFrame(index=pd.Index([], name="segger_cell_id")), + var=pd.DataFrame(index=pd.Index([str(g) for g in genes], name="feature_name")), + ) + adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) + return 0, len(genes) + + feature_idx = ( + assigned.select("feature_name") + .with_columns(pl.col("feature_name").cast(pl.Utf8)) + .unique() + .sort("feature_name") + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned.select("segger_cell_id") + .with_columns(pl.col("segger_cell_id").cast(pl.Utf8)) + .unique() + .sort("segger_cell_id") + .with_row_index(name="_cid") + ) + + mapped = assigned.join(feature_idx, on="feature_name").join(cell_idx, on="segger_cell_id") + counts = mapped.group_by(["_cid", "_fid"]).agg(pl.len().alias("_count")) + + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + X = sp.coo_matrix((data, (rows, cols)), shape=(cell_idx.height, feature_idx.height)).tocsr() + adata = ad.AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(cell_idx.get_column("segger_cell_id").to_list(), name="segger_cell_id")), + var=pd.DataFrame(index=pd.Index(feature_idx.get_column("feature_name").to_list(), name="feature_name")), + ) + + coord_cols = [c for c in ["x", "y", "z"] if c in assigned.columns] + if "x" in coord_cols and "y" in coord_cols: + centroids = ( + assigned.group_by("segger_cell_id") + .agg([pl.col(c).mean().alias(c) for c in coord_cols]) + .to_pandas() + .set_index("segger_cell_id") + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids[coord_cols].to_numpy() + + adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) + return int(adata.n_obs), int(adata.n_vars) + + +def build_reference_artifacts( + *, + input_dir: Path, + canonical_seg: Path, + kind: str, + out_seg: Path, + out_h5ad: Path, +) -> dict[str, object]: + if kind not in {"10x_cell", "10x_nucleus"}: + raise ValueError("kind must be one of: 10x_cell, 10x_nucleus") + + tx_path = input_dir / "transcripts.parquet" + if not tx_path.exists(): + raise FileNotFoundError(f"transcripts.parquet not found under input dir: {tx_path}") + + if not canonical_seg.exists(): + raise FileNotFoundError(f"Canonical universe segmentation missing: {canonical_seg}") + + universe = pl.read_parquet(canonical_seg).select(pl.col("row_index").cast(pl.Int64)).unique().sort("row_index") + + tx_lf = pl.scan_parquet(tx_path, parallel="row_groups").with_row_index(name="row_index") + schema_names = tx_lf.collect_schema().names() + + feature_col = _pick_column(schema_names, ["feature_name", "target", "gene"]) # Xenium/CosMX/MERSCOPE + x_col = _pick_column(schema_names, ["x_location", "x", "global_x", "x_global_px"]) + y_col = _pick_column(schema_names, ["y_location", "y", "global_y", "y_global_px"]) + z_col = _pick_column(schema_names, ["z_location", "z", "global_z"], required=False) + cell_col = _pick_column(schema_names, ["cell_id", "cell"]) + overlap_col = _pick_column(schema_names, ["overlaps_nucleus", "cell_compartment", "CellComp"], required=False) + + select_cols = ["row_index", feature_col, x_col, y_col, cell_col] + if z_col is not None: + select_cols.append(z_col) + if overlap_col is not None: + select_cols.append(overlap_col) + + tx_raw = tx_lf.select(select_cols) + universe_tx = universe.lazy().join(tx_raw, on="row_index", how="left") + + universe_tx = universe_tx.with_columns( + _clean_cell_id_expr(cell_col).alias("_cell_id_clean"), + ) + if kind == "10x_nucleus": + if overlap_col is None: + raise ValueError("Cannot build 10x_nucleus reference: overlaps_nucleus-like column not found") + universe_tx = universe_tx.with_columns( + pl.when(_nucleus_overlap_expr(overlap_col)).then(pl.col("_cell_id_clean")).otherwise(None).alias("segger_cell_id") + ) + else: + universe_tx = universe_tx.with_columns(pl.col("_cell_id_clean").alias("segger_cell_id")) + + rename_map = { + feature_col: "feature_name", + x_col: "x", + y_col: "y", + } + if z_col is not None: + rename_map[z_col] = "z" + + tx = ( + universe_tx + .select(["row_index", "segger_cell_id", *list(rename_map.keys())]) + .rename(rename_map) + .collect() + ) + + out_seg.parent.mkdir(parents=True, exist_ok=True) + out_h5ad.parent.mkdir(parents=True, exist_ok=True) + + seg_df = tx.select(["row_index", "segger_cell_id"]).with_columns(pl.lit(1.0).alias("segger_similarity")) + seg_df.write_parquet(out_seg) + + n_cells, n_genes = _build_anndata(tx, out_h5ad) + + assigned_n = int(tx.filter(pl.col("segger_cell_id").is_not_null()).height) + summary = { + "kind": kind, + "canonical_seg": str(canonical_seg), + "input_transcripts": str(tx_path), + "rows_universe": int(tx.height), + "rows_assigned": assigned_n, + "cells": n_cells, + "genes": n_genes, + "out_seg": str(out_seg), + "out_h5ad": str(out_h5ad), + } + return summary + + +def main() -> int: + parser = argparse.ArgumentParser(description="Build default 10x reference artifacts on Segger universe") + parser.add_argument("--input-dir", type=Path, required=True) + parser.add_argument("--canonical-seg", type=Path, required=True) + parser.add_argument("--kind", choices=["10x_cell", "10x_nucleus"], required=True) + parser.add_argument("--out-seg", type=Path, required=True) + parser.add_argument("--out-h5ad", type=Path, required=True) + args = parser.parse_args() + + summary = build_reference_artifacts( + input_dir=args.input_dir, + canonical_seg=args.canonical_seg, + kind=args.kind, + out_seg=args.out_seg, + out_h5ad=args.out_h5ad, + ) + print(json.dumps(summary, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/presentation/experiments.md b/scripts/presentation/experiments.md new file mode 100644 index 0000000..8706e52 --- /dev/null +++ b/scripts/presentation/experiments.md @@ -0,0 +1,361 @@ +# Segger v0.2.0 Experiment Design + +## Overview + +The `scripts/` directory contains a fully automated benchmarking pipeline that trains, predicts, exports, validates, and reports on Segger segmentation results across a systematic grid of hyperparameter configurations. The dataset is **Xenium pancreas (Mossi)** and the pipeline runs overnight on **N GPUs** (auto-detected or overridable). + +The experiments answer three questions: + +1. **Which hyperparameters matter most for segmentation quality?** (parameter sensitivity) +2. **Are the best configurations stable and robust to perturbations?** (repeatability & robustness) +3. **Which architectural components and loss terms are actually necessary?** (ablation study) + +--- + +## Scripts at a Glance + +| Script | Role | +|--------|------| +| `run_param_benchmark_2gpu.sh` | **Parameter sweep** -- one-factor-at-a-time around a baseline | +| `run_robustness_ablation_2gpu.sh` | **Robustness & ablation** -- stability repeats, interaction grid, stress tests | +| `run_ablation_study.sh` | **Component ablation** -- loss decomposition, architecture, features (auto-detects GPUs) | +| `build_default_10x_reference_artifacts.py` | Build 10x cell/nucleus **reference segmentations** for comparison | +| `build_benchmark_validation_table.sh` | Compute **validation metrics** (MECR, contamination, TCO, doublet, assignment %) for every run | +| `benchmark_status_dashboard.sh` | Live **terminal dashboard** showing progress, failures, and ranked metrics | +| `build_benchmark_pdf_report.py` | Generate a **multi-page PDF** with bar charts, scatter plots, heatmaps, UMAPs, and FOV panels | + +--- + +## Experiment 1: Parameter Sweep (`run_param_benchmark_2gpu.sh`) + +### What it does + +Runs a **one-factor-at-a-time (OFAT)** sensitivity analysis. Starting from a fixed baseline configuration, it varies exactly one parameter at a time while holding all others constant. This isolates the marginal effect of each parameter on segmentation quality. + +### Baseline configuration + +| Parameter | Value | Meaning | +|-----------|-------|---------| +| `use_3d` | `true` | Include z-coordinate in graph construction | +| `expansion_ratio` | `2.0` | Scale factor for boundary polygons (captures edge transcripts) | +| `tx_max_k` | `5` | Max k-nearest-neighbors per transcript node | +| `tx_max_dist` | `5` | Max distance (microns) for tx-tx edges | +| `n_mid_layers` | `2` | Number of GNN message-passing layers | +| `n_heads` | `2` | Number of attention heads in the transformer encoder | +| `cells_min_counts` | `5` | Minimum transcripts per cell to include | +| `alignment_loss` | `true` | Enable ME-gene alignment loss from scRNA reference | + +### Sweep axes + +Each axis varies one parameter while the rest stay at baseline: + +| Axis | Values tested | Why | +|------|--------------|-----| +| **use_3d** | `false`, `true` | Does z-coordinate information improve segmentation? Xenium captures z-stacks, but not all platforms do. Tests whether the model benefits from 3D spatial context. | +| **expansion_ratio** | 1.0, 1.5, 2.0, 2.5, 3.0 | Controls how far beyond the nucleus boundary Segger looks for transcripts. Too small = misses cytoplasmic transcripts (low sensitivity). Too large = captures transcripts from neighboring cells (low specificity / high contamination). | +| **tx_max_k** | 5, 10, 20 | How many transcript neighbors each node connects to. More neighbors = richer local context for the GNN but higher memory/compute cost and potential for over-smoothing. | +| **tx_max_dist** | 3, 5, 10, 20 | Maximum edge distance for tx-tx connections. Interacts with tissue density -- sparse tissue needs larger distances, dense tissue smaller. Affects the receptive field of the GNN. | +| **n_mid_layers** | 1, 2, 3 | Depth of the GNN. Deeper = larger receptive field but risk of over-smoothing. For cell-level segmentation, 1-3 layers is the typical range. | +| **n_heads** | 2, 4, 8 | Attention heads in the IST encoder. More heads = more representational capacity, but diminishing returns and higher cost. | +| **cells_min_counts** | 3, 5, 10 | Minimum transcript threshold to define a cell. Lower = more cells detected (potentially fragments/noise). Higher = stricter, fewer false-positive cells. | +| **alignment_loss** | `false`, `true` | Whether to add the ME-gene constraint loss during training. This loss penalizes co-expression of mutually exclusive genes within the same cell, leveraging scRNA-seq priors. | + +### Logic + +The OFAT design keeps the experiment tractable (approximately 20 jobs instead of a combinatorial explosion of hundreds). Each comparison has a single variable, so any metric change can be attributed directly to that parameter. The trade-off is that OFAT misses parameter interactions -- that's what Experiment 2 addresses. + +### Total jobs + +~20 runs (1 baseline + ~19 single-parameter variants), split across 2 GPUs in round-robin order. Each run: + +1. **Trains** Segger for 20 epochs +2. **Predicts** cell assignments +3. **Exports** to AnnData (.h5ad) and Xenium Explorer format +4. If training OOMs during prediction, falls back to the last checkpoint +5. If an `ancdata` multiprocessing error occurs, retries with 0 workers + +--- + +## Experiment 2: Robustness & Ablation (`run_robustness_ablation_2gpu.sh`) + +### What it does + +Three study blocks that go beyond single-parameter variation: + +### Block A: Stability / Repeatability + +Runs the **same** configuration multiple times (default 3 repeats) to measure run-to-run variance: + +| Config | Repeats | Purpose | +|--------|---------|---------| +| Baseline (legacy) | 3 | Is the original configuration stable across random seeds/initialization? | +| Anchor (current best) | 3 | Is the improved configuration equally stable? | +| High-sensitivity variant | 2 | Does pushing expansion to 3.0 produce consistent results? | + +**Why:** GNN training involves stochastic initialization and mini-batch sampling. If the same hyperparameters produce wildly different metrics across runs, we can't trust the parameter sweep results. Stability repeats give us error bars. + +The **anchor** configuration (expansion=2.5, tx_dist=20, n_heads=4) represents the "current best" derived from early validation trends, distinct from the legacy baseline. + +### Block B: Interaction Grid + +Tests **combinations** of the most impactful parameters simultaneously: + +- `expansion_ratio` x `tx_max_dist` x `n_heads` (2 x 2 x 2 = 8 combinations with alignment=true) +- Plus alignment ablation at each corner (expansion x dist with heads=4, alignment=false): 4 more jobs + +**Why:** The OFAT sweep can't detect interactions. For example, a larger expansion ratio might only help when combined with a larger tx_max_dist (because expanded boundaries need longer-range edges to connect properly). This grid covers the "high-performing region" identified by the sweep -- it's not exhaustive but targets where interactions are most likely to matter. + +The alignment ablation within the grid specifically tests: **does the ME-gene loss help or hurt at different graph configurations?** If the loss only helps in some regimes, that's critical to know. + +### Block C: Stress Tests + +Deliberately pushes single parameters to **extreme or degraded** values to test robustness: + +| Test | What changes | Why | +|------|-------------|-----| +| `stress_use3d_false_anchor` | Drops z from anchor config | Does the anchor config fall apart without 3D? | +| `stress_use3d_false_sens` | Drops z from high-sensitivity config | Same for the aggressive expansion config | +| `stress_cellsmin3_anchor` | Very permissive cell threshold | How noisy do results get with loose filters? | +| `stress_cellsmin10_anchor` | Strict cell threshold | How many cells/transcripts do we lose? | +| `stress_txk20_anchor` | Very dense transcript graph | Does OOM or over-smoothing kick in? | +| `stress_layers1_anchor` | Minimal GNN depth | Can a single layer still segment well? | + +**Why:** Practical deployment means users may have different tissue types, densities, and platform configurations. Stress tests reveal how gracefully the model degrades when conditions shift from the ideal. + +### Total jobs + +~24-28 runs (8 stability + 12 interaction + 6 stress), again split round-robin across 2 GPUs. + +--- + +## Experiment 3: Component Ablation (`run_ablation_study.sh`) + +### What it does + +Systematically removes or swaps individual components -- loss terms, architecture choices, and feature representations -- to measure each one's contribution to segmentation quality. Unlike the OFAT parameter sweep (Experiment 1) which varies continuous hyperparameters, this script tests **discrete design decisions**: "Is this component necessary?" + +### Key differences from the other benchmark scripts + +| Feature | `run_param_benchmark_2gpu.sh` / `run_robustness_ablation_2gpu.sh` | `run_ablation_study.sh` | +|---------|------------------------------------------------------------------|------------------------| +| GPU handling | Hardcoded 2 GPUs (`GPU_A`, `GPU_B`) | Auto-detects N GPUs (round-robin distribution) | +| Job spec | 10 pipe-delimited fields | 21 fields (adds loss weights, architecture, features, LR) | +| Focus | Hyperparameter sensitivity & stability | Component necessity & design decisions | +| Block toggles | `RUN_INTERACTION_GRID`, `RUN_STRESS_TESTS` | 6 toggles: `RUN_LOSS_ABLATION`, `RUN_SGLOSS_ABLATION`, `RUN_ALIGNMENT_SWEEP`, `RUN_ARCH_ABLATION`, `RUN_PREDICTION_ABLATION`, `RUN_LR_ABLATION` | + +### Anchor configuration + +All ablation jobs start from the current best ("anchor") configuration and modify exactly one aspect: + +| Parameter | Anchor value | +|-----------|-------------| +| `use_3d` | `true` | +| `expansion_ratio` | `2.5` | +| `tx_max_k` | `5` | +| `tx_max_dist` | `20` | +| `n_mid_layers` | `2` | +| `n_heads` | `4` | +| `hidden_channels` / `out_channels` | `64` / `64` | +| `sg_loss_type` | `triplet` | +| `tx_weight_end` / `bd_weight_end` / `sg_weight_end` | `1.0` / `1.0` / `0.5` | +| `alignment_loss` | `true` (weight `0.03`) | +| `positional_embeddings` | `true` | +| `normalize_embeddings` | `true` | +| `cells_representation` | `pca` | +| `learning_rate` | `1e-3` | +| `prediction_mode` | `nucleus` | + +### Block A: Loss Decomposition (6 jobs) + +Tests every meaningful subset of the 4 loss terms to determine which are necessary: + +| Job | sg_loss | tx_triplet | bd_metric | alignment | Question | +|-----|:-------:|:----------:|:---------:|:---------:|----------| +| `abl_sg_only` | ON | - | - | - | Is the segmentation loss alone sufficient? | +| `abl_sg_tx` | ON | ON | - | - | Does transcript clustering help? | +| `abl_sg_bd` | ON | - | ON | - | Does boundary clustering help? | +| `abl_sg_tx_bd` | ON | ON | ON | - | Full v1 loss (no alignment) -- the pre-alignment baseline | +| `abl_sg_align` | ON | - | - | ON | Can alignment replace the triplet losses? | +| `abl_full` | ON | ON | ON | ON | Anchor baseline -- should be best | + +**Why:** The multi-task loss has 4 components with scheduled weights. We don't know if the triplet/metric losses for transcript and boundary clustering are necessary, or if they just slow convergence. If `sg_only` performs nearly as well as `full`, the training pipeline can be simplified significantly. + +### Block B: Segmentation Loss Type (2 jobs) + +| Job | sg_loss_type | Notes | +|-----|-------------|-------| +| `abl_sgloss_triplet` | triplet | Current default (margin-based) | +| `abl_sgloss_bce` | bce | Binary cross-entropy (v0.1.0 approach) | + +**Why:** Direct comparison on the same data reveals which formulation produces better assignment boundaries. + +### Block C: Alignment Weight Sweep (5 jobs) + +| Job | alignment_weight_end | Notes | +|-----|---------------------|-------| +| `abl_aw_0` | 0.0 | No alignment (control) | +| `abl_aw_001` | 0.01 | Light regularization | +| `abl_aw_003` | 0.03 | Current default | +| `abl_aw_01` | 0.1 | Strong regularization | +| `abl_aw_03` | 0.3 | Very strong -- may over-regularize | + +**Why:** The alignment loss weight was chosen somewhat arbitrarily. This sweep identifies the sweet spot. If 0.1 beats 0.03 on MECR without hurting assigned %, we should increase it. + +### Block D: Architecture Ablation (10 jobs) + +| Job | What changes | Question | +|-----|-------------|----------| +| `abl_depth_0` | 0 mid layers (in+out only) | Can a non-message-passing encoder segment? | +| `abl_depth_1` | 1 mid layer | Is 2 layers deeper than needed? | +| `abl_depth_3` | 3 mid layers | Does more depth help or over-smooth? | +| `abl_width_32` | 32/32 hidden/out | Can a 4x smaller model match the default? | +| `abl_width_128` | 128/128 hidden/out | Does 4x more capacity help? | +| `abl_heads_1` | 1 attention head | Is multi-head attention necessary? | +| `abl_heads_8` | 8 attention heads | Diminishing returns from more heads? | +| `abl_no_pos` | No positional embeddings | Are spatial encodings redundant given graph structure? | +| `abl_no_norm` | No embedding normalization | Does L2 normalization help or constrain? | +| `abl_morph` | Morphology cell features | Are polygon-derived features better than PCA? | + +**Why:** Each tests whether a specific design choice is earning its complexity. Findings directly inform model simplification or capacity recommendations. + +### Block E: Prediction Mode (2 jobs) + +| Job | prediction_mode | Notes | +|-----|----------------|-------| +| `abl_pred_cell` | cell | All transcripts within cell boundary for training edges | +| `abl_pred_uniform` | uniform | Uniform sampling around boundary | + +The anchor uses `nucleus` mode. These test whether alternative prediction graph construction strategies improve or degrade quality. + +### Block F: Learning Rate (3 jobs) + +| Job | learning_rate | Notes | +|-----|--------------|-------| +| `abl_lr_3e4` | 3e-4 | Conservative (slower convergence) | +| `abl_lr_3e3` | 3e-3 | Aggressive (faster, riskier) | +| `abl_lr_1e2` | 1e-2 | Very aggressive (may diverge) | + +The anchor uses `1e-3`. This identifies whether the learning rate is well-tuned or if training could be faster. + +### Total jobs + +**28 jobs** across 6 blocks. Each block can be toggled independently. Fits in one overnight session on 2+ GPUs. + +### GPU auto-detection + +The script automatically detects available GPUs: + +1. If `CUDA_VISIBLE_DEVICES` is set, counts the comma-separated IDs +2. Otherwise, queries `nvidia-smi --list-gpus` +3. Falls back to 1 GPU if neither is available +4. Can be overridden with `NUM_GPUS=N` + +Jobs are distributed round-robin across all detected GPUs and launched as parallel background processes. + +### Usage + +```bash +# Full ablation (auto-detect GPUs) +bash scripts/run_ablation_study.sh + +# Dry run -- prints job plan and exits +DRY_RUN=1 bash scripts/run_ablation_study.sh + +# Run only loss and architecture blocks on 4 GPUs +RUN_SGLOSS_ABLATION=0 RUN_ALIGNMENT_SWEEP=0 RUN_PREDICTION_ABLATION=0 \ +RUN_LR_ABLATION=0 NUM_GPUS=4 bash scripts/run_ablation_study.sh + +# Override anchor values +ANCHOR_N_HEADS=2 ANCHOR_EXPANSION=3.0 bash scripts/run_ablation_study.sh +``` + +### Recovery and fault tolerance + +Identical to the other benchmark scripts: OOM predict fallback, ancdata retry with reduced workers, timeout enforcement, and a post-run recovery pass that attempts predict-only from saved checkpoints. + +--- + +## Validation Metrics (`build_benchmark_validation_table.sh`) + +After all runs complete, this script calls `segger validate` on every segmentation output and collects metrics into a single TSV: + +| Metric | Direction | What it measures | +|--------|-----------|-----------------| +| **assigned_pct** | higher is better | Fraction of transcripts assigned to a cell (sensitivity) | +| **MECR** | lower is better | Mutually Exclusive Co-expression Rate -- are biologically impossible gene pairs showing up in the same cell? (specificity) | +| **contamination_pct** | lower is better | Fraction of cells with border contamination from neighbors | +| **TCO** | higher is better | Transcript-Centroid Offset -- how well do assigned transcripts cluster toward their cell center | +| **doublet_pct** | lower is better | Fraction of cells that look like merged doublets | + +### Reference baselines + +The script also builds **10x default segmentations** (cell-level and nucleus-only) on the same transcript universe as Segger. This provides an apples-to-apples comparison: "How does Segger compare to the manufacturer's built-in segmentation?" + +`build_default_10x_reference_artifacts.py` handles this by: +1. Reading the raw `transcripts.parquet` from the Xenium dataset +2. Filtering to the same `row_index` universe that Segger used +3. Using the 10x `cell_id` column (all transcripts) or `overlaps_nucleus` (nuclear only) as the assignment +4. Building a matching AnnData for metric computation + +### Incremental computation + +The validation table is **incremental** -- it reuses existing rows if the metric schema version, input paths, and reference universe haven't changed. This means you can re-run after fixing a failed job without recomputing everything. + +--- + +## Dashboard (`benchmark_status_dashboard.sh`) + +A terminal tool that reads the job plan, GPU summary files, and validation TSV to show: + +- Progress bar (done/total) +- State counts (running, pending, failed, done) +- Failure categorization (OOM, timeout, ancdata errors) +- Ranked validation metrics table with bold highlighting on top-2 performers +- Running/failed/retried job details + +Supports `--watch N` for auto-refresh every N seconds during overnight runs. + +--- + +## PDF Report (`build_benchmark_pdf_report.py`) + +Generates a publication-quality multi-page PDF: + +| Page | Content | +|------|---------| +| **Bar charts** | All 6 metrics side-by-side for every run, ranked by overall score. Segger runs in a blue gradient (darker = better), 10x references in orange. | +| **Scatter plots** | Sensitivity vs Contamination, Sensitivity vs MECR -- visualizes the trade-off frontier. | +| **Heatmap** | Normalized 0-1 metric matrix across all runs (cividis colormap). Quick visual comparison. | +| **UMAP panels** | 6 panels showing cell embedding structure for 2 references, 2 best Segger, 2 worst Segger runs. Uses scanpy or sg_utils for dimensionality reduction. | +| **FOV panels** | Small field-of-view cutouts showing actual cell boundaries (convex hulls) overlaid on transcript positions. Compares how different configurations segment the same tissue region. | + +--- + +## Why This Experimental Design + +### The scientific logic + +Segger frames cell segmentation as **link prediction on a heterogeneous graph**. The quality of segmentation depends on: + +1. **Graph topology** -- which transcripts and boundaries are connected (controlled by `expansion_ratio`, `tx_max_k`, `tx_max_dist`, `use_3d`) +2. **Model capacity** -- how the GNN processes the graph (controlled by `n_mid_layers`, `n_heads`) +3. **Training signal** -- what the loss function optimizes (controlled by `alignment_loss`) +4. **Post-processing** -- how results are filtered (controlled by `cells_min_counts`) + +The experiments systematically vary each of these four aspects: + +- **OFAT sweep** identifies which knobs matter most (often expansion ratio and tx_max_dist dominate) +- **Interaction grid** checks if the top parameters synergize or conflict +- **Stability repeats** quantify noise so we know if a 2% MECR improvement is real or random +- **Stress tests** reveal failure modes for the recommended configuration +- **10x references** provide a competitive baseline -- "is Segger actually better?" + +### The engineering logic + +- **N-GPU parallelism** distributes jobs round-robin across all available GPUs (auto-detected or overridable) +- **OOM fallback** (predict from last checkpoint) salvages partially-trained runs +- **Ancdata retry** (reduce dataloader workers) handles a known PyTorch multiprocessing bug +- **Timeout enforcement** (90 min default) prevents a single hung job from blocking the entire queue +- **Post-run recovery pass** catches jobs that failed during prediction but left a usable checkpoint +- **Incremental validation** avoids recomputing expensive metrics when only a few jobs were re-run +- **TSV-based outputs** integrate easily with downstream analysis notebooks diff --git a/scripts/presentation/experiments_plan.md b/scripts/presentation/experiments_plan.md new file mode 100644 index 0000000..85a20c6 --- /dev/null +++ b/scripts/presentation/experiments_plan.md @@ -0,0 +1,327 @@ +# Ablation & Extended Experiment Plan + +## Motivation + +The current benchmarks (see `experiments.md`) answer "which hyperparameter values work best?" via OFAT sweeps and interaction grids. What they **don't** answer is: + +- Which **architectural components** are actually necessary? +- Which **loss terms** contribute signal vs add noise? +- Does the model **generalize** across tissues and platforms? +- Where are the **failure modes**? + +This plan proposes a structured ablation study organized into 5 tiers, from highest expected impact to exploratory. Each experiment isolates one design decision and measures the delta on our 5 validation metrics (assigned %, MECR, contamination, TCO, doublet %). + +--- + +## Tier 1: Loss Function Ablation + +These are the most informative experiments because they test whether each loss term is earning its weight. + +### 1A. Full loss decomposition + +Train with every possible subset of the 4 loss terms: + +| Experiment | tx_triplet | bd_metric | sg_loss | alignment | Expected insight | +|-----------|:---:|:---:|:---:|:---:|---| +| `abl_sg_only` | - | - | ON | - | Minimum viable loss -- is the segmentation loss alone sufficient? | +| `abl_sg+tx` | ON | - | ON | - | Does transcript clustering help? | +| `abl_sg+bd` | - | ON | ON | - | Does boundary clustering help? | +| `abl_sg+tx+bd` | ON | ON | ON | - | Full v1 loss (no alignment) -- the pre-alignment baseline | +| `abl_full` | ON | ON | ON | ON | Current default -- should be best, or alignment is hurting | +| `abl_sg+align` | - | - | ON | ON | Can alignment replace the triplet losses entirely? | + +**Why:** The multi-task loss has 4 components with scheduled weights. We don't know if the triplet/metric losses for transcript and boundary clustering are actually necessary, or if they just slow convergence. If `sg_only` performs nearly as well as `full`, we can simplify the training pipeline significantly. + +### 1B. Segmentation loss type + +| Experiment | sg_loss_type | Notes | +|-----------|-------------|-------| +| `abl_sg_triplet` | triplet | Current default (margin-based) | +| `abl_sg_bce` | bce | Binary cross-entropy with random negatives | + +**Why:** The code supports both but defaults to triplet. BCE was the v0.1.0 approach. Direct comparison on the same data reveals which formulation produces better assignment boundaries. + +### 1C. Alignment loss strength + +| Experiment | alignment_weight_end | Notes | +|-----------|---------------------|-------| +| `abl_align_0` | 0.0 | No alignment (control) | +| `abl_align_001` | 0.01 | Light regularization | +| `abl_align_003` | 0.03 | Current default | +| `abl_align_01` | 0.1 | Strong regularization | +| `abl_align_03` | 0.3 | Very strong -- likely over-regularizes | + +**Why:** The alignment loss weight was chosen somewhat arbitrarily. This sweep identifies the sweet spot. If 0.1 beats 0.03 on MECR without hurting assigned %, we should increase it. If 0.01 is equivalent to 0.03, we're wasting gradient signal. + +### 1D. Loss weight schedule + +| Experiment | Schedule | Notes | +|-----------|---------|-------| +| `abl_sched_cosine` | Cosine ramp (current) | sg: 0 -> 0.5 over training | +| `abl_sched_fixed` | Fixed weights | sg: 0.5 from epoch 0 | +| `abl_sched_linear` | Linear ramp | sg: 0 -> 0.5 linearly | +| `abl_sched_late` | Late activation | sg: 0 for first 50% of training, then 0.5 | + +**Why:** The cosine ramp was designed to let the encoder warm up before the segmentation loss kicks in. If fixed weights perform equally, the schedule adds unnecessary complexity. + +--- + +## Tier 2: Graph Topology Ablation + +The heterogeneous graph has 3 edge types. Each encodes different information. Removing them reveals what the GNN actually needs. + +### 2A. Edge type removal + +| Experiment | tx-tx | tx-bd (ref) | tx-bd (pred) | Expected insight | +|-----------|:---:|:---:|:---:|---| +| `abl_edges_full` | ON | ON | ON | Baseline (all edges) | +| `abl_edges_no_txtx` | - | ON | ON | Is local transcript context necessary? | +| `abl_edges_no_ref` | ON | - | ON | Can the model learn without reference segmentation? | +| `abl_edges_txtx_only` | ON | - | - | Pure transcript clustering (no boundary info) | + +**Why:** tx-tx edges are the most expensive to construct (KDTree over millions of transcripts). If removing them doesn't hurt, we can dramatically speed up data processing. Conversely, if they're essential, we know local context is a critical signal. + +### 2B. tx-tx graph density + +| Experiment | tx_max_k | tx_max_dist | Effective density | +|-----------|---------|------------|-------------------| +| `abl_txtx_sparse` | 3 | 3.0 | Very sparse local context | +| `abl_txtx_default` | 5 | 5.0 | Current default | +| `abl_txtx_medium` | 10 | 10.0 | Medium density | +| `abl_txtx_dense` | 20 | 20.0 | Dense (high memory) | + +**Why:** These two parameters interact -- both must be large for a dense graph. This tests whether a minimal local graph (k=3, d=3) captures the same information as the more expensive default. + +### 2C. Prediction graph mode + +| Experiment | prediction_mode | Notes | +|-----------|----------------|-------| +| `abl_pred_nucleus` | nucleus | Only nuclear transcripts for training edges | +| `abl_pred_cell` | cell | All transcripts within cell boundary | +| `abl_pred_uniform` | uniform | Uniform sampling around boundary | + +**Why:** The prediction graph mode controls which transcript-boundary edges are used during training. Nucleus mode is conservative (high-confidence assignments), cell mode is permissive (more edges, potentially noisier labels). + +--- + +## Tier 3: Feature & Embedding Ablation + +Tests whether the input representations matter or if the GNN can learn from raw structure alone. + +### 3A. Gene embedding source + +| Experiment | Gene features | Notes | +|-----------|--------------|-------| +| `abl_gene_scrnaseq` | scRNA PCA embeddings | Current default (cell-type proportion vectors) | +| `abl_gene_onehot` | One-hot encoding | No biological prior, pure token identity | +| `abl_gene_random` | Random fixed vectors | Control -- no gene-level information at all | +| `abl_gene_learned` | Trainable from scratch | Let the GNN discover gene relationships | + +**Why:** The scRNA-derived gene embeddings encode cell-type co-expression priors. If one-hot performs similarly, the embeddings aren't adding value. If random vectors work, the GNN is learning purely from spatial structure -- which would be a significant finding. + +### 3B. Boundary (cell) features + +| Experiment | cells_representation_mode | Notes | +|-----------|--------------------------|-------| +| `abl_bd_pca` | pca (dim=128) | Current default -- gene expression PCA | +| `abl_bd_morph` | morphology | Polygon-derived features (area, convexity, elongation) | +| `abl_bd_none` | zeros | No boundary features (structure only) | +| `abl_bd_pca_small` | pca (dim=32) | Reduced dimensionality | + +**Why:** Boundary features are expensive to compute (morphology requires polygon operations; PCA requires gene counting). If zeros work, the model learns cell identity purely from connected transcripts. + +### 3C. Positional embeddings + +| Experiment | use_positional_embeddings | Notes | +|-----------|--------------------------|-------| +| `abl_pos_on` | True | Current default -- sinusoidal 2D encoding | +| `abl_pos_off` | False | No spatial encoding in embeddings | + +**Why:** The graph structure already encodes spatial relationships through edge construction. Positional embeddings may be redundant. If removing them doesn't hurt, it simplifies the model. + +### 3D. Embedding normalization + +| Experiment | normalize_embeddings | Notes | +|-----------|---------------------|-------| +| `abl_norm_on` | True | L2-normalize output embeddings (current) | +| `abl_norm_off` | False | Raw unnormalized embeddings | + +**Why:** L2 normalization constrains embeddings to a unit hypersphere, which affects how cosine similarity (used at prediction) distributes. Without normalization, the model can use magnitude as an additional signal. + +--- + +## Tier 4: Architecture Ablation + +Tests structural choices in the GNN itself. + +### 4A. GNN depth + +| Experiment | n_mid_layers | Total layers | Receptive field | +|-----------|-------------|-------------|-----------------| +| `abl_depth_0` | 0 | 2 (in+out only) | 1-hop | +| `abl_depth_1` | 1 | 3 | 2-hop | +| `abl_depth_2` | 2 | 4 | 3-hop (default) | +| `abl_depth_3` | 3 | 5 | 4-hop | +| `abl_depth_4` | 4 | 6 | 5-hop | + +**Why:** Deeper GNNs have larger receptive fields but risk over-smoothing (all node embeddings converge). For cell segmentation, the optimal depth depends on cell size relative to transcript density. 0 layers tests whether a non-message-passing encoder can segment at all. + +### 4B. Model width + +| Experiment | hidden_channels | out_channels | Parameters (approx) | +|-----------|----------------|-------------|---------------------| +| `abl_width_32` | 32 | 32 | ~25% of default | +| `abl_width_64` | 64 | 64 | Default | +| `abl_width_128` | 128 | 128 | ~4x default | +| `abl_width_256` | 256 | 256 | ~16x default | + +**Why:** Determines the capacity vs efficiency tradeoff. If 32 channels perform within 2% of 64, we can deploy a much faster model. + +### 4C. Attention heads + +| Experiment | n_heads | Notes | +|-----------|--------|-------| +| `abl_heads_1` | 1 | Single-head attention (simplest) | +| `abl_heads_2` | 2 | Current default | +| `abl_heads_4` | 4 | Double capacity | +| `abl_heads_8` | 8 | Diminishing returns? | + +**Why:** Multi-head attention allows the model to attend to different relationship types simultaneously. But for a graph with only 3 edge types, 8 heads may be overkill. + +### 4D. Skip connections + +| Experiment | skip_connections | Notes | +|-----------|-----------------|-------| +| `abl_skip_none` | None | Current default (despite class name "SkipGAT") | +| `abl_skip_residual` | Residual add | Standard ResNet-style | + +**Why:** The model class is called SkipGAT but doesn't implement skip connections. Adding them could help with gradient flow in deeper models and would test whether the current architecture is leaving performance on the table. + +--- + +## Tier 5: Generalization & Cross-Dataset + +Tests whether findings transfer beyond the Xenium pancreas dataset. + +### 5A. Cross-tissue (same platform) + +| Experiment | Dataset | Tissue | Density | Notes | +|-----------|---------|--------|---------|-------| +| `gen_pancreas` | Xenium pancreas (Mossi) | Pancreas | Medium | Current benchmark dataset | +| `gen_brain` | Xenium brain | Brain cortex | High | Dense, many cell types | +| `gen_lung` | Xenium lung | Lung | Mixed | Sparse stroma + dense epithelium | +| `gen_tumor` | Xenium tumor (CRC) | Colorectal | Variable | Disordered tissue, heterogeneous | + +**Why:** All current experiments use one dataset. If the optimal hyperparameters shift dramatically between tissues, we need tissue-specific recommendations or a more robust default. + +### 5B. Cross-platform + +| Experiment | Platform | Key differences | +|-----------|---------|----------------| +| `gen_xenium` | 10x Xenium | High QV scores, nuclear boundaries available | +| `gen_merscope` | Vizgen MERSCOPE | FOV-based stitching, polygon boundaries | +| `gen_cosmx` | NanoString CosMx | Different noise profile, z-stacks | + +**Why:** Segger claims platform-agnostic segmentation. Cross-platform experiments validate this claim and identify platform-specific failure modes. + +### 5C. Data efficiency + +| Experiment | Subsample % | Transcripts (approx) | Notes | +|-----------|------------|----------------------|-------| +| `gen_full` | 100% | ~5M | Full dataset | +| `gen_50pct` | 50% | ~2.5M | Moderate reduction | +| `gen_25pct` | 25% | ~1.25M | Aggressive reduction | +| `gen_10pct` | 10% | ~500K | Stress test | + +**Why:** Lower-depth sequencing or smaller gene panels produce fewer transcripts per cell. This tests how gracefully Segger degrades and identifies the minimum data requirement for useful segmentation. + +### 5D. Training data vs inference data shift + +| Experiment | Train on | Predict on | Notes | +|-----------|---------|-----------|-------| +| `gen_same` | Pancreas | Pancreas | Standard (baseline) | +| `gen_transfer_brain` | Pancreas | Brain | Zero-shot cross-tissue | +| `gen_transfer_platform` | Xenium | MERSCOPE | Zero-shot cross-platform | +| `gen_finetune_brain` | Pancreas -> Brain (finetune) | Brain | Few-epoch adaptation | + +**Why:** Tests whether Segger learns general spatial segmentation rules or memorizes pancreas-specific patterns. Transfer learning results determine whether per-tissue training is required. + +--- + +## Implementation Priority + +### Phase 1 (implemented in `run_ablation_study.sh`) + +These are now implemented as the 6 blocks of `scripts/run_ablation_study.sh`: + +1. **1A** Loss decomposition (6 runs) -- Block A: `abl_sg_only`, `abl_sg_tx`, `abl_sg_bd`, `abl_sg_tx_bd`, `abl_sg_align`, `abl_full` +2. **1B** Triplet vs BCE (2 runs) -- Block B: `abl_sgloss_triplet`, `abl_sgloss_bce` +3. **1C** Alignment weight sweep (5 runs) -- Block C: `abl_aw_0` through `abl_aw_03` +4. **3C** Positional embeddings on/off (1 run) -- Block D: `abl_no_pos` +5. **3D** Embedding normalization on/off (1 run) -- Block D: `abl_no_norm` +6. **4A** GNN depth (3 runs) -- Block D: `abl_depth_0`, `abl_depth_1`, `abl_depth_3` +7. **4B** Model width (2 runs) -- Block D: `abl_width_32`, `abl_width_128` +8. **4C** Attention heads (2 runs) -- Block D: `abl_heads_1`, `abl_heads_8` +9. **3B** Boundary features / morphology (1 run) -- Block D: `abl_morph` +10. **2C** Prediction graph mode (2 runs) -- Block E: `abl_pred_cell`, `abl_pred_uniform` +11. **Learning rate** (3 runs) -- Block F: `abl_lr_3e4`, `abl_lr_3e3`, `abl_lr_1e2` + +**Total: 28 runs**, fits in one overnight session on 2+ GPUs. The script auto-detects available GPUs (N-way round-robin) and each block can be toggled independently. + +**Not yet implemented from this tier:** +- **2A** Edge type removal (requires data module changes to selectively drop edge types) +- **1D** Loss weight schedule variants (requires new scheduler options) + +### Phase 2 (requires code changes) + +6. **3A** Gene embedding ablation -- needs a `--gene-embedding-mode` CLI parameter +7. **3B** Boundary feature ablation -- needs a null/zero mode for cell features +8. **4D** Skip connections -- needs `ist_encoder.py` modification +9. **1D** Loss schedule variants -- needs new scheduler options + +### Phase 3 (requires new datasets) + +10. **5A-5D** Cross-tissue and cross-platform generalization +11. **5C** Subsampling experiments + +--- + +## Expected Outcomes + +### What would change our recommendations + +| Finding | Implication | +|---------|------------| +| `sg_only` matches `full` on all metrics | Simplify to single-loss training, 3x faster | +| Alignment weight 0.1 >> 0.03 on MECR | Increase default alignment strength | +| Removing tx-tx edges doesn't hurt | Skip KDTree construction, 2x faster data prep | +| One-hot gene embeddings match scRNA PCA | Remove scRNA reference dependency (major UX win) | +| 32 channels match 64 channels | Deploy 4x smaller model | +| Cross-tissue transfer fails | Need per-tissue training protocol | +| 10% subsample still works | Segger viable for low-depth experiments | + +### What would confirm our design + +| Finding | Implication | +|---------|------------| +| Full multi-task loss >> sg_only | Multi-task learning is justified | +| Alignment improves MECR without hurting assigned % | ME-gene loss is well-calibrated | +| tx-tx edges significantly improve metrics | Local context is essential | +| scRNA embeddings >> one-hot | Biological priors are valuable | +| Stable across 3 repeats (CV < 5%) | Results are trustworthy | +| Cross-tissue transfer works | Architecture is general | + +--- + +## Reporting + +All ablation results flow through the same validation pipeline: + +``` +Train + Predict → segger_segmentation.parquet + → build_benchmark_validation_table.sh (metrics) + → benchmark_status_dashboard.sh (live monitoring) + → build_benchmark_pdf_report.py (publication figures) +``` + +The PDF report automatically ranks all runs (including ablations) by an overall normalized score across assigned %, MECR, contamination, TCO, and doublet %. Ablation results will appear directly in the bar charts and heatmaps alongside the parameter sweep results. diff --git a/scripts/run_ablation_study.sh b/scripts/run_ablation_study.sh new file mode 100755 index 0000000..f3a42c2 --- /dev/null +++ b/scripts/run_ablation_study.sh @@ -0,0 +1,1138 @@ +#!/usr/bin/env bash +set -u -o pipefail + +# ------------------------------------------------------------------------- +# Segger comprehensive ablation study (auto-detect GPUs, N-way parallel) +# ------------------------------------------------------------------------- +# Systematically removes or swaps individual components (loss terms, +# architecture choices, features) to measure their contribution. +# +# Usage: +# bash scripts/run_ablation_study.sh +# +# Optional overrides (environment variables): +# INPUT_DIR=data/xe_pancreas_mossi/ +# OUTPUT_ROOT=./results/mossi_ablation_study +# NUM_GPUS= # Override auto-detected GPU count +# N_EPOCHS=20 +# RESUME_IF_EXISTS=1 +# DRY_RUN=0 +# SEGMENT_TIMEOUT_MIN=90 +# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad +# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type +# SEGMENT_NUM_WORKERS=8 +# SEGMENT_ANC_RETRY_WORKERS=0 +# TORCH_SHARING_STRATEGY=file_system +# RUN_VALIDATION_TABLE=1 +# VALIDATION_SCRIPT=scripts/build_benchmark_validation_table.sh +# +# Block toggles (set to 0 to skip): +# RUN_LOSS_ABLATION=1 +# RUN_SGLOSS_ABLATION=1 +# RUN_ALIGNMENT_SWEEP=1 +# RUN_ARCH_ABLATION=1 +# RUN_PREDICTION_ABLATION=1 +# RUN_LR_ABLATION=1 +# ------------------------------------------------------------------------- + +timestamp() { + date '+%Y-%m-%d %H:%M:%S' +} + +# ------------------------------------------------------------------------- +# GPU detection +# ------------------------------------------------------------------------- +detect_gpus() { + if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + echo $(( $(echo "${CUDA_VISIBLE_DEVICES}" | tr ',' '\n' | wc -l) )) + elif command -v nvidia-smi >/dev/null 2>&1; then + nvidia-smi --list-gpus 2>/dev/null | wc -l + else + echo 1 + fi +} + +NUM_GPUS="${NUM_GPUS:-$(detect_gpus)}" +if [[ "${NUM_GPUS}" -lt 1 ]]; then + NUM_GPUS=1 +fi + +# Build array of GPU IDs (0..N-1, or from CUDA_VISIBLE_DEVICES). +GPU_IDS=() +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + IFS=',' read -ra GPU_IDS <<< "${CUDA_VISIBLE_DEVICES}" +else + for ((g = 0; g < NUM_GPUS; g++)); do + GPU_IDS+=("${g}") + done +fi + +# ------------------------------------------------------------------------- +# Paths and defaults +# ------------------------------------------------------------------------- +DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" +INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" +OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_ablation_study}" + +if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ + [[ ! -d "${INPUT_DIR}" ]] && \ + [[ -d "../data/xe_pancreas_mossi/" ]]; then + INPUT_DIR="../data/xe_pancreas_mossi/" +fi + +N_EPOCHS="${N_EPOCHS:-20}" +PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" + +BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" +BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" +XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" + +RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" +DRY_RUN="${DRY_RUN:-0}" +PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" +SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" +SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) +SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" +SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" +TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" + +# Alignment defaults (needed by anchor and alignment sweep). +ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" +ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" +ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" +ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" + +if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ + [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ + [[ -f "../data/ref_pancreas.h5ad" ]]; then + ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" +fi + +# ------------------------------------------------------------------------- +# Anchor configuration (defaults matching current best config) +# ------------------------------------------------------------------------- +ANCHOR_USE_3D="${ANCHOR_USE_3D:-true}" +ANCHOR_EXPANSION="${ANCHOR_EXPANSION:-2.5}" +ANCHOR_TX_K="${ANCHOR_TX_K:-5}" +ANCHOR_TX_DIST="${ANCHOR_TX_DIST:-20}" +ANCHOR_N_LAYERS="${ANCHOR_N_LAYERS:-2}" +ANCHOR_N_HEADS="${ANCHOR_N_HEADS:-4}" +ANCHOR_CELLS_MIN="${ANCHOR_CELLS_MIN:-5}" +ANCHOR_MIN_QV="${ANCHOR_MIN_QV:-0}" +ANCHOR_ALIGNMENT="${ANCHOR_ALIGNMENT:-true}" +ANCHOR_SG_LOSS="${ANCHOR_SG_LOSS:-triplet}" +ANCHOR_HIDDEN="${ANCHOR_HIDDEN:-64}" +ANCHOR_OUT="${ANCHOR_OUT:-64}" +ANCHOR_TX_WEIGHT="${ANCHOR_TX_WEIGHT:-1.0}" +ANCHOR_BD_WEIGHT="${ANCHOR_BD_WEIGHT:-1.0}" +ANCHOR_SG_WEIGHT="${ANCHOR_SG_WEIGHT:-0.5}" +ANCHOR_ALIGN_WEIGHT="${ANCHOR_ALIGN_WEIGHT:-0.03}" +ANCHOR_POS_EMB="${ANCHOR_POS_EMB:-true}" +ANCHOR_NORM_EMB="${ANCHOR_NORM_EMB:-true}" +ANCHOR_CELLS_REP="${ANCHOR_CELLS_REP:-pca}" +ANCHOR_LR="${ANCHOR_LR:-1e-3}" + +# ------------------------------------------------------------------------- +# Block toggles +# ------------------------------------------------------------------------- +RUN_LOSS_ABLATION="${RUN_LOSS_ABLATION:-1}" +RUN_SGLOSS_ABLATION="${RUN_SGLOSS_ABLATION:-1}" +RUN_ALIGNMENT_SWEEP="${RUN_ALIGNMENT_SWEEP:-1}" +RUN_ARCH_ABLATION="${RUN_ARCH_ABLATION:-1}" +RUN_PREDICTION_ABLATION="${RUN_PREDICTION_ABLATION:-1}" +RUN_LR_ABLATION="${RUN_LR_ABLATION:-1}" + +# ------------------------------------------------------------------------- +# Directories +# ------------------------------------------------------------------------- +RUNS_DIR="${OUTPUT_ROOT}/runs" +EXPORTS_DIR="${OUTPUT_ROOT}/exports" +LOGS_DIR="${OUTPUT_ROOT}/logs" +SUMMARY_DIR="${OUTPUT_ROOT}/summaries" +PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" +RUN_VALIDATION_TABLE="${RUN_VALIDATION_TABLE:-1}" +VALIDATION_SCRIPT="${VALIDATION_SCRIPT:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/build_benchmark_validation_table.sh}" +VALIDATION_INCLUDE_DEFAULT_10X="${VALIDATION_INCLUDE_DEFAULT_10X:-true}" + +mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" + +if [[ ! -d "${INPUT_DIR}" ]]; then + if [[ "${DRY_RUN}" == "1" ]]; then + echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" + else + echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" + exit 1 + fi +fi + +if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then + echo "ERROR: 'segger' command not found in PATH." + exit 1 +fi + +# Check alignment inputs (needed if any ablation uses alignment). +need_alignment_inputs=0 +if [[ "${RUN_LOSS_ABLATION}" == "1" ]] || \ + [[ "${RUN_ALIGNMENT_SWEEP}" == "1" ]] || \ + [[ "${ANCHOR_ALIGNMENT}" == "true" ]]; then + need_alignment_inputs=1 +fi + +if [[ "${need_alignment_inputs}" == "1" ]]; then + if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: Alignment ablation requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." + exit 1 + fi + if [[ "${DRY_RUN}" != "1" ]]; then + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" + exit 1 + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" + exit 1 + fi + fi +fi + +# ========================================================================= +# Job spec: extended 21-field pipe-delimited format +# ========================================================================= +# Fields: +# 1 job_name +# 2 use_3d +# 3 expansion +# 4 tx_k +# 5 tx_dist +# 6 n_layers +# 7 n_heads +# 8 cells_min_counts +# 9 min_qv +# 10 alignment_loss (true/false) +# 11 sg_loss_type (triplet/bce) +# 12 hidden_channels +# 13 out_channels +# 14 tx_weight_end +# 15 bd_weight_end +# 16 sg_weight_end +# 17 alignment_weight_end +# 18 positional_embeddings (true/false) +# 19 normalize_embeddings (true/false) +# 20 cells_representation (pca/morphology) +# 21 learning_rate +# ========================================================================= + +JOB_SPECS=() + +add_job() { + local job_name="$1" + local use_3d="${2:-${ANCHOR_USE_3D}}" + local expansion="${3:-${ANCHOR_EXPANSION}}" + local tx_k="${4:-${ANCHOR_TX_K}}" + local tx_dist="${5:-${ANCHOR_TX_DIST}}" + local n_layers="${6:-${ANCHOR_N_LAYERS}}" + local n_heads="${7:-${ANCHOR_N_HEADS}}" + local cells_min="${8:-${ANCHOR_CELLS_MIN}}" + local min_qv="${9:-${ANCHOR_MIN_QV}}" + local align="${10:-${ANCHOR_ALIGNMENT}}" + local sg_loss="${11:-${ANCHOR_SG_LOSS}}" + local hidden="${12:-${ANCHOR_HIDDEN}}" + local out="${13:-${ANCHOR_OUT}}" + local tx_w="${14:-${ANCHOR_TX_WEIGHT}}" + local bd_w="${15:-${ANCHOR_BD_WEIGHT}}" + local sg_w="${16:-${ANCHOR_SG_WEIGHT}}" + local align_w="${17:-${ANCHOR_ALIGN_WEIGHT}}" + local pos_emb="${18:-${ANCHOR_POS_EMB}}" + local norm_emb="${19:-${ANCHOR_NORM_EMB}}" + local cells_rep="${20:-${ANCHOR_CELLS_REP}}" + local lr="${21:-${ANCHOR_LR}}" + + JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min}|${min_qv}|${align}|${sg_loss}|${hidden}|${out}|${tx_w}|${bd_w}|${sg_w}|${align_w}|${pos_emb}|${norm_emb}|${cells_rep}|${lr}") +} + +# Helper: add_job with only overridden fields (positional anchor defaults). +# Usage: add_ablation_job [field=value ...] +# This is a convenience wrapper; for clarity each block calls add_job directly. + +job_block() { + local job_name="$1" + case "${job_name}" in + abl_sg_*|abl_full) echo "loss_decomposition" ;; + abl_sgloss_*) echo "sg_loss_type" ;; + abl_aw_*) echo "alignment_sweep" ;; + abl_depth_*|abl_width_*|abl_heads_*|abl_no_pos|abl_no_norm|abl_morph) echo "architecture" ;; + abl_pred_*) echo "prediction_mode" ;; + abl_lr_*) echo "learning_rate" ;; + *) echo "other" ;; + esac +} + +build_jobs() { + # ------------------------------------------------------------------- + # Block A: Loss decomposition (6 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_LOSS_ABLATION}" == "1" ]]; then + # sg only: tx=0, bd=0, no alignment + add_job "abl_sg_only" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "0" "0" "${ANCHOR_SG_WEIGHT}" "0" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # sg + transcript triplet: bd=0, no alignment + add_job "abl_sg_tx" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "0" "${ANCHOR_SG_WEIGHT}" "0" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # sg + boundary metric: tx=0, no alignment + add_job "abl_sg_bd" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "0" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "0" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # sg + both clustering: no alignment + add_job "abl_sg_tx_bd" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "0" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # sg + alignment only: tx=0, bd=0 + add_job "abl_sg_align" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "true" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "0" "0" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # full (anchor baseline) + add_job "abl_full" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + fi + + # ------------------------------------------------------------------- + # Block B: Segmentation loss type (2 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_SGLOSS_ABLATION}" == "1" ]]; then + add_job "abl_sgloss_triplet" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "triplet" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + add_job "abl_sgloss_bce" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "bce" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + fi + + # ------------------------------------------------------------------- + # Block C: Alignment weight sweep (5 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_ALIGNMENT_SWEEP}" == "1" ]]; then + local aw_values=(0 0.01 0.03 0.1 0.3) + local aw_tags=(0 001 003 01 03) + local aw_i + for aw_i in "${!aw_values[@]}"; do + local aw="${aw_values[$aw_i]}" + local aw_tag="${aw_tags[$aw_i]}" + local aw_align="true" + if [[ "${aw}" == "0" ]]; then + aw_align="false" + fi + add_job "abl_aw_${aw_tag}" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${aw_align}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${aw}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + done + fi + + # ------------------------------------------------------------------- + # Block D: Architecture ablation (10 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_ARCH_ABLATION}" == "1" ]]; then + # Depth: 0, 1, 3 mid layers + for depth in 0 1 3; do + add_job "abl_depth_${depth}" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${depth}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + done + + # Width: 32/32 and 128/128 + for width in 32 128; do + add_job "abl_width_${width}" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${width}" "${width}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + done + + # Heads: 1 and 8 + for heads in 1 8; do + add_job "abl_heads_${heads}" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${heads}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + done + + # No positional embeddings + add_job "abl_no_pos" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "false" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # No embedding normalization + add_job "abl_no_norm" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "false" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + # Morphology instead of PCA + add_job "abl_morph" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "morphology" "${ANCHOR_LR}" + fi + + # ------------------------------------------------------------------- + # Block E: Prediction mode (2 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_PREDICTION_ABLATION}" == "1" ]]; then + add_job "abl_pred_cell" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + + add_job "abl_pred_uniform" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" + fi + + # ------------------------------------------------------------------- + # Block F: Learning rate (3 jobs) + # ------------------------------------------------------------------- + if [[ "${RUN_LR_ABLATION}" == "1" ]]; then + add_job "abl_lr_3e4" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "3e-4" + + add_job "abl_lr_3e3" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "3e-3" + + add_job "abl_lr_1e2" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ + "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ + "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ + "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ + "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "1e-2" + fi +} + +# ========================================================================= +# Helper functions (identical to run_robustness_ablation_2gpu.sh) +# ========================================================================= + +run_cmd() { + local log_file="$1" + shift + local -a cmd=("$@") + + { + printf '[%s] CMD:' "$(timestamp)" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + "${cmd[@]}" >> "${log_file}" 2>&1 +} + +run_cmd_with_timeout() { + local log_file="$1" + local timeout_seconds="$2" + shift 2 + local -a cmd=("$@") + + { + printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + if [[ "${timeout_seconds}" -le 0 ]]; then + "${cmd[@]}" >> "${log_file}" 2>&1 + return $? + fi + + local start_ts now elapsed + local cmd_pid timed_out rc + timed_out=0 + start_ts="$(date +%s)" + + "${cmd[@]}" >> "${log_file}" 2>&1 & + cmd_pid=$! + + while kill -0 "${cmd_pid}" 2>/dev/null; do + now="$(date +%s)" + elapsed=$((now - start_ts)) + if (( elapsed >= timeout_seconds )); then + timed_out=1 + echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" + kill -TERM "${cmd_pid}" 2>/dev/null || true + pkill -TERM -P "${cmd_pid}" 2>/dev/null || true + sleep 5 + kill -KILL "${cmd_pid}" 2>/dev/null || true + pkill -KILL -P "${cmd_pid}" 2>/dev/null || true + break + fi + sleep 10 + done + + wait "${cmd_pid}" + rc=$? + if (( timed_out == 1 )); then + return 124 + fi + return "${rc}" +} + +is_oom_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +is_ancdata_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +LAST_EXPORT_STATUS="ok" + +run_exports_for_job() { + local job_name="$1" + local seg_dir="$2" + local log_file="$3" + + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + + mkdir -p "${anndata_dir}" "${xenium_dir}" + + if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then + LAST_EXPORT_STATUS="missing_segmentation" + return 1 + fi + + if [[ ! -f "${anndata_file}" ]]; then + local -a anndata_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${anndata_dir}" + --format anndata + ) + if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then + LAST_EXPORT_STATUS="anndata_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" + fi + + if [[ ! -f "${xenium_file}" ]]; then + local -a xenium_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${xenium_dir}" + --format xenium_explorer + --boundary-method "${BOUNDARY_METHOD}" + --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" + --num-workers "${XENIUM_NUM_WORKERS}" + ) + if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then + LAST_EXPORT_STATUS="xenium_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" + fi + + LAST_EXPORT_STATUS="ok" + return 0 +} + +# ========================================================================= +# run_job — extended to handle 21-field spec +# ========================================================================= + +LAST_JOB_STATUS="unknown" + +run_job() { + local gpu="$1" + local spec="$2" + + local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv + local alignment_loss sg_loss_type hidden_channels out_channels + local tx_weight_end bd_weight_end sg_weight_end alignment_weight_end + local positional_embeddings normalize_embeddings cells_representation learning_rate + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv \ + alignment_loss sg_loss_type hidden_channels out_channels \ + tx_weight_end bd_weight_end sg_weight_end alignment_weight_end \ + positional_embeddings normalize_embeddings cells_representation learning_rate \ + <<< "${spec}" + + # Resolve prediction mode from job name (Block E override). + local job_prediction_mode="${PREDICTION_MODE}" + case "${job_name}" in + abl_pred_cell) job_prediction_mode="cell" ;; + abl_pred_uniform) job_prediction_mode="uniform" ;; + esac + + local seg_dir="${RUNS_DIR}/${job_name}" + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + + mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" + + { + echo "==================================================================" + echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" + echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} sg_loss=${sg_loss_type} hidden=${hidden_channels} out=${out_channels} tx_w=${tx_weight_end} bd_w=${bd_weight_end} sg_w=${sg_weight_end} align_w=${alignment_weight_end} pos_emb=${positional_embeddings} norm_emb=${normalize_embeddings} cells_rep=${cells_representation} lr=${learning_rate} pred_mode=${job_prediction_mode} timeout_min=${SEGMENT_TIMEOUT_MIN}" + } | tee -a "${log_file}" >/dev/null + + if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ + [[ -f "${seg_file}" ]] && \ + [[ -f "${anndata_file}" ]] && \ + [[ -f "${xenium_file}" ]]; then + echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="skipped_existing" + return 0 + fi + + if [[ ! -f "${seg_file}" ]]; then + # Build positional/normalize flags for cyclopts booleans. + local pos_flag="--use-positional-embeddings" + if [[ "${positional_embeddings}" == "false" ]]; then + pos_flag="--no-use-positional-embeddings" + fi + local norm_flag="--normalize-embeddings" + if [[ "${normalize_embeddings}" == "false" ]]; then + norm_flag="--no-normalize-embeddings" + fi + + local -a seg_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger segment + -i "${INPUT_DIR}" + -o "${seg_dir}" + --n-epochs "${N_EPOCHS}" + --prediction-mode "${job_prediction_mode}" + --prediction-expansion-ratio "${expansion}" + --cells-min-counts "${cells_min_counts}" + --min-qv "${min_qv}" + --use-3d "${use_3d}" + --transcripts-max-k "${tx_k}" + --transcripts-max-dist "${tx_dist}" + --n-mid-layers "${n_layers}" + --n-heads "${n_heads}" + --segmentation-loss "${sg_loss_type}" + --hidden-channels "${hidden_channels}" + --out-channels "${out_channels}" + --transcripts-loss-weight-end "${tx_weight_end}" + --cells-loss-weight-end "${bd_weight_end}" + --segmentation-loss-weight-end "${sg_weight_end}" + --learning-rate "${learning_rate}" + --cells-representation "${cells_representation}" + "${pos_flag}" + "${norm_flag}" + ) + + if [[ "${alignment_loss}" == "true" ]]; then + seg_cmd+=( + --alignment-loss + --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" + --alignment-loss-weight-end "${alignment_weight_end}" + ) + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + seg_cmd+=( + --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" + ) + fi + fi + + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" + local seg_rc=$? + if [[ "${seg_rc}" -ne 0 ]]; then + if [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + + if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + local -a seg_retry_cmd=("${seg_cmd[@]}") + local i + for i in "${!seg_retry_cmd[@]}"; do + if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then + seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" + break + fi + done + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" + seg_rc=$? + if [[ "${seg_rc}" -eq 0 ]]; then + echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + elif [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + fi + + if [[ "${seg_rc}" -eq 0 ]]; then + : + else + local last_ckpt="${seg_dir}/checkpoints/last.ckpt" + if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null + else + echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="predict_fallback_failed" + return 1 + fi + else + if is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_ancdata" + elif is_oom_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oom" + else + echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_failed" + fi + return 1 + fi + fi + fi + else + echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null + fi + + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" + return 1 + fi + + echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="ok" + return 0 +} + +# ========================================================================= +# GPU group runner +# ========================================================================= + +run_gpu_group() { + local gpu="$1" + shift + local -a indices=("$@") + local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" + + printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" + + local idx spec job_name start_ts end_ts elapsed_s + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r job_name _ <<< "${spec}" + + start_ts="$(date +%s)" + run_job "${gpu}" "${spec}" + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + + printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${LAST_JOB_STATUS}" \ + "${elapsed_s}" \ + "${RUNS_DIR}/${job_name}" \ + "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ + >> "${summary_file}" + done +} + +# ========================================================================= +# Post-run recovery (predict-only from checkpoints) +# ========================================================================= + +run_post_recovery_predict_only_group() { + local gpu="$1" + local out_file="$2" + shift 2 + local -a indices=("$@") + + printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" + + local idx spec job_name + local seg_dir seg_file last_ckpt log_file note status + local start_ts end_ts elapsed_s + + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r job_name _ <<< "${spec}" + + seg_dir="${RUNS_DIR}/${job_name}" + seg_file="${seg_dir}/segger_segmentation.parquet" + last_ckpt="${seg_dir}/checkpoints/last.ckpt" + log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + mkdir -p "${seg_dir}" + + start_ts="$(date +%s)" + note="" + status="ok" + + if [[ -f "${seg_file}" ]]; then + note="segmentation_exists" + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="${LAST_EXPORT_STATUS}" + note="exports_failed_after_existing_seg" + fi + else + if [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="recovered_predict_ok" + note="predict_only_from_last_ckpt" + else + status="${LAST_EXPORT_STATUS}" + note="predict_recovered_but_exports_failed" + fi + else + status="recovered_predict_failed" + note="predict_only_failed" + fi + else + status="recovery_no_checkpoint" + note="missing_seg_and_last_ckpt" + fi + fi + + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${status}" \ + "${elapsed_s}" \ + "${note}" \ + "${seg_dir}" \ + "${log_file}" \ + >> "${out_file}" + done +} + +run_post_recovery_predict_only() { + local recovery_file="${SUMMARY_DIR}/recovery.tsv" + local pids=() + local g gpu recovery_per_gpu + + # Single GPU: run sequentially with all indices. + if [[ "${NUM_GPUS}" -eq 1 ]]; then + local all_indices=() + for g in $(seq 0 $((NUM_GPUS - 1))); do + local -n arr="GPU_${g}_INDICES" + all_indices+=("${arr[@]}") + done + recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${GPU_IDS[0]}.tsv" + run_post_recovery_predict_only_group "${GPU_IDS[0]}" "${recovery_per_gpu}" "${all_indices[@]}" + cp "${recovery_per_gpu}" "${recovery_file}" + return + fi + + # Multi-GPU: run recovery groups in parallel. + for g in $(seq 0 $((NUM_GPUS - 1))); do + gpu="${GPU_IDS[$g]}" + recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${gpu}.tsv" + local -n arr="GPU_${g}_INDICES" + if [[ "${#arr[@]}" -gt 0 ]]; then + run_post_recovery_predict_only_group "${gpu}" "${recovery_per_gpu}" "${arr[@]}" & + pids+=($!) + fi + done + + for pid in "${pids[@]}"; do + wait "${pid}" + done + + # Merge recovery files. + local first=1 + for g in $(seq 0 $((NUM_GPUS - 1))); do + gpu="${GPU_IDS[$g]}" + recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${gpu}.tsv" + if [[ -f "${recovery_per_gpu}" ]]; then + if [[ "${first}" -eq 1 ]]; then + cat "${recovery_per_gpu}" > "${recovery_file}" + first=0 + else + tail -n +2 "${recovery_per_gpu}" >> "${recovery_file}" + fi + fi + done +} + +# ========================================================================= +# Build jobs and distribute across GPUs +# ========================================================================= + +build_jobs + +if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then + echo "ERROR: No ablation jobs were generated. Check block toggles (RUN_LOSS_ABLATION, etc.)." + exit 1 +fi + +# Create per-GPU index arrays with round-robin distribution. +for g in $(seq 0 $((NUM_GPUS - 1))); do + declare -a "GPU_${g}_INDICES=()" +done + +idx=0 +for spec in "${JOB_SPECS[@]}"; do + g=$((idx % NUM_GPUS)) + eval "GPU_${g}_INDICES+=(${idx})" + idx=$((idx + 1)) +done + +# ========================================================================= +# Write job plan TSV +# ========================================================================= +{ + printf "job\tstudy_block\tgpu_group\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\tsg_loss_type\thidden_channels\tout_channels\ttx_weight_end\tbd_weight_end\tsg_weight_end\talignment_weight_end\tpositional_embeddings\tnormalize_embeddings\tcells_representation\tlearning_rate\n" + for idx in "${!JOB_SPECS[@]}"; do + local_group=$((idx % NUM_GPUS)) + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv \ + alignment_loss sg_loss_type hidden_channels out_channels \ + tx_weight_end bd_weight_end sg_weight_end alignment_weight_end \ + positional_embeddings normalize_embeddings cells_representation learning_rate \ + <<< "${JOB_SPECS[$idx]}" + local_block="$(job_block "${job_name}")" + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" "${local_block}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ + "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" \ + "${sg_loss_type}" "${hidden_channels}" "${out_channels}" \ + "${tx_weight_end}" "${bd_weight_end}" "${sg_weight_end}" "${alignment_weight_end}" \ + "${positional_embeddings}" "${normalize_embeddings}" "${cells_representation}" "${learning_rate}" + done +} > "${PLAN_FILE}" + +echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} ablation jobs across ${NUM_GPUS} GPU(s)." +echo "[$(timestamp)] GPUs: ${GPU_IDS[*]}" +for g in $(seq 0 $((NUM_GPUS - 1))); do + eval "_count=\${#GPU_${g}_INDICES[@]}" + echo "[$(timestamp)] GPU ${GPU_IDS[$g]}: ${_count} jobs" +done +echo "[$(timestamp)] Job plan: ${PLAN_FILE}" +echo "[$(timestamp)] Logs: ${LOGS_DIR}" + +if [[ "${DRY_RUN}" == "1" ]]; then + echo "[$(timestamp)] DRY_RUN=1 — exiting without running jobs." + echo "" + echo "Job plan:" + column -t -s $'\t' "${PLAN_FILE}" 2>/dev/null || cat "${PLAN_FILE}" + exit 0 +fi + +# ========================================================================= +# Launch GPU groups in parallel +# ========================================================================= + +PIDS=() +for g in $(seq 0 $((NUM_GPUS - 1))); do + gpu="${GPU_IDS[$g]}" + eval "_arr=(\"\${GPU_${g}_INDICES[@]}\")" + if [[ "${#_arr[@]}" -gt 0 ]]; then + run_gpu_group "${gpu}" "${_arr[@]}" & + PIDS+=($!) + fi +done + +for pid in "${PIDS[@]}"; do + wait "${pid}" +done + +# ========================================================================= +# Post-run recovery pass +# ========================================================================= + +echo "[$(timestamp)] Starting post-run predict-only recovery pass..." +run_post_recovery_predict_only + +# ========================================================================= +# Combine summaries +# ========================================================================= + +COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" +if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" + ) +else + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" + ) +fi + +echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" +if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" +fi + +# ========================================================================= +# Validation table +# ========================================================================= + +if [[ "${RUN_VALIDATION_TABLE}" == "1" ]]; then + if [[ -f "${VALIDATION_SCRIPT}" ]]; then + echo "[$(timestamp)] Building validation metrics table..." + validation_log="${SUMMARY_DIR}/validation_metrics.log" + validation_cmd=( + bash "${VALIDATION_SCRIPT}" + --root "${OUTPUT_ROOT}" + --input-dir "${INPUT_DIR}" + --include-default-10x "${VALIDATION_INCLUDE_DEFAULT_10X}" + ) + # Pass first two GPU IDs for compatibility with validation script. + if [[ "${NUM_GPUS}" -ge 2 ]]; then + validation_cmd+=(--gpu-a "${GPU_IDS[0]}" --gpu-b "${GPU_IDS[1]}") + else + validation_cmd+=(--gpu-a "${GPU_IDS[0]}" --gpu-b "${GPU_IDS[0]}") + fi + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + validation_cmd+=(--me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + validation_cmd+=( + --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" + ) + fi + if "${validation_cmd[@]}" >> "${validation_log}" 2>&1; then + echo "[$(timestamp)] Validation table updated: ${OUTPUT_ROOT}/summaries/validation_metrics.tsv" + else + echo "[$(timestamp)] WARN: validation table build failed (see ${validation_log})" + fi + else + echo "[$(timestamp)] WARN: VALIDATION_SCRIPT not found: ${VALIDATION_SCRIPT}" + fi +fi + +echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" + +if [[ "${FAILED_COUNT}" -gt 0 ]]; then + exit 1 +fi diff --git a/scripts/run_param_benchmark_2gpu.sh b/scripts/run_param_benchmark_2gpu.sh new file mode 100755 index 0000000..8236231 --- /dev/null +++ b/scripts/run_param_benchmark_2gpu.sh @@ -0,0 +1,764 @@ +#!/usr/bin/env bash +set -u -o pipefail + +# ------------------------------------------------------------------------- +# Segger overnight benchmark runner (2 GPUs, 1 job per GPU at a time) +# ------------------------------------------------------------------------- +# Usage: +# bash scripts/run_param_benchmark_2gpu.sh +# +# Optional overrides (environment variables): +# INPUT_DIR=data/xe_pancreas_mossi/ +# OUTPUT_ROOT=./results/mossi_main_big_benchmark_nightly +# GPU_A=0 +# GPU_B=1 +# N_EPOCHS=20 +# INCLUDE_EXTRA_SWEEPS=1 +# RESUME_IF_EXISTS=1 +# DRY_RUN=0 +# SEGMENT_TIMEOUT_MIN=90 +# ALIGNMENT_LOSS=true +# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad +# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type +# SEGMENT_NUM_WORKERS=8 +# SEGMENT_ANC_RETRY_WORKERS=0 +# TORCH_SHARING_STRATEGY=file_system +# ------------------------------------------------------------------------- + +timestamp() { + date '+%Y-%m-%d %H:%M:%S' +} + +DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" +INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" +OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_main_big_benchmark_nightly}" + +# Common layout fallback when running from segger-0.2.0 with data one level up. +if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ + [[ ! -d "${INPUT_DIR}" ]] && \ + [[ -d "../data/xe_pancreas_mossi/" ]]; then + INPUT_DIR="../data/xe_pancreas_mossi/" +fi + +GPU_A="${GPU_A:-0}" +GPU_B="${GPU_B:-1}" + +N_EPOCHS="${N_EPOCHS:-20}" +PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" + +BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" +BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" +XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" + +INCLUDE_EXTRA_SWEEPS="${INCLUDE_EXTRA_SWEEPS:-1}" +RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" +DRY_RUN="${DRY_RUN:-0}" +PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" +SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" +SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) +SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" +SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" +TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" + +ALIGNMENT_LOSS="${ALIGNMENT_LOSS:-true}" +ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" +ALIGNMENT_LOSS_WEIGHT_END="${ALIGNMENT_LOSS_WEIGHT_END:-0.03}" +ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" +ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" +ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" + +# Common layout fallback when running from segger-0.2.0 with data one level up. +if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ + [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ + [[ -f "../data/ref_pancreas.h5ad" ]]; then + ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" +fi + +# Baseline values (matches the command you provided). +BASE_USE_3D="${BASE_USE_3D:-true}" +BASE_EXPANSION_RATIO="${BASE_EXPANSION_RATIO:-2.0}" +BASE_TX_MAX_K="${BASE_TX_MAX_K:-5}" +BASE_TX_MAX_DIST="${BASE_TX_MAX_DIST:-5}" +BASE_N_MID_LAYERS="${BASE_N_MID_LAYERS:-2}" +BASE_N_HEADS="${BASE_N_HEADS:-2}" +BASE_CELLS_MIN_COUNTS="${BASE_CELLS_MIN_COUNTS:-5}" +BASE_MIN_QV="${BASE_MIN_QV:-0}" + +# One-factor-at-a-time sweep values around baseline. +USE_3D_VALUES=(false true) +EXPANSION_VALUES=(1 1.5 2.0 2.5 3.0) +TX_MAX_K_VALUES=(5 10 20) +TX_MAX_DIST_VALUES=(3 5 10 20) +N_MID_LAYER_VALUES=(1 2 3) +N_HEAD_VALUES=(2 4 8) +CELLS_MIN_COUNTS_VALUES=(3 5 10) +ALIGNMENT_VALUES=(false true) + +RUNS_DIR="${OUTPUT_ROOT}/runs" +EXPORTS_DIR="${OUTPUT_ROOT}/exports" +LOGS_DIR="${OUTPUT_ROOT}/logs" +SUMMARY_DIR="${OUTPUT_ROOT}/summaries" +PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" + +mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" + +if [[ ! -d "${INPUT_DIR}" ]]; then + if [[ "${DRY_RUN}" == "1" ]]; then + echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" + else + echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" + exit 1 + fi +fi + +if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then + echo "ERROR: 'segger' command not found in PATH." + exit 1 +fi + +need_alignment_inputs=0 +if [[ "${ALIGNMENT_LOSS}" == "true" ]]; then + need_alignment_inputs=1 +elif [[ "${INCLUDE_EXTRA_SWEEPS}" == "1" ]]; then + for v in "${ALIGNMENT_VALUES[@]}"; do + if [[ "${v}" == "true" ]]; then + need_alignment_inputs=1 + break + fi + done +fi + +if [[ "${need_alignment_inputs}" == "1" ]]; then + if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: ALIGNMENT_LOSS=true requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." + exit 1 + fi + if [[ "${DRY_RUN}" != "1" ]]; then + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" + exit 1 + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" + exit 1 + fi + fi +fi + +JOB_SPECS=() + +add_job() { + local job_name="$1" + local use_3d="$2" + local expansion="$3" + local tx_k="$4" + local tx_dist="$5" + local n_layers="$6" + local n_heads="$7" + local cells_min_counts="$8" + local min_qv="$9" + local alignment_loss="${10}" + JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min_counts}|${min_qv}|${alignment_loss}") +} + +build_jobs() { + local v tag + + add_job \ + "baseline" \ + "${BASE_USE_3D}" \ + "${BASE_EXPANSION_RATIO}" \ + "${BASE_TX_MAX_K}" \ + "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" \ + "${BASE_N_HEADS}" \ + "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" \ + "${ALIGNMENT_LOSS}" + + for v in "${USE_3D_VALUES[@]}"; do + [[ "${v}" == "${BASE_USE_3D}" ]] && continue + add_job "use3d_${v}" \ + "${v}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${EXPANSION_VALUES[@]}"; do + [[ "${v}" == "${BASE_EXPANSION_RATIO}" ]] && continue + tag="${v//./p}" + add_job "expansion_${tag}" \ + "${BASE_USE_3D}" "${v}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${TX_MAX_K_VALUES[@]}"; do + [[ "${v}" == "${BASE_TX_MAX_K}" ]] && continue + add_job "txk_${v}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${v}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${TX_MAX_DIST_VALUES[@]}"; do + [[ "${v}" == "${BASE_TX_MAX_DIST}" ]] && continue + tag="${v//./p}" + add_job "txdist_${tag}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${v}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${N_MID_LAYER_VALUES[@]}"; do + [[ "${v}" == "${BASE_N_MID_LAYERS}" ]] && continue + add_job "layers_${v}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${v}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${N_HEAD_VALUES[@]}"; do + [[ "${v}" == "${BASE_N_HEADS}" ]] && continue + add_job "heads_${v}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${v}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for v in "${CELLS_MIN_COUNTS_VALUES[@]}"; do + [[ "${v}" == "${BASE_CELLS_MIN_COUNTS}" ]] && continue + add_job "cellsmin_${v}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${v}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + if [[ "${INCLUDE_EXTRA_SWEEPS}" == "1" ]]; then + for v in "${ALIGNMENT_VALUES[@]}"; do + [[ "${v}" == "${ALIGNMENT_LOSS}" ]] && continue + add_job "align_${v}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${v}" + done + fi +} + +run_cmd() { + local log_file="$1" + shift + local -a cmd=("$@") + + { + printf '[%s] CMD:' "$(timestamp)" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + "${cmd[@]}" >> "${log_file}" 2>&1 +} + +run_cmd_with_timeout() { + local log_file="$1" + local timeout_seconds="$2" + shift 2 + local -a cmd=("$@") + + { + printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + if [[ "${timeout_seconds}" -le 0 ]]; then + "${cmd[@]}" >> "${log_file}" 2>&1 + return $? + fi + + local start_ts now elapsed + local cmd_pid timed_out rc + timed_out=0 + start_ts="$(date +%s)" + + "${cmd[@]}" >> "${log_file}" 2>&1 & + cmd_pid=$! + + while kill -0 "${cmd_pid}" 2>/dev/null; do + now="$(date +%s)" + elapsed=$((now - start_ts)) + if (( elapsed >= timeout_seconds )); then + timed_out=1 + echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" + kill -TERM "${cmd_pid}" 2>/dev/null || true + pkill -TERM -P "${cmd_pid}" 2>/dev/null || true + sleep 5 + kill -KILL "${cmd_pid}" 2>/dev/null || true + pkill -KILL -P "${cmd_pid}" 2>/dev/null || true + break + fi + sleep 10 + done + + wait "${cmd_pid}" + rc=$? + if (( timed_out == 1 )); then + return 124 + fi + return "${rc}" +} + +is_oom_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +is_ancdata_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +LAST_EXPORT_STATUS="ok" + +run_exports_for_job() { + local job_name="$1" + local seg_dir="$2" + local log_file="$3" + + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + + mkdir -p "${anndata_dir}" "${xenium_dir}" + + if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then + LAST_EXPORT_STATUS="missing_segmentation" + return 1 + fi + + if [[ ! -f "${anndata_file}" ]]; then + local -a anndata_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${anndata_dir}" + --format anndata + ) + if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then + LAST_EXPORT_STATUS="anndata_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" + fi + + if [[ ! -f "${xenium_file}" ]]; then + local -a xenium_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${xenium_dir}" + --format xenium_explorer + --boundary-method "${BOUNDARY_METHOD}" + --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" + --num-workers "${XENIUM_NUM_WORKERS}" + ) + if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then + LAST_EXPORT_STATUS="xenium_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" + fi + + LAST_EXPORT_STATUS="ok" + return 0 +} + +LAST_JOB_STATUS="unknown" + +run_job() { + local gpu="$1" + local spec="$2" + + local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${spec}" + + local seg_dir="${RUNS_DIR}/${job_name}" + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + + mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" + + { + echo "==================================================================" + echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" + echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} timeout_min=${SEGMENT_TIMEOUT_MIN} dl_workers=${SEGMENT_NUM_WORKERS} anc_retry_workers=${SEGMENT_ANC_RETRY_WORKERS} sharing=${TORCH_SHARING_STRATEGY}" + } | tee -a "${log_file}" >/dev/null + + if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ + [[ -f "${seg_file}" ]] && \ + [[ -f "${anndata_file}" ]] && \ + [[ -f "${xenium_file}" ]]; then + echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="skipped_existing" + return 0 + fi + + if [[ ! -f "${seg_file}" ]]; then + local -a seg_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger segment + -i "${INPUT_DIR}" + -o "${seg_dir}" + --n-epochs "${N_EPOCHS}" + --prediction-mode "${PREDICTION_MODE}" + --prediction-expansion-ratio "${expansion}" + --cells-min-counts "${cells_min_counts}" + --min-qv "${min_qv}" + --use-3d "${use_3d}" + --transcripts-max-k "${tx_k}" + --transcripts-max-dist "${tx_dist}" + --n-mid-layers "${n_layers}" + --n-heads "${n_heads}" + ) + if [[ "${alignment_loss}" == "true" ]]; then + seg_cmd+=( + --alignment-loss + --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" + --alignment-loss-weight-end "${ALIGNMENT_LOSS_WEIGHT_END}" + ) + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + seg_cmd+=( + --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" + ) + fi + fi + + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" + local seg_rc=$? + if [[ "${seg_rc}" -ne 0 ]]; then + if [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + + if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + local -a seg_retry_cmd=("${seg_cmd[@]}") + local i + for i in "${!seg_retry_cmd[@]}"; do + if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then + seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" + break + fi + done + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" + seg_rc=$? + if [[ "${seg_rc}" -eq 0 ]]; then + echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + elif [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + fi + + if [[ "${seg_rc}" -eq 0 ]]; then + : + else + local last_ckpt="${seg_dir}/checkpoints/last.ckpt" + if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null + else + echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="predict_fallback_failed" + return 1 + fi + else + if is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_ancdata" + elif is_oom_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oom" + else + echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_failed" + fi + return 1 + fi + fi + fi + else + echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null + fi + + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" + return 1 + fi + + echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="ok" + return 0 +} + +run_gpu_group() { + local gpu="$1" + shift + local -a indices=("$@") + local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" + + printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" + + local idx spec job_name start_ts end_ts elapsed_s + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r job_name _ <<< "${spec}" + + start_ts="$(date +%s)" + run_job "${gpu}" "${spec}" + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + + printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${LAST_JOB_STATUS}" \ + "${elapsed_s}" \ + "${RUNS_DIR}/${job_name}" \ + "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ + >> "${summary_file}" + done +} + +run_post_recovery_predict_only_group() { + local gpu="$1" + local out_file="$2" + shift 2 + local -a indices=("$@") + + printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" + + local idx spec job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss + local seg_dir seg_file last_ckpt log_file note status + local start_ts end_ts elapsed_s + + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${spec}" + + seg_dir="${RUNS_DIR}/${job_name}" + seg_file="${seg_dir}/segger_segmentation.parquet" + last_ckpt="${seg_dir}/checkpoints/last.ckpt" + log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + mkdir -p "${seg_dir}" + + start_ts="$(date +%s)" + note="" + status="ok" + + if [[ -f "${seg_file}" ]]; then + note="segmentation_exists" + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="${LAST_EXPORT_STATUS}" + note="exports_failed_after_existing_seg" + fi + else + if [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="recovered_predict_ok" + note="predict_only_from_last_ckpt" + else + status="${LAST_EXPORT_STATUS}" + note="predict_recovered_but_exports_failed" + fi + else + status="recovered_predict_failed" + note="predict_only_failed" + fi + else + status="recovery_no_checkpoint" + note="missing_seg_and_last_ckpt" + fi + fi + + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${status}" \ + "${elapsed_s}" \ + "${note}" \ + "${seg_dir}" \ + "${log_file}" \ + >> "${out_file}" + done +} + +run_post_recovery_predict_only() { + local recovery_file="${SUMMARY_DIR}/recovery.tsv" + local recovery_a="${SUMMARY_DIR}/recovery.gpu${GPU_A}.tsv" + local recovery_b="${SUMMARY_DIR}/recovery.gpu${GPU_B}.tsv" + local pid_a pid_b + + if [[ "${GPU_A}" == "${GPU_B}" ]]; then + run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" "${GPU_B_INDICES[@]}" + cp "${recovery_a}" "${recovery_file}" + return + fi + + run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" & + pid_a=$! + run_post_recovery_predict_only_group "${GPU_B}" "${recovery_b}" "${GPU_B_INDICES[@]}" & + pid_b=$! + + wait "${pid_a}" + wait "${pid_b}" + + awk 'FNR==1 && NR!=1 {next} {print}' "${recovery_a}" "${recovery_b}" > "${recovery_file}" +} + +build_jobs + +if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then + echo "ERROR: No benchmark jobs were generated." + exit 1 +fi + +GPU_A_INDICES=() +GPU_B_INDICES=() + +idx=0 +for spec in "${JOB_SPECS[@]}"; do + if (( idx % 2 == 0 )); then + GPU_A_INDICES+=("${idx}") + else + GPU_B_INDICES+=("${idx}") + fi + idx=$((idx + 1)) +done + +{ + printf "job\tgroup\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" + for idx in "${!JOB_SPECS[@]}"; do + local_group="A" + if (( idx % 2 == 1 )); then + local_group="B" + fi + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${JOB_SPECS[$idx]}" + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ + "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" + done +} > "${PLAN_FILE}" + +echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} jobs." +echo "[$(timestamp)] Group A (GPU ${GPU_A}): ${#GPU_A_INDICES[@]} jobs" +echo "[$(timestamp)] Group B (GPU ${GPU_B}): ${#GPU_B_INDICES[@]} jobs" +echo "[$(timestamp)] Job plan: ${PLAN_FILE}" +echo "[$(timestamp)] Logs: ${LOGS_DIR}" + +run_gpu_group "${GPU_A}" "${GPU_A_INDICES[@]}" & +PID_A=$! +run_gpu_group "${GPU_B}" "${GPU_B_INDICES[@]}" & +PID_B=$! + +wait "${PID_A}" +wait "${PID_B}" + +if [[ "${DRY_RUN}" != "1" ]]; then + echo "[$(timestamp)] Starting post-run predict-only recovery pass..." + run_post_recovery_predict_only +fi + +COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" +if [[ "${DRY_RUN}" != "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" + ) +else + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" + ) +fi + +echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" +if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" +fi +echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" + +if [[ "${FAILED_COUNT}" -gt 0 ]]; then + exit 1 +fi diff --git a/scripts/run_robustness_ablation_2gpu.sh b/scripts/run_robustness_ablation_2gpu.sh new file mode 100755 index 0000000..685eb8e --- /dev/null +++ b/scripts/run_robustness_ablation_2gpu.sh @@ -0,0 +1,845 @@ +#!/usr/bin/env bash +set -u -o pipefail + +# ------------------------------------------------------------------------- +# Segger robustness + ablation runner (2 GPUs, 1 job per GPU at a time) +# ------------------------------------------------------------------------- +# Usage: +# bash scripts/run_robustness_ablation_2gpu.sh +# +# Optional overrides (environment variables): +# INPUT_DIR=data/xe_pancreas_mossi/ +# OUTPUT_ROOT=./results/mossi_main_big_robustness_ablation +# GPU_A=0 +# GPU_B=1 +# N_EPOCHS=20 +# STABILITY_REPEATS=3 +# RUN_INTERACTION_GRID=1 +# RUN_STRESS_TESTS=1 +# RESUME_IF_EXISTS=1 +# DRY_RUN=0 +# SEGMENT_TIMEOUT_MIN=90 +# ALIGNMENT_LOSS=true +# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad +# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type +# SEGMENT_NUM_WORKERS=8 +# SEGMENT_ANC_RETRY_WORKERS=0 +# TORCH_SHARING_STRATEGY=file_system +# RUN_VALIDATION_TABLE=1 +# VALIDATION_SCRIPT=scripts/build_benchmark_validation_table.sh +# ------------------------------------------------------------------------- + +timestamp() { + date '+%Y-%m-%d %H:%M:%S' +} + +DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" +INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" +OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_main_big_robustness_ablation}" + +# Common layout fallback when running from segger-0.2.0 with data one level up. +if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ + [[ ! -d "${INPUT_DIR}" ]] && \ + [[ -d "../data/xe_pancreas_mossi/" ]]; then + INPUT_DIR="../data/xe_pancreas_mossi/" +fi + +GPU_A="${GPU_A:-0}" +GPU_B="${GPU_B:-1}" + +N_EPOCHS="${N_EPOCHS:-20}" +PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" + +BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" +BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" +XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" + +RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" +DRY_RUN="${DRY_RUN:-0}" +PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" +SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" +SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) +SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" +SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" +TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" + +ALIGNMENT_LOSS="${ALIGNMENT_LOSS:-true}" +ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" +ALIGNMENT_LOSS_WEIGHT_END="${ALIGNMENT_LOSS_WEIGHT_END:-0.03}" +ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" +ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" +ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" + +# Common layout fallback when running from segger-0.2.0 with data one level up. +if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ + [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ + [[ -f "../data/ref_pancreas.h5ad" ]]; then + ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" +fi + +# Baseline values (legacy baseline). +BASE_USE_3D="${BASE_USE_3D:-true}" +BASE_EXPANSION_RATIO="${BASE_EXPANSION_RATIO:-2.0}" +BASE_TX_MAX_K="${BASE_TX_MAX_K:-5}" +BASE_TX_MAX_DIST="${BASE_TX_MAX_DIST:-5}" +BASE_N_MID_LAYERS="${BASE_N_MID_LAYERS:-2}" +BASE_N_HEADS="${BASE_N_HEADS:-2}" +BASE_CELLS_MIN_COUNTS="${BASE_CELLS_MIN_COUNTS:-5}" +BASE_MIN_QV="${BASE_MIN_QV:-0}" + +# Robust anchor values (derived from current validation trends). +ANCHOR_USE_3D="${ANCHOR_USE_3D:-true}" +ANCHOR_EXPANSION_RATIO="${ANCHOR_EXPANSION_RATIO:-2.5}" +ANCHOR_TX_MAX_K="${ANCHOR_TX_MAX_K:-5}" +ANCHOR_TX_MAX_DIST="${ANCHOR_TX_MAX_DIST:-20}" +ANCHOR_N_MID_LAYERS="${ANCHOR_N_MID_LAYERS:-2}" +ANCHOR_N_HEADS="${ANCHOR_N_HEADS:-4}" +ANCHOR_CELLS_MIN_COUNTS="${ANCHOR_CELLS_MIN_COUNTS:-5}" +ANCHOR_MIN_QV="${ANCHOR_MIN_QV:-0}" +ANCHOR_ALIGNMENT_LOSS="${ANCHOR_ALIGNMENT_LOSS:-true}" + +# High-sensitivity variant. +SENS_EXPANSION_RATIO="${SENS_EXPANSION_RATIO:-3.0}" + +# Study controls. +STABILITY_REPEATS="${STABILITY_REPEATS:-3}" +RUN_INTERACTION_GRID="${RUN_INTERACTION_GRID:-1}" +RUN_STRESS_TESTS="${RUN_STRESS_TESTS:-1}" + +# Interaction grid around high-performing region. +INTERACTION_EXPANSIONS=(2.5 3.0) +INTERACTION_TX_DISTS=(10 20) +INTERACTION_HEADS=(2 4) + +# Alignment ablation subset. +INTERACTION_ALIGN_VALUES=(true false) + +if ! [[ "${STABILITY_REPEATS}" =~ ^[0-9]+$ ]] || [[ "${STABILITY_REPEATS}" -lt 1 ]]; then + echo "ERROR: STABILITY_REPEATS must be a positive integer. Got: ${STABILITY_REPEATS}" + exit 1 +fi + +RUNS_DIR="${OUTPUT_ROOT}/runs" +EXPORTS_DIR="${OUTPUT_ROOT}/exports" +LOGS_DIR="${OUTPUT_ROOT}/logs" +SUMMARY_DIR="${OUTPUT_ROOT}/summaries" +PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" +RUN_VALIDATION_TABLE="${RUN_VALIDATION_TABLE:-1}" +VALIDATION_SCRIPT="${VALIDATION_SCRIPT:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/build_benchmark_validation_table.sh}" +VALIDATION_INCLUDE_DEFAULT_10X="${VALIDATION_INCLUDE_DEFAULT_10X:-true}" + +mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" + +if [[ ! -d "${INPUT_DIR}" ]]; then + if [[ "${DRY_RUN}" == "1" ]]; then + echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" + else + echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" + exit 1 + fi +fi + +if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then + echo "ERROR: 'segger' command not found in PATH." + exit 1 +fi + +need_alignment_inputs=0 +if [[ "${ALIGNMENT_LOSS}" == "true" ]]; then + need_alignment_inputs=1 +elif [[ "${ANCHOR_ALIGNMENT_LOSS}" == "true" ]]; then + need_alignment_inputs=1 +elif [[ "${RUN_INTERACTION_GRID}" == "1" ]]; then + need_alignment_inputs=1 +fi + +if [[ "${need_alignment_inputs}" == "1" ]]; then + if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: ALIGNMENT_LOSS=true requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." + exit 1 + fi + if [[ "${DRY_RUN}" != "1" ]]; then + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" + exit 1 + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" + exit 1 + fi + fi +fi + +JOB_SPECS=() + +add_job() { + local job_name="$1" + local use_3d="$2" + local expansion="$3" + local tx_k="$4" + local tx_dist="$5" + local n_layers="$6" + local n_heads="$7" + local cells_min_counts="$8" + local min_qv="$9" + local alignment_loss="${10}" + JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min_counts}|${min_qv}|${alignment_loss}") +} + +job_block() { + local job_name="$1" + case "${job_name}" in + stbl_*) echo "stability" ;; + int_*) echo "interaction" ;; + stress_*) echo "stress" ;; + *) echo "other" ;; + esac +} + +build_jobs() { + local i exp dist heads align tag_exp tag_dist + + # ----------------------------------------------------------------------- + # Block A: stability / repeatability + # ----------------------------------------------------------------------- + for ((i = 1; i <= STABILITY_REPEATS; i++)); do + add_job "stbl_baseline_r${i}" \ + "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ + "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ + "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" + done + + for ((i = 1; i <= STABILITY_REPEATS; i++)); do + add_job "stbl_anchor_r${i}" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + done + + for ((i = 1; i <= 2; i++)); do + add_job "stbl_sens_r${i}" \ + "${ANCHOR_USE_3D}" "${SENS_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + done + + # ----------------------------------------------------------------------- + # Block B: interaction grid in high-performing region + # ----------------------------------------------------------------------- + if [[ "${RUN_INTERACTION_GRID}" == "1" ]]; then + for exp in "${INTERACTION_EXPANSIONS[@]}"; do + for dist in "${INTERACTION_TX_DISTS[@]}"; do + for heads in "${INTERACTION_HEADS[@]}"; do + tag_exp="${exp//./p}" + tag_dist="${dist//./p}" + add_job "int_e${tag_exp}_d${tag_dist}_h${heads}_aT" \ + "${ANCHOR_USE_3D}" "${exp}" "${ANCHOR_TX_MAX_K}" "${dist}" \ + "${ANCHOR_N_MID_LAYERS}" "${heads}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "true" + done + done + done + + # Alignment ablation on selected interaction corners (heads=4). + for exp in "${INTERACTION_EXPANSIONS[@]}"; do + for dist in "${INTERACTION_TX_DISTS[@]}"; do + for align in "${INTERACTION_ALIGN_VALUES[@]}"; do + [[ "${align}" == "true" ]] && continue + tag_exp="${exp//./p}" + tag_dist="${dist//./p}" + add_job "int_e${tag_exp}_d${tag_dist}_h4_aF" \ + "${ANCHOR_USE_3D}" "${exp}" "${ANCHOR_TX_MAX_K}" "${dist}" \ + "${ANCHOR_N_MID_LAYERS}" "4" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${align}" + done + done + done + fi + + # ----------------------------------------------------------------------- + # Block C: stress tests (robustness to practical shifts) + # ----------------------------------------------------------------------- + if [[ "${RUN_STRESS_TESTS}" == "1" ]]; then + add_job "stress_use3d_false_anchor" \ + "false" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + + add_job "stress_use3d_false_sens" \ + "false" "${SENS_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + + add_job "stress_cellsmin3_anchor" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "3" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + + add_job "stress_cellsmin10_anchor" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "10" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + + add_job "stress_txk20_anchor" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "20" "${ANCHOR_TX_MAX_DIST}" \ + "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + + add_job "stress_layers1_anchor" \ + "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ + "1" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ + "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" + fi +} + +run_cmd() { + local log_file="$1" + shift + local -a cmd=("$@") + + { + printf '[%s] CMD:' "$(timestamp)" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + "${cmd[@]}" >> "${log_file}" 2>&1 +} + +run_cmd_with_timeout() { + local log_file="$1" + local timeout_seconds="$2" + shift 2 + local -a cmd=("$@") + + { + printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" + printf ' %q' "${cmd[@]}" + printf '\n' + } >> "${log_file}" + + if [[ "${DRY_RUN}" == "1" ]]; then + return 0 + fi + + if [[ "${timeout_seconds}" -le 0 ]]; then + "${cmd[@]}" >> "${log_file}" 2>&1 + return $? + fi + + local start_ts now elapsed + local cmd_pid timed_out rc + timed_out=0 + start_ts="$(date +%s)" + + "${cmd[@]}" >> "${log_file}" 2>&1 & + cmd_pid=$! + + while kill -0 "${cmd_pid}" 2>/dev/null; do + now="$(date +%s)" + elapsed=$((now - start_ts)) + if (( elapsed >= timeout_seconds )); then + timed_out=1 + echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" + kill -TERM "${cmd_pid}" 2>/dev/null || true + pkill -TERM -P "${cmd_pid}" 2>/dev/null || true + sleep 5 + kill -KILL "${cmd_pid}" 2>/dev/null || true + pkill -KILL -P "${cmd_pid}" 2>/dev/null || true + break + fi + sleep 10 + done + + wait "${cmd_pid}" + rc=$? + if (( timed_out == 1 )); then + return 124 + fi + return "${rc}" +} + +is_oom_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +is_ancdata_failure() { + local log_file="$1" + if [[ ! -f "${log_file}" ]]; then + return 1 + fi + local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" + if command -v rg >/dev/null 2>&1; then + rg -qi "${pattern}" "${log_file}" + else + grep -Eiq "${pattern}" "${log_file}" + fi +} + +LAST_EXPORT_STATUS="ok" + +run_exports_for_job() { + local job_name="$1" + local seg_dir="$2" + local log_file="$3" + + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + + mkdir -p "${anndata_dir}" "${xenium_dir}" + + if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then + LAST_EXPORT_STATUS="missing_segmentation" + return 1 + fi + + if [[ ! -f "${anndata_file}" ]]; then + local -a anndata_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${anndata_dir}" + --format anndata + ) + if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then + LAST_EXPORT_STATUS="anndata_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" + fi + + if [[ ! -f "${xenium_file}" ]]; then + local -a xenium_cmd=( + segger export + -s "${seg_file}" + -i "${INPUT_DIR}" + -o "${xenium_dir}" + --format xenium_explorer + --boundary-method "${BOUNDARY_METHOD}" + --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" + --num-workers "${XENIUM_NUM_WORKERS}" + ) + if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then + LAST_EXPORT_STATUS="xenium_export_failed" + return 1 + fi + else + echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" + fi + + LAST_EXPORT_STATUS="ok" + return 0 +} + +LAST_JOB_STATUS="unknown" + +run_job() { + local gpu="$1" + local spec="$2" + + local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${spec}" + + local seg_dir="${RUNS_DIR}/${job_name}" + local seg_file="${seg_dir}/segger_segmentation.parquet" + local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" + local anndata_file="${anndata_dir}/segger_segmentation.h5ad" + local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" + local xenium_file="${xenium_dir}/seg_experiment.xenium" + local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + + mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" + + { + echo "==================================================================" + echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" + echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} timeout_min=${SEGMENT_TIMEOUT_MIN} dl_workers=${SEGMENT_NUM_WORKERS} anc_retry_workers=${SEGMENT_ANC_RETRY_WORKERS} sharing=${TORCH_SHARING_STRATEGY}" + } | tee -a "${log_file}" >/dev/null + + if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ + [[ -f "${seg_file}" ]] && \ + [[ -f "${anndata_file}" ]] && \ + [[ -f "${xenium_file}" ]]; then + echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="skipped_existing" + return 0 + fi + + if [[ ! -f "${seg_file}" ]]; then + local -a seg_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger segment + -i "${INPUT_DIR}" + -o "${seg_dir}" + --n-epochs "${N_EPOCHS}" + --prediction-mode "${PREDICTION_MODE}" + --prediction-expansion-ratio "${expansion}" + --cells-min-counts "${cells_min_counts}" + --min-qv "${min_qv}" + --use-3d "${use_3d}" + --transcripts-max-k "${tx_k}" + --transcripts-max-dist "${tx_dist}" + --n-mid-layers "${n_layers}" + --n-heads "${n_heads}" + ) + if [[ "${alignment_loss}" == "true" ]]; then + seg_cmd+=( + --alignment-loss + --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" + --alignment-loss-weight-end "${ALIGNMENT_LOSS_WEIGHT_END}" + ) + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + seg_cmd+=( + --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" + ) + fi + fi + + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" + local seg_rc=$? + if [[ "${seg_rc}" -ne 0 ]]; then + if [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + + if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + local -a seg_retry_cmd=("${seg_cmd[@]}") + local i + for i in "${!seg_retry_cmd[@]}"; do + if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then + seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" + break + fi + done + run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" + seg_rc=$? + if [[ "${seg_rc}" -eq 0 ]]; then + echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null + elif [[ "${seg_rc}" -eq 124 ]]; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oot" + return 1 + fi + fi + + if [[ "${seg_rc}" -eq 0 ]]; then + : + else + local last_ckpt="${seg_dir}/checkpoints/last.ckpt" + if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null + else + echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="predict_fallback_failed" + return 1 + fi + else + if is_ancdata_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_ancdata" + elif is_oom_failure "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_oom" + else + echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="segment_failed" + fi + return 1 + fi + fi + fi + else + echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null + fi + + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" + return 1 + fi + + echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null + LAST_JOB_STATUS="ok" + return 0 +} + +run_gpu_group() { + local gpu="$1" + shift + local -a indices=("$@") + local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" + + printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" + + local idx spec job_name start_ts end_ts elapsed_s + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r job_name _ <<< "${spec}" + + start_ts="$(date +%s)" + run_job "${gpu}" "${spec}" + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + + printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${LAST_JOB_STATUS}" \ + "${elapsed_s}" \ + "${RUNS_DIR}/${job_name}" \ + "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ + >> "${summary_file}" + done +} + +run_post_recovery_predict_only_group() { + local gpu="$1" + local out_file="$2" + shift 2 + local -a indices=("$@") + + printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" + + local idx spec job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss + local seg_dir seg_file last_ckpt log_file note status + local start_ts end_ts elapsed_s + + for idx in "${indices[@]}"; do + spec="${JOB_SPECS[$idx]}" + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${spec}" + + seg_dir="${RUNS_DIR}/${job_name}" + seg_file="${seg_dir}/segger_segmentation.parquet" + last_ckpt="${seg_dir}/checkpoints/last.ckpt" + log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" + mkdir -p "${seg_dir}" + + start_ts="$(date +%s)" + note="" + status="ok" + + if [[ -f "${seg_file}" ]]; then + note="segmentation_exists" + if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="${LAST_EXPORT_STATUS}" + note="exports_failed_after_existing_seg" + fi + else + if [[ -f "${last_ckpt}" ]]; then + echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null + local -a predict_cmd=( + env CUDA_VISIBLE_DEVICES="${gpu}" + PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" + SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" + segger predict + -c "${last_ckpt}" + -i "${INPUT_DIR}" + -o "${seg_dir}" + ) + if run_cmd "${log_file}" "${predict_cmd[@]}"; then + if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then + status="recovered_predict_ok" + note="predict_only_from_last_ckpt" + else + status="${LAST_EXPORT_STATUS}" + note="predict_recovered_but_exports_failed" + fi + else + status="recovered_predict_failed" + note="predict_only_failed" + fi + else + status="recovery_no_checkpoint" + note="missing_seg_and_last_ckpt" + fi + fi + + end_ts="$(date +%s)" + elapsed_s=$((end_ts - start_ts)) + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" \ + "${gpu}" \ + "${status}" \ + "${elapsed_s}" \ + "${note}" \ + "${seg_dir}" \ + "${log_file}" \ + >> "${out_file}" + done +} + +run_post_recovery_predict_only() { + local recovery_file="${SUMMARY_DIR}/recovery.tsv" + local recovery_a="${SUMMARY_DIR}/recovery.gpu${GPU_A}.tsv" + local recovery_b="${SUMMARY_DIR}/recovery.gpu${GPU_B}.tsv" + local pid_a pid_b + + if [[ "${GPU_A}" == "${GPU_B}" ]]; then + run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" "${GPU_B_INDICES[@]}" + cp "${recovery_a}" "${recovery_file}" + return + fi + + run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" & + pid_a=$! + run_post_recovery_predict_only_group "${GPU_B}" "${recovery_b}" "${GPU_B_INDICES[@]}" & + pid_b=$! + + wait "${pid_a}" + wait "${pid_b}" + + awk 'FNR==1 && NR!=1 {next} {print}' "${recovery_a}" "${recovery_b}" > "${recovery_file}" +} + +build_jobs + +if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then + echo "ERROR: No robustness/ablation jobs were generated." + exit 1 +fi + +GPU_A_INDICES=() +GPU_B_INDICES=() + +idx=0 +for spec in "${JOB_SPECS[@]}"; do + if (( idx % 2 == 0 )); then + GPU_A_INDICES+=("${idx}") + else + GPU_B_INDICES+=("${idx}") + fi + idx=$((idx + 1)) +done + +{ + printf "job\tstudy_block\tgroup\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" + for idx in "${!JOB_SPECS[@]}"; do + local_group="A" + if (( idx % 2 == 1 )); then + local_group="B" + fi + IFS='|' read -r \ + job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ + <<< "${JOB_SPECS[$idx]}" + local_block="$(job_block "${job_name}")" + printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ + "${job_name}" "${local_block}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ + "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" + done +} > "${PLAN_FILE}" + +echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} jobs." +echo "[$(timestamp)] Study blocks: stability + interaction + stress" +echo "[$(timestamp)] Group A (GPU ${GPU_A}): ${#GPU_A_INDICES[@]} jobs" +echo "[$(timestamp)] Group B (GPU ${GPU_B}): ${#GPU_B_INDICES[@]} jobs" +echo "[$(timestamp)] Job plan: ${PLAN_FILE}" +echo "[$(timestamp)] Logs: ${LOGS_DIR}" + +run_gpu_group "${GPU_A}" "${GPU_A_INDICES[@]}" & +PID_A=$! +run_gpu_group "${GPU_B}" "${GPU_B_INDICES[@]}" & +PID_B=$! + +wait "${PID_A}" +wait "${PID_B}" + +if [[ "${DRY_RUN}" != "1" ]]; then + echo "[$(timestamp)] Starting post-run predict-only recovery pass..." + run_post_recovery_predict_only +fi + +COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" +if [[ "${DRY_RUN}" != "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" + ) +else + awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" + FAILED_COUNT=$( + awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" + ) +fi + +echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" +if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then + echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" +fi + +if [[ "${DRY_RUN}" != "1" ]] && [[ "${RUN_VALIDATION_TABLE}" == "1" ]]; then + if [[ -f "${VALIDATION_SCRIPT}" ]]; then + echo "[$(timestamp)] Building validation metrics table..." + validation_log="${SUMMARY_DIR}/validation_metrics.log" + validation_cmd=( + bash "${VALIDATION_SCRIPT}" + --root "${OUTPUT_ROOT}" + --input-dir "${INPUT_DIR}" + --gpu-a "${GPU_A}" + --gpu-b "${GPU_B}" + --include-default-10x "${VALIDATION_INCLUDE_DEFAULT_10X}" + ) + if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then + validation_cmd+=(--me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") + fi + if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then + validation_cmd+=( + --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" + --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" + ) + fi + if "${validation_cmd[@]}" >> "${validation_log}" 2>&1; then + echo "[$(timestamp)] Validation table updated: ${OUTPUT_ROOT}/summaries/validation_metrics.tsv" + else + echo "[$(timestamp)] WARN: validation table build failed (see ${validation_log})" + fi + else + echo "[$(timestamp)] WARN: VALIDATION_SCRIPT not found: ${VALIDATION_SCRIPT}" + fi +fi + +echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" + +if [[ "${FAILED_COUNT}" -gt 0 ]]; then + exit 1 +fi diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index 2b1e4f7..f94bb50 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -83,8 +83,8 @@ def __init__( points_key: str = "transcripts", shapes_key: str = "cells", include_table: bool = True, - table_key: str = "cells_table", # no duplicate names allowed - # fragment_table_key: str = "fragments_table", + table_key: str = "cells_table", + fragment_table_key: str = "fragments_table", table_region_key: str = "cell_id", ): require_spatialdata() @@ -249,7 +249,6 @@ def _create_spatialdata( """Create SpatialData object from transcripts and boundaries.""" import spatialdata from spatialdata.models import PointsModel, ShapesModel, TableModel - import spatialdata.models._accessor # for points parsing on pre-release (https://github.com/scverse/spatialdata/issues/1093) import dask.dataframe as dd identity = self._identity_transform() @@ -280,6 +279,7 @@ def _create_spatialdata( tx_pd[col] = tx_pd[col].astype(float) # Create Dask DataFrame for points + tx_pd[feature_column] = tx_pd[feature_column].astype("category") tx_dask = dd.from_pandas(tx_pd) # Points element @@ -325,17 +325,17 @@ def _parse_shapes(shapes): if self.include_boundaries and self.boundary_method != "skip": if self.boundary_method == "input": - for bd_type in ["cell", "nucleus"]: # these are segger hard-coded + bd_types = {"cell": "cells", "nucleus": "nuclei"} + for k, v in bd_types.items(): shapes = self._get_input_boundaries( cell_tx_pd, cell_id_column, boundaries, - bd_type) + k) shapes = _ensure_cell_id(shapes) parsed = _parse_shapes(shapes) if parsed is not None: - shapes_elements[f"{bd_type}_boundaries"] = parsed - # this naming convention is very Xenium-based (ideally one would maintain the input one which is currently lost) + shapes_elements[v] = parsed else: shape_specs = [(self.shapes_key, cell_tx_pd)] if has_fragments and fragment_tx_pd is not None: @@ -381,6 +381,12 @@ def _parse_shapes(shapes): pass tables_elements[self.table_key] = table + for name, table in tables_elements.items(): + if 'spatialdata_attrs' not in table.uns.keys(): + warnings.warn( + f"Table {name} does not contain the `uns['spatialdata_attrs']` field as no shapes element is associated." + ) + # Create SpatialData (prefer modern constructor methods, keep fallback on single elemnts) sdata = self._build_spatialdata( spatialdata=spatialdata, @@ -439,7 +445,7 @@ def _build_table_element( obs_index_as_str=True, ) if region is None: - return table + return TableModel.validate(table) instance_key = self.table_region_key table.obs["region"] = region @@ -452,7 +458,8 @@ def _build_table_element( region_key="region", instance_key=instance_key or "instance_id", ) - except Exception: + except Exception as e: + warnings.warn(f"TableModel.parse failed: {e}") return table def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> None: @@ -525,6 +532,7 @@ def _get_generated_boundaries( elif self.boundary_method == "delaunay": from segger.export.boundary import generate_boundaries + warnings.filterwarnings('ignore', 'GeoSeries.notna', UserWarning) boundaries_gdf = generate_boundaries( assigned, diff --git a/src/segger/models/alignment_loss.py b/src/segger/models/alignment_loss.py new file mode 100644 index 0000000..b7e9a9e --- /dev/null +++ b/src/segger/models/alignment_loss.py @@ -0,0 +1,118 @@ +"""Alignment loss for mutually exclusive gene constraints. + +This module implements alignment loss using ME gene pairs (negatives) and +same-gene transcript neighbors (positives). Other tx-tx edges are ignored +for the alignment objective. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AlignmentLoss(nn.Module): + """Contrastive loss for ME-gene constraints.""" + + def __init__( + self, + weight_start: float = 0.0, + weight_end: float = 0.1, + ): + super().__init__() + self.weight_start = weight_start + self.weight_end = weight_end + self._margin = 0.2 + + def get_scheduled_weight( + self, + current_epoch: int, + max_epochs: int, + ) -> float: + """Cosine schedule between start/end weights.""" + max_epochs = max(1, max_epochs - 1) + t = min(current_epoch, max_epochs) / max_epochs + alpha = 0.5 * (1.0 + math.cos(math.pi * t)) + return self.weight_end + (self.weight_start - self.weight_end) * alpha + + def forward( + self, + embeddings_src: torch.Tensor, + embeddings_dst: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Compute alignment loss for transcript-transcript edges.""" + sim = (embeddings_src * embeddings_dst).sum(dim=-1) + labels = labels.float() + + pos_mask = labels > 0.5 + neg_mask = ~pos_mask + + loss = torch.tensor(0.0, device=sim.device) + if pos_mask.any(): + pos_loss = (1.0 - sim[pos_mask]) ** 2 + loss = loss + pos_loss.mean() + if neg_mask.any(): + neg_loss = F.relu(sim[neg_mask] - self._margin) ** 2 + loss = loss + neg_loss.mean() + + return loss + + +def compute_me_gene_edges( + gene_indices: torch.Tensor, + me_gene_pairs: torch.Tensor, + edge_index: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Create tx-tx alignment edges: same-gene positives + ME negatives.""" + src, dst = edge_index + src_genes = gene_indices[src] + dst_genes = gene_indices[dst] + + pos_mask = src_genes == dst_genes + + neg_mask = torch.zeros_like(pos_mask, dtype=torch.bool) + if me_gene_pairs.numel() > 0 and src_genes.numel() > 0: + me_genes = torch.unique(me_gene_pairs.flatten()) + in_me = torch.isin(src_genes, me_genes) & torch.isin(dst_genes, me_genes) + if in_me.any(): + pair_min = torch.minimum(me_gene_pairs[:, 0], me_gene_pairs[:, 1]) + pair_max = torch.maximum(me_gene_pairs[:, 0], me_gene_pairs[:, 1]) + max_gene = max( + src_genes.max().item() if src_genes.numel() > 0 else 0, + dst_genes.max().item() if dst_genes.numel() > 0 else 0, + pair_max.max().item() if pair_max.numel() > 0 else 0, + ) + 1 + me_pair_keys = pair_min * max_gene + pair_max + + edge_min = torch.minimum(src_genes[in_me], dst_genes[in_me]) + edge_max = torch.maximum(src_genes[in_me], dst_genes[in_me]) + edge_pair_keys = edge_min * max_gene + edge_max + is_me = torch.isin(edge_pair_keys, me_pair_keys) + neg_mask[in_me] = is_me + + n_pos = int(pos_mask.sum().item()) + n_neg = int(neg_mask.sum().item()) + if n_neg == 0 and n_pos == 0: + return edge_index[:, :0], torch.empty((0,), device=edge_index.device) + if n_neg == 0: + return edge_index[:, :0], torch.empty((0,), device=edge_index.device) + + max_pos = 3 * n_neg + if n_pos > max_pos: + pos_idx = pos_mask.nonzero().flatten() + pos_idx = pos_idx[ + torch.randperm(n_pos, device=pos_idx.device)[:max_pos] + ] + keep = torch.zeros_like(pos_mask, dtype=torch.bool) + keep[pos_idx] = True + keep |= neg_mask + else: + keep = pos_mask | neg_mask + + if not keep.any(): + return edge_index[:, :0], torch.empty((0,), device=edge_index.device) + + labels = torch.zeros(keep.sum().item(), device=edge_index.device) + labels[pos_mask[keep]] = 1.0 + return edge_index[:, keep], labels diff --git a/src/segger/validation/__init__.py b/src/segger/validation/__init__.py new file mode 100644 index 0000000..ecde6fb --- /dev/null +++ b/src/segger/validation/__init__.py @@ -0,0 +1,14 @@ +from .me_genes import load_me_genes_from_scrna +from .quick_metrics import ( + count_cells_from_anndata, + compute_assignment_metrics, + compute_border_contamination_fast, + compute_mecr_fast, + compute_resolvi_contamination_fast, + compute_signal_doublet_fast, + compute_transcript_centroid_offset_fast, + load_me_gene_pairs, + load_segmentation, + load_source_transcripts, + merge_assigned_transcripts, +) diff --git a/src/segger/validation/me_genes.py b/src/segger/validation/me_genes.py new file mode 100644 index 0000000..8e64585 --- /dev/null +++ b/src/segger/validation/me_genes.py @@ -0,0 +1,421 @@ +"""Mutually exclusive gene discovery from scRNA-seq reference. + +This module provides functions to identify mutually exclusive (ME) gene pairs +from single-cell RNA-seq reference data. ME genes are markers that are highly +expressed in one cell type but not co-expressed in the same cell, making them +useful constraints for cell segmentation. + +Ported from segger v0.1.0 validation/utils.py. +""" + +from typing import Dict, List, Tuple, Optional +from pathlib import Path +import warnings +import json +import hashlib +import time +import os +import numpy as np +import anndata as ad +import scanpy as sc +import pandas as pd +from itertools import combinations + + +def find_markers( + adata: ad.AnnData, + cell_type_column: str, + pos_percentile: float = 10, + neg_percentile: float = 10, + percentage: float = 30, +) -> Dict[str, Dict[str, List[str]]]: + """Identify positive and negative markers for each cell type. + + Parameters + ---------- + adata : ad.AnnData + Annotated data object containing gene expression data. + cell_type_column : str + Column name in `adata.obs` that specifies cell types. + pos_percentile : float, optional + Percentile threshold for top highly expressed genes (default: 10). + neg_percentile : float, optional + Percentile threshold for top lowly expressed genes (default: 10). + percentage : float, optional + Minimum percentage of cells expressing the marker (default: 30). + + Returns + ------- + dict + Dictionary where keys are cell types and values contain: + 'positive': list of highly expressed genes + 'negative': list of lowly expressed genes + """ + markers = {} + sc.tl.rank_genes_groups(adata, groupby=cell_type_column) + genes = adata.var_names + + for cell_type in adata.obs[cell_type_column].unique(): + subset = adata[adata.obs[cell_type_column] == cell_type] + mean_expression = np.asarray(subset.X.mean(axis=0)).flatten() + + cutoff_high = np.percentile(mean_expression, 100 - pos_percentile) + cutoff_low = np.percentile(mean_expression, neg_percentile) + + pos_indices = np.where(mean_expression >= cutoff_high)[0] + neg_indices = np.where(mean_expression <= cutoff_low)[0] + + # Filter by expression percentage + expr_frac = np.asarray((subset.X[:, pos_indices] > 0).mean(axis=0)).flatten() + valid_pos_indices = pos_indices[expr_frac >= (percentage / 100)] + + positive_markers = genes[valid_pos_indices] + negative_markers = genes[neg_indices] + + markers[cell_type] = { + "positive": list(positive_markers), + "negative": list(negative_markers), + } + + return markers + + +def find_mutually_exclusive_genes( + adata: ad.AnnData, + markers: Dict[str, Dict[str, List[str]]], + cell_type_column: str, + expr_threshold_in: float = 0.25, + expr_threshold_out: float = 0.03, +) -> List[Tuple[str, str]]: + """Identify mutually exclusive genes based on expression criteria. + + A gene is considered ME if it's expressed in >expr_threshold_in of its + cell type but in 0).mean() + expr_out = (gene_expr[non_cell_type_mask] > 0).mean() + + if expr_in > expr_threshold_in and expr_out < expr_threshold_out: + exclusive_genes[cell_type].append(gene) + all_exclusive.append(gene) + + # Get unique exclusive genes + unique_genes = list(set(all_exclusive)) + filtered_exclusive_genes = { + ct: [g for g in genes if g in unique_genes] + for ct, genes in exclusive_genes.items() + } + + # Create pairs from different cell types + mutually_exclusive_gene_pairs = [ + (gene1, gene2) + for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2) + for gene1 in filtered_exclusive_genes[key1] + for gene2 in filtered_exclusive_genes[key2] + ] + + return mutually_exclusive_gene_pairs + + +def compute_MECR( + adata: ad.AnnData, + gene_pairs: List[Tuple[str, str]], +) -> Dict[Tuple[str, str], float]: + """Compute Mutually Exclusive Co-expression Rate (MECR) for gene pairs. + + MECR = (both expressed) / (at least one expressed) + Lower MECR indicates better mutual exclusivity. + + Parameters + ---------- + adata : ad.AnnData + Annotated data object containing gene expression data. + gene_pairs : list + List of gene pairs to evaluate. + + Returns + ------- + dict + Dictionary mapping gene pairs to MECR values. + """ + mecr_dict = {} + gene_expression = adata.to_df() + + for gene1, gene2 in gene_pairs: + if gene1 not in gene_expression.columns or gene2 not in gene_expression.columns: + continue + + expr_gene1 = gene_expression[gene1] > 0 + expr_gene2 = gene_expression[gene2] > 0 + + both_expressed = (expr_gene1 & expr_gene2).mean() + at_least_one_expressed = (expr_gene1 | expr_gene2).mean() + + mecr = ( + both_expressed / at_least_one_expressed + if at_least_one_expressed > 0 + else 0 + ) + mecr_dict[(gene1, gene2)] = mecr + + return mecr_dict + + +def load_me_genes_from_scrna( + scrna_path: Path, + cell_type_column: str = "celltype", + gene_name_column: Optional[str] = None, + pos_percentile: float = 10, + neg_percentile: float = 10, + percentage: float = 30, + expr_threshold_in: float = 0.25, + expr_threshold_out: float = 0.03, +) -> Tuple[List[Tuple[str, str]], Dict[str, Dict[str, List[str]]]]: + """Load scRNA-seq reference and compute ME gene pairs. + + Parameters + ---------- + scrna_path : Path + Path to scRNA-seq reference h5ad file. + cell_type_column : str, optional + Column name for cell type annotations (default: "celltype"). + gene_name_column : str | None, optional + Column in var for gene names. If None, uses var_names. + pos_percentile : float, optional + Percentile for positive markers (default: 10). + neg_percentile : float, optional + Percentile for negative markers (default: 10). + percentage : float, optional + Minimum expression percentage (default: 30). + expr_threshold_in : float, optional + Minimum expression in own cell type (default: 0.25). + expr_threshold_out : float, optional + Maximum expression in other cell types (default: 0.03). + Notes + ----- + For performance, cells are subsampled to at most 1000 per cell type. + + Returns + ------- + tuple + (me_gene_pairs, markers) where me_gene_pairs is a list of + (gene1, gene2) tuples and markers is the full marker dictionary. + """ + verbose = os.getenv("SEGGER_ME_VERBOSE", "").lower() not in {"0", "false", "no", "off"} + # Cache to avoid repeated expensive ME discovery + cache_key = _me_cache_key( + scrna_path=scrna_path, + cell_type_column=cell_type_column, + gene_name_column=gene_name_column, + pos_percentile=pos_percentile, + neg_percentile=neg_percentile, + percentage=percentage, + expr_threshold_in=expr_threshold_in, + expr_threshold_out=expr_threshold_out, + ) + cache_path = _me_cache_path(scrna_path, cache_key) + if cache_path.exists(): + try: + with open(cache_path, "r") as f: + cached = json.load(f) + if cached.get("key") == cache_key: + pairs = [ + (p[0], p[1]) + for p in cached.get("me_gene_pairs", []) + if len(p) == 2 + ] + markers = cached.get("markers", {}) + if verbose: + print( + f"[segger][me] cache hit: {len(pairs)} pairs", + flush=True, + ) + return pairs, markers + except Exception: + pass + + t0 = time.monotonic() + if verbose: + print( + "[segger][me] computing ME gene pairs (this can take a while)...", + flush=True, + ) + + # Load scRNA-seq data + adata = sc.read_h5ad(scrna_path) + + # Subsample cells per cell type to limit runtime + if cell_type_column in adata.obs: + rng = np.random.default_rng(0) + idx = [] + for ct in adata.obs[cell_type_column].unique(): + ct_idx = np.where(adata.obs[cell_type_column] == ct)[0] + if ct_idx.size > _ME_MAX_CELLS_PER_TYPE: + ct_idx = rng.choice( + ct_idx, + size=_ME_MAX_CELLS_PER_TYPE, + replace=False, + ) + idx.append(ct_idx) + if idx: + idx = np.concatenate(idx) + adata = adata[idx].copy() + + # Ensure unique var names and log-normalize if needed + if not adata.var_names.is_unique: + adata.var_names_make_unique() + if "log1p" not in adata.uns: + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) + + # Optionally remap gene names + if gene_name_column is not None and gene_name_column in adata.var.columns: + adata.var_names = adata.var[gene_name_column] + + # Find markers + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=pd.errors.PerformanceWarning, + ) + markers = find_markers( + adata, + cell_type_column=cell_type_column, + pos_percentile=pos_percentile, + neg_percentile=neg_percentile, + percentage=percentage, + ) + + # Find ME gene pairs + me_gene_pairs = find_mutually_exclusive_genes( + adata, + markers, + cell_type_column=cell_type_column, + expr_threshold_in=expr_threshold_in, + expr_threshold_out=expr_threshold_out, + ) + + if verbose: + n_types = adata.obs[cell_type_column].nunique() + elapsed = time.monotonic() - t0 + print( + f"[segger][me] done: {len(me_gene_pairs)} pairs " + f"across {n_types} cell types in {elapsed:.1f}s", + flush=True, + ) + + # Write cache (best-effort) + try: + payload = { + "key": cache_key, + "me_gene_pairs": [list(p) for p in me_gene_pairs], + "markers": markers, + } + with open(cache_path, "w") as f: + json.dump(payload, f) + except Exception: + pass + + return me_gene_pairs, markers + + +def me_gene_pairs_to_indices( + me_gene_pairs: List[Tuple[str, str]], + gene_names: List[str], +) -> List[Tuple[int, int]]: + """Convert gene name pairs to index pairs. + + Parameters + ---------- + me_gene_pairs : list + List of (gene1, gene2) name tuples. + gene_names : list + List of gene names in order (index corresponds to token). + + Returns + ------- + list + List of (idx1, idx2) index tuples. + """ + gene_to_idx = {name: idx for idx, name in enumerate(gene_names)} + + index_pairs = [] + for gene1, gene2 in me_gene_pairs: + if gene1 in gene_to_idx and gene2 in gene_to_idx: + index_pairs.append((gene_to_idx[gene1], gene_to_idx[gene2])) + + return index_pairs +_ME_CACHE_VERSION = 2 +_ME_MAX_CELLS_PER_TYPE = 1000 + + +def _me_cache_key( + scrna_path: Path, + cell_type_column: str, + gene_name_column: Optional[str], + pos_percentile: float, + neg_percentile: float, + percentage: float, + expr_threshold_in: float, + expr_threshold_out: float, +) -> str: + """Create a stable cache key for ME gene discovery inputs.""" + st = scrna_path.stat() + payload = { + "version": _ME_CACHE_VERSION, + "path": str(scrna_path.resolve()), + "size": st.st_size, + "mtime_ns": st.st_mtime_ns, + "cell_type_column": cell_type_column, + "gene_name_column": gene_name_column, + "pos_percentile": pos_percentile, + "neg_percentile": neg_percentile, + "percentage": percentage, + "expr_threshold_in": expr_threshold_in, + "expr_threshold_out": expr_threshold_out, + "max_cells_per_type": _ME_MAX_CELLS_PER_TYPE, + } + raw = json.dumps(payload, sort_keys=True).encode("utf-8") + return hashlib.sha256(raw).hexdigest()[:16] + + +def _me_cache_path(scrna_path: Path, key: str) -> Path: + """Cache file path for ME gene discovery outputs.""" + return Path(f"{scrna_path}.segger_me_cache.{key}.json") diff --git a/src/segger/validation/quick_metrics.py b/src/segger/validation/quick_metrics.py new file mode 100644 index 0000000..1eeb1f2 --- /dev/null +++ b/src/segger/validation/quick_metrics.py @@ -0,0 +1,1050 @@ +"""Lightweight validation metrics for Segger outputs. + +This module provides fast, reference-light metrics intended for quick model +selection and single-run quality checks. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Sequence + +import anndata as ad +import numpy as np +import polars as pl +from scipy import sparse +from scipy.spatial import cKDTree + +from ..io import StandardTranscriptFields, get_preprocessor +from .me_genes import load_me_genes_from_scrna + + +def assigned_cell_expr(cell_id_column: str = "segger_cell_id") -> pl.Expr: + """Expression selecting transcripts assigned to a valid cell.""" + cell = pl.col(cell_id_column) + cell_str = cell.cast(pl.Utf8) + return ( + cell.is_not_null() + & (cell_str != "-1") + & (cell_str.str.to_uppercase() != "UNASSIGNED") + & (cell_str.str.to_uppercase() != "NONE") + ) + + +def _effective_sample_size(weights: np.ndarray) -> float: + """Kish effective sample size for non-negative weights.""" + w = np.asarray(weights, dtype=np.float64) + w = w[np.isfinite(w) & (w > 0)] + if w.size == 0: + return float("nan") + sw = float(w.sum()) + sw2 = float(np.square(w).sum()) + if sw <= 0 or sw2 <= 0: + return float("nan") + return (sw * sw) / sw2 + + +def _weighted_mean_ci95(values: np.ndarray, weights: np.ndarray) -> float: + """Approximate 95% CI half-width for a weighted mean.""" + v = np.asarray(values, dtype=np.float64) + w = np.asarray(weights, dtype=np.float64) + mask = np.isfinite(v) & np.isfinite(w) & (w > 0) + if not np.any(mask): + return float("nan") + v = v[mask] + w = w[mask] + if v.size == 0: + return float("nan") + mu = float(np.average(v, weights=w)) + neff = _effective_sample_size(w) + if not np.isfinite(neff) or neff <= 1: + return float("nan") + var = float(np.average(np.square(v - mu), weights=w)) + se = np.sqrt(max(var, 0.0) / neff) + return float(1.96 * se) + + +def _weighted_bernoulli_ci95(flags: np.ndarray, weights: np.ndarray) -> float: + """Approximate 95% CI half-width for weighted Bernoulli proportion.""" + f = np.asarray(flags, dtype=np.float64) + w = np.asarray(weights, dtype=np.float64) + mask = np.isfinite(f) & np.isfinite(w) & (w > 0) + if not np.any(mask): + return float("nan") + f = f[mask] + w = w[mask] + if f.size == 0: + return float("nan") + p = float(np.average(f, weights=w)) + neff = _effective_sample_size(w) + if not np.isfinite(neff) or neff <= 1: + return float("nan") + se = np.sqrt(max(p * (1.0 - p), 0.0) / neff) + return float(1.96 * se) + + +def _binomial_pct_ci95(successes: int, total: int) -> float: + """95% CI half-width in percentage points for a binomial proportion.""" + if total <= 0: + return float("nan") + p = min(max(float(successes) / float(total), 0.0), 1.0) + se = np.sqrt(max(p * (1.0 - p), 0.0) / float(total)) + return float(100.0 * 1.96 * se) + + +def load_source_transcripts(source_path: Path) -> pl.DataFrame: + """Load standardized source transcripts with only needed columns.""" + tx_fields = StandardTranscriptFields() + source_path = Path(source_path) + tx = None + + # Optional SpatialData input support + try: + from ..io.spatialdata_loader import is_spatialdata_path, load_from_spatialdata + + if is_spatialdata_path(source_path): + tx_lf, _ = load_from_spatialdata( + source_path, + boundary_type="all", + normalize=True, + ) + tx = tx_lf.collect() if isinstance(tx_lf, pl.LazyFrame) else tx_lf + except Exception: + tx = None + + if tx is None: + pp = get_preprocessor(source_path, min_qv=0, include_z=True) + tx = pp.transcripts + if isinstance(tx, pl.LazyFrame): + tx = tx.collect() + + keep_cols = [ + tx_fields.row_index, + tx_fields.feature, + tx_fields.x, + tx_fields.y, + ] + if tx_fields.z in tx.columns: + keep_cols.append(tx_fields.z) + tx = tx.select([c for c in keep_cols if c in tx.columns]) + + if tx_fields.row_index not in tx.columns: + tx = tx.with_row_index(name=tx_fields.row_index) + + return tx + + +def load_segmentation(segmentation_path: Path) -> pl.DataFrame: + """Load segmentation parquet with canonical columns.""" + seg = pl.read_parquet(segmentation_path) + required = ["row_index", "segger_cell_id"] + missing = [c for c in required if c not in seg.columns] + if missing: + raise ValueError( + f"Segmentation file missing required columns {missing}: {segmentation_path}" + ) + return seg.select([c for c in ["row_index", "segger_cell_id", "segger_similarity"] if c in seg.columns]) + + +def compute_assignment_metrics( + seg_df: pl.DataFrame, + cell_id_column: str = "segger_cell_id", +) -> dict[str, float]: + """Compute transcript assignment coverage metrics.""" + total = int(seg_df.height) + if total == 0: + return { + "transcripts_total": 0, + "transcripts_assigned": 0, + "transcripts_assigned_pct": float("nan"), + "transcripts_assigned_pct_ci95": float("nan"), + "cells_assigned": 0, + } + + assigned_df = seg_df.filter(assigned_cell_expr(cell_id_column)) + assigned = int(assigned_df.height) + cells = int( + assigned_df.select(pl.col(cell_id_column).n_unique()).to_series().item() + if assigned > 0 + else 0 + ) + return { + "transcripts_total": total, + "transcripts_assigned": assigned, + "transcripts_assigned_pct": 100.0 * assigned / total, + "transcripts_assigned_pct_ci95": _binomial_pct_ci95(assigned, total), + "cells_assigned": cells, + } + + +def count_cells_from_anndata(anndata_path: Optional[Path]) -> Optional[int]: + """Return number of cells (n_obs) from AnnData, or None if unavailable.""" + if anndata_path is None: + return None + path = Path(anndata_path) + if not path.exists(): + return None + + adata = ad.read_h5ad(path, backed="r") + try: + return int(adata.n_obs) + finally: + try: + if getattr(adata, "isbacked", False): + adata.file.close() + except Exception: + pass + + +def merge_assigned_transcripts( + seg_df: pl.DataFrame, + source_tx: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + row_index_column: str = "row_index", +) -> pl.DataFrame: + """Inner-join source transcripts with assigned segmentation rows.""" + left = source_tx + right = seg_df + + if row_index_column not in left.columns: + left = left.with_row_index(name=row_index_column) + + left = left.with_columns(pl.col(row_index_column).cast(pl.Int64)) + right = right.with_columns(pl.col(row_index_column).cast(pl.Int64)) + right = right.filter(assigned_cell_expr(cell_id_column)).select([row_index_column, cell_id_column]) + + return left.join(right, on=row_index_column, how="inner") + + +def _empty_resolvi_metrics() -> dict[str, float]: + """Default empty return payload for RESOLVI-like contamination metric.""" + return { + "resolvi_contamination_pct_fast": float("nan"), + "resolvi_contamination_ci95_fast": float("nan"), + "resolvi_contaminated_cells_pct_fast": float("nan"), + "resolvi_contaminated_cells_pct_ci95_fast": float("nan"), + "resolvi_metric_cells_used": 0, + "resolvi_shared_genes_used": 0, + "resolvi_cell_types_used": 0, + } + + +def _build_cell_gene_matrix( + assigned_tx: pl.DataFrame, + *, + cell_id_column: str, + feature_column: str, + x_column: str, + y_column: str, + min_transcripts_per_cell: int, + max_cells: int, + seed: int, +) -> tuple[sparse.csr_matrix, np.ndarray, np.ndarray, list[str]] | None: + """Build sparse cell x gene counts with centroids and per-cell weights.""" + req = [cell_id_column, feature_column, x_column, y_column] + for col in req: + if col not in assigned_tx.columns: + return None + + df = ( + assigned_tx.select(req) + .drop_nulls() + .with_columns( + pl.col(cell_id_column).cast(pl.Utf8), + pl.col(feature_column).cast(pl.Utf8), + ) + ) + if df.height == 0: + return None + + cell_stats = ( + df.group_by(cell_id_column) + .agg( + pl.len().alias("n_total"), + pl.col(x_column).mean().alias("cx"), + pl.col(y_column).mean().alias("cy"), + ) + .filter(pl.col("n_total") >= int(min_transcripts_per_cell)) + ) + if cell_stats.height == 0: + return None + + if max_cells > 0 and cell_stats.height > max_cells: + rng = np.random.default_rng(seed) + ids = np.asarray(cell_stats.get_column(cell_id_column).to_list(), dtype=object) + picked = rng.choice(ids, size=max_cells, replace=False).tolist() + cell_stats = cell_stats.filter(pl.col(cell_id_column).is_in(picked)) + + cell_stats = cell_stats.sort(cell_id_column).with_row_index(name="_cid") + if cell_stats.height == 0: + return None + + df = df.join(cell_stats.select([cell_id_column, "_cid"]), on=cell_id_column, how="inner") + if df.height == 0: + return None + + gene_idx = ( + df.select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_gid") + ) + if gene_idx.height == 0: + return None + + mapped = df.join(gene_idx, on=feature_column, how="inner") + counts = mapped.group_by(["_cid", "_gid"]).agg(pl.len().alias("_count")) + if counts.height == 0: + return None + + rows = counts.get_column("_cid").to_numpy().astype(np.int64, copy=False) + cols = counts.get_column("_gid").to_numpy().astype(np.int64, copy=False) + data = counts.get_column("_count").to_numpy().astype(np.float64, copy=False) + + X = sparse.coo_matrix( + (data, (rows, cols)), + shape=(int(cell_stats.height), int(gene_idx.height)), + ).tocsr() + + centroids = cell_stats.select(["cx", "cy"]).to_numpy().astype(np.float64, copy=False) + weights = cell_stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) + gene_names = [str(g) for g in gene_idx.get_column(feature_column).to_list()] + return X, centroids, weights, gene_names + + +def _load_reference_type_profiles( + scrna_reference_path: Path, + scrna_celltype_column: str, + seg_gene_names: list[str], +) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: + """Load per-celltype expression profiles on genes shared with segmentation.""" + if not Path(scrna_reference_path).exists(): + return None + + ref = ad.read_h5ad(scrna_reference_path) + if scrna_celltype_column not in ref.obs.columns: + return None + + labels_raw = ref.obs[scrna_celltype_column].astype(str).to_numpy() + labels = np.asarray([str(x) for x in labels_raw], dtype=object) + label_norm = np.char.lower(labels.astype(str)) + valid = ( + (labels.astype(str) != "") + & (label_norm != "nan") + & (label_norm != "none") + & (label_norm != "-1") + ) + if not np.any(valid): + return None + + ref_genes = np.asarray([str(g) for g in ref.var_names], dtype=object) + ref_gene_to_idx = {g: i for i, g in enumerate(ref_genes.tolist())} + + seg_shared_idx: list[int] = [] + ref_shared_idx: list[int] = [] + for i, g in enumerate(seg_gene_names): + j = ref_gene_to_idx.get(str(g)) + if j is None: + continue + seg_shared_idx.append(i) + ref_shared_idx.append(int(j)) + + if len(seg_shared_idx) == 0: + return None + + X_ref = ref.X + if sparse.issparse(X_ref): + X_ref = X_ref.tocsr()[valid][:, np.asarray(ref_shared_idx, dtype=np.int64)] + else: + X_ref = np.asarray(X_ref)[valid][:, np.asarray(ref_shared_idx, dtype=np.int64)] + + labels_valid = labels[valid].astype(str) + type_names, type_inverse = np.unique(labels_valid, return_inverse=True) + if type_names.size == 0: + return None + + n_types = int(type_names.size) + n_genes = int(len(seg_shared_idx)) + profiles = np.zeros((n_types, n_genes), dtype=np.float64) + + for t in range(n_types): + idx = np.where(type_inverse == t)[0] + if idx.size == 0: + continue + if sparse.issparse(X_ref): + sub = X_ref[idx] + profiles[t] = np.asarray(sub.mean(axis=0)).ravel() + else: + profiles[t] = np.asarray(X_ref[idx], dtype=np.float64).mean(axis=0) + + profiles = np.nan_to_num(profiles, nan=0.0, posinf=0.0, neginf=0.0) + profiles = np.maximum(profiles, 0.0) + keep = np.asarray(profiles.sum(axis=1)).ravel() > 0 + if not np.any(keep): + return None + + return ( + np.asarray(seg_shared_idx, dtype=np.int64), + profiles[keep], + type_names[keep], + ) + + +def compute_resolvi_contamination_fast( + assigned_tx: pl.DataFrame, + *, + scrna_reference_path: Optional[Path], + scrna_celltype_column: str = "cell_type", + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: str = "x", + y_column: str = "y", + min_transcripts_per_cell: int = 20, + max_cells: int = 3000, + k_neighbors: int = 10, + max_neighbor_distance: float = 20.0, + alpha_self: float = 0.8, + alpha_neighbor: float = 0.175, + alpha_background: float = 0.025, + contam_cutoff: float = 0.5, + seed: int = 0, +) -> dict[str, float]: + """Fast RESOLVI-style contamination estimate (lower is better). + + This approximates the RESOLVI neighborhood contamination formulation on a + sampled subset of segmented cells by: + 1) deriving host cell types from scRNA reference profile similarity, + 2) mixing self/neighbor/background expected expression per gene, + 3) flagging counts as contaminated when q_self < contam_cutoff. + """ + out = _empty_resolvi_metrics() + if scrna_reference_path is None: + return out + if alpha_self < 0 or alpha_neighbor < 0 or alpha_background < 0: + return out + if alpha_self + alpha_neighbor + alpha_background <= 0: + return out + + try: + built = _build_cell_gene_matrix( + assigned_tx, + cell_id_column=cell_id_column, + feature_column=feature_column, + x_column=x_column, + y_column=y_column, + min_transcripts_per_cell=min_transcripts_per_cell, + max_cells=max_cells, + seed=seed, + ) + if built is None: + return out + X, centroids, cell_weights, gene_names = built + if X.shape[0] == 0 or X.shape[1] == 0: + return out + + ref_data = _load_reference_type_profiles( + Path(scrna_reference_path), + scrna_celltype_column=scrna_celltype_column, + seg_gene_names=gene_names, + ) + if ref_data is None: + return out + seg_shared_idx, ref_profiles, type_names = ref_data + if seg_shared_idx.size == 0 or ref_profiles.size == 0: + return out + + X = X[:, seg_shared_idx] + if X.shape[1] == 0: + return out + + totals = np.asarray(X.sum(axis=1)).ravel().astype(np.float64, copy=False) + keep = np.isfinite(totals) & (totals > 0) + if not np.any(keep): + return out + if not np.all(keep): + rows_keep = np.where(keep)[0] + X = X[rows_keep] + centroids = centroids[rows_keep] + cell_weights = cell_weights[rows_keep] + totals = totals[rows_keep] + + n_cells = int(X.shape[0]) + n_types = int(ref_profiles.shape[0]) + out["resolvi_shared_genes_used"] = int(X.shape[1]) + out["resolvi_cell_types_used"] = int(type_names.size) + if n_cells == 0 or n_types == 0: + return out + + eps = 1e-9 + ref = np.asarray(ref_profiles, dtype=np.float64) + ref_norm = np.linalg.norm(ref, axis=1) + ref_norm[~np.isfinite(ref_norm) | (ref_norm <= 0)] = 1.0 + + cell_norm = np.sqrt(np.asarray(X.multiply(X).sum(axis=1)).ravel()) + cell_norm[~np.isfinite(cell_norm) | (cell_norm <= 0)] = 1.0 + + sim = X @ ref.T + if sparse.issparse(sim): + sim = sim.toarray() + sim = np.asarray(sim, dtype=np.float64) + sim /= cell_norm[:, None] + sim /= ref_norm[None, :] + host_type = np.argmax(sim, axis=1).astype(np.int64) + + neighbor_freq = np.zeros((n_cells, n_types), dtype=np.float64) + if n_cells > 1 and int(k_neighbors) > 0: + k = min(int(k_neighbors) + 1, n_cells) + tree = cKDTree(centroids) + dists, idxs = tree.query(centroids, k=k) + if k > 1: + if dists.ndim == 1: + dists = dists[:, None] + idxs = idxs[:, None] + nbr_d = dists[:, 1:] + nbr_i = idxs[:, 1:] + for i in range(n_cells): + if nbr_i.shape[1] == 0: + continue + if np.isfinite(max_neighbor_distance) and max_neighbor_distance > 0: + valid_nbr = nbr_d[i] <= float(max_neighbor_distance) + else: + valid_nbr = np.ones(nbr_i.shape[1], dtype=bool) + pick = nbr_i[i][valid_nbr] + if pick.size == 0: + continue + np.add.at(neighbor_freq[i], host_type[pick], 1.0) + s = float(neighbor_freq[i].sum()) + if s > 0: + neighbor_freq[i] /= s + + bg = np.bincount( + host_type, + weights=np.asarray(cell_weights, dtype=np.float64), + minlength=n_types, + ).astype(np.float64, copy=False) + bsum = float(bg.sum()) + if bsum > 0: + bg /= bsum + p_back = bg @ ref + + per_cell_pct = np.full(n_cells, np.nan, dtype=np.float64) + per_cell_flag = np.zeros(n_cells, dtype=np.float64) + + for i in range(n_cells): + row = X.getrow(i) + if row.nnz == 0: + continue + h = int(host_type[i]) + neigh = neighbor_freq[i].copy() + if 0 <= h < n_types: + neigh[h] = 0.0 + nsum = float(neigh.sum()) + if nsum > 0: + neigh /= nsum + p_self = ref[h] + p_neigh = neigh @ ref + denom = (alpha_self * p_self) + (alpha_neighbor * p_neigh) + (alpha_background * p_back) + eps + q_self = (alpha_self * p_self) / denom + q_self = np.clip(q_self, 0.0, 1.0) + + vals = row.data.astype(np.float64, copy=False) + cols = row.indices + total = float(vals.sum()) + if total <= 0: + continue + contam = float(vals[q_self[cols] < float(contam_cutoff)].sum()) + pct = 100.0 * contam / total + per_cell_pct[i] = pct + per_cell_flag[i] = 1.0 if contam > 0 else 0.0 + + valid_cells = np.isfinite(per_cell_pct) & np.isfinite(totals) & (totals > 0) + if not np.any(valid_cells): + return out + + w = totals[valid_cells] + vals = per_cell_pct[valid_cells] + flags = per_cell_flag[valid_cells] + + out["resolvi_metric_cells_used"] = int(np.sum(valid_cells)) + out["resolvi_contamination_pct_fast"] = float(np.average(vals, weights=w)) + out["resolvi_contamination_ci95_fast"] = float(_weighted_mean_ci95(vals, w)) + out["resolvi_contaminated_cells_pct_fast"] = float(100.0 * np.average(flags, weights=w)) + out["resolvi_contaminated_cells_pct_ci95_fast"] = float(100.0 * _weighted_bernoulli_ci95(flags, w)) + return out + except Exception: + return out + + +def compute_border_contamination_fast( + assigned_tx: pl.DataFrame, + *, + cell_id_column: str = "segger_cell_id", + x_column: str = "x", + y_column: str = "y", + erosion_fraction: float = 0.3, + min_transcripts_per_cell: int = 20, + max_cells: int = 3000, + contaminated_enrichment_threshold: float = 1.25, + seed: int = 0, +) -> dict[str, float]: + """Fast border-enrichment contamination proxy (lower is better). + + This approximates periphery contamination by comparing transcript density + in border vs center regions defined by an eroded bounding box. + """ + eps = 1e-9 + req = [cell_id_column, x_column, y_column] + for col in req: + if col not in assigned_tx.columns: + return { + "border_contamination_fast": float("nan"), + "border_enrichment_fast": float("nan"), + "border_excess_pct_fast": float("nan"), + "border_contaminated_cells_pct_fast": float("nan"), + "border_contaminated_cells_pct_ci95_fast": float("nan"), + "border_metric_cells_used": 0, + } + + df = assigned_tx.select(req).drop_nulls() + if df.height == 0: + return { + "border_contamination_fast": float("nan"), + "border_enrichment_fast": float("nan"), + "border_excess_pct_fast": float("nan"), + "border_contaminated_cells_pct_fast": float("nan"), + "border_contaminated_cells_pct_ci95_fast": float("nan"), + "border_metric_cells_used": 0, + } + + cell_stats = ( + df.group_by(cell_id_column) + .agg( + pl.len().alias("n_total"), + pl.col(x_column).min().alias("x_min"), + pl.col(x_column).max().alias("x_max"), + pl.col(y_column).min().alias("y_min"), + pl.col(y_column).max().alias("y_max"), + ) + .with_columns( + (pl.col("x_max") - pl.col("x_min")).alias("width"), + (pl.col("y_max") - pl.col("y_min")).alias("height"), + ) + .with_columns( + pl.min_horizontal("width", "height").alias("min_side"), + (pl.min_horizontal("width", "height") * erosion_fraction).alias("erosion"), + ) + .filter(pl.col("n_total") >= min_transcripts_per_cell) + .filter((pl.col("width") > 0) & (pl.col("height") > 0)) + .filter((pl.col("min_side") > 0) & (pl.col("erosion") > 0)) + .with_columns( + (pl.col("x_min") + pl.col("erosion")).alias("cx_min"), + (pl.col("x_max") - pl.col("erosion")).alias("cx_max"), + (pl.col("y_min") + pl.col("erosion")).alias("cy_min"), + (pl.col("y_max") - pl.col("erosion")).alias("cy_max"), + ) + .filter((pl.col("cx_max") > pl.col("cx_min")) & (pl.col("cy_max") > pl.col("cy_min"))) + ) + + if cell_stats.height == 0: + return { + "border_contamination_fast": float("nan"), + "border_enrichment_fast": float("nan"), + "border_excess_pct_fast": float("nan"), + "border_contaminated_cells_pct_fast": float("nan"), + "border_contaminated_cells_pct_ci95_fast": float("nan"), + "border_metric_cells_used": 0, + } + + if max_cells > 0 and cell_stats.height > max_cells: + rng = np.random.default_rng(seed) + cell_ids = np.array(cell_stats.get_column(cell_id_column).to_list(), dtype=object) + picked = rng.choice(cell_ids, size=max_cells, replace=False).tolist() + cell_stats = cell_stats.filter(pl.col(cell_id_column).is_in(picked)) + df = df.join(cell_stats.select([cell_id_column]), on=cell_id_column, how="inner") + + classified = ( + df.join( + cell_stats.select( + [ + cell_id_column, + "cx_min", + "cx_max", + "cy_min", + "cy_max", + "width", + "height", + "n_total", + ] + ), + on=cell_id_column, + how="inner", + ) + .with_columns( + ( + (pl.col(x_column) >= pl.col("cx_min")) + & (pl.col(x_column) <= pl.col("cx_max")) + & (pl.col(y_column) >= pl.col("cy_min")) + & (pl.col(y_column) <= pl.col("cy_max")) + ).alias("is_center") + ) + ) + + grouped = ( + classified.group_by([cell_id_column, "is_center"]) + .agg(pl.len().alias("n")) + ) + n_center = grouped.filter(pl.col("is_center")).select([cell_id_column, pl.col("n").alias("n_center")]) + n_border = grouped.filter(~pl.col("is_center")).select([cell_id_column, pl.col("n").alias("n_border")]) + + per_cell = ( + cell_stats.join(n_center, on=cell_id_column, how="left") + .join(n_border, on=cell_id_column, how="left") + .with_columns( + pl.col("n_center").fill_null(0).cast(pl.Float64), + pl.col("n_border").fill_null(0).cast(pl.Float64), + pl.col("n_total").cast(pl.Float64), + ) + .with_columns( + (pl.col("width") * pl.col("height")).alias("bbox_area"), + ((pl.col("cx_max") - pl.col("cx_min")) * (pl.col("cy_max") - pl.col("cy_min"))).alias("center_area"), + ) + .with_columns( + (pl.col("bbox_area") - pl.col("center_area")).alias("border_area"), + ) + .with_columns( + (pl.col("n_center") / pl.max_horizontal(pl.col("center_area"), pl.lit(eps))).alias("center_density"), + (pl.col("n_border") / pl.max_horizontal(pl.col("border_area"), pl.lit(eps))).alias("border_density"), + ) + .with_columns( + (pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps))).alias("border_enrichment"), + ( + pl.when( + (pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps)) - 1.0) + > 0 + ) + .then(pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps)) - 1.0) + .otherwise(0.0) + ).alias("contam_score"), + ) + ) + + if per_cell.height == 0: + return { + "border_contamination_fast": float("nan"), + "border_enrichment_fast": float("nan"), + "border_excess_pct_fast": float("nan"), + "border_contaminated_cells_pct_fast": float("nan"), + "border_contaminated_cells_pct_ci95_fast": float("nan"), + "border_metric_cells_used": 0, + } + + weights = per_cell.get_column("n_total").to_numpy() + contam = np.average(per_cell.get_column("contam_score").to_numpy(), weights=weights) + enrich = np.average(per_cell.get_column("border_enrichment").to_numpy(), weights=weights) + border_excess_pct = max(0.0, (float(enrich) - 1.0) * 100.0) + contaminated_flags = ( + per_cell.get_column("border_enrichment").to_numpy() + > float(contaminated_enrichment_threshold) + ).astype(np.float64) + contaminated_cells_pct = 100.0 * np.average( + contaminated_flags, + weights=weights, + ) + contaminated_cells_pct_ci95 = 100.0 * _weighted_bernoulli_ci95( + contaminated_flags, + weights, + ) + return { + "border_contamination_fast": float(contam), + "border_enrichment_fast": float(enrich), + "border_excess_pct_fast": float(border_excess_pct), + "border_contaminated_cells_pct_fast": float(contaminated_cells_pct), + "border_contaminated_cells_pct_ci95_fast": float(contaminated_cells_pct_ci95), + "border_metric_cells_used": int(per_cell.height), + } + + +def compute_transcript_centroid_offset_fast( + assigned_tx: pl.DataFrame, + *, + cell_id_column: str = "segger_cell_id", + x_column: str = "x", + y_column: str = "y", + min_transcripts_per_cell: int = 20, + max_cells: int = 3000, + seed: int = 0, +) -> dict[str, float]: + """Fast transcript centroid-offset metric (higher is better). + + Uses a bounding-box center as cell centroid approximation. + """ + req = [cell_id_column, x_column, y_column] + for col in req: + if col not in assigned_tx.columns: + return { + "transcript_centroid_offset_fast": float("nan"), + "transcript_centroid_offset_ci95_fast": float("nan"), + "tco_metric_cells_used": 0, + } + + stats = ( + assigned_tx.select(req) + .drop_nulls() + .group_by(cell_id_column) + .agg( + pl.len().alias("n_total"), + pl.col(x_column).mean().alias("tx_cx"), + pl.col(y_column).mean().alias("tx_cy"), + pl.col(x_column).min().alias("x_min"), + pl.col(x_column).max().alias("x_max"), + pl.col(y_column).min().alias("y_min"), + pl.col(y_column).max().alias("y_max"), + ) + .filter(pl.col("n_total") >= min_transcripts_per_cell) + .with_columns( + (pl.col("x_max") - pl.col("x_min")).alias("width"), + (pl.col("y_max") - pl.col("y_min")).alias("height"), + ) + .filter((pl.col("width") > 0) & (pl.col("height") > 0)) + .with_columns( + ((pl.col("x_min") + pl.col("x_max")) / 2.0).alias("cell_cx"), + ((pl.col("y_min") + pl.col("y_max")) / 2.0).alias("cell_cy"), + (pl.col("width") * pl.col("height")).alias("area"), + ) + .filter(pl.col("area") > 0) + ) + + if stats.height == 0: + return { + "transcript_centroid_offset_fast": float("nan"), + "transcript_centroid_offset_ci95_fast": float("nan"), + "tco_metric_cells_used": 0, + } + + if max_cells > 0 and stats.height > max_cells: + rng = np.random.default_rng(seed) + ids = np.array(stats.get_column(cell_id_column).to_list(), dtype=object) + picked = rng.choice(ids, size=max_cells, replace=False).tolist() + stats = stats.filter(pl.col(cell_id_column).is_in(picked)) + + stats = stats.with_columns( + ( + ( + (pl.col("tx_cx") - pl.col("cell_cx")) ** 2 + + (pl.col("tx_cy") - pl.col("cell_cy")) ** 2 + ).sqrt() + ).alias("centroid_offset") + ).with_columns( + ( + 1.0 - (pl.col("centroid_offset") / (pl.col("area").sqrt() + 1e-9)) + ).clip(lower_bound=0.0, upper_bound=1.0).alias("tco_score") + ) + + weights = stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) + tco_vals = stats.get_column("tco_score").to_numpy().astype(np.float64, copy=False) + tco = float(np.average(tco_vals, weights=weights)) + tco_ci95 = _weighted_mean_ci95(tco_vals, weights) + return { + "transcript_centroid_offset_fast": tco, + "transcript_centroid_offset_ci95_fast": float(tco_ci95), + "tco_metric_cells_used": int(stats.height), + } + + +def compute_signal_doublet_fast( + assigned_tx: pl.DataFrame, + *, + cell_id_column: str = "segger_cell_id", + z_column: str = "z", + min_transcripts_per_cell: int = 20, + max_cells: int = 3000, + seed: int = 0, + doublet_threshold: float = 0.6, +) -> dict[str, float]: + """Fast 3D doublet-like fraction based on per-cell z-spread.""" + if z_column not in assigned_tx.columns or cell_id_column not in assigned_tx.columns: + return { + "signal_doublet_like_fraction_fast": float("nan"), + "signal_doublet_like_fraction_ci95_fast": float("nan"), + "signal_metric_cells_used": 0, + } + + stats = ( + assigned_tx.select([cell_id_column, z_column]) + .drop_nulls() + .group_by(cell_id_column) + .agg( + pl.len().alias("n_total"), + pl.col(z_column).std().alias("z_std"), + ) + .filter(pl.col("n_total") >= min_transcripts_per_cell) + .drop_nulls(["z_std"]) + ) + + if stats.height == 0: + return { + "signal_doublet_like_fraction_fast": float("nan"), + "signal_doublet_like_fraction_ci95_fast": float("nan"), + "signal_metric_cells_used": 0, + } + + if max_cells > 0 and stats.height > max_cells: + rng = np.random.default_rng(seed) + ids = np.array(stats.get_column(cell_id_column).to_list(), dtype=object) + picked = rng.choice(ids, size=max_cells, replace=False).tolist() + stats = stats.filter(pl.col(cell_id_column).is_in(picked)) + + n_total = stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) + z_std = stats.get_column("z_std").to_numpy().astype(np.float64, copy=False) + z_std = np.where(np.isfinite(z_std), z_std, np.nan) + positive = z_std[np.isfinite(z_std) & (z_std > 0)] + + if positive.size == 0: + ci95 = _weighted_bernoulli_ci95( + np.zeros_like(n_total, dtype=np.float64), + n_total, + ) + return { + "signal_doublet_like_fraction_fast": 0.0, + "signal_doublet_like_fraction_ci95_fast": float(ci95), + "signal_metric_cells_used": int(stats.height), + } + + expected = float(np.median(positive)) + if expected <= 1e-12: + ci95 = _weighted_bernoulli_ci95( + np.zeros_like(n_total, dtype=np.float64), + n_total, + ) + return { + "signal_doublet_like_fraction_fast": 0.0, + "signal_doublet_like_fraction_ci95_fast": float(ci95), + "signal_metric_cells_used": int(stats.height), + } + + integrity = np.clip(expected / (z_std + 1e-9), 0.0, 1.0) + doublet_flags = (integrity < doublet_threshold).astype(np.float64) + doublet_like = float(np.average(doublet_flags, weights=n_total)) + doublet_like_ci95 = _weighted_bernoulli_ci95(doublet_flags, n_total) + return { + "signal_doublet_like_fraction_fast": doublet_like, + "signal_doublet_like_fraction_ci95_fast": float(doublet_like_ci95), + "signal_metric_cells_used": int(stats.height), + } + + +def load_me_gene_pairs( + *, + me_gene_pairs_path: Optional[Path] = None, + scrna_reference_path: Optional[Path] = None, + scrna_celltype_column: str = "cell_type", +) -> list[tuple[str, str]]: + """Load mutually exclusive gene pairs from file or scRNA reference.""" + if me_gene_pairs_path is not None: + pairs: list[tuple[str, str]] = [] + with Path(me_gene_pairs_path).open("r", encoding="utf-8") as fh: + for raw_line in fh: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if "\t" in line: + parts = [p.strip() for p in line.split("\t")] + elif "," in line: + parts = [p.strip() for p in line.split(",")] + else: + parts = [p.strip() for p in line.split()] + if len(parts) < 2: + continue + if parts[0].lower() in {"gene1", "gene_a"} and parts[1].lower() in {"gene2", "gene_b"}: + continue + pairs.append((parts[0], parts[1])) + return pairs + + if scrna_reference_path is not None: + pairs, _ = load_me_genes_from_scrna( + scrna_path=Path(scrna_reference_path), + cell_type_column=scrna_celltype_column, + ) + return [(str(g1), str(g2)) for g1, g2 in pairs] + + return [] + + +def compute_mecr_fast( + anndata_path: Path, + gene_pairs: Sequence[tuple[str, str]], + *, + max_pairs: int = 500, + soft: bool = True, + seed: int = 0, +) -> dict[str, float]: + """Compute fast MECR from an AnnData file (lower is better).""" + if anndata_path is None or not Path(anndata_path).exists(): + return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} + if len(gene_pairs) == 0: + return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} + + adata = ad.read_h5ad(anndata_path) + gene_to_idx = {str(g): i for i, g in enumerate(adata.var_names)} + + valid_pairs: list[tuple[int, int]] = [] + for g1, g2 in gene_pairs: + i = gene_to_idx.get(str(g1)) + j = gene_to_idx.get(str(g2)) + if i is None or j is None: + continue + valid_pairs.append((i, j)) + + if len(valid_pairs) == 0: + return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} + + if max_pairs > 0 and len(valid_pairs) > max_pairs: + rng = np.random.default_rng(seed) + pick = rng.choice(len(valid_pairs), size=max_pairs, replace=False) + valid_pairs = [valid_pairs[int(i)] for i in pick] + + X = adata.X + is_sparse = sparse.issparse(X) + if is_sparse: + X = X.tocsc() + else: + X = np.asarray(X) + + vals: list[float] = [] + for i, j in valid_pairs: + if is_sparse: + a = np.asarray(X.getcol(i).toarray()).ravel() + b = np.asarray(X.getcol(j).toarray()).ravel() + else: + a = np.asarray(X[:, i]).ravel() + b = np.asarray(X[:, j]).ravel() + + if soft: + den = float(np.maximum(a, b).sum()) + if den <= 0: + continue + num = float(np.minimum(a, b).sum()) + vals.append(num / den) + else: + a_bin = a > 0 + b_bin = b > 0 + either = float((a_bin | b_bin).sum()) + if either <= 0: + continue + both = float((a_bin & b_bin).sum()) + vals.append(both / either) + + if len(vals) == 0: + return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} + + arr = np.asarray(vals, dtype=np.float64) + ci95 = float("nan") + if arr.size > 1: + se = float(np.std(arr, ddof=1)) / np.sqrt(float(arr.size)) + ci95 = float(1.96 * se) + + return { + "mecr_fast": float(np.mean(arr)), + "mecr_ci95_fast": float(ci95), + "mecr_pairs_used": int(len(vals)), + } From 960fbf2558cc599a8054c009e2415970a62d21bf Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Tue, 5 May 2026 19:52:40 +0200 Subject: [PATCH 12/20] Fix regex construction for pattern filtering --- src/segger/io/filtering.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 src/segger/io/filtering.py diff --git a/src/segger/io/filtering.py b/src/segger/io/filtering.py new file mode 100644 index 0000000..abb796d --- /dev/null +++ b/src/segger/io/filtering.py @@ -0,0 +1,79 @@ +"""Shared transcript filtering utilities for I/O readers.""" + +from __future__ import annotations + +import re +from typing import Collection, Sequence + +import polars as pl + +from .fields import CosMxTranscriptFields, MerscopeTranscriptFields, XeniumTranscriptFields + + +_PLATFORM_ALIASES: dict[str, str] = { + "10x_xenium": "xenium", + "nanostring_cosmx": "cosmx", + "vizgen_merscope": "merscope", +} + + +def normalize_platform_name(platform: str | None) -> str | None: + """Normalize platform aliases to canonical names.""" + if platform is None: + return None + lowered = str(platform).strip().lower() + return _PLATFORM_ALIASES.get(lowered, lowered) + + +def infer_platform_from_columns(columns: Collection[str]) -> str | None: + """Infer source platform from transcript table columns.""" + cols = set(columns) + + # CosMx marker columns are highly specific. + if "CellComp" in cols or {"x_global_px", "y_global_px"}.issubset(cols): + return "cosmx" + + # Xenium marker columns. + if "overlaps_nucleus" in cols or "qv" in cols: + return "xenium" + if {"x_location", "y_location", "feature_name"}.issubset(cols): + return "xenium" + + # MERSCOPE marker columns. + if {"global_x", "global_y"}.issubset(cols): + return "merscope" + + return None + + +def platform_feature_filter_patterns(platform: str | None) -> list[str]: + """Return feature-name control patterns for the given platform.""" + normalized = normalize_platform_name(platform) + if normalized == "xenium": + return list(XeniumTranscriptFields.filter_substrings) + if normalized == "cosmx": + return list(CosMxTranscriptFields.filter_substrings) + if normalized == "merscope": + return list(MerscopeTranscriptFields.filter_substrings) + return [] + +def glob_patterns_to_regex(patterns: Sequence[str]) -> str: + """Convert glob-like patterns (`*`) to a regex union.""" + regexes = [] + for pattern in patterns: + regex_pattern = re.escape(pattern).replace(r"\*", ".*") + regexes.append(f"^{regex_pattern}$") + return "|".join(regexes) + + +def apply_feature_filters( + lf: pl.LazyFrame, + feature_column: str, + patterns: Sequence[str], +) -> pl.LazyFrame: + """Drop rows whose feature names match control/blank patterns.""" + if not patterns: + return lf + pattern_regex = glob_patterns_to_regex(patterns) + feature_expr = pl.col(feature_column).cast(pl.String, strict=False) + return lf.filter(feature_expr.str.contains(pattern_regex).fill_null(False).not_()) From 93d29b3cc784b624a556856a99a03edcfd780323 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 17:38:52 +0200 Subject: [PATCH 13/20] Improve code robustness on dataframe joins --- src/segger/data/utils/anndata.py | 2 +- src/segger/data/utils/heterodata.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/segger/data/utils/anndata.py b/src/segger/data/utils/anndata.py index 293fccc..f2ba8dc 100644 --- a/src/segger/data/utils/anndata.py +++ b/src/segger/data/utils/anndata.py @@ -159,7 +159,7 @@ def setup_anndata( ad.obs .join( ( - boundaries + boundaries.drop_duplicates(subset=bd_fields.id) # some data oddly has duplicate boundary entries on the same cell id .reset_index(names=bd_fields.index) .set_index(bd_fields.id, verify_integrity=True) .get(bd_fields.index) diff --git a/src/segger/data/utils/heterodata.py b/src/segger/data/utils/heterodata.py index fd3bea1..9254609 100644 --- a/src/segger/data/utils/heterodata.py +++ b/src/segger/data/utils/heterodata.py @@ -45,6 +45,11 @@ def setup_heterodata( tx_fields.cell_cluster, tx_fields.gene_cluster, ] + + transcripts = transcripts.with_columns( + pl.col(tx_fields.feature).cast(pl.Utf8) + ) + # Update transcripts with fields for training transcripts = ( @@ -56,9 +61,14 @@ def setup_heterodata( pl.from_pandas( adata.var[[genes_encoding_column, genes_clusters_column]], include_index=True - ), + ).rename({ + pl.from_pandas( + adata.var[[genes_encoding_column, genes_clusters_column]], + include_index=True + ).columns[0]: tx_fields.feature + }), left_on=tx_fields.feature, - right_on=adata.var.index.name if adata.var.index.name else 'None', + right_on=tx_fields.feature, ) .rename( { From bdec1086b2f06be6a58b3fc3392db037554e13cc Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 17:53:17 +0200 Subject: [PATCH 14/20] Clean code for convex hull and delaunay boundaries on cells (no fragments) --- src/segger/cli/segment.py | 10 ++++++- src/segger/data/writer.py | 6 ++-- src/segger/export/spatialdata_writer.py | 40 +++++++------------------ src/segger/io/spatialdata_loader.py | 2 -- 4 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/segger/cli/segment.py b/src/segger/cli/segment.py index 6dc1fa0..6a1e0aa 100644 --- a/src/segger/cli/segment.py +++ b/src/segger/cli/segment.py @@ -311,8 +311,15 @@ def segment( save_spatialdata: Annotated[bool, registry.get_parameter( "save_spatialdata", - group=group_io, + group=group_io, # might change )] = registry.get_default("save_spatialdata"), + + boundary_method: Annotated[ + Literal["convex_hull", "delaunay", "skip"], + registry.get_parameter( + "boundary_method", + group=group_io, # might change + )] = registry.get_default("boundary_method"), debug: Annotated[bool, Parameter( help="Whether to save additional debug information (trainer, predictions).", @@ -415,6 +422,7 @@ def segment( output_directory, save_anndata=save_anndata, save_spatialdata=save_spatialdata, + boundary_method=boundary_method, debug=debug, ) trainer = Trainer( diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index 3e57526..ae22d34 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -28,6 +28,7 @@ def __init__( output_directory: Path, save_anndata: bool = True, save_spatialdata: bool = True, + boundary_method: str = "convex_hull", debug: bool = False ): # "write" callback at the end of prediction epoch @@ -36,6 +37,7 @@ def __init__( self.output_directory = Path(output_directory) self.save_anndata = save_anndata self.save_spatialdata = save_spatialdata + self.boundary_method = boundary_method self.segger_logger = logging.getLogger(__name__) # setup debugging @@ -132,8 +134,8 @@ def write_anndata( if self.save_spatialdata: writer = SpatialDataWriter( include_boundaries="True", - boundary_method='convex_hull', - # boundary_n_jobs=max(num_workers, 1), + boundary_method=self.boundary_method, + boundary_n_jobs=4, ) tx, _ = _resolve_transcripts_and_boundaries(self.input_directory) output_path = writer.write( diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index f94bb50..a4b47bb 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -78,7 +78,7 @@ class SpatialDataWriter: def __init__( self, include_boundaries: bool = True, - boundary_method: Literal["input", "convex_hull", "delaunay", "skip"] = "input", + boundary_method: Literal["input", "convex_hull", "delaunay", "skip"] = "convex_hull", boundary_n_jobs: int = 1, points_key: str = "transcripts", shapes_key: str = "cells", @@ -112,7 +112,7 @@ def write( x_column: str = "x", y_column: str = "y", z_column: Optional[str] = "z", - overwrite: bool = False, + overwrite: bool = True, **kwargs, ) -> Path: """Write segmentation results to SpatialData Zarr store. @@ -320,34 +320,16 @@ def _parse_shapes(shapes): kwargs = {"transformations": transformations} if transformations is not None else {} return ShapesModel.parse(shapes, **kwargs) - shapes_elements = {} - - if self.include_boundaries and self.boundary_method != "skip": - if self.boundary_method == "input": - bd_types = {"cell": "cells", "nucleus": "nuclei"} - for k, v in bd_types.items(): - shapes = self._get_input_boundaries( - cell_tx_pd, - cell_id_column, - boundaries, - k) - shapes = _ensure_cell_id(shapes) - parsed = _parse_shapes(shapes) - if parsed is not None: - shapes_elements[v] = parsed - else: - shape_specs = [(self.shapes_key, cell_tx_pd)] - if has_fragments and fragment_tx_pd is not None: - shape_specs.append((self.fragment_shapes_key, fragment_tx_pd)) - - for shape_key, shape_tx_pd in shape_specs: - shapes = self._get_generated_boundaries(shape_tx_pd, x_column, y_column, cell_id_column) - shapes = _ensure_cell_id(shapes) - parsed = _parse_shapes(shapes) - if parsed is not None: - shapes_elements[shape_key] = parsed - + + shape_specs = [(self.shapes_key, tx_pd)] + + for shape_key, shape_tx_pd in shape_specs: + shapes = self._get_generated_boundaries(shape_tx_pd, x_column, y_column, cell_id_column) + shapes = _ensure_cell_id(shapes) + parsed = _parse_shapes(shapes) + if parsed is not None: + shapes_elements[shape_key] = parsed # Optional AnnData table tables_elements = {} diff --git a/src/segger/io/spatialdata_loader.py b/src/segger/io/spatialdata_loader.py index 9d27f5c..895716f 100644 --- a/src/segger/io/spatialdata_loader.py +++ b/src/segger/io/spatialdata_loader.py @@ -61,8 +61,6 @@ def _safe_to_geodataframe(data: object) -> gpd.GeoDataFrame: return data.copy() if hasattr(data, "compute"): data = data.compute() - if isinstance(data, gpd.GeoDataFrame): - return data.copy() if hasattr(data, "to_geopandas"): return data.to_geopandas().copy() if hasattr(data, "to_pandas"): From c26dd83b2d32a4a923e56e991e66a401ce1d830c Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 18:13:00 +0200 Subject: [PATCH 15/20] Clean export module to minimal APIs for spatialdata writing --- src/segger/export/__init__.py | 144 --------- src/segger/export/adapter.py | 165 ---------- src/segger/export/anndata_writer.py | 250 --------------- src/segger/export/merged_writer.py | 317 ------------------ src/segger/export/minimal_apis.py | 410 ++++++++++++++++++++++++ src/segger/export/output_formats.py | 309 ------------------ src/segger/export/sopa_compat.py | 396 ----------------------- src/segger/export/spatialdata_writer.py | 3 +- 8 files changed, 411 insertions(+), 1583 deletions(-) delete mode 100644 src/segger/export/adapter.py delete mode 100644 src/segger/export/anndata_writer.py delete mode 100644 src/segger/export/merged_writer.py create mode 100644 src/segger/export/minimal_apis.py delete mode 100644 src/segger/export/output_formats.py delete mode 100644 src/segger/export/sopa_compat.py diff --git a/src/segger/export/__init__.py b/src/segger/export/__init__.py index 0df59d6..e69de29 100644 --- a/src/segger/export/__init__.py +++ b/src/segger/export/__init__.py @@ -1,144 +0,0 @@ -"""Export module for segmentation results. - -This module provides functionality to export segmentation results to various formats: -- Xenium Explorer format for visualization and validation -- Merged transcripts (original data with segmentation results) -- SpatialData Zarr format for scverse ecosystem -- SOPA-compatible format for spatial omics workflows -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING -import importlib - -__all__ = [ - # Existing exports - "BoundaryIdentification", - "generate_boundary", - "generate_boundaries", - "seg2explorer", - "seg2explorer_pqdm", - "predictions_to_dataframe", - # Output formats - "OutputFormat", - "OutputWriter", - "get_writer", - "register_writer", - "write_all_formats", - # Writers - "MergedTranscriptsWriter", - "SeggerRawWriter", - "AnnDataWriter", - "merge_predictions_with_transcripts", - # SpatialData (optional) - "SpatialDataWriter", - "write_spatialdata", - # SOPA (optional) - "validate_sopa_compatibility", - "export_for_sopa", - "sopa_to_segger_input", - "check_sopa_installation", -] - -if TYPE_CHECKING: # pragma: no cover - from .boundary import BoundaryIdentification, generate_boundary, generate_boundaries - from .xenium import seg2explorer, seg2explorer_pqdm - from .adapter import predictions_to_dataframe - from .output_formats import ( - OutputFormat, - OutputWriter, - get_writer, - register_writer, - write_all_formats, - ) - from .merged_writer import ( - MergedTranscriptsWriter, - SeggerRawWriter, - merge_predictions_with_transcripts, - ) - from .anndata_writer import AnnDataWriter - from .spatialdata_writer import SpatialDataWriter, write_spatialdata - from .sopa_compat import ( - validate_sopa_compatibility, - export_for_sopa, - sopa_to_segger_input, - check_sopa_installation, - ) - - -def __getattr__(name: str): - if name in {"BoundaryIdentification", "generate_boundary", "generate_boundaries"}: - from .boundary import BoundaryIdentification, generate_boundary, generate_boundaries - return locals()[name] - if name in {"seg2explorer", "seg2explorer_pqdm"}: - from .xenium import seg2explorer, seg2explorer_pqdm - return locals()[name] - if name == "predictions_to_dataframe": - from .adapter import predictions_to_dataframe - return predictions_to_dataframe - if name in { - "OutputFormat", - "OutputWriter", - "get_writer", - "register_writer", - "write_all_formats", - }: - from .output_formats import ( - OutputFormat, - OutputWriter, - get_writer, - register_writer, - write_all_formats, - ) - return locals()[name] - if name in { - "MergedTranscriptsWriter", - "SeggerRawWriter", - "AnnDataWriter", - "merge_predictions_with_transcripts", - }: - from .merged_writer import ( - MergedTranscriptsWriter, - SeggerRawWriter, - merge_predictions_with_transcripts, - ) - if name == "AnnDataWriter": - from .anndata_writer import AnnDataWriter - return locals()[name] - if name in {"SpatialDataWriter", "write_spatialdata"}: - try: - from .spatialdata_writer import SpatialDataWriter, write_spatialdata - except Exception: - return None - return locals()[name] - if name in { - "validate_sopa_compatibility", - "export_for_sopa", - "sopa_to_segger_input", - "check_sopa_installation", - }: - try: - from .sopa_compat import ( - validate_sopa_compatibility, - export_for_sopa, - sopa_to_segger_input, - check_sopa_installation, - ) - except Exception: - return None - return locals()[name] - if name in { - "boundary", - "xenium", - "adapter", - "output_formats", - "merged_writer", - "spatialdata_writer", - "sopa_compat", - }: - try: - return importlib.import_module(f"{__name__}.{name}") - except Exception as exc: - raise ImportError(f"Failed to import optional module '{name}'.") from exc - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/segger/export/adapter.py b/src/segger/export/adapter.py deleted file mode 100644 index 541daa9..0000000 --- a/src/segger/export/adapter.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Adapter to convert model predictions to export-compatible format. - -This module bridges the gap between LitISTEncoder.predict_step() output -and the seg2explorer functions for Xenium Explorer export. -""" - -from typing import Optional, Union -import pandas as pd -import polars as pl -import torch - - -def predictions_to_dataframe( - src_idx: torch.Tensor, - seg_idx: torch.Tensor, - max_sim: torch.Tensor, - gen_idx: torch.Tensor, - transcript_data: Union[pd.DataFrame, pl.DataFrame], - min_similarity: float = 0.5, - x_column: str = "x", - y_column: str = "y", - gene_column: str = "feature_name", -) -> pd.DataFrame: - """Convert prediction tensors to seg2explorer-compatible DataFrame. - - This function takes the output from LitISTEncoder.predict_step() and - combines it with the original transcript data to create a DataFrame - suitable for Xenium Explorer export. - - Parameters - ---------- - src_idx : torch.Tensor - Transcript indices from prediction, shape (N,). - seg_idx : torch.Tensor - Assigned boundary/cell indices, shape (N,). Value of -1 indicates - unassigned transcripts. - max_sim : torch.Tensor - Maximum similarity scores, shape (N,). - gen_idx : torch.Tensor - Gene indices for each transcript, shape (N,). - transcript_data : Union[pd.DataFrame, pl.DataFrame] - Original transcript DataFrame with coordinates. - min_similarity : float - Minimum similarity threshold for valid assignments. - x_column : str - Column name for x coordinates. - y_column : str - Column name for y coordinates. - gene_column : str - Column name for gene/feature names. - - Returns - ------- - pd.DataFrame - DataFrame with columns: - - row_index: Original transcript index - - x: X coordinate - - y: Y coordinate - - seg_cell_id: Assigned cell ID (or -1 if unassigned) - - similarity: Assignment confidence score - - feature_name: Gene name - """ - # Convert to numpy - src_idx_np = src_idx.cpu().numpy() - seg_idx_np = seg_idx.cpu().numpy() - max_sim_np = max_sim.cpu().numpy() - - # Filter by similarity threshold - valid_mask = (seg_idx_np >= 0) & (max_sim_np >= min_similarity) - - # Convert Polars to pandas if needed - if isinstance(transcript_data, pl.DataFrame): - transcript_data = transcript_data.to_pandas() - - # Build result DataFrame - result = pd.DataFrame({ - "row_index": src_idx_np, - "seg_cell_id": seg_idx_np, - "similarity": max_sim_np, - }) - - # Mark low-similarity assignments as unassigned - result.loc[~valid_mask, "seg_cell_id"] = -1 - - # Merge with original transcript data for coordinates - if "row_index" in transcript_data.columns: - # Use existing row_index - result = result.merge( - transcript_data[["row_index", x_column, y_column, gene_column]], - on="row_index", - how="left", - ) - else: - # Use index as row_index - transcript_data = transcript_data.reset_index() - transcript_data = transcript_data.rename(columns={"index": "row_index"}) - result = result.merge( - transcript_data[["row_index", x_column, y_column, gene_column]], - on="row_index", - how="left", - ) - - # Rename columns for consistency - result = result.rename(columns={ - gene_column: "feature_name", - x_column: "x", - y_column: "y", - }) - - return result - - -def collect_predictions( - predictions: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Collect predictions from multiple batches. - - Parameters - ---------- - predictions : list[tuple] - List of (src_idx, seg_idx, max_sim, gen_idx) tuples from predict_step. - - Returns - ------- - tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - Concatenated (src_idx, seg_idx, max_sim, gen_idx) tensors. - """ - src_indices = [] - seg_indices = [] - similarities = [] - gene_indices = [] - - for src_idx, seg_idx, max_sim, gen_idx in predictions: - src_indices.append(src_idx) - seg_indices.append(seg_idx) - similarities.append(max_sim) - gene_indices.append(gen_idx) - - return ( - torch.cat(src_indices), - torch.cat(seg_indices), - torch.cat(similarities), - torch.cat(gene_indices), - ) - - -def filter_assigned_transcripts( - seg_df: pd.DataFrame, - cell_id_column: str = "seg_cell_id", -) -> pd.DataFrame: - """Filter DataFrame to only include assigned transcripts. - - Parameters - ---------- - seg_df : pd.DataFrame - Segmentation result DataFrame. - cell_id_column : str - Column name for cell IDs. - - Returns - ------- - pd.DataFrame - DataFrame with only assigned transcripts. - """ - return seg_df[seg_df[cell_id_column] >= 0].copy() diff --git a/src/segger/export/anndata_writer.py b/src/segger/export/anndata_writer.py deleted file mode 100644 index 716c2cc..0000000 --- a/src/segger/export/anndata_writer.py +++ /dev/null @@ -1,250 +0,0 @@ -"""Write segmentation results as AnnData (.h5ad). - -This writer builds a cell x gene count matrix from transcript assignments -and saves it as an AnnData object. The output can also be embedded as a -table in SpatialData. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Optional, Union - -import numpy as np -import pandas as pd -import polars as pl -from anndata import AnnData -from scipy import sparse as sp - -from segger.export.output_formats import OutputFormat, register_writer -from segger.export.merged_writer import merge_predictions_with_transcripts - - -def build_anndata_table( - transcripts: pl.DataFrame, - cell_id_column: str = "segger_cell_id", - feature_column: str = "feature_name", - x_column: Optional[str] = "x", - y_column: Optional[str] = "y", - z_column: Optional[str] = "z", - unassigned_value: Union[int, str, None] = -1, - region: Optional[str] = None, - region_key: Optional[str] = None, - obs_index_as_str: bool = False, -) -> AnnData: - """Build AnnData from assigned transcripts. - - Parameters - ---------- - transcripts - Transcript DataFrame with segmentation assignments. - cell_id_column - Column with assigned cell IDs. - feature_column - Column with gene/feature names. - x_column, y_column, z_column - Coordinate columns (optional). If present, centroids are stored in - ``obsm["X_spatial"]``. - unassigned_value - Marker for unassigned transcripts (filtered out). - region, region_key - SpatialData table linkage metadata. - obs_index_as_str - If True, cast cell IDs to string for ``obs`` index. - """ - if cell_id_column not in transcripts.columns: - raise ValueError(f"Missing cell_id column: {cell_id_column}") - if feature_column not in transcripts.columns: - raise ValueError(f"Missing feature column: {feature_column}") - - assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) - if unassigned_value is not None: - col_dtype = transcripts.schema.get(cell_id_column) - try: - compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() - filter_expr = pl.col(cell_id_column) != compare_value - except Exception: - filter_expr = ( - pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) - ) - assigned = assigned.filter(filter_expr) - - # Gene list from all transcripts (even if no assignments) - var_idx = ( - transcripts - .select(feature_column) - .unique() - .sort(feature_column) - .get_column(feature_column) - .to_list() - ) - - if assigned.height == 0: - obs_index = pd.Index([], name=cell_id_column) - if obs_index_as_str: - var_index = pd.Index([str(v) for v in var_idx], name=feature_column) - else: - var_index = pd.Index(var_idx, name=feature_column) - X = sp.csr_matrix((0, len(var_index))) - adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) - if region is not None: - adata.obs["region"] = region - if region_key is not None: - adata.obs["region_key"] = region_key - return adata - - feature_idx = ( - assigned - .select(feature_column) - .unique() - .sort(feature_column) - .with_row_index(name="_fid") - ) - cell_idx = ( - assigned - .select(cell_id_column) - .unique() - .sort(cell_id_column) - .with_row_index(name="_cid") - ) - - mapped = ( - assigned - .join(feature_idx, on=feature_column) - .join(cell_idx, on=cell_id_column) - ) - counts = ( - mapped - .group_by(["_cid", "_fid"]) - .agg(pl.len().alias("_count")) - ) - ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T - rows = ijv[0].astype(np.int64, copy=False) - cols = ijv[1].astype(np.int64, copy=False) - data = ijv[2].astype(np.int64, copy=False) - - n_cells = cell_idx.height - n_genes = feature_idx.height - X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() - - obs_ids = cell_idx.get_column(cell_id_column).to_list() - var_ids = feature_idx.get_column(feature_column).to_list() - if obs_index_as_str: - obs_ids = [str(v) for v in obs_ids] - var_ids = [str(v) for v in var_ids] - - adata = AnnData( - X=X, - obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), - var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), - ) - - # Add centroid coordinates if present - if x_column in assigned.columns and y_column in assigned.columns: - coords_cols = [x_column, y_column] - if z_column and z_column in assigned.columns: - coords_cols.append(z_column) - centroids = ( - assigned - .group_by(cell_id_column) - .agg([pl.col(c).mean().alias(c) for c in coords_cols]) - ) - centroids_pd = ( - centroids - .to_pandas() - .set_index(cell_id_column) - .reindex(adata.obs.index) - ) - adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() - - if region is not None: - adata.obs["region"] = region - if region_key is not None: - adata.obs["region_key"] = region_key - - return adata - - -@register_writer(OutputFormat.ANNDATA) -class AnnDataWriter: - """Write segmentation results as AnnData (.h5ad).""" - - def __init__( - self, - unassigned_marker: Union[int, str, None] = -1, - compression: Optional[str] = "gzip", - compression_opts: Optional[int] = 4, - ): - self.unassigned_marker = unassigned_marker - self.compression = compression - self.compression_opts = compression_opts - - def write( - self, - predictions: pl.DataFrame, - output_dir: Path, - transcripts: Optional[pl.DataFrame] = None, - output_name: str = "segger_segmentation.h5ad", - row_index_column: str = "row_index", - cell_id_column: str = "segger_cell_id", - similarity_column: str = "segger_similarity", - feature_column: str = "feature_name", - x_column: Optional[str] = "x", - y_column: Optional[str] = "y", - z_column: Optional[str] = "z", - overwrite: bool = False, - **kwargs, - ) -> Path: - """Write segmentation results to AnnData (.h5ad). - - Parameters - ---------- - predictions - Segmentation predictions. - output_dir - Output directory. - transcripts - Original transcripts DataFrame (required). - output_name - Output filename. Default "segger_segmentation.h5ad". - """ - if transcripts is None: - raise ValueError("AnnData output requires transcripts DataFrame.") - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - output_path = output_dir / output_name - - if output_path.exists() and not overwrite: - raise FileExistsError( - f"Output path exists: {output_path}. " - "Use overwrite=True to replace." - ) - - merged = merge_predictions_with_transcripts( - predictions=predictions, - transcripts=transcripts, - row_index_column=row_index_column, - cell_id_column=cell_id_column, - similarity_column=similarity_column, - unassigned_marker=self.unassigned_marker, - ) - - adata = build_anndata_table( - transcripts=merged, - cell_id_column=cell_id_column, - feature_column=feature_column, - x_column=x_column, - y_column=y_column, - z_column=z_column, - unassigned_value=self.unassigned_marker, - ) - - write_kwargs = {} - if self.compression is not None: - write_kwargs["compression"] = self.compression - if self.compression_opts is not None: - write_kwargs["compression_opts"] = self.compression_opts - - adata.write_h5ad(output_path, **write_kwargs) - return output_path diff --git a/src/segger/export/merged_writer.py b/src/segger/export/merged_writer.py deleted file mode 100644 index eb687df..0000000 --- a/src/segger/export/merged_writer.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Write segmentation results merged back to original transcripts. - -This writer joins segmentation predictions with the original transcript data, -producing a single output file that contains all original columns plus -the segmentation results (segger_cell_id, segger_similarity). - -Usage ------ ->>> from segger.export.merged_writer import MergedTranscriptsWriter ->>> writer = MergedTranscriptsWriter( -... original_transcripts_path=Path("data/transcripts.parquet") -... ) ->>> output_path = writer.write(predictions, Path("output/")) - -The output file contains: -- All original transcript columns -- segger_cell_id: Assigned cell ID (-1 for unassigned) -- segger_similarity: Assignment confidence score (0.0 for unassigned) -""" - -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union - -import polars as pl - -from segger.export.output_formats import OutputFormat, register_writer - -if TYPE_CHECKING: - pass - - -@register_writer(OutputFormat.SEGGER_RAW) -class SeggerRawWriter: - """Write raw Segger prediction output (default format). - - This writer outputs just the predictions DataFrame without merging - with original transcripts. This is the default Segger output format. - - Output columns: - - row_index: Original transcript row index - - segger_cell_id: Assigned cell ID - - segger_similarity: Assignment confidence score - """ - - def __init__( - self, - compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", - ): - """Initialize the raw writer. - - Parameters - ---------- - compression - Parquet compression algorithm. Default is 'snappy'. - """ - self.compression = compression if compression != "none" else None - - def write( - self, - predictions: pl.DataFrame, - output_dir: Path, - output_name: str = "predictions.parquet", - **kwargs, - ) -> Path: - """Write predictions to Parquet file. - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. - output_dir - Output directory. - output_name - Output filename. Default is 'predictions.parquet'. - - Returns - ------- - Path - Path to the written Parquet file. - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - output_path = output_dir / output_name - predictions.write_parquet(output_path, compression=self.compression) - - return output_path - - -@register_writer(OutputFormat.MERGED_TRANSCRIPTS) -class MergedTranscriptsWriter: - """Write segmentation results merged with original transcripts. - - This writer joins predictions with original transcript data, producing - a complete output file with all original columns plus segmentation results. - - Output columns: - - All original transcript columns - - segger_cell_id: Assigned cell ID (configurable marker for unassigned) - - segger_similarity: Assignment confidence score - - Parameters - ---------- - original_transcripts_path - Path to the original transcripts file (Parquet or CSV). - If not provided, must be passed to write() via kwargs. - unassigned_marker - Value to use for unassigned transcripts. Default is -1. - Can be int, str, or None. - include_similarity - Whether to include the similarity score column. Default True. - compression - Parquet compression algorithm. Default is 'snappy'. - """ - - def __init__( - self, - original_transcripts_path: Optional[Path] = None, - unassigned_marker: Union[int, str, None] = -1, - include_similarity: bool = True, - compression: Literal["snappy", "gzip", "lz4", "zstd", "none"] = "snappy", - ): - self.original_transcripts_path = ( - Path(original_transcripts_path) if original_transcripts_path else None - ) - self.unassigned_marker = unassigned_marker - self.include_similarity = include_similarity - self.compression = compression if compression != "none" else None - - def write( - self, - predictions: pl.DataFrame, - output_dir: Path, - output_name: str = "transcripts_segmented.parquet", - transcripts: Optional[pl.DataFrame] = None, - original_transcripts_path: Optional[Path] = None, - row_index_column: str = "row_index", - cell_id_column: str = "segger_cell_id", - similarity_column: str = "segger_similarity", - **kwargs, - ) -> Path: - """Merge predictions with original transcripts and write to file. - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. Must contain: - - row_index: Original transcript row index - - segger_cell_id: Assigned cell ID - - segger_similarity: Assignment confidence score (optional) - output_dir - Output directory. - output_name - Output filename. Default is 'transcripts_segmented.parquet'. - transcripts - Original transcripts DataFrame. If provided, used instead of - loading from original_transcripts_path. - original_transcripts_path - Path to original transcripts. Overrides constructor parameter. - row_index_column - Column name for row index in predictions. Default 'row_index'. - cell_id_column - Column name for cell ID in predictions. Default 'segger_cell_id'. - similarity_column - Column name for similarity in predictions. Default 'segger_similarity'. - - Returns - ------- - Path - Path to the written Parquet file. - - Raises - ------ - ValueError - If no transcripts source is provided. - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - # Get original transcripts - if transcripts is not None: - original = transcripts - else: - path = original_transcripts_path or self.original_transcripts_path - if path is None: - raise ValueError( - "No original transcripts provided. Either pass 'transcripts' " - "DataFrame or specify 'original_transcripts_path'." - ) - original = self._load_transcripts(path) - - # Prepare predictions for join - pred_cols = [row_index_column, cell_id_column] - if self.include_similarity and similarity_column in predictions.columns: - pred_cols.append(similarity_column) - - pred_subset = predictions.select(pred_cols) - - # Handle missing row_index in original (add if needed) - if row_index_column not in original.columns: - original = original.with_row_index(name=row_index_column) - - # Join predictions with original transcripts - merged = original.join( - pred_subset, - on=row_index_column, - how="left", - ) - - # Fill unassigned values - if self.unassigned_marker is not None: - merged = merged.with_columns( - pl.col(cell_id_column).fill_null(self.unassigned_marker) - ) - if self.include_similarity and similarity_column in merged.columns: - merged = merged.with_columns( - pl.col(similarity_column).fill_null(0.0) - ) - - # Write output - output_path = output_dir / output_name - merged.write_parquet(output_path, compression=self.compression) - - return output_path - - def _load_transcripts(self, path: Path) -> pl.DataFrame: - """Load transcripts from file. - - Parameters - ---------- - path - Path to transcripts file (Parquet or CSV). - - Returns - ------- - pl.DataFrame - Loaded transcripts. - """ - path = Path(path) - suffix = path.suffix.lower() - - if suffix == ".parquet": - return pl.read_parquet(path) - elif suffix in (".csv", ".tsv"): - separator = "\t" if suffix == ".tsv" else "," - return pl.read_csv(path, separator=separator) - else: - # Try Parquet first, then CSV - try: - return pl.read_parquet(path) - except Exception: - return pl.read_csv(path) - - -def merge_predictions_with_transcripts( - predictions: pl.DataFrame, - transcripts: pl.DataFrame, - row_index_column: str = "row_index", - cell_id_column: str = "segger_cell_id", - similarity_column: str = "segger_similarity", - unassigned_marker: Union[int, str, None] = -1, -) -> pl.DataFrame: - """Merge predictions with transcripts (functional interface). - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. - transcripts - Original transcripts DataFrame. - row_index_column - Column name for row index. - cell_id_column - Column name for cell ID in predictions. - similarity_column - Column name for similarity in predictions. - unassigned_marker - Value for unassigned transcripts. - - Returns - ------- - pl.DataFrame - Merged DataFrame with all original columns plus predictions. - - Examples - -------- - >>> merged = merge_predictions_with_transcripts(predictions, transcripts) - >>> print(merged.columns) - ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] - """ - # Prepare predictions - pred_cols = [row_index_column, cell_id_column] - if similarity_column in predictions.columns: - pred_cols.append(similarity_column) - - pred_subset = predictions.select(pred_cols) - - # Add row_index if missing - if row_index_column not in transcripts.columns: - transcripts = transcripts.with_row_index(name=row_index_column) - - # Join - merged = transcripts.join(pred_subset, on=row_index_column, how="left") - - # Fill unassigned - if unassigned_marker is not None: - merged = merged.with_columns( - pl.col(cell_id_column).fill_null(unassigned_marker) - ) - if similarity_column in merged.columns: - merged = merged.with_columns( - pl.col(similarity_column).fill_null(0.0) - ) - - return merged diff --git a/src/segger/export/minimal_apis.py b/src/segger/export/minimal_apis.py new file mode 100644 index 0000000..4f0ddee --- /dev/null +++ b/src/segger/export/minimal_apis.py @@ -0,0 +1,410 @@ + +from __future__ import annotations + +from enum import Enum +from pathlib import Path +from typing import Optional, Union, Any, Protocol, runtime_checkable + +from typing import Optional, Union + +import numpy as np +import pandas as pd +import polars as pl +from anndata import AnnData +from scipy import sparse as sp + + +class OutputFormat(str, Enum): + """Available output formats for segmentation results. + + Attributes + ---------- + SEGGER_RAW : str + Default Segger output format. Writes predictions as Parquet file + with columns: row_index, segger_cell_id, segger_similarity. + + MERGED_TRANSCRIPTS : str + Merged transcripts format. Original transcript data with segmentation + results joined (segger_cell_id, segger_similarity columns added). + + SPATIALDATA : str + SpatialData Zarr format. Creates a .zarr store compatible with + the scverse ecosystem, containing transcripts and optional boundaries. + + ANNDATA : str + AnnData format. Creates a .h5ad file with a cell x gene matrix + derived from transcript assignments. + """ + + SEGGER_RAW = "segger_raw" + MERGED_TRANSCRIPTS = "merged" + SPATIALDATA = "spatialdata" + ANNDATA = "anndata" + + @classmethod + def from_string(cls, value: str) -> "OutputFormat": + """Parse OutputFormat from string, case-insensitive. + + Parameters + ---------- + value + Format name ('segger_raw', 'merged', 'spatialdata', 'anndata', or 'all'). + + Returns + ------- + OutputFormat + Corresponding enum value. + + Raises + ------ + ValueError + If value is not a valid format name. + """ + value_lower = value.lower().strip() + + # Handle aliases + aliases = { + "raw": cls.SEGGER_RAW, + "segger": cls.SEGGER_RAW, + "default": cls.SEGGER_RAW, + "merge": cls.MERGED_TRANSCRIPTS, + "merged": cls.MERGED_TRANSCRIPTS, + "transcripts": cls.MERGED_TRANSCRIPTS, + "sdata": cls.SPATIALDATA, + "zarr": cls.SPATIALDATA, + "h5ad": cls.ANNDATA, + "ann": cls.ANNDATA, + "anndata": cls.ANNDATA, + } + + if value_lower in aliases: + return aliases[value_lower] + + # Try direct match + for fmt in cls: + if fmt.value == value_lower: + return fmt + + valid = [f.value for f in cls] + list(aliases.keys()) + raise ValueError( + f"Unknown output format: '{value}'. " + f"Valid formats: {sorted(set(valid))}" + ) + + + +@runtime_checkable +class OutputWriter(Protocol): + """Protocol for output format writers. + + Implementations must provide a `write` method that writes segmentation + results to the specified output directory. + """ + + def write( + self, + predictions: "pl.DataFrame", + output_dir: Path, + **kwargs: Any, + ) -> Path: + """Write segmentation results to output format. + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. Must contain: + - row_index: Original transcript row index + - segger_cell_id: Assigned cell ID (or -1/None for unassigned) + - segger_similarity: Assignment confidence score + + output_dir + Directory to write output files. + + **kwargs + Format-specific options (e.g., transcripts, boundaries). + + Returns + ------- + Path + Path to the primary output file/directory. + """ + ... + + +# Registry of output writers by format +_OUTPUT_WRITERS: dict[OutputFormat, type] = {} + +def register_writer(fmt: OutputFormat): + """Decorator to register an output writer class. + + Parameters + ---------- + fmt + Output format this writer handles. + + Returns + ------- + decorator + Class decorator that registers the writer. + + Examples + -------- + >>> @register_writer(OutputFormat.MERGED_TRANSCRIPTS) + ... class MergedTranscriptsWriter: + ... def write(self, predictions, output_dir, **kwargs): + ... ... + """ + def decorator(cls): + _OUTPUT_WRITERS[fmt] = cls + return cls + return decorator + + +def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: + """Get an output writer for the specified format. + + Parameters + ---------- + fmt + Output format (enum or string). + **init_kwargs + Keyword arguments passed to the writer constructor. + + Returns + ------- + OutputWriter + Writer instance for the specified format. + + Raises + ------ + ValueError + If format is not recognized or writer not registered. + + Examples + -------- + >>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS, unassigned_marker=-1) + >>> writer.write(predictions, Path("output/")) + """ + if isinstance(fmt, str): + fmt = OutputFormat.from_string(fmt) + + if fmt not in _OUTPUT_WRITERS: + raise ValueError( + f"No writer registered for format: {fmt.value}. " + f"Available formats: {[f.value for f in _OUTPUT_WRITERS.keys()]}" + ) + + writer_cls = _OUTPUT_WRITERS[fmt] + return writer_cls(**init_kwargs) + + + +### ANNDATA EXPORT ### + +def build_anndata_table( + transcripts: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + unassigned_value: Union[int, str, None] = -1, + region: Optional[str] = None, + region_key: Optional[str] = None, + obs_index_as_str: bool = False, +) -> AnnData: + """Build AnnData from assigned transcripts. + + Parameters + ---------- + transcripts + Transcript DataFrame with segmentation assignments. + cell_id_column + Column with assigned cell IDs. + feature_column + Column with gene/feature names. + x_column, y_column, z_column + Coordinate columns (optional). If present, centroids are stored in + ``obsm["X_spatial"]``. + unassigned_value + Marker for unassigned transcripts (filtered out). + region, region_key + SpatialData table linkage metadata. + obs_index_as_str + If True, cast cell IDs to string for ``obs`` index. + """ + if cell_id_column not in transcripts.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + if feature_column not in transcripts.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) + if unassigned_value is not None: + col_dtype = transcripts.schema.get(cell_id_column) + try: + compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() + filter_expr = pl.col(cell_id_column) != compare_value + except Exception: + filter_expr = ( + pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) + ) + assigned = assigned.filter(filter_expr) + + # Gene list from all transcripts (even if no assignments) + var_idx = ( + transcripts + .select(feature_column) + .unique() + .sort(feature_column) + .get_column(feature_column) + .to_list() + ) + + if assigned.height == 0: + obs_index = pd.Index([], name=cell_id_column) + if obs_index_as_str: + var_index = pd.Index([str(v) for v in var_idx], name=feature_column) + else: + var_index = pd.Index(var_idx, name=feature_column) + X = sp.csr_matrix((0, len(var_index))) + adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + return adata + + feature_idx = ( + assigned + .select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned + .select(cell_id_column) + .unique() + .sort(cell_id_column) + .with_row_index(name="_cid") + ) + + mapped = ( + assigned + .join(feature_idx, on=feature_column) + .join(cell_idx, on=cell_id_column) + ) + counts = ( + mapped + .group_by(["_cid", "_fid"]) + .agg(pl.len().alias("_count")) + ) + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + n_cells = cell_idx.height + n_genes = feature_idx.height + X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() + + obs_ids = cell_idx.get_column(cell_id_column).to_list() + var_ids = feature_idx.get_column(feature_column).to_list() + if obs_index_as_str: + obs_ids = [str(v) for v in obs_ids] + var_ids = [str(v) for v in var_ids] + + adata = AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), + var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), + ) + + # Add centroid coordinates if present + if x_column in assigned.columns and y_column in assigned.columns: + coords_cols = [x_column, y_column] + if z_column and z_column in assigned.columns: + coords_cols.append(z_column) + centroids = ( + assigned + .group_by(cell_id_column) + .agg([pl.col(c).mean().alias(c) for c in coords_cols]) + ) + centroids_pd = ( + centroids + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() + + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + + return adata + +### MERGED EXPORT ### + +def merge_predictions_with_transcripts( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + unassigned_marker: Union[int, str, None] = -1, +) -> pl.DataFrame: + """Merge predictions with transcripts (functional interface). + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + transcripts + Original transcripts DataFrame. + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + unassigned_marker + Value for unassigned transcripts. + + Returns + ------- + pl.DataFrame + Merged DataFrame with all original columns plus predictions. + + Examples + -------- + >>> merged = merge_predictions_with_transcripts(predictions, transcripts) + >>> print(merged.columns) + ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] + """ + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned + if unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(unassigned_marker) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged diff --git a/src/segger/export/output_formats.py b/src/segger/export/output_formats.py deleted file mode 100644 index d08a990..0000000 --- a/src/segger/export/output_formats.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Output format definitions and writer registry for segmentation results. - -This module provides: -- OutputFormat enum for available output formats -- OutputWriter protocol for implementing format-specific writers -- Factory function to get the appropriate writer for a format - -Available formats: -- SEGGER_RAW: Default Segger output (predictions parquet) -- MERGED_TRANSCRIPTS: Original transcripts merged with assignments -- SPATIALDATA: SpatialData Zarr format for scverse ecosystem -- ANNDATA: AnnData (.h5ad) cell x gene matrix - -Usage ------ ->>> from segger.export.output_formats import OutputFormat, get_writer ->>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS) ->>> writer.write(predictions, transcripts, output_dir) -""" - -from __future__ import annotations - -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable - -if TYPE_CHECKING: - import geopandas as gpd - import polars as pl - - -class OutputFormat(str, Enum): - """Available output formats for segmentation results. - - Attributes - ---------- - SEGGER_RAW : str - Default Segger output format. Writes predictions as Parquet file - with columns: row_index, segger_cell_id, segger_similarity. - - MERGED_TRANSCRIPTS : str - Merged transcripts format. Original transcript data with segmentation - results joined (segger_cell_id, segger_similarity columns added). - - SPATIALDATA : str - SpatialData Zarr format. Creates a .zarr store compatible with - the scverse ecosystem, containing transcripts and optional boundaries. - - ANNDATA : str - AnnData format. Creates a .h5ad file with a cell x gene matrix - derived from transcript assignments. - """ - - SEGGER_RAW = "segger_raw" - MERGED_TRANSCRIPTS = "merged" - SPATIALDATA = "spatialdata" - ANNDATA = "anndata" - - @classmethod - def from_string(cls, value: str) -> "OutputFormat": - """Parse OutputFormat from string, case-insensitive. - - Parameters - ---------- - value - Format name ('segger_raw', 'merged', 'spatialdata', 'anndata', or 'all'). - - Returns - ------- - OutputFormat - Corresponding enum value. - - Raises - ------ - ValueError - If value is not a valid format name. - """ - value_lower = value.lower().strip() - - # Handle aliases - aliases = { - "raw": cls.SEGGER_RAW, - "segger": cls.SEGGER_RAW, - "default": cls.SEGGER_RAW, - "merge": cls.MERGED_TRANSCRIPTS, - "merged": cls.MERGED_TRANSCRIPTS, - "transcripts": cls.MERGED_TRANSCRIPTS, - "sdata": cls.SPATIALDATA, - "zarr": cls.SPATIALDATA, - "h5ad": cls.ANNDATA, - "ann": cls.ANNDATA, - "anndata": cls.ANNDATA, - } - - if value_lower in aliases: - return aliases[value_lower] - - # Try direct match - for fmt in cls: - if fmt.value == value_lower: - return fmt - - valid = [f.value for f in cls] + list(aliases.keys()) - raise ValueError( - f"Unknown output format: '{value}'. " - f"Valid formats: {sorted(set(valid))}" - ) - - -@runtime_checkable -class OutputWriter(Protocol): - """Protocol for output format writers. - - Implementations must provide a `write` method that writes segmentation - results to the specified output directory. - """ - - def write( - self, - predictions: "pl.DataFrame", - output_dir: Path, - **kwargs: Any, - ) -> Path: - """Write segmentation results to output format. - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. Must contain: - - row_index: Original transcript row index - - segger_cell_id: Assigned cell ID (or -1/None for unassigned) - - segger_similarity: Assignment confidence score - - output_dir - Directory to write output files. - - **kwargs - Format-specific options (e.g., transcripts, boundaries). - - Returns - ------- - Path - Path to the primary output file/directory. - """ - ... - - -# Registry of output writers by format -_OUTPUT_WRITERS: dict[OutputFormat, type] = {} - - -def register_writer(fmt: OutputFormat): - """Decorator to register an output writer class. - - Parameters - ---------- - fmt - Output format this writer handles. - - Returns - ------- - decorator - Class decorator that registers the writer. - - Examples - -------- - >>> @register_writer(OutputFormat.MERGED_TRANSCRIPTS) - ... class MergedTranscriptsWriter: - ... def write(self, predictions, output_dir, **kwargs): - ... ... - """ - def decorator(cls): - _OUTPUT_WRITERS[fmt] = cls - return cls - return decorator - - -def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: - """Get an output writer for the specified format. - - Parameters - ---------- - fmt - Output format (enum or string). - **init_kwargs - Keyword arguments passed to the writer constructor. - - Returns - ------- - OutputWriter - Writer instance for the specified format. - - Raises - ------ - ValueError - If format is not recognized or writer not registered. - - Examples - -------- - >>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS, unassigned_marker=-1) - >>> writer.write(predictions, Path("output/")) - """ - if isinstance(fmt, str): - fmt = OutputFormat.from_string(fmt) - - if fmt not in _OUTPUT_WRITERS: - raise ValueError( - f"No writer registered for format: {fmt.value}. " - f"Available formats: {[f.value for f in _OUTPUT_WRITERS.keys()]}" - ) - - writer_cls = _OUTPUT_WRITERS[fmt] - return writer_cls(**init_kwargs) - - -def get_all_writers(**init_kwargs: Any) -> dict[OutputFormat, OutputWriter]: - """Get writers for all registered formats. - - Parameters - ---------- - **init_kwargs - Keyword arguments passed to each writer constructor. - - Returns - ------- - dict[OutputFormat, OutputWriter] - Dictionary mapping formats to writer instances. - """ - return {fmt: get_writer(fmt, **init_kwargs) for fmt in _OUTPUT_WRITERS} - - -def write_all_formats( - predictions: "pl.DataFrame", - output_dir: Path, - **kwargs: Any, -) -> dict[OutputFormat, Path]: - """Write segmentation results in all available formats. - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. - output_dir - Base output directory. Subdirectories may be created for each format. - **kwargs - Additional arguments passed to each writer (transcripts, boundaries, etc.). - - Returns - ------- - dict[OutputFormat, Path] - Dictionary mapping formats to output paths. - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - results = {} - for fmt, writer in get_all_writers().items(): - try: - path = writer.write(predictions, output_dir, **kwargs) - results[fmt] = path - except Exception as e: - # Log error but continue with other formats - import warnings - warnings.warn( - f"Failed to write {fmt.value} format: {e}", - UserWarning, - stacklevel=2, - ) - - return results - - -# Import writers to register them (done at end to avoid circular imports) -def _register_builtin_writers(): - """Register built-in output writers. - - Called lazily to avoid import errors if optional dependencies are missing. - """ - # Import here to register writers via decorators - from segger.export import merged_writer # noqa: F401 - from segger.export import anndata_writer # noqa: F401 - - # SpatialData writer is optional - try: - from segger.export import spatialdata_writer # noqa: F401 - except ImportError: - pass - - -# Lazy registration on first use -_writers_registered = False - - -def _ensure_writers_registered(): - """Ensure built-in writers are registered.""" - global _writers_registered - if not _writers_registered: - _register_builtin_writers() - _writers_registered = True - - -# Override get_writer to ensure registration -_original_get_writer = get_writer - - -def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: - """Get an output writer for the specified format.""" - _ensure_writers_registered() - return _original_get_writer(fmt, **init_kwargs) diff --git a/src/segger/export/sopa_compat.py b/src/segger/export/sopa_compat.py deleted file mode 100644 index 230157c..0000000 --- a/src/segger/export/sopa_compat.py +++ /dev/null @@ -1,396 +0,0 @@ -"""SOPA compatibility utilities for SpatialData export. - -SOPA (Spatial Omics Pipeline Architecture) is a framework for spatial omics -analysis built on SpatialData. This module provides utilities to ensure -Segger output is compatible with SOPA workflows. - -SOPA Conventions ----------------- -- shapes[cell_key]: Cell polygons with 'cell_id' column -- points[transcript_key]: Transcripts with 'cell_id' assignment column -- No images required for segmentation workflows -- Cell IDs should be consistent between shapes and points - -Usage ------ ->>> from segger.export.sopa_compat import validate_sopa_compatibility ->>> issues = validate_sopa_compatibility(sdata) ->>> if not issues: -... print("SpatialData is SOPA-compatible") - ->>> from segger.export.sopa_compat import export_for_sopa ->>> path = export_for_sopa(sdata, Path("output/sopa_compatible.zarr")) - -Installation ------------- -Requires the spatialdata optional dependency: - pip install segger[spatialdata] - -For full SOPA integration: - pip install segger[sopa] -""" - -from __future__ import annotations - -import warnings -from pathlib import Path -from typing import TYPE_CHECKING, Optional - -import polars as pl - -from segger.utils.optional_deps import ( - SPATIALDATA_AVAILABLE, - SOPA_AVAILABLE, - require_spatialdata, - warn_sopa_unavailable, -) - -if TYPE_CHECKING: - import geopandas as gpd - from spatialdata import SpatialData - - -# SOPA expected keys and columns -SOPA_DEFAULT_CELL_KEY = "cells" -SOPA_DEFAULT_TRANSCRIPT_KEY = "transcripts" -SOPA_CELL_ID_COLUMN = "cell_id" - - -def validate_sopa_compatibility( - sdata: "SpatialData", - cell_key: str = SOPA_DEFAULT_CELL_KEY, - transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, -) -> list[str]: - """Validate SpatialData object for SOPA compatibility. - - Checks that the SpatialData object follows SOPA conventions: - - Cell shapes exist with cell_id column - - Transcripts exist with cell_id assignment column - - Cell IDs are consistent between shapes and points - - Parameters - ---------- - sdata - SpatialData object to validate. - cell_key - Expected key for cell shapes. Default "cells". - transcript_key - Expected key for transcripts. Default "transcripts". - - Returns - ------- - list[str] - List of compatibility issues (empty if fully compatible). - - Examples - -------- - >>> issues = validate_sopa_compatibility(sdata) - >>> if issues: - ... for issue in issues: - ... print(f"- {issue}") - """ - require_spatialdata() - - issues = [] - - # Check for cell shapes - if cell_key not in sdata.shapes: - issues.append( - f"Missing cell shapes: expected shapes['{cell_key}']. " - f"Available shapes: {list(sdata.shapes.keys())}" - ) - else: - cells = sdata.shapes[cell_key] - if SOPA_CELL_ID_COLUMN not in cells.columns: - issues.append( - f"Cell shapes missing '{SOPA_CELL_ID_COLUMN}' column. " - f"Available columns: {list(cells.columns)}" - ) - - # Check for transcripts - if transcript_key not in sdata.points: - issues.append( - f"Missing transcripts: expected points['{transcript_key}']. " - f"Available points: {list(sdata.points.keys())}" - ) - else: - transcripts = sdata.points[transcript_key] - # Get column names from Dask DataFrame - if hasattr(transcripts, "columns"): - tx_columns = list(transcripts.columns) - else: - tx_columns = [] - - if SOPA_CELL_ID_COLUMN not in tx_columns: - # Check for alternative names - alt_names = ["segger_cell_id", "seg_cell_id", "cell"] - found = [c for c in alt_names if c in tx_columns] - if found: - issues.append( - f"Transcripts use '{found[0]}' instead of '{SOPA_CELL_ID_COLUMN}'. " - "SOPA expects 'cell_id' column for assignments." - ) - else: - issues.append( - f"Transcripts missing '{SOPA_CELL_ID_COLUMN}' column. " - f"Available columns: {tx_columns}" - ) - - # Check cell ID consistency - if cell_key in sdata.shapes and transcript_key in sdata.points: - try: - cells = sdata.shapes[cell_key] - transcripts = sdata.points[transcript_key] - - if SOPA_CELL_ID_COLUMN in cells.columns: - cell_ids_shapes = set(cells[SOPA_CELL_ID_COLUMN].unique()) - - if hasattr(transcripts, "compute"): - tx_computed = transcripts.compute() - else: - tx_computed = transcripts - - if SOPA_CELL_ID_COLUMN in tx_computed.columns: - cell_ids_tx = set( - tx_computed[SOPA_CELL_ID_COLUMN].dropna().unique() - ) - # Filter out unassigned (-1 or negative) - cell_ids_tx = {c for c in cell_ids_tx if c >= 0} - - missing_in_shapes = cell_ids_tx - cell_ids_shapes - if missing_in_shapes: - issues.append( - f"Cell IDs in transcripts not found in shapes: " - f"{len(missing_in_shapes)} IDs missing" - ) - except Exception as e: - issues.append(f"Could not verify cell ID consistency: {e}") - - return issues - - -def export_for_sopa( - sdata: "SpatialData", - output_path: Path, - cell_key: str = SOPA_DEFAULT_CELL_KEY, - transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, - rename_cell_id: bool = True, - overwrite: bool = False, -) -> Path: - """Export SpatialData in SOPA-expected structure. - - Ensures the output follows SOPA conventions: - - shapes[cell_key]: Cell polygons with 'cell_id' column - - points[transcript_key]: Transcripts with 'cell_id' assignment - - Parameters - ---------- - sdata - SpatialData object to export. - output_path - Path for output .zarr store. - cell_key - Key for cell shapes. Default "cells". - transcript_key - Key for transcripts. Default "transcripts". - rename_cell_id - If True, rename 'segger_cell_id' to 'cell_id' for SOPA. - overwrite - Whether to overwrite existing output. - - Returns - ------- - Path - Path to exported .zarr store. - - Examples - -------- - >>> path = export_for_sopa(sdata, Path("output/sopa_ready.zarr")) - """ - require_spatialdata() - import spatialdata - - output_path = Path(output_path) - - if output_path.exists() and not overwrite: - raise FileExistsError( - f"Output exists: {output_path}. Use overwrite=True to replace." - ) - - # Create a modified copy for SOPA compatibility - elements = {} - - # Process points (transcripts) - for key in sdata.points: - points = sdata.points[key] - - # Rename to expected key if needed - target_key = transcript_key if key == list(sdata.points.keys())[0] else key - - # Rename cell_id column if needed - if rename_cell_id and hasattr(points, "columns"): - if "segger_cell_id" in points.columns and SOPA_CELL_ID_COLUMN not in points.columns: - points = points.rename(columns={"segger_cell_id": SOPA_CELL_ID_COLUMN}) - - elements[f"points/{target_key}"] = points - - # Process shapes - for key in sdata.shapes: - shapes = sdata.shapes[key] - - # Rename to expected key if needed - target_key = cell_key if key == list(sdata.shapes.keys())[0] else key - - # Ensure cell_id column exists - if SOPA_CELL_ID_COLUMN not in shapes.columns: - if "segger_cell_id" in shapes.columns: - shapes = shapes.rename(columns={"segger_cell_id": SOPA_CELL_ID_COLUMN}) - elif shapes.index.name: - shapes = shapes.reset_index() - if shapes.columns[0] != SOPA_CELL_ID_COLUMN: - shapes = shapes.rename(columns={shapes.columns[0]: SOPA_CELL_ID_COLUMN}) - - elements[f"shapes/{target_key}"] = shapes - - # Create new SpatialData - sdata_sopa = spatialdata.SpatialData.from_elements_dict(elements) - - # Write - if output_path.exists(): - import shutil - shutil.rmtree(output_path) - - sdata_sopa.write(output_path) - - return output_path - - -def sopa_to_segger_input( - sopa_sdata: "SpatialData", - cell_key: str = SOPA_DEFAULT_CELL_KEY, - transcript_key: str = SOPA_DEFAULT_TRANSCRIPT_KEY, -) -> tuple[pl.LazyFrame, "gpd.GeoDataFrame"]: - """Convert SOPA SpatialData to Segger internal format. - - Enables round-trip: SOPA → Segger → SOPA - - Parameters - ---------- - sopa_sdata - SOPA-formatted SpatialData object. - cell_key - Key for cell shapes. - transcript_key - Key for transcripts. - - Returns - ------- - tuple[pl.LazyFrame, gpd.GeoDataFrame] - (transcripts, boundaries) in Segger internal format. - - Examples - -------- - >>> transcripts, boundaries = sopa_to_segger_input(sdata) - >>> # Run Segger segmentation - >>> predictions = segment(transcripts, boundaries) - >>> # Export back to SOPA format - >>> export_for_sopa(results, "output.zarr") - """ - require_spatialdata() - import geopandas as gpd - - # Extract transcripts - if transcript_key not in sopa_sdata.points: - available = list(sopa_sdata.points.keys()) - raise ValueError( - f"Transcript key '{transcript_key}' not found. Available: {available}" - ) - - points = sopa_sdata.points[transcript_key] - - # Convert to Polars - if hasattr(points, "compute"): - points_pd = points.compute() - else: - points_pd = points - - transcripts = pl.from_pandas(points_pd).lazy() - - # Normalize column names - column_map = { - SOPA_CELL_ID_COLUMN: "cell_id", - } - for old, new in column_map.items(): - if old in transcripts.collect_schema().names() and old != new: - transcripts = transcripts.rename({old: new}) - - # Add row_index if missing - schema = transcripts.collect_schema() - if "row_index" not in schema.names(): - transcripts = transcripts.with_row_index(name="row_index") - - # Extract boundaries - boundaries = None - if cell_key in sopa_sdata.shapes: - boundaries = sopa_sdata.shapes[cell_key].copy() - - # Normalize cell_id column - if SOPA_CELL_ID_COLUMN not in boundaries.columns: - if boundaries.index.name: - boundaries = boundaries.reset_index() - boundaries = boundaries.rename( - columns={boundaries.columns[0]: SOPA_CELL_ID_COLUMN} - ) - - return transcripts, boundaries - - -def check_sopa_installation() -> dict[str, bool]: - """Check SOPA and related package installation status. - - Returns - ------- - dict[str, bool] - Dictionary with package names and installation status. - """ - status = { - "spatialdata": SPATIALDATA_AVAILABLE, - "sopa": SOPA_AVAILABLE, - } - - # Check spatialdata-io - try: - import spatialdata_io # noqa: F401 - status["spatialdata_io"] = True - except ImportError: - status["spatialdata_io"] = False - - return status - - -def get_sopa_installation_instructions() -> str: - """Get installation instructions for SOPA integration. - - Returns - ------- - str - Installation instructions. - """ - status = check_sopa_installation() - - lines = ["SOPA Integration Installation Status:", ""] - - for pkg, installed in status.items(): - mark = "✓" if installed else "✗" - lines.append(f" {mark} {pkg}: {'installed' if installed else 'not installed'}") - - lines.append("") - lines.append("To install all SOPA dependencies:") - lines.append(" pip install segger[spatialdata-all]") - lines.append("") - lines.append("Or install individually:") - lines.append(" pip install spatialdata>=0.2.0") - lines.append(" pip install spatialdata-io>=0.1.0") - lines.append(" pip install sopa>=1.0.0") - - return "\n".join(lines) diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index a4b47bb..b0f10cf 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -35,8 +35,7 @@ from segger.utils.optional_deps import ( require_spatialdata, ) -from segger.export.output_formats import OutputFormat, register_writer -from segger.export.anndata_writer import build_anndata_table +from segger.export.minimal_apis import OutputFormat, register_writer, build_anndata_table if TYPE_CHECKING: import geopandas as gpd From 413c217989a7182e6dcd8c31eed963f5d5ae5d06 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 18:38:19 +0200 Subject: [PATCH 16/20] Remove scripts dir from v2-incremental branch --- scripts/benchmark_status_dashboard.sh | 1191 ----------------- scripts/build_benchmark_pdf_report.py | 1168 ---------------- scripts/build_benchmark_validation_table.sh | 783 ----------- .../build_default_10x_reference_artifacts.py | 230 ---- scripts/presentation/experiments.md | 361 ----- scripts/presentation/experiments_plan.md | 327 ----- scripts/run_ablation_study.sh | 1138 ---------------- scripts/run_param_benchmark_2gpu.sh | 764 ----------- scripts/run_robustness_ablation_2gpu.sh | 845 ------------ 9 files changed, 6807 deletions(-) delete mode 100755 scripts/benchmark_status_dashboard.sh delete mode 100644 scripts/build_benchmark_pdf_report.py delete mode 100755 scripts/build_benchmark_validation_table.sh delete mode 100755 scripts/build_default_10x_reference_artifacts.py delete mode 100644 scripts/presentation/experiments.md delete mode 100644 scripts/presentation/experiments_plan.md delete mode 100755 scripts/run_ablation_study.sh delete mode 100755 scripts/run_param_benchmark_2gpu.sh delete mode 100755 scripts/run_robustness_ablation_2gpu.sh diff --git a/scripts/benchmark_status_dashboard.sh b/scripts/benchmark_status_dashboard.sh deleted file mode 100755 index 667b8f9..0000000 --- a/scripts/benchmark_status_dashboard.sh +++ /dev/null @@ -1,1191 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -usage() { - cat <<'EOF' -Benchmark status snapshot + terminal dashboard. - -Usage: - bash scripts/benchmark_status_dashboard.sh [options] - -Options: - --root Benchmark root directory - (default: ./results/mossi_main_big_benchmark_nightly) - --out-tsv Snapshot TSV output path - (default: /summaries/status_snapshot.tsv) - --watch [sec] Refresh dashboard every N seconds (default: 20) - --no-color Disable ANSI colors - -h, --help Show this help -EOF -} - -ROOT="./results/mossi_main_big_benchmark_nightly" -OUT_TSV="" -WATCH_SEC=0 -NO_COLOR=0 - -while [[ $# -gt 0 ]]; do - case "$1" in - --root) - if [[ $# -lt 2 ]]; then - echo "ERROR: --root requires a value." >&2 - exit 1 - fi - ROOT="$2" - shift 2 - ;; - --out-tsv) - if [[ $# -lt 2 ]]; then - echo "ERROR: --out-tsv requires a value." >&2 - exit 1 - fi - OUT_TSV="$2" - shift 2 - ;; - --watch) - if [[ $# -ge 2 ]] && [[ ! "${2-}" =~ ^- ]]; then - WATCH_SEC="$2" - shift 2 - else - WATCH_SEC=20 - shift - fi - ;; - --no-color) - NO_COLOR=1 - shift - ;; - -h|--help) - usage - exit 0 - ;; - *) - echo "Unknown argument: $1" >&2 - usage - exit 1 - ;; - esac -done - -if [[ -z "${OUT_TSV}" ]]; then - OUT_TSV="${ROOT}/summaries/status_snapshot.tsv" -fi - -if ! [[ "${WATCH_SEC}" =~ ^[0-9]+$ ]]; then - echo "ERROR: --watch must be a non-negative integer." >&2 - exit 1 -fi - -PLAN_FILE="${ROOT}/job_plan.tsv" -SUMMARY_DIR="${ROOT}/summaries" -LOGS_DIR="${ROOT}/logs" -RUNS_DIR="${ROOT}/runs" -EXPORTS_DIR="${ROOT}/exports" -VALIDATION_TSV="${SUMMARY_DIR}/validation_metrics.tsv" -if [[ ! -f "${VALIDATION_TSV}" ]] && [[ -f "${ROOT}/validation_metrics.tsv" ]]; then - VALIDATION_TSV="${ROOT}/validation_metrics.tsv" -fi - -if [[ ! -f "${PLAN_FILE}" ]]; then - echo "ERROR: Missing plan file: ${PLAN_FILE}" >&2 - exit 1 -fi - -mkdir -p "$(dirname "${OUT_TSV}")" - -if [[ "${NO_COLOR}" == "0" ]] && [[ -t 1 ]]; then - C_RESET=$'\033[0m' - C_BOLD=$'\033[1m' - C_BOLD_OFF=$'\033[22m' - C_GREEN=$'\033[32m' - C_RED=$'\033[31m' - C_YELLOW=$'\033[33m' - C_BLUE=$'\033[34m' - C_CYAN=$'\033[36m' -else - C_RESET="" - C_BOLD="" - C_BOLD_OFF="" - C_GREEN="" - C_RED="" - C_YELLOW="" - C_BLUE="" - C_CYAN="" -fi - -collect_status_map() { - local out_file="$1" - local running_jobs_file="${2:-}" - local -a status_files=() - local f - local include_recovery=1 - - if [[ -n "${running_jobs_file}" ]] && [[ -s "${running_jobs_file}" ]]; then - include_recovery=0 - fi - - for f in "${SUMMARY_DIR}"/gpu*.tsv; do - [[ -f "${f}" ]] || continue - status_files+=("${f}") - done - if [[ "${include_recovery}" == "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - status_files+=("${SUMMARY_DIR}/recovery.tsv") - fi - - if [[ "${#status_files[@]}" -eq 0 ]]; then - : > "${out_file}" - return 0 - fi - - awk -F'\t' ' - FNR == 1 { next } - { - job = $1 - gpu = $2 - status = $3 - elapsed = $4 - note = "" - seg = "" - log_path = "" - if (NF >= 7) { - note = $5 - seg = $6 - log_path = $7 - } else if (NF >= 6) { - seg = $5 - log_path = $6 - } - - if (note == "") { - note = "-" - } - gpu_map[job] = gpu - status_map[job] = status - elapsed_map[job] = elapsed - note_map[job] = note - seg_map[job] = seg - log_map[job] = log_path - } - END { - for (job in status_map) { - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n", - job, - gpu_map[job], - status_map[job], - elapsed_map[job], - note_map[job], - seg_map[job], - log_map[job] - } - } - ' "${status_files[@]}" > "${out_file}" -} - -collect_running_jobs() { - local out_file="$1" - if ! command -v pgrep >/dev/null 2>&1; then - : > "${out_file}" - return 0 - fi - - pgrep -af 'segger segment|segger predict' 2>/dev/null \ - | awk ' - { - for (i = 1; i <= NF; i++) { - if ($i == "-o" && (i + 1) <= NF) { - out = $(i + 1) - gsub(/\/+$/, "", out) - n = split(out, a, "/") - if (n > 0) { - print a[n] - } - } - } - } - ' \ - | sed '/^$/d' \ - | sort -u > "${out_file}" || : > "${out_file}" -} - -pick_log_file() { - local job="$1" - local f - local found="" - for f in "${LOGS_DIR}/${job}.gpu"*.log; do - [[ -f "${f}" ]] || continue - found="${f}" - done - printf '%s' "${found}" -} - -build_snapshot() { - local status_map="$1" - local running_jobs="$2" - local out_file="$3" - local tmp_file - local plan_header plan_has_study_block - tmp_file="$(mktemp)" - - plan_header="$(head -n 1 "${PLAN_FILE}")" - plan_has_study_block=0 - if printf '%s\n' "${plan_header}" | tr '\t' '\n' | grep -Fxq "study_block"; then - plan_has_study_block=1 - fi - - printf "job\tgroup\tgpu\tstatus\tstate\trunning\telapsed_s\trun_count\thad_rerun\thad_anc_retry\thad_predict_fallback\thad_recovery_pass\tseg_exists\tanndata_exists\txenium_exists\tseg_dir\tlog_file\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\tnote\n" > "${tmp_file}" - - tail -n +2 "${PLAN_FILE}" | while IFS=$'\t' read -r -a cols; do - local job group use_3d expansion tx_max_k tx_max_dist n_mid_layers n_heads cells_min_counts min_qv alignment_loss - local row gpu status elapsed note seg_dir log_file - local running seg_exists anndata_exists xenium_exists - local run_count had_rerun had_anc_retry had_predict_fallback had_recovery_pass - local state - - if [[ "${plan_has_study_block}" == "1" ]]; then - job="${cols[0]-}" - group="${cols[2]-}" - use_3d="${cols[3]-}" - expansion="${cols[4]-}" - tx_max_k="${cols[5]-}" - tx_max_dist="${cols[6]-}" - n_mid_layers="${cols[7]-}" - n_heads="${cols[8]-}" - cells_min_counts="${cols[9]-}" - min_qv="${cols[10]-}" - alignment_loss="${cols[11]-}" - else - job="${cols[0]-}" - group="${cols[1]-}" - use_3d="${cols[2]-}" - expansion="${cols[3]-}" - tx_max_k="${cols[4]-}" - tx_max_dist="${cols[5]-}" - n_mid_layers="${cols[6]-}" - n_heads="${cols[7]-}" - cells_min_counts="${cols[8]-}" - min_qv="${cols[9]-}" - alignment_loss="${cols[10]-}" - fi - - row="$(awk -F'\t' -v j="${job}" '$1 == j { print; exit }' "${status_map}")" - gpu="" - status="" - elapsed="" - note="" - seg_dir="" - log_file="" - if [[ -n "${row}" ]]; then - IFS=$'\t' read -r _ gpu status elapsed note seg_dir log_file <<< "${row}" - if [[ "${note}" == "-" ]]; then - note="" - fi - fi - - [[ -n "${seg_dir}" ]] || seg_dir="${RUNS_DIR}/${job}" - [[ -n "${log_file}" ]] || log_file="$(pick_log_file "${job}")" - [[ -n "${log_file}" ]] || log_file="${LOGS_DIR}/${job}.gpu?.log" - - running=0 - if [[ -s "${running_jobs}" ]] && grep -Fxq "${job}" "${running_jobs}"; then - running=1 - fi - - seg_exists=0 - anndata_exists=0 - xenium_exists=0 - [[ -f "${RUNS_DIR}/${job}/segger_segmentation.parquet" ]] && seg_exists=1 - [[ -f "${EXPORTS_DIR}/${job}/anndata/segger_segmentation.h5ad" ]] && anndata_exists=1 - [[ -f "${EXPORTS_DIR}/${job}/xenium_explorer/seg_experiment.xenium" ]] && xenium_exists=1 - - run_count=0 - had_rerun=0 - had_anc_retry=0 - had_predict_fallback=0 - had_recovery_pass=0 - if [[ -f "${log_file}" ]]; then - run_count="$(grep -c "START job=${job}" "${log_file}" 2>/dev/null || printf '0')" - if [[ "${run_count}" -gt 1 ]]; then - had_rerun=1 - fi - grep -q "segment failed with ancdata; retrying" "${log_file}" 2>/dev/null && had_anc_retry=1 || true - grep -q "predict fallback succeeded after OOM" "${log_file}" 2>/dev/null && had_predict_fallback=1 || true - grep -q "RECOVERY job=${job}" "${log_file}" 2>/dev/null && had_recovery_pass=1 || true - fi - - state="pending" - if [[ "${running}" == "1" ]]; then - state="running" - elif [[ "${seg_exists}" == "1" && "${anndata_exists}" == "1" && "${xenium_exists}" == "1" ]]; then - state="done" - elif [[ -n "${status}" ]]; then - case "${status}" in - ok|skipped_existing|recovered_predict_ok) - state="partial" - ;; - *) - state="failed" - ;; - esac - else - if [[ "${seg_exists}" == "1" || "${anndata_exists}" == "1" || "${xenium_exists}" == "1" ]]; then - state="partial" - else - state="pending" - fi - fi - - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job}" "${group}" "${gpu}" "${status}" "${state}" "${running}" "${elapsed}" \ - "${run_count}" "${had_rerun}" "${had_anc_retry}" "${had_predict_fallback}" "${had_recovery_pass}" \ - "${seg_exists}" "${anndata_exists}" "${xenium_exists}" "${seg_dir}" "${log_file}" \ - "${use_3d}" "${expansion}" "${tx_max_k}" "${tx_max_dist}" "${n_mid_layers}" "${n_heads}" \ - "${cells_min_counts}" "${min_qv}" "${alignment_loss}" "${note}" \ - >> "${tmp_file}" - done - - mv "${tmp_file}" "${out_file}" -} - -draw_progress_bar() { - local current="$1" - local total="$2" - local width=40 - local fill=0 - local pct=0 - if [[ "${total}" -gt 0 ]]; then - fill=$((current * width / total)) - pct=$((current * 100 / total)) - fi - local empty=$((width - fill)) - local left right - left="$(printf '%*s' "${fill}" '' | tr ' ' '#')" - right="$(printf '%*s' "${empty}" '' | tr ' ' '-')" - printf "[%s%s] %d/%d (%d%%)" "${left}" "${right}" "${current}" "${total}" "${pct}" -} - -render_tsv_table() { - awk -F'\t' ' - function repeat_char(ch, n, out, i) { - out = "" - for (i = 0; i < n; i++) out = out ch - return out - } - function visible_len(s, t) { - t = s - gsub(/\033\[[0-9;]*m/, "", t) - return length(t) - } - function print_cell(val, width, vis, i) { - vis = visible_len(val) - printf(" %s", val) - for (i = vis; i < width; i++) { - printf(" ") - } - printf(" |") - } - { - rows = NR - if ($1 == "__ROW_SEP__") { - row_sep[NR] = 1 - next - } - if (NF > ncols) ncols = NF - for (i = 1; i <= NF; i++) { - val = $i - sub(/\r$/, "", val) - cells[NR, i] = val - vis = visible_len(val) - if (vis > widths[i]) widths[i] = vis - } - } - END { - if (rows == 0) exit - - sep = "+" - for (i = 1; i <= ncols; i++) { - sep = sep repeat_char("-", widths[i] + 2) "+" - } - - print sep - printf("|") - for (i = 1; i <= ncols; i++) { - print_cell(cells[1, i], widths[i]) - } - printf("\n") - print sep - - for (r = 2; r <= rows; r++) { - if (row_sep[r]) { - print sep - continue - } - printf("|") - for (i = 1; i <= ncols; i++) { - print_cell(cells[r, i], widths[i]) - } - printf("\n") - } - print sep - } - ' -} - -colorize_rows_by_state_column() { - local state_col="$1" - awk -F'\t' \ - -v state_col="${state_col}" \ - -v c_done="${C_GREEN}" \ - -v c_run="${C_BLUE}" \ - -v c_fail="${C_RED}" \ - -v c_pending="${C_YELLOW}" \ - -v c_partial="${C_CYAN}" \ - -v c_reset="${C_RESET}" \ - ' - function color_for_state(s, lower) { - lower = tolower(s) - if (lower == "done") return c_done - if (lower == "running") return c_run - if (lower == "failed") return c_fail - if (lower == "pending") return c_pending - if (lower == "reference") return c_partial - if (lower == "partial") return c_partial - return "" - } - NR == 1 { print; next } - { - state = (state_col > 0 && state_col <= NF) ? $state_col : "" - color = color_for_state(state) - if (color != "") { - for (i = 1; i <= NF; i++) { - $i = color $i c_reset - } - } - print - } - ' OFS=$'\t' -} - -colorize_rows_by_status_column() { - local status_col="$1" - awk -F'\t' \ - -v status_col="${status_col}" \ - -v c_done="${C_GREEN}" \ - -v c_run="${C_BLUE}" \ - -v c_fail="${C_RED}" \ - -v c_pending="${C_YELLOW}" \ - -v c_partial="${C_CYAN}" \ - -v c_reset="${C_RESET}" \ - ' - function state_from_status(status, lower) { - lower = tolower(status) - if (status == "" || status == "") return "pending" - if (lower ~ /running|in_progress/) return "running" - if (lower ~ /oom|oot|ancdata|fail|error|missing|recovery_no_checkpoint/) return "failed" - if (lower == "ok" || lower == "skipped_existing" || lower == "recovered_predict_ok") return "done" - return "pending" - } - function color_for_state(s, lower) { - lower = tolower(s) - if (lower == "done") return c_done - if (lower == "running") return c_run - if (lower == "failed") return c_fail - if (lower == "pending") return c_pending - if (lower == "reference") return c_partial - if (lower == "partial") return c_partial - return "" - } - NR == 1 { print; next } - { - status = (status_col > 0 && status_col <= NF) ? $status_col : "" - state = state_from_status(status) - color = color_for_state(state) - if (color != "") { - for (i = 1; i <= NF; i++) { - $i = color $i c_reset - } - } - print - } - ' OFS=$'\t' -} - -render_dashboard() { - local snapshot="$1" - local total done_count running_count pending_count partial_count failed_count - local oom_count oot_count anc_count rerun_count recovered_count processed - local now - - total="$(awk 'END { print NR-1 }' "${snapshot}")" - done_count="$(awk -F'\t' 'NR>1 && $5=="done" {c++} END{print c+0}' "${snapshot}")" - running_count="$(awk -F'\t' 'NR>1 && $5=="running" {c++} END{print c+0}' "${snapshot}")" - pending_count="$(awk -F'\t' 'NR>1 && $5=="pending" {c++} END{print c+0}' "${snapshot}")" - partial_count="$(awk -F'\t' 'NR>1 && $5=="partial" {c++} END{print c+0}' "${snapshot}")" - failed_count="$(awk -F'\t' 'NR>1 && $5=="failed" {c++} END{print c+0}' "${snapshot}")" - - oom_count="$(awk -F'\t' 'NR>1 && $4 ~ /oom/ {c++} END{print c+0}' "${snapshot}")" - oot_count="$(awk -F'\t' 'NR>1 && $4=="segment_oot" {c++} END{print c+0}' "${snapshot}")" - anc_count="$(awk -F'\t' 'NR>1 && ($4=="segment_ancdata" || $10=="1") {c++} END{print c+0}' "${snapshot}")" - rerun_count="$(awk -F'\t' 'NR>1 && $9=="1" {c++} END{print c+0}' "${snapshot}")" - recovered_count="$(awk -F'\t' 'NR>1 && ($4=="recovered_predict_ok" || $11=="1" || $12=="1") {c++} END{print c+0}' "${snapshot}")" - processed=$((done_count + partial_count + failed_count)) - - now="$(date '+%Y-%m-%d %H:%M:%S')" - echo "${C_CYAN}Benchmark Dashboard${C_RESET} | ${now}" - echo "Root: ${ROOT}" - echo "Snapshot: ${snapshot}" - printf "Progress: " - draw_progress_bar "${processed}" "${total}" - echo - echo - printf "%b\n" "${C_BLUE}running=${running_count}${C_RESET} ${C_YELLOW}pending=${pending_count}${C_RESET} ${C_RED}failed=${failed_count}${C_RESET} ${C_GREEN}done=${done_count}${C_RESET} ${C_CYAN}partial=${partial_count}${C_RESET}" - printf "oom=%s oot=%s ancdata=%s rerun=%s recovered=%s\n" "${oom_count}" "${oot_count}" "${anc_count}" "${rerun_count}" "${recovered_count}" - - echo - echo "State Counts:" - awk -F'\t' ' - NR > 1 { - c[$5]++ - } - END { - print "state\tcount" - order[1] = "running" - order[2] = "pending" - order[3] = "failed" - order[4] = "done" - order[5] = "partial" - - for (i = 1; i <= 5; i++) { - s = order[i] - print s "\t" (c[s] + 0) - seen[s] = 1 - } - for (k in c) { - if (!(k in seen)) { - print k "\t" c[k] - } - } - } - ' "${snapshot}" \ - | colorize_rows_by_state_column 1 \ - | render_tsv_table - - echo - echo "Status Counts:" - awk -F'\t' ' - function state_from_status(status, lower) { - lower = tolower(status) - if (status == "" || status == "") return "pending" - if (lower ~ /running|in_progress/) return "running" - if (lower ~ /oom|oot|ancdata|fail|error|missing|recovery_no_checkpoint/) return "failed" - if (lower == "ok" || lower == "skipped_existing" || lower == "recovered_predict_ok") return "done" - return "pending" - } - function rank(state, lower) { - lower = tolower(state) - if (lower == "running") return 1 - if (lower == "pending") return 2 - if (lower == "failed") return 3 - if (lower == "done") return 4 - if (lower == "partial") return 5 - return 99 - } - NR > 1 { - key = $4 - if (key == "") key = "" - c[key]++ - } - END { - print "rank\tstatus\tcount" - for (k in c) { - st = state_from_status(k) - print rank(st) "\t" k "\t" c[k] - } - } - ' "${snapshot}" \ - | { - IFS= read -r header || true - if [[ -n "${header}" ]]; then - printf "status\tcount\n" - fi - sort -t $'\t' -k1,1n -k3,3nr -k2,2 \ - | cut -f2- - } \ - | colorize_rows_by_status_column 1 \ - | render_tsv_table - - echo - echo "All Jobs Overview:" - awk -F'\t' ' - function rank(st, lower) { - lower = tolower(st) - if (lower == "done") return 1 - if (lower == "running") return 2 - if (lower == "failed") return 3 - if (lower == "pending") return 4 - if (lower == "partial") return 5 - if (lower == "reference") return 6 - return 99 - } - function fmt_minutes(v, lower, n) { - lower = tolower(v) - if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" - n = (v + 0.0) / 60.0 - if (n < 0) n = 0 - return sprintf("%.2f", n) - } - BEGIN { - print "rank\tjob\tgroup\tgpu\tstate\tstatus\truns\telapsed_min\trerun\tanc_retry\toom_pred_fallback\trecovery\tseg\tanndata\txenium" - } - NR > 1 { - st = $5 - if (st == "") st = "pending" - status = $4 - if (status == "") status = "" - print rank(st) "\t" $1 "\t" $2 "\t" $3 "\t" st "\t" status "\t" $8 "\t" fmt_minutes($7) "\t" $9 "\t" $10 "\t" $11 "\t" $12 "\t" $13 "\t" $14 "\t" $15 - } - ' "${snapshot}" \ - | { - IFS= read -r header || true - if [[ -n "${header}" ]]; then - printf "job\tgroup\tgpu\tstate\tstatus\truns\telapsed_min\trerun\tanc_retry\toom_pred_fallback\trecovery\tseg\tanndata\txenium\n" - fi - sort -t $'\t' -k1,1n -k2,2 \ - | cut -f2- - } \ - | colorize_rows_by_state_column 4 \ - | render_tsv_table - - echo - echo "Model Parameterization:" - awk -F'\t' ' - function block_rank(block, lower) { - lower = tolower(block) - if (lower == "stability") return 1 - if (lower == "interaction") return 2 - if (lower == "stress") return 3 - if (block == "-") return 4 - return 5 - } - function model_label(job, lower) { - lower = tolower(job) - if (lower == "baseline") return "baseline" - if (index(lower, "stbl_baseline_") == 1) return "baseline_repeat" - if (index(lower, "stbl_anchor_") == 1) return "anchor_repeat" - if (index(lower, "stbl_sens_") == 1) return "sensitivity_repeat" - if (index(lower, "int_") == 1) return "interaction_ablation" - if (index(lower, "stress_") == 1) return "stress_test" - if (index(lower, "use3d_") == 1) return "ablation_use3d" - if (index(lower, "expansion_") == 1) return "ablation_expansion" - if (index(lower, "txk_") == 1) return "ablation_txk" - if (index(lower, "txdist_") == 1) return "ablation_txdist" - if (index(lower, "layers_") == 1) return "ablation_layers" - if (index(lower, "heads_") == 1) return "ablation_heads" - if (index(lower, "cellsmin_") == 1) return "ablation_cellsmin" - if (index(lower, "align_") == 1) return "ablation_alignment" - return "custom" - } - function col(name, fallback) { - if ((name in idx) && idx[name] > 0) return $(idx[name]) - return fallback - } - FNR == NR { - if (FNR == 1) { - for (i = 1; i <= NF; i++) { - if ($i == "job") snap_job_col = i - if ($i == "state") snap_state_col = i - if ($i == "status") snap_status_col = i - } - next - } - if (snap_job_col > 0) { - j = $snap_job_col - state_by_job[j] = (snap_state_col > 0 ? $snap_state_col : "") - status_by_job[j] = (snap_status_col > 0 ? $snap_status_col : "") - } - next - } - FNR == 1 { - for (i = 1; i <= NF; i++) idx[$i] = i - has_block = (("study_block" in idx) && idx["study_block"] > 0) ? 1 : 0 - print "rank\tjob\tmodel\tstudy_block\tgroup\tstate\tstatus\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss" - next - } - { - job = col("job", $1) - block = has_block ? col("study_block", "-") : "-" - group = col("group", "-") - state = state_by_job[job] - if (state == "") state = "pending" - status = status_by_job[job] - if (status == "") status = "" - - print block_rank(block) "\t" \ - job "\t" \ - model_label(job) "\t" \ - block "\t" \ - group "\t" \ - state "\t" \ - status "\t" \ - col("use_3d", "-") "\t" \ - col("expansion", "-") "\t" \ - col("tx_max_k", "-") "\t" \ - col("tx_max_dist", "-") "\t" \ - col("n_mid_layers", "-") "\t" \ - col("n_heads", "-") "\t" \ - col("cells_min_counts", "-") "\t" \ - col("min_qv", "-") "\t" \ - col("alignment_loss", "-") - } - ' "${snapshot}" "${PLAN_FILE}" \ - | { - IFS= read -r header || true - if [[ -n "${header}" ]]; then - printf "job\tmodel\tstudy_block\tgroup\tstate\tstatus\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" - fi - sort -t $'\t' -k1,1n -k2,2 \ - | cut -f2- - } \ - | colorize_rows_by_state_column 5 \ - | render_tsv_table - - echo - echo "Validation Metrics:" - if [[ -f "${VALIDATION_TSV}" ]] && [[ "$(awk 'END { print NR-1 }' "${VALIDATION_TSV}")" -gt 0 ]]; then - awk -F'\t' ' - function has_col(name) { - return (name in idx) && idx[name] > 0 - } - function get_col(name) { - if (has_col(name)) return $(idx[name]) - return "" - } - function fmt_float(v, lower) { - lower = tolower(v) - if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" - return sprintf("%.4f", v + 0.0) - } - function fmt_nonneg_int(v, lower, n) { - lower = tolower(v) - if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "0" - n = v + 0.0 - if (n < 0) n = 0 - return sprintf("%.0f", n) - } - function fmt_minutes(v, lower, n) { - lower = tolower(v) - if (v == "" || lower == "nan" || lower == "none" || lower == "-") return "nan" - n = (v + 0.0) / 60.0 - if (n < 0) n = 0 - return sprintf("%.2f", n) - } - FNR == NR { - if (FNR == 1) { - for (i = 1; i <= NF; i++) { - if ($i == "job") snap_job_col = i - if ($i == "state") snap_state_col = i - if ($i == "elapsed_s") snap_elapsed_col = i - } - next - } - if (snap_job_col > 0) { - job_key = $snap_job_col - if (snap_state_col > 0) state_by_job[job_key] = $snap_state_col - if (snap_elapsed_col > 0) elapsed_by_job[job_key] = $snap_elapsed_col - } - next - } - FNR == 1 { - for (i = 1; i <= NF; i++) { - idx[$i] = i - } - print "job\tkind\tstate\tvalidate_status\tgpu_time_min v\tcells\tassigned_pct ^\tmecr v\tcontamination_pct v\tresolvi_contam_pct v\ttco ^\tdoublet_pct v" - next - } - { - job = get_col("job") - job_disp = job - group = get_col("group") - is_reference = get_col("is_reference") - reference_kind = get_col("reference_kind") - if (reference_kind != "" && reference_kind != "-") { - kind = reference_kind - } else if (is_reference == "1" || group == "R") { - kind = "reference" - } else { - kind = "segger" - } - if (job == "baseline" && kind == "segger") { - job_disp = "baseline*" - } - - state = state_by_job[job] - if (kind != "segger") { - state = "reference" - } else if (state == "") { - state = "" - } - - validate_status = get_col("validate_status") - if (validate_status == "") validate_status = "" - - gpu_time = get_col("gpu_time_s") - if (gpu_time == "") gpu_time = get_col("elapsed_s") - if (gpu_time == "") gpu_time = elapsed_by_job[job] - gpu_time = fmt_minutes(gpu_time) - - cells = get_col("cells") - if (cells == "") cells = get_col("cells_total") - if (cells == "") cells = get_col("cells_assigned") - cells = fmt_nonneg_int(cells) - - assigned = get_col("assigned_pct") - if (assigned == "") assigned = get_col("transcripts_assigned_pct") - - mecr = get_col("mecr") - if (mecr == "") mecr = get_col("mecr_fast") - - contamination = get_col("contamination_pct") - if (contamination == "") contamination = get_col("border_contaminated_cells_pct_fast") - - resolvi = get_col("resolvi_contamination_pct") - if (resolvi == "") resolvi = get_col("resolvi_contamination_pct_fast") - - tco = get_col("tco") - if (tco == "") tco = get_col("transcript_centroid_offset_fast") - - doublet = get_col("doublet_pct") - if (doublet == "") { - doublet = get_col("signal_doublet_like_fraction_fast") - if (doublet != "" && tolower(doublet) != "nan" && tolower(doublet) != "none") { - doublet = 100.0 * (doublet + 0.0) - } - } - - print job_disp "\t" kind "\t" state "\t" validate_status "\t" gpu_time "\t" cells "\t" \ - fmt_float(assigned) "\t" fmt_float(mecr) "\t" fmt_float(contamination) "\t" \ - fmt_float(resolvi) "\t" fmt_float(tco) "\t" fmt_float(doublet) - } - ' "${snapshot}" "${VALIDATION_TSV}" \ - | { - IFS= read -r header || true - if [[ -n "${header}" ]]; then - printf "%s\n" "${header}" - fi - awk -F'\t' -v b_on="${C_BOLD}" -v b_off="${C_BOLD_OFF}" ' - function is_num(v, lower) { - lower = tolower(v) - return !(v == "" || lower == "nan" || lower == "none" || lower == "-") - } - function to_num(v) { - return v + 0.0 - } - function update_top2_up(v, x) { - if (!is_num(v)) return - x = to_num(v) - if (!have1 || x > top1) { - top2 = top1 - have2 = have1 - top1 = x - have1 = 1 - } else if (!have2 || x > top2) { - top2 = x - have2 = 1 - } - } - function update_top2_down(v, x) { - if (!is_num(v)) return - x = to_num(v) - if (!have1 || x < top1) { - top2 = top1 - have2 = have1 - top1 = x - have1 = 1 - } else if (!have2 || x < top2) { - top2 = x - have2 = 1 - } - } - function is_top2_up(v, best1, best2, have_2, x) { - if (!is_num(v) || !is_num(best1)) return 0 - x = to_num(v) - if (!have_2) return (x == best1) - return (x >= best2) - } - function is_top2_down(v, best1, best2, have_2, x) { - if (!is_num(v) || !is_num(best1)) return 0 - x = to_num(v) - if (!have_2) return (x == best1) - return (x <= best2) - } - function norm_up(v, lo, hi) { - if (!is_num(v)) return "" - if (hi <= lo) return 1.0 - return (to_num(v) - lo) / (hi - lo) - } - function norm_down(v, lo, hi) { - if (!is_num(v)) return "" - if (hi <= lo) return 1.0 - return (hi - to_num(v)) / (hi - lo) - } - { - n++ - for (j = 1; j <= NF; j++) { - cell[n, j] = $j - } - nf[n] = NF - m_assigned[n] = $7 - m_mecr[n] = $8 - m_contam[n] = $9 - m_resolvi[n] = $10 - m_tco[n] = $11 - m_doublet[n] = $12 - m_gpu[n] = $5 - st[n] = tolower($4) - is_ref[n] = (tolower($3) == "reference") - - if (is_num($7)) { - v = to_num($7) - if (!has_a || v < min_a) min_a = v - if (!has_a || v > max_a) max_a = v - has_a = 1 - } - if (is_num($8)) { - v = to_num($8) - if (!has_m || v < min_m) min_m = v - if (!has_m || v > max_m) max_m = v - has_m = 1 - } - if (is_num($9)) { - v = to_num($9) - if (!has_c || v < min_c) min_c = v - if (!has_c || v > max_c) max_c = v - has_c = 1 - } - if (is_num($10)) { - v = to_num($10) - if (!has_r || v < min_r) min_r = v - if (!has_r || v > max_r) max_r = v - has_r = 1 - } - if (is_num($11)) { - v = to_num($11) - if (!has_t || v < min_t) min_t = v - if (!has_t || v > max_t) max_t = v - has_t = 1 - } - if (is_num($12)) { - v = to_num($12) - if (!has_d || v < min_d) min_d = v - if (!has_d || v > max_d) max_d = v - has_d = 1 - } - - if (st[n] == "ok") { - have1 = have_a_best1; top1 = a_best1; have2 = have_a_best2; top2 = a_best2 - update_top2_up($7) - have_a_best1 = have1; a_best1 = top1; have_a_best2 = have2; a_best2 = top2 - - have1 = have_m_best1; top1 = m_best1; have2 = have_m_best2; top2 = m_best2 - update_top2_down($8) - have_m_best1 = have1; m_best1 = top1; have_m_best2 = have2; m_best2 = top2 - - have1 = have_c_best1; top1 = c_best1; have2 = have_c_best2; top2 = c_best2 - update_top2_down($9) - have_c_best1 = have1; c_best1 = top1; have_c_best2 = have2; c_best2 = top2 - - have1 = have_r_best1; top1 = r_best1; have2 = have_r_best2; top2 = r_best2 - update_top2_down($10) - have_r_best1 = have1; r_best1 = top1; have_r_best2 = have2; r_best2 = top2 - - have1 = have_t_best1; top1 = t_best1; have2 = have_t_best2; top2 = t_best2 - update_top2_up($11) - have_t_best1 = have1; t_best1 = top1; have_t_best2 = have2; t_best2 = top2 - - have1 = have_d_best1; top1 = d_best1; have2 = have_d_best2; top2 = d_best2 - update_top2_down($12) - have_d_best1 = have1; d_best1 = top1; have_d_best2 = have2; d_best2 = top2 - - if (!is_ref[n]) { - have1 = have_g_best1; top1 = g_best1; have2 = have_g_best2; top2 = g_best2 - update_top2_down($5) - have_g_best1 = have1; g_best1 = top1; have_g_best2 = have2; g_best2 = top2 - } - } - } - END { - for (i = 1; i <= n; i++) { - # Rank rows by overall score across assigned/mecr/contam/tco/doublet. - # Rows with no numeric metrics (or non-ok status) stay at the bottom. - if (st[i] != "ok") { - score = -1e9 - } else { - score = 0.0 - cnt = 0 - - s = norm_up(m_assigned[i], min_a, max_a) - if (s != "") { score += s; cnt++ } - - s = norm_down(m_mecr[i], min_m, max_m) - if (s != "") { score += s; cnt++ } - - s = norm_down(m_contam[i], min_c, max_c) - if (s != "") { score += s; cnt++ } - - s = norm_up(m_tco[i], min_t, max_t) - if (s != "") { score += s; cnt++ } - - s = norm_down(m_doublet[i], min_d, max_d) - if (s != "") { score += s; cnt++ } - - if (cnt > 0) { - score /= cnt - } else { - score = -1e9 - } - } - - if (b_on != "" && st[i] == "ok") { - if (!is_ref[i] && is_top2_down(m_gpu[i], g_best1, g_best2, have_g_best2)) cell[i, 5] = b_on cell[i, 5] b_off - if (is_top2_up(m_assigned[i], a_best1, a_best2, have_a_best2)) cell[i, 7] = b_on cell[i, 7] b_off - if (is_top2_down(m_mecr[i], m_best1, m_best2, have_m_best2)) cell[i, 8] = b_on cell[i, 8] b_off - if (is_top2_down(m_contam[i], c_best1, c_best2, have_c_best2)) cell[i, 9] = b_on cell[i, 9] b_off - if (is_top2_down(m_resolvi[i], r_best1, r_best2, have_r_best2)) cell[i, 10] = b_on cell[i, 10] b_off - if (is_top2_up(m_tco[i], t_best1, t_best2, have_t_best2)) cell[i, 11] = b_on cell[i, 11] b_off - if (is_top2_down(m_doublet[i], d_best1, d_best2, have_d_best2)) cell[i, 12] = b_on cell[i, 12] b_off - } - - row = cell[i, 1] - for (j = 2; j <= nf[i]; j++) { - row = row "\t" cell[i, j] - } - gkey = 1 - if (tolower(cell[i, 3]) == "reference") gkey = 0 - printf "%d\t%.10f\t%s\n", gkey, score, row - } - } - ' \ - | sort -t $'\t' -k1,1n -k2,2gr -k3,3 \ - | cut -f3- \ - | awk -F'\t' ' - { - state = tolower($3) - if (!seen_data) { - seen_data = 1 - } - if (state == "reference") { - seen_ref = 1 - print - next - } - if (seen_ref && !inserted_sep) { - print "__ROW_SEP__" - inserted_sep = 1 - } - print - } - ' - } \ - | colorize_rows_by_state_column 3 \ - | render_tsv_table - else - echo "No validation TSV found at ${VALIDATION_TSV}" - echo "Run: bash scripts/build_benchmark_validation_table.sh --root ${ROOT}" - fi - - if [[ "${running_count}" -gt 0 ]]; then - echo - echo "Running Jobs:" - awk -F'\t' ' - BEGIN { - print "job\tgpu\tstatus\tstate\truns\tlog_file" - } - NR > 1 && $6 == "1" { - st = $4 - if (st == "") st = "" - print $1 "\t" $3 "\t" st "\t" $5 "\t" $8 "\t" $17 - } - ' "${snapshot}" \ - | colorize_rows_by_state_column 4 \ - | render_tsv_table - fi - - if [[ "${failed_count}" -gt 0 ]]; then - echo - echo "Failed Jobs:" - awk -F'\t' ' - BEGIN { - print "job\tgpu\tstatus\tstate\truns\trerun\tlog_file" - } - NR > 1 && $5 == "failed" { - st = $4 - if (st == "") st = "" - print $1 "\t" $3 "\t" st "\t" $5 "\t" $8 "\t" $9 "\t" $17 - } - ' "${snapshot}" \ - | colorize_rows_by_state_column 4 \ - | render_tsv_table - fi - - if [[ "${rerun_count}" -gt 0 ]]; then - echo - echo "Rerun/Retry Jobs:" - awk -F'\t' ' - BEGIN { - print "job\tstate\tstatus\truns\tanc_retry\toom_predict_fallback\trecovery_pass" - } - NR > 1 && $9 == "1" { - st = $4 - if (st == "") st = "" - print $1 "\t" $5 "\t" st "\t" $8 "\t" $10 "\t" $11 "\t" $12 - } - ' "${snapshot}" \ - | { - IFS= read -r header || true - if [[ -n "${header}" ]]; then - printf "%s\n" "${header}" - fi - awk -F'\t' ' - function rank(state, lower) { - lower = tolower(state) - if (lower == "running") return 1 - if (lower == "pending") return 2 - if (lower == "failed") return 3 - if (lower == "done") return 4 - if (lower == "partial") return 5 - return 99 - } - { - print rank($2) "\t" $0 - } - ' \ - | sort -t $'\t' -k1,1n -k2,2 \ - | cut -f2- - } \ - | colorize_rows_by_state_column 2 \ - | render_tsv_table - fi -} - -snapshot_once() { - local tmp_status_map tmp_running - tmp_status_map="$(mktemp)" - tmp_running="$(mktemp)" - collect_running_jobs "${tmp_running}" - collect_status_map "${tmp_status_map}" "${tmp_running}" - build_snapshot "${tmp_status_map}" "${tmp_running}" "${OUT_TSV}" - rm -f "${tmp_status_map}" "${tmp_running}" -} - -if [[ "${WATCH_SEC}" -gt 0 ]]; then - while true; do - snapshot_once - if [[ -t 1 ]]; then - clear - fi - render_dashboard "${OUT_TSV}" - sleep "${WATCH_SEC}" - done -else - snapshot_once - render_dashboard "${OUT_TSV}" -fi diff --git a/scripts/build_benchmark_pdf_report.py b/scripts/build_benchmark_pdf_report.py deleted file mode 100644 index 89db660..0000000 --- a/scripts/build_benchmark_pdf_report.py +++ /dev/null @@ -1,1168 +0,0 @@ -#!/usr/bin/env python3 -"""Build a multi-page PDF report for benchmark validation metrics.""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path -from typing import Iterable, Sequence - -import numpy as np -import pandas as pd - -try: - import anndata as ad -except Exception: # pragma: no cover - ad = None - -try: - import matplotlib.pyplot as plt - from matplotlib.backends.backend_pdf import PdfPages - from matplotlib.patches import Polygon as MplPolygon -except Exception: # pragma: no cover - plt = None - PdfPages = None - MplPolygon = None - -try: - import polars as pl -except Exception: # pragma: no cover - pl = None - - -METRIC_SPECS = [ - ("assigned_pct", "assigned_ci95", "Assigned %", "up"), - ("mecr", "mecr_ci95", "MECR", "down"), - ("contamination_pct", "contamination_ci95", "Contamination %", "down"), - ("tco", "tco_ci95", "TCO", "up"), - ("doublet_pct", "doublet_ci95", "Doublet %", "down"), - ("gpu_time_min", None, "GPU time (min)", "down"), -] - -OVERALL_RANK_METRICS = [ - ("assigned_pct", True), - ("mecr", False), - ("contamination_pct", False), - ("tco", True), - ("doublet_pct", False), -] - -REFERENCE_COLORS = { - "10x_cell": "#c97a3d", - "10x_nucleus": "#e0ab66", - "ref_other": "#b68f6f", -} -SEGGER_SHADE_LIGHT = "#d8e6f4" -SEGGER_SHADE_DARK = "#0f4c81" - - -def _to_numeric(df: pd.DataFrame, cols: Iterable[str]) -> None: - for c in cols: - if c in df.columns: - df[c] = pd.to_numeric(df[c], errors="coerce") - - -def _safe_str(value: object) -> str: - if value is None: - return "" - return str(value) - - -def _is_reference_row(row: pd.Series) -> bool: - is_ref = _safe_str(row.get("is_reference", "0")).strip().lower() in {"1", "true", "yes"} - group = _safe_str(row.get("group", "")).strip().upper() - job = _safe_str(row.get("job", "")).strip().lower() - ref_kind = _safe_str(row.get("reference_kind", "")).strip() - return is_ref or group == "R" or job.startswith("ref_") or (ref_kind not in {"", "-", "nan", "None"}) - - -def _kind_label(row: pd.Series) -> str: - ref_kind = _safe_str(row.get("reference_kind", "")).strip() - if ref_kind and ref_kind not in {"-", "nan", "None"}: - return ref_kind - return "segger" - - -def _display_job_label(row: pd.Series) -> str: - job = _safe_str(row.get("job", "")).strip() - if job == "baseline": - return "baseline*" - if _is_reference_row(row): - kind = _kind_label(row) - if kind == "10x_cell": - return "10x cell (ref)" - if kind == "10x_nucleus": - return "10x nucleus (ref)" - if kind and kind != "segger": - return f"{kind} (ref)" - return job - - -def _hex_to_rgb(color_hex: str) -> tuple[float, float, float]: - c = color_hex.strip().lstrip("#") - if len(c) != 6: - return (0.5, 0.5, 0.5) - return tuple(int(c[i : i + 2], 16) / 255.0 for i in (0, 2, 4)) - - -def _rgb_to_hex(rgb: Sequence[float]) -> str: - vals = [max(0, min(255, int(round(float(v) * 255.0)))) for v in rgb] - return f"#{vals[0]:02x}{vals[1]:02x}{vals[2]:02x}" - - -def _mix_hex(color_a: str, color_b: str, t: float) -> str: - t = float(max(0.0, min(1.0, t))) - a = _hex_to_rgb(color_a) - b = _hex_to_rgb(color_b) - return _rgb_to_hex(tuple((1.0 - t) * x + t * y for x, y in zip(a, b))) - - -def _build_gradient(light_hex: str, dark_hex: str, n: int) -> list[str]: - if n <= 0: - return [] - if n == 1: - return [dark_hex] - vals = np.linspace(0.0, 1.0, n) - return [_mix_hex(dark_hex, light_hex, float(v)) for v in vals] - - -def _normalize_metric(vals: np.ndarray, higher_is_better: bool) -> np.ndarray: - x = np.asarray(vals, dtype=float) - if x.size == 0: - return x - finite = np.isfinite(x) - out = np.full_like(x, np.nan) - if not np.any(finite): - return out - lo = np.nanmin(x[finite]) - hi = np.nanmax(x[finite]) - if not np.isfinite(lo) or not np.isfinite(hi): - return out - if hi <= lo: - out[finite] = 1.0 - else: - out[finite] = (x[finite] - lo) / (hi - lo) - if not higher_is_better: - out[finite] = 1.0 - out[finite] - return out - - -def _compute_overall_score(df: pd.DataFrame) -> pd.Series: - if df.empty: - return pd.Series(dtype=float) - mats = [] - for metric, hib in OVERALL_RANK_METRICS: - if metric not in df.columns: - mats.append(np.full(len(df), np.nan, dtype=float)) - continue - arr = pd.to_numeric(df[metric], errors="coerce").to_numpy(dtype=float) - mats.append(_normalize_metric(arr, hib)) - if not mats: - return pd.Series(np.nan, index=df.index, dtype=float) - stack = np.vstack(mats).T - with np.errstate(invalid="ignore"): - denom = np.sum(np.isfinite(stack), axis=1).astype(float) - numer = np.nansum(stack, axis=1) - score = np.divide(numer, denom, out=np.full_like(numer, np.nan), where=denom > 0) - return pd.Series(score, index=df.index, dtype=float) - - -def _rank_df(df: pd.DataFrame) -> pd.DataFrame: - out = df.copy() - if "_overall_score" not in out.columns: - out["_overall_score"] = _compute_overall_score(out) - for c in ["assigned_pct", "mecr"]: - if c not in out.columns: - out[c] = np.nan - out = out.sort_values( - by=["_overall_score", "assigned_pct", "mecr"], - ascending=[False, False, True], - na_position="last", - ) - return out - - -def _ensure_columns(df: pd.DataFrame) -> pd.DataFrame: - out = df.copy() - for col, _, _, _ in METRIC_SPECS: - if col not in out.columns: - out[col] = np.nan - for _, ci_col, _, _ in METRIC_SPECS: - if ci_col and ci_col not in out.columns: - out[ci_col] = np.nan - if "gpu_time_s" in out.columns and "gpu_time_min" not in out.columns: - out["gpu_time_min"] = pd.to_numeric(out["gpu_time_s"], errors="coerce") / 60.0 - if "gpu_time_min" not in out.columns: - out["gpu_time_min"] = np.nan - if "validate_status" not in out.columns: - out["validate_status"] = "" - if "job" not in out.columns: - out["job"] = "" - if "anndata_path" not in out.columns: - out["anndata_path"] = "" - if "segmentation_path" not in out.columns: - out["segmentation_path"] = "" - return out - - -def _assign_plot_colors(df: pd.DataFrame) -> pd.DataFrame: - out = df.copy() - out["_plot_color"] = "#4f6f8f" - out["_plot_label"] = out.apply(_display_job_label, axis=1) - out["_overall_score"] = _compute_overall_score(out) - - ref_mask = out["is_reference_row"].astype(bool) - for idx, row in out[ref_mask].iterrows(): - kind = _safe_str(row.get("kind", "")) - if kind == "10x_cell": - out.at[idx, "_plot_color"] = REFERENCE_COLORS["10x_cell"] - elif kind == "10x_nucleus": - out.at[idx, "_plot_color"] = REFERENCE_COLORS["10x_nucleus"] - else: - out.at[idx, "_plot_color"] = REFERENCE_COLORS["ref_other"] - - seg = out[~ref_mask].copy() - seg = _rank_df(seg) - shades = _build_gradient(SEGGER_SHADE_LIGHT, SEGGER_SHADE_DARK, len(seg)) - for i, (idx, _) in enumerate(seg.iterrows()): - out.at[idx, "_plot_color"] = shades[i] if i < len(shades) else SEGGER_SHADE_DARK - - return out - - -def _load_metrics(path: Path) -> pd.DataFrame: - df = pd.read_csv(path, sep="\t", dtype=str) - num_cols = [ - "gpu_time_s", - "cells", - "assigned_pct", - "assigned_ci95", - "mecr", - "mecr_ci95", - "contamination_pct", - "contamination_ci95", - "tco", - "tco_ci95", - "doublet_pct", - "doublet_ci95", - ] - _to_numeric(df, num_cols) - if "gpu_time_s" in df.columns: - df["gpu_time_min"] = df["gpu_time_s"] / 60.0 - else: - df["gpu_time_min"] = np.nan - if "validate_status" not in df.columns: - df["validate_status"] = "" - df["is_reference_row"] = df.apply(_is_reference_row, axis=1) - df["kind"] = df.apply(_kind_label, axis=1) - df = _ensure_columns(df) - df = _assign_plot_colors(df) - return df - - -def _ok_rows(df: pd.DataFrame) -> pd.DataFrame: - return df[df["validate_status"].astype(str).str.lower() == "ok"].copy() - - -def _ordered_ok_rows(df: pd.DataFrame) -> pd.DataFrame: - ok = _ok_rows(df) - if ok.empty: - return ok - refs = ok[ok["is_reference_row"]].copy() - refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) - refs = refs.sort_values(by=["_ref_order", "job"], ascending=[True, True]) - seg = _rank_df(ok[~ok["is_reference_row"]].copy()) - return pd.concat([refs, seg], axis=0, ignore_index=False) - - -def _apply_report_style() -> None: - if plt is None: - return - style_candidates = [ - Path(__file__).resolve().parents[2] / "segger-analysis" / "assets" / "paper.mplstyle", - Path(__file__).resolve().parents[1] / "assets" / "paper.mplstyle", - Path("../segger-analysis/assets/paper.mplstyle").resolve(), - ] - for style_path in style_candidates: - if style_path.exists(): - try: - plt.style.use(str(style_path)) - break - except Exception: - continue - - plt.rcParams.update( - { - "font.family": "sans-serif", - "font.sans-serif": ["Helvetica Neue", "Helvetica", "Arial", "DejaVu Sans"], - "axes.titlesize": 9, - "axes.labelsize": 8, - "xtick.labelsize": 7, - "ytick.labelsize": 7, - "axes.linewidth": 0.6, - "grid.linewidth": 0.45, - "grid.alpha": 0.16, - "figure.dpi": 220, - "savefig.dpi": 300, - "legend.frameon": False, - "legend.fontsize": 7, - } - ) - - -def _clean_axes(ax) -> None: - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - -def _metric_title(base: str, direction: str) -> str: - arrow = "^" if direction == "up" else "v" - return f"{base} {arrow}" - - -def _plot_bar_page(pdf: PdfPages, df: pd.DataFrame) -> None: - fig, axes = plt.subplots(2, 3, figsize=(11.6, 8.3)) - axes = axes.flatten() - - disp = _ordered_ok_rows(df) - if disp.empty: - fig.suptitle("Benchmark Comparison: no valid rows", fontsize=11) - for ax in axes: - ax.axis("off") - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - y = np.arange(len(disp)) - labels = [_safe_str(x) for x in disp["_plot_label"].tolist()] - colors = [_safe_str(x) for x in disp["_plot_color"].tolist()] - - for ax, (metric, ci_col, title, direction) in zip(axes, METRIC_SPECS): - vals = pd.to_numeric(disp[metric], errors="coerce").to_numpy(dtype=float) - errs = ( - pd.to_numeric(disp[ci_col], errors="coerce").to_numpy(dtype=float) - if ci_col is not None - else np.full_like(vals, np.nan, dtype=float) - ) - valid = np.isfinite(vals) - if not np.any(valid): - ax.set_title(_metric_title(title, direction), fontsize=9) - ax.text(0.5, 0.5, "no data", transform=ax.transAxes, ha="center", va="center", fontsize=8) - ax.set_yticks([]) - ax.grid(False) - _clean_axes(ax) - continue - - val_v = vals[valid] - err_v = errs[valid] - err_plot = np.where(np.isfinite(err_v) & (err_v >= 0), err_v, 0.0) - color_v = np.asarray(colors, dtype=object)[valid] - y_v = y[valid] - - ax.barh( - y_v, - val_v, - xerr=err_plot, - color=color_v, - alpha=0.9, - edgecolor="none", - error_kw={ - "elinewidth": 0.6, - "capthick": 0.6, - "capsize": 1.8, - "ecolor": "#2f2f2f", - "alpha": 0.9, - }, - ) - ax.set_title(_metric_title(title, direction), fontsize=9) - ax.set_yticks(y) - ax.set_yticklabels(labels, fontsize=6.8) - ax.invert_yaxis() - ax.grid(axis="x") - ax.tick_params(axis="x", labelsize=7) - _clean_axes(ax) - - for ax in axes[len(METRIC_SPECS) :]: - ax.axis("off") - - fig.suptitle("Benchmark Overview (Segger shades + 10x references)", fontsize=11, y=0.995) - fig.tight_layout(rect=[0, 0, 1, 0.98]) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _annotate_selected(ax, df: pd.DataFrame, x_col: str, y_col: str) -> None: - if df.empty: - return - refs = df[df["is_reference_row"]].copy() - seg = df[~df["is_reference_row"]].copy() - seg = _rank_df(seg) - - candidates = [] - if not refs.empty: - candidates.extend(refs.index.tolist()[:2]) - if not seg.empty: - candidates.extend(seg.index.tolist()[:3]) - - for idx in candidates: - row = df.loc[idx] - x = pd.to_numeric(pd.Series([row.get(x_col)]), errors="coerce").iloc[0] - y = pd.to_numeric(pd.Series([row.get(y_col)]), errors="coerce").iloc[0] - if not (np.isfinite(x) and np.isfinite(y)): - continue - ax.text(float(x), float(y), _safe_str(row.get("_plot_label", "")), fontsize=6.5, ha="left", va="bottom") - - -def _plot_scatter_page(pdf: PdfPages, df: pd.DataFrame) -> None: - ok = _ordered_ok_rows(df) - fig, axes = plt.subplots(1, 2, figsize=(11.5, 4.2)) - - for ax in axes: - _clean_axes(ax) - ax.grid(alpha=0.2) - - for _, row in ok.iterrows(): - color = _safe_str(row.get("_plot_color", "#4f6f8f")) - is_ref = bool(row.get("is_reference_row")) - marker = "s" if is_ref else "o" - size = 44 if is_ref else 32 - alpha = 0.95 if is_ref else 0.86 - - a = float(pd.to_numeric(pd.Series([row.get("assigned_pct")]), errors="coerce").iloc[0]) - c = float(pd.to_numeric(pd.Series([row.get("contamination_pct")]), errors="coerce").iloc[0]) - m = float(pd.to_numeric(pd.Series([row.get("mecr")]), errors="coerce").iloc[0]) - if np.isfinite(a) and np.isfinite(c): - axes[0].scatter(a, c, c=color, marker=marker, s=size, alpha=alpha, linewidths=0) - if np.isfinite(a) and np.isfinite(m): - axes[1].scatter(a, m, c=color, marker=marker, s=size, alpha=alpha, linewidths=0) - - axes[0].set_title("Sensitivity vs Contamination") - axes[0].set_xlabel("Assigned transcripts (%)") - axes[0].set_ylabel("Contamination (%) v") - axes[1].set_title("Sensitivity vs MECR") - axes[1].set_xlabel("Assigned transcripts (%)") - axes[1].set_ylabel("MECR v") - - _annotate_selected(axes[0], ok, "assigned_pct", "contamination_pct") - _annotate_selected(axes[1], ok, "assigned_pct", "mecr") - - fig.suptitle("Fast-Metric Trade-offs", fontsize=11) - fig.tight_layout(rect=[0, 0, 1, 0.95]) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _plot_heatmap_page(pdf: PdfPages, df: pd.DataFrame) -> None: - ok = _ordered_ok_rows(df) - metric_defs = [ - ("assigned_pct", True), - ("mecr", False), - ("contamination_pct", False), - ("tco", True), - ("doublet_pct", False), - ("gpu_time_min", False), - ] - if ok.empty: - fig, ax = plt.subplots(figsize=(11, 6)) - ax.axis("off") - ax.text(0.5, 0.5, "No valid rows for heatmap", ha="center", va="center", fontsize=10) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - arr = [] - labels = [] - for metric, hib in metric_defs: - vals = pd.to_numeric(ok[metric], errors="coerce").to_numpy(dtype=float) - arr.append(_normalize_metric(vals, hib)) - labels.append(metric) - mat = np.vstack(arr).T - mat_masked = np.ma.masked_invalid(mat) - - fig_h = max(4.0, 0.26 * len(ok) + 2.0) - fig, ax = plt.subplots(figsize=(10.3, fig_h)) - im = ax.imshow(mat_masked, aspect="auto", cmap="cividis", vmin=0.0, vmax=1.0) - ax.set_yticks(np.arange(len(ok))) - ax.set_yticklabels(ok["_plot_label"].astype(str).tolist(), fontsize=7) - ax.set_xticks(np.arange(len(labels))) - ax.set_xticklabels(labels, fontsize=8, rotation=30, ha="right") - ax.set_title("Metric Heatmap (normalized, higher is better)") - cbar = fig.colorbar(im, ax=ax, shrink=0.9, fraction=0.028, pad=0.015) - cbar.set_label("relative score", fontsize=8) - cbar.ax.tick_params(labelsize=7) - _clean_axes(ax) - fig.tight_layout() - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _find_sgutils_src() -> Path | None: - candidates = [ - Path(__file__).resolve().parents[2] / "segger-analysis" / "src", - Path("../segger-analysis/src").resolve(), - Path.cwd().parent / "segger-analysis" / "src", - ] - for c in candidates: - if c.exists(): - return c - return None - - -def _enable_sgutils_import() -> bool: - src = _find_sgutils_src() - if src is None: - return False - src_str = str(src) - if src_str not in sys.path: - sys.path.insert(0, src_str) - return True - - -def _valid_umap_xy(xy: np.ndarray | None) -> np.ndarray | None: - if xy is None: - return None - arr = np.asarray(xy) - if arr.ndim != 2 or arr.shape[0] == 0 or arr.shape[1] < 2: - return None - finite = np.isfinite(arr[:, 0]) & np.isfinite(arr[:, 1]) - arr = arr[finite] - if arr.shape[0] == 0: - return None - return arr[:, :2] - - -def _downsample_adata(adata_obj, max_cells: int, seed: int): - if getattr(adata_obj, "n_obs", 0) <= max_cells: - return adata_obj - rng = np.random.default_rng(seed) - keep = rng.choice(np.arange(adata_obj.n_obs), size=max_cells, replace=False) - keep = np.sort(keep) - return adata_obj[keep, :].copy() - - -def _compute_umap_with_sgutils(adata_obj, seed: int) -> np.ndarray | None: - if not _enable_sgutils_import(): - return None - try: - from sg_utils.pp.preprocess_rapids import preprocess_rapids - except Exception: - return None - try: - work = adata_obj.copy() - if getattr(work, "n_obs", 0) < 30 or getattr(work, "n_vars", 0) < 20: - return None - if getattr(work, "raw", None) is None: - work.raw = work.copy() - preprocess_rapids( - work, - n_hvgs=min(2000, max(256, int(getattr(work, "n_vars", 2000)))), - pca_total_var=0.9, - knn_neighbors=min(15, max(5, int(getattr(work, "n_obs", 100) - 1))), - umap_n_epochs=300, - random_state=seed, - show_progress=False, - ) - return _valid_umap_xy(work.obsm.get("X_umap")) - except Exception: - return None - - -def _compute_umap_with_scanpy(adata_obj, seed: int) -> np.ndarray | None: - try: - import scanpy as sc - except Exception: - return None - try: - work = adata_obj.copy() - if getattr(work, "n_obs", 0) < 25 or getattr(work, "n_vars", 0) < 15: - return None - sc.pp.filter_cells(work, min_counts=1) - sc.pp.filter_genes(work, min_counts=1) - if work.n_obs < 25 or work.n_vars < 15: - return None - - sc.pp.normalize_total(work, target_sum=1e4) - sc.pp.log1p(work) - - if work.n_vars > 80: - n_top = min(2000, max(80, int(0.5 * work.n_vars))) - sc.pp.highly_variable_genes(work, n_top_genes=n_top, flavor="seurat") - if "highly_variable" in work.var.columns: - hv_mask = np.asarray(work.var["highly_variable"]).astype(bool) - if int(hv_mask.sum()) >= 20: - work = work[:, hv_mask].copy() - - n_comps = min(35, work.n_obs - 1, work.n_vars - 1) - if n_comps < 2: - return None - sc.pp.pca(work, n_comps=n_comps) - n_neighbors = min(15, max(5, work.n_obs - 1)) - sc.pp.neighbors(work, n_neighbors=n_neighbors, n_pcs=min(20, n_comps)) - sc.tl.umap(work, min_dist=0.35, spread=1.0, random_state=seed) - return _valid_umap_xy(work.obsm.get("X_umap")) - except Exception: - return None - - -def _umap_points_for_path(anndata_path: Path, seed: int, max_cells: int) -> tuple[np.ndarray | None, str]: - if ad is None: - return None, "anndata_missing" - if not anndata_path.exists(): - return None, "missing_h5ad" - try: - adata_obj = ad.read_h5ad(anndata_path) - except Exception: - return None, "read_h5ad_failed" - - adata_obj = _downsample_adata(adata_obj, max_cells=max_cells, seed=seed) - xy = _valid_umap_xy(adata_obj.obsm.get("X_umap")) - if xy is not None: - return xy, "precomputed" - - xy = _compute_umap_with_sgutils(adata_obj, seed=seed) - if xy is not None: - return xy, "sgutils" - - xy = _compute_umap_with_scanpy(adata_obj, seed=seed) - if xy is not None: - return xy, "scanpy" - - return None, "umap_unavailable" - - -def _pick_umap_rows(df: pd.DataFrame) -> list[pd.Series]: - ok = _ordered_ok_rows(df) - ok = ok[ok["anndata_path"].astype(str).str.strip() != ""].copy() - if ok.empty: - return [] - - refs = ok[ok["is_reference_row"]].copy() - refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) - refs = refs.sort_values(by=["_ref_order", "job"]).head(2) - - seg = _rank_df(ok[~ok["is_reference_row"]].copy()) - best = seg.head(2) - worst = seg.tail(2) - - picked: list[pd.Series] = [] - seen_jobs: set[str] = set() - - def _push_rows(sub_df: pd.DataFrame) -> None: - for _, row in sub_df.iterrows(): - job = _safe_str(row.get("job", "")) - if job in seen_jobs: - continue - seen_jobs.add(job) - picked.append(row) - - _push_rows(refs) - _push_rows(best) - _push_rows(worst) - - if len(picked) < 6: - _push_rows(seg) - if len(picked) < 6: - _push_rows(refs) - - return picked[:6] - - -def _panel_title_from_row(row: pd.Series) -> str: - label = _safe_str(row.get("_plot_label", "")).strip() - if label: - return label - return _safe_str(row.get("job", "")) - - -def _plot_umap_panel(ax, row: pd.Series, seed: int, max_cells: int, cache: dict[str, tuple[np.ndarray | None, str]]) -> None: - _clean_axes(ax) - ax.set_xticks([]) - ax.set_yticks([]) - - title = _panel_title_from_row(row) - path = Path(_safe_str(row.get("anndata_path", "")).strip()) - color = _safe_str(row.get("_plot_color", "#4f6f8f")) - cache_key = str(path.resolve()) if path.as_posix() not in {"", "."} else _safe_str(row.get("job", "")) - - if cache_key in cache: - xy, source = cache[cache_key] - else: - xy, source = _umap_points_for_path(path, seed=seed, max_cells=max_cells) - cache[cache_key] = (xy, source) - - ax.set_title(title, fontsize=8.2) - if xy is None: - ax.text(0.5, 0.5, f"-- error: UMAP missing for {title}", ha="center", va="center", fontsize=7.2) - return - - ax.scatter(xy[:, 0], xy[:, 1], s=1.8, c=color, alpha=0.72, linewidths=0, rasterized=True) - ax.text( - 0.02, - 0.03, - f"n={xy.shape[0]} | {source}", - transform=ax.transAxes, - ha="left", - va="bottom", - fontsize=6.2, - color="#555555", - ) - ax.set_aspect("equal", adjustable="box") - - -def _plot_umap_page(pdf: PdfPages, df: pd.DataFrame, seed: int, umap_max_cells: int) -> None: - picks = _pick_umap_rows(df) - fig, axes = plt.subplots(2, 3, figsize=(11.2, 7.6)) - axes = axes.flatten() - cache: dict[str, tuple[np.ndarray | None, str]] = {} - - for i, ax in enumerate(axes): - if i < len(picks): - _plot_umap_panel(ax, picks[i], seed=seed + i, max_cells=umap_max_cells, cache=cache) - else: - ax.axis("off") - ax.text(0.5, 0.5, "no panel", ha="center", va="center", fontsize=8) - - fig.suptitle("UMAP Panels: 2 references + 2 best + 2 worst", fontsize=11) - fig.tight_layout(rect=[0, 0, 1, 0.96]) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def _load_transcript_xy(input_dir: Path) -> pd.DataFrame | None: - tx_path = input_dir / "transcripts.parquet" - if not tx_path.exists(): - return None - - if pl is not None: - try: - lf = pl.scan_parquet(tx_path, parallel="row_groups") - cols = lf.collect_schema().names() - if "row_index" in cols: - lf = lf.with_columns(pl.col("row_index").cast(pl.Int64)) - else: - lf = lf.with_row_index(name="row_index") - - x_col = "x_location" if "x_location" in cols else ("x" if "x" in cols else None) - y_col = "y_location" if "y_location" in cols else ("y" if "y" in cols else None) - if x_col is None or y_col is None: - return None - - tx = ( - lf.select(["row_index", x_col, y_col]) - .rename({x_col: "x", y_col: "y"}) - .collect() - .to_pandas() - ) - tx["row_index"] = pd.to_numeric(tx["row_index"], errors="coerce").astype("Int64") - tx["x"] = pd.to_numeric(tx["x"], errors="coerce") - tx["y"] = pd.to_numeric(tx["y"], errors="coerce") - return tx.dropna(subset=["row_index", "x", "y"]).copy() - except Exception: - pass - - try: - tx = pd.read_parquet(tx_path) - except Exception: - return None - if "row_index" not in tx.columns: - tx = tx.copy() - tx["row_index"] = np.arange(len(tx), dtype=np.int64) - x_col = "x_location" if "x_location" in tx.columns else ("x" if "x" in tx.columns else None) - y_col = "y_location" if "y_location" in tx.columns else ("y" if "y" in tx.columns else None) - if x_col is None or y_col is None: - return None - out = tx[["row_index", x_col, y_col]].rename(columns={x_col: "x", y_col: "y"}) - out["row_index"] = pd.to_numeric(out["row_index"], errors="coerce").astype("Int64") - out["x"] = pd.to_numeric(out["x"], errors="coerce") - out["y"] = pd.to_numeric(out["y"], errors="coerce") - return out.dropna(subset=["row_index", "x", "y"]).copy() - - -def _load_seg_assign(seg_path: Path) -> pd.DataFrame | None: - if not seg_path.exists(): - return None - - id_col_candidates = [ - "segger_cell_id", - "cell_id", - "xenium_cell_id", - "tenx_cell_id", - ] - - if pl is not None: - try: - lf = pl.scan_parquet(seg_path) - cols = lf.collect_schema().names() - id_col = None - for c in id_col_candidates: - if c in cols: - id_col = c - break - if id_col is None: - for c in cols: - if c.endswith("_cell_id"): - id_col = c - break - if id_col is None or "row_index" not in cols: - return None - - df = lf.select(["row_index", id_col]).collect().to_pandas() - df.columns = ["row_index", "segger_cell_id"] - df["row_index"] = pd.to_numeric(df["row_index"], errors="coerce").astype("Int64") - return df.dropna(subset=["row_index"]).copy() - except Exception: - pass - - try: - df = pd.read_parquet(seg_path) - except Exception: - return None - if "row_index" not in df.columns: - return None - id_col = None - for c in id_col_candidates: - if c in df.columns: - id_col = c - break - if id_col is None: - for c in df.columns: - if str(c).endswith("_cell_id"): - id_col = c - break - if id_col is None: - return None - out = df[["row_index", id_col]].copy() - out.columns = ["row_index", "segger_cell_id"] - out["row_index"] = pd.to_numeric(out["row_index"], errors="coerce").astype("Int64") - return out.dropna(subset=["row_index"]).copy() - - -def _is_assigned(series: pd.Series) -> pd.Series: - s = series.astype(str).str.strip() - return series.notna() & ~s.eq("") & ~s.str.upper().eq("UNASSIGNED") & ~s.str.lower().eq("nan") - - -def _select_small_fovs( - tx: pd.DataFrame, - n: int, - max_tx: int, - min_tx: int, - seed: int, -) -> list[tuple[float, float, float, float, int]]: - if tx.empty: - return [] - x = pd.to_numeric(tx["x"], errors="coerce").to_numpy(dtype=float) - y = pd.to_numeric(tx["y"], errors="coerce").to_numpy(dtype=float) - finite = np.isfinite(x) & np.isfinite(y) - x = x[finite] - y = y[finite] - if x.size == 0: - return [] - - n_bins = int(np.clip(np.sqrt(max(x.size / max(max_tx, 1), 8)) * 8, 24, 80)) - counts, xedges, yedges = np.histogram2d(x, y, bins=[n_bins, n_bins]) - - candidates = [] - for i in range(n_bins): - for j in range(n_bins): - c = int(counts[i, j]) - if min_tx <= c <= max_tx: - candidates.append((i, j, c)) - - if not candidates: - fallback = [] - for i in range(n_bins): - for j in range(n_bins): - c = int(counts[i, j]) - if c > 0: - fallback.append((i, j, c)) - fallback = sorted(fallback, key=lambda x: x[2]) - candidates = fallback[: max(1, n)] - - rng = np.random.default_rng(seed) - if len(candidates) > n: - order = rng.permutation(len(candidates)) - candidates = [candidates[k] for k in order[:n]] - - windows = [] - for i, j, c in candidates[:n]: - x0, x1 = float(xedges[i]), float(xedges[i + 1]) - y0, y1 = float(yedges[j]), float(yedges[j + 1]) - windows.append((x0, x1, y0, y1, int(c))) - return windows - - -def _convex_hull(points: np.ndarray) -> np.ndarray | None: - if points.ndim != 2 or points.shape[1] != 2: - return None - pts = np.unique(points, axis=0) - if pts.shape[0] < 3: - return None - pts = pts[np.lexsort((pts[:, 1], pts[:, 0]))] - - def cross(o, a, b) -> float: - return float((a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])) - - lower: list[np.ndarray] = [] - for p in pts: - while len(lower) >= 2 and cross(lower[-2], lower[-1], p) <= 0: - lower.pop() - lower.append(p) - - upper: list[np.ndarray] = [] - for p in pts[::-1]: - while len(upper) >= 2 and cross(upper[-2], upper[-1], p) <= 0: - upper.pop() - upper.append(p) - - hull = np.vstack((lower[:-1], upper[:-1])) - if hull.shape[0] < 3: - return None - return hull - - -def _pick_mask_rows(df: pd.DataFrame) -> list[pd.Series]: - ok = _ordered_ok_rows(df) - ok = ok[ok["segmentation_path"].astype(str).str.strip() != ""].copy() - if ok.empty: - return [] - - refs = ok[ok["is_reference_row"]].copy() - refs["_ref_order"] = refs["kind"].map({"10x_cell": 0, "10x_nucleus": 1}).fillna(9) - refs = refs.sort_values(by=["_ref_order", "job"]).head(2) - seg = _rank_df(ok[~ok["is_reference_row"]].copy()) - - picks: list[pd.Series] = [] - seen: set[str] = set() - for _, row in refs.iterrows(): - job = _safe_str(row.get("job", "")) - if job not in seen: - seen.add(job) - picks.append(row) - for sub in [seg.head(1), seg.tail(1), seg.head(3)]: - for _, row in sub.iterrows(): - job = _safe_str(row.get("job", "")) - if job not in seen: - seen.add(job) - picks.append(row) - if len(picks) >= 4: - break - if len(picks) >= 4: - break - return picks[:4] - - -def _plot_mask_panel(ax, sub: pd.DataFrame, base_color: str, title: str) -> None: - _clean_axes(ax) - ax.set_title(title, fontsize=7.2, pad=3.0) - ax.set_xticks([]) - ax.set_yticks([]) - - if sub.empty: - ax.text(0.5, 0.5, "no transcripts", transform=ax.transAxes, ha="center", va="center", fontsize=7) - return - - assigned_mask = _is_assigned(sub["segger_cell_id"]) - un = sub[~assigned_mask] - asn = sub[assigned_mask] - if not un.empty: - ax.scatter(un["x"], un["y"], s=1.0, c="#d3d3d3", alpha=0.22, linewidths=0, rasterized=True, zorder=1) - - n_cells = 0 - if not asn.empty and MplPolygon is not None: - for i, (_, grp) in enumerate(asn.groupby("segger_cell_id", sort=False)): - if len(grp) < 4: - continue - hull = _convex_hull(grp[["x", "y"]].to_numpy(dtype=float)) - if hull is None: - continue - t = (i % 8) / 8.0 - face = _mix_hex(base_color, "#ffffff", 0.18 + 0.35 * t) - edge = _mix_hex(base_color, "#0b2038", 0.35) - patch = MplPolygon( - hull, - closed=True, - facecolor=face, - edgecolor=edge, - linewidth=0.28, - alpha=0.48, - zorder=2, - ) - ax.add_patch(patch) - n_cells += 1 - - ax.scatter(asn["x"], asn["y"], s=1.1, c=base_color, alpha=0.44, linewidths=0, rasterized=True, zorder=3) - - ax.text( - 0.02, - 0.02, - f"tx={len(sub)} | cells={n_cells}", - transform=ax.transAxes, - ha="left", - va="bottom", - fontsize=6.2, - color="#5b5b5b", - ) - ax.set_aspect("equal", adjustable="box") - - -def _plot_fov_page( - pdf: PdfPages, - df: pd.DataFrame, - input_dir: Path | None, - seed: int, - fov_count: int, - fov_max_tx: int, - fov_min_tx: int, -) -> None: - if input_dir is None: - fig, ax = plt.subplots(figsize=(11, 3.8)) - ax.axis("off") - ax.text(0.5, 0.5, "FOV panels skipped: --input-dir not provided", ha="center", va="center", fontsize=10) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - tx = _load_transcript_xy(input_dir) - if tx is None or tx.empty: - fig, ax = plt.subplots(figsize=(11, 3.8)) - ax.axis("off") - ax.text(0.5, 0.5, "FOV panels skipped: transcripts.parquet unavailable", ha="center", va="center", fontsize=10) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - rows = _pick_mask_rows(df) - if not rows: - fig, ax = plt.subplots(figsize=(11, 3.8)) - ax.axis("off") - ax.text(0.5, 0.5, "FOV panels skipped: no usable segmentation rows", ha="center", va="center", fontsize=10) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - windows = _select_small_fovs( - tx=tx, - n=max(1, fov_count), - max_tx=max(100, fov_max_tx), - min_tx=max(10, fov_min_tx), - seed=seed, - ) - if not windows: - fig, ax = plt.subplots(figsize=(11, 3.8)) - ax.axis("off") - ax.text(0.5, 0.5, "FOV panels skipped: unable to find small windows", ha="center", va="center", fontsize=10) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - return - - tx_fovs = [] - for x0, x1, y0, y1, n_tx in windows: - sub = tx[(tx["x"] >= x0) & (tx["x"] <= x1) & (tx["y"] >= y0) & (tx["y"] <= y1)].copy() - tx_fovs.append((x0, x1, y0, y1, n_tx, sub)) - - nrows = len(rows) - ncols = len(tx_fovs) - fig_w = max(7.0, 3.9 * ncols) - fig_h = max(5.2, 2.7 * nrows) - fig, axes = plt.subplots(nrows, ncols, figsize=(fig_w, fig_h), squeeze=False) - - for r, row in enumerate(rows): - seg_path = Path(_safe_str(row.get("segmentation_path", "")).strip()) - seg_df = _load_seg_assign(seg_path) - if seg_df is None or seg_df.empty: - for c in range(ncols): - ax = axes[r, c] - ax.axis("off") - ax.text(0.5, 0.5, f"missing segmentation\n{_panel_title_from_row(row)}", ha="center", va="center", fontsize=7) - continue - seg_map = seg_df.set_index("row_index")["segger_cell_id"] - base_color = _safe_str(row.get("_plot_color", "#4f6f8f")) - method_name = _panel_title_from_row(row) - - for c, (x0, x1, y0, y1, n_tx, sub_tx) in enumerate(tx_fovs): - ax = axes[r, c] - sub = sub_tx.copy() - sub["segger_cell_id"] = sub["row_index"].map(seg_map) - title = f"{method_name} | FOV{c+1} ({n_tx} tx)" - _plot_mask_panel(ax, sub, base_color=base_color, title=title) - ax.set_xlim(x0, x1) - ax.set_ylim(y0, y1) - - fig.suptitle(f"Cell-mask FOV panels (convex hulls, < {fov_max_tx} transcripts/FOV)", fontsize=11, y=0.995) - fig.tight_layout(rect=[0, 0, 1, 0.98]) - pdf.savefig(fig, bbox_inches="tight") - plt.close(fig) - - -def build_report( - root: Path, - validation_tsv: Path, - out_pdf: Path, - input_dir: Path | None, - seed: int, - fov_count: int, - fov_max_tx: int, - fov_min_tx: int, - umap_max_cells: int, -) -> None: - if plt is None or PdfPages is None: - raise RuntimeError("matplotlib is required for PDF report generation") - if not validation_tsv.exists(): - raise FileNotFoundError(f"Validation TSV not found: {validation_tsv}") - - df = _load_metrics(validation_tsv) - out_pdf.parent.mkdir(parents=True, exist_ok=True) - _apply_report_style() - - with PdfPages(out_pdf) as pdf: - _plot_bar_page(pdf, df) - _plot_scatter_page(pdf, df) - _plot_heatmap_page(pdf, df) - _plot_umap_page(pdf, df, seed=seed, umap_max_cells=max(1000, umap_max_cells)) - _plot_fov_page( - pdf, - df, - input_dir=input_dir, - seed=seed, - fov_count=max(1, fov_count), - fov_max_tx=max(50, fov_max_tx), - fov_min_tx=max(5, min(fov_min_tx, fov_max_tx)), - ) - - -def main() -> int: - parser = argparse.ArgumentParser(description="Build benchmark multi-page PDF report") - parser.add_argument("--root", type=Path, default=Path("./results/mossi_main_big_benchmark_nightly")) - parser.add_argument("--validation-tsv", type=Path, default=None) - parser.add_argument("--out-pdf", type=Path, default=None) - parser.add_argument("--input-dir", type=Path, default=None) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--fov-count", type=int, default=2) - parser.add_argument("--fov-max-transcripts", type=int, default=2000) - parser.add_argument("--fov-min-transcripts", type=int, default=150) - parser.add_argument("--umap-max-cells", type=int, default=12000) - args = parser.parse_args() - - root = args.root - validation_tsv = args.validation_tsv or (root / "summaries" / "validation_metrics.tsv") - out_pdf = args.out_pdf or (root / "summaries" / "benchmark_report.pdf") - - build_report( - root=root, - validation_tsv=validation_tsv, - out_pdf=out_pdf, - input_dir=args.input_dir, - seed=args.seed, - fov_count=max(1, args.fov_count), - fov_max_tx=max(50, args.fov_max_transcripts), - fov_min_tx=max(5, args.fov_min_transcripts), - umap_max_cells=max(1000, args.umap_max_cells), - ) - print(f"Wrote report: {out_pdf}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/build_benchmark_validation_table.sh b/scripts/build_benchmark_validation_table.sh deleted file mode 100755 index 64f33ba..0000000 --- a/scripts/build_benchmark_validation_table.sh +++ /dev/null @@ -1,783 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -usage() { - cat <<'EOF' -Build per-job validation metrics table for a benchmark root. - -Usage: - bash scripts/build_benchmark_validation_table.sh [options] - -Options: - --root Benchmark root (default: ./results/mossi_main_big_benchmark_nightly) - --input-dir Source dataset path for contamination/geometry/doublet metrics (optional) - --out-tsv Output TSV (default: /summaries/validation_metrics.tsv) - --recompute Recompute all jobs even if already present in output TSV - --segger-bin Segger executable/command (default: segger) - --me-gene-pairs-path Optional ME-gene pair file passed to segger validate - --scrna-reference-path Optional scRNA h5ad passed to segger validate - --scrna-celltype-column scRNA cell type column (default: cell_type) - --max-me-gene-pairs Max sampled ME-gene pairs (default: 500) - --gpu-a GPU id used for group A labels (default: env GPU_A or 0) - --gpu-b GPU id used for group B labels (default: env GPU_B or 1) - --include-default-10x Include ref_10x_cell/ref_10x_nucleus rows (default: true) - --reference-universe-seg Canonical Segger universe segmentation override - -h, --help Show this help -EOF -} - -timestamp() { - date '+%Y-%m-%d %H:%M:%S' -} - -sanitize_tsv_field() { - local value="${1:-}" - value="${value//$'\t'/ }" - value="${value//$'\r'/ }" - value="${value//$'\n'/ }" - printf '%s' "${value}" -} - -normalize_token() { - local value="${1:-}" - value="$(sanitize_tsv_field "${value}")" - if [[ -z "${value}" ]]; then - printf '%s' "-" - else - printf '%s' "${value}" - fi -} - -normalize_bool() { - local v="${1:-}" - local lc - lc="$(printf '%s' "${v}" | tr '[:upper:]' '[:lower:]')" - case "${lc}" in - 1|true|t|yes|y|on) - printf 'true' - ;; - 0|false|f|no|n|off) - printf 'false' - ;; - *) - return 1 - ;; - esac -} - -ROOT="./results/mossi_main_big_benchmark_nightly" -INPUT_DIR="" -OUT_TSV="" -SEGGER_BIN="segger" -ME_GENE_PAIRS_PATH="" -SCRNA_REFERENCE_PATH="" -SCRNA_CELLTYPE_COLUMN="cell_type" -MAX_ME_GENE_PAIRS=500 -GPU_A="${GPU_A:-0}" -GPU_B="${GPU_B:-1}" -INCLUDE_DEFAULT_10X="true" -REFERENCE_UNIVERSE_SEG="" -RECOMPUTE=0 - -require_value() { - local opt="$1" - if [[ $# -lt 2 ]] || [[ -z "${2}" ]] || [[ "${2}" == -* ]]; then - echo "ERROR: ${opt} requires a value." >&2 - exit 1 - fi -} - -while [[ $# -gt 0 ]]; do - case "$1" in - --root) - require_value "$1" "${2-}" - ROOT="$2" - shift 2 - ;; - --input-dir) - require_value "$1" "${2-}" - INPUT_DIR="$2" - shift 2 - ;; - --out-tsv) - require_value "$1" "${2-}" - OUT_TSV="$2" - shift 2 - ;; - --recompute) - RECOMPUTE=1 - shift - ;; - --segger-bin) - require_value "$1" "${2-}" - SEGGER_BIN="$2" - shift 2 - ;; - --me-gene-pairs-path) - require_value "$1" "${2-}" - ME_GENE_PAIRS_PATH="$2" - shift 2 - ;; - --scrna-reference-path) - require_value "$1" "${2-}" - SCRNA_REFERENCE_PATH="$2" - shift 2 - ;; - --scrna-celltype-column) - require_value "$1" "${2-}" - SCRNA_CELLTYPE_COLUMN="$2" - shift 2 - ;; - --max-me-gene-pairs) - require_value "$1" "${2-}" - MAX_ME_GENE_PAIRS="$2" - shift 2 - ;; - --gpu-a) - require_value "$1" "${2-}" - GPU_A="$2" - shift 2 - ;; - --gpu-b) - require_value "$1" "${2-}" - GPU_B="$2" - shift 2 - ;; - --include-default-10x) - require_value "$1" "${2-}" - INCLUDE_DEFAULT_10X="$2" - shift 2 - ;; - --reference-universe-seg) - require_value "$1" "${2-}" - REFERENCE_UNIVERSE_SEG="$2" - shift 2 - ;; - -h|--help) - usage - exit 0 - ;; - *) - echo "Unknown argument: $1" >&2 - usage - exit 1 - ;; - esac -done - -if [[ -z "${OUT_TSV}" ]]; then - OUT_TSV="${ROOT}/summaries/validation_metrics.tsv" -fi - -if ! [[ "${MAX_ME_GENE_PAIRS}" =~ ^[0-9]+$ ]] || [[ "${MAX_ME_GENE_PAIRS}" -le 0 ]]; then - echo "ERROR: --max-me-gene-pairs must be a positive integer." >&2 - exit 1 -fi - -if ! INCLUDE_DEFAULT_10X="$(normalize_bool "${INCLUDE_DEFAULT_10X}")"; then - echo "ERROR: --include-default-10x must be true/false." >&2 - exit 1 -fi - -if [[ -n "${INPUT_DIR}" ]] && [[ ! -e "${INPUT_DIR}" ]]; then - echo "ERROR: --input-dir not found: ${INPUT_DIR}" >&2 - exit 1 -fi - -if [[ -n "${ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ME_GENE_PAIRS_PATH}" ]]; then - echo "ERROR: --me-gene-pairs-path not found: ${ME_GENE_PAIRS_PATH}" >&2 - exit 1 -fi - -if [[ -n "${SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: --scrna-reference-path not found: ${SCRNA_REFERENCE_PATH}" >&2 - exit 1 -fi - -if [[ -n "${REFERENCE_UNIVERSE_SEG}" ]] && [[ ! -f "${REFERENCE_UNIVERSE_SEG}" ]]; then - echo "ERROR: --reference-universe-seg not found: ${REFERENCE_UNIVERSE_SEG}" >&2 - exit 1 -fi - -SEGGER_CMD_PATH="" -if [[ "${SEGGER_BIN}" == */* ]]; then - if [[ ! -x "${SEGGER_BIN}" ]]; then - echo "ERROR: --segger-bin is not executable: ${SEGGER_BIN}" >&2 - exit 1 - fi - SEGGER_CMD_PATH="${SEGGER_BIN}" -else - if ! SEGGER_CMD_PATH="$(command -v "${SEGGER_BIN}" 2>/dev/null)"; then - echo "ERROR: segger command not found: ${SEGGER_BIN}" >&2 - exit 1 - fi -fi - -SEGGER_PYTHON="$(dirname "${SEGGER_CMD_PATH}")/python" -if [[ ! -x "${SEGGER_PYTHON}" ]]; then - if command -v python3 >/dev/null 2>&1; then - SEGGER_PYTHON="$(command -v python3)" - else - echo "ERROR: Could not resolve python interpreter for reference artifact builder." >&2 - exit 1 - fi -fi - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -REFERENCE_BUILDER_PY="${SCRIPT_DIR}/build_default_10x_reference_artifacts.py" -if [[ "${INCLUDE_DEFAULT_10X}" == "true" ]] && [[ ! -f "${REFERENCE_BUILDER_PY}" ]]; then - echo "ERROR: Missing reference artifact builder: ${REFERENCE_BUILDER_PY}" >&2 - exit 1 -fi - -PLAN_FILE="${ROOT}/job_plan.tsv" -RUNS_DIR="${ROOT}/runs" -EXPORTS_DIR="${ROOT}/exports" -SUMMARY_DIR="${ROOT}/summaries" -LOG_FILE="${SUMMARY_DIR}/validation_metrics.log" -REFERENCE_ARTIFACTS_DIR="${SUMMARY_DIR}/reference_artifacts" - -if [[ ! -f "${PLAN_FILE}" ]]; then - echo "ERROR: Missing benchmark plan file: ${PLAN_FILE}" >&2 - exit 1 -fi - -mkdir -p "${SUMMARY_DIR}" "$(dirname "${OUT_TSV}")" "${REFERENCE_ARTIFACTS_DIR}" - -tmp_out="$(mktemp)" -tmp_row="$(mktemp "${TMPDIR:-/tmp}/segger_validate_row.XXXXXX.tsv")" -trap 'rm -f "${tmp_out}" "${tmp_row}"' EXIT - -METRIC_SCHEMA_VERSION="2026-02-25-v8" -RUN_INPUT_DIR_TOKEN="$(normalize_token "${INPUT_DIR}")" -RUN_SCRNA_REFERENCE_TOKEN="$(normalize_token "${SCRNA_REFERENCE_PATH}")" -RUN_ME_GENE_PAIRS_TOKEN="$(normalize_token "${ME_GENE_PAIRS_PATH}")" - -REFERENCE_UNIVERSE_SEG_RESOLVED="" -if [[ -n "${REFERENCE_UNIVERSE_SEG}" ]]; then - REFERENCE_UNIVERSE_SEG_RESOLVED="${REFERENCE_UNIVERSE_SEG}" -elif [[ -f "${RUNS_DIR}/baseline/segger_segmentation.parquet" ]]; then - REFERENCE_UNIVERSE_SEG_RESOLVED="${RUNS_DIR}/baseline/segger_segmentation.parquet" -else - first_run_seg="$(find "${RUNS_DIR}" -mindepth 2 -maxdepth 2 -type f -name 'segger_segmentation.parquet' | sort | head -n1 || true)" - if [[ -n "${first_run_seg}" ]]; then - REFERENCE_UNIVERSE_SEG_RESOLVED="${first_run_seg}" - elif [[ -n "${INPUT_DIR}" ]] && [[ -f "${INPUT_DIR}/segger_segmentation.parquet" ]]; then - REFERENCE_UNIVERSE_SEG_RESOLVED="${INPUT_DIR}/segger_segmentation.parquet" - fi -fi -RUN_REFERENCE_UNIVERSE_TOKEN="$(normalize_token "${REFERENCE_UNIVERSE_SEG_RESOLVED}")" - -OUTPUT_HEADER=$'job\tgroup\tgpu\tis_reference\treference_kind\tvalidate_status\tvalidate_error\tgpu_time_s\tcells\tassigned_pct\tassigned_ci95\tmecr\tmecr_ci95\tcontamination_pct\tcontamination_ci95\tresolvi_contamination_pct\tresolvi_contamination_ci95\ttco\ttco_ci95\tdoublet_pct\tdoublet_ci95\tsegmentation_path\tanndata_path\toutput_path\tupdated_at\tmetric_schema_version\trun_input_dir\trun_scrna_reference_path\trun_me_gene_pairs_path\trun_reference_universe_seg' -reuse_existing=0 -if [[ "${RECOMPUTE}" != "1" ]] && [[ -s "${OUT_TSV}" ]]; then - existing_header="$(head -n1 "${OUT_TSV}" || true)" - if [[ "${existing_header}" == "${OUTPUT_HEADER}" ]]; then - reuse_existing=1 - else - echo "[$(timestamp)] WARN existing TSV header mismatch; recomputing all jobs" >> "${LOG_FILE}" - fi -fi - -get_field() { - local column_name="$1" - if [[ ! -s "${tmp_row}" ]]; then - return 0 - fi - awk -F'\t' -v key="${column_name}" ' - NR == 1 { - for (i = 1; i <= NF; i++) { - if ($i == key) { - idx = i - break - } - } - next - } - NR == 2 { - if (idx > 0) { - print $idx - } - exit - } - ' "${tmp_row}" -} - -get_existing_field_by_job() { - local job_name="$1" - local column_name="$2" - awk -F'\t' -v j="${job_name}" -v key="${column_name}" ' - NR == 1 { - for (i = 1; i <= NF; i++) { - if ($i == key) { - idx = i - break - } - } - next - } - $1 == j { - if (idx > 0) { - print $idx - } - exit - } - ' "${OUT_TSV}" -} - -lookup_gpu_time() { - local job_name="$1" - local preferred_gpu="$2" - local elapsed="" - local f row - - if [[ -f "${SUMMARY_DIR}/gpu${preferred_gpu}.tsv" ]]; then - row="$(awk -F'\t' -v j="${job_name}" ' - NR == 1 { - for (i = 1; i <= NF; i++) { - if ($i == "job") job_col = i - if ($i == "elapsed_s") elapsed_col = i - } - next - } - job_col > 0 && $job_col == j { - if (elapsed_col > 0) print $elapsed_col - exit - } - ' "${SUMMARY_DIR}/gpu${preferred_gpu}.tsv")" - if [[ -n "${row}" ]]; then - elapsed="${row}" - fi - fi - - if [[ -z "${elapsed}" ]]; then - for f in "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv"; do - [[ -f "${f}" ]] || continue - row="$(awk -F'\t' -v j="${job_name}" ' - NR == 1 { - for (i = 1; i <= NF; i++) { - if ($i == "job") job_col = i - if ($i == "elapsed_s") elapsed_col = i - } - next - } - job_col > 0 && $job_col == j { - if (elapsed_col > 0) print $elapsed_col - exit - } - ' "${f}")" - if [[ -n "${row}" ]]; then - elapsed="${row}" - break - fi - done - fi - - if [[ -z "${elapsed}" ]]; then - elapsed="0" - fi - printf '%s' "${elapsed}" -} - -scale_frac_to_pct() { - local value="${1:-}" - local lower - lower="$(printf '%s' "${value}" | tr '[:upper:]' '[:lower:]')" - if [[ -z "${value}" ]] || [[ "${lower}" == "nan" ]] || [[ "${lower}" == "none" ]]; then - printf '%s' "nan" - return 0 - fi - awk -v v="${value}" 'BEGIN { printf "%.6f", (v + 0.0) * 100.0 }' -} - -append_row() { - local job="$1" - local group="$2" - local gpu="$3" - local is_reference="$4" - local reference_kind="$5" - local validate_status="$6" - local validate_error="$7" - local gpu_time_s="$8" - local cells="$9" - local assigned_pct="${10}" - local assigned_ci95="${11}" - local mecr="${12}" - local mecr_ci95="${13}" - local contamination_pct="${14}" - local contamination_ci95="${15}" - local resolvi_contamination_pct="${16}" - local resolvi_contamination_ci95="${17}" - local tco="${18}" - local tco_ci95="${19}" - local doublet_pct="${20}" - local doublet_ci95="${21}" - local row_seg_path="${22}" - local row_anndata_path="${23}" - - validate_error="$(sanitize_tsv_field "${validate_error}")" - reference_kind="$(normalize_token "${reference_kind}")" - - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job}" "${group}" "${gpu}" "${is_reference}" "${reference_kind}" "${validate_status}" "${validate_error}" "${gpu_time_s}" \ - "${cells}" "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" "${contamination_pct}" "${contamination_ci95}" \ - "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" "${OUT_TSV}" \ - "$(timestamp)" "${METRIC_SCHEMA_VERSION}" "${RUN_INPUT_DIR_TOKEN}" "${RUN_SCRNA_REFERENCE_TOKEN}" "${RUN_ME_GENE_PAIRS_TOKEN}" "${RUN_REFERENCE_UNIVERSE_TOKEN}" \ - >> "${tmp_out}" -} - -should_reuse_row() { - local job="$1" - local is_reference="$2" - local should_reuse=0 - local reason="" - - if [[ "${reuse_existing}" != "1" ]]; then - printf '0\tno_reuse_mode\n' - return 0 - fi - - existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" - if [[ -z "${existing_row}" ]]; then - printf '0\tmissing_existing_row\n' - return 0 - fi - - existing_status="$(get_existing_field_by_job "${job}" "validate_status")" - existing_metric_schema_version="$(get_existing_field_by_job "${job}" "metric_schema_version")" - existing_run_input_dir="$(get_existing_field_by_job "${job}" "run_input_dir")" - existing_run_scrna_ref="$(get_existing_field_by_job "${job}" "run_scrna_reference_path")" - existing_run_me_pairs="$(get_existing_field_by_job "${job}" "run_me_gene_pairs_path")" - existing_run_ref_universe="$(get_existing_field_by_job "${job}" "run_reference_universe_seg")" - existing_cells="$(get_existing_field_by_job "${job}" "cells")" - existing_is_reference="$(get_existing_field_by_job "${job}" "is_reference")" - - should_reuse=1 - reason="ok" - - if [[ "${existing_status}" != "ok" ]]; then - should_reuse=0 - reason="status=${existing_status}" - elif [[ "${existing_metric_schema_version:-}" != "${METRIC_SCHEMA_VERSION}" ]]; then - should_reuse=0 - reason="schema_mismatch" - elif [[ "${existing_run_input_dir:-}" != "${RUN_INPUT_DIR_TOKEN}" ]] || \ - [[ "${existing_run_scrna_ref:-}" != "${RUN_SCRNA_REFERENCE_TOKEN}" ]] || \ - [[ "${existing_run_me_pairs:-}" != "${RUN_ME_GENE_PAIRS_TOKEN}" ]] || \ - [[ "${existing_run_ref_universe:-}" != "${RUN_REFERENCE_UNIVERSE_TOKEN}" ]]; then - should_reuse=0 - reason="validation_inputs_changed" - elif [[ "${existing_is_reference:-0}" != "${is_reference}" ]]; then - should_reuse=0 - reason="reference_flag_changed" - elif [[ -z "${existing_cells:-}" ]] || [[ "$(printf '%s' "${existing_cells}" | tr '[:upper:]' '[:lower:]')" == "nan" ]]; then - should_reuse=0 - reason="missing_cells" - elif [[ -n "${INPUT_DIR}" ]]; then - existing_tco="$(get_existing_field_by_job "${job}" "tco")" - existing_contam="$(get_existing_field_by_job "${job}" "contamination_pct")" - existing_doublet="$(get_existing_field_by_job "${job}" "doublet_pct")" - if [[ -z "${existing_tco:-}" ]] || [[ -z "${existing_contam:-}" ]] || [[ -z "${existing_doublet:-}" ]]; then - should_reuse=0 - reason="missing_input_dependent_metrics" - elif [[ -n "${SCRNA_REFERENCE_PATH}" ]]; then - existing_resolvi="$(get_existing_field_by_job "${job}" "resolvi_contamination_pct")" - if [[ -z "${existing_resolvi:-}" ]]; then - should_reuse=0 - reason="missing_resolvi_metric" - fi - fi - fi - - printf '%s\t%s\n' "${should_reuse}" "${reason}" -} - -run_validate_for_row() { - local job="$1" - local seg_path="$2" - local anndata_path="$3" - local apply_elapsed_fallback="$4" - - validate_status="ok" - validate_error="" - cells="0" - assigned_pct="nan" - assigned_ci95="nan" - mecr="nan" - mecr_ci95="nan" - contamination_pct="nan" - contamination_ci95="nan" - resolvi_contamination_pct="nan" - resolvi_contamination_ci95="nan" - tco="nan" - tco_ci95="nan" - doublet_pct="nan" - doublet_ci95="nan" - row_seg_path="${seg_path}" - row_anndata_path="" - - if [[ ! -f "${seg_path}" ]]; then - validate_status="missing_segmentation" - validate_error="segger_segmentation.parquet not found" - return 0 - fi - - : > "${tmp_row}" - cmd=( - "${SEGGER_BIN}" validate - -s "${seg_path}" - -o "${tmp_row}" - --max-me-gene-pairs "${MAX_ME_GENE_PAIRS}" - ) - if [[ -f "${anndata_path}" ]]; then - cmd+=(-a "${anndata_path}") - fi - if [[ -n "${INPUT_DIR}" ]]; then - cmd+=(-i "${INPUT_DIR}") - fi - if [[ -n "${ME_GENE_PAIRS_PATH}" ]]; then - cmd+=(--me-gene-pairs-path "${ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${SCRNA_REFERENCE_PATH}" ]]; then - cmd+=( - --scrna-reference-path "${SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${SCRNA_CELLTYPE_COLUMN}" - ) - fi - - { - printf '[%s] job=%s CMD:' "$(timestamp)" "${job}" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${LOG_FILE}" - - if "${cmd[@]}" >> "${LOG_FILE}" 2>&1; then - parsed_status="$(get_field "validate_status")" - [[ -n "${parsed_status}" ]] && validate_status="${parsed_status}" - parsed_error="$(get_field "validate_error")" - [[ -n "${parsed_error}" ]] && validate_error="${parsed_error}" - - parsed_elapsed="$(get_field "elapsed_s")" - if [[ "${apply_elapsed_fallback}" == "1" ]] && [[ -n "${parsed_elapsed}" ]] && [[ "${gpu_time_s}" == "0" || -z "${gpu_time_s}" ]]; then - gpu_time_s="${parsed_elapsed}" - fi - - parsed_val="$(get_field "cells_total")" - if [[ -n "${parsed_val}" ]] && [[ "$(printf '%s' "${parsed_val}" | tr '[:upper:]' '[:lower:]')" != "nan" ]]; then - cells="${parsed_val}" - else - parsed_val="$(get_field "cells_assigned")" - [[ -n "${parsed_val}" ]] && cells="${parsed_val}" - fi - - parsed_val="$(get_field "transcripts_assigned_pct")" - [[ -n "${parsed_val}" ]] && assigned_pct="${parsed_val}" - parsed_val="$(get_field "transcripts_assigned_pct_ci95")" - [[ -n "${parsed_val}" ]] && assigned_ci95="${parsed_val}" - - parsed_val="$(get_field "mecr_fast")" - [[ -n "${parsed_val}" ]] && mecr="${parsed_val}" - parsed_val="$(get_field "mecr_ci95_fast")" - [[ -n "${parsed_val}" ]] && mecr_ci95="${parsed_val}" - - parsed_val="$(get_field "border_contaminated_cells_pct_fast")" - [[ -n "${parsed_val}" ]] && contamination_pct="${parsed_val}" - parsed_val="$(get_field "border_contaminated_cells_pct_ci95_fast")" - [[ -n "${parsed_val}" ]] && contamination_ci95="${parsed_val}" - - parsed_val="$(get_field "resolvi_contamination_pct_fast")" - [[ -n "${parsed_val}" ]] && resolvi_contamination_pct="${parsed_val}" - parsed_val="$(get_field "resolvi_contamination_ci95_fast")" - [[ -n "${parsed_val}" ]] && resolvi_contamination_ci95="${parsed_val}" - - parsed_val="$(get_field "transcript_centroid_offset_fast")" - [[ -n "${parsed_val}" ]] && tco="${parsed_val}" - parsed_val="$(get_field "transcript_centroid_offset_ci95_fast")" - [[ -n "${parsed_val}" ]] && tco_ci95="${parsed_val}" - - parsed_val="$(get_field "signal_doublet_like_fraction_fast")" - if [[ -n "${parsed_val}" ]]; then - doublet_pct="$(scale_frac_to_pct "${parsed_val}")" - fi - parsed_val="$(get_field "signal_doublet_like_fraction_ci95_fast")" - if [[ -n "${parsed_val}" ]]; then - doublet_ci95="$(scale_frac_to_pct "${parsed_val}")" - fi - - parsed_val="$(get_field "segmentation_path")" - [[ -n "${parsed_val}" ]] && row_seg_path="${parsed_val}" - parsed_val="$(get_field "anndata_path")" - [[ -n "${parsed_val}" ]] && row_anndata_path="${parsed_val}" - else - validate_status="validate_command_failed" - validate_error="segger validate command failed" - fi -} - -printf "%s\n" "${OUTPUT_HEADER}" > "${tmp_out}" - -echo "[$(timestamp)] START benchmark validation table build" >> "${LOG_FILE}" -echo "[$(timestamp)] ROOT=${ROOT}" >> "${LOG_FILE}" -echo "[$(timestamp)] OUT_TSV=${OUT_TSV}" >> "${LOG_FILE}" -echo "[$(timestamp)] RECOMPUTE=${RECOMPUTE}" >> "${LOG_FILE}" -echo "[$(timestamp)] INCLUDE_DEFAULT_10X=${INCLUDE_DEFAULT_10X}" >> "${LOG_FILE}" -echo "[$(timestamp)] REFERENCE_UNIVERSE_SEG=${RUN_REFERENCE_UNIVERSE_TOKEN}" >> "${LOG_FILE}" - -reused_count=0 -computed_count=0 - -while IFS=$'\t' read -r \ - job group _use3d _expansion _txk _txdist _layers _heads _cellsmin _minqv _alignment; do - if [[ -z "${job:-}" ]] || [[ "${job}" == "job" ]]; then - continue - fi - - reuse_info="$(should_reuse_row "${job}" "0")" - should_reuse="${reuse_info%%$'\t'*}" - reuse_reason="${reuse_info#*$'\t'}" - if [[ "${should_reuse}" == "1" ]]; then - existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" - printf "%s\n" "${existing_row}" >> "${tmp_out}" - reused_count=$((reused_count + 1)) - echo "[$(timestamp)] SKIP job=${job}: existing row reused" >> "${LOG_FILE}" - continue - fi - if [[ "${reuse_existing}" == "1" ]]; then - echo "[$(timestamp)] RECOMPUTE job=${job}: ${reuse_reason}" >> "${LOG_FILE}" - fi - - computed_count=$((computed_count + 1)) - - gpu="${GPU_A}" - if [[ "${group}" == "B" ]]; then - gpu="${GPU_B}" - fi - - seg_path="${RUNS_DIR}/${job}/segger_segmentation.parquet" - anndata_path="${EXPORTS_DIR}/${job}/anndata/segger_segmentation.h5ad" - gpu_time_s="$(lookup_gpu_time "${job}" "${gpu}")" - - start_ts="$(date +%s)" - run_validate_for_row "${job}" "${seg_path}" "${anndata_path}" "1" - end_ts="$(date +%s)" - if [[ "${gpu_time_s}" == "0" || -z "${gpu_time_s}" ]]; then - gpu_time_s="$((end_ts - start_ts))" - fi - - append_row \ - "${job}" "${group}" "${gpu}" "0" "-" \ - "${validate_status}" "${validate_error}" "${gpu_time_s}" "${cells}" \ - "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" \ - "${contamination_pct}" "${contamination_ci95}" "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" \ - "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" -done < "${PLAN_FILE}" - -if [[ "${INCLUDE_DEFAULT_10X}" == "true" ]]; then - for reference_kind in 10x_cell 10x_nucleus; do - job="ref_${reference_kind}" - group="R" - gpu="-" - - reuse_info="$(should_reuse_row "${job}" "1")" - should_reuse="${reuse_info%%$'\t'*}" - reuse_reason="${reuse_info#*$'\t'}" - if [[ "${should_reuse}" == "1" ]]; then - existing_row="$(awk -F'\t' -v j="${job}" 'NR > 1 && $1 == j { print; exit }' "${OUT_TSV}")" - printf "%s\n" "${existing_row}" >> "${tmp_out}" - reused_count=$((reused_count + 1)) - echo "[$(timestamp)] SKIP job=${job}: existing row reused" >> "${LOG_FILE}" - continue - fi - if [[ "${reuse_existing}" == "1" ]]; then - echo "[$(timestamp)] RECOMPUTE job=${job}: ${reuse_reason}" >> "${LOG_FILE}" - fi - - computed_count=$((computed_count + 1)) - gpu_time_s="0" - validate_status="ok" - validate_error="" - cells="0" - assigned_pct="nan" - assigned_ci95="nan" - mecr="nan" - mecr_ci95="nan" - contamination_pct="nan" - contamination_ci95="nan" - resolvi_contamination_pct="nan" - resolvi_contamination_ci95="nan" - tco="nan" - tco_ci95="nan" - doublet_pct="nan" - doublet_ci95="nan" - row_seg_path="${REFERENCE_ARTIFACTS_DIR}/${job}/segger_segmentation.parquet" - row_anndata_path="${REFERENCE_ARTIFACTS_DIR}/${job}/segger_segmentation.h5ad" - - if [[ -z "${INPUT_DIR}" ]]; then - validate_status="missing_input_dir" - validate_error="--input-dir is required for default 10x references" - elif [[ -z "${REFERENCE_UNIVERSE_SEG_RESOLVED}" ]] || [[ ! -f "${REFERENCE_UNIVERSE_SEG_RESOLVED}" ]]; then - validate_status="missing_universe_segmentation" - validate_error="canonical Segger universe segmentation not found" - else - mkdir -p "$(dirname "${row_seg_path}")" - build_cmd=( - "${SEGGER_PYTHON}" "${REFERENCE_BUILDER_PY}" - --input-dir "${INPUT_DIR}" - --canonical-seg "${REFERENCE_UNIVERSE_SEG_RESOLVED}" - --kind "${reference_kind}" - --out-seg "${row_seg_path}" - --out-h5ad "${row_anndata_path}" - ) - - { - printf '[%s] job=%s BUILD_REF_CMD:' "$(timestamp)" "${job}" - printf ' %q' "${build_cmd[@]}" - printf '\n' - } >> "${LOG_FILE}" - - if "${build_cmd[@]}" >> "${LOG_FILE}" 2>&1; then - run_validate_for_row "${job}" "${row_seg_path}" "${row_anndata_path}" "0" - else - validate_status="reference_artifact_failed" - validate_error="failed to build default 10x reference artifacts" - fi - fi - - append_row \ - "${job}" "${group}" "${gpu}" "1" "${reference_kind}" \ - "${validate_status}" "${validate_error}" "${gpu_time_s}" "${cells}" \ - "${assigned_pct}" "${assigned_ci95}" "${mecr}" "${mecr_ci95}" \ - "${contamination_pct}" "${contamination_ci95}" "${resolvi_contamination_pct}" "${resolvi_contamination_ci95}" "${tco}" "${tco_ci95}" \ - "${doublet_pct}" "${doublet_ci95}" "${row_seg_path}" "${row_anndata_path}" - done -fi - -tmp_sorted="$(mktemp)" -{ - head -n1 "${tmp_out}" - tail -n +2 "${tmp_out}" \ - | awk -F'\t' ' - function norm_num(v, lower) { - lower = tolower(v) - if (v == "" || lower == "nan" || lower == "none") return "" - return v + 0.0 - } - { - assigned = norm_num($10) - mecr = norm_num($12) - if (assigned == "") assigned_key = -1 - else assigned_key = assigned - if (mecr == "") mecr_key = 1e99 - else mecr_key = mecr - printf "%.10f\t%.10f\t%s\n", assigned_key, mecr_key, $0 - } - ' \ - | sort -t $'\t' -k1,1gr -k2,2g -k3,3 \ - | cut -f3- -} > "${tmp_sorted}" -mv "${tmp_sorted}" "${tmp_out}" - -mv "${tmp_out}" "${OUT_TSV}" -echo "[$(timestamp)] WROTE validation table: ${OUT_TSV}" >> "${LOG_FILE}" -echo "[$(timestamp)] SUMMARY reused=${reused_count} computed=${computed_count}" >> "${LOG_FILE}" -echo "Wrote validation table: ${OUT_TSV} (reused=${reused_count}, computed=${computed_count})" diff --git a/scripts/build_default_10x_reference_artifacts.py b/scripts/build_default_10x_reference_artifacts.py deleted file mode 100755 index ddc5787..0000000 --- a/scripts/build_default_10x_reference_artifacts.py +++ /dev/null @@ -1,230 +0,0 @@ -#!/usr/bin/env python3 -"""Build temporary 10x reference segmentation + AnnData on Segger row_index universe.""" - -from __future__ import annotations - -import argparse -import json -from pathlib import Path - -import anndata as ad -import numpy as np -import pandas as pd -import polars as pl -from scipy import sparse as sp - - -def _pick_column(columns: list[str], candidates: list[str], required: bool = True) -> str | None: - for c in candidates: - if c in columns: - return c - if required: - raise ValueError(f"Missing required column; tried: {candidates}") - return None - - -def _clean_cell_id_expr(cell_col: str) -> pl.Expr: - cell_str = pl.col(cell_col).cast(pl.Utf8) - return ( - pl.when( - pl.col(cell_col).is_null() - | (cell_str == "") - | cell_str.str.to_uppercase().is_in(["UNASSIGNED", "NONE", "-1"]) - ) - .then(None) - .otherwise(cell_str) - ) - - -def _nucleus_overlap_expr(overlap_col: str) -> pl.Expr: - overlap_str = pl.col(overlap_col).cast(pl.Utf8).str.to_lowercase() - return overlap_str.is_in(["1", "2", "true", "t", "yes", "y", "nuclear"]) - - -def _build_anndata(tx: pl.DataFrame, out_h5ad: Path) -> tuple[int, int]: - if tx.height == 0: - adata = ad.AnnData( - X=sp.csr_matrix((0, 0)), - obs=pd.DataFrame(index=pd.Index([], name="segger_cell_id")), - var=pd.DataFrame(index=pd.Index([], name="feature_name")), - ) - adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) - return 0, 0 - - assigned = tx.filter( - pl.col("segger_cell_id").is_not_null() & pl.col("feature_name").is_not_null() - ) - if assigned.height == 0: - genes = ( - tx.select("feature_name") - .drop_nulls() - .unique() - .sort("feature_name") - .get_column("feature_name") - .cast(pl.Utf8) - .to_list() - ) - adata = ad.AnnData( - X=sp.csr_matrix((0, len(genes))), - obs=pd.DataFrame(index=pd.Index([], name="segger_cell_id")), - var=pd.DataFrame(index=pd.Index([str(g) for g in genes], name="feature_name")), - ) - adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) - return 0, len(genes) - - feature_idx = ( - assigned.select("feature_name") - .with_columns(pl.col("feature_name").cast(pl.Utf8)) - .unique() - .sort("feature_name") - .with_row_index(name="_fid") - ) - cell_idx = ( - assigned.select("segger_cell_id") - .with_columns(pl.col("segger_cell_id").cast(pl.Utf8)) - .unique() - .sort("segger_cell_id") - .with_row_index(name="_cid") - ) - - mapped = assigned.join(feature_idx, on="feature_name").join(cell_idx, on="segger_cell_id") - counts = mapped.group_by(["_cid", "_fid"]).agg(pl.len().alias("_count")) - - ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T - rows = ijv[0].astype(np.int64, copy=False) - cols = ijv[1].astype(np.int64, copy=False) - data = ijv[2].astype(np.int64, copy=False) - - X = sp.coo_matrix((data, (rows, cols)), shape=(cell_idx.height, feature_idx.height)).tocsr() - adata = ad.AnnData( - X=X, - obs=pd.DataFrame(index=pd.Index(cell_idx.get_column("segger_cell_id").to_list(), name="segger_cell_id")), - var=pd.DataFrame(index=pd.Index(feature_idx.get_column("feature_name").to_list(), name="feature_name")), - ) - - coord_cols = [c for c in ["x", "y", "z"] if c in assigned.columns] - if "x" in coord_cols and "y" in coord_cols: - centroids = ( - assigned.group_by("segger_cell_id") - .agg([pl.col(c).mean().alias(c) for c in coord_cols]) - .to_pandas() - .set_index("segger_cell_id") - .reindex(adata.obs.index) - ) - adata.obsm["X_spatial"] = centroids[coord_cols].to_numpy() - - adata.write_h5ad(out_h5ad, compression="gzip", compression_opts=4) - return int(adata.n_obs), int(adata.n_vars) - - -def build_reference_artifacts( - *, - input_dir: Path, - canonical_seg: Path, - kind: str, - out_seg: Path, - out_h5ad: Path, -) -> dict[str, object]: - if kind not in {"10x_cell", "10x_nucleus"}: - raise ValueError("kind must be one of: 10x_cell, 10x_nucleus") - - tx_path = input_dir / "transcripts.parquet" - if not tx_path.exists(): - raise FileNotFoundError(f"transcripts.parquet not found under input dir: {tx_path}") - - if not canonical_seg.exists(): - raise FileNotFoundError(f"Canonical universe segmentation missing: {canonical_seg}") - - universe = pl.read_parquet(canonical_seg).select(pl.col("row_index").cast(pl.Int64)).unique().sort("row_index") - - tx_lf = pl.scan_parquet(tx_path, parallel="row_groups").with_row_index(name="row_index") - schema_names = tx_lf.collect_schema().names() - - feature_col = _pick_column(schema_names, ["feature_name", "target", "gene"]) # Xenium/CosMX/MERSCOPE - x_col = _pick_column(schema_names, ["x_location", "x", "global_x", "x_global_px"]) - y_col = _pick_column(schema_names, ["y_location", "y", "global_y", "y_global_px"]) - z_col = _pick_column(schema_names, ["z_location", "z", "global_z"], required=False) - cell_col = _pick_column(schema_names, ["cell_id", "cell"]) - overlap_col = _pick_column(schema_names, ["overlaps_nucleus", "cell_compartment", "CellComp"], required=False) - - select_cols = ["row_index", feature_col, x_col, y_col, cell_col] - if z_col is not None: - select_cols.append(z_col) - if overlap_col is not None: - select_cols.append(overlap_col) - - tx_raw = tx_lf.select(select_cols) - universe_tx = universe.lazy().join(tx_raw, on="row_index", how="left") - - universe_tx = universe_tx.with_columns( - _clean_cell_id_expr(cell_col).alias("_cell_id_clean"), - ) - if kind == "10x_nucleus": - if overlap_col is None: - raise ValueError("Cannot build 10x_nucleus reference: overlaps_nucleus-like column not found") - universe_tx = universe_tx.with_columns( - pl.when(_nucleus_overlap_expr(overlap_col)).then(pl.col("_cell_id_clean")).otherwise(None).alias("segger_cell_id") - ) - else: - universe_tx = universe_tx.with_columns(pl.col("_cell_id_clean").alias("segger_cell_id")) - - rename_map = { - feature_col: "feature_name", - x_col: "x", - y_col: "y", - } - if z_col is not None: - rename_map[z_col] = "z" - - tx = ( - universe_tx - .select(["row_index", "segger_cell_id", *list(rename_map.keys())]) - .rename(rename_map) - .collect() - ) - - out_seg.parent.mkdir(parents=True, exist_ok=True) - out_h5ad.parent.mkdir(parents=True, exist_ok=True) - - seg_df = tx.select(["row_index", "segger_cell_id"]).with_columns(pl.lit(1.0).alias("segger_similarity")) - seg_df.write_parquet(out_seg) - - n_cells, n_genes = _build_anndata(tx, out_h5ad) - - assigned_n = int(tx.filter(pl.col("segger_cell_id").is_not_null()).height) - summary = { - "kind": kind, - "canonical_seg": str(canonical_seg), - "input_transcripts": str(tx_path), - "rows_universe": int(tx.height), - "rows_assigned": assigned_n, - "cells": n_cells, - "genes": n_genes, - "out_seg": str(out_seg), - "out_h5ad": str(out_h5ad), - } - return summary - - -def main() -> int: - parser = argparse.ArgumentParser(description="Build default 10x reference artifacts on Segger universe") - parser.add_argument("--input-dir", type=Path, required=True) - parser.add_argument("--canonical-seg", type=Path, required=True) - parser.add_argument("--kind", choices=["10x_cell", "10x_nucleus"], required=True) - parser.add_argument("--out-seg", type=Path, required=True) - parser.add_argument("--out-h5ad", type=Path, required=True) - args = parser.parse_args() - - summary = build_reference_artifacts( - input_dir=args.input_dir, - canonical_seg=args.canonical_seg, - kind=args.kind, - out_seg=args.out_seg, - out_h5ad=args.out_h5ad, - ) - print(json.dumps(summary, sort_keys=True)) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/presentation/experiments.md b/scripts/presentation/experiments.md deleted file mode 100644 index 8706e52..0000000 --- a/scripts/presentation/experiments.md +++ /dev/null @@ -1,361 +0,0 @@ -# Segger v0.2.0 Experiment Design - -## Overview - -The `scripts/` directory contains a fully automated benchmarking pipeline that trains, predicts, exports, validates, and reports on Segger segmentation results across a systematic grid of hyperparameter configurations. The dataset is **Xenium pancreas (Mossi)** and the pipeline runs overnight on **N GPUs** (auto-detected or overridable). - -The experiments answer three questions: - -1. **Which hyperparameters matter most for segmentation quality?** (parameter sensitivity) -2. **Are the best configurations stable and robust to perturbations?** (repeatability & robustness) -3. **Which architectural components and loss terms are actually necessary?** (ablation study) - ---- - -## Scripts at a Glance - -| Script | Role | -|--------|------| -| `run_param_benchmark_2gpu.sh` | **Parameter sweep** -- one-factor-at-a-time around a baseline | -| `run_robustness_ablation_2gpu.sh` | **Robustness & ablation** -- stability repeats, interaction grid, stress tests | -| `run_ablation_study.sh` | **Component ablation** -- loss decomposition, architecture, features (auto-detects GPUs) | -| `build_default_10x_reference_artifacts.py` | Build 10x cell/nucleus **reference segmentations** for comparison | -| `build_benchmark_validation_table.sh` | Compute **validation metrics** (MECR, contamination, TCO, doublet, assignment %) for every run | -| `benchmark_status_dashboard.sh` | Live **terminal dashboard** showing progress, failures, and ranked metrics | -| `build_benchmark_pdf_report.py` | Generate a **multi-page PDF** with bar charts, scatter plots, heatmaps, UMAPs, and FOV panels | - ---- - -## Experiment 1: Parameter Sweep (`run_param_benchmark_2gpu.sh`) - -### What it does - -Runs a **one-factor-at-a-time (OFAT)** sensitivity analysis. Starting from a fixed baseline configuration, it varies exactly one parameter at a time while holding all others constant. This isolates the marginal effect of each parameter on segmentation quality. - -### Baseline configuration - -| Parameter | Value | Meaning | -|-----------|-------|---------| -| `use_3d` | `true` | Include z-coordinate in graph construction | -| `expansion_ratio` | `2.0` | Scale factor for boundary polygons (captures edge transcripts) | -| `tx_max_k` | `5` | Max k-nearest-neighbors per transcript node | -| `tx_max_dist` | `5` | Max distance (microns) for tx-tx edges | -| `n_mid_layers` | `2` | Number of GNN message-passing layers | -| `n_heads` | `2` | Number of attention heads in the transformer encoder | -| `cells_min_counts` | `5` | Minimum transcripts per cell to include | -| `alignment_loss` | `true` | Enable ME-gene alignment loss from scRNA reference | - -### Sweep axes - -Each axis varies one parameter while the rest stay at baseline: - -| Axis | Values tested | Why | -|------|--------------|-----| -| **use_3d** | `false`, `true` | Does z-coordinate information improve segmentation? Xenium captures z-stacks, but not all platforms do. Tests whether the model benefits from 3D spatial context. | -| **expansion_ratio** | 1.0, 1.5, 2.0, 2.5, 3.0 | Controls how far beyond the nucleus boundary Segger looks for transcripts. Too small = misses cytoplasmic transcripts (low sensitivity). Too large = captures transcripts from neighboring cells (low specificity / high contamination). | -| **tx_max_k** | 5, 10, 20 | How many transcript neighbors each node connects to. More neighbors = richer local context for the GNN but higher memory/compute cost and potential for over-smoothing. | -| **tx_max_dist** | 3, 5, 10, 20 | Maximum edge distance for tx-tx connections. Interacts with tissue density -- sparse tissue needs larger distances, dense tissue smaller. Affects the receptive field of the GNN. | -| **n_mid_layers** | 1, 2, 3 | Depth of the GNN. Deeper = larger receptive field but risk of over-smoothing. For cell-level segmentation, 1-3 layers is the typical range. | -| **n_heads** | 2, 4, 8 | Attention heads in the IST encoder. More heads = more representational capacity, but diminishing returns and higher cost. | -| **cells_min_counts** | 3, 5, 10 | Minimum transcript threshold to define a cell. Lower = more cells detected (potentially fragments/noise). Higher = stricter, fewer false-positive cells. | -| **alignment_loss** | `false`, `true` | Whether to add the ME-gene constraint loss during training. This loss penalizes co-expression of mutually exclusive genes within the same cell, leveraging scRNA-seq priors. | - -### Logic - -The OFAT design keeps the experiment tractable (approximately 20 jobs instead of a combinatorial explosion of hundreds). Each comparison has a single variable, so any metric change can be attributed directly to that parameter. The trade-off is that OFAT misses parameter interactions -- that's what Experiment 2 addresses. - -### Total jobs - -~20 runs (1 baseline + ~19 single-parameter variants), split across 2 GPUs in round-robin order. Each run: - -1. **Trains** Segger for 20 epochs -2. **Predicts** cell assignments -3. **Exports** to AnnData (.h5ad) and Xenium Explorer format -4. If training OOMs during prediction, falls back to the last checkpoint -5. If an `ancdata` multiprocessing error occurs, retries with 0 workers - ---- - -## Experiment 2: Robustness & Ablation (`run_robustness_ablation_2gpu.sh`) - -### What it does - -Three study blocks that go beyond single-parameter variation: - -### Block A: Stability / Repeatability - -Runs the **same** configuration multiple times (default 3 repeats) to measure run-to-run variance: - -| Config | Repeats | Purpose | -|--------|---------|---------| -| Baseline (legacy) | 3 | Is the original configuration stable across random seeds/initialization? | -| Anchor (current best) | 3 | Is the improved configuration equally stable? | -| High-sensitivity variant | 2 | Does pushing expansion to 3.0 produce consistent results? | - -**Why:** GNN training involves stochastic initialization and mini-batch sampling. If the same hyperparameters produce wildly different metrics across runs, we can't trust the parameter sweep results. Stability repeats give us error bars. - -The **anchor** configuration (expansion=2.5, tx_dist=20, n_heads=4) represents the "current best" derived from early validation trends, distinct from the legacy baseline. - -### Block B: Interaction Grid - -Tests **combinations** of the most impactful parameters simultaneously: - -- `expansion_ratio` x `tx_max_dist` x `n_heads` (2 x 2 x 2 = 8 combinations with alignment=true) -- Plus alignment ablation at each corner (expansion x dist with heads=4, alignment=false): 4 more jobs - -**Why:** The OFAT sweep can't detect interactions. For example, a larger expansion ratio might only help when combined with a larger tx_max_dist (because expanded boundaries need longer-range edges to connect properly). This grid covers the "high-performing region" identified by the sweep -- it's not exhaustive but targets where interactions are most likely to matter. - -The alignment ablation within the grid specifically tests: **does the ME-gene loss help or hurt at different graph configurations?** If the loss only helps in some regimes, that's critical to know. - -### Block C: Stress Tests - -Deliberately pushes single parameters to **extreme or degraded** values to test robustness: - -| Test | What changes | Why | -|------|-------------|-----| -| `stress_use3d_false_anchor` | Drops z from anchor config | Does the anchor config fall apart without 3D? | -| `stress_use3d_false_sens` | Drops z from high-sensitivity config | Same for the aggressive expansion config | -| `stress_cellsmin3_anchor` | Very permissive cell threshold | How noisy do results get with loose filters? | -| `stress_cellsmin10_anchor` | Strict cell threshold | How many cells/transcripts do we lose? | -| `stress_txk20_anchor` | Very dense transcript graph | Does OOM or over-smoothing kick in? | -| `stress_layers1_anchor` | Minimal GNN depth | Can a single layer still segment well? | - -**Why:** Practical deployment means users may have different tissue types, densities, and platform configurations. Stress tests reveal how gracefully the model degrades when conditions shift from the ideal. - -### Total jobs - -~24-28 runs (8 stability + 12 interaction + 6 stress), again split round-robin across 2 GPUs. - ---- - -## Experiment 3: Component Ablation (`run_ablation_study.sh`) - -### What it does - -Systematically removes or swaps individual components -- loss terms, architecture choices, and feature representations -- to measure each one's contribution to segmentation quality. Unlike the OFAT parameter sweep (Experiment 1) which varies continuous hyperparameters, this script tests **discrete design decisions**: "Is this component necessary?" - -### Key differences from the other benchmark scripts - -| Feature | `run_param_benchmark_2gpu.sh` / `run_robustness_ablation_2gpu.sh` | `run_ablation_study.sh` | -|---------|------------------------------------------------------------------|------------------------| -| GPU handling | Hardcoded 2 GPUs (`GPU_A`, `GPU_B`) | Auto-detects N GPUs (round-robin distribution) | -| Job spec | 10 pipe-delimited fields | 21 fields (adds loss weights, architecture, features, LR) | -| Focus | Hyperparameter sensitivity & stability | Component necessity & design decisions | -| Block toggles | `RUN_INTERACTION_GRID`, `RUN_STRESS_TESTS` | 6 toggles: `RUN_LOSS_ABLATION`, `RUN_SGLOSS_ABLATION`, `RUN_ALIGNMENT_SWEEP`, `RUN_ARCH_ABLATION`, `RUN_PREDICTION_ABLATION`, `RUN_LR_ABLATION` | - -### Anchor configuration - -All ablation jobs start from the current best ("anchor") configuration and modify exactly one aspect: - -| Parameter | Anchor value | -|-----------|-------------| -| `use_3d` | `true` | -| `expansion_ratio` | `2.5` | -| `tx_max_k` | `5` | -| `tx_max_dist` | `20` | -| `n_mid_layers` | `2` | -| `n_heads` | `4` | -| `hidden_channels` / `out_channels` | `64` / `64` | -| `sg_loss_type` | `triplet` | -| `tx_weight_end` / `bd_weight_end` / `sg_weight_end` | `1.0` / `1.0` / `0.5` | -| `alignment_loss` | `true` (weight `0.03`) | -| `positional_embeddings` | `true` | -| `normalize_embeddings` | `true` | -| `cells_representation` | `pca` | -| `learning_rate` | `1e-3` | -| `prediction_mode` | `nucleus` | - -### Block A: Loss Decomposition (6 jobs) - -Tests every meaningful subset of the 4 loss terms to determine which are necessary: - -| Job | sg_loss | tx_triplet | bd_metric | alignment | Question | -|-----|:-------:|:----------:|:---------:|:---------:|----------| -| `abl_sg_only` | ON | - | - | - | Is the segmentation loss alone sufficient? | -| `abl_sg_tx` | ON | ON | - | - | Does transcript clustering help? | -| `abl_sg_bd` | ON | - | ON | - | Does boundary clustering help? | -| `abl_sg_tx_bd` | ON | ON | ON | - | Full v1 loss (no alignment) -- the pre-alignment baseline | -| `abl_sg_align` | ON | - | - | ON | Can alignment replace the triplet losses? | -| `abl_full` | ON | ON | ON | ON | Anchor baseline -- should be best | - -**Why:** The multi-task loss has 4 components with scheduled weights. We don't know if the triplet/metric losses for transcript and boundary clustering are necessary, or if they just slow convergence. If `sg_only` performs nearly as well as `full`, the training pipeline can be simplified significantly. - -### Block B: Segmentation Loss Type (2 jobs) - -| Job | sg_loss_type | Notes | -|-----|-------------|-------| -| `abl_sgloss_triplet` | triplet | Current default (margin-based) | -| `abl_sgloss_bce` | bce | Binary cross-entropy (v0.1.0 approach) | - -**Why:** Direct comparison on the same data reveals which formulation produces better assignment boundaries. - -### Block C: Alignment Weight Sweep (5 jobs) - -| Job | alignment_weight_end | Notes | -|-----|---------------------|-------| -| `abl_aw_0` | 0.0 | No alignment (control) | -| `abl_aw_001` | 0.01 | Light regularization | -| `abl_aw_003` | 0.03 | Current default | -| `abl_aw_01` | 0.1 | Strong regularization | -| `abl_aw_03` | 0.3 | Very strong -- may over-regularize | - -**Why:** The alignment loss weight was chosen somewhat arbitrarily. This sweep identifies the sweet spot. If 0.1 beats 0.03 on MECR without hurting assigned %, we should increase it. - -### Block D: Architecture Ablation (10 jobs) - -| Job | What changes | Question | -|-----|-------------|----------| -| `abl_depth_0` | 0 mid layers (in+out only) | Can a non-message-passing encoder segment? | -| `abl_depth_1` | 1 mid layer | Is 2 layers deeper than needed? | -| `abl_depth_3` | 3 mid layers | Does more depth help or over-smooth? | -| `abl_width_32` | 32/32 hidden/out | Can a 4x smaller model match the default? | -| `abl_width_128` | 128/128 hidden/out | Does 4x more capacity help? | -| `abl_heads_1` | 1 attention head | Is multi-head attention necessary? | -| `abl_heads_8` | 8 attention heads | Diminishing returns from more heads? | -| `abl_no_pos` | No positional embeddings | Are spatial encodings redundant given graph structure? | -| `abl_no_norm` | No embedding normalization | Does L2 normalization help or constrain? | -| `abl_morph` | Morphology cell features | Are polygon-derived features better than PCA? | - -**Why:** Each tests whether a specific design choice is earning its complexity. Findings directly inform model simplification or capacity recommendations. - -### Block E: Prediction Mode (2 jobs) - -| Job | prediction_mode | Notes | -|-----|----------------|-------| -| `abl_pred_cell` | cell | All transcripts within cell boundary for training edges | -| `abl_pred_uniform` | uniform | Uniform sampling around boundary | - -The anchor uses `nucleus` mode. These test whether alternative prediction graph construction strategies improve or degrade quality. - -### Block F: Learning Rate (3 jobs) - -| Job | learning_rate | Notes | -|-----|--------------|-------| -| `abl_lr_3e4` | 3e-4 | Conservative (slower convergence) | -| `abl_lr_3e3` | 3e-3 | Aggressive (faster, riskier) | -| `abl_lr_1e2` | 1e-2 | Very aggressive (may diverge) | - -The anchor uses `1e-3`. This identifies whether the learning rate is well-tuned or if training could be faster. - -### Total jobs - -**28 jobs** across 6 blocks. Each block can be toggled independently. Fits in one overnight session on 2+ GPUs. - -### GPU auto-detection - -The script automatically detects available GPUs: - -1. If `CUDA_VISIBLE_DEVICES` is set, counts the comma-separated IDs -2. Otherwise, queries `nvidia-smi --list-gpus` -3. Falls back to 1 GPU if neither is available -4. Can be overridden with `NUM_GPUS=N` - -Jobs are distributed round-robin across all detected GPUs and launched as parallel background processes. - -### Usage - -```bash -# Full ablation (auto-detect GPUs) -bash scripts/run_ablation_study.sh - -# Dry run -- prints job plan and exits -DRY_RUN=1 bash scripts/run_ablation_study.sh - -# Run only loss and architecture blocks on 4 GPUs -RUN_SGLOSS_ABLATION=0 RUN_ALIGNMENT_SWEEP=0 RUN_PREDICTION_ABLATION=0 \ -RUN_LR_ABLATION=0 NUM_GPUS=4 bash scripts/run_ablation_study.sh - -# Override anchor values -ANCHOR_N_HEADS=2 ANCHOR_EXPANSION=3.0 bash scripts/run_ablation_study.sh -``` - -### Recovery and fault tolerance - -Identical to the other benchmark scripts: OOM predict fallback, ancdata retry with reduced workers, timeout enforcement, and a post-run recovery pass that attempts predict-only from saved checkpoints. - ---- - -## Validation Metrics (`build_benchmark_validation_table.sh`) - -After all runs complete, this script calls `segger validate` on every segmentation output and collects metrics into a single TSV: - -| Metric | Direction | What it measures | -|--------|-----------|-----------------| -| **assigned_pct** | higher is better | Fraction of transcripts assigned to a cell (sensitivity) | -| **MECR** | lower is better | Mutually Exclusive Co-expression Rate -- are biologically impossible gene pairs showing up in the same cell? (specificity) | -| **contamination_pct** | lower is better | Fraction of cells with border contamination from neighbors | -| **TCO** | higher is better | Transcript-Centroid Offset -- how well do assigned transcripts cluster toward their cell center | -| **doublet_pct** | lower is better | Fraction of cells that look like merged doublets | - -### Reference baselines - -The script also builds **10x default segmentations** (cell-level and nucleus-only) on the same transcript universe as Segger. This provides an apples-to-apples comparison: "How does Segger compare to the manufacturer's built-in segmentation?" - -`build_default_10x_reference_artifacts.py` handles this by: -1. Reading the raw `transcripts.parquet` from the Xenium dataset -2. Filtering to the same `row_index` universe that Segger used -3. Using the 10x `cell_id` column (all transcripts) or `overlaps_nucleus` (nuclear only) as the assignment -4. Building a matching AnnData for metric computation - -### Incremental computation - -The validation table is **incremental** -- it reuses existing rows if the metric schema version, input paths, and reference universe haven't changed. This means you can re-run after fixing a failed job without recomputing everything. - ---- - -## Dashboard (`benchmark_status_dashboard.sh`) - -A terminal tool that reads the job plan, GPU summary files, and validation TSV to show: - -- Progress bar (done/total) -- State counts (running, pending, failed, done) -- Failure categorization (OOM, timeout, ancdata errors) -- Ranked validation metrics table with bold highlighting on top-2 performers -- Running/failed/retried job details - -Supports `--watch N` for auto-refresh every N seconds during overnight runs. - ---- - -## PDF Report (`build_benchmark_pdf_report.py`) - -Generates a publication-quality multi-page PDF: - -| Page | Content | -|------|---------| -| **Bar charts** | All 6 metrics side-by-side for every run, ranked by overall score. Segger runs in a blue gradient (darker = better), 10x references in orange. | -| **Scatter plots** | Sensitivity vs Contamination, Sensitivity vs MECR -- visualizes the trade-off frontier. | -| **Heatmap** | Normalized 0-1 metric matrix across all runs (cividis colormap). Quick visual comparison. | -| **UMAP panels** | 6 panels showing cell embedding structure for 2 references, 2 best Segger, 2 worst Segger runs. Uses scanpy or sg_utils for dimensionality reduction. | -| **FOV panels** | Small field-of-view cutouts showing actual cell boundaries (convex hulls) overlaid on transcript positions. Compares how different configurations segment the same tissue region. | - ---- - -## Why This Experimental Design - -### The scientific logic - -Segger frames cell segmentation as **link prediction on a heterogeneous graph**. The quality of segmentation depends on: - -1. **Graph topology** -- which transcripts and boundaries are connected (controlled by `expansion_ratio`, `tx_max_k`, `tx_max_dist`, `use_3d`) -2. **Model capacity** -- how the GNN processes the graph (controlled by `n_mid_layers`, `n_heads`) -3. **Training signal** -- what the loss function optimizes (controlled by `alignment_loss`) -4. **Post-processing** -- how results are filtered (controlled by `cells_min_counts`) - -The experiments systematically vary each of these four aspects: - -- **OFAT sweep** identifies which knobs matter most (often expansion ratio and tx_max_dist dominate) -- **Interaction grid** checks if the top parameters synergize or conflict -- **Stability repeats** quantify noise so we know if a 2% MECR improvement is real or random -- **Stress tests** reveal failure modes for the recommended configuration -- **10x references** provide a competitive baseline -- "is Segger actually better?" - -### The engineering logic - -- **N-GPU parallelism** distributes jobs round-robin across all available GPUs (auto-detected or overridable) -- **OOM fallback** (predict from last checkpoint) salvages partially-trained runs -- **Ancdata retry** (reduce dataloader workers) handles a known PyTorch multiprocessing bug -- **Timeout enforcement** (90 min default) prevents a single hung job from blocking the entire queue -- **Post-run recovery pass** catches jobs that failed during prediction but left a usable checkpoint -- **Incremental validation** avoids recomputing expensive metrics when only a few jobs were re-run -- **TSV-based outputs** integrate easily with downstream analysis notebooks diff --git a/scripts/presentation/experiments_plan.md b/scripts/presentation/experiments_plan.md deleted file mode 100644 index 85a20c6..0000000 --- a/scripts/presentation/experiments_plan.md +++ /dev/null @@ -1,327 +0,0 @@ -# Ablation & Extended Experiment Plan - -## Motivation - -The current benchmarks (see `experiments.md`) answer "which hyperparameter values work best?" via OFAT sweeps and interaction grids. What they **don't** answer is: - -- Which **architectural components** are actually necessary? -- Which **loss terms** contribute signal vs add noise? -- Does the model **generalize** across tissues and platforms? -- Where are the **failure modes**? - -This plan proposes a structured ablation study organized into 5 tiers, from highest expected impact to exploratory. Each experiment isolates one design decision and measures the delta on our 5 validation metrics (assigned %, MECR, contamination, TCO, doublet %). - ---- - -## Tier 1: Loss Function Ablation - -These are the most informative experiments because they test whether each loss term is earning its weight. - -### 1A. Full loss decomposition - -Train with every possible subset of the 4 loss terms: - -| Experiment | tx_triplet | bd_metric | sg_loss | alignment | Expected insight | -|-----------|:---:|:---:|:---:|:---:|---| -| `abl_sg_only` | - | - | ON | - | Minimum viable loss -- is the segmentation loss alone sufficient? | -| `abl_sg+tx` | ON | - | ON | - | Does transcript clustering help? | -| `abl_sg+bd` | - | ON | ON | - | Does boundary clustering help? | -| `abl_sg+tx+bd` | ON | ON | ON | - | Full v1 loss (no alignment) -- the pre-alignment baseline | -| `abl_full` | ON | ON | ON | ON | Current default -- should be best, or alignment is hurting | -| `abl_sg+align` | - | - | ON | ON | Can alignment replace the triplet losses entirely? | - -**Why:** The multi-task loss has 4 components with scheduled weights. We don't know if the triplet/metric losses for transcript and boundary clustering are actually necessary, or if they just slow convergence. If `sg_only` performs nearly as well as `full`, we can simplify the training pipeline significantly. - -### 1B. Segmentation loss type - -| Experiment | sg_loss_type | Notes | -|-----------|-------------|-------| -| `abl_sg_triplet` | triplet | Current default (margin-based) | -| `abl_sg_bce` | bce | Binary cross-entropy with random negatives | - -**Why:** The code supports both but defaults to triplet. BCE was the v0.1.0 approach. Direct comparison on the same data reveals which formulation produces better assignment boundaries. - -### 1C. Alignment loss strength - -| Experiment | alignment_weight_end | Notes | -|-----------|---------------------|-------| -| `abl_align_0` | 0.0 | No alignment (control) | -| `abl_align_001` | 0.01 | Light regularization | -| `abl_align_003` | 0.03 | Current default | -| `abl_align_01` | 0.1 | Strong regularization | -| `abl_align_03` | 0.3 | Very strong -- likely over-regularizes | - -**Why:** The alignment loss weight was chosen somewhat arbitrarily. This sweep identifies the sweet spot. If 0.1 beats 0.03 on MECR without hurting assigned %, we should increase it. If 0.01 is equivalent to 0.03, we're wasting gradient signal. - -### 1D. Loss weight schedule - -| Experiment | Schedule | Notes | -|-----------|---------|-------| -| `abl_sched_cosine` | Cosine ramp (current) | sg: 0 -> 0.5 over training | -| `abl_sched_fixed` | Fixed weights | sg: 0.5 from epoch 0 | -| `abl_sched_linear` | Linear ramp | sg: 0 -> 0.5 linearly | -| `abl_sched_late` | Late activation | sg: 0 for first 50% of training, then 0.5 | - -**Why:** The cosine ramp was designed to let the encoder warm up before the segmentation loss kicks in. If fixed weights perform equally, the schedule adds unnecessary complexity. - ---- - -## Tier 2: Graph Topology Ablation - -The heterogeneous graph has 3 edge types. Each encodes different information. Removing them reveals what the GNN actually needs. - -### 2A. Edge type removal - -| Experiment | tx-tx | tx-bd (ref) | tx-bd (pred) | Expected insight | -|-----------|:---:|:---:|:---:|---| -| `abl_edges_full` | ON | ON | ON | Baseline (all edges) | -| `abl_edges_no_txtx` | - | ON | ON | Is local transcript context necessary? | -| `abl_edges_no_ref` | ON | - | ON | Can the model learn without reference segmentation? | -| `abl_edges_txtx_only` | ON | - | - | Pure transcript clustering (no boundary info) | - -**Why:** tx-tx edges are the most expensive to construct (KDTree over millions of transcripts). If removing them doesn't hurt, we can dramatically speed up data processing. Conversely, if they're essential, we know local context is a critical signal. - -### 2B. tx-tx graph density - -| Experiment | tx_max_k | tx_max_dist | Effective density | -|-----------|---------|------------|-------------------| -| `abl_txtx_sparse` | 3 | 3.0 | Very sparse local context | -| `abl_txtx_default` | 5 | 5.0 | Current default | -| `abl_txtx_medium` | 10 | 10.0 | Medium density | -| `abl_txtx_dense` | 20 | 20.0 | Dense (high memory) | - -**Why:** These two parameters interact -- both must be large for a dense graph. This tests whether a minimal local graph (k=3, d=3) captures the same information as the more expensive default. - -### 2C. Prediction graph mode - -| Experiment | prediction_mode | Notes | -|-----------|----------------|-------| -| `abl_pred_nucleus` | nucleus | Only nuclear transcripts for training edges | -| `abl_pred_cell` | cell | All transcripts within cell boundary | -| `abl_pred_uniform` | uniform | Uniform sampling around boundary | - -**Why:** The prediction graph mode controls which transcript-boundary edges are used during training. Nucleus mode is conservative (high-confidence assignments), cell mode is permissive (more edges, potentially noisier labels). - ---- - -## Tier 3: Feature & Embedding Ablation - -Tests whether the input representations matter or if the GNN can learn from raw structure alone. - -### 3A. Gene embedding source - -| Experiment | Gene features | Notes | -|-----------|--------------|-------| -| `abl_gene_scrnaseq` | scRNA PCA embeddings | Current default (cell-type proportion vectors) | -| `abl_gene_onehot` | One-hot encoding | No biological prior, pure token identity | -| `abl_gene_random` | Random fixed vectors | Control -- no gene-level information at all | -| `abl_gene_learned` | Trainable from scratch | Let the GNN discover gene relationships | - -**Why:** The scRNA-derived gene embeddings encode cell-type co-expression priors. If one-hot performs similarly, the embeddings aren't adding value. If random vectors work, the GNN is learning purely from spatial structure -- which would be a significant finding. - -### 3B. Boundary (cell) features - -| Experiment | cells_representation_mode | Notes | -|-----------|--------------------------|-------| -| `abl_bd_pca` | pca (dim=128) | Current default -- gene expression PCA | -| `abl_bd_morph` | morphology | Polygon-derived features (area, convexity, elongation) | -| `abl_bd_none` | zeros | No boundary features (structure only) | -| `abl_bd_pca_small` | pca (dim=32) | Reduced dimensionality | - -**Why:** Boundary features are expensive to compute (morphology requires polygon operations; PCA requires gene counting). If zeros work, the model learns cell identity purely from connected transcripts. - -### 3C. Positional embeddings - -| Experiment | use_positional_embeddings | Notes | -|-----------|--------------------------|-------| -| `abl_pos_on` | True | Current default -- sinusoidal 2D encoding | -| `abl_pos_off` | False | No spatial encoding in embeddings | - -**Why:** The graph structure already encodes spatial relationships through edge construction. Positional embeddings may be redundant. If removing them doesn't hurt, it simplifies the model. - -### 3D. Embedding normalization - -| Experiment | normalize_embeddings | Notes | -|-----------|---------------------|-------| -| `abl_norm_on` | True | L2-normalize output embeddings (current) | -| `abl_norm_off` | False | Raw unnormalized embeddings | - -**Why:** L2 normalization constrains embeddings to a unit hypersphere, which affects how cosine similarity (used at prediction) distributes. Without normalization, the model can use magnitude as an additional signal. - ---- - -## Tier 4: Architecture Ablation - -Tests structural choices in the GNN itself. - -### 4A. GNN depth - -| Experiment | n_mid_layers | Total layers | Receptive field | -|-----------|-------------|-------------|-----------------| -| `abl_depth_0` | 0 | 2 (in+out only) | 1-hop | -| `abl_depth_1` | 1 | 3 | 2-hop | -| `abl_depth_2` | 2 | 4 | 3-hop (default) | -| `abl_depth_3` | 3 | 5 | 4-hop | -| `abl_depth_4` | 4 | 6 | 5-hop | - -**Why:** Deeper GNNs have larger receptive fields but risk over-smoothing (all node embeddings converge). For cell segmentation, the optimal depth depends on cell size relative to transcript density. 0 layers tests whether a non-message-passing encoder can segment at all. - -### 4B. Model width - -| Experiment | hidden_channels | out_channels | Parameters (approx) | -|-----------|----------------|-------------|---------------------| -| `abl_width_32` | 32 | 32 | ~25% of default | -| `abl_width_64` | 64 | 64 | Default | -| `abl_width_128` | 128 | 128 | ~4x default | -| `abl_width_256` | 256 | 256 | ~16x default | - -**Why:** Determines the capacity vs efficiency tradeoff. If 32 channels perform within 2% of 64, we can deploy a much faster model. - -### 4C. Attention heads - -| Experiment | n_heads | Notes | -|-----------|--------|-------| -| `abl_heads_1` | 1 | Single-head attention (simplest) | -| `abl_heads_2` | 2 | Current default | -| `abl_heads_4` | 4 | Double capacity | -| `abl_heads_8` | 8 | Diminishing returns? | - -**Why:** Multi-head attention allows the model to attend to different relationship types simultaneously. But for a graph with only 3 edge types, 8 heads may be overkill. - -### 4D. Skip connections - -| Experiment | skip_connections | Notes | -|-----------|-----------------|-------| -| `abl_skip_none` | None | Current default (despite class name "SkipGAT") | -| `abl_skip_residual` | Residual add | Standard ResNet-style | - -**Why:** The model class is called SkipGAT but doesn't implement skip connections. Adding them could help with gradient flow in deeper models and would test whether the current architecture is leaving performance on the table. - ---- - -## Tier 5: Generalization & Cross-Dataset - -Tests whether findings transfer beyond the Xenium pancreas dataset. - -### 5A. Cross-tissue (same platform) - -| Experiment | Dataset | Tissue | Density | Notes | -|-----------|---------|--------|---------|-------| -| `gen_pancreas` | Xenium pancreas (Mossi) | Pancreas | Medium | Current benchmark dataset | -| `gen_brain` | Xenium brain | Brain cortex | High | Dense, many cell types | -| `gen_lung` | Xenium lung | Lung | Mixed | Sparse stroma + dense epithelium | -| `gen_tumor` | Xenium tumor (CRC) | Colorectal | Variable | Disordered tissue, heterogeneous | - -**Why:** All current experiments use one dataset. If the optimal hyperparameters shift dramatically between tissues, we need tissue-specific recommendations or a more robust default. - -### 5B. Cross-platform - -| Experiment | Platform | Key differences | -|-----------|---------|----------------| -| `gen_xenium` | 10x Xenium | High QV scores, nuclear boundaries available | -| `gen_merscope` | Vizgen MERSCOPE | FOV-based stitching, polygon boundaries | -| `gen_cosmx` | NanoString CosMx | Different noise profile, z-stacks | - -**Why:** Segger claims platform-agnostic segmentation. Cross-platform experiments validate this claim and identify platform-specific failure modes. - -### 5C. Data efficiency - -| Experiment | Subsample % | Transcripts (approx) | Notes | -|-----------|------------|----------------------|-------| -| `gen_full` | 100% | ~5M | Full dataset | -| `gen_50pct` | 50% | ~2.5M | Moderate reduction | -| `gen_25pct` | 25% | ~1.25M | Aggressive reduction | -| `gen_10pct` | 10% | ~500K | Stress test | - -**Why:** Lower-depth sequencing or smaller gene panels produce fewer transcripts per cell. This tests how gracefully Segger degrades and identifies the minimum data requirement for useful segmentation. - -### 5D. Training data vs inference data shift - -| Experiment | Train on | Predict on | Notes | -|-----------|---------|-----------|-------| -| `gen_same` | Pancreas | Pancreas | Standard (baseline) | -| `gen_transfer_brain` | Pancreas | Brain | Zero-shot cross-tissue | -| `gen_transfer_platform` | Xenium | MERSCOPE | Zero-shot cross-platform | -| `gen_finetune_brain` | Pancreas -> Brain (finetune) | Brain | Few-epoch adaptation | - -**Why:** Tests whether Segger learns general spatial segmentation rules or memorizes pancreas-specific patterns. Transfer learning results determine whether per-tissue training is required. - ---- - -## Implementation Priority - -### Phase 1 (implemented in `run_ablation_study.sh`) - -These are now implemented as the 6 blocks of `scripts/run_ablation_study.sh`: - -1. **1A** Loss decomposition (6 runs) -- Block A: `abl_sg_only`, `abl_sg_tx`, `abl_sg_bd`, `abl_sg_tx_bd`, `abl_sg_align`, `abl_full` -2. **1B** Triplet vs BCE (2 runs) -- Block B: `abl_sgloss_triplet`, `abl_sgloss_bce` -3. **1C** Alignment weight sweep (5 runs) -- Block C: `abl_aw_0` through `abl_aw_03` -4. **3C** Positional embeddings on/off (1 run) -- Block D: `abl_no_pos` -5. **3D** Embedding normalization on/off (1 run) -- Block D: `abl_no_norm` -6. **4A** GNN depth (3 runs) -- Block D: `abl_depth_0`, `abl_depth_1`, `abl_depth_3` -7. **4B** Model width (2 runs) -- Block D: `abl_width_32`, `abl_width_128` -8. **4C** Attention heads (2 runs) -- Block D: `abl_heads_1`, `abl_heads_8` -9. **3B** Boundary features / morphology (1 run) -- Block D: `abl_morph` -10. **2C** Prediction graph mode (2 runs) -- Block E: `abl_pred_cell`, `abl_pred_uniform` -11. **Learning rate** (3 runs) -- Block F: `abl_lr_3e4`, `abl_lr_3e3`, `abl_lr_1e2` - -**Total: 28 runs**, fits in one overnight session on 2+ GPUs. The script auto-detects available GPUs (N-way round-robin) and each block can be toggled independently. - -**Not yet implemented from this tier:** -- **2A** Edge type removal (requires data module changes to selectively drop edge types) -- **1D** Loss weight schedule variants (requires new scheduler options) - -### Phase 2 (requires code changes) - -6. **3A** Gene embedding ablation -- needs a `--gene-embedding-mode` CLI parameter -7. **3B** Boundary feature ablation -- needs a null/zero mode for cell features -8. **4D** Skip connections -- needs `ist_encoder.py` modification -9. **1D** Loss schedule variants -- needs new scheduler options - -### Phase 3 (requires new datasets) - -10. **5A-5D** Cross-tissue and cross-platform generalization -11. **5C** Subsampling experiments - ---- - -## Expected Outcomes - -### What would change our recommendations - -| Finding | Implication | -|---------|------------| -| `sg_only` matches `full` on all metrics | Simplify to single-loss training, 3x faster | -| Alignment weight 0.1 >> 0.03 on MECR | Increase default alignment strength | -| Removing tx-tx edges doesn't hurt | Skip KDTree construction, 2x faster data prep | -| One-hot gene embeddings match scRNA PCA | Remove scRNA reference dependency (major UX win) | -| 32 channels match 64 channels | Deploy 4x smaller model | -| Cross-tissue transfer fails | Need per-tissue training protocol | -| 10% subsample still works | Segger viable for low-depth experiments | - -### What would confirm our design - -| Finding | Implication | -|---------|------------| -| Full multi-task loss >> sg_only | Multi-task learning is justified | -| Alignment improves MECR without hurting assigned % | ME-gene loss is well-calibrated | -| tx-tx edges significantly improve metrics | Local context is essential | -| scRNA embeddings >> one-hot | Biological priors are valuable | -| Stable across 3 repeats (CV < 5%) | Results are trustworthy | -| Cross-tissue transfer works | Architecture is general | - ---- - -## Reporting - -All ablation results flow through the same validation pipeline: - -``` -Train + Predict → segger_segmentation.parquet - → build_benchmark_validation_table.sh (metrics) - → benchmark_status_dashboard.sh (live monitoring) - → build_benchmark_pdf_report.py (publication figures) -``` - -The PDF report automatically ranks all runs (including ablations) by an overall normalized score across assigned %, MECR, contamination, TCO, and doublet %. Ablation results will appear directly in the bar charts and heatmaps alongside the parameter sweep results. diff --git a/scripts/run_ablation_study.sh b/scripts/run_ablation_study.sh deleted file mode 100755 index f3a42c2..0000000 --- a/scripts/run_ablation_study.sh +++ /dev/null @@ -1,1138 +0,0 @@ -#!/usr/bin/env bash -set -u -o pipefail - -# ------------------------------------------------------------------------- -# Segger comprehensive ablation study (auto-detect GPUs, N-way parallel) -# ------------------------------------------------------------------------- -# Systematically removes or swaps individual components (loss terms, -# architecture choices, features) to measure their contribution. -# -# Usage: -# bash scripts/run_ablation_study.sh -# -# Optional overrides (environment variables): -# INPUT_DIR=data/xe_pancreas_mossi/ -# OUTPUT_ROOT=./results/mossi_ablation_study -# NUM_GPUS= # Override auto-detected GPU count -# N_EPOCHS=20 -# RESUME_IF_EXISTS=1 -# DRY_RUN=0 -# SEGMENT_TIMEOUT_MIN=90 -# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad -# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type -# SEGMENT_NUM_WORKERS=8 -# SEGMENT_ANC_RETRY_WORKERS=0 -# TORCH_SHARING_STRATEGY=file_system -# RUN_VALIDATION_TABLE=1 -# VALIDATION_SCRIPT=scripts/build_benchmark_validation_table.sh -# -# Block toggles (set to 0 to skip): -# RUN_LOSS_ABLATION=1 -# RUN_SGLOSS_ABLATION=1 -# RUN_ALIGNMENT_SWEEP=1 -# RUN_ARCH_ABLATION=1 -# RUN_PREDICTION_ABLATION=1 -# RUN_LR_ABLATION=1 -# ------------------------------------------------------------------------- - -timestamp() { - date '+%Y-%m-%d %H:%M:%S' -} - -# ------------------------------------------------------------------------- -# GPU detection -# ------------------------------------------------------------------------- -detect_gpus() { - if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then - echo $(( $(echo "${CUDA_VISIBLE_DEVICES}" | tr ',' '\n' | wc -l) )) - elif command -v nvidia-smi >/dev/null 2>&1; then - nvidia-smi --list-gpus 2>/dev/null | wc -l - else - echo 1 - fi -} - -NUM_GPUS="${NUM_GPUS:-$(detect_gpus)}" -if [[ "${NUM_GPUS}" -lt 1 ]]; then - NUM_GPUS=1 -fi - -# Build array of GPU IDs (0..N-1, or from CUDA_VISIBLE_DEVICES). -GPU_IDS=() -if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then - IFS=',' read -ra GPU_IDS <<< "${CUDA_VISIBLE_DEVICES}" -else - for ((g = 0; g < NUM_GPUS; g++)); do - GPU_IDS+=("${g}") - done -fi - -# ------------------------------------------------------------------------- -# Paths and defaults -# ------------------------------------------------------------------------- -DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" -INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" -OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_ablation_study}" - -if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ - [[ ! -d "${INPUT_DIR}" ]] && \ - [[ -d "../data/xe_pancreas_mossi/" ]]; then - INPUT_DIR="../data/xe_pancreas_mossi/" -fi - -N_EPOCHS="${N_EPOCHS:-20}" -PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" - -BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" -BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" -XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" - -RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" -DRY_RUN="${DRY_RUN:-0}" -PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" -SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" -SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) -SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" -SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" -TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" - -# Alignment defaults (needed by anchor and alignment sweep). -ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" -ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" -ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" -ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" - -if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ - [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ - [[ -f "../data/ref_pancreas.h5ad" ]]; then - ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" -fi - -# ------------------------------------------------------------------------- -# Anchor configuration (defaults matching current best config) -# ------------------------------------------------------------------------- -ANCHOR_USE_3D="${ANCHOR_USE_3D:-true}" -ANCHOR_EXPANSION="${ANCHOR_EXPANSION:-2.5}" -ANCHOR_TX_K="${ANCHOR_TX_K:-5}" -ANCHOR_TX_DIST="${ANCHOR_TX_DIST:-20}" -ANCHOR_N_LAYERS="${ANCHOR_N_LAYERS:-2}" -ANCHOR_N_HEADS="${ANCHOR_N_HEADS:-4}" -ANCHOR_CELLS_MIN="${ANCHOR_CELLS_MIN:-5}" -ANCHOR_MIN_QV="${ANCHOR_MIN_QV:-0}" -ANCHOR_ALIGNMENT="${ANCHOR_ALIGNMENT:-true}" -ANCHOR_SG_LOSS="${ANCHOR_SG_LOSS:-triplet}" -ANCHOR_HIDDEN="${ANCHOR_HIDDEN:-64}" -ANCHOR_OUT="${ANCHOR_OUT:-64}" -ANCHOR_TX_WEIGHT="${ANCHOR_TX_WEIGHT:-1.0}" -ANCHOR_BD_WEIGHT="${ANCHOR_BD_WEIGHT:-1.0}" -ANCHOR_SG_WEIGHT="${ANCHOR_SG_WEIGHT:-0.5}" -ANCHOR_ALIGN_WEIGHT="${ANCHOR_ALIGN_WEIGHT:-0.03}" -ANCHOR_POS_EMB="${ANCHOR_POS_EMB:-true}" -ANCHOR_NORM_EMB="${ANCHOR_NORM_EMB:-true}" -ANCHOR_CELLS_REP="${ANCHOR_CELLS_REP:-pca}" -ANCHOR_LR="${ANCHOR_LR:-1e-3}" - -# ------------------------------------------------------------------------- -# Block toggles -# ------------------------------------------------------------------------- -RUN_LOSS_ABLATION="${RUN_LOSS_ABLATION:-1}" -RUN_SGLOSS_ABLATION="${RUN_SGLOSS_ABLATION:-1}" -RUN_ALIGNMENT_SWEEP="${RUN_ALIGNMENT_SWEEP:-1}" -RUN_ARCH_ABLATION="${RUN_ARCH_ABLATION:-1}" -RUN_PREDICTION_ABLATION="${RUN_PREDICTION_ABLATION:-1}" -RUN_LR_ABLATION="${RUN_LR_ABLATION:-1}" - -# ------------------------------------------------------------------------- -# Directories -# ------------------------------------------------------------------------- -RUNS_DIR="${OUTPUT_ROOT}/runs" -EXPORTS_DIR="${OUTPUT_ROOT}/exports" -LOGS_DIR="${OUTPUT_ROOT}/logs" -SUMMARY_DIR="${OUTPUT_ROOT}/summaries" -PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" -RUN_VALIDATION_TABLE="${RUN_VALIDATION_TABLE:-1}" -VALIDATION_SCRIPT="${VALIDATION_SCRIPT:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/build_benchmark_validation_table.sh}" -VALIDATION_INCLUDE_DEFAULT_10X="${VALIDATION_INCLUDE_DEFAULT_10X:-true}" - -mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" - -if [[ ! -d "${INPUT_DIR}" ]]; then - if [[ "${DRY_RUN}" == "1" ]]; then - echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" - else - echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" - exit 1 - fi -fi - -if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then - echo "ERROR: 'segger' command not found in PATH." - exit 1 -fi - -# Check alignment inputs (needed if any ablation uses alignment). -need_alignment_inputs=0 -if [[ "${RUN_LOSS_ABLATION}" == "1" ]] || \ - [[ "${RUN_ALIGNMENT_SWEEP}" == "1" ]] || \ - [[ "${ANCHOR_ALIGNMENT}" == "true" ]]; then - need_alignment_inputs=1 -fi - -if [[ "${need_alignment_inputs}" == "1" ]]; then - if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: Alignment ablation requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." - exit 1 - fi - if [[ "${DRY_RUN}" != "1" ]]; then - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" - exit 1 - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" - exit 1 - fi - fi -fi - -# ========================================================================= -# Job spec: extended 21-field pipe-delimited format -# ========================================================================= -# Fields: -# 1 job_name -# 2 use_3d -# 3 expansion -# 4 tx_k -# 5 tx_dist -# 6 n_layers -# 7 n_heads -# 8 cells_min_counts -# 9 min_qv -# 10 alignment_loss (true/false) -# 11 sg_loss_type (triplet/bce) -# 12 hidden_channels -# 13 out_channels -# 14 tx_weight_end -# 15 bd_weight_end -# 16 sg_weight_end -# 17 alignment_weight_end -# 18 positional_embeddings (true/false) -# 19 normalize_embeddings (true/false) -# 20 cells_representation (pca/morphology) -# 21 learning_rate -# ========================================================================= - -JOB_SPECS=() - -add_job() { - local job_name="$1" - local use_3d="${2:-${ANCHOR_USE_3D}}" - local expansion="${3:-${ANCHOR_EXPANSION}}" - local tx_k="${4:-${ANCHOR_TX_K}}" - local tx_dist="${5:-${ANCHOR_TX_DIST}}" - local n_layers="${6:-${ANCHOR_N_LAYERS}}" - local n_heads="${7:-${ANCHOR_N_HEADS}}" - local cells_min="${8:-${ANCHOR_CELLS_MIN}}" - local min_qv="${9:-${ANCHOR_MIN_QV}}" - local align="${10:-${ANCHOR_ALIGNMENT}}" - local sg_loss="${11:-${ANCHOR_SG_LOSS}}" - local hidden="${12:-${ANCHOR_HIDDEN}}" - local out="${13:-${ANCHOR_OUT}}" - local tx_w="${14:-${ANCHOR_TX_WEIGHT}}" - local bd_w="${15:-${ANCHOR_BD_WEIGHT}}" - local sg_w="${16:-${ANCHOR_SG_WEIGHT}}" - local align_w="${17:-${ANCHOR_ALIGN_WEIGHT}}" - local pos_emb="${18:-${ANCHOR_POS_EMB}}" - local norm_emb="${19:-${ANCHOR_NORM_EMB}}" - local cells_rep="${20:-${ANCHOR_CELLS_REP}}" - local lr="${21:-${ANCHOR_LR}}" - - JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min}|${min_qv}|${align}|${sg_loss}|${hidden}|${out}|${tx_w}|${bd_w}|${sg_w}|${align_w}|${pos_emb}|${norm_emb}|${cells_rep}|${lr}") -} - -# Helper: add_job with only overridden fields (positional anchor defaults). -# Usage: add_ablation_job [field=value ...] -# This is a convenience wrapper; for clarity each block calls add_job directly. - -job_block() { - local job_name="$1" - case "${job_name}" in - abl_sg_*|abl_full) echo "loss_decomposition" ;; - abl_sgloss_*) echo "sg_loss_type" ;; - abl_aw_*) echo "alignment_sweep" ;; - abl_depth_*|abl_width_*|abl_heads_*|abl_no_pos|abl_no_norm|abl_morph) echo "architecture" ;; - abl_pred_*) echo "prediction_mode" ;; - abl_lr_*) echo "learning_rate" ;; - *) echo "other" ;; - esac -} - -build_jobs() { - # ------------------------------------------------------------------- - # Block A: Loss decomposition (6 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_LOSS_ABLATION}" == "1" ]]; then - # sg only: tx=0, bd=0, no alignment - add_job "abl_sg_only" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "0" "0" "${ANCHOR_SG_WEIGHT}" "0" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # sg + transcript triplet: bd=0, no alignment - add_job "abl_sg_tx" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "0" "${ANCHOR_SG_WEIGHT}" "0" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # sg + boundary metric: tx=0, no alignment - add_job "abl_sg_bd" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "0" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "0" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # sg + both clustering: no alignment - add_job "abl_sg_tx_bd" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "false" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "0" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # sg + alignment only: tx=0, bd=0 - add_job "abl_sg_align" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "true" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "0" "0" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # full (anchor baseline) - add_job "abl_full" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - fi - - # ------------------------------------------------------------------- - # Block B: Segmentation loss type (2 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_SGLOSS_ABLATION}" == "1" ]]; then - add_job "abl_sgloss_triplet" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "triplet" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - add_job "abl_sgloss_bce" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "bce" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - fi - - # ------------------------------------------------------------------- - # Block C: Alignment weight sweep (5 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_ALIGNMENT_SWEEP}" == "1" ]]; then - local aw_values=(0 0.01 0.03 0.1 0.3) - local aw_tags=(0 001 003 01 03) - local aw_i - for aw_i in "${!aw_values[@]}"; do - local aw="${aw_values[$aw_i]}" - local aw_tag="${aw_tags[$aw_i]}" - local aw_align="true" - if [[ "${aw}" == "0" ]]; then - aw_align="false" - fi - add_job "abl_aw_${aw_tag}" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${aw_align}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${aw}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - done - fi - - # ------------------------------------------------------------------- - # Block D: Architecture ablation (10 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_ARCH_ABLATION}" == "1" ]]; then - # Depth: 0, 1, 3 mid layers - for depth in 0 1 3; do - add_job "abl_depth_${depth}" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${depth}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - done - - # Width: 32/32 and 128/128 - for width in 32 128; do - add_job "abl_width_${width}" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${width}" "${width}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - done - - # Heads: 1 and 8 - for heads in 1 8; do - add_job "abl_heads_${heads}" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${heads}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - done - - # No positional embeddings - add_job "abl_no_pos" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "false" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # No embedding normalization - add_job "abl_no_norm" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "false" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - # Morphology instead of PCA - add_job "abl_morph" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "morphology" "${ANCHOR_LR}" - fi - - # ------------------------------------------------------------------- - # Block E: Prediction mode (2 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_PREDICTION_ABLATION}" == "1" ]]; then - add_job "abl_pred_cell" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - - add_job "abl_pred_uniform" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "${ANCHOR_LR}" - fi - - # ------------------------------------------------------------------- - # Block F: Learning rate (3 jobs) - # ------------------------------------------------------------------- - if [[ "${RUN_LR_ABLATION}" == "1" ]]; then - add_job "abl_lr_3e4" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "3e-4" - - add_job "abl_lr_3e3" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "3e-3" - - add_job "abl_lr_1e2" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION}" "${ANCHOR_TX_K}" "${ANCHOR_TX_DIST}" \ - "${ANCHOR_N_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN}" "${ANCHOR_MIN_QV}" \ - "${ANCHOR_ALIGNMENT}" "${ANCHOR_SG_LOSS}" "${ANCHOR_HIDDEN}" "${ANCHOR_OUT}" \ - "${ANCHOR_TX_WEIGHT}" "${ANCHOR_BD_WEIGHT}" "${ANCHOR_SG_WEIGHT}" "${ANCHOR_ALIGN_WEIGHT}" \ - "${ANCHOR_POS_EMB}" "${ANCHOR_NORM_EMB}" "${ANCHOR_CELLS_REP}" "1e-2" - fi -} - -# ========================================================================= -# Helper functions (identical to run_robustness_ablation_2gpu.sh) -# ========================================================================= - -run_cmd() { - local log_file="$1" - shift - local -a cmd=("$@") - - { - printf '[%s] CMD:' "$(timestamp)" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - "${cmd[@]}" >> "${log_file}" 2>&1 -} - -run_cmd_with_timeout() { - local log_file="$1" - local timeout_seconds="$2" - shift 2 - local -a cmd=("$@") - - { - printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - if [[ "${timeout_seconds}" -le 0 ]]; then - "${cmd[@]}" >> "${log_file}" 2>&1 - return $? - fi - - local start_ts now elapsed - local cmd_pid timed_out rc - timed_out=0 - start_ts="$(date +%s)" - - "${cmd[@]}" >> "${log_file}" 2>&1 & - cmd_pid=$! - - while kill -0 "${cmd_pid}" 2>/dev/null; do - now="$(date +%s)" - elapsed=$((now - start_ts)) - if (( elapsed >= timeout_seconds )); then - timed_out=1 - echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" - kill -TERM "${cmd_pid}" 2>/dev/null || true - pkill -TERM -P "${cmd_pid}" 2>/dev/null || true - sleep 5 - kill -KILL "${cmd_pid}" 2>/dev/null || true - pkill -KILL -P "${cmd_pid}" 2>/dev/null || true - break - fi - sleep 10 - done - - wait "${cmd_pid}" - rc=$? - if (( timed_out == 1 )); then - return 124 - fi - return "${rc}" -} - -is_oom_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -is_ancdata_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -LAST_EXPORT_STATUS="ok" - -run_exports_for_job() { - local job_name="$1" - local seg_dir="$2" - local log_file="$3" - - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - - mkdir -p "${anndata_dir}" "${xenium_dir}" - - if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then - LAST_EXPORT_STATUS="missing_segmentation" - return 1 - fi - - if [[ ! -f "${anndata_file}" ]]; then - local -a anndata_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${anndata_dir}" - --format anndata - ) - if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then - LAST_EXPORT_STATUS="anndata_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" - fi - - if [[ ! -f "${xenium_file}" ]]; then - local -a xenium_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${xenium_dir}" - --format xenium_explorer - --boundary-method "${BOUNDARY_METHOD}" - --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" - --num-workers "${XENIUM_NUM_WORKERS}" - ) - if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then - LAST_EXPORT_STATUS="xenium_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" - fi - - LAST_EXPORT_STATUS="ok" - return 0 -} - -# ========================================================================= -# run_job — extended to handle 21-field spec -# ========================================================================= - -LAST_JOB_STATUS="unknown" - -run_job() { - local gpu="$1" - local spec="$2" - - local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv - local alignment_loss sg_loss_type hidden_channels out_channels - local tx_weight_end bd_weight_end sg_weight_end alignment_weight_end - local positional_embeddings normalize_embeddings cells_representation learning_rate - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv \ - alignment_loss sg_loss_type hidden_channels out_channels \ - tx_weight_end bd_weight_end sg_weight_end alignment_weight_end \ - positional_embeddings normalize_embeddings cells_representation learning_rate \ - <<< "${spec}" - - # Resolve prediction mode from job name (Block E override). - local job_prediction_mode="${PREDICTION_MODE}" - case "${job_name}" in - abl_pred_cell) job_prediction_mode="cell" ;; - abl_pred_uniform) job_prediction_mode="uniform" ;; - esac - - local seg_dir="${RUNS_DIR}/${job_name}" - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - - mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" - - { - echo "==================================================================" - echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" - echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} sg_loss=${sg_loss_type} hidden=${hidden_channels} out=${out_channels} tx_w=${tx_weight_end} bd_w=${bd_weight_end} sg_w=${sg_weight_end} align_w=${alignment_weight_end} pos_emb=${positional_embeddings} norm_emb=${normalize_embeddings} cells_rep=${cells_representation} lr=${learning_rate} pred_mode=${job_prediction_mode} timeout_min=${SEGMENT_TIMEOUT_MIN}" - } | tee -a "${log_file}" >/dev/null - - if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ - [[ -f "${seg_file}" ]] && \ - [[ -f "${anndata_file}" ]] && \ - [[ -f "${xenium_file}" ]]; then - echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="skipped_existing" - return 0 - fi - - if [[ ! -f "${seg_file}" ]]; then - # Build positional/normalize flags for cyclopts booleans. - local pos_flag="--use-positional-embeddings" - if [[ "${positional_embeddings}" == "false" ]]; then - pos_flag="--no-use-positional-embeddings" - fi - local norm_flag="--normalize-embeddings" - if [[ "${normalize_embeddings}" == "false" ]]; then - norm_flag="--no-normalize-embeddings" - fi - - local -a seg_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger segment - -i "${INPUT_DIR}" - -o "${seg_dir}" - --n-epochs "${N_EPOCHS}" - --prediction-mode "${job_prediction_mode}" - --prediction-expansion-ratio "${expansion}" - --cells-min-counts "${cells_min_counts}" - --min-qv "${min_qv}" - --use-3d "${use_3d}" - --transcripts-max-k "${tx_k}" - --transcripts-max-dist "${tx_dist}" - --n-mid-layers "${n_layers}" - --n-heads "${n_heads}" - --segmentation-loss "${sg_loss_type}" - --hidden-channels "${hidden_channels}" - --out-channels "${out_channels}" - --transcripts-loss-weight-end "${tx_weight_end}" - --cells-loss-weight-end "${bd_weight_end}" - --segmentation-loss-weight-end "${sg_weight_end}" - --learning-rate "${learning_rate}" - --cells-representation "${cells_representation}" - "${pos_flag}" - "${norm_flag}" - ) - - if [[ "${alignment_loss}" == "true" ]]; then - seg_cmd+=( - --alignment-loss - --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" - --alignment-loss-weight-end "${alignment_weight_end}" - ) - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - seg_cmd+=( - --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" - ) - fi - fi - - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" - local seg_rc=$? - if [[ "${seg_rc}" -ne 0 ]]; then - if [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - - if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - local -a seg_retry_cmd=("${seg_cmd[@]}") - local i - for i in "${!seg_retry_cmd[@]}"; do - if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then - seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" - break - fi - done - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" - seg_rc=$? - if [[ "${seg_rc}" -eq 0 ]]; then - echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - elif [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - fi - - if [[ "${seg_rc}" -eq 0 ]]; then - : - else - local last_ckpt="${seg_dir}/checkpoints/last.ckpt" - if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null - else - echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="predict_fallback_failed" - return 1 - fi - else - if is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_ancdata" - elif is_oom_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oom" - else - echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_failed" - fi - return 1 - fi - fi - fi - else - echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null - fi - - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" - return 1 - fi - - echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="ok" - return 0 -} - -# ========================================================================= -# GPU group runner -# ========================================================================= - -run_gpu_group() { - local gpu="$1" - shift - local -a indices=("$@") - local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" - - printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" - - local idx spec job_name start_ts end_ts elapsed_s - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r job_name _ <<< "${spec}" - - start_ts="$(date +%s)" - run_job "${gpu}" "${spec}" - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - - printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${LAST_JOB_STATUS}" \ - "${elapsed_s}" \ - "${RUNS_DIR}/${job_name}" \ - "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ - >> "${summary_file}" - done -} - -# ========================================================================= -# Post-run recovery (predict-only from checkpoints) -# ========================================================================= - -run_post_recovery_predict_only_group() { - local gpu="$1" - local out_file="$2" - shift 2 - local -a indices=("$@") - - printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" - - local idx spec job_name - local seg_dir seg_file last_ckpt log_file note status - local start_ts end_ts elapsed_s - - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r job_name _ <<< "${spec}" - - seg_dir="${RUNS_DIR}/${job_name}" - seg_file="${seg_dir}/segger_segmentation.parquet" - last_ckpt="${seg_dir}/checkpoints/last.ckpt" - log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - mkdir -p "${seg_dir}" - - start_ts="$(date +%s)" - note="" - status="ok" - - if [[ -f "${seg_file}" ]]; then - note="segmentation_exists" - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="${LAST_EXPORT_STATUS}" - note="exports_failed_after_existing_seg" - fi - else - if [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="recovered_predict_ok" - note="predict_only_from_last_ckpt" - else - status="${LAST_EXPORT_STATUS}" - note="predict_recovered_but_exports_failed" - fi - else - status="recovered_predict_failed" - note="predict_only_failed" - fi - else - status="recovery_no_checkpoint" - note="missing_seg_and_last_ckpt" - fi - fi - - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${status}" \ - "${elapsed_s}" \ - "${note}" \ - "${seg_dir}" \ - "${log_file}" \ - >> "${out_file}" - done -} - -run_post_recovery_predict_only() { - local recovery_file="${SUMMARY_DIR}/recovery.tsv" - local pids=() - local g gpu recovery_per_gpu - - # Single GPU: run sequentially with all indices. - if [[ "${NUM_GPUS}" -eq 1 ]]; then - local all_indices=() - for g in $(seq 0 $((NUM_GPUS - 1))); do - local -n arr="GPU_${g}_INDICES" - all_indices+=("${arr[@]}") - done - recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${GPU_IDS[0]}.tsv" - run_post_recovery_predict_only_group "${GPU_IDS[0]}" "${recovery_per_gpu}" "${all_indices[@]}" - cp "${recovery_per_gpu}" "${recovery_file}" - return - fi - - # Multi-GPU: run recovery groups in parallel. - for g in $(seq 0 $((NUM_GPUS - 1))); do - gpu="${GPU_IDS[$g]}" - recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${gpu}.tsv" - local -n arr="GPU_${g}_INDICES" - if [[ "${#arr[@]}" -gt 0 ]]; then - run_post_recovery_predict_only_group "${gpu}" "${recovery_per_gpu}" "${arr[@]}" & - pids+=($!) - fi - done - - for pid in "${pids[@]}"; do - wait "${pid}" - done - - # Merge recovery files. - local first=1 - for g in $(seq 0 $((NUM_GPUS - 1))); do - gpu="${GPU_IDS[$g]}" - recovery_per_gpu="${SUMMARY_DIR}/recovery.gpu${gpu}.tsv" - if [[ -f "${recovery_per_gpu}" ]]; then - if [[ "${first}" -eq 1 ]]; then - cat "${recovery_per_gpu}" > "${recovery_file}" - first=0 - else - tail -n +2 "${recovery_per_gpu}" >> "${recovery_file}" - fi - fi - done -} - -# ========================================================================= -# Build jobs and distribute across GPUs -# ========================================================================= - -build_jobs - -if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then - echo "ERROR: No ablation jobs were generated. Check block toggles (RUN_LOSS_ABLATION, etc.)." - exit 1 -fi - -# Create per-GPU index arrays with round-robin distribution. -for g in $(seq 0 $((NUM_GPUS - 1))); do - declare -a "GPU_${g}_INDICES=()" -done - -idx=0 -for spec in "${JOB_SPECS[@]}"; do - g=$((idx % NUM_GPUS)) - eval "GPU_${g}_INDICES+=(${idx})" - idx=$((idx + 1)) -done - -# ========================================================================= -# Write job plan TSV -# ========================================================================= -{ - printf "job\tstudy_block\tgpu_group\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\tsg_loss_type\thidden_channels\tout_channels\ttx_weight_end\tbd_weight_end\tsg_weight_end\talignment_weight_end\tpositional_embeddings\tnormalize_embeddings\tcells_representation\tlearning_rate\n" - for idx in "${!JOB_SPECS[@]}"; do - local_group=$((idx % NUM_GPUS)) - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv \ - alignment_loss sg_loss_type hidden_channels out_channels \ - tx_weight_end bd_weight_end sg_weight_end alignment_weight_end \ - positional_embeddings normalize_embeddings cells_representation learning_rate \ - <<< "${JOB_SPECS[$idx]}" - local_block="$(job_block "${job_name}")" - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" "${local_block}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ - "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" \ - "${sg_loss_type}" "${hidden_channels}" "${out_channels}" \ - "${tx_weight_end}" "${bd_weight_end}" "${sg_weight_end}" "${alignment_weight_end}" \ - "${positional_embeddings}" "${normalize_embeddings}" "${cells_representation}" "${learning_rate}" - done -} > "${PLAN_FILE}" - -echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} ablation jobs across ${NUM_GPUS} GPU(s)." -echo "[$(timestamp)] GPUs: ${GPU_IDS[*]}" -for g in $(seq 0 $((NUM_GPUS - 1))); do - eval "_count=\${#GPU_${g}_INDICES[@]}" - echo "[$(timestamp)] GPU ${GPU_IDS[$g]}: ${_count} jobs" -done -echo "[$(timestamp)] Job plan: ${PLAN_FILE}" -echo "[$(timestamp)] Logs: ${LOGS_DIR}" - -if [[ "${DRY_RUN}" == "1" ]]; then - echo "[$(timestamp)] DRY_RUN=1 — exiting without running jobs." - echo "" - echo "Job plan:" - column -t -s $'\t' "${PLAN_FILE}" 2>/dev/null || cat "${PLAN_FILE}" - exit 0 -fi - -# ========================================================================= -# Launch GPU groups in parallel -# ========================================================================= - -PIDS=() -for g in $(seq 0 $((NUM_GPUS - 1))); do - gpu="${GPU_IDS[$g]}" - eval "_arr=(\"\${GPU_${g}_INDICES[@]}\")" - if [[ "${#_arr[@]}" -gt 0 ]]; then - run_gpu_group "${gpu}" "${_arr[@]}" & - PIDS+=($!) - fi -done - -for pid in "${PIDS[@]}"; do - wait "${pid}" -done - -# ========================================================================= -# Post-run recovery pass -# ========================================================================= - -echo "[$(timestamp)] Starting post-run predict-only recovery pass..." -run_post_recovery_predict_only - -# ========================================================================= -# Combine summaries -# ========================================================================= - -COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" -if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" - ) -else - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" - ) -fi - -echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" -if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" -fi - -# ========================================================================= -# Validation table -# ========================================================================= - -if [[ "${RUN_VALIDATION_TABLE}" == "1" ]]; then - if [[ -f "${VALIDATION_SCRIPT}" ]]; then - echo "[$(timestamp)] Building validation metrics table..." - validation_log="${SUMMARY_DIR}/validation_metrics.log" - validation_cmd=( - bash "${VALIDATION_SCRIPT}" - --root "${OUTPUT_ROOT}" - --input-dir "${INPUT_DIR}" - --include-default-10x "${VALIDATION_INCLUDE_DEFAULT_10X}" - ) - # Pass first two GPU IDs for compatibility with validation script. - if [[ "${NUM_GPUS}" -ge 2 ]]; then - validation_cmd+=(--gpu-a "${GPU_IDS[0]}" --gpu-b "${GPU_IDS[1]}") - else - validation_cmd+=(--gpu-a "${GPU_IDS[0]}" --gpu-b "${GPU_IDS[0]}") - fi - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - validation_cmd+=(--me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - validation_cmd+=( - --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" - ) - fi - if "${validation_cmd[@]}" >> "${validation_log}" 2>&1; then - echo "[$(timestamp)] Validation table updated: ${OUTPUT_ROOT}/summaries/validation_metrics.tsv" - else - echo "[$(timestamp)] WARN: validation table build failed (see ${validation_log})" - fi - else - echo "[$(timestamp)] WARN: VALIDATION_SCRIPT not found: ${VALIDATION_SCRIPT}" - fi -fi - -echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" - -if [[ "${FAILED_COUNT}" -gt 0 ]]; then - exit 1 -fi diff --git a/scripts/run_param_benchmark_2gpu.sh b/scripts/run_param_benchmark_2gpu.sh deleted file mode 100755 index 8236231..0000000 --- a/scripts/run_param_benchmark_2gpu.sh +++ /dev/null @@ -1,764 +0,0 @@ -#!/usr/bin/env bash -set -u -o pipefail - -# ------------------------------------------------------------------------- -# Segger overnight benchmark runner (2 GPUs, 1 job per GPU at a time) -# ------------------------------------------------------------------------- -# Usage: -# bash scripts/run_param_benchmark_2gpu.sh -# -# Optional overrides (environment variables): -# INPUT_DIR=data/xe_pancreas_mossi/ -# OUTPUT_ROOT=./results/mossi_main_big_benchmark_nightly -# GPU_A=0 -# GPU_B=1 -# N_EPOCHS=20 -# INCLUDE_EXTRA_SWEEPS=1 -# RESUME_IF_EXISTS=1 -# DRY_RUN=0 -# SEGMENT_TIMEOUT_MIN=90 -# ALIGNMENT_LOSS=true -# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad -# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type -# SEGMENT_NUM_WORKERS=8 -# SEGMENT_ANC_RETRY_WORKERS=0 -# TORCH_SHARING_STRATEGY=file_system -# ------------------------------------------------------------------------- - -timestamp() { - date '+%Y-%m-%d %H:%M:%S' -} - -DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" -INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" -OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_main_big_benchmark_nightly}" - -# Common layout fallback when running from segger-0.2.0 with data one level up. -if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ - [[ ! -d "${INPUT_DIR}" ]] && \ - [[ -d "../data/xe_pancreas_mossi/" ]]; then - INPUT_DIR="../data/xe_pancreas_mossi/" -fi - -GPU_A="${GPU_A:-0}" -GPU_B="${GPU_B:-1}" - -N_EPOCHS="${N_EPOCHS:-20}" -PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" - -BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" -BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" -XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" - -INCLUDE_EXTRA_SWEEPS="${INCLUDE_EXTRA_SWEEPS:-1}" -RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" -DRY_RUN="${DRY_RUN:-0}" -PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" -SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" -SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) -SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" -SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" -TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" - -ALIGNMENT_LOSS="${ALIGNMENT_LOSS:-true}" -ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" -ALIGNMENT_LOSS_WEIGHT_END="${ALIGNMENT_LOSS_WEIGHT_END:-0.03}" -ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" -ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" -ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" - -# Common layout fallback when running from segger-0.2.0 with data one level up. -if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ - [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ - [[ -f "../data/ref_pancreas.h5ad" ]]; then - ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" -fi - -# Baseline values (matches the command you provided). -BASE_USE_3D="${BASE_USE_3D:-true}" -BASE_EXPANSION_RATIO="${BASE_EXPANSION_RATIO:-2.0}" -BASE_TX_MAX_K="${BASE_TX_MAX_K:-5}" -BASE_TX_MAX_DIST="${BASE_TX_MAX_DIST:-5}" -BASE_N_MID_LAYERS="${BASE_N_MID_LAYERS:-2}" -BASE_N_HEADS="${BASE_N_HEADS:-2}" -BASE_CELLS_MIN_COUNTS="${BASE_CELLS_MIN_COUNTS:-5}" -BASE_MIN_QV="${BASE_MIN_QV:-0}" - -# One-factor-at-a-time sweep values around baseline. -USE_3D_VALUES=(false true) -EXPANSION_VALUES=(1 1.5 2.0 2.5 3.0) -TX_MAX_K_VALUES=(5 10 20) -TX_MAX_DIST_VALUES=(3 5 10 20) -N_MID_LAYER_VALUES=(1 2 3) -N_HEAD_VALUES=(2 4 8) -CELLS_MIN_COUNTS_VALUES=(3 5 10) -ALIGNMENT_VALUES=(false true) - -RUNS_DIR="${OUTPUT_ROOT}/runs" -EXPORTS_DIR="${OUTPUT_ROOT}/exports" -LOGS_DIR="${OUTPUT_ROOT}/logs" -SUMMARY_DIR="${OUTPUT_ROOT}/summaries" -PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" - -mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" - -if [[ ! -d "${INPUT_DIR}" ]]; then - if [[ "${DRY_RUN}" == "1" ]]; then - echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" - else - echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" - exit 1 - fi -fi - -if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then - echo "ERROR: 'segger' command not found in PATH." - exit 1 -fi - -need_alignment_inputs=0 -if [[ "${ALIGNMENT_LOSS}" == "true" ]]; then - need_alignment_inputs=1 -elif [[ "${INCLUDE_EXTRA_SWEEPS}" == "1" ]]; then - for v in "${ALIGNMENT_VALUES[@]}"; do - if [[ "${v}" == "true" ]]; then - need_alignment_inputs=1 - break - fi - done -fi - -if [[ "${need_alignment_inputs}" == "1" ]]; then - if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: ALIGNMENT_LOSS=true requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." - exit 1 - fi - if [[ "${DRY_RUN}" != "1" ]]; then - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" - exit 1 - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" - exit 1 - fi - fi -fi - -JOB_SPECS=() - -add_job() { - local job_name="$1" - local use_3d="$2" - local expansion="$3" - local tx_k="$4" - local tx_dist="$5" - local n_layers="$6" - local n_heads="$7" - local cells_min_counts="$8" - local min_qv="$9" - local alignment_loss="${10}" - JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min_counts}|${min_qv}|${alignment_loss}") -} - -build_jobs() { - local v tag - - add_job \ - "baseline" \ - "${BASE_USE_3D}" \ - "${BASE_EXPANSION_RATIO}" \ - "${BASE_TX_MAX_K}" \ - "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" \ - "${BASE_N_HEADS}" \ - "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" \ - "${ALIGNMENT_LOSS}" - - for v in "${USE_3D_VALUES[@]}"; do - [[ "${v}" == "${BASE_USE_3D}" ]] && continue - add_job "use3d_${v}" \ - "${v}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${EXPANSION_VALUES[@]}"; do - [[ "${v}" == "${BASE_EXPANSION_RATIO}" ]] && continue - tag="${v//./p}" - add_job "expansion_${tag}" \ - "${BASE_USE_3D}" "${v}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${TX_MAX_K_VALUES[@]}"; do - [[ "${v}" == "${BASE_TX_MAX_K}" ]] && continue - add_job "txk_${v}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${v}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${TX_MAX_DIST_VALUES[@]}"; do - [[ "${v}" == "${BASE_TX_MAX_DIST}" ]] && continue - tag="${v//./p}" - add_job "txdist_${tag}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${v}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${N_MID_LAYER_VALUES[@]}"; do - [[ "${v}" == "${BASE_N_MID_LAYERS}" ]] && continue - add_job "layers_${v}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${v}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${N_HEAD_VALUES[@]}"; do - [[ "${v}" == "${BASE_N_HEADS}" ]] && continue - add_job "heads_${v}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${v}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for v in "${CELLS_MIN_COUNTS_VALUES[@]}"; do - [[ "${v}" == "${BASE_CELLS_MIN_COUNTS}" ]] && continue - add_job "cellsmin_${v}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${v}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - if [[ "${INCLUDE_EXTRA_SWEEPS}" == "1" ]]; then - for v in "${ALIGNMENT_VALUES[@]}"; do - [[ "${v}" == "${ALIGNMENT_LOSS}" ]] && continue - add_job "align_${v}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${v}" - done - fi -} - -run_cmd() { - local log_file="$1" - shift - local -a cmd=("$@") - - { - printf '[%s] CMD:' "$(timestamp)" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - "${cmd[@]}" >> "${log_file}" 2>&1 -} - -run_cmd_with_timeout() { - local log_file="$1" - local timeout_seconds="$2" - shift 2 - local -a cmd=("$@") - - { - printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - if [[ "${timeout_seconds}" -le 0 ]]; then - "${cmd[@]}" >> "${log_file}" 2>&1 - return $? - fi - - local start_ts now elapsed - local cmd_pid timed_out rc - timed_out=0 - start_ts="$(date +%s)" - - "${cmd[@]}" >> "${log_file}" 2>&1 & - cmd_pid=$! - - while kill -0 "${cmd_pid}" 2>/dev/null; do - now="$(date +%s)" - elapsed=$((now - start_ts)) - if (( elapsed >= timeout_seconds )); then - timed_out=1 - echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" - kill -TERM "${cmd_pid}" 2>/dev/null || true - pkill -TERM -P "${cmd_pid}" 2>/dev/null || true - sleep 5 - kill -KILL "${cmd_pid}" 2>/dev/null || true - pkill -KILL -P "${cmd_pid}" 2>/dev/null || true - break - fi - sleep 10 - done - - wait "${cmd_pid}" - rc=$? - if (( timed_out == 1 )); then - return 124 - fi - return "${rc}" -} - -is_oom_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -is_ancdata_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -LAST_EXPORT_STATUS="ok" - -run_exports_for_job() { - local job_name="$1" - local seg_dir="$2" - local log_file="$3" - - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - - mkdir -p "${anndata_dir}" "${xenium_dir}" - - if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then - LAST_EXPORT_STATUS="missing_segmentation" - return 1 - fi - - if [[ ! -f "${anndata_file}" ]]; then - local -a anndata_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${anndata_dir}" - --format anndata - ) - if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then - LAST_EXPORT_STATUS="anndata_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" - fi - - if [[ ! -f "${xenium_file}" ]]; then - local -a xenium_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${xenium_dir}" - --format xenium_explorer - --boundary-method "${BOUNDARY_METHOD}" - --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" - --num-workers "${XENIUM_NUM_WORKERS}" - ) - if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then - LAST_EXPORT_STATUS="xenium_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" - fi - - LAST_EXPORT_STATUS="ok" - return 0 -} - -LAST_JOB_STATUS="unknown" - -run_job() { - local gpu="$1" - local spec="$2" - - local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${spec}" - - local seg_dir="${RUNS_DIR}/${job_name}" - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - - mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" - - { - echo "==================================================================" - echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" - echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} timeout_min=${SEGMENT_TIMEOUT_MIN} dl_workers=${SEGMENT_NUM_WORKERS} anc_retry_workers=${SEGMENT_ANC_RETRY_WORKERS} sharing=${TORCH_SHARING_STRATEGY}" - } | tee -a "${log_file}" >/dev/null - - if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ - [[ -f "${seg_file}" ]] && \ - [[ -f "${anndata_file}" ]] && \ - [[ -f "${xenium_file}" ]]; then - echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="skipped_existing" - return 0 - fi - - if [[ ! -f "${seg_file}" ]]; then - local -a seg_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger segment - -i "${INPUT_DIR}" - -o "${seg_dir}" - --n-epochs "${N_EPOCHS}" - --prediction-mode "${PREDICTION_MODE}" - --prediction-expansion-ratio "${expansion}" - --cells-min-counts "${cells_min_counts}" - --min-qv "${min_qv}" - --use-3d "${use_3d}" - --transcripts-max-k "${tx_k}" - --transcripts-max-dist "${tx_dist}" - --n-mid-layers "${n_layers}" - --n-heads "${n_heads}" - ) - if [[ "${alignment_loss}" == "true" ]]; then - seg_cmd+=( - --alignment-loss - --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" - --alignment-loss-weight-end "${ALIGNMENT_LOSS_WEIGHT_END}" - ) - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - seg_cmd+=( - --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" - ) - fi - fi - - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" - local seg_rc=$? - if [[ "${seg_rc}" -ne 0 ]]; then - if [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - - if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - local -a seg_retry_cmd=("${seg_cmd[@]}") - local i - for i in "${!seg_retry_cmd[@]}"; do - if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then - seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" - break - fi - done - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" - seg_rc=$? - if [[ "${seg_rc}" -eq 0 ]]; then - echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - elif [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - fi - - if [[ "${seg_rc}" -eq 0 ]]; then - : - else - local last_ckpt="${seg_dir}/checkpoints/last.ckpt" - if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null - else - echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="predict_fallback_failed" - return 1 - fi - else - if is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_ancdata" - elif is_oom_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oom" - else - echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_failed" - fi - return 1 - fi - fi - fi - else - echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null - fi - - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" - return 1 - fi - - echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="ok" - return 0 -} - -run_gpu_group() { - local gpu="$1" - shift - local -a indices=("$@") - local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" - - printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" - - local idx spec job_name start_ts end_ts elapsed_s - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r job_name _ <<< "${spec}" - - start_ts="$(date +%s)" - run_job "${gpu}" "${spec}" - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - - printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${LAST_JOB_STATUS}" \ - "${elapsed_s}" \ - "${RUNS_DIR}/${job_name}" \ - "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ - >> "${summary_file}" - done -} - -run_post_recovery_predict_only_group() { - local gpu="$1" - local out_file="$2" - shift 2 - local -a indices=("$@") - - printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" - - local idx spec job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss - local seg_dir seg_file last_ckpt log_file note status - local start_ts end_ts elapsed_s - - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${spec}" - - seg_dir="${RUNS_DIR}/${job_name}" - seg_file="${seg_dir}/segger_segmentation.parquet" - last_ckpt="${seg_dir}/checkpoints/last.ckpt" - log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - mkdir -p "${seg_dir}" - - start_ts="$(date +%s)" - note="" - status="ok" - - if [[ -f "${seg_file}" ]]; then - note="segmentation_exists" - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="${LAST_EXPORT_STATUS}" - note="exports_failed_after_existing_seg" - fi - else - if [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="recovered_predict_ok" - note="predict_only_from_last_ckpt" - else - status="${LAST_EXPORT_STATUS}" - note="predict_recovered_but_exports_failed" - fi - else - status="recovered_predict_failed" - note="predict_only_failed" - fi - else - status="recovery_no_checkpoint" - note="missing_seg_and_last_ckpt" - fi - fi - - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${status}" \ - "${elapsed_s}" \ - "${note}" \ - "${seg_dir}" \ - "${log_file}" \ - >> "${out_file}" - done -} - -run_post_recovery_predict_only() { - local recovery_file="${SUMMARY_DIR}/recovery.tsv" - local recovery_a="${SUMMARY_DIR}/recovery.gpu${GPU_A}.tsv" - local recovery_b="${SUMMARY_DIR}/recovery.gpu${GPU_B}.tsv" - local pid_a pid_b - - if [[ "${GPU_A}" == "${GPU_B}" ]]; then - run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" "${GPU_B_INDICES[@]}" - cp "${recovery_a}" "${recovery_file}" - return - fi - - run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" & - pid_a=$! - run_post_recovery_predict_only_group "${GPU_B}" "${recovery_b}" "${GPU_B_INDICES[@]}" & - pid_b=$! - - wait "${pid_a}" - wait "${pid_b}" - - awk 'FNR==1 && NR!=1 {next} {print}' "${recovery_a}" "${recovery_b}" > "${recovery_file}" -} - -build_jobs - -if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then - echo "ERROR: No benchmark jobs were generated." - exit 1 -fi - -GPU_A_INDICES=() -GPU_B_INDICES=() - -idx=0 -for spec in "${JOB_SPECS[@]}"; do - if (( idx % 2 == 0 )); then - GPU_A_INDICES+=("${idx}") - else - GPU_B_INDICES+=("${idx}") - fi - idx=$((idx + 1)) -done - -{ - printf "job\tgroup\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" - for idx in "${!JOB_SPECS[@]}"; do - local_group="A" - if (( idx % 2 == 1 )); then - local_group="B" - fi - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${JOB_SPECS[$idx]}" - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ - "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" - done -} > "${PLAN_FILE}" - -echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} jobs." -echo "[$(timestamp)] Group A (GPU ${GPU_A}): ${#GPU_A_INDICES[@]} jobs" -echo "[$(timestamp)] Group B (GPU ${GPU_B}): ${#GPU_B_INDICES[@]} jobs" -echo "[$(timestamp)] Job plan: ${PLAN_FILE}" -echo "[$(timestamp)] Logs: ${LOGS_DIR}" - -run_gpu_group "${GPU_A}" "${GPU_A_INDICES[@]}" & -PID_A=$! -run_gpu_group "${GPU_B}" "${GPU_B_INDICES[@]}" & -PID_B=$! - -wait "${PID_A}" -wait "${PID_B}" - -if [[ "${DRY_RUN}" != "1" ]]; then - echo "[$(timestamp)] Starting post-run predict-only recovery pass..." - run_post_recovery_predict_only -fi - -COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" -if [[ "${DRY_RUN}" != "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" - ) -else - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" - ) -fi - -echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" -if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" -fi -echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" - -if [[ "${FAILED_COUNT}" -gt 0 ]]; then - exit 1 -fi diff --git a/scripts/run_robustness_ablation_2gpu.sh b/scripts/run_robustness_ablation_2gpu.sh deleted file mode 100755 index 685eb8e..0000000 --- a/scripts/run_robustness_ablation_2gpu.sh +++ /dev/null @@ -1,845 +0,0 @@ -#!/usr/bin/env bash -set -u -o pipefail - -# ------------------------------------------------------------------------- -# Segger robustness + ablation runner (2 GPUs, 1 job per GPU at a time) -# ------------------------------------------------------------------------- -# Usage: -# bash scripts/run_robustness_ablation_2gpu.sh -# -# Optional overrides (environment variables): -# INPUT_DIR=data/xe_pancreas_mossi/ -# OUTPUT_ROOT=./results/mossi_main_big_robustness_ablation -# GPU_A=0 -# GPU_B=1 -# N_EPOCHS=20 -# STABILITY_REPEATS=3 -# RUN_INTERACTION_GRID=1 -# RUN_STRESS_TESTS=1 -# RESUME_IF_EXISTS=1 -# DRY_RUN=0 -# SEGMENT_TIMEOUT_MIN=90 -# ALIGNMENT_LOSS=true -# ALIGNMENT_SCRNA_REFERENCE_PATH=data/ref_pancreas.h5ad -# ALIGNMENT_SCRNA_CELLTYPE_COLUMN=cell_type -# SEGMENT_NUM_WORKERS=8 -# SEGMENT_ANC_RETRY_WORKERS=0 -# TORCH_SHARING_STRATEGY=file_system -# RUN_VALIDATION_TABLE=1 -# VALIDATION_SCRIPT=scripts/build_benchmark_validation_table.sh -# ------------------------------------------------------------------------- - -timestamp() { - date '+%Y-%m-%d %H:%M:%S' -} - -DEFAULT_INPUT_DIR="data/xe_pancreas_mossi/" -INPUT_DIR="${INPUT_DIR:-${DEFAULT_INPUT_DIR}}" -OUTPUT_ROOT="${OUTPUT_ROOT:-./results/mossi_main_big_robustness_ablation}" - -# Common layout fallback when running from segger-0.2.0 with data one level up. -if [[ "${INPUT_DIR}" == "${DEFAULT_INPUT_DIR}" ]] && \ - [[ ! -d "${INPUT_DIR}" ]] && \ - [[ -d "../data/xe_pancreas_mossi/" ]]; then - INPUT_DIR="../data/xe_pancreas_mossi/" -fi - -GPU_A="${GPU_A:-0}" -GPU_B="${GPU_B:-1}" - -N_EPOCHS="${N_EPOCHS:-20}" -PREDICTION_MODE="${PREDICTION_MODE:-nucleus}" - -BOUNDARY_METHOD="${BOUNDARY_METHOD:-convex_hull}" -BOUNDARY_VOXEL_SIZE="${BOUNDARY_VOXEL_SIZE:-5}" -XENIUM_NUM_WORKERS="${XENIUM_NUM_WORKERS:-8}" - -RESUME_IF_EXISTS="${RESUME_IF_EXISTS:-1}" -DRY_RUN="${DRY_RUN:-0}" -PREDICT_FALLBACK_ON_OOM="${PREDICT_FALLBACK_ON_OOM:-1}" -SEGMENT_TIMEOUT_MIN="${SEGMENT_TIMEOUT_MIN:-90}" -SEGMENT_TIMEOUT_SEC=$((SEGMENT_TIMEOUT_MIN * 60)) -SEGMENT_NUM_WORKERS="${SEGMENT_NUM_WORKERS:-8}" -SEGMENT_ANC_RETRY_WORKERS="${SEGMENT_ANC_RETRY_WORKERS:-0}" -TORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY:-file_system}" - -ALIGNMENT_LOSS="${ALIGNMENT_LOSS:-true}" -ALIGNMENT_LOSS_WEIGHT_START="${ALIGNMENT_LOSS_WEIGHT_START:-0.0}" -ALIGNMENT_LOSS_WEIGHT_END="${ALIGNMENT_LOSS_WEIGHT_END:-0.03}" -ALIGNMENT_ME_GENE_PAIRS_PATH="${ALIGNMENT_ME_GENE_PAIRS_PATH:-}" -ALIGNMENT_SCRNA_REFERENCE_PATH="${ALIGNMENT_SCRNA_REFERENCE_PATH:-data/ref_pancreas.h5ad}" -ALIGNMENT_SCRNA_CELLTYPE_COLUMN="${ALIGNMENT_SCRNA_CELLTYPE_COLUMN:-cell_type}" - -# Common layout fallback when running from segger-0.2.0 with data one level up. -if [[ "${ALIGNMENT_SCRNA_REFERENCE_PATH}" == "data/ref_pancreas.h5ad" ]] && \ - [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && \ - [[ -f "../data/ref_pancreas.h5ad" ]]; then - ALIGNMENT_SCRNA_REFERENCE_PATH="../data/ref_pancreas.h5ad" -fi - -# Baseline values (legacy baseline). -BASE_USE_3D="${BASE_USE_3D:-true}" -BASE_EXPANSION_RATIO="${BASE_EXPANSION_RATIO:-2.0}" -BASE_TX_MAX_K="${BASE_TX_MAX_K:-5}" -BASE_TX_MAX_DIST="${BASE_TX_MAX_DIST:-5}" -BASE_N_MID_LAYERS="${BASE_N_MID_LAYERS:-2}" -BASE_N_HEADS="${BASE_N_HEADS:-2}" -BASE_CELLS_MIN_COUNTS="${BASE_CELLS_MIN_COUNTS:-5}" -BASE_MIN_QV="${BASE_MIN_QV:-0}" - -# Robust anchor values (derived from current validation trends). -ANCHOR_USE_3D="${ANCHOR_USE_3D:-true}" -ANCHOR_EXPANSION_RATIO="${ANCHOR_EXPANSION_RATIO:-2.5}" -ANCHOR_TX_MAX_K="${ANCHOR_TX_MAX_K:-5}" -ANCHOR_TX_MAX_DIST="${ANCHOR_TX_MAX_DIST:-20}" -ANCHOR_N_MID_LAYERS="${ANCHOR_N_MID_LAYERS:-2}" -ANCHOR_N_HEADS="${ANCHOR_N_HEADS:-4}" -ANCHOR_CELLS_MIN_COUNTS="${ANCHOR_CELLS_MIN_COUNTS:-5}" -ANCHOR_MIN_QV="${ANCHOR_MIN_QV:-0}" -ANCHOR_ALIGNMENT_LOSS="${ANCHOR_ALIGNMENT_LOSS:-true}" - -# High-sensitivity variant. -SENS_EXPANSION_RATIO="${SENS_EXPANSION_RATIO:-3.0}" - -# Study controls. -STABILITY_REPEATS="${STABILITY_REPEATS:-3}" -RUN_INTERACTION_GRID="${RUN_INTERACTION_GRID:-1}" -RUN_STRESS_TESTS="${RUN_STRESS_TESTS:-1}" - -# Interaction grid around high-performing region. -INTERACTION_EXPANSIONS=(2.5 3.0) -INTERACTION_TX_DISTS=(10 20) -INTERACTION_HEADS=(2 4) - -# Alignment ablation subset. -INTERACTION_ALIGN_VALUES=(true false) - -if ! [[ "${STABILITY_REPEATS}" =~ ^[0-9]+$ ]] || [[ "${STABILITY_REPEATS}" -lt 1 ]]; then - echo "ERROR: STABILITY_REPEATS must be a positive integer. Got: ${STABILITY_REPEATS}" - exit 1 -fi - -RUNS_DIR="${OUTPUT_ROOT}/runs" -EXPORTS_DIR="${OUTPUT_ROOT}/exports" -LOGS_DIR="${OUTPUT_ROOT}/logs" -SUMMARY_DIR="${OUTPUT_ROOT}/summaries" -PLAN_FILE="${OUTPUT_ROOT}/job_plan.tsv" -RUN_VALIDATION_TABLE="${RUN_VALIDATION_TABLE:-1}" -VALIDATION_SCRIPT="${VALIDATION_SCRIPT:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/build_benchmark_validation_table.sh}" -VALIDATION_INCLUDE_DEFAULT_10X="${VALIDATION_INCLUDE_DEFAULT_10X:-true}" - -mkdir -p "${RUNS_DIR}" "${EXPORTS_DIR}" "${LOGS_DIR}" "${SUMMARY_DIR}" - -if [[ ! -d "${INPUT_DIR}" ]]; then - if [[ "${DRY_RUN}" == "1" ]]; then - echo "WARN: INPUT_DIR does not exist (dry run only): ${INPUT_DIR}" - else - echo "ERROR: INPUT_DIR does not exist: ${INPUT_DIR}" - exit 1 - fi -fi - -if [[ "${DRY_RUN}" != "1" ]] && ! command -v segger >/dev/null 2>&1; then - echo "ERROR: 'segger' command not found in PATH." - exit 1 -fi - -need_alignment_inputs=0 -if [[ "${ALIGNMENT_LOSS}" == "true" ]]; then - need_alignment_inputs=1 -elif [[ "${ANCHOR_ALIGNMENT_LOSS}" == "true" ]]; then - need_alignment_inputs=1 -elif [[ "${RUN_INTERACTION_GRID}" == "1" ]]; then - need_alignment_inputs=1 -fi - -if [[ "${need_alignment_inputs}" == "1" ]]; then - if [[ -z "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ -z "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: ALIGNMENT_LOSS=true requires ALIGNMENT_ME_GENE_PAIRS_PATH or ALIGNMENT_SCRNA_REFERENCE_PATH." - exit 1 - fi - if [[ "${DRY_RUN}" != "1" ]]; then - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]] && [[ ! -f "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - echo "ERROR: ALIGNMENT_ME_GENE_PAIRS_PATH not found: ${ALIGNMENT_ME_GENE_PAIRS_PATH}" - exit 1 - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]] && [[ ! -f "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - echo "ERROR: ALIGNMENT_SCRNA_REFERENCE_PATH not found: ${ALIGNMENT_SCRNA_REFERENCE_PATH}" - exit 1 - fi - fi -fi - -JOB_SPECS=() - -add_job() { - local job_name="$1" - local use_3d="$2" - local expansion="$3" - local tx_k="$4" - local tx_dist="$5" - local n_layers="$6" - local n_heads="$7" - local cells_min_counts="$8" - local min_qv="$9" - local alignment_loss="${10}" - JOB_SPECS+=("${job_name}|${use_3d}|${expansion}|${tx_k}|${tx_dist}|${n_layers}|${n_heads}|${cells_min_counts}|${min_qv}|${alignment_loss}") -} - -job_block() { - local job_name="$1" - case "${job_name}" in - stbl_*) echo "stability" ;; - int_*) echo "interaction" ;; - stress_*) echo "stress" ;; - *) echo "other" ;; - esac -} - -build_jobs() { - local i exp dist heads align tag_exp tag_dist - - # ----------------------------------------------------------------------- - # Block A: stability / repeatability - # ----------------------------------------------------------------------- - for ((i = 1; i <= STABILITY_REPEATS; i++)); do - add_job "stbl_baseline_r${i}" \ - "${BASE_USE_3D}" "${BASE_EXPANSION_RATIO}" "${BASE_TX_MAX_K}" "${BASE_TX_MAX_DIST}" \ - "${BASE_N_MID_LAYERS}" "${BASE_N_HEADS}" "${BASE_CELLS_MIN_COUNTS}" \ - "${BASE_MIN_QV}" "${ALIGNMENT_LOSS}" - done - - for ((i = 1; i <= STABILITY_REPEATS; i++)); do - add_job "stbl_anchor_r${i}" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - done - - for ((i = 1; i <= 2; i++)); do - add_job "stbl_sens_r${i}" \ - "${ANCHOR_USE_3D}" "${SENS_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - done - - # ----------------------------------------------------------------------- - # Block B: interaction grid in high-performing region - # ----------------------------------------------------------------------- - if [[ "${RUN_INTERACTION_GRID}" == "1" ]]; then - for exp in "${INTERACTION_EXPANSIONS[@]}"; do - for dist in "${INTERACTION_TX_DISTS[@]}"; do - for heads in "${INTERACTION_HEADS[@]}"; do - tag_exp="${exp//./p}" - tag_dist="${dist//./p}" - add_job "int_e${tag_exp}_d${tag_dist}_h${heads}_aT" \ - "${ANCHOR_USE_3D}" "${exp}" "${ANCHOR_TX_MAX_K}" "${dist}" \ - "${ANCHOR_N_MID_LAYERS}" "${heads}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "true" - done - done - done - - # Alignment ablation on selected interaction corners (heads=4). - for exp in "${INTERACTION_EXPANSIONS[@]}"; do - for dist in "${INTERACTION_TX_DISTS[@]}"; do - for align in "${INTERACTION_ALIGN_VALUES[@]}"; do - [[ "${align}" == "true" ]] && continue - tag_exp="${exp//./p}" - tag_dist="${dist//./p}" - add_job "int_e${tag_exp}_d${tag_dist}_h4_aF" \ - "${ANCHOR_USE_3D}" "${exp}" "${ANCHOR_TX_MAX_K}" "${dist}" \ - "${ANCHOR_N_MID_LAYERS}" "4" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${align}" - done - done - done - fi - - # ----------------------------------------------------------------------- - # Block C: stress tests (robustness to practical shifts) - # ----------------------------------------------------------------------- - if [[ "${RUN_STRESS_TESTS}" == "1" ]]; then - add_job "stress_use3d_false_anchor" \ - "false" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - - add_job "stress_use3d_false_sens" \ - "false" "${SENS_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - - add_job "stress_cellsmin3_anchor" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "3" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - - add_job "stress_cellsmin10_anchor" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "10" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - - add_job "stress_txk20_anchor" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "20" "${ANCHOR_TX_MAX_DIST}" \ - "${ANCHOR_N_MID_LAYERS}" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - - add_job "stress_layers1_anchor" \ - "${ANCHOR_USE_3D}" "${ANCHOR_EXPANSION_RATIO}" "${ANCHOR_TX_MAX_K}" "${ANCHOR_TX_MAX_DIST}" \ - "1" "${ANCHOR_N_HEADS}" "${ANCHOR_CELLS_MIN_COUNTS}" \ - "${ANCHOR_MIN_QV}" "${ANCHOR_ALIGNMENT_LOSS}" - fi -} - -run_cmd() { - local log_file="$1" - shift - local -a cmd=("$@") - - { - printf '[%s] CMD:' "$(timestamp)" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - "${cmd[@]}" >> "${log_file}" 2>&1 -} - -run_cmd_with_timeout() { - local log_file="$1" - local timeout_seconds="$2" - shift 2 - local -a cmd=("$@") - - { - printf '[%s] CMD(timeout=%ss):' "$(timestamp)" "${timeout_seconds}" - printf ' %q' "${cmd[@]}" - printf '\n' - } >> "${log_file}" - - if [[ "${DRY_RUN}" == "1" ]]; then - return 0 - fi - - if [[ "${timeout_seconds}" -le 0 ]]; then - "${cmd[@]}" >> "${log_file}" 2>&1 - return $? - fi - - local start_ts now elapsed - local cmd_pid timed_out rc - timed_out=0 - start_ts="$(date +%s)" - - "${cmd[@]}" >> "${log_file}" 2>&1 & - cmd_pid=$! - - while kill -0 "${cmd_pid}" 2>/dev/null; do - now="$(date +%s)" - elapsed=$((now - start_ts)) - if (( elapsed >= timeout_seconds )); then - timed_out=1 - echo "[$(timestamp)] OOT: command exceeded ${timeout_seconds}s; terminating PID=${cmd_pid}" >> "${log_file}" - kill -TERM "${cmd_pid}" 2>/dev/null || true - pkill -TERM -P "${cmd_pid}" 2>/dev/null || true - sleep 5 - kill -KILL "${cmd_pid}" 2>/dev/null || true - pkill -KILL -P "${cmd_pid}" 2>/dev/null || true - break - fi - sleep 10 - done - - wait "${cmd_pid}" - rc=$? - if (( timed_out == 1 )); then - return 124 - fi - return "${rc}" -} - -is_oom_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="out of memory|cuda error: out of memory|cublas status alloc failed|cuda driver error.*memory" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -is_ancdata_failure() { - local log_file="$1" - if [[ ! -f "${log_file}" ]]; then - return 1 - fi - local pattern="received [0-9]+ items of ancdata|multiprocessing/resource_sharer\\.py" - if command -v rg >/dev/null 2>&1; then - rg -qi "${pattern}" "${log_file}" - else - grep -Eiq "${pattern}" "${log_file}" - fi -} - -LAST_EXPORT_STATUS="ok" - -run_exports_for_job() { - local job_name="$1" - local seg_dir="$2" - local log_file="$3" - - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - - mkdir -p "${anndata_dir}" "${xenium_dir}" - - if [[ ! -f "${seg_file}" ]] && [[ "${DRY_RUN}" != "1" ]]; then - LAST_EXPORT_STATUS="missing_segmentation" - return 1 - fi - - if [[ ! -f "${anndata_file}" ]]; then - local -a anndata_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${anndata_dir}" - --format anndata - ) - if ! run_cmd "${log_file}" "${anndata_cmd[@]}"; then - LAST_EXPORT_STATUS="anndata_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP anndata export (existing): ${anndata_file}" >> "${log_file}" - fi - - if [[ ! -f "${xenium_file}" ]]; then - local -a xenium_cmd=( - segger export - -s "${seg_file}" - -i "${INPUT_DIR}" - -o "${xenium_dir}" - --format xenium_explorer - --boundary-method "${BOUNDARY_METHOD}" - --boundary-voxel-size "${BOUNDARY_VOXEL_SIZE}" - --num-workers "${XENIUM_NUM_WORKERS}" - ) - if ! run_cmd "${log_file}" "${xenium_cmd[@]}"; then - LAST_EXPORT_STATUS="xenium_export_failed" - return 1 - fi - else - echo "[$(timestamp)] SKIP xenium export (existing): ${xenium_file}" >> "${log_file}" - fi - - LAST_EXPORT_STATUS="ok" - return 0 -} - -LAST_JOB_STATUS="unknown" - -run_job() { - local gpu="$1" - local spec="$2" - - local job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${spec}" - - local seg_dir="${RUNS_DIR}/${job_name}" - local seg_file="${seg_dir}/segger_segmentation.parquet" - local anndata_dir="${EXPORTS_DIR}/${job_name}/anndata" - local anndata_file="${anndata_dir}/segger_segmentation.h5ad" - local xenium_dir="${EXPORTS_DIR}/${job_name}/xenium_explorer" - local xenium_file="${xenium_dir}/seg_experiment.xenium" - local log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - - mkdir -p "${seg_dir}" "${anndata_dir}" "${xenium_dir}" - - { - echo "==================================================================" - echo "[$(timestamp)] START job=${job_name} gpu=${gpu}" - echo "params: use3d=${use_3d} expansion=${expansion} tx_k=${tx_k} tx_dist=${tx_dist} layers=${n_layers} heads=${n_heads} cells_min=${cells_min_counts} min_qv=${min_qv} align=${alignment_loss} timeout_min=${SEGMENT_TIMEOUT_MIN} dl_workers=${SEGMENT_NUM_WORKERS} anc_retry_workers=${SEGMENT_ANC_RETRY_WORKERS} sharing=${TORCH_SHARING_STRATEGY}" - } | tee -a "${log_file}" >/dev/null - - if [[ "${RESUME_IF_EXISTS}" == "1" ]] && \ - [[ -f "${seg_file}" ]] && \ - [[ -f "${anndata_file}" ]] && \ - [[ -f "${xenium_file}" ]]; then - echo "[$(timestamp)] SKIP job=${job_name} (all outputs already present)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="skipped_existing" - return 0 - fi - - if [[ ! -f "${seg_file}" ]]; then - local -a seg_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger segment - -i "${INPUT_DIR}" - -o "${seg_dir}" - --n-epochs "${N_EPOCHS}" - --prediction-mode "${PREDICTION_MODE}" - --prediction-expansion-ratio "${expansion}" - --cells-min-counts "${cells_min_counts}" - --min-qv "${min_qv}" - --use-3d "${use_3d}" - --transcripts-max-k "${tx_k}" - --transcripts-max-dist "${tx_dist}" - --n-mid-layers "${n_layers}" - --n-heads "${n_heads}" - ) - if [[ "${alignment_loss}" == "true" ]]; then - seg_cmd+=( - --alignment-loss - --alignment-loss-weight-start "${ALIGNMENT_LOSS_WEIGHT_START}" - --alignment-loss-weight-end "${ALIGNMENT_LOSS_WEIGHT_END}" - ) - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - seg_cmd+=(--alignment-me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - seg_cmd+=( - --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" - ) - fi - fi - - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_cmd[@]}" - local seg_rc=$? - if [[ "${seg_rc}" -ne 0 ]]; then - if [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - - if [[ "${SEGMENT_ANC_RETRY_WORKERS}" != "${SEGMENT_NUM_WORKERS}" ]] && is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] WARN job=${job_name} segment failed with ancdata; retrying with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - local -a seg_retry_cmd=("${seg_cmd[@]}") - local i - for i in "${!seg_retry_cmd[@]}"; do - if [[ "${seg_retry_cmd[$i]}" == SEGGER_NUM_WORKERS=* ]]; then - seg_retry_cmd[$i]="SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" - break - fi - done - run_cmd_with_timeout "${log_file}" "${SEGMENT_TIMEOUT_SEC}" "${seg_retry_cmd[@]}" - seg_rc=$? - if [[ "${seg_rc}" -eq 0 ]]; then - echo "[$(timestamp)] OK job=${job_name} segment retry succeeded with SEGGER_NUM_WORKERS=${SEGMENT_ANC_RETRY_WORKERS}" | tee -a "${log_file}" >/dev/null - elif [[ "${seg_rc}" -eq 124 ]]; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment_retry (OOT ${SEGMENT_TIMEOUT_MIN}m)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oot" - return 1 - fi - fi - - if [[ "${seg_rc}" -eq 0 ]]; then - : - else - local last_ckpt="${seg_dir}/checkpoints/last.ckpt" - if [[ "${PREDICT_FALLBACK_ON_OOM}" == "1" ]] && is_oom_failure "${log_file}" && [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] WARN job=${job_name} segment OOM; trying checkpoint predict fallback (${last_ckpt})" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - echo "[$(timestamp)] OK job=${job_name} predict fallback succeeded after OOM" | tee -a "${log_file}" >/dev/null - else - echo "[$(timestamp)] FAIL job=${job_name} step=predict_fallback_after_oom" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="predict_fallback_failed" - return 1 - fi - else - if is_ancdata_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (ancdata)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_ancdata" - elif is_oom_failure "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=segment (oom)" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_oom" - else - echo "[$(timestamp)] FAIL job=${job_name} step=segment" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="segment_failed" - fi - return 1 - fi - fi - fi - else - echo "[$(timestamp)] SKIP segmentation (existing): ${seg_file}" | tee -a "${log_file}" >/dev/null - fi - - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - echo "[$(timestamp)] FAIL job=${job_name} step=${LAST_EXPORT_STATUS}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="${LAST_EXPORT_STATUS}" - return 1 - fi - - echo "[$(timestamp)] DONE job=${job_name}" | tee -a "${log_file}" >/dev/null - LAST_JOB_STATUS="ok" - return 0 -} - -run_gpu_group() { - local gpu="$1" - shift - local -a indices=("$@") - local summary_file="${SUMMARY_DIR}/gpu${gpu}.tsv" - - printf "job\tgpu\tstatus\telapsed_s\tseg_dir\tlog_file\n" > "${summary_file}" - - local idx spec job_name start_ts end_ts elapsed_s - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r job_name _ <<< "${spec}" - - start_ts="$(date +%s)" - run_job "${gpu}" "${spec}" - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - - printf "%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${LAST_JOB_STATUS}" \ - "${elapsed_s}" \ - "${RUNS_DIR}/${job_name}" \ - "${LOGS_DIR}/${job_name}.gpu${gpu}.log" \ - >> "${summary_file}" - done -} - -run_post_recovery_predict_only_group() { - local gpu="$1" - local out_file="$2" - shift 2 - local -a indices=("$@") - - printf "job\tgpu\tstatus\telapsed_s\tnote\tseg_dir\tlog_file\n" > "${out_file}" - - local idx spec job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss - local seg_dir seg_file last_ckpt log_file note status - local start_ts end_ts elapsed_s - - for idx in "${indices[@]}"; do - spec="${JOB_SPECS[$idx]}" - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${spec}" - - seg_dir="${RUNS_DIR}/${job_name}" - seg_file="${seg_dir}/segger_segmentation.parquet" - last_ckpt="${seg_dir}/checkpoints/last.ckpt" - log_file="${LOGS_DIR}/${job_name}.gpu${gpu}.log" - mkdir -p "${seg_dir}" - - start_ts="$(date +%s)" - note="" - status="ok" - - if [[ -f "${seg_file}" ]]; then - note="segmentation_exists" - if ! run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="${LAST_EXPORT_STATUS}" - note="exports_failed_after_existing_seg" - fi - else - if [[ -f "${last_ckpt}" ]]; then - echo "[$(timestamp)] RECOVERY job=${job_name}: running predict-only from ${last_ckpt}" | tee -a "${log_file}" >/dev/null - local -a predict_cmd=( - env CUDA_VISIBLE_DEVICES="${gpu}" - PYTORCH_SHARING_STRATEGY="${TORCH_SHARING_STRATEGY}" - SEGGER_NUM_WORKERS="${SEGMENT_NUM_WORKERS}" - segger predict - -c "${last_ckpt}" - -i "${INPUT_DIR}" - -o "${seg_dir}" - ) - if run_cmd "${log_file}" "${predict_cmd[@]}"; then - if run_exports_for_job "${job_name}" "${seg_dir}" "${log_file}"; then - status="recovered_predict_ok" - note="predict_only_from_last_ckpt" - else - status="${LAST_EXPORT_STATUS}" - note="predict_recovered_but_exports_failed" - fi - else - status="recovered_predict_failed" - note="predict_only_failed" - fi - else - status="recovery_no_checkpoint" - note="missing_seg_and_last_ckpt" - fi - fi - - end_ts="$(date +%s)" - elapsed_s=$((end_ts - start_ts)) - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" \ - "${gpu}" \ - "${status}" \ - "${elapsed_s}" \ - "${note}" \ - "${seg_dir}" \ - "${log_file}" \ - >> "${out_file}" - done -} - -run_post_recovery_predict_only() { - local recovery_file="${SUMMARY_DIR}/recovery.tsv" - local recovery_a="${SUMMARY_DIR}/recovery.gpu${GPU_A}.tsv" - local recovery_b="${SUMMARY_DIR}/recovery.gpu${GPU_B}.tsv" - local pid_a pid_b - - if [[ "${GPU_A}" == "${GPU_B}" ]]; then - run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" "${GPU_B_INDICES[@]}" - cp "${recovery_a}" "${recovery_file}" - return - fi - - run_post_recovery_predict_only_group "${GPU_A}" "${recovery_a}" "${GPU_A_INDICES[@]}" & - pid_a=$! - run_post_recovery_predict_only_group "${GPU_B}" "${recovery_b}" "${GPU_B_INDICES[@]}" & - pid_b=$! - - wait "${pid_a}" - wait "${pid_b}" - - awk 'FNR==1 && NR!=1 {next} {print}' "${recovery_a}" "${recovery_b}" > "${recovery_file}" -} - -build_jobs - -if [[ "${#JOB_SPECS[@]}" -eq 0 ]]; then - echo "ERROR: No robustness/ablation jobs were generated." - exit 1 -fi - -GPU_A_INDICES=() -GPU_B_INDICES=() - -idx=0 -for spec in "${JOB_SPECS[@]}"; do - if (( idx % 2 == 0 )); then - GPU_A_INDICES+=("${idx}") - else - GPU_B_INDICES+=("${idx}") - fi - idx=$((idx + 1)) -done - -{ - printf "job\tstudy_block\tgroup\tuse_3d\texpansion\ttx_max_k\ttx_max_dist\tn_mid_layers\tn_heads\tcells_min_counts\tmin_qv\talignment_loss\n" - for idx in "${!JOB_SPECS[@]}"; do - local_group="A" - if (( idx % 2 == 1 )); then - local_group="B" - fi - IFS='|' read -r \ - job_name use_3d expansion tx_k tx_dist n_layers n_heads cells_min_counts min_qv alignment_loss \ - <<< "${JOB_SPECS[$idx]}" - local_block="$(job_block "${job_name}")" - printf "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" \ - "${job_name}" "${local_block}" "${local_group}" "${use_3d}" "${expansion}" "${tx_k}" "${tx_dist}" \ - "${n_layers}" "${n_heads}" "${cells_min_counts}" "${min_qv}" "${alignment_loss}" - done -} > "${PLAN_FILE}" - -echo "[$(timestamp)] Prepared ${#JOB_SPECS[@]} jobs." -echo "[$(timestamp)] Study blocks: stability + interaction + stress" -echo "[$(timestamp)] Group A (GPU ${GPU_A}): ${#GPU_A_INDICES[@]} jobs" -echo "[$(timestamp)] Group B (GPU ${GPU_B}): ${#GPU_B_INDICES[@]} jobs" -echo "[$(timestamp)] Job plan: ${PLAN_FILE}" -echo "[$(timestamp)] Logs: ${LOGS_DIR}" - -run_gpu_group "${GPU_A}" "${GPU_A_INDICES[@]}" & -PID_A=$! -run_gpu_group "${GPU_B}" "${GPU_B_INDICES[@]}" & -PID_B=$! - -wait "${PID_A}" -wait "${PID_B}" - -if [[ "${DRY_RUN}" != "1" ]]; then - echo "[$(timestamp)] Starting post-run predict-only recovery pass..." - run_post_recovery_predict_only -fi - -COMBINED_SUMMARY="${SUMMARY_DIR}/all_jobs.tsv" -if [[ "${DRY_RUN}" != "1" ]] && [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv "${SUMMARY_DIR}/recovery.tsv" > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="recovered_predict_ok" {c++} END{print c+0}' "${SUMMARY_DIR}/recovery.tsv" - ) -else - awk 'FNR==1 && NR!=1 {next} {print}' "${SUMMARY_DIR}"/gpu*.tsv > "${COMBINED_SUMMARY}" - FAILED_COUNT=$( - awk -F'\t' 'NR>1 && $3!="ok" && $3!="skipped_existing" {c++} END{print c+0}' "${COMBINED_SUMMARY}" - ) -fi - -echo "[$(timestamp)] Combined summary: ${COMBINED_SUMMARY}" -if [[ -f "${SUMMARY_DIR}/recovery.tsv" ]]; then - echo "[$(timestamp)] Recovery summary: ${SUMMARY_DIR}/recovery.tsv" -fi - -if [[ "${DRY_RUN}" != "1" ]] && [[ "${RUN_VALIDATION_TABLE}" == "1" ]]; then - if [[ -f "${VALIDATION_SCRIPT}" ]]; then - echo "[$(timestamp)] Building validation metrics table..." - validation_log="${SUMMARY_DIR}/validation_metrics.log" - validation_cmd=( - bash "${VALIDATION_SCRIPT}" - --root "${OUTPUT_ROOT}" - --input-dir "${INPUT_DIR}" - --gpu-a "${GPU_A}" - --gpu-b "${GPU_B}" - --include-default-10x "${VALIDATION_INCLUDE_DEFAULT_10X}" - ) - if [[ -n "${ALIGNMENT_ME_GENE_PAIRS_PATH}" ]]; then - validation_cmd+=(--me-gene-pairs-path "${ALIGNMENT_ME_GENE_PAIRS_PATH}") - fi - if [[ -n "${ALIGNMENT_SCRNA_REFERENCE_PATH}" ]]; then - validation_cmd+=( - --scrna-reference-path "${ALIGNMENT_SCRNA_REFERENCE_PATH}" - --scrna-celltype-column "${ALIGNMENT_SCRNA_CELLTYPE_COLUMN}" - ) - fi - if "${validation_cmd[@]}" >> "${validation_log}" 2>&1; then - echo "[$(timestamp)] Validation table updated: ${OUTPUT_ROOT}/summaries/validation_metrics.tsv" - else - echo "[$(timestamp)] WARN: validation table build failed (see ${validation_log})" - fi - else - echo "[$(timestamp)] WARN: VALIDATION_SCRIPT not found: ${VALIDATION_SCRIPT}" - fi -fi - -echo "[$(timestamp)] Failed jobs: ${FAILED_COUNT}" - -if [[ "${FAILED_COUNT}" -gt 0 ]]; then - exit 1 -fi From 59d91fb8772cd2204623129b39f062242eb6674d Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 19:01:58 +0200 Subject: [PATCH 17/20] Remove sopa support from modules and dependencies --- pyproject.toml | 5 -- src/segger/utils/__init__.py | 10 ---- src/segger/utils/optional_deps.py | 95 +------------------------------ 3 files changed, 1 insertion(+), 109 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f336294..972d3b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,11 +44,6 @@ spatialdata-io = [ "spatialdata-io>=0.6.0", ] -sopa = [ - "sopa>=2.0.0", - "spatialdata>=0.7.2", -] - spatialdata-all = [ "spatialdata>=0.7.2", "spatialdata-io>=0.6.0", diff --git a/src/segger/utils/__init__.py b/src/segger/utils/__init__.py index 4ae2257..78a5eb6 100644 --- a/src/segger/utils/__init__.py +++ b/src/segger/utils/__init__.py @@ -23,25 +23,20 @@ def setup_logging(level: str = "WARNING", log_file: str = None): # Availability flags SPATIALDATA_AVAILABLE, SPATIALDATA_IO_AVAILABLE, - SOPA_AVAILABLE, # Import functions (raise ImportError if missing) require_spatialdata, require_spatialdata_io, - require_sopa, # Decorators for functions requiring optional deps requires_spatialdata, requires_spatialdata_io, - requires_sopa, # Warning functions for soft failures warn_spatialdata_unavailable, warn_spatialdata_io_unavailable, - warn_sopa_unavailable, warn_rapids_unavailable, # RAPIDS helpers require_rapids, # Version utilities get_spatialdata_version, - get_sopa_version, check_spatialdata_version, ) @@ -49,23 +44,18 @@ def setup_logging(level: str = "WARNING", log_file: str = None): # Availability flags "SPATIALDATA_AVAILABLE", "SPATIALDATA_IO_AVAILABLE", - "SOPA_AVAILABLE", # Import functions "require_spatialdata", "require_spatialdata_io", - "require_sopa", # Decorators "requires_spatialdata", "requires_spatialdata_io", - "requires_sopa", # Warning functions "warn_spatialdata_unavailable", "warn_spatialdata_io_unavailable", - "warn_sopa_unavailable", "warn_rapids_unavailable", "require_rapids", # Version utilities "get_spatialdata_version", - "get_sopa_version", "check_spatialdata_version", ] diff --git a/src/segger/utils/optional_deps.py b/src/segger/utils/optional_deps.py index b7231ca..72b9553 100644 --- a/src/segger/utils/optional_deps.py +++ b/src/segger/utils/optional_deps.py @@ -1,7 +1,7 @@ """Optional dependency handling with informative warnings. This module provides lazy import wrappers for optional dependencies -(spatialdata, spatialdata-io, sopa) with clear installation instructions +(spatialdata, spatialdata-io) with clear installation instructions when the dependencies are not available. Usage @@ -58,18 +58,9 @@ def _check_spatialdata_io() -> bool: return False -def _check_sopa() -> bool: - """Check if sopa is available.""" - try: - return importlib.util.find_spec("sopa") is not None - except Exception: - return False - - # Availability flags (evaluated once at import time) SPATIALDATA_AVAILABLE: bool = _check_spatialdata() SPATIALDATA_IO_AVAILABLE: bool = _check_spatialdata_io() -SOPA_AVAILABLE: bool = _check_sopa() # ----------------------------------------------------------------------------- @@ -100,24 +91,6 @@ def _check_sopa() -> bool: pip install spatialdata-io>=0.6.0 """ -SOPA_INSTALL_MSG = """ -sopa is not installed. This package is required for SOPA compatibility features. - -To install SOPA support: - pip install segger[sopa] - -Or install sopa directly: - pip install sopa>=2.0.0 - -For all SpatialData features including SOPA: - pip install segger[spatialdata-all] -""" - -RAPIDS_INSTALL_MSG = """ -RAPIDS GPU packages are not installed. Segger requires CuPy/cuDF/cuML/cuGraph/cuSpatial and a CUDA-enabled GPU. - -See docs/INSTALLATION.md for RAPIDS/CUDA setup. -""" # ----------------------------------------------------------------------------- @@ -162,25 +135,6 @@ def require_spatialdata_io() -> "types.ModuleType": return spatialdata_io -def require_sopa() -> "types.ModuleType": - """Import and return sopa, raising ImportError if not available. - - Returns - ------- - types.ModuleType - The sopa module. - - Raises - ------ - ImportError - If sopa is not installed, with installation instructions. - """ - if not SOPA_AVAILABLE: - raise ImportError(SOPA_INSTALL_MSG) - import sopa - return sopa - - # ----------------------------------------------------------------------------- # Decorators for requiring optional dependencies # ----------------------------------------------------------------------------- @@ -232,26 +186,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper # type: ignore[return-value] -def requires_sopa(func: F) -> F: - """Decorator that raises ImportError if sopa is not available. - - Parameters - ---------- - func - Function that requires sopa. - - Returns - ------- - F - Wrapped function that checks for sopa before execution. - """ - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - require_sopa() - return func(*args, **kwargs) - return wrapper # type: ignore[return-value] - - # ----------------------------------------------------------------------------- # Warning functions for soft failures # ----------------------------------------------------------------------------- @@ -288,22 +222,6 @@ def warn_spatialdata_io_unavailable(feature: str = "Platform-specific SpatialDat ) -def warn_sopa_unavailable(feature: str = "SOPA compatibility") -> None: - """Emit a warning that sopa is not available. - - Parameters - ---------- - feature - Description of the feature requiring sopa. - """ - warnings.warn( - f"{feature} requires sopa. " - "Install with: pip install segger[sopa]", - UserWarning, - stacklevel=2, - ) - - def _import_optional_packages(packages: list[str]) -> tuple[dict[str, "types.ModuleType"], list[str]]: """Import optional packages and return (modules, missing).""" modules: dict[str, "types.ModuleType"] = {} @@ -366,17 +284,6 @@ def get_spatialdata_version() -> str | None: return None -def get_sopa_version() -> str | None: - """Get the installed sopa version, or None if not installed.""" - if not SOPA_AVAILABLE: - return None - try: - import sopa - return getattr(sopa, "__version__", "unknown") - except Exception: - return None - - def check_spatialdata_version(min_version: str = "0.7.2") -> bool: """Check if spatialdata version meets minimum requirement. From 733c24bfbef1a526eaf52db4f29ec15bcfaf16eb Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 19:21:30 +0200 Subject: [PATCH 18/20] Remove additional modules from v2-incremental --- src/segger/export/xenium.py | 862 ------------------- src/segger/models/alignment_loss.py | 118 --- src/segger/validation/__init__.py | 14 - src/segger/validation/me_genes.py | 421 ---------- src/segger/validation/quick_metrics.py | 1050 ------------------------ 5 files changed, 2465 deletions(-) delete mode 100644 src/segger/export/xenium.py delete mode 100644 src/segger/models/alignment_loss.py delete mode 100644 src/segger/validation/__init__.py delete mode 100644 src/segger/validation/me_genes.py delete mode 100644 src/segger/validation/quick_metrics.py diff --git a/src/segger/export/xenium.py b/src/segger/export/xenium.py deleted file mode 100644 index 3ab1ebe..0000000 --- a/src/segger/export/xenium.py +++ /dev/null @@ -1,862 +0,0 @@ -"""Xenium Explorer export functionality. - -This module converts segmentation results into Xenium Explorer-compatible -Zarr format for visualization and validation. -""" - -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union -import json -from concurrent.futures.process import BrokenProcessPool - -import numpy as np -import pandas as pd -import polars as pl -import zarr -from pqdm.processes import pqdm as pqdm_processes -try: - from pqdm.threads import pqdm as pqdm_threads -except Exception: - pqdm_threads = None -from shapely.geometry import MultiPoint, MultiPolygon, Polygon -from tqdm import tqdm -from zarr.storage import ZipStore - -from .boundary import extract_largest_polygon, generate_boundary - - -def _normalize_polygon_vertices( - polygon: Polygon, - max_vertices: int, -) -> Tuple[List[Tuple[float, float]], int]: - """Normalize polygon vertices to a fixed length with closure. - - Returns a list of vertices padded/truncated to ``max_vertices`` and the - true number of vertices including the closing vertex. - """ - coords = list(polygon.exterior.coords) - # Remove duplicate closing vertex - if coords[0] == coords[-1]: - coords = coords[:-1] - - if len(coords) < 3: - return [], 0 - - num_vertices = len(coords) + 1 # include closing vertex - target = max_vertices - 1 - - if len(coords) > target: - indices = np.linspace(0, len(coords) - 1, target, dtype=int) - coords = [coords[i] for i in indices] - - # Close polygon and pad - coords.append(coords[0]) - if len(coords) < max_vertices: - coords += [coords[0]] * (max_vertices - len(coords)) - - return coords, num_vertices - - -def _safe_boundary_polygon( - seg_cell: pd.DataFrame, - x: str, - y: str, - boundary_method: str = "delaunay", - boundary_voxel_size: float = 0.0, -) -> Optional[Polygon]: - """Generate a robust polygon boundary for a cell. - - Uses the requested boundary method with robust fallbacks. - """ - if boundary_method in {"convex_hull", "input"}: - mp = MultiPoint(seg_cell[[x, y]].values) - cell_poly = mp.convex_hull if not mp.is_empty else None - elif boundary_method == "voxel": - if boundary_voxel_size <= 0: - return None - points = seg_cell[[x, y]].to_numpy(dtype=np.float64) - if len(points) < 3: - return None - mins = points.min(axis=0) - bins = np.floor((points - mins) / boundary_voxel_size).astype(np.int64) - _, keep = np.unique(bins, axis=0, return_index=True) - reduced = points[np.sort(keep)] - if len(reduced) < 3: - return None - mp = MultiPoint(reduced) - cell_poly = mp.convex_hull if not mp.is_empty else None - else: - working = seg_cell - if boundary_voxel_size > 0: - points = seg_cell[[x, y]].to_numpy(dtype=np.float64) - mins = points.min(axis=0) - bins = np.floor((points - mins) / boundary_voxel_size).astype(np.int64) - _, keep = np.unique(bins, axis=0, return_index=True) - working = seg_cell.iloc[np.sort(keep)] - - try: - cell_poly = generate_boundary(working, x=x, y=y) - if isinstance(cell_poly, MultiPolygon): - cell_poly = extract_largest_polygon(cell_poly) - except Exception: - cell_poly = None - - if cell_poly is None or not isinstance(cell_poly, Polygon) or cell_poly.is_empty: - # Fallback: convex hull of points - mp = MultiPoint(seg_cell[[x, y]].values) - cell_poly = mp.convex_hull if not mp.is_empty else None - - if cell_poly is None or not isinstance(cell_poly, Polygon) or cell_poly.is_empty: - return None - - return cell_poly - - -def _prepare_input_boundaries( - boundaries, - boundary_id_column: str = "cell_id", - boundary_type_column: str = "boundary_type", - boundary_cell_value: str = "cell", - boundary_nucleus_value: str = "nucleus", -) -> Tuple[Dict[Any, Polygon], Dict[Any, Polygon]]: - """Prepare lookup tables for input cell/nucleus boundaries.""" - if boundaries is None: - return {}, {} - - gdf = boundaries - if boundary_id_column not in gdf.columns: - if gdf.index.name == boundary_id_column: - gdf = gdf.reset_index() - else: - return {}, {} - - def _pick_largest(group): - largest = None - max_area = -1.0 - for geom in group.geometry: - if geom is None or getattr(geom, "is_empty", True): - continue - if isinstance(geom, MultiPolygon): - geom = extract_largest_polygon(geom) - if not isinstance(geom, Polygon) or geom is None or geom.is_empty: - continue - area = float(geom.area) - if area > max_area: - max_area = area - largest = geom - return largest - - if boundary_type_column in gdf.columns: - cells = gdf[gdf[boundary_type_column] == boundary_cell_value] - nuclei = gdf[gdf[boundary_type_column] == boundary_nucleus_value] - else: - cells = gdf - nuclei = gdf.iloc[0:0] - - cell_lookup: Dict[Any, Polygon] = {} - for cell_id, group in cells.groupby(boundary_id_column): - poly = _pick_largest(group) - if poly is not None: - cell_lookup[cell_id] = poly - - nucleus_lookup: Dict[Any, Polygon] = {} - for cell_id, group in nuclei.groupby(boundary_id_column): - poly = _pick_largest(group) - if poly is not None: - nucleus_lookup[cell_id] = poly - - return cell_lookup, nucleus_lookup - - -def get_indices_indptr(input_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """Get sparse matrix representation for cluster assignments. - - Parameters - ---------- - input_array : np.ndarray - Array of cluster labels. - - Returns - ------- - Tuple[np.ndarray, np.ndarray] - Indices and indptr arrays for CSR-like representation. - """ - clusters = sorted(np.unique(input_array[input_array != 0])) - indptr = np.zeros(len(clusters), dtype=np.uint32) - indices = [] - - for cluster in clusters: - cluster_indices = np.where(input_array == cluster)[0] - indptr[cluster - 1] = len(indices) - indices.extend(cluster_indices) - - indices.extend(-np.zeros(len(input_array[input_array == 0]))) - indices = np.array(indices, dtype=np.int32).astype(np.uint32) - return indices, indptr - - -def generate_experiment_file( - template_path: Path, - output_path: Path, - cells_name: str = "seg_cells", - analysis_name: str = "seg_analysis", -) -> None: - """Generate Xenium experiment manifest file. - - Parameters - ---------- - template_path : Path - Path to template experiment.xenium file. - output_path : Path - Path for output experiment file. - cells_name : str - Name of cells Zarr file (without extension). - analysis_name : str - Name of analysis Zarr file (without extension). - Notes - ----- - We only replace the cells and analysis Zarr paths, preserving all other - entries (including morphology image references). This keeps multi-channel - morphology_focus image stacks intact for segmentation kit datasets. - """ - with open(template_path) as f: - experiment = json.load(f) - - experiment["xenium_explorer_files"]["cells_zarr_filepath"] = f"{cells_name}.zarr.zip" - experiment["xenium_explorer_files"].pop("cell_features_zarr_filepath", None) - experiment["xenium_explorer_files"]["analysis_zarr_filepath"] = f"{analysis_name}.zarr.zip" - - with open(output_path, "w") as f: - json.dump(experiment, f, indent=2) - - -def seg2explorer( - seg_df: Union[pd.DataFrame, pl.DataFrame], - source_path: Union[str, Path], - output_dir: Union[str, Path], - cells_filename: str = "seg_cells", - analysis_filename: str = "seg_analysis", - xenium_filename: str = "seg_experiment.xenium", - analysis_df: Optional[pd.DataFrame] = None, - cell_id_column: str = "seg_cell_id", - x_column: str = "x", - y_column: str = "y", - z_column: Optional[str] = "z", - nucleus_column: Optional[str] = "cell_compartment", - nucleus_value: int = 2, - area_low: float = 10, - area_high: float = 100, - polygon_max_vertices: int = 13, - boundary_method: str = "delaunay", - boundary_voxel_size: float = 0.0, - boundaries: Optional["gpd.GeoDataFrame"] = None, - boundary_id_column: str = "cell_id", - boundary_type_column: str = "boundary_type", - boundary_cell_value: str = "cell", - boundary_nucleus_value: str = "nucleus", - cell_id_columns: Optional[str] = None, -) -> None: - """Convert segmentation results to Xenium Explorer format. - - Parameters - ---------- - seg_df : Union[pd.DataFrame, pl.DataFrame] - Segmented transcript DataFrame with cell assignments. - source_path : Union[str, Path] - Path to source Xenium data directory. - output_dir : Union[str, Path] - Output directory for Zarr files. - cells_filename : str - Filename prefix for cells Zarr. - analysis_filename : str - Filename prefix for analysis Zarr. - xenium_filename : str - Filename for experiment manifest. - analysis_df : Optional[pd.DataFrame] - Optional clustering/annotation DataFrame. - cell_id_column : str - Column name for cell IDs. - x_column : str - Column name for x coordinates. - y_column : str - Column name for y coordinates. - z_column : Optional[str] - Column name for z coordinates (if available). - nucleus_column : Optional[str] - Column name for nucleus/compartment assignment. - nucleus_value : int - Value indicating nuclear compartment. - area_low : float - Minimum cell area threshold. - area_high : float - Maximum cell area threshold. - polygon_max_vertices : int - Maximum number of vertices per polygon (including closure). - """ - if cell_id_columns is not None: - cell_id_column = cell_id_columns - - if boundary_method == "skip": - raise ValueError("boundary_method='skip' is not supported for Xenium export.") - - # Convert Polars to pandas - if isinstance(seg_df, pl.DataFrame): - seg_df = seg_df.to_pandas() - - source_path = Path(source_path) - storage = Path(output_dir) - storage.mkdir(parents=True, exist_ok=True) - - cell_boundaries: Dict[Any, Polygon] = {} - nucleus_boundaries: Dict[Any, Polygon] = {} - if boundary_method == "input": - cell_boundaries, nucleus_boundaries = _prepare_input_boundaries( - boundaries=boundaries, - boundary_id_column=boundary_id_column, - boundary_type_column=boundary_type_column, - boundary_cell_value=boundary_cell_value, - boundary_nucleus_value=boundary_nucleus_value, - ) - - # Drop unassigned cells if numeric - if cell_id_column in seg_df.columns: - if pd.api.types.is_numeric_dtype(seg_df[cell_id_column]): - seg_df = seg_df[seg_df[cell_id_column] >= 0] - else: - seg_df = seg_df[seg_df[cell_id_column].notna()] - - cell_id2old_id: Dict[int, Any] = {} - cell_id: List[int] = [] - cell_summary_rows: List[List[float]] = [] - cell_num_vertices: List[int] = [] - nucleus_num_vertices: List[int] = [] - cell_vertices: List[List[Tuple[float, float]]] = [] - nucleus_vertices: List[List[Tuple[float, float]]] = [] - - grouped_by = seg_df.groupby(cell_id_column) - - for cell_incremental_id, (seg_cell_id, seg_cell) in tqdm( - enumerate(grouped_by), total=len(grouped_by), desc="Processing cells" - ): - if len(seg_cell) < 5: - continue - - if boundary_method == "input" and cell_boundaries: - cell_poly = cell_boundaries.get(seg_cell_id) - else: - fallback_method = "delaunay" if boundary_method == "input" else boundary_method - cell_poly = _safe_boundary_polygon( - seg_cell, - x=x_column, - y=y_column, - boundary_method=fallback_method, - boundary_voxel_size=boundary_voxel_size, - ) - if cell_poly is None or not (area_low <= cell_poly.area <= area_high): - continue - - # Nucleus polygon (optional) - nucleus_poly = None - if boundary_method == "input" and nucleus_boundaries: - nucleus_poly = nucleus_boundaries.get(seg_cell_id) - elif nucleus_column is not None and nucleus_column in seg_cell.columns: - seg_nucleus = seg_cell[seg_cell[nucleus_column] == nucleus_value] - if len(seg_nucleus) >= 3: - nucleus_poly = MultiPoint(seg_nucleus[[x_column, y_column]].values).convex_hull - if isinstance(nucleus_poly, MultiPolygon): - nucleus_poly = extract_largest_polygon(nucleus_poly) - if not isinstance(nucleus_poly, Polygon) or nucleus_poly.is_empty: - nucleus_poly = None - - cell_coords, cell_nv = _normalize_polygon_vertices(cell_poly, polygon_max_vertices) - if cell_nv == 0: - continue - - zero_vertices = [(0.0, 0.0)] * polygon_max_vertices - if nucleus_poly is not None: - nuc_coords, nuc_nv = _normalize_polygon_vertices(nucleus_poly, polygon_max_vertices) - else: - nuc_coords, nuc_nv = zero_vertices, 0 - - uint_cell_id = cell_incremental_id + 1 - cell_id2old_id[uint_cell_id] = seg_cell_id - cell_id.append(uint_cell_id) - - # Compute z-level if available - z_level = 0.0 - if z_column is not None and z_column in seg_cell.columns: - z_level = (seg_cell[z_column].mean() // 3) * 3 - - cell_centroid = cell_poly.centroid - nucleus_centroid = nucleus_poly.centroid if nucleus_poly is not None else None - - cell_summary_rows.append([ - float(cell_centroid.x), - float(cell_centroid.y), - float(cell_poly.area), - float(nucleus_centroid.x) if nucleus_centroid is not None else 0.0, - float(nucleus_centroid.y) if nucleus_centroid is not None else 0.0, - float(nucleus_poly.area) if nucleus_poly is not None else 0.0, - float(z_level), - float(1 if nucleus_poly is not None else 0), - ]) - - cell_num_vertices.append(cell_nv) - nucleus_num_vertices.append(nuc_nv) - cell_vertices.append(cell_coords) - nucleus_vertices.append(nuc_coords) - - if len(cell_id) == 0: - raise ValueError("No valid cells found in segmentation data.") - - n_cells = len(cell_id) - cell_vertices_arr = np.array(cell_vertices, dtype=np.float32) - nucleus_vertices_arr = np.array(nucleus_vertices, dtype=np.float32) - cell_vertices_flat = cell_vertices_arr.reshape(n_cells, -1) - nucleus_vertices_flat = nucleus_vertices_arr.reshape(n_cells, -1) - - # Open source store and create new store - source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") - existing_store = zarr.open(source_zarr_store, mode="r") - new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") - - # Root datasets - cell_id_arr = np.zeros((n_cells, 2), dtype=np.uint32) - cell_id_arr[:, 1] = np.array(cell_id, dtype=np.uint32) - new_store["cell_id"] = cell_id_arr - new_store["cell_summary"] = np.array(cell_summary_rows, dtype=np.float64) - - # Polygon sets - polygon_group = new_store.create_group("polygon_sets") - - # Nucleus polygons (set 0) - set0 = polygon_group.create_group("0") - set0["cell_index"] = np.array(cell_id, dtype=np.uint32) - set0["method"] = np.zeros(n_cells, dtype=np.uint32) - set0["num_vertices"] = np.array(nucleus_num_vertices, dtype=np.int32) - set0["vertices"] = nucleus_vertices_flat.astype(np.float32) - - # Cell polygons (set 1) - set1 = polygon_group.create_group("1") - set1["cell_index"] = np.array(cell_id, dtype=np.uint32) - set1["method"] = np.full(n_cells, 1, dtype=np.uint32) - set1["num_vertices"] = np.array(cell_num_vertices, dtype=np.int32) - set1["vertices"] = cell_vertices_flat.astype(np.float32) - - # Update attributes - attrs = dict(existing_store.attrs) - attrs["number_cells"] = n_cells - attrs["polygon_set_names"] = ["nucleus", "cell"] - attrs["polygon_set_display_names"] = ["Nucleus", "Cell"] - attrs["polygon_set_descriptions"] = [ - "Segger nucleus boundaries", - "Segger cell boundaries", - ] - cell_method = f"segger_cell_{boundary_method}" - nucleus_method = "segger_nucleus_convex_hull" - if boundary_method == "input" and nucleus_boundaries: - nucleus_method = "segger_nucleus_input" - attrs["segmentation_methods"] = [nucleus_method, cell_method] - attrs.setdefault("spatial_units", "microns") - attrs.setdefault("major_version", 4) - attrs.setdefault("minor_version", 0) - new_store.attrs.update(attrs) - - new_store.store.close() - source_zarr_store.close() - - # Create analysis data - if analysis_df is None: - analysis_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] - ) - analysis_df["default"] = "segger" - - zarr_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] - ) - clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_column) - clusters_names = [col for col in analysis_df.columns if col != cell_id_column] - - clusters_dict = { - cluster: { - label: idx + 1 - for idx, label in enumerate(sorted(np.unique(clustering_df[cluster].dropna()))) - } - for cluster in clusters_names - } - - new_zarr = zarr.open(storage / f"{analysis_filename}.zarr.zip", mode="w") - new_zarr.create_group("/cell_groups") - - for i, cluster in enumerate(clusters_names): - new_zarr["cell_groups"].create_group(str(i)) - group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] - indices, indptr = get_indices_indptr(np.array(group_values)) - new_zarr["cell_groups"][str(i)]["indices"] = indices - new_zarr["cell_groups"][str(i)]["indptr"] = indptr - - new_zarr["cell_groups"].attrs.update({ - "major_version": 1, - "minor_version": 0, - "number_groupings": len(clusters_names), - "grouping_names": clusters_names, - "group_names": [ - sorted(clusters_dict[cluster], key=clusters_dict[cluster].get) - for cluster in clusters_names - ], - }) - new_zarr.store.close() - - generate_experiment_file( - template_path=source_path / "experiment.xenium", - output_path=storage / xenium_filename, - cells_name=cells_filename, - analysis_name=analysis_filename, - ) - - -def _process_one_cell(args: tuple) -> Optional[dict]: - """Process a single cell for parallel boundary generation.""" - ( - seg_cell_id, - seg_cell, - x_col, - y_col, - z_col, - nucleus_column, - nucleus_value, - area_low, - area_high, - polygon_max_vertices, - boundary_method, - boundary_voxel_size, - ) = args - - if len(seg_cell) < 5: - return None - - cell_poly = _safe_boundary_polygon( - seg_cell, - x=x_col, - y=y_col, - boundary_method=boundary_method, - boundary_voxel_size=boundary_voxel_size, - ) - if cell_poly is None or not (area_low <= cell_poly.area <= area_high): - return None - - cell_vertices, cell_nv = _normalize_polygon_vertices(cell_poly, polygon_max_vertices) - if cell_nv == 0: - return None - - # Nucleus polygon (optional) - nucleus_poly = None - if nucleus_column is not None and nucleus_column in seg_cell.columns: - seg_nucleus = seg_cell[seg_cell[nucleus_column] == nucleus_value] - if len(seg_nucleus) >= 3: - nucleus_poly = MultiPoint(seg_nucleus[[x_col, y_col]].values).convex_hull - if isinstance(nucleus_poly, MultiPolygon): - nucleus_poly = extract_largest_polygon(nucleus_poly) - if not isinstance(nucleus_poly, Polygon) or nucleus_poly.is_empty: - nucleus_poly = None - - if nucleus_poly is not None: - nucleus_vertices, nucleus_nv = _normalize_polygon_vertices( - nucleus_poly, polygon_max_vertices - ) - else: - nucleus_vertices = [(0.0, 0.0)] * polygon_max_vertices - nucleus_nv = 0 - - # Compute z-level if available - z_level = 0.0 - if z_col is not None and z_col in seg_cell.columns: - z_level = (seg_cell[z_col].mean() // 3) * 3 - - cell_centroid = cell_poly.centroid - nucleus_centroid = nucleus_poly.centroid if nucleus_poly is not None else None - - return { - "seg_cell_id": seg_cell_id, - "cell_area": float(cell_poly.area), - "cell_vertices": cell_vertices, - "cell_num_vertices": cell_nv, - "nucleus_vertices": nucleus_vertices, - "nucleus_num_vertices": nucleus_nv, - "cell_centroid_x": float(cell_centroid.x), - "cell_centroid_y": float(cell_centroid.y), - "nucleus_centroid_x": float(nucleus_centroid.x) if nucleus_centroid else 0.0, - "nucleus_centroid_y": float(nucleus_centroid.y) if nucleus_centroid else 0.0, - "nucleus_area": float(nucleus_poly.area) if nucleus_poly is not None else 0.0, - "z_level": float(z_level), - "nucleus_count": float(1 if nucleus_poly is not None else 0), - } - - -def seg2explorer_pqdm( - seg_df: Union[pd.DataFrame, pl.DataFrame], - source_path: Union[str, Path], - output_dir: Union[str, Path], - cells_filename: str = "seg_cells", - analysis_filename: str = "seg_analysis", - xenium_filename: str = "seg_experiment.xenium", - analysis_df: Optional[pd.DataFrame] = None, - cell_id_column: str = "seg_cell_id", - x_column: str = "x", - y_column: str = "y", - z_column: Optional[str] = "z", - nucleus_column: Optional[str] = "cell_compartment", - nucleus_value: int = 2, - area_low: float = 10, - area_high: float = 100, - n_jobs: int = 1, - polygon_max_vertices: int = 13, - boundary_method: str = "delaunay", - boundary_voxel_size: float = 0.0, - boundaries: Optional["gpd.GeoDataFrame"] = None, - boundary_id_column: str = "cell_id", - boundary_type_column: str = "boundary_type", - boundary_cell_value: str = "cell", - boundary_nucleus_value: str = "nucleus", - cell_id_columns: Optional[str] = None, -) -> None: - """Parallelized version of seg2explorer using pqdm. - - Parameters - ---------- - seg_df : Union[pd.DataFrame, pl.DataFrame] - Segmented transcript DataFrame. - source_path : Union[str, Path] - Path to source Xenium data. - output_dir : Union[str, Path] - Output directory. - cells_filename : str - Cells Zarr filename prefix. - analysis_filename : str - Analysis Zarr filename prefix. - xenium_filename : str - Experiment manifest filename. - analysis_df : Optional[pd.DataFrame] - Optional clustering annotations. - cell_id_column : str - Cell ID column name. - x_column : str - X coordinate column name. - y_column : str - Y coordinate column name. - z_column : Optional[str] - Z coordinate column name (if available). - nucleus_column : Optional[str] - Column name for nucleus/compartment assignment. - nucleus_value : int - Value indicating nuclear compartment. - area_low : float - Minimum cell area. - area_high : float - Maximum cell area. - n_jobs : int - Number of parallel workers. - polygon_max_vertices : int - Maximum number of vertices per polygon (including closure). - """ - if cell_id_columns is not None: - cell_id_column = cell_id_columns - - if boundary_method == "skip": - raise ValueError("boundary_method='skip' is not supported for Xenium export.") - if boundary_method == "input" and boundaries is not None: - raise ValueError( - "Parallel Xenium export does not support boundary_method='input'. " - "Use seg2explorer (serial) when passing input boundaries." - ) - if boundary_method == "input": - boundary_method = "delaunay" - - # Convert Polars to pandas - if isinstance(seg_df, pl.DataFrame): - seg_df = seg_df.to_pandas() - - source_path = Path(source_path) - storage = Path(output_dir) - storage.mkdir(parents=True, exist_ok=True) - - grouped_by = seg_df.groupby(cell_id_column) - - def _work_iter(): - return ( - ( - seg_cell_id, - seg_cell, - x_column, - y_column, - z_column, - nucleus_column, - nucleus_value, - area_low, - area_high, - polygon_max_vertices, - boundary_method, - boundary_voxel_size, - ) - for seg_cell_id, seg_cell in grouped_by - ) - - # Process backend first for throughput and "whole job" progress visibility. - # If the process pool crashes, restart once with thread workers. - try: - results = pqdm_processes( - _work_iter(), - _process_one_cell, - n_jobs=n_jobs, - desc="Processing cells", - exception_behaviour="immediate", - ) - except BrokenProcessPool: - if pqdm_threads is None: - raise RuntimeError( - "Process workers crashed and pqdm thread backend is unavailable." - ) - tqdm.write( - "Warning: process workers crashed during Xenium export. " - "Retrying with thread workers from 0% (completed process results " - "cannot be recovered by pqdm)." - ) - results = pqdm_threads( - _work_iter(), - _process_one_cell, - n_jobs=n_jobs, - desc="Processing cells (thread fallback)", - exception_behaviour="immediate", - ) - - # Collate results - cell_id2old_id: Dict[int, Any] = {} - cell_id: List[int] = [] - cell_num_vertices: List[int] = [] - nucleus_num_vertices: List[int] = [] - cell_vertices: List[List[Any]] = [] - nucleus_vertices: List[List[Any]] = [] - cell_summary_rows: List[List[float]] = [] - - kept = [r for r in results if r is not None] - for cell_incremental_id, r in enumerate(kept): - uint_cell_id = cell_incremental_id + 1 - cell_id2old_id[uint_cell_id] = r["seg_cell_id"] - cell_id.append(uint_cell_id) - cell_num_vertices.append(r["cell_num_vertices"]) - nucleus_num_vertices.append(r["nucleus_num_vertices"]) - cell_vertices.append(r["cell_vertices"]) - nucleus_vertices.append(r["nucleus_vertices"]) - cell_summary_rows.append([ - r["cell_centroid_x"], - r["cell_centroid_y"], - r["cell_area"], - r["nucleus_centroid_x"], - r["nucleus_centroid_y"], - r["nucleus_area"], - r["z_level"], - r["nucleus_count"], - ]) - - if len(cell_id) == 0: - raise ValueError("No valid cells found in segmentation data.") - - n_cells = len(cell_id) - cell_vertices_arr = np.array(cell_vertices, dtype=np.float32) - nucleus_vertices_arr = np.array(nucleus_vertices, dtype=np.float32) - cell_vertices_flat = cell_vertices_arr.reshape(n_cells, -1) - nucleus_vertices_flat = nucleus_vertices_arr.reshape(n_cells, -1) - - # Open source and create new store - source_zarr_store = ZipStore(source_path / "cells.zarr.zip", mode="r") - existing_store = zarr.open(source_zarr_store, mode="r") - new_store = zarr.open(storage / f"{cells_filename}.zarr.zip", mode="w") - - # Root datasets - cell_id_arr = np.zeros((n_cells, 2), dtype=np.uint32) - cell_id_arr[:, 1] = np.array(cell_id, dtype=np.uint32) - new_store["cell_id"] = cell_id_arr - new_store["cell_summary"] = np.array(cell_summary_rows, dtype=np.float64) - - polygon_group = new_store.create_group("polygon_sets") - - # Nucleus polygons (set 0) - set0 = polygon_group.create_group("0") - set0["cell_index"] = np.array(cell_id, dtype=np.uint32) - set0["method"] = np.zeros(n_cells, dtype=np.uint32) - set0["num_vertices"] = np.array(nucleus_num_vertices, dtype=np.int32) - set0["vertices"] = nucleus_vertices_flat.astype(np.float32) - - # Cell polygons (set 1) - set1 = polygon_group.create_group("1") - set1["cell_index"] = np.array(cell_id, dtype=np.uint32) - set1["method"] = np.full(n_cells, 1, dtype=np.uint32) - set1["num_vertices"] = np.array(cell_num_vertices, dtype=np.int32) - set1["vertices"] = cell_vertices_flat.astype(np.float32) - - attrs = dict(existing_store.attrs) - attrs["number_cells"] = n_cells - attrs["polygon_set_names"] = ["nucleus", "cell"] - attrs["polygon_set_display_names"] = ["Nucleus", "Cell"] - attrs["polygon_set_descriptions"] = [ - "Segger nucleus boundaries", - "Segger cell boundaries", - ] - attrs["segmentation_methods"] = ["segger_nucleus_convex_hull", f"segger_cell_{boundary_method}"] - attrs.setdefault("spatial_units", "microns") - attrs.setdefault("major_version", 4) - attrs.setdefault("minor_version", 0) - new_store.attrs.update(attrs) - new_store.store.close() - source_zarr_store.close() - - # Create analysis data - if analysis_df is None: - analysis_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] - ) - analysis_df["default"] = "segger" - - zarr_df = pd.DataFrame( - [cell_id2old_id[i] for i in cell_id], columns=[cell_id_column] - ) - clustering_df = pd.merge(zarr_df, analysis_df, how="left", on=cell_id_column) - clusters_names = [col for col in analysis_df.columns if col != cell_id_column] - - clusters_dict = { - cluster: { - label: idx + 1 - for idx, label in enumerate(sorted(np.unique(clustering_df[cluster].dropna()))) - } - for cluster in clusters_names - } - - new_zarr = zarr.open(storage / f"{analysis_filename}.zarr.zip", mode="w") - new_zarr.create_group("/cell_groups") - - for i, cluster in enumerate(clusters_names): - new_zarr["cell_groups"].create_group(str(i)) - group_values = [clusters_dict[cluster].get(x, 0) for x in clustering_df[cluster]] - indices, indptr = get_indices_indptr(np.array(group_values)) - new_zarr["cell_groups"][str(i)]["indices"] = indices - new_zarr["cell_groups"][str(i)]["indptr"] = indptr - - new_zarr["cell_groups"].attrs.update({ - "major_version": 1, - "minor_version": 0, - "number_groupings": len(clusters_names), - "grouping_names": clusters_names, - "group_names": [ - sorted(clusters_dict[cluster], key=clusters_dict[cluster].get) - for cluster in clusters_names - ], - }) - new_zarr.store.close() - - generate_experiment_file( - template_path=source_path / "experiment.xenium", - output_path=storage / xenium_filename, - cells_name=cells_filename, - analysis_name=analysis_filename, - ) diff --git a/src/segger/models/alignment_loss.py b/src/segger/models/alignment_loss.py deleted file mode 100644 index b7e9a9e..0000000 --- a/src/segger/models/alignment_loss.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Alignment loss for mutually exclusive gene constraints. - -This module implements alignment loss using ME gene pairs (negatives) and -same-gene transcript neighbors (positives). Other tx-tx edges are ignored -for the alignment objective. -""" - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class AlignmentLoss(nn.Module): - """Contrastive loss for ME-gene constraints.""" - - def __init__( - self, - weight_start: float = 0.0, - weight_end: float = 0.1, - ): - super().__init__() - self.weight_start = weight_start - self.weight_end = weight_end - self._margin = 0.2 - - def get_scheduled_weight( - self, - current_epoch: int, - max_epochs: int, - ) -> float: - """Cosine schedule between start/end weights.""" - max_epochs = max(1, max_epochs - 1) - t = min(current_epoch, max_epochs) / max_epochs - alpha = 0.5 * (1.0 + math.cos(math.pi * t)) - return self.weight_end + (self.weight_start - self.weight_end) * alpha - - def forward( - self, - embeddings_src: torch.Tensor, - embeddings_dst: torch.Tensor, - labels: torch.Tensor, - ) -> torch.Tensor: - """Compute alignment loss for transcript-transcript edges.""" - sim = (embeddings_src * embeddings_dst).sum(dim=-1) - labels = labels.float() - - pos_mask = labels > 0.5 - neg_mask = ~pos_mask - - loss = torch.tensor(0.0, device=sim.device) - if pos_mask.any(): - pos_loss = (1.0 - sim[pos_mask]) ** 2 - loss = loss + pos_loss.mean() - if neg_mask.any(): - neg_loss = F.relu(sim[neg_mask] - self._margin) ** 2 - loss = loss + neg_loss.mean() - - return loss - - -def compute_me_gene_edges( - gene_indices: torch.Tensor, - me_gene_pairs: torch.Tensor, - edge_index: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """Create tx-tx alignment edges: same-gene positives + ME negatives.""" - src, dst = edge_index - src_genes = gene_indices[src] - dst_genes = gene_indices[dst] - - pos_mask = src_genes == dst_genes - - neg_mask = torch.zeros_like(pos_mask, dtype=torch.bool) - if me_gene_pairs.numel() > 0 and src_genes.numel() > 0: - me_genes = torch.unique(me_gene_pairs.flatten()) - in_me = torch.isin(src_genes, me_genes) & torch.isin(dst_genes, me_genes) - if in_me.any(): - pair_min = torch.minimum(me_gene_pairs[:, 0], me_gene_pairs[:, 1]) - pair_max = torch.maximum(me_gene_pairs[:, 0], me_gene_pairs[:, 1]) - max_gene = max( - src_genes.max().item() if src_genes.numel() > 0 else 0, - dst_genes.max().item() if dst_genes.numel() > 0 else 0, - pair_max.max().item() if pair_max.numel() > 0 else 0, - ) + 1 - me_pair_keys = pair_min * max_gene + pair_max - - edge_min = torch.minimum(src_genes[in_me], dst_genes[in_me]) - edge_max = torch.maximum(src_genes[in_me], dst_genes[in_me]) - edge_pair_keys = edge_min * max_gene + edge_max - is_me = torch.isin(edge_pair_keys, me_pair_keys) - neg_mask[in_me] = is_me - - n_pos = int(pos_mask.sum().item()) - n_neg = int(neg_mask.sum().item()) - if n_neg == 0 and n_pos == 0: - return edge_index[:, :0], torch.empty((0,), device=edge_index.device) - if n_neg == 0: - return edge_index[:, :0], torch.empty((0,), device=edge_index.device) - - max_pos = 3 * n_neg - if n_pos > max_pos: - pos_idx = pos_mask.nonzero().flatten() - pos_idx = pos_idx[ - torch.randperm(n_pos, device=pos_idx.device)[:max_pos] - ] - keep = torch.zeros_like(pos_mask, dtype=torch.bool) - keep[pos_idx] = True - keep |= neg_mask - else: - keep = pos_mask | neg_mask - - if not keep.any(): - return edge_index[:, :0], torch.empty((0,), device=edge_index.device) - - labels = torch.zeros(keep.sum().item(), device=edge_index.device) - labels[pos_mask[keep]] = 1.0 - return edge_index[:, keep], labels diff --git a/src/segger/validation/__init__.py b/src/segger/validation/__init__.py deleted file mode 100644 index ecde6fb..0000000 --- a/src/segger/validation/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .me_genes import load_me_genes_from_scrna -from .quick_metrics import ( - count_cells_from_anndata, - compute_assignment_metrics, - compute_border_contamination_fast, - compute_mecr_fast, - compute_resolvi_contamination_fast, - compute_signal_doublet_fast, - compute_transcript_centroid_offset_fast, - load_me_gene_pairs, - load_segmentation, - load_source_transcripts, - merge_assigned_transcripts, -) diff --git a/src/segger/validation/me_genes.py b/src/segger/validation/me_genes.py deleted file mode 100644 index 8e64585..0000000 --- a/src/segger/validation/me_genes.py +++ /dev/null @@ -1,421 +0,0 @@ -"""Mutually exclusive gene discovery from scRNA-seq reference. - -This module provides functions to identify mutually exclusive (ME) gene pairs -from single-cell RNA-seq reference data. ME genes are markers that are highly -expressed in one cell type but not co-expressed in the same cell, making them -useful constraints for cell segmentation. - -Ported from segger v0.1.0 validation/utils.py. -""" - -from typing import Dict, List, Tuple, Optional -from pathlib import Path -import warnings -import json -import hashlib -import time -import os -import numpy as np -import anndata as ad -import scanpy as sc -import pandas as pd -from itertools import combinations - - -def find_markers( - adata: ad.AnnData, - cell_type_column: str, - pos_percentile: float = 10, - neg_percentile: float = 10, - percentage: float = 30, -) -> Dict[str, Dict[str, List[str]]]: - """Identify positive and negative markers for each cell type. - - Parameters - ---------- - adata : ad.AnnData - Annotated data object containing gene expression data. - cell_type_column : str - Column name in `adata.obs` that specifies cell types. - pos_percentile : float, optional - Percentile threshold for top highly expressed genes (default: 10). - neg_percentile : float, optional - Percentile threshold for top lowly expressed genes (default: 10). - percentage : float, optional - Minimum percentage of cells expressing the marker (default: 30). - - Returns - ------- - dict - Dictionary where keys are cell types and values contain: - 'positive': list of highly expressed genes - 'negative': list of lowly expressed genes - """ - markers = {} - sc.tl.rank_genes_groups(adata, groupby=cell_type_column) - genes = adata.var_names - - for cell_type in adata.obs[cell_type_column].unique(): - subset = adata[adata.obs[cell_type_column] == cell_type] - mean_expression = np.asarray(subset.X.mean(axis=0)).flatten() - - cutoff_high = np.percentile(mean_expression, 100 - pos_percentile) - cutoff_low = np.percentile(mean_expression, neg_percentile) - - pos_indices = np.where(mean_expression >= cutoff_high)[0] - neg_indices = np.where(mean_expression <= cutoff_low)[0] - - # Filter by expression percentage - expr_frac = np.asarray((subset.X[:, pos_indices] > 0).mean(axis=0)).flatten() - valid_pos_indices = pos_indices[expr_frac >= (percentage / 100)] - - positive_markers = genes[valid_pos_indices] - negative_markers = genes[neg_indices] - - markers[cell_type] = { - "positive": list(positive_markers), - "negative": list(negative_markers), - } - - return markers - - -def find_mutually_exclusive_genes( - adata: ad.AnnData, - markers: Dict[str, Dict[str, List[str]]], - cell_type_column: str, - expr_threshold_in: float = 0.25, - expr_threshold_out: float = 0.03, -) -> List[Tuple[str, str]]: - """Identify mutually exclusive genes based on expression criteria. - - A gene is considered ME if it's expressed in >expr_threshold_in of its - cell type but in 0).mean() - expr_out = (gene_expr[non_cell_type_mask] > 0).mean() - - if expr_in > expr_threshold_in and expr_out < expr_threshold_out: - exclusive_genes[cell_type].append(gene) - all_exclusive.append(gene) - - # Get unique exclusive genes - unique_genes = list(set(all_exclusive)) - filtered_exclusive_genes = { - ct: [g for g in genes if g in unique_genes] - for ct, genes in exclusive_genes.items() - } - - # Create pairs from different cell types - mutually_exclusive_gene_pairs = [ - (gene1, gene2) - for key1, key2 in combinations(filtered_exclusive_genes.keys(), 2) - for gene1 in filtered_exclusive_genes[key1] - for gene2 in filtered_exclusive_genes[key2] - ] - - return mutually_exclusive_gene_pairs - - -def compute_MECR( - adata: ad.AnnData, - gene_pairs: List[Tuple[str, str]], -) -> Dict[Tuple[str, str], float]: - """Compute Mutually Exclusive Co-expression Rate (MECR) for gene pairs. - - MECR = (both expressed) / (at least one expressed) - Lower MECR indicates better mutual exclusivity. - - Parameters - ---------- - adata : ad.AnnData - Annotated data object containing gene expression data. - gene_pairs : list - List of gene pairs to evaluate. - - Returns - ------- - dict - Dictionary mapping gene pairs to MECR values. - """ - mecr_dict = {} - gene_expression = adata.to_df() - - for gene1, gene2 in gene_pairs: - if gene1 not in gene_expression.columns or gene2 not in gene_expression.columns: - continue - - expr_gene1 = gene_expression[gene1] > 0 - expr_gene2 = gene_expression[gene2] > 0 - - both_expressed = (expr_gene1 & expr_gene2).mean() - at_least_one_expressed = (expr_gene1 | expr_gene2).mean() - - mecr = ( - both_expressed / at_least_one_expressed - if at_least_one_expressed > 0 - else 0 - ) - mecr_dict[(gene1, gene2)] = mecr - - return mecr_dict - - -def load_me_genes_from_scrna( - scrna_path: Path, - cell_type_column: str = "celltype", - gene_name_column: Optional[str] = None, - pos_percentile: float = 10, - neg_percentile: float = 10, - percentage: float = 30, - expr_threshold_in: float = 0.25, - expr_threshold_out: float = 0.03, -) -> Tuple[List[Tuple[str, str]], Dict[str, Dict[str, List[str]]]]: - """Load scRNA-seq reference and compute ME gene pairs. - - Parameters - ---------- - scrna_path : Path - Path to scRNA-seq reference h5ad file. - cell_type_column : str, optional - Column name for cell type annotations (default: "celltype"). - gene_name_column : str | None, optional - Column in var for gene names. If None, uses var_names. - pos_percentile : float, optional - Percentile for positive markers (default: 10). - neg_percentile : float, optional - Percentile for negative markers (default: 10). - percentage : float, optional - Minimum expression percentage (default: 30). - expr_threshold_in : float, optional - Minimum expression in own cell type (default: 0.25). - expr_threshold_out : float, optional - Maximum expression in other cell types (default: 0.03). - Notes - ----- - For performance, cells are subsampled to at most 1000 per cell type. - - Returns - ------- - tuple - (me_gene_pairs, markers) where me_gene_pairs is a list of - (gene1, gene2) tuples and markers is the full marker dictionary. - """ - verbose = os.getenv("SEGGER_ME_VERBOSE", "").lower() not in {"0", "false", "no", "off"} - # Cache to avoid repeated expensive ME discovery - cache_key = _me_cache_key( - scrna_path=scrna_path, - cell_type_column=cell_type_column, - gene_name_column=gene_name_column, - pos_percentile=pos_percentile, - neg_percentile=neg_percentile, - percentage=percentage, - expr_threshold_in=expr_threshold_in, - expr_threshold_out=expr_threshold_out, - ) - cache_path = _me_cache_path(scrna_path, cache_key) - if cache_path.exists(): - try: - with open(cache_path, "r") as f: - cached = json.load(f) - if cached.get("key") == cache_key: - pairs = [ - (p[0], p[1]) - for p in cached.get("me_gene_pairs", []) - if len(p) == 2 - ] - markers = cached.get("markers", {}) - if verbose: - print( - f"[segger][me] cache hit: {len(pairs)} pairs", - flush=True, - ) - return pairs, markers - except Exception: - pass - - t0 = time.monotonic() - if verbose: - print( - "[segger][me] computing ME gene pairs (this can take a while)...", - flush=True, - ) - - # Load scRNA-seq data - adata = sc.read_h5ad(scrna_path) - - # Subsample cells per cell type to limit runtime - if cell_type_column in adata.obs: - rng = np.random.default_rng(0) - idx = [] - for ct in adata.obs[cell_type_column].unique(): - ct_idx = np.where(adata.obs[cell_type_column] == ct)[0] - if ct_idx.size > _ME_MAX_CELLS_PER_TYPE: - ct_idx = rng.choice( - ct_idx, - size=_ME_MAX_CELLS_PER_TYPE, - replace=False, - ) - idx.append(ct_idx) - if idx: - idx = np.concatenate(idx) - adata = adata[idx].copy() - - # Ensure unique var names and log-normalize if needed - if not adata.var_names.is_unique: - adata.var_names_make_unique() - if "log1p" not in adata.uns: - sc.pp.normalize_total(adata, target_sum=1e4) - sc.pp.log1p(adata) - - # Optionally remap gene names - if gene_name_column is not None and gene_name_column in adata.var.columns: - adata.var_names = adata.var[gene_name_column] - - # Find markers - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - category=pd.errors.PerformanceWarning, - ) - markers = find_markers( - adata, - cell_type_column=cell_type_column, - pos_percentile=pos_percentile, - neg_percentile=neg_percentile, - percentage=percentage, - ) - - # Find ME gene pairs - me_gene_pairs = find_mutually_exclusive_genes( - adata, - markers, - cell_type_column=cell_type_column, - expr_threshold_in=expr_threshold_in, - expr_threshold_out=expr_threshold_out, - ) - - if verbose: - n_types = adata.obs[cell_type_column].nunique() - elapsed = time.monotonic() - t0 - print( - f"[segger][me] done: {len(me_gene_pairs)} pairs " - f"across {n_types} cell types in {elapsed:.1f}s", - flush=True, - ) - - # Write cache (best-effort) - try: - payload = { - "key": cache_key, - "me_gene_pairs": [list(p) for p in me_gene_pairs], - "markers": markers, - } - with open(cache_path, "w") as f: - json.dump(payload, f) - except Exception: - pass - - return me_gene_pairs, markers - - -def me_gene_pairs_to_indices( - me_gene_pairs: List[Tuple[str, str]], - gene_names: List[str], -) -> List[Tuple[int, int]]: - """Convert gene name pairs to index pairs. - - Parameters - ---------- - me_gene_pairs : list - List of (gene1, gene2) name tuples. - gene_names : list - List of gene names in order (index corresponds to token). - - Returns - ------- - list - List of (idx1, idx2) index tuples. - """ - gene_to_idx = {name: idx for idx, name in enumerate(gene_names)} - - index_pairs = [] - for gene1, gene2 in me_gene_pairs: - if gene1 in gene_to_idx and gene2 in gene_to_idx: - index_pairs.append((gene_to_idx[gene1], gene_to_idx[gene2])) - - return index_pairs -_ME_CACHE_VERSION = 2 -_ME_MAX_CELLS_PER_TYPE = 1000 - - -def _me_cache_key( - scrna_path: Path, - cell_type_column: str, - gene_name_column: Optional[str], - pos_percentile: float, - neg_percentile: float, - percentage: float, - expr_threshold_in: float, - expr_threshold_out: float, -) -> str: - """Create a stable cache key for ME gene discovery inputs.""" - st = scrna_path.stat() - payload = { - "version": _ME_CACHE_VERSION, - "path": str(scrna_path.resolve()), - "size": st.st_size, - "mtime_ns": st.st_mtime_ns, - "cell_type_column": cell_type_column, - "gene_name_column": gene_name_column, - "pos_percentile": pos_percentile, - "neg_percentile": neg_percentile, - "percentage": percentage, - "expr_threshold_in": expr_threshold_in, - "expr_threshold_out": expr_threshold_out, - "max_cells_per_type": _ME_MAX_CELLS_PER_TYPE, - } - raw = json.dumps(payload, sort_keys=True).encode("utf-8") - return hashlib.sha256(raw).hexdigest()[:16] - - -def _me_cache_path(scrna_path: Path, key: str) -> Path: - """Cache file path for ME gene discovery outputs.""" - return Path(f"{scrna_path}.segger_me_cache.{key}.json") diff --git a/src/segger/validation/quick_metrics.py b/src/segger/validation/quick_metrics.py deleted file mode 100644 index 1eeb1f2..0000000 --- a/src/segger/validation/quick_metrics.py +++ /dev/null @@ -1,1050 +0,0 @@ -"""Lightweight validation metrics for Segger outputs. - -This module provides fast, reference-light metrics intended for quick model -selection and single-run quality checks. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import Optional, Sequence - -import anndata as ad -import numpy as np -import polars as pl -from scipy import sparse -from scipy.spatial import cKDTree - -from ..io import StandardTranscriptFields, get_preprocessor -from .me_genes import load_me_genes_from_scrna - - -def assigned_cell_expr(cell_id_column: str = "segger_cell_id") -> pl.Expr: - """Expression selecting transcripts assigned to a valid cell.""" - cell = pl.col(cell_id_column) - cell_str = cell.cast(pl.Utf8) - return ( - cell.is_not_null() - & (cell_str != "-1") - & (cell_str.str.to_uppercase() != "UNASSIGNED") - & (cell_str.str.to_uppercase() != "NONE") - ) - - -def _effective_sample_size(weights: np.ndarray) -> float: - """Kish effective sample size for non-negative weights.""" - w = np.asarray(weights, dtype=np.float64) - w = w[np.isfinite(w) & (w > 0)] - if w.size == 0: - return float("nan") - sw = float(w.sum()) - sw2 = float(np.square(w).sum()) - if sw <= 0 or sw2 <= 0: - return float("nan") - return (sw * sw) / sw2 - - -def _weighted_mean_ci95(values: np.ndarray, weights: np.ndarray) -> float: - """Approximate 95% CI half-width for a weighted mean.""" - v = np.asarray(values, dtype=np.float64) - w = np.asarray(weights, dtype=np.float64) - mask = np.isfinite(v) & np.isfinite(w) & (w > 0) - if not np.any(mask): - return float("nan") - v = v[mask] - w = w[mask] - if v.size == 0: - return float("nan") - mu = float(np.average(v, weights=w)) - neff = _effective_sample_size(w) - if not np.isfinite(neff) or neff <= 1: - return float("nan") - var = float(np.average(np.square(v - mu), weights=w)) - se = np.sqrt(max(var, 0.0) / neff) - return float(1.96 * se) - - -def _weighted_bernoulli_ci95(flags: np.ndarray, weights: np.ndarray) -> float: - """Approximate 95% CI half-width for weighted Bernoulli proportion.""" - f = np.asarray(flags, dtype=np.float64) - w = np.asarray(weights, dtype=np.float64) - mask = np.isfinite(f) & np.isfinite(w) & (w > 0) - if not np.any(mask): - return float("nan") - f = f[mask] - w = w[mask] - if f.size == 0: - return float("nan") - p = float(np.average(f, weights=w)) - neff = _effective_sample_size(w) - if not np.isfinite(neff) or neff <= 1: - return float("nan") - se = np.sqrt(max(p * (1.0 - p), 0.0) / neff) - return float(1.96 * se) - - -def _binomial_pct_ci95(successes: int, total: int) -> float: - """95% CI half-width in percentage points for a binomial proportion.""" - if total <= 0: - return float("nan") - p = min(max(float(successes) / float(total), 0.0), 1.0) - se = np.sqrt(max(p * (1.0 - p), 0.0) / float(total)) - return float(100.0 * 1.96 * se) - - -def load_source_transcripts(source_path: Path) -> pl.DataFrame: - """Load standardized source transcripts with only needed columns.""" - tx_fields = StandardTranscriptFields() - source_path = Path(source_path) - tx = None - - # Optional SpatialData input support - try: - from ..io.spatialdata_loader import is_spatialdata_path, load_from_spatialdata - - if is_spatialdata_path(source_path): - tx_lf, _ = load_from_spatialdata( - source_path, - boundary_type="all", - normalize=True, - ) - tx = tx_lf.collect() if isinstance(tx_lf, pl.LazyFrame) else tx_lf - except Exception: - tx = None - - if tx is None: - pp = get_preprocessor(source_path, min_qv=0, include_z=True) - tx = pp.transcripts - if isinstance(tx, pl.LazyFrame): - tx = tx.collect() - - keep_cols = [ - tx_fields.row_index, - tx_fields.feature, - tx_fields.x, - tx_fields.y, - ] - if tx_fields.z in tx.columns: - keep_cols.append(tx_fields.z) - tx = tx.select([c for c in keep_cols if c in tx.columns]) - - if tx_fields.row_index not in tx.columns: - tx = tx.with_row_index(name=tx_fields.row_index) - - return tx - - -def load_segmentation(segmentation_path: Path) -> pl.DataFrame: - """Load segmentation parquet with canonical columns.""" - seg = pl.read_parquet(segmentation_path) - required = ["row_index", "segger_cell_id"] - missing = [c for c in required if c not in seg.columns] - if missing: - raise ValueError( - f"Segmentation file missing required columns {missing}: {segmentation_path}" - ) - return seg.select([c for c in ["row_index", "segger_cell_id", "segger_similarity"] if c in seg.columns]) - - -def compute_assignment_metrics( - seg_df: pl.DataFrame, - cell_id_column: str = "segger_cell_id", -) -> dict[str, float]: - """Compute transcript assignment coverage metrics.""" - total = int(seg_df.height) - if total == 0: - return { - "transcripts_total": 0, - "transcripts_assigned": 0, - "transcripts_assigned_pct": float("nan"), - "transcripts_assigned_pct_ci95": float("nan"), - "cells_assigned": 0, - } - - assigned_df = seg_df.filter(assigned_cell_expr(cell_id_column)) - assigned = int(assigned_df.height) - cells = int( - assigned_df.select(pl.col(cell_id_column).n_unique()).to_series().item() - if assigned > 0 - else 0 - ) - return { - "transcripts_total": total, - "transcripts_assigned": assigned, - "transcripts_assigned_pct": 100.0 * assigned / total, - "transcripts_assigned_pct_ci95": _binomial_pct_ci95(assigned, total), - "cells_assigned": cells, - } - - -def count_cells_from_anndata(anndata_path: Optional[Path]) -> Optional[int]: - """Return number of cells (n_obs) from AnnData, or None if unavailable.""" - if anndata_path is None: - return None - path = Path(anndata_path) - if not path.exists(): - return None - - adata = ad.read_h5ad(path, backed="r") - try: - return int(adata.n_obs) - finally: - try: - if getattr(adata, "isbacked", False): - adata.file.close() - except Exception: - pass - - -def merge_assigned_transcripts( - seg_df: pl.DataFrame, - source_tx: pl.DataFrame, - cell_id_column: str = "segger_cell_id", - row_index_column: str = "row_index", -) -> pl.DataFrame: - """Inner-join source transcripts with assigned segmentation rows.""" - left = source_tx - right = seg_df - - if row_index_column not in left.columns: - left = left.with_row_index(name=row_index_column) - - left = left.with_columns(pl.col(row_index_column).cast(pl.Int64)) - right = right.with_columns(pl.col(row_index_column).cast(pl.Int64)) - right = right.filter(assigned_cell_expr(cell_id_column)).select([row_index_column, cell_id_column]) - - return left.join(right, on=row_index_column, how="inner") - - -def _empty_resolvi_metrics() -> dict[str, float]: - """Default empty return payload for RESOLVI-like contamination metric.""" - return { - "resolvi_contamination_pct_fast": float("nan"), - "resolvi_contamination_ci95_fast": float("nan"), - "resolvi_contaminated_cells_pct_fast": float("nan"), - "resolvi_contaminated_cells_pct_ci95_fast": float("nan"), - "resolvi_metric_cells_used": 0, - "resolvi_shared_genes_used": 0, - "resolvi_cell_types_used": 0, - } - - -def _build_cell_gene_matrix( - assigned_tx: pl.DataFrame, - *, - cell_id_column: str, - feature_column: str, - x_column: str, - y_column: str, - min_transcripts_per_cell: int, - max_cells: int, - seed: int, -) -> tuple[sparse.csr_matrix, np.ndarray, np.ndarray, list[str]] | None: - """Build sparse cell x gene counts with centroids and per-cell weights.""" - req = [cell_id_column, feature_column, x_column, y_column] - for col in req: - if col not in assigned_tx.columns: - return None - - df = ( - assigned_tx.select(req) - .drop_nulls() - .with_columns( - pl.col(cell_id_column).cast(pl.Utf8), - pl.col(feature_column).cast(pl.Utf8), - ) - ) - if df.height == 0: - return None - - cell_stats = ( - df.group_by(cell_id_column) - .agg( - pl.len().alias("n_total"), - pl.col(x_column).mean().alias("cx"), - pl.col(y_column).mean().alias("cy"), - ) - .filter(pl.col("n_total") >= int(min_transcripts_per_cell)) - ) - if cell_stats.height == 0: - return None - - if max_cells > 0 and cell_stats.height > max_cells: - rng = np.random.default_rng(seed) - ids = np.asarray(cell_stats.get_column(cell_id_column).to_list(), dtype=object) - picked = rng.choice(ids, size=max_cells, replace=False).tolist() - cell_stats = cell_stats.filter(pl.col(cell_id_column).is_in(picked)) - - cell_stats = cell_stats.sort(cell_id_column).with_row_index(name="_cid") - if cell_stats.height == 0: - return None - - df = df.join(cell_stats.select([cell_id_column, "_cid"]), on=cell_id_column, how="inner") - if df.height == 0: - return None - - gene_idx = ( - df.select(feature_column) - .unique() - .sort(feature_column) - .with_row_index(name="_gid") - ) - if gene_idx.height == 0: - return None - - mapped = df.join(gene_idx, on=feature_column, how="inner") - counts = mapped.group_by(["_cid", "_gid"]).agg(pl.len().alias("_count")) - if counts.height == 0: - return None - - rows = counts.get_column("_cid").to_numpy().astype(np.int64, copy=False) - cols = counts.get_column("_gid").to_numpy().astype(np.int64, copy=False) - data = counts.get_column("_count").to_numpy().astype(np.float64, copy=False) - - X = sparse.coo_matrix( - (data, (rows, cols)), - shape=(int(cell_stats.height), int(gene_idx.height)), - ).tocsr() - - centroids = cell_stats.select(["cx", "cy"]).to_numpy().astype(np.float64, copy=False) - weights = cell_stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) - gene_names = [str(g) for g in gene_idx.get_column(feature_column).to_list()] - return X, centroids, weights, gene_names - - -def _load_reference_type_profiles( - scrna_reference_path: Path, - scrna_celltype_column: str, - seg_gene_names: list[str], -) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: - """Load per-celltype expression profiles on genes shared with segmentation.""" - if not Path(scrna_reference_path).exists(): - return None - - ref = ad.read_h5ad(scrna_reference_path) - if scrna_celltype_column not in ref.obs.columns: - return None - - labels_raw = ref.obs[scrna_celltype_column].astype(str).to_numpy() - labels = np.asarray([str(x) for x in labels_raw], dtype=object) - label_norm = np.char.lower(labels.astype(str)) - valid = ( - (labels.astype(str) != "") - & (label_norm != "nan") - & (label_norm != "none") - & (label_norm != "-1") - ) - if not np.any(valid): - return None - - ref_genes = np.asarray([str(g) for g in ref.var_names], dtype=object) - ref_gene_to_idx = {g: i for i, g in enumerate(ref_genes.tolist())} - - seg_shared_idx: list[int] = [] - ref_shared_idx: list[int] = [] - for i, g in enumerate(seg_gene_names): - j = ref_gene_to_idx.get(str(g)) - if j is None: - continue - seg_shared_idx.append(i) - ref_shared_idx.append(int(j)) - - if len(seg_shared_idx) == 0: - return None - - X_ref = ref.X - if sparse.issparse(X_ref): - X_ref = X_ref.tocsr()[valid][:, np.asarray(ref_shared_idx, dtype=np.int64)] - else: - X_ref = np.asarray(X_ref)[valid][:, np.asarray(ref_shared_idx, dtype=np.int64)] - - labels_valid = labels[valid].astype(str) - type_names, type_inverse = np.unique(labels_valid, return_inverse=True) - if type_names.size == 0: - return None - - n_types = int(type_names.size) - n_genes = int(len(seg_shared_idx)) - profiles = np.zeros((n_types, n_genes), dtype=np.float64) - - for t in range(n_types): - idx = np.where(type_inverse == t)[0] - if idx.size == 0: - continue - if sparse.issparse(X_ref): - sub = X_ref[idx] - profiles[t] = np.asarray(sub.mean(axis=0)).ravel() - else: - profiles[t] = np.asarray(X_ref[idx], dtype=np.float64).mean(axis=0) - - profiles = np.nan_to_num(profiles, nan=0.0, posinf=0.0, neginf=0.0) - profiles = np.maximum(profiles, 0.0) - keep = np.asarray(profiles.sum(axis=1)).ravel() > 0 - if not np.any(keep): - return None - - return ( - np.asarray(seg_shared_idx, dtype=np.int64), - profiles[keep], - type_names[keep], - ) - - -def compute_resolvi_contamination_fast( - assigned_tx: pl.DataFrame, - *, - scrna_reference_path: Optional[Path], - scrna_celltype_column: str = "cell_type", - cell_id_column: str = "segger_cell_id", - feature_column: str = "feature_name", - x_column: str = "x", - y_column: str = "y", - min_transcripts_per_cell: int = 20, - max_cells: int = 3000, - k_neighbors: int = 10, - max_neighbor_distance: float = 20.0, - alpha_self: float = 0.8, - alpha_neighbor: float = 0.175, - alpha_background: float = 0.025, - contam_cutoff: float = 0.5, - seed: int = 0, -) -> dict[str, float]: - """Fast RESOLVI-style contamination estimate (lower is better). - - This approximates the RESOLVI neighborhood contamination formulation on a - sampled subset of segmented cells by: - 1) deriving host cell types from scRNA reference profile similarity, - 2) mixing self/neighbor/background expected expression per gene, - 3) flagging counts as contaminated when q_self < contam_cutoff. - """ - out = _empty_resolvi_metrics() - if scrna_reference_path is None: - return out - if alpha_self < 0 or alpha_neighbor < 0 or alpha_background < 0: - return out - if alpha_self + alpha_neighbor + alpha_background <= 0: - return out - - try: - built = _build_cell_gene_matrix( - assigned_tx, - cell_id_column=cell_id_column, - feature_column=feature_column, - x_column=x_column, - y_column=y_column, - min_transcripts_per_cell=min_transcripts_per_cell, - max_cells=max_cells, - seed=seed, - ) - if built is None: - return out - X, centroids, cell_weights, gene_names = built - if X.shape[0] == 0 or X.shape[1] == 0: - return out - - ref_data = _load_reference_type_profiles( - Path(scrna_reference_path), - scrna_celltype_column=scrna_celltype_column, - seg_gene_names=gene_names, - ) - if ref_data is None: - return out - seg_shared_idx, ref_profiles, type_names = ref_data - if seg_shared_idx.size == 0 or ref_profiles.size == 0: - return out - - X = X[:, seg_shared_idx] - if X.shape[1] == 0: - return out - - totals = np.asarray(X.sum(axis=1)).ravel().astype(np.float64, copy=False) - keep = np.isfinite(totals) & (totals > 0) - if not np.any(keep): - return out - if not np.all(keep): - rows_keep = np.where(keep)[0] - X = X[rows_keep] - centroids = centroids[rows_keep] - cell_weights = cell_weights[rows_keep] - totals = totals[rows_keep] - - n_cells = int(X.shape[0]) - n_types = int(ref_profiles.shape[0]) - out["resolvi_shared_genes_used"] = int(X.shape[1]) - out["resolvi_cell_types_used"] = int(type_names.size) - if n_cells == 0 or n_types == 0: - return out - - eps = 1e-9 - ref = np.asarray(ref_profiles, dtype=np.float64) - ref_norm = np.linalg.norm(ref, axis=1) - ref_norm[~np.isfinite(ref_norm) | (ref_norm <= 0)] = 1.0 - - cell_norm = np.sqrt(np.asarray(X.multiply(X).sum(axis=1)).ravel()) - cell_norm[~np.isfinite(cell_norm) | (cell_norm <= 0)] = 1.0 - - sim = X @ ref.T - if sparse.issparse(sim): - sim = sim.toarray() - sim = np.asarray(sim, dtype=np.float64) - sim /= cell_norm[:, None] - sim /= ref_norm[None, :] - host_type = np.argmax(sim, axis=1).astype(np.int64) - - neighbor_freq = np.zeros((n_cells, n_types), dtype=np.float64) - if n_cells > 1 and int(k_neighbors) > 0: - k = min(int(k_neighbors) + 1, n_cells) - tree = cKDTree(centroids) - dists, idxs = tree.query(centroids, k=k) - if k > 1: - if dists.ndim == 1: - dists = dists[:, None] - idxs = idxs[:, None] - nbr_d = dists[:, 1:] - nbr_i = idxs[:, 1:] - for i in range(n_cells): - if nbr_i.shape[1] == 0: - continue - if np.isfinite(max_neighbor_distance) and max_neighbor_distance > 0: - valid_nbr = nbr_d[i] <= float(max_neighbor_distance) - else: - valid_nbr = np.ones(nbr_i.shape[1], dtype=bool) - pick = nbr_i[i][valid_nbr] - if pick.size == 0: - continue - np.add.at(neighbor_freq[i], host_type[pick], 1.0) - s = float(neighbor_freq[i].sum()) - if s > 0: - neighbor_freq[i] /= s - - bg = np.bincount( - host_type, - weights=np.asarray(cell_weights, dtype=np.float64), - minlength=n_types, - ).astype(np.float64, copy=False) - bsum = float(bg.sum()) - if bsum > 0: - bg /= bsum - p_back = bg @ ref - - per_cell_pct = np.full(n_cells, np.nan, dtype=np.float64) - per_cell_flag = np.zeros(n_cells, dtype=np.float64) - - for i in range(n_cells): - row = X.getrow(i) - if row.nnz == 0: - continue - h = int(host_type[i]) - neigh = neighbor_freq[i].copy() - if 0 <= h < n_types: - neigh[h] = 0.0 - nsum = float(neigh.sum()) - if nsum > 0: - neigh /= nsum - p_self = ref[h] - p_neigh = neigh @ ref - denom = (alpha_self * p_self) + (alpha_neighbor * p_neigh) + (alpha_background * p_back) + eps - q_self = (alpha_self * p_self) / denom - q_self = np.clip(q_self, 0.0, 1.0) - - vals = row.data.astype(np.float64, copy=False) - cols = row.indices - total = float(vals.sum()) - if total <= 0: - continue - contam = float(vals[q_self[cols] < float(contam_cutoff)].sum()) - pct = 100.0 * contam / total - per_cell_pct[i] = pct - per_cell_flag[i] = 1.0 if contam > 0 else 0.0 - - valid_cells = np.isfinite(per_cell_pct) & np.isfinite(totals) & (totals > 0) - if not np.any(valid_cells): - return out - - w = totals[valid_cells] - vals = per_cell_pct[valid_cells] - flags = per_cell_flag[valid_cells] - - out["resolvi_metric_cells_used"] = int(np.sum(valid_cells)) - out["resolvi_contamination_pct_fast"] = float(np.average(vals, weights=w)) - out["resolvi_contamination_ci95_fast"] = float(_weighted_mean_ci95(vals, w)) - out["resolvi_contaminated_cells_pct_fast"] = float(100.0 * np.average(flags, weights=w)) - out["resolvi_contaminated_cells_pct_ci95_fast"] = float(100.0 * _weighted_bernoulli_ci95(flags, w)) - return out - except Exception: - return out - - -def compute_border_contamination_fast( - assigned_tx: pl.DataFrame, - *, - cell_id_column: str = "segger_cell_id", - x_column: str = "x", - y_column: str = "y", - erosion_fraction: float = 0.3, - min_transcripts_per_cell: int = 20, - max_cells: int = 3000, - contaminated_enrichment_threshold: float = 1.25, - seed: int = 0, -) -> dict[str, float]: - """Fast border-enrichment contamination proxy (lower is better). - - This approximates periphery contamination by comparing transcript density - in border vs center regions defined by an eroded bounding box. - """ - eps = 1e-9 - req = [cell_id_column, x_column, y_column] - for col in req: - if col not in assigned_tx.columns: - return { - "border_contamination_fast": float("nan"), - "border_enrichment_fast": float("nan"), - "border_excess_pct_fast": float("nan"), - "border_contaminated_cells_pct_fast": float("nan"), - "border_contaminated_cells_pct_ci95_fast": float("nan"), - "border_metric_cells_used": 0, - } - - df = assigned_tx.select(req).drop_nulls() - if df.height == 0: - return { - "border_contamination_fast": float("nan"), - "border_enrichment_fast": float("nan"), - "border_excess_pct_fast": float("nan"), - "border_contaminated_cells_pct_fast": float("nan"), - "border_contaminated_cells_pct_ci95_fast": float("nan"), - "border_metric_cells_used": 0, - } - - cell_stats = ( - df.group_by(cell_id_column) - .agg( - pl.len().alias("n_total"), - pl.col(x_column).min().alias("x_min"), - pl.col(x_column).max().alias("x_max"), - pl.col(y_column).min().alias("y_min"), - pl.col(y_column).max().alias("y_max"), - ) - .with_columns( - (pl.col("x_max") - pl.col("x_min")).alias("width"), - (pl.col("y_max") - pl.col("y_min")).alias("height"), - ) - .with_columns( - pl.min_horizontal("width", "height").alias("min_side"), - (pl.min_horizontal("width", "height") * erosion_fraction).alias("erosion"), - ) - .filter(pl.col("n_total") >= min_transcripts_per_cell) - .filter((pl.col("width") > 0) & (pl.col("height") > 0)) - .filter((pl.col("min_side") > 0) & (pl.col("erosion") > 0)) - .with_columns( - (pl.col("x_min") + pl.col("erosion")).alias("cx_min"), - (pl.col("x_max") - pl.col("erosion")).alias("cx_max"), - (pl.col("y_min") + pl.col("erosion")).alias("cy_min"), - (pl.col("y_max") - pl.col("erosion")).alias("cy_max"), - ) - .filter((pl.col("cx_max") > pl.col("cx_min")) & (pl.col("cy_max") > pl.col("cy_min"))) - ) - - if cell_stats.height == 0: - return { - "border_contamination_fast": float("nan"), - "border_enrichment_fast": float("nan"), - "border_excess_pct_fast": float("nan"), - "border_contaminated_cells_pct_fast": float("nan"), - "border_contaminated_cells_pct_ci95_fast": float("nan"), - "border_metric_cells_used": 0, - } - - if max_cells > 0 and cell_stats.height > max_cells: - rng = np.random.default_rng(seed) - cell_ids = np.array(cell_stats.get_column(cell_id_column).to_list(), dtype=object) - picked = rng.choice(cell_ids, size=max_cells, replace=False).tolist() - cell_stats = cell_stats.filter(pl.col(cell_id_column).is_in(picked)) - df = df.join(cell_stats.select([cell_id_column]), on=cell_id_column, how="inner") - - classified = ( - df.join( - cell_stats.select( - [ - cell_id_column, - "cx_min", - "cx_max", - "cy_min", - "cy_max", - "width", - "height", - "n_total", - ] - ), - on=cell_id_column, - how="inner", - ) - .with_columns( - ( - (pl.col(x_column) >= pl.col("cx_min")) - & (pl.col(x_column) <= pl.col("cx_max")) - & (pl.col(y_column) >= pl.col("cy_min")) - & (pl.col(y_column) <= pl.col("cy_max")) - ).alias("is_center") - ) - ) - - grouped = ( - classified.group_by([cell_id_column, "is_center"]) - .agg(pl.len().alias("n")) - ) - n_center = grouped.filter(pl.col("is_center")).select([cell_id_column, pl.col("n").alias("n_center")]) - n_border = grouped.filter(~pl.col("is_center")).select([cell_id_column, pl.col("n").alias("n_border")]) - - per_cell = ( - cell_stats.join(n_center, on=cell_id_column, how="left") - .join(n_border, on=cell_id_column, how="left") - .with_columns( - pl.col("n_center").fill_null(0).cast(pl.Float64), - pl.col("n_border").fill_null(0).cast(pl.Float64), - pl.col("n_total").cast(pl.Float64), - ) - .with_columns( - (pl.col("width") * pl.col("height")).alias("bbox_area"), - ((pl.col("cx_max") - pl.col("cx_min")) * (pl.col("cy_max") - pl.col("cy_min"))).alias("center_area"), - ) - .with_columns( - (pl.col("bbox_area") - pl.col("center_area")).alias("border_area"), - ) - .with_columns( - (pl.col("n_center") / pl.max_horizontal(pl.col("center_area"), pl.lit(eps))).alias("center_density"), - (pl.col("n_border") / pl.max_horizontal(pl.col("border_area"), pl.lit(eps))).alias("border_density"), - ) - .with_columns( - (pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps))).alias("border_enrichment"), - ( - pl.when( - (pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps)) - 1.0) - > 0 - ) - .then(pl.col("border_density") / pl.max_horizontal(pl.col("center_density"), pl.lit(eps)) - 1.0) - .otherwise(0.0) - ).alias("contam_score"), - ) - ) - - if per_cell.height == 0: - return { - "border_contamination_fast": float("nan"), - "border_enrichment_fast": float("nan"), - "border_excess_pct_fast": float("nan"), - "border_contaminated_cells_pct_fast": float("nan"), - "border_contaminated_cells_pct_ci95_fast": float("nan"), - "border_metric_cells_used": 0, - } - - weights = per_cell.get_column("n_total").to_numpy() - contam = np.average(per_cell.get_column("contam_score").to_numpy(), weights=weights) - enrich = np.average(per_cell.get_column("border_enrichment").to_numpy(), weights=weights) - border_excess_pct = max(0.0, (float(enrich) - 1.0) * 100.0) - contaminated_flags = ( - per_cell.get_column("border_enrichment").to_numpy() - > float(contaminated_enrichment_threshold) - ).astype(np.float64) - contaminated_cells_pct = 100.0 * np.average( - contaminated_flags, - weights=weights, - ) - contaminated_cells_pct_ci95 = 100.0 * _weighted_bernoulli_ci95( - contaminated_flags, - weights, - ) - return { - "border_contamination_fast": float(contam), - "border_enrichment_fast": float(enrich), - "border_excess_pct_fast": float(border_excess_pct), - "border_contaminated_cells_pct_fast": float(contaminated_cells_pct), - "border_contaminated_cells_pct_ci95_fast": float(contaminated_cells_pct_ci95), - "border_metric_cells_used": int(per_cell.height), - } - - -def compute_transcript_centroid_offset_fast( - assigned_tx: pl.DataFrame, - *, - cell_id_column: str = "segger_cell_id", - x_column: str = "x", - y_column: str = "y", - min_transcripts_per_cell: int = 20, - max_cells: int = 3000, - seed: int = 0, -) -> dict[str, float]: - """Fast transcript centroid-offset metric (higher is better). - - Uses a bounding-box center as cell centroid approximation. - """ - req = [cell_id_column, x_column, y_column] - for col in req: - if col not in assigned_tx.columns: - return { - "transcript_centroid_offset_fast": float("nan"), - "transcript_centroid_offset_ci95_fast": float("nan"), - "tco_metric_cells_used": 0, - } - - stats = ( - assigned_tx.select(req) - .drop_nulls() - .group_by(cell_id_column) - .agg( - pl.len().alias("n_total"), - pl.col(x_column).mean().alias("tx_cx"), - pl.col(y_column).mean().alias("tx_cy"), - pl.col(x_column).min().alias("x_min"), - pl.col(x_column).max().alias("x_max"), - pl.col(y_column).min().alias("y_min"), - pl.col(y_column).max().alias("y_max"), - ) - .filter(pl.col("n_total") >= min_transcripts_per_cell) - .with_columns( - (pl.col("x_max") - pl.col("x_min")).alias("width"), - (pl.col("y_max") - pl.col("y_min")).alias("height"), - ) - .filter((pl.col("width") > 0) & (pl.col("height") > 0)) - .with_columns( - ((pl.col("x_min") + pl.col("x_max")) / 2.0).alias("cell_cx"), - ((pl.col("y_min") + pl.col("y_max")) / 2.0).alias("cell_cy"), - (pl.col("width") * pl.col("height")).alias("area"), - ) - .filter(pl.col("area") > 0) - ) - - if stats.height == 0: - return { - "transcript_centroid_offset_fast": float("nan"), - "transcript_centroid_offset_ci95_fast": float("nan"), - "tco_metric_cells_used": 0, - } - - if max_cells > 0 and stats.height > max_cells: - rng = np.random.default_rng(seed) - ids = np.array(stats.get_column(cell_id_column).to_list(), dtype=object) - picked = rng.choice(ids, size=max_cells, replace=False).tolist() - stats = stats.filter(pl.col(cell_id_column).is_in(picked)) - - stats = stats.with_columns( - ( - ( - (pl.col("tx_cx") - pl.col("cell_cx")) ** 2 - + (pl.col("tx_cy") - pl.col("cell_cy")) ** 2 - ).sqrt() - ).alias("centroid_offset") - ).with_columns( - ( - 1.0 - (pl.col("centroid_offset") / (pl.col("area").sqrt() + 1e-9)) - ).clip(lower_bound=0.0, upper_bound=1.0).alias("tco_score") - ) - - weights = stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) - tco_vals = stats.get_column("tco_score").to_numpy().astype(np.float64, copy=False) - tco = float(np.average(tco_vals, weights=weights)) - tco_ci95 = _weighted_mean_ci95(tco_vals, weights) - return { - "transcript_centroid_offset_fast": tco, - "transcript_centroid_offset_ci95_fast": float(tco_ci95), - "tco_metric_cells_used": int(stats.height), - } - - -def compute_signal_doublet_fast( - assigned_tx: pl.DataFrame, - *, - cell_id_column: str = "segger_cell_id", - z_column: str = "z", - min_transcripts_per_cell: int = 20, - max_cells: int = 3000, - seed: int = 0, - doublet_threshold: float = 0.6, -) -> dict[str, float]: - """Fast 3D doublet-like fraction based on per-cell z-spread.""" - if z_column not in assigned_tx.columns or cell_id_column not in assigned_tx.columns: - return { - "signal_doublet_like_fraction_fast": float("nan"), - "signal_doublet_like_fraction_ci95_fast": float("nan"), - "signal_metric_cells_used": 0, - } - - stats = ( - assigned_tx.select([cell_id_column, z_column]) - .drop_nulls() - .group_by(cell_id_column) - .agg( - pl.len().alias("n_total"), - pl.col(z_column).std().alias("z_std"), - ) - .filter(pl.col("n_total") >= min_transcripts_per_cell) - .drop_nulls(["z_std"]) - ) - - if stats.height == 0: - return { - "signal_doublet_like_fraction_fast": float("nan"), - "signal_doublet_like_fraction_ci95_fast": float("nan"), - "signal_metric_cells_used": 0, - } - - if max_cells > 0 and stats.height > max_cells: - rng = np.random.default_rng(seed) - ids = np.array(stats.get_column(cell_id_column).to_list(), dtype=object) - picked = rng.choice(ids, size=max_cells, replace=False).tolist() - stats = stats.filter(pl.col(cell_id_column).is_in(picked)) - - n_total = stats.get_column("n_total").to_numpy().astype(np.float64, copy=False) - z_std = stats.get_column("z_std").to_numpy().astype(np.float64, copy=False) - z_std = np.where(np.isfinite(z_std), z_std, np.nan) - positive = z_std[np.isfinite(z_std) & (z_std > 0)] - - if positive.size == 0: - ci95 = _weighted_bernoulli_ci95( - np.zeros_like(n_total, dtype=np.float64), - n_total, - ) - return { - "signal_doublet_like_fraction_fast": 0.0, - "signal_doublet_like_fraction_ci95_fast": float(ci95), - "signal_metric_cells_used": int(stats.height), - } - - expected = float(np.median(positive)) - if expected <= 1e-12: - ci95 = _weighted_bernoulli_ci95( - np.zeros_like(n_total, dtype=np.float64), - n_total, - ) - return { - "signal_doublet_like_fraction_fast": 0.0, - "signal_doublet_like_fraction_ci95_fast": float(ci95), - "signal_metric_cells_used": int(stats.height), - } - - integrity = np.clip(expected / (z_std + 1e-9), 0.0, 1.0) - doublet_flags = (integrity < doublet_threshold).astype(np.float64) - doublet_like = float(np.average(doublet_flags, weights=n_total)) - doublet_like_ci95 = _weighted_bernoulli_ci95(doublet_flags, n_total) - return { - "signal_doublet_like_fraction_fast": doublet_like, - "signal_doublet_like_fraction_ci95_fast": float(doublet_like_ci95), - "signal_metric_cells_used": int(stats.height), - } - - -def load_me_gene_pairs( - *, - me_gene_pairs_path: Optional[Path] = None, - scrna_reference_path: Optional[Path] = None, - scrna_celltype_column: str = "cell_type", -) -> list[tuple[str, str]]: - """Load mutually exclusive gene pairs from file or scRNA reference.""" - if me_gene_pairs_path is not None: - pairs: list[tuple[str, str]] = [] - with Path(me_gene_pairs_path).open("r", encoding="utf-8") as fh: - for raw_line in fh: - line = raw_line.strip() - if not line or line.startswith("#"): - continue - if "\t" in line: - parts = [p.strip() for p in line.split("\t")] - elif "," in line: - parts = [p.strip() for p in line.split(",")] - else: - parts = [p.strip() for p in line.split()] - if len(parts) < 2: - continue - if parts[0].lower() in {"gene1", "gene_a"} and parts[1].lower() in {"gene2", "gene_b"}: - continue - pairs.append((parts[0], parts[1])) - return pairs - - if scrna_reference_path is not None: - pairs, _ = load_me_genes_from_scrna( - scrna_path=Path(scrna_reference_path), - cell_type_column=scrna_celltype_column, - ) - return [(str(g1), str(g2)) for g1, g2 in pairs] - - return [] - - -def compute_mecr_fast( - anndata_path: Path, - gene_pairs: Sequence[tuple[str, str]], - *, - max_pairs: int = 500, - soft: bool = True, - seed: int = 0, -) -> dict[str, float]: - """Compute fast MECR from an AnnData file (lower is better).""" - if anndata_path is None or not Path(anndata_path).exists(): - return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} - if len(gene_pairs) == 0: - return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} - - adata = ad.read_h5ad(anndata_path) - gene_to_idx = {str(g): i for i, g in enumerate(adata.var_names)} - - valid_pairs: list[tuple[int, int]] = [] - for g1, g2 in gene_pairs: - i = gene_to_idx.get(str(g1)) - j = gene_to_idx.get(str(g2)) - if i is None or j is None: - continue - valid_pairs.append((i, j)) - - if len(valid_pairs) == 0: - return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} - - if max_pairs > 0 and len(valid_pairs) > max_pairs: - rng = np.random.default_rng(seed) - pick = rng.choice(len(valid_pairs), size=max_pairs, replace=False) - valid_pairs = [valid_pairs[int(i)] for i in pick] - - X = adata.X - is_sparse = sparse.issparse(X) - if is_sparse: - X = X.tocsc() - else: - X = np.asarray(X) - - vals: list[float] = [] - for i, j in valid_pairs: - if is_sparse: - a = np.asarray(X.getcol(i).toarray()).ravel() - b = np.asarray(X.getcol(j).toarray()).ravel() - else: - a = np.asarray(X[:, i]).ravel() - b = np.asarray(X[:, j]).ravel() - - if soft: - den = float(np.maximum(a, b).sum()) - if den <= 0: - continue - num = float(np.minimum(a, b).sum()) - vals.append(num / den) - else: - a_bin = a > 0 - b_bin = b > 0 - either = float((a_bin | b_bin).sum()) - if either <= 0: - continue - both = float((a_bin & b_bin).sum()) - vals.append(both / either) - - if len(vals) == 0: - return {"mecr_fast": float("nan"), "mecr_ci95_fast": float("nan"), "mecr_pairs_used": 0} - - arr = np.asarray(vals, dtype=np.float64) - ci95 = float("nan") - if arr.size > 1: - se = float(np.std(arr, ddof=1)) / np.sqrt(float(arr.size)) - ci95 = float(1.96 * se) - - return { - "mecr_fast": float(np.mean(arr)), - "mecr_ci95_fast": float(ci95), - "mecr_pairs_used": int(len(vals)), - } From 149af1d59f6ff841c7c158da89df40d4003dc1b5 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 19:22:10 +0200 Subject: [PATCH 19/20] Lower bound PCA dimensionality to number of genes --- src/segger/data/utils/anndata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/segger/data/utils/anndata.py b/src/segger/data/utils/anndata.py index f2ba8dc..725f1ef 100644 --- a/src/segger/data/utils/anndata.py +++ b/src/segger/data/utils/anndata.py @@ -195,7 +195,6 @@ def setup_anndata( # Build gene embedding on filtered dataset C = np.corrcoef(ad[ad.obs['filtered']].layers['norm'].todense().T) C = np.nan_to_num(C, 0, posinf=True, neginf=True) - # model = sklearn.decomposition.PCA(n_components=cells_embedding_size) model = sklearn.decomposition.PCA(n_components=min(cells_embedding_size, ad.var.shape[0])) if ad.var.shape[0] < cells_embedding_size: import warnings From 29e8028d9d421ed0e5d7b5e7a172be170af1dee4 Mon Sep 17 00:00:00 2001 From: enric-bazz Date: Wed, 6 May 2026 19:22:41 +0200 Subject: [PATCH 20/20] Move required functions within spatialdata loader module --- src/segger/export/minimal_apis.py | 410 ------------------------ src/segger/export/spatialdata_writer.py | 224 ++++++++++++- 2 files changed, 220 insertions(+), 414 deletions(-) delete mode 100644 src/segger/export/minimal_apis.py diff --git a/src/segger/export/minimal_apis.py b/src/segger/export/minimal_apis.py deleted file mode 100644 index 4f0ddee..0000000 --- a/src/segger/export/minimal_apis.py +++ /dev/null @@ -1,410 +0,0 @@ - -from __future__ import annotations - -from enum import Enum -from pathlib import Path -from typing import Optional, Union, Any, Protocol, runtime_checkable - -from typing import Optional, Union - -import numpy as np -import pandas as pd -import polars as pl -from anndata import AnnData -from scipy import sparse as sp - - -class OutputFormat(str, Enum): - """Available output formats for segmentation results. - - Attributes - ---------- - SEGGER_RAW : str - Default Segger output format. Writes predictions as Parquet file - with columns: row_index, segger_cell_id, segger_similarity. - - MERGED_TRANSCRIPTS : str - Merged transcripts format. Original transcript data with segmentation - results joined (segger_cell_id, segger_similarity columns added). - - SPATIALDATA : str - SpatialData Zarr format. Creates a .zarr store compatible with - the scverse ecosystem, containing transcripts and optional boundaries. - - ANNDATA : str - AnnData format. Creates a .h5ad file with a cell x gene matrix - derived from transcript assignments. - """ - - SEGGER_RAW = "segger_raw" - MERGED_TRANSCRIPTS = "merged" - SPATIALDATA = "spatialdata" - ANNDATA = "anndata" - - @classmethod - def from_string(cls, value: str) -> "OutputFormat": - """Parse OutputFormat from string, case-insensitive. - - Parameters - ---------- - value - Format name ('segger_raw', 'merged', 'spatialdata', 'anndata', or 'all'). - - Returns - ------- - OutputFormat - Corresponding enum value. - - Raises - ------ - ValueError - If value is not a valid format name. - """ - value_lower = value.lower().strip() - - # Handle aliases - aliases = { - "raw": cls.SEGGER_RAW, - "segger": cls.SEGGER_RAW, - "default": cls.SEGGER_RAW, - "merge": cls.MERGED_TRANSCRIPTS, - "merged": cls.MERGED_TRANSCRIPTS, - "transcripts": cls.MERGED_TRANSCRIPTS, - "sdata": cls.SPATIALDATA, - "zarr": cls.SPATIALDATA, - "h5ad": cls.ANNDATA, - "ann": cls.ANNDATA, - "anndata": cls.ANNDATA, - } - - if value_lower in aliases: - return aliases[value_lower] - - # Try direct match - for fmt in cls: - if fmt.value == value_lower: - return fmt - - valid = [f.value for f in cls] + list(aliases.keys()) - raise ValueError( - f"Unknown output format: '{value}'. " - f"Valid formats: {sorted(set(valid))}" - ) - - - -@runtime_checkable -class OutputWriter(Protocol): - """Protocol for output format writers. - - Implementations must provide a `write` method that writes segmentation - results to the specified output directory. - """ - - def write( - self, - predictions: "pl.DataFrame", - output_dir: Path, - **kwargs: Any, - ) -> Path: - """Write segmentation results to output format. - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. Must contain: - - row_index: Original transcript row index - - segger_cell_id: Assigned cell ID (or -1/None for unassigned) - - segger_similarity: Assignment confidence score - - output_dir - Directory to write output files. - - **kwargs - Format-specific options (e.g., transcripts, boundaries). - - Returns - ------- - Path - Path to the primary output file/directory. - """ - ... - - -# Registry of output writers by format -_OUTPUT_WRITERS: dict[OutputFormat, type] = {} - -def register_writer(fmt: OutputFormat): - """Decorator to register an output writer class. - - Parameters - ---------- - fmt - Output format this writer handles. - - Returns - ------- - decorator - Class decorator that registers the writer. - - Examples - -------- - >>> @register_writer(OutputFormat.MERGED_TRANSCRIPTS) - ... class MergedTranscriptsWriter: - ... def write(self, predictions, output_dir, **kwargs): - ... ... - """ - def decorator(cls): - _OUTPUT_WRITERS[fmt] = cls - return cls - return decorator - - -def get_writer(fmt: OutputFormat | str, **init_kwargs: Any) -> OutputWriter: - """Get an output writer for the specified format. - - Parameters - ---------- - fmt - Output format (enum or string). - **init_kwargs - Keyword arguments passed to the writer constructor. - - Returns - ------- - OutputWriter - Writer instance for the specified format. - - Raises - ------ - ValueError - If format is not recognized or writer not registered. - - Examples - -------- - >>> writer = get_writer(OutputFormat.MERGED_TRANSCRIPTS, unassigned_marker=-1) - >>> writer.write(predictions, Path("output/")) - """ - if isinstance(fmt, str): - fmt = OutputFormat.from_string(fmt) - - if fmt not in _OUTPUT_WRITERS: - raise ValueError( - f"No writer registered for format: {fmt.value}. " - f"Available formats: {[f.value for f in _OUTPUT_WRITERS.keys()]}" - ) - - writer_cls = _OUTPUT_WRITERS[fmt] - return writer_cls(**init_kwargs) - - - -### ANNDATA EXPORT ### - -def build_anndata_table( - transcripts: pl.DataFrame, - cell_id_column: str = "segger_cell_id", - feature_column: str = "feature_name", - x_column: Optional[str] = "x", - y_column: Optional[str] = "y", - z_column: Optional[str] = "z", - unassigned_value: Union[int, str, None] = -1, - region: Optional[str] = None, - region_key: Optional[str] = None, - obs_index_as_str: bool = False, -) -> AnnData: - """Build AnnData from assigned transcripts. - - Parameters - ---------- - transcripts - Transcript DataFrame with segmentation assignments. - cell_id_column - Column with assigned cell IDs. - feature_column - Column with gene/feature names. - x_column, y_column, z_column - Coordinate columns (optional). If present, centroids are stored in - ``obsm["X_spatial"]``. - unassigned_value - Marker for unassigned transcripts (filtered out). - region, region_key - SpatialData table linkage metadata. - obs_index_as_str - If True, cast cell IDs to string for ``obs`` index. - """ - if cell_id_column not in transcripts.columns: - raise ValueError(f"Missing cell_id column: {cell_id_column}") - if feature_column not in transcripts.columns: - raise ValueError(f"Missing feature column: {feature_column}") - - assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) - if unassigned_value is not None: - col_dtype = transcripts.schema.get(cell_id_column) - try: - compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() - filter_expr = pl.col(cell_id_column) != compare_value - except Exception: - filter_expr = ( - pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) - ) - assigned = assigned.filter(filter_expr) - - # Gene list from all transcripts (even if no assignments) - var_idx = ( - transcripts - .select(feature_column) - .unique() - .sort(feature_column) - .get_column(feature_column) - .to_list() - ) - - if assigned.height == 0: - obs_index = pd.Index([], name=cell_id_column) - if obs_index_as_str: - var_index = pd.Index([str(v) for v in var_idx], name=feature_column) - else: - var_index = pd.Index(var_idx, name=feature_column) - X = sp.csr_matrix((0, len(var_index))) - adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) - if region is not None: - adata.obs["region"] = region - if region_key is not None: - adata.obs["region_key"] = region_key - return adata - - feature_idx = ( - assigned - .select(feature_column) - .unique() - .sort(feature_column) - .with_row_index(name="_fid") - ) - cell_idx = ( - assigned - .select(cell_id_column) - .unique() - .sort(cell_id_column) - .with_row_index(name="_cid") - ) - - mapped = ( - assigned - .join(feature_idx, on=feature_column) - .join(cell_idx, on=cell_id_column) - ) - counts = ( - mapped - .group_by(["_cid", "_fid"]) - .agg(pl.len().alias("_count")) - ) - ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T - rows = ijv[0].astype(np.int64, copy=False) - cols = ijv[1].astype(np.int64, copy=False) - data = ijv[2].astype(np.int64, copy=False) - - n_cells = cell_idx.height - n_genes = feature_idx.height - X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() - - obs_ids = cell_idx.get_column(cell_id_column).to_list() - var_ids = feature_idx.get_column(feature_column).to_list() - if obs_index_as_str: - obs_ids = [str(v) for v in obs_ids] - var_ids = [str(v) for v in var_ids] - - adata = AnnData( - X=X, - obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), - var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), - ) - - # Add centroid coordinates if present - if x_column in assigned.columns and y_column in assigned.columns: - coords_cols = [x_column, y_column] - if z_column and z_column in assigned.columns: - coords_cols.append(z_column) - centroids = ( - assigned - .group_by(cell_id_column) - .agg([pl.col(c).mean().alias(c) for c in coords_cols]) - ) - centroids_pd = ( - centroids - .to_pandas() - .set_index(cell_id_column) - .reindex(adata.obs.index) - ) - adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() - - if region is not None: - adata.obs["region"] = region - if region_key is not None: - adata.obs["region_key"] = region_key - - return adata - -### MERGED EXPORT ### - -def merge_predictions_with_transcripts( - predictions: pl.DataFrame, - transcripts: pl.DataFrame, - row_index_column: str = "row_index", - cell_id_column: str = "segger_cell_id", - similarity_column: str = "segger_similarity", - unassigned_marker: Union[int, str, None] = -1, -) -> pl.DataFrame: - """Merge predictions with transcripts (functional interface). - - Parameters - ---------- - predictions - DataFrame with segmentation predictions. - transcripts - Original transcripts DataFrame. - row_index_column - Column name for row index. - cell_id_column - Column name for cell ID in predictions. - similarity_column - Column name for similarity in predictions. - unassigned_marker - Value for unassigned transcripts. - - Returns - ------- - pl.DataFrame - Merged DataFrame with all original columns plus predictions. - - Examples - -------- - >>> merged = merge_predictions_with_transcripts(predictions, transcripts) - >>> print(merged.columns) - ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] - """ - # Prepare predictions - pred_cols = [row_index_column, cell_id_column] - if similarity_column in predictions.columns: - pred_cols.append(similarity_column) - - pred_subset = predictions.select(pred_cols) - - # Add row_index if missing - if row_index_column not in transcripts.columns: - transcripts = transcripts.with_row_index(name=row_index_column) - - # Join - merged = transcripts.join(pred_subset, on=row_index_column, how="left") - - # Fill unassigned - if unassigned_marker is not None: - merged = merged.with_columns( - pl.col(cell_id_column).fill_null(unassigned_marker) - ) - if similarity_column in merged.columns: - merged = merged.with_columns( - pl.col(similarity_column).fill_null(0.0) - ) - - return merged diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index b0f10cf..761d3bb 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -28,21 +28,24 @@ import warnings from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional, Union +import numpy as np +import pandas as pd import polars as pl +from anndata import AnnData +from scipy import sparse as sp + from segger.utils.optional_deps import ( require_spatialdata, ) -from segger.export.minimal_apis import OutputFormat, register_writer, build_anndata_table - if TYPE_CHECKING: import geopandas as gpd from spatialdata import SpatialData -@register_writer(OutputFormat.SPATIALDATA) +# @register_writer(OutputFormat.SPATIALDATA) class SpatialDataWriter: """Write segmentation results as SpatialData Zarr store. @@ -579,3 +582,216 @@ def write_spatialdata( output_name=output_name, **kwargs, ) + + +### APIs from other exporting formats in v2-incremental ### + +### ANNDATA EXPORT ### + +def build_anndata_table( + transcripts: pl.DataFrame, + cell_id_column: str = "segger_cell_id", + feature_column: str = "feature_name", + x_column: Optional[str] = "x", + y_column: Optional[str] = "y", + z_column: Optional[str] = "z", + unassigned_value: Union[int, str, None] = -1, + region: Optional[str] = None, + region_key: Optional[str] = None, + obs_index_as_str: bool = False, +) -> AnnData: + """Build AnnData from assigned transcripts. + + Parameters + ---------- + transcripts + Transcript DataFrame with segmentation assignments. + cell_id_column + Column with assigned cell IDs. + feature_column + Column with gene/feature names. + x_column, y_column, z_column + Coordinate columns (optional). If present, centroids are stored in + ``obsm["X_spatial"]``. + unassigned_value + Marker for unassigned transcripts (filtered out). + region, region_key + SpatialData table linkage metadata. + obs_index_as_str + If True, cast cell IDs to string for ``obs`` index. + """ + if cell_id_column not in transcripts.columns: + raise ValueError(f"Missing cell_id column: {cell_id_column}") + if feature_column not in transcripts.columns: + raise ValueError(f"Missing feature column: {feature_column}") + + assigned = transcripts.filter(pl.col(cell_id_column).is_not_null()) + if unassigned_value is not None: + col_dtype = transcripts.schema.get(cell_id_column) + try: + compare_value = pl.Series([unassigned_value]).cast(col_dtype).item() + filter_expr = pl.col(cell_id_column) != compare_value + except Exception: + filter_expr = ( + pl.col(cell_id_column).cast(pl.Utf8) != str(unassigned_value) + ) + assigned = assigned.filter(filter_expr) + + # Gene list from all transcripts (even if no assignments) + var_idx = ( + transcripts + .select(feature_column) + .unique() + .sort(feature_column) + .get_column(feature_column) + .to_list() + ) + + if assigned.height == 0: + obs_index = pd.Index([], name=cell_id_column) + if obs_index_as_str: + var_index = pd.Index([str(v) for v in var_idx], name=feature_column) + else: + var_index = pd.Index(var_idx, name=feature_column) + X = sp.csr_matrix((0, len(var_index))) + adata = AnnData(X=X, obs=pd.DataFrame(index=obs_index), var=pd.DataFrame(index=var_index)) + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + return adata + + feature_idx = ( + assigned + .select(feature_column) + .unique() + .sort(feature_column) + .with_row_index(name="_fid") + ) + cell_idx = ( + assigned + .select(cell_id_column) + .unique() + .sort(cell_id_column) + .with_row_index(name="_cid") + ) + + mapped = ( + assigned + .join(feature_idx, on=feature_column) + .join(cell_idx, on=cell_id_column) + ) + counts = ( + mapped + .group_by(["_cid", "_fid"]) + .agg(pl.len().alias("_count")) + ) + ijv = counts.select(["_cid", "_fid", "_count"]).to_numpy().T + rows = ijv[0].astype(np.int64, copy=False) + cols = ijv[1].astype(np.int64, copy=False) + data = ijv[2].astype(np.int64, copy=False) + + n_cells = cell_idx.height + n_genes = feature_idx.height + X = sp.coo_matrix((data, (rows, cols)), shape=(n_cells, n_genes)).tocsr() + + obs_ids = cell_idx.get_column(cell_id_column).to_list() + var_ids = feature_idx.get_column(feature_column).to_list() + if obs_index_as_str: + obs_ids = [str(v) for v in obs_ids] + var_ids = [str(v) for v in var_ids] + + adata = AnnData( + X=X, + obs=pd.DataFrame(index=pd.Index(obs_ids, name=cell_id_column)), + var=pd.DataFrame(index=pd.Index(var_ids, name=feature_column)), + ) + + # Add centroid coordinates if present + if x_column in assigned.columns and y_column in assigned.columns: + coords_cols = [x_column, y_column] + if z_column and z_column in assigned.columns: + coords_cols.append(z_column) + centroids = ( + assigned + .group_by(cell_id_column) + .agg([pl.col(c).mean().alias(c) for c in coords_cols]) + ) + centroids_pd = ( + centroids + .to_pandas() + .set_index(cell_id_column) + .reindex(adata.obs.index) + ) + adata.obsm["X_spatial"] = centroids_pd[coords_cols].to_numpy() + + if region is not None: + adata.obs["region"] = region + if region_key is not None: + adata.obs["region_key"] = region_key + + return adata + +### MERGED EXPORT ### + +def merge_predictions_with_transcripts( + predictions: pl.DataFrame, + transcripts: pl.DataFrame, + row_index_column: str = "row_index", + cell_id_column: str = "segger_cell_id", + similarity_column: str = "segger_similarity", + unassigned_marker: Union[int, str, None] = -1, +) -> pl.DataFrame: + """Merge predictions with transcripts (functional interface). + + Parameters + ---------- + predictions + DataFrame with segmentation predictions. + transcripts + Original transcripts DataFrame. + row_index_column + Column name for row index. + cell_id_column + Column name for cell ID in predictions. + similarity_column + Column name for similarity in predictions. + unassigned_marker + Value for unassigned transcripts. + + Returns + ------- + pl.DataFrame + Merged DataFrame with all original columns plus predictions. + + Examples + -------- + >>> merged = merge_predictions_with_transcripts(predictions, transcripts) + >>> print(merged.columns) + ['row_index', 'x', 'y', 'feature_name', 'segger_cell_id', 'segger_similarity'] + """ + # Prepare predictions + pred_cols = [row_index_column, cell_id_column] + if similarity_column in predictions.columns: + pred_cols.append(similarity_column) + + pred_subset = predictions.select(pred_cols) + + # Add row_index if missing + if row_index_column not in transcripts.columns: + transcripts = transcripts.with_row_index(name=row_index_column) + + # Join + merged = transcripts.join(pred_subset, on=row_index_column, how="left") + + # Fill unassigned + if unassigned_marker is not None: + merged = merged.with_columns( + pl.col(cell_id_column).fill_null(unassigned_marker) + ) + if similarity_column in merged.columns: + merged = merged.with_columns( + pl.col(similarity_column).fill_null(0.0) + ) + + return merged