feat: support for multi-gpu inference and OOM datasets#53
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR enables scalable inference across multiple GPUs and adds an out-of-memory (OOM) safe map-style dataloader for very large .h5ad files. This addresses memory bottlenecks and processing speed limitations when working with large single-cell datasets.
- Multi-GPU inference support using DistributedDataParallel (DDP) with configurable GPU allocation
- OOM-safe map-style dataset that uses backed reads and per-row densification to handle large files
- New CLI flags for enabling OOM dataloader, specifying data workers, and controlling GPU usage
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test/test_compare_umap.py | New test utility for comparing UMAPs between embeddings with Procrustes analysis |
| test/test_compare_emb.py | Updated embedding comparison to use "embeddings" key and added statistical tests |
| test/test_cli.py | Removed deprecated CLI test functions |
| src/transcriptformer/model/inference.py | Added multi-GPU support and OOM dataloader integration |
| src/transcriptformer/data/dataloader.py | Major refactoring with new AnnDatasetOOM class and extracted utility functions |
| src/transcriptformer/data/dataclasses.py | Updated InferenceConfig with num_gpus and use_oom_dataloader fields |
| src/transcriptformer/cli/conf/inference_config.yaml | Added new configuration options for multi-GPU and OOM handling |
| src/transcriptformer/cli/init.py | Complete CLI rewrite with direct inference execution instead of Hydra delegation |
| inference.py | Removed legacy inference script |
| download_artifacts.py | Removed legacy download script |
| README.md | Updated documentation with new CLI flags and usage examples |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables scalable inference across multiple GPUs and adds an out-of-memory (OOM) safe map-style dataloader for very large
.h5adfiles.Key Changes
Multi-GPU inference (DDP)
devicesandaccelerator=gpubased oninference_config.num_gpus.DistributedSampleris used with the new map-style dataset whendevices > 1.OOM-safe map-style dataloader
AnnDatasetOOM:.h5adinbacked='r'mode.__getitem__.DistributedSampler(order-safe with multiple workers).CLI and config
--oom-dataloader: enable OOM-safe map-style dataloader.--n-data-workers: number of DataLoader workers per process.use_oom_dataloadertoInferenceConfigandinference_config.yaml.n_data_workersvia CLI →DataConfig.Sparse-aware data handling
get_counts_layersafely selectsraw.XorXwith clear logging.is_raw_countssupports sparse inputs via sampling of non-zero data.AnnDatasetexplicitly densifies before batch processing to satisfy downstream expectations.Testing
.h5ad.Documentation
README.mdwith:--oom-dataloader,--n-data-workers)