diff --git a/conf/base.config b/conf/base.config index e1668aaf..45b62604 100644 --- a/conf/base.config +++ b/conf/base.config @@ -16,6 +16,15 @@ process { // resourceLimits = [ cpus: 192, memory: 750.GB, time: 72.h ] + // Retry signal-induced exits and "killed without exit code" cases: + // 130..145 = signal exits (SIGINT=130, SIGKILL=137, SIGTERM=143, etc.) + // 104 = ECONNRESET (transient network failures during stage-in/out) + // 2147483647 = Integer.MAX_VALUE, Nextflow's sentinel for tasks that died + // before writing .exitcode (Nextflow surfaces this as + // "terminated for an unknown reason -- Likely it has been + // terminated by the external system"). Common on AWS Batch + // spot capacity, kubernetes preemption, and grid-scheduler + // cancellations. See nextflow docs/aws.md for the AWS case. errorStrategy = { task.exitStatus in ((130..145) + 104 + 2147483647) ? 'retry' : 'finish' } maxRetries = 3 maxErrors = '-1' diff --git a/conf/modules.config b/conf/modules.config index 22726919..bb5a8786 100644 --- a/conf/modules.config +++ b/conf/modules.config @@ -157,6 +157,7 @@ process { ext.filter_method = params.patch_filter_method ?: null ext.iqr_multiplier = params.patch_filter_iqr_multiplier ext.z_threshold = params.patch_filter_z_threshold + ext.args = { "--min-transcripts-per-cell ${params.baysor_tiling_min_transcripts_per_cell}" } publishDir = [ path: { "${params.outdir}/${meta.id}/xenium_patch" }, mode: params.publish_dir_mode, @@ -307,7 +308,7 @@ process { params.stardist_prob_thresh != null ? "--prob_thresh ${params.stardist_prob_thresh}" : "", params.stardist_nms_thresh != null ? "--nms_thresh ${params.stardist_nms_thresh}" : "", params.stardist_n_tiles != null ? "--n_tiles ${params.stardist_n_tiles}" : "", - ].findAll().join(' ')} + ].join(' ').trim()} } withName: 'STARDIST_NUCLEI' { diff --git a/modules/local/segger/create_dataset/main.nf b/modules/local/segger/create_dataset/main.nf index 06848c41..81320eff 100644 --- a/modules/local/segger/create_dataset/main.nf +++ b/modules/local/segger/create_dataset/main.nf @@ -22,7 +22,6 @@ process SEGGER_CREATE_DATASET { } def args = task.ext.args ?: '' - def script_path = "/workspace/segger_dev/src/segger/cli/create_dataset_fast.py" prefix = task.ext.prefix ?: "${meta.id}" // check for platform values @@ -31,193 +30,17 @@ process SEGGER_CREATE_DATASET { } """ - # Set numba cache directory to avoid caching issues in container export NUMBA_CACHE_DIR=\$PWD/.numba_cache mkdir -p \$NUMBA_CACHE_DIR - # Create local bundle directory with symlinks to all original files - # This is necessary because input files from S3/Fusion are read-only - # Use absolute paths to avoid broken relative symlinks - mkdir -p bundle_local - for item in ${base_dir}/*; do - # Resolve to absolute path (follow any symlinks) - abs_path=\$(readlink -f "\$item" 2>/dev/null || realpath "\$item" 2>/dev/null || echo "\$item") - basename=\$(basename "\$item") - ln -sf "\$abs_path" "bundle_local/\$basename" - done - - # Segger expects nucleus_boundaries.parquet but Xenium bundles have cell_boundaries.parquet - # Create the symlink if nucleus_boundaries doesn't exist but cell_boundaries does - if [ ! -e "bundle_local/nucleus_boundaries.parquet" ] && [ -e "bundle_local/cell_boundaries.parquet" ]; then - echo "Creating nucleus_boundaries.parquet symlink from cell_boundaries.parquet" - cell_bounds_path=\$(readlink -f "bundle_local/cell_boundaries.parquet" 2>/dev/null || realpath "bundle_local/cell_boundaries.parquet" 2>/dev/null) - ln -sf "\$cell_bounds_path" bundle_local/nucleus_boundaries.parquet - fi - - # List bundle contents for debugging - echo "Bundle contents:" - ls -la bundle_local/ - - # Fix: Add parquet column statistics for segger - echo "Adding statistics to parquet files..." - python3 - << 'PYEOF' -import pyarrow.parquet as pq -import os - -def add_stats(inp, out): - if not os.path.exists(inp): - print(f" Skip {inp}") - return - t = pq.read_table(inp) - pq.write_table(t, out, write_statistics=True, compression='snappy') - print(f" Done {os.path.basename(inp)} ({len(t)} rows)") - -os.makedirs('bundle_stats', exist_ok=True) -for f in ['transcripts.parquet', 'nucleus_boundaries.parquet']: - add_stats(f'bundle_local/{f}', f'bundle_stats/{f}') - -for item in os.listdir('bundle_local'): - s, d = f'bundle_local/{item}', f'bundle_stats/{item}' - if not os.path.exists(d): - os.symlink(os.path.realpath(s), d) -print("Done") - -# Debug: Check overlaps_nucleus column data -print("") -print("=== Debugging overlaps_nucleus data ===") -import pyarrow.compute as pc - -tx = pq.read_table('bundle_stats/transcripts.parquet') -bd = pq.read_table('bundle_stats/nucleus_boundaries.parquet') - -if 'overlaps_nucleus' in tx.column_names: - col = tx.column('overlaps_nucleus') - print(f"overlaps_nucleus dtype: {col.type}") - unique_vals = pc.unique(col) - print(f"overlaps_nucleus unique values: {unique_vals.to_pylist()[:10]}") - val_counts = pc.value_counts(col) - print(f"overlaps_nucleus value_counts: {val_counts.to_pylist()}") -else: - print("WARNING: overlaps_nucleus column NOT FOUND in transcripts.parquet") - -# Check cell_id overlap between transcripts and boundaries -if 'cell_id' in tx.column_names and 'cell_id' in bd.column_names: - tx_cells = set(pc.unique(tx.column('cell_id')).to_pylist()) - bd_cells = set(pc.unique(bd.column('cell_id')).to_pylist()) - overlap = tx_cells & bd_cells - print("") - print(f"Transcripts unique cell_ids: {len(tx_cells)}") - print(f"Boundaries unique cell_ids: {len(bd_cells)}") - print(f"Overlapping cell_ids: {len(overlap)}") - -print("=== End Debug ===") -PYEOF - ls -la bundle_stats/ - - python3 ${script_path} \\ - --base_dir bundle_stats \\ - --data_dir ${prefix} \\ - --sample_type ${params.format} \\ - --tile_width ${params.tile_width} \\ - --tile_height ${params.tile_height} \\ - --n_workers ${task.cpus} \\ + run_create_dataset.py \\ + --bundle-dir ${base_dir} \\ + --output-dir ${prefix} \\ + --sample-type ${params.format} \\ + --tile-width ${params.tile_width} \\ + --tile-height ${params.tile_height} \\ + --n-workers ${task.cpus} \\ ${args} - - # Verify tiles were created and show distribution - echo "Dataset split (before fix):" - echo " train_tiles: \$(ls ${prefix}/train_tiles/processed/ 2>/dev/null | wc -l) files" - echo " val_tiles: \$(ls ${prefix}/val_tiles/processed/ 2>/dev/null | wc -l) files" - echo " test_tiles: \$(ls ${prefix}/test_tiles/processed/ 2>/dev/null | wc -l) files" - - # Workaround: segger commit 0787167 has a bug where all tiles go to test_tiles - # regardless of test_prob/val_prob settings. Move ONLY trainable tiles (those with - # edge_label_index) from test_tiles to train_tiles. - # Tiles without tx-belongs-bd edges don't have edge_label_index and cannot be used for training. - train_count=\$(ls ${prefix}/train_tiles/processed/ 2>/dev/null | wc -l) - test_count=\$(ls ${prefix}/test_tiles/processed/ 2>/dev/null | wc -l) - - if [ "\$train_count" -eq 0 ] && [ "\$test_count" -gt 0 ]; then - echo "Applying workaround: filtering trainable tiles from test_tiles (segger split bug)" - export SEGGER_PREFIX="${prefix}" - python3 - << 'PYEOF' -import torch -import os -import shutil - -prefix = os.environ['SEGGER_PREFIX'] -test_dir = f"{prefix}/test_tiles/processed" -train_dir = f"{prefix}/train_tiles/processed" - -moved = 0 -skipped = 0 - -for f in os.listdir(test_dir): - if not f.endswith('.pt'): - continue - fpath = os.path.join(test_dir, f) - try: - tile = torch.load(fpath, weights_only=False) - edge_store = tile['tx', 'belongs', 'bd'] - # Check if edge_label_index exists and has data - if hasattr(edge_store, 'edge_label_index') and edge_store.edge_label_index.numel() > 0: - shutil.move(fpath, os.path.join(train_dir, f)) - moved += 1 - else: - skipped += 1 - except Exception as e: - print(f"Warning: Could not process {f}: {e}") - skipped += 1 - -print(f"Moved {moved} trainable tiles to train_tiles") -print(f"Skipped {skipped} test-only tiles (no edge_label_index)") -PYEOF - fi - - echo "Dataset split (after fix):" - echo " train_tiles: \$(ls ${prefix}/train_tiles/processed/ 2>/dev/null | wc -l) files" - echo " val_tiles: \$(ls ${prefix}/val_tiles/processed/ 2>/dev/null | wc -l) files" - echo " test_tiles: \$(ls ${prefix}/test_tiles/processed/ 2>/dev/null | wc -l) files" - - train_tiles_dir="${prefix}/train_tiles/processed" - if [ ! -d "\$train_tiles_dir" ] || [ -z "\$(ls -A \$train_tiles_dir 2>/dev/null)" ]; then - echo "ERROR: No trainable tiles were created in \$train_tiles_dir" - echo "This usually means no transcripts overlap with nucleus boundaries in the dataset." - echo "Check if the Xenium bundle contains valid overlaps_nucleus data in transcripts.parquet." - exit 1 - fi - echo "Successfully created \$(ls \$train_tiles_dir | wc -l) trainable tiles" - - # Workaround: Segger's get_polygon_props() produces NaN boundary features (bd.x) - # when polygon geometries have zero area or index misalignment during GeoDataFrame - # construction. Replace NaN bd.x with zeros so BCEWithLogitsLoss doesn't propagate NaN. - export SEGGER_PREFIX="${prefix}" - python3 - << 'PYEOF' -import torch -import os - -prefix = os.environ['SEGGER_PREFIX'] -fixed = 0 -total = 0 - -for split in ['train_tiles', 'test_tiles', 'val_tiles']: - tile_dir = f"{prefix}/{split}/processed" - if not os.path.isdir(tile_dir): - continue - for f in os.listdir(tile_dir): - if not f.endswith('.pt'): - continue - total += 1 - fpath = os.path.join(tile_dir, f) - tile = torch.load(fpath, weights_only=False) - bd_x = tile['bd'].x - if bd_x.isnan().any(): - tile['bd'].x = torch.nan_to_num(bd_x, nan=0.0) - torch.save(tile, fpath) - fixed += 1 - -print(f"Fixed NaN bd.x in {fixed}/{total} tiles") -PYEOF - """ stub: diff --git a/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py b/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py new file mode 100755 index 00000000..c73ab006 --- /dev/null +++ b/modules/local/segger/create_dataset/resources/usr/bin/run_create_dataset.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Run segger create_dataset with spatialxe-specific preprocessing and workarounds. + +Wraps segger's create_dataset_fast.py with: + - bundle_local symlink prep (handles read-only S3/Fusion mounts) + - parquet column statistics (segger needs these) + - WORKAROUND: filter trainable tiles from test_tiles when segger commit 0787167 mis-splits + - WORKAROUND: replace NaN bd.x with zeros after get_polygon_props produces NaN + +Each WORKAROUND should be removable when the upstream segger bug is fixed. +""" + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + +# imports for actual work (used in functions below) +import pyarrow.parquet as pq +import pyarrow.compute as pc +import torch + + +SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/create_dataset_fast.py" + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--bundle-dir", required=True) + p.add_argument("--output-dir", required=True) + p.add_argument("--sample-type", required=True, choices=["xenium"]) + p.add_argument("--tile-width", type=int, required=True) + p.add_argument("--tile-height", type=int, required=True) + p.add_argument("--n-workers", type=int, required=True) + # remaining args forwarded to segger CLI + args, extra = p.parse_known_args() + return args, extra + + +def prepare_bundle(bundle_dir): + """Create local bundle dir with absolute symlinks (S3/Fusion read-only-safe).""" + Path("bundle_local").mkdir(exist_ok=True) + for item in Path(bundle_dir).iterdir(): + try: + abs_path = item.resolve() + except Exception: + abs_path = item + target = Path("bundle_local") / item.name + if target.exists() or target.is_symlink(): + target.unlink() + target.symlink_to(abs_path) + + # Segger expects nucleus_boundaries.parquet but Xenium bundles have cell_boundaries.parquet + nb = Path("bundle_local/nucleus_boundaries.parquet") + cb = Path("bundle_local/cell_boundaries.parquet") + if not nb.exists() and cb.exists(): + print( + "Creating nucleus_boundaries.parquet symlink from cell_boundaries.parquet" + ) + nb.symlink_to(cb.resolve()) + + print("Bundle contents:") + for item in sorted(Path("bundle_local").iterdir()): + print(f" {item.name}") + + +def add_parquet_stats(): + """Rewrite key parquet files with column statistics (segger requires them).""" + Path("bundle_stats").mkdir(exist_ok=True) + for fname in ["transcripts.parquet", "nucleus_boundaries.parquet"]: + src = Path("bundle_local") / fname + dst = Path("bundle_stats") / fname + if not src.exists(): + print(f" Skip {src}") + continue + t = pq.read_table(str(src)) + pq.write_table(t, str(dst), write_statistics=True, compression="snappy") + print(f" Done {fname} ({len(t)} rows)") + + # Symlink everything else from bundle_local into bundle_stats + for item in Path("bundle_local").iterdir(): + dst = Path("bundle_stats") / item.name + if not dst.exists(): + dst.symlink_to(item.resolve()) + + # Debug: check overlaps_nucleus column in transcripts + print("\n=== Debugging overlaps_nucleus data ===") + tx = pq.read_table("bundle_stats/transcripts.parquet") + bd = pq.read_table("bundle_stats/nucleus_boundaries.parquet") + if "overlaps_nucleus" in tx.column_names: + col = tx.column("overlaps_nucleus") + print(f"overlaps_nucleus dtype: {col.type}") + unique_vals = pc.unique(col) + print(f"overlaps_nucleus unique values: {unique_vals.to_pylist()[:10]}") + val_counts = pc.value_counts(col) + print(f"overlaps_nucleus value_counts: {val_counts.to_pylist()}") + else: + print("WARNING: overlaps_nucleus column NOT FOUND in transcripts.parquet") + + if "cell_id" in tx.column_names and "cell_id" in bd.column_names: + tx_cells = set(pc.unique(tx.column("cell_id")).to_pylist()) + bd_cells = set(pc.unique(bd.column("cell_id")).to_pylist()) + overlap = tx_cells & bd_cells + print(f"Transcripts unique cell_ids: {len(tx_cells)}") + print(f"Boundaries unique cell_ids: {len(bd_cells)}") + print(f"Overlapping cell_ids: {len(overlap)}") + print("=== End Debug ===\n") + + +def run_segger_cli(args, extra): + cmd = [ + "python3", + SEGGER_CLI, + "--base_dir", + "bundle_stats", + "--data_dir", + args.output_dir, + "--sample_type", + args.sample_type, + "--tile_width", + str(args.tile_width), + "--tile_height", + str(args.tile_height), + "--n_workers", + str(args.n_workers), + *extra, + ] + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd) + if result.returncode != 0: + sys.exit(result.returncode) + + +def filter_trainable_tiles_if_needed(prefix): + """ + WORKAROUND: segger commit 0787167 has a bug where all tiles end up in test_tiles + regardless of test_prob/val_prob settings. Move ONLY trainable tiles (those with + edge_label_index) from test_tiles to train_tiles. + + Remove this function once segger >= 0.1.x is bumped with the upstream fix. + """ + train_dir = Path(prefix) / "train_tiles" / "processed" + test_dir = Path(prefix) / "test_tiles" / "processed" + val_dir = Path(prefix) / "val_tiles" / "processed" + + train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 + test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 + val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 + print( + f"Dataset split (before fix): train={train_count} val={val_count} test={test_count}" + ) + + if train_count == 0 and test_count > 0: + print( + "Applying workaround: filtering trainable tiles from test_tiles (segger split bug)" + ) + moved = 0 + skipped = 0 + for tile_path in list(test_dir.iterdir()): + if not tile_path.name.endswith(".pt"): + continue + try: + tile = torch.load(str(tile_path), weights_only=False) + edge_store = tile["tx", "belongs", "bd"] + if ( + hasattr(edge_store, "edge_label_index") + and edge_store.edge_label_index.numel() > 0 + ): + shutil.move(str(tile_path), str(train_dir / tile_path.name)) + moved += 1 + else: + skipped += 1 + except Exception as e: + print(f"Warning: Could not process {tile_path.name}: {e}") + skipped += 1 + print(f"Moved {moved} trainable tiles to train_tiles") + print(f"Skipped {skipped} test-only tiles (no edge_label_index)") + + train_count = len(list(train_dir.iterdir())) if train_dir.exists() else 0 + test_count = len(list(test_dir.iterdir())) if test_dir.exists() else 0 + val_count = len(list(val_dir.iterdir())) if val_dir.exists() else 0 + print( + f"Dataset split (after fix): train={train_count} val={val_count} test={test_count}" + ) + + if train_count == 0: + print(f"ERROR: No trainable tiles were created in {train_dir}", file=sys.stderr) + print( + "This usually means no transcripts overlap with nucleus boundaries in the dataset.", + file=sys.stderr, + ) + print( + "Check if the Xenium bundle contains valid overlaps_nucleus data in transcripts.parquet.", + file=sys.stderr, + ) + sys.exit(1) + print(f"Successfully created {train_count} trainable tiles") + + +def fix_bd_x_nan(prefix): + """ + WORKAROUND: segger's get_polygon_props() produces NaN boundary features (bd.x) + when polygon geometries have zero area or index misalignment during GeoDataFrame + construction. Replace NaN bd.x with zeros so BCEWithLogitsLoss doesn't propagate NaN. + + Remove this function once segger >= 0.1.x is bumped with the upstream fix. + """ + fixed = 0 + total = 0 + for split in ["train_tiles", "test_tiles", "val_tiles"]: + tile_dir = Path(prefix) / split / "processed" + if not tile_dir.is_dir(): + continue + for tile_path in tile_dir.iterdir(): + if not tile_path.name.endswith(".pt"): + continue + total += 1 + tile = torch.load(str(tile_path), weights_only=False) + bd_x = tile["bd"].x + if bd_x.isnan().any(): + tile["bd"].x = torch.nan_to_num(bd_x, nan=0.0) + torch.save(tile, str(tile_path)) + fixed += 1 + print(f"Fixed NaN bd.x in {fixed}/{total} tiles") + + +def main(): + args, extra = parse_args() + + # Ensure numba cache dir is writable (env var should be set by caller, but belt-and-suspenders) + os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) + os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) + + prepare_bundle(args.bundle_dir) + print("Adding statistics to parquet files...") + add_parquet_stats() + + # Sanity-check bundle_stats + print("bundle_stats contents:") + for item in sorted(Path("bundle_stats").iterdir()): + print(f" {item.name}") + + run_segger_cli(args, extra) + + filter_trainable_tiles_if_needed(args.output_dir) + fix_bd_x_nan(args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/modules/local/segger/predict/main.nf b/modules/local/segger/predict/main.nf index a7aec75d..0da7a594 100644 --- a/modules/local/segger/predict/main.nf +++ b/modules/local/segger/predict/main.nf @@ -25,52 +25,17 @@ process SEGGER_PREDICT { } def args = task.ext.args ?: '' - def script_path = "/workspace/segger_dev/src/segger/cli/predict_fast.py" prefix = task.ext.prefix ?: "${meta.id}" """ - # Limit cupy GPU memory to 80% so PyTorch has headroom for graph attention ops - export CUPY_GPU_MEMORY_LIMIT="80%" - # Belt-and-suspenders: ensure PyTorch uses expandable segments (also set in env {} block) - export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True,max_split_size_mb:512" - - # Set numba cache directory to avoid caching issues in container - export NUMBA_CACHE_DIR=\$PWD/.numba_cache - mkdir -p \$NUMBA_CACHE_DIR - - # GPU detection logging - echo "=== GPU Detection (SEGGER_PREDICT) ===" - nvidia-smi 2>/dev/null && echo "GPU available: yes" || echo "GPU available: no (nvidia-smi failed)" - python3 -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'CUDA device count: {torch.cuda.device_count()}')" 2>/dev/null || echo "PyTorch CUDA check failed" - echo "======================================" - - # Use all available GPUs (autocast reduces VRAM ~50%, so multi-GPU is safe) - GPU_IDS=\$(python3 -c " -import torch -n = torch.cuda.device_count() -print(','.join(str(i) for i in range(n)) if n > 0 else '0') -" 2>/dev/null || echo "0") - echo "Using GPUs: \$GPU_IDS" - - # Patch predict_parquet.py at runtime (avoids Docker rebuild) - PRED_PY=\$(python3 -c "import segger.prediction.predict_parquet as m; print(m.__file__)") - - # 1. Add torch.no_grad() to disable gradient graphs during inference (~30-50% VRAM savings) - sed -i 's/with cp.cuda.Device(gpu_id):/with cp.cuda.Device(gpu_id), torch.no_grad():/' "\$PRED_PY" - - # 2. Seed random for deterministic GPU assignment (avoids stochastic OOM) - sed -i 's/gpu_id = random.choice(gpu_ids)/random.seed(0); gpu_id = random.choice(gpu_ids)/' "\$PRED_PY" - echo "Patched \$PRED_PY: torch.no_grad() + round-robin GPU assignment" - - python3 ${script_path} \\ - --models_dir ${models_dir} \\ - --segger_data_dir ${segger_dataset} \\ - --transcripts_file ${transcripts} \\ - --benchmarks_dir benchmarks_dir \\ - --batch_size ${params.batch_size_predict} \\ - --use_cc ${params.cc_analysis} \\ - --knn_method ${params.segger_knn_method} \\ - --num_workers ${task.cpus} \\ - --gpu_ids \$GPU_IDS \\ + run_predict.py \\ + --models-dir ${models_dir} \\ + --segger-data-dir ${segger_dataset} \\ + --transcripts-file ${transcripts} \\ + --benchmarks-dir benchmarks_dir \\ + --batch-size ${params.batch_size_predict} \\ + --use-cc ${params.cc_analysis} \\ + --knn-method ${params.segger_knn_method} \\ + --num-workers ${task.cpus} \\ ${args} """ diff --git a/modules/local/segger/predict/resources/usr/bin/run_predict.py b/modules/local/segger/predict/resources/usr/bin/run_predict.py new file mode 100755 index 00000000..56a77ffc --- /dev/null +++ b/modules/local/segger/predict/resources/usr/bin/run_predict.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +""" +Run segger predict with spatialxe-specific preprocessing. + +Wraps segger's predict_fast.py with: + - GPU enumeration (replaces inline python3 -c torch check) + - WORKAROUND: patch predict_parquet.py at runtime to add torch.no_grad() for ~30-50% VRAM savings + - WORKAROUND: seed random.choice for deterministic GPU assignment (avoids stochastic OOM) + +Both WORKAROUNDs should be removable once the patches are upstreamed to segger. +""" + +import argparse +import os +import subprocess +import sys + + +SEGGER_CLI = "/workspace/segger_dev/src/segger/cli/predict_fast.py" + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--models-dir", required=True) + p.add_argument("--segger-data-dir", required=True) + p.add_argument("--transcripts-file", required=True) + p.add_argument("--benchmarks-dir", required=True) + p.add_argument("--batch-size", type=int, required=True) + p.add_argument("--use-cc", required=True) + p.add_argument("--knn-method", required=True) + p.add_argument("--num-workers", type=int, required=True) + args, extra = p.parse_known_args() + return args, extra + + +def detect_gpus(): + """Return comma-separated list of available CUDA device ids (or "0" if none).""" + import torch + + print("=== GPU Detection (SEGGER_PREDICT) ===") + print(f"PyTorch CUDA available: {torch.cuda.is_available()}") + n = torch.cuda.device_count() + print(f"CUDA device count: {n}") + print("======================================") + if n > 0: + return ",".join(str(i) for i in range(n)) + return "0" + + +def patch_predict_parquet(): + """ + WORKAROUND: patch segger.prediction.predict_parquet at runtime. + + Avoids rebuilding the segger Docker image. Two patches: + 1. Add torch.no_grad() to disable gradient graphs during inference (~30-50% VRAM savings). + 2. Seed random for deterministic GPU assignment (avoids stochastic OOM). + + Remove this function once the patches are upstreamed to segger. + """ + import segger.prediction.predict_parquet as m + + pred_py = m.__file__ + print(f"Patching {pred_py}: torch.no_grad() + round-robin GPU assignment") + # Use sed via subprocess for in-place edit (matches the original behavior exactly) + subprocess.run( + [ + "sed", + "-i", + "s/with cp.cuda.Device(gpu_id):/with cp.cuda.Device(gpu_id), torch.no_grad():/", + pred_py, + ], + check=True, + ) + subprocess.run( + [ + "sed", + "-i", + "s/gpu_id = random.choice(gpu_ids)/random.seed(0); gpu_id = random.choice(gpu_ids)/", + pred_py, + ], + check=True, + ) + + +def run_segger_cli(args, extra, gpu_ids): + cmd = [ + "python3", + SEGGER_CLI, + "--models_dir", + args.models_dir, + "--segger_data_dir", + args.segger_data_dir, + "--transcripts_file", + args.transcripts_file, + "--benchmarks_dir", + args.benchmarks_dir, + "--batch_size", + str(args.batch_size), + "--use_cc", + str(args.use_cc), + "--knn_method", + args.knn_method, + "--num_workers", + str(args.num_workers), + "--gpu_ids", + gpu_ids, + *extra, + ] + print(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd) + if result.returncode != 0: + sys.exit(result.returncode) + + +def main(): + args, extra = parse_args() + + # Limit cupy GPU memory to 80% so PyTorch has headroom for graph attention ops + os.environ.setdefault("CUPY_GPU_MEMORY_LIMIT", "80%") + # Belt-and-suspenders: ensure PyTorch uses expandable segments + os.environ.setdefault( + "PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512" + ) + # Numba cache directory + os.environ.setdefault("NUMBA_CACHE_DIR", os.path.join(os.getcwd(), ".numba_cache")) + os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) + + gpu_ids = detect_gpus() + print(f"Using GPUs: {gpu_ids}") + + patch_predict_parquet() + + run_segger_cli(args, extra, gpu_ids) + + +if __name__ == "__main__": + main() diff --git a/modules/local/utility/Dockerfile b/modules/local/utility/Dockerfile deleted file mode 100644 index 79213ebb..00000000 --- a/modules/local/utility/Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM mambaorg/micromamba:1.5.10-noble -COPY --chown=$MAMBA_USER:$MAMBA_USER conda.yml /tmp/conda.yml -RUN micromamba install -y -n base -f /tmp/conda.yml \ - && micromamba install -y -n base conda-forge::procps-ng \ - && micromamba env export --name base --explicit > environment.lock \ - && echo ">> CONDA_LOCK_START" \ - && cat environment.lock \ - && echo "<< CONDA_LOCK_END" \ - && micromamba clean -a -y -USER root -ENV PATH="$MAMBA_ROOT_PREFIX/bin:$PATH" diff --git a/modules/local/xenium_patch/stitch/main.nf b/modules/local/xenium_patch/stitch/main.nf index 3d523971..5f050e23 100644 --- a/modules/local/xenium_patch/stitch/main.nf +++ b/modules/local/xenium_patch/stitch/main.nf @@ -33,11 +33,12 @@ process XENIUM_PATCH_STITCH { task.ext.when == null || task.ext.when script: + def args = task.ext.args ?: '' """ stitch_transcripts.py \\ --patches ${patches} \\ --output output \\ - --min-transcripts-per-cell ${params.baysor_tiling_min_transcripts_per_cell} + ${args} # Post-process: ensure all GeoJSON geometries are Polygon. # make_valid() and solve_conflicts() can produce MultiPolygon,