From b8182d8c6100007e18a79243634151a86cbe4a1f Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Fri, 31 May 2024 09:09:07 -0400 Subject: [PATCH 01/18] ExperimentDataPipe: configurable array-conversion method `nd.array`, `scipy.coo`, `scipy.csr` --- .../experimental/ml/pytorch.py | 131 ++++++++++++++---- 1 file changed, 104 insertions(+), 27 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 6bf9aa30c..7748490ed 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -1,24 +1,26 @@ import gc import logging import os +import typing from contextlib import contextmanager from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt import pandas as pd import psutil import pyarrow as pa -import scipy import tiledbsoma as soma import torch import torchdata.datapipes.iter as pipes from attr import define from numpy.random import Generator +from pyarrow import Table from scipy import sparse +from scipy.sparse import coo_matrix, csr_matrix from sklearn.preprocessing import LabelEncoder from torch import Tensor from torch import distributed as dist @@ -36,6 +38,14 @@ The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" +# Various "methods" for converting from TileDB COO (on disk) to `torch.Tensor` +Method = Literal["np.array", "scipy.csr", "scipy.coo"] + + +# "Chunk" of X data, returned by each `Method` above +ChunkX = Union[np.array, csr_matrix, coo_matrix] + + @define class _SOMAChunk: """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X`` @@ -46,7 +56,7 @@ class _SOMAChunk: """ obs: pd.DataFrame - X: scipy.sparse.spmatrix + X: ChunkX stats: "Stats" def __len__(self) -> int: @@ -72,7 +82,7 @@ class Stats: nnz: int = 0 """The total number of values retrieved""" - elapsed: int = 0 + elapsed: float = 0 """The total elapsed time in seconds for retrieving all batches""" n_soma_chunks: int = 0 @@ -101,6 +111,19 @@ def _open_experiment( yield exp +def tables_to_np( + tables: Iterator[Tuple[Table, any]], shape: Tuple[int, int] +) -> typing.Generator[Tuple[np.ndarray, any, int], None, None]: + for tbl, indices in tables: + row_indices_np = np.array(tbl.columns[0]) + col_indices_np = np.array(tbl.columns[1]) + data_np = np.array(tbl.columns[2]) + nnz = len(data_np) + dense_matrix = np.zeros(shape, dtype=data_np.dtype) + dense_matrix[row_indices_np, col_indices_np] = data_np + yield dense_matrix, indices, nnz + + class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, not intended for public use. @@ -123,6 +146,7 @@ def __init__( var_joinids: npt.NDArray[np.int64], shuffle_chunk_count: Optional[int] = None, shuffle_rng: Optional[Generator] = None, + method: Method = "scipy.csr", ): self.obs = obs self.X = X @@ -145,6 +169,7 @@ def __init__( self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) self.var_joinids = var_joinids self.shuffle_chunk_count = shuffle_chunk_count + self.method = method def __next__(self) -> _SOMAChunk: pytorch_logger.debug("Retrieving next SOMA chunk...") @@ -173,18 +198,33 @@ def __next__(self) -> _SOMAChunk: # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) - scipy_iter = ( - self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) - .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) - .scipy(compress=True) + blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise( + axis=0, size=len(obs_joinids_chunk), eager=False ) - X_batch, _ = next(scipy_iter) + + method = self.method + if method == "np.array": + batch_iter = tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) + elif method == "scipy.coo": + batch_iter = blockwise_iter.scipy(compress=False) + elif method == "scipy.csr": + batch_iter = blockwise_iter.scipy(compress=True) + else: + raise ValueError(f"Invalid format: {method}") + + res = next(batch_iter) + X_batch: ChunkX = res[0] + if isinstance(X_batch, np.ndarray): + nnz = res[2] + else: + nnz = X_batch.nnz assert obs_batch.shape[0] == X_batch.shape[0] + end_time = time() stats = Stats() stats.n_obs += X_batch.shape[0] - stats.nnz += X_batch.nnz - stats.elapsed += int(time() - start_time) + stats.nnz += nnz + stats.elapsed += end_time - start_time stats.n_soma_chunks += 1 pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") @@ -208,17 +248,19 @@ def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: return result -def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103 +def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 proc = psutil.Process(os.getpid()) pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + start = time() gc.collect() + gc_elapsed = time() - start post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() pytorch_logger.debug(f"gc: pre={pre_gc}") pytorch_logger.debug(f"gc: post={post_gc}") - return pre_gc, post_gc + return pre_gc, post_gc, gc_elapsed class _ObsAndXIterator(Iterator[ObsAndXDatum]): @@ -254,9 +296,17 @@ def __init__( use_eager_fetch: bool, shuffle_chunk_count: Optional[int] = None, shuffle_rng: Optional[Generator] = None, + method: Method = "scipy.csr", ) -> None: self.soma_chunk_iter = _ObsAndXSOMAIterator( - obs, X, obs_column_names, obs_joinids_chunked, var_joinids, shuffle_chunk_count, shuffle_rng + obs, + X, + obs_column_names, + obs_joinids_chunked, + var_joinids, + shuffle_chunk_count, + shuffle_rng, + method=method, ) if use_eager_fetch: self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) @@ -266,19 +316,26 @@ def __init__( self.return_sparse_X = return_sparse_X self.encoders = encoders self.stats = stats + self.gc_elapsed = 0 self.max_process_mem_usage_bytes = 0 self.X_dtype = X.schema[2].type.to_pandas_dtype() def __next__(self) -> ObsAndXDatum: """Read the next torch batch, possibly across multiple soma chunks.""" obs: pd.DataFrame = pd.DataFrame() - X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) + X: ChunkX = csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) + first = True while len(obs) < self.batch_size: try: obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) - obs = pd.concat([obs, obs_partial], axis=0) - X = sparse.vstack([X, X_partial]) + if first: + obs = obs_partial + X = X_partial + first = False + else: + obs = pd.concat([obs, obs_partial], axis=0) + X = sparse.vstack([X, X_partial]) except StopIteration: break @@ -298,7 +355,9 @@ def __next__(self) -> ObsAndXDatum: obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) if not self.return_sparse_X: - X_tensor = torch.from_numpy(X.todense()) + if isinstance(X, (csr_matrix, coo_matrix)): + X = X.todense() + X_tensor = torch.from_numpy(X) else: coo = X.tocoo() @@ -315,7 +374,7 @@ def __next__(self) -> ObsAndXDatum: return X_tensor, obs_tensor - def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: + def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current SOMA chunk are fewer than the requested ``batch_size``. @@ -323,17 +382,20 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): # GC memory from previous soma_chunk self.soma_chunk = None - mem_info = run_gc() - self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, mem_info[0][0].uss) + pre_gc, _, gc_elapsed = run_gc() + self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss) self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) self.stats += self.soma_chunk.stats + self.gc_elapsed += gc_elapsed self.i = 0 - pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}") + pytorch_logger.debug( + f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}" + ) obs_batch = self.soma_chunk.obs - X_batch = self.soma_chunk.X + X_chunk = self.soma_chunk.X safe_batch_size = min(batch_size, len(obs_batch) - self.i) slice_ = slice(self.i, self.i + safe_batch_size) @@ -343,12 +405,22 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: assert obs_rows.index.is_unique assert safe_batch_size == obs_rows.shape[0] - X_csr_scipy = X_batch[slice_] - assert obs_rows.shape[0] == X_csr_scipy.shape[0] + if isinstance(X_chunk, coo_matrix): + start = np.searchsorted(X_chunk.row, slice_.start) + stop = np.searchsorted(X_chunk.row, slice_.stop) + data = X_chunk.data[start:stop] + row = X_chunk.row[start:stop] - slice_.start + col = X_chunk.col[start:stop] + shape = (slice_.stop - slice_.start, X_chunk.shape[1]) + X_batch = coo_matrix((data, (row, col)), shape=shape) + else: + X_batch = X_chunk[slice_] + + assert obs_rows.shape[0] == X_batch.shape[0] self.i += safe_batch_size - return obs_rows, X_csr_scipy + return obs_rows, X_batch class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore @@ -419,6 +491,7 @@ def __init__( soma_chunk_size: Optional[int] = 64, use_eager_fetch: bool = True, shuffle_chunk_count: Optional[int] = 2000, + method: Method = "scipy.csr", ) -> None: r"""Construct a new ``ExperimentDataPipe``. @@ -498,6 +571,8 @@ def __init__( self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self._initialized = False + self.method = method + self.max_process_mem_usage_bytes = 0 if "soma_joinid" not in self.obs_column_names: self.obs_column_names = ["soma_joinid", *self.obs_column_names] @@ -613,12 +688,14 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: use_eager_fetch=self.use_eager_fetch, shuffle_rng=self._shuffle_rng, shuffle_chunk_count=self._shuffle_chunk_count, + method=self.method, ) yield from obs_and_x_iter + self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes pytorch_logger.debug( - "max process memory usage=" f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" + "max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" ) @staticmethod From f7f6121080f2e9cdc9cf0f52578f46e4796e36bf Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 6 Jun 2024 11:44:48 -0400 Subject: [PATCH 02/18] group chunks iff `shuffle_chunk_count > 1` --- .../src/cellxgene_census/experimental/ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 7748490ed..fae429f24 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -151,7 +151,7 @@ def __init__( self.obs = obs self.X = X self.obs_column_names = obs_column_names - if shuffle_chunk_count: + if shuffle_chunk_count > 1: assert shuffle_rng is not None # At the start of this step, `obs_joinids_chunked` is a list of one dimensional From 9655f8bea00702a5e94bd65db00c3946bba9086b Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 6 Jun 2024 16:14:48 -0400 Subject: [PATCH 03/18] rm scipy.coo method --- .../src/cellxgene_census/experimental/ml/pytorch.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index fae429f24..93ce5d7f1 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -39,11 +39,11 @@ # Various "methods" for converting from TileDB COO (on disk) to `torch.Tensor` -Method = Literal["np.array", "scipy.csr", "scipy.coo"] +Method = Literal["np.array", "scipy.csr"] # "Chunk" of X data, returned by each `Method` above -ChunkX = Union[np.array, csr_matrix, coo_matrix] +ChunkX = Union[np.array, csr_matrix] @define @@ -205,8 +205,6 @@ def __next__(self) -> _SOMAChunk: method = self.method if method == "np.array": batch_iter = tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) - elif method == "scipy.coo": - batch_iter = blockwise_iter.scipy(compress=False) elif method == "scipy.csr": batch_iter = blockwise_iter.scipy(compress=True) else: @@ -355,7 +353,7 @@ def __next__(self) -> ObsAndXDatum: obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) if not self.return_sparse_X: - if isinstance(X, (csr_matrix, coo_matrix)): + if isinstance(X, csr_matrix): X = X.todense() X_tensor = torch.from_numpy(X) else: From 54ffe73b2cff43e7c4b445db7999717125d57a23 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Fri, 7 Jun 2024 10:06:35 -0400 Subject: [PATCH 04/18] add `METHODS` constant --- .../src/cellxgene_census/experimental/ml/pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 93ce5d7f1..cc5faa2ba 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -40,6 +40,7 @@ # Various "methods" for converting from TileDB COO (on disk) to `torch.Tensor` Method = Literal["np.array", "scipy.csr"] +METHODS = ["np.array", "scipy.csr"] # "Chunk" of X data, returned by each `Method` above From 2fd1b142cdef6abd28e3668af0b94817610cd498 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Wed, 3 Jul 2024 23:07:07 +0000 Subject: [PATCH 05/18] Use return_sparse_X instead of adding a method kwarg --- .../cellxgene_census/experimental/ml/pytorch.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index cc5faa2ba..8b80797b8 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -147,7 +147,7 @@ def __init__( var_joinids: npt.NDArray[np.int64], shuffle_chunk_count: Optional[int] = None, shuffle_rng: Optional[Generator] = None, - method: Method = "scipy.csr", + return_sparse_X: bool = False, ): self.obs = obs self.X = X @@ -170,7 +170,7 @@ def __init__( self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) self.var_joinids = var_joinids self.shuffle_chunk_count = shuffle_chunk_count - self.method = method + self.return_sparse_X = return_sparse_X def __next__(self) -> _SOMAChunk: pytorch_logger.debug("Retrieving next SOMA chunk...") @@ -203,13 +203,10 @@ def __next__(self) -> _SOMAChunk: axis=0, size=len(obs_joinids_chunk), eager=False ) - method = self.method - if method == "np.array": + if not self.return_sparse_X: batch_iter = tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) - elif method == "scipy.csr": - batch_iter = blockwise_iter.scipy(compress=True) else: - raise ValueError(f"Invalid format: {method}") + batch_iter = blockwise_iter.scipy(compress=True) res = next(batch_iter) X_batch: ChunkX = res[0] @@ -295,7 +292,6 @@ def __init__( use_eager_fetch: bool, shuffle_chunk_count: Optional[int] = None, shuffle_rng: Optional[Generator] = None, - method: Method = "scipy.csr", ) -> None: self.soma_chunk_iter = _ObsAndXSOMAIterator( obs, @@ -305,7 +301,7 @@ def __init__( var_joinids, shuffle_chunk_count, shuffle_rng, - method=method, + return_sparse_X=return_sparse_X, ) if use_eager_fetch: self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) @@ -490,7 +486,6 @@ def __init__( soma_chunk_size: Optional[int] = 64, use_eager_fetch: bool = True, shuffle_chunk_count: Optional[int] = 2000, - method: Method = "scipy.csr", ) -> None: r"""Construct a new ``ExperimentDataPipe``. @@ -687,7 +682,6 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: use_eager_fetch=self.use_eager_fetch, shuffle_rng=self._shuffle_rng, shuffle_chunk_count=self._shuffle_chunk_count, - method=self.method, ) yield from obs_and_x_iter From b74bf67938e5563e0e4364240c47bffd9157ba5c Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Wed, 3 Jul 2024 23:32:06 +0000 Subject: [PATCH 06/18] Oops, one more thing --- .../src/cellxgene_census/experimental/ml/pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 8b80797b8..4a2f42530 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -565,7 +565,6 @@ def __init__( self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self._initialized = False - self.method = method self.max_process_mem_usage_bytes = 0 if "soma_joinid" not in self.obs_column_names: From 664812a0ad6fee9aa535365c456ceab36502131b Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Wed, 3 Jul 2024 23:45:13 +0000 Subject: [PATCH 07/18] revert change to shuffle_chunk_count branch --- .../src/cellxgene_census/experimental/ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 4a2f42530..eec95f70f 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -152,7 +152,7 @@ def __init__( self.obs = obs self.X = X self.obs_column_names = obs_column_names - if shuffle_chunk_count > 1: + if shuffle_chunk_count is not None: assert shuffle_rng is not None # At the start of this step, `obs_joinids_chunked` is a list of one dimensional From 629af82c104959591df283baf270c04b7048105e Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Wed, 3 Jul 2024 23:55:34 +0000 Subject: [PATCH 08/18] mypy + linting fixes --- .../experimental/ml/pytorch.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index eec95f70f..1dc04a036 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -6,7 +6,7 @@ from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt @@ -38,13 +38,8 @@ The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" -# Various "methods" for converting from TileDB COO (on disk) to `torch.Tensor` -Method = Literal["np.array", "scipy.csr"] -METHODS = ["np.array", "scipy.csr"] - - # "Chunk" of X data, returned by each `Method` above -ChunkX = Union[np.array, csr_matrix] +ChunkX = Union[npt.NDArray[Any], csr_matrix] @define @@ -112,9 +107,9 @@ def _open_experiment( yield exp -def tables_to_np( - tables: Iterator[Tuple[Table, any]], shape: Tuple[int, int] -) -> typing.Generator[Tuple[np.ndarray, any, int], None, None]: +def _tables_to_np( + tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] +) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: for tbl, indices in tables: row_indices_np = np.array(tbl.columns[0]) col_indices_np = np.array(tbl.columns[1]) @@ -204,7 +199,7 @@ def __next__(self) -> _SOMAChunk: ) if not self.return_sparse_X: - batch_iter = tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) + batch_iter = _tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) else: batch_iter = blockwise_iter.scipy(compress=True) @@ -311,7 +306,7 @@ def __init__( self.return_sparse_X = return_sparse_X self.encoders = encoders self.stats = stats - self.gc_elapsed = 0 + self.gc_elapsed = 0.0 self.max_process_mem_usage_bytes = 0 self.X_dtype = X.schema[2].type.to_pandas_dtype() @@ -354,7 +349,7 @@ def __next__(self) -> ObsAndXDatum: X = X.todense() X_tensor = torch.from_numpy(X) else: - coo = X.tocoo() + coo = X.tocoo() # type: ignore X_tensor = torch.sparse_coo_tensor( # Note: The `np.array` seems unnecessary, but PyTorch warns bare array is "extremely slow" From 8bab49a8267cfb29e9a27bf866fa44934a4d2429 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Thu, 4 Jul 2024 03:10:31 +0000 Subject: [PATCH 09/18] Add test for bug I found --- .../tests/experimental/ml/test_pytorch.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py index dd524a691..872037ba9 100644 --- a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py +++ b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py @@ -152,6 +152,39 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None assert row[1].tolist() == [0, 0] +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +def test_uneven_soma_and_result_batches( + soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool +) -> None: + """This is checking that batches are correctly created when they require fetching multiple chunks. + + This was added due to failures in _ObsAndXIterator.__next__. + """ + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + shuffle=False, + batch_size=3, + soma_chunk_size=2, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + X_batch = row[0].to_dense() if return_sparse_X else row[0] + assert X_batch.int()[0].tolist() == [0, 1, 0] + assert row[1].tolist() == [[0, 0], [1, 1], [2, 2]] + + @pytest.mark.experimental # noinspection PyTestParametrized,DuplicatedCode @pytest.mark.parametrize( From 4c042ee677e7868c6e07eeace5da877b388ba325 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Thu, 4 Jul 2024 03:39:06 +0000 Subject: [PATCH 10/18] Fix bug, possibly speed up dataloaders --- .../experimental/ml/pytorch.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 1dc04a036..bee45d49e 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -312,25 +312,27 @@ def __init__( def __next__(self) -> ObsAndXDatum: """Read the next torch batch, possibly across multiple soma chunks.""" - obs: pd.DataFrame = pd.DataFrame() - X: ChunkX = csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) - first = True + obss: list[pd.DataFrame] = [] + Xs: list[ChunkX] = [] + n_obs = 0 - while len(obs) < self.batch_size: + while n_obs < self.batch_size: try: - obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) - if first: - obs = obs_partial - X = X_partial - first = False - else: - obs = pd.concat([obs, obs_partial], axis=0) - X = sparse.vstack([X, X_partial]) + obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs) + n_obs += len(obs_partial) + obss.append(obs_partial) + Xs.append(X_partial) except StopIteration: break - if len(obs) == 0: + if len(Xs) == 0: # If we ran out of data raise StopIteration + else: + if self.return_sparse_X: + X = sparse.vstack(Xs) + else: + X = np.concatenate(Xs, axis=0) + obs = pd.concat(obss, axis=0) obs_encoded = pd.DataFrame( data={"soma_joinid": obs.index}, @@ -349,7 +351,7 @@ def __next__(self) -> ObsAndXDatum: X = X.todense() X_tensor = torch.from_numpy(X) else: - coo = X.tocoo() # type: ignore + coo = X.tocoo() X_tensor = torch.sparse_coo_tensor( # Note: The `np.array` seems unnecessary, but PyTorch warns bare array is "extremely slow" From 9d5cb738c54c1e90da018a94906e5e29d2ddf247 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Thu, 4 Jul 2024 03:42:14 +0000 Subject: [PATCH 11/18] Remove unused (I think) branch --- .../src/cellxgene_census/experimental/ml/pytorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index bee45d49e..d28fce935 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -347,8 +347,6 @@ def __next__(self) -> ObsAndXDatum: obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) if not self.return_sparse_X: - if isinstance(X, csr_matrix): - X = X.todense() X_tensor = torch.from_numpy(X) else: coo = X.tocoo() From 295b30828a347a649428583236759dddf079fe55 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Thu, 4 Jul 2024 04:30:15 +0000 Subject: [PATCH 12/18] Remove more unused code --- .../cellxgene_census/experimental/ml/pytorch.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index d28fce935..c42a29659 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -20,7 +20,6 @@ from numpy.random import Generator from pyarrow import Table from scipy import sparse -from scipy.sparse import coo_matrix, csr_matrix from sklearn.preprocessing import LabelEncoder from torch import Tensor from torch import distributed as dist @@ -39,7 +38,7 @@ # "Chunk" of X data, returned by each `Method` above -ChunkX = Union[npt.NDArray[Any], csr_matrix] +ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] @define @@ -395,16 +394,7 @@ def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, Chun assert obs_rows.index.is_unique assert safe_batch_size == obs_rows.shape[0] - if isinstance(X_chunk, coo_matrix): - start = np.searchsorted(X_chunk.row, slice_.start) - stop = np.searchsorted(X_chunk.row, slice_.stop) - data = X_chunk.data[start:stop] - row = X_chunk.row[start:stop] - slice_.start - col = X_chunk.col[start:stop] - shape = (slice_.stop - slice_.start, X_chunk.shape[1]) - X_batch = coo_matrix((data, (row, col)), shape=shape) - else: - X_batch = X_chunk[slice_] + X_batch = X_chunk[slice_] assert obs_rows.shape[0] == X_batch.shape[0] From 137f77d8ff7aa5b3001ac2fd0b42bf059df807ec Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 17:29:19 +0000 Subject: [PATCH 13/18] Slightly more precise typing --- .../src/cellxgene_census/experimental/ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index c42a29659..9d0154b6a 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -38,7 +38,7 @@ # "Chunk" of X data, returned by each `Method` above -ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] +ChunkX = Union[npt.NDArray[np.number[Any]], sparse.csr_matrix] @define From 4ac851153de7b817ec547340a5299f1db846a2d8 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 17:35:16 +0000 Subject: [PATCH 14/18] Reduce additional copies in _tables_to_np --- .../src/cellxgene_census/experimental/ml/pytorch.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 9d0154b6a..5612566eb 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -110,12 +110,10 @@ def _tables_to_np( tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] ) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: for tbl, indices in tables: - row_indices_np = np.array(tbl.columns[0]) - col_indices_np = np.array(tbl.columns[1]) - data_np = np.array(tbl.columns[2]) - nnz = len(data_np) - dense_matrix = np.zeros(shape, dtype=data_np.dtype) - dense_matrix[row_indices_np, col_indices_np] = data_np + row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) + nnz = len(data) + dense_matrix = np.zeros(shape, dtype=data.dtype) + dense_matrix[row_indices, col_indices] = data yield dense_matrix, indices, nnz From f9434042cb1d8d4a7cdd431ac1e6989b55414f17 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 18:09:18 +0000 Subject: [PATCH 15/18] Simplify branching in _ObsAndXSOMAIterator --- .../src/cellxgene_census/experimental/ml/pytorch.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 5612566eb..c806c5f23 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -195,17 +195,14 @@ def __next__(self) -> _SOMAChunk: axis=0, size=len(obs_joinids_chunk), eager=False ) + X_batch: ChunkX if not self.return_sparse_X: - batch_iter = _tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))) - else: - batch_iter = blockwise_iter.scipy(compress=True) - - res = next(batch_iter) - X_batch: ChunkX = res[0] - if isinstance(X_batch, np.ndarray): - nnz = res[2] + res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids)))) + X_batch, nnz = res[0], res[2] else: + X_batch = next(blockwise_iter.scipy(compress=True))[0] nnz = X_batch.nnz + assert obs_batch.shape[0] == X_batch.shape[0] end_time = time() From 357582ece6766ca4562c61a148ce4a3636236370 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 20:24:43 +0000 Subject: [PATCH 16/18] typing fix --- .../experimental/ml/pytorch.py | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 5d00426df..5d83aa5ab 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gc import itertools import logging @@ -7,7 +9,7 @@ from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, Sequence, Tuple, Union import numpy as np import numpy.typing as npt @@ -52,7 +54,7 @@ class _SOMAChunk: obs: pd.DataFrame X: ChunkX - stats: "Stats" + stats: Stats def __len__(self) -> int: return len(self.obs) @@ -86,7 +88,7 @@ class Stats: def __str__(self) -> str: return f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " f"elapsed={timedelta(seconds=self.elapsed)}" - def __add__(self, other: "Stats") -> "Stats": + def __add__(self, other: Stats) -> Stats: self.n_obs += other.n_obs self.nnz += other.nnz self.elapsed += other.elapsed @@ -97,7 +99,7 @@ def __add__(self, other: "Stats") -> "Stats": @contextmanager def _open_experiment( uri: str, - aws_region: Optional[str] = None, + aws_region: str | None = None, ) -> soma.Experiment: """Internal method for opening a SOMA ``Experiment`` as a context manager.""" context = get_default_soma_context().replace(tiledb_config={"vfs.s3.region": aws_region} if aws_region else {}) @@ -107,8 +109,8 @@ def _open_experiment( def _tables_to_np( - tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] -) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: + tables: Iterator[tuple[Table, Any]], shape: tuple[int, int] +) -> typing.Generator[tuple[npt.NDArray[Any], Any, int], None, None]: for tbl, indices in tables: row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) nnz = len(data) @@ -135,10 +137,10 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], + obs_joinids_chunked: list[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, return_sparse_X: bool = False, ): self.obs = obs @@ -221,7 +223,7 @@ def __next__(self) -> _SOMAChunk: return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) -def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: +def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. """ @@ -238,7 +240,7 @@ def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: return result -def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 +def run_gc() -> tuple[tuple[Any, Any, Any], tuple[Any, Any, Any], float]: # noqa: D103 proc = psutil.Process(os.getpid()) pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() @@ -266,7 +268,7 @@ class _ObsAndXIterator(Iterator[ObsAndXDatum]): soma_chunk_iter: Iterator[_SOMAChunk] """The iterator for SOMA chunks of paired obs and X data""" - soma_chunk: Optional[_SOMAChunk] + soma_chunk: _SOMAChunk | None """The current SOMA chunk of obs and X data""" i: int = -1 @@ -277,15 +279,15 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: List[npt.NDArray[np.int64]], + obs_joinids_chunked: list[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], batch_size: int, - encoders: List[Encoder], + encoders: list[Encoder], stats: Stats, return_sparse_X: bool, use_eager_fetch: bool, - shuffle_chunk_count: Optional[int] = None, - shuffle_rng: Optional[Generator] = None, + shuffle_chunk_count: int | None = None, + shuffle_rng: Generator | None = None, ) -> None: self.soma_chunk_iter = _ObsAndXSOMAIterator( obs, @@ -362,7 +364,7 @@ def __next__(self) -> ObsAndXDatum: return X_tensor, obs_tensor - def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: + def _read_partial_torch_batch(self, batch_size: int) -> tuple[pd.DataFrame, ChunkX]: """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current SOMA chunk are fewer than the requested ``batch_size``. @@ -443,15 +445,15 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig _initialized: bool - _obs_joinids: Optional[npt.NDArray[np.int64]] + _obs_joinids: npt.NDArray[np.int64] | None - _var_joinids: Optional[npt.NDArray[np.int64]] + _var_joinids: npt.NDArray[np.int64] | None - _encoders: List[Encoder] + _encoders: list[Encoder] _stats: Stats - _shuffle_rng: Optional[Generator] + _shuffle_rng: Generator | None # TODO: Consider adding another convenience method wrapper to construct this object whose signature is more closely # aligned with get_anndata() params (i.e. "exploded" AxisQuery params). @@ -460,17 +462,17 @@ def __init__( experiment: soma.Experiment, measurement_name: str = "RNA", X_name: str = "raw", - obs_query: Optional[soma.AxisQuery] = None, - var_query: Optional[soma.AxisQuery] = None, + obs_query: soma.AxisQuery | None = None, + var_query: soma.AxisQuery | None = None, obs_column_names: Sequence[str] = (), batch_size: int = 1, shuffle: bool = True, - seed: Optional[int] = None, + seed: int | None = None, return_sparse_X: bool = False, - soma_chunk_size: Optional[int] = 64, + soma_chunk_size: int | None = 64, use_eager_fetch: bool = True, - encoders: Optional[List[Encoder]] = None, - shuffle_chunk_count: Optional[int] = 2000, + encoders: list[Encoder] | None = None, + shuffle_chunk_count: int | None = 2000, ) -> None: r"""Construct a new ``ExperimentDataPipe``. @@ -596,10 +598,10 @@ def _init(self) -> None: @staticmethod def _subset_ids_to_partition( - ids_chunked: List[npt.NDArray[np.int64]], + ids_chunked: list[npt.NDArray[np.int64]], partition_index: int, num_partitions: int, - ) -> List[npt.NDArray[np.int64]]: + ) -> list[npt.NDArray[np.int64]]: """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world size. """ @@ -622,7 +624,7 @@ def _compute_partitions( loader_partitions: int, dist_partition: int, num_dist_partitions: int, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload total_partitions = num_dist_partitions * loader_partitions partition = dist_partition * loader_partitions + loader_partition @@ -665,7 +667,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: dist_partition=dist.get_rank() if dist.is_initialized() else 0, num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, ) - obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition( + obs_joinids_chunked_partition: list[npt.NDArray[np.int64]] = self._subset_ids_to_partition( obs_joinids_chunked, partition, partitions ) @@ -693,7 +695,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: ) @staticmethod - def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> List[npt.NDArray[np.int64]]: + def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> list[npt.NDArray[np.int64]]: num_chunks = max(1, ceil(len(ids) / chunk_size)) pytorch_logger.debug(f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}") return np.array_split(ids, num_chunks) @@ -708,7 +710,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> ObsAndXDatum: raise NotImplementedError("IterDataPipe can only be iterated") - def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> List[Encoder]: + def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> list[Encoder]: pytorch_logger.debug("Initializing encoders") encoders = [] @@ -748,7 +750,7 @@ def stats(self) -> Stats: return self._stats @property - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: """Get the shape of the data that will be returned by this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect From 15cb1a8cdc2d52b614b1c892465730fc58ddfd50 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 21:08:12 +0000 Subject: [PATCH 17/18] Revert "typing fix" This reverts commit 357582ece6766ca4562c61a148ce4a3636236370. --- .../experimental/ml/pytorch.py | 70 +++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 5d83aa5ab..5d00426df 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import gc import itertools import logging @@ -9,7 +7,7 @@ from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, Sequence, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt @@ -54,7 +52,7 @@ class _SOMAChunk: obs: pd.DataFrame X: ChunkX - stats: Stats + stats: "Stats" def __len__(self) -> int: return len(self.obs) @@ -88,7 +86,7 @@ class Stats: def __str__(self) -> str: return f"{self.n_soma_chunks=}, {self.n_obs=}, {self.nnz=}, " f"elapsed={timedelta(seconds=self.elapsed)}" - def __add__(self, other: Stats) -> Stats: + def __add__(self, other: "Stats") -> "Stats": self.n_obs += other.n_obs self.nnz += other.nnz self.elapsed += other.elapsed @@ -99,7 +97,7 @@ def __add__(self, other: Stats) -> Stats: @contextmanager def _open_experiment( uri: str, - aws_region: str | None = None, + aws_region: Optional[str] = None, ) -> soma.Experiment: """Internal method for opening a SOMA ``Experiment`` as a context manager.""" context = get_default_soma_context().replace(tiledb_config={"vfs.s3.region": aws_region} if aws_region else {}) @@ -109,8 +107,8 @@ def _open_experiment( def _tables_to_np( - tables: Iterator[tuple[Table, Any]], shape: tuple[int, int] -) -> typing.Generator[tuple[npt.NDArray[Any], Any, int], None, None]: + tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] +) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: for tbl, indices in tables: row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) nnz = len(data) @@ -137,10 +135,10 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: list[npt.NDArray[np.int64]], + obs_joinids_chunked: List[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], - shuffle_chunk_count: int | None = None, - shuffle_rng: Generator | None = None, + shuffle_chunk_count: Optional[int] = None, + shuffle_rng: Optional[Generator] = None, return_sparse_X: bool = False, ): self.obs = obs @@ -223,7 +221,7 @@ def __next__(self) -> _SOMAChunk: return _SOMAChunk(obs=obs_batch, X=X_batch, stats=stats) -def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: +def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: """Splits a python list into a list of sublists where each sublist is of size `sublist_len`. TODO: Replace with `itertools.batched` when Python 3.12 becomes the minimum supported version. """ @@ -240,7 +238,7 @@ def list_split(arr_list: list[Any], sublist_len: int) -> list[list[Any]]: return result -def run_gc() -> tuple[tuple[Any, Any, Any], tuple[Any, Any, Any], float]: # noqa: D103 +def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 proc = psutil.Process(os.getpid()) pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() @@ -268,7 +266,7 @@ class _ObsAndXIterator(Iterator[ObsAndXDatum]): soma_chunk_iter: Iterator[_SOMAChunk] """The iterator for SOMA chunks of paired obs and X data""" - soma_chunk: _SOMAChunk | None + soma_chunk: Optional[_SOMAChunk] """The current SOMA chunk of obs and X data""" i: int = -1 @@ -279,15 +277,15 @@ def __init__( obs: soma.DataFrame, X: soma.SparseNDArray, obs_column_names: Sequence[str], - obs_joinids_chunked: list[npt.NDArray[np.int64]], + obs_joinids_chunked: List[npt.NDArray[np.int64]], var_joinids: npt.NDArray[np.int64], batch_size: int, - encoders: list[Encoder], + encoders: List[Encoder], stats: Stats, return_sparse_X: bool, use_eager_fetch: bool, - shuffle_chunk_count: int | None = None, - shuffle_rng: Generator | None = None, + shuffle_chunk_count: Optional[int] = None, + shuffle_rng: Optional[Generator] = None, ) -> None: self.soma_chunk_iter = _ObsAndXSOMAIterator( obs, @@ -364,7 +362,7 @@ def __next__(self) -> ObsAndXDatum: return X_tensor, obs_tensor - def _read_partial_torch_batch(self, batch_size: int) -> tuple[pd.DataFrame, ChunkX]: + def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current SOMA chunk are fewer than the requested ``batch_size``. @@ -445,15 +443,15 @@ class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ig _initialized: bool - _obs_joinids: npt.NDArray[np.int64] | None + _obs_joinids: Optional[npt.NDArray[np.int64]] - _var_joinids: npt.NDArray[np.int64] | None + _var_joinids: Optional[npt.NDArray[np.int64]] - _encoders: list[Encoder] + _encoders: List[Encoder] _stats: Stats - _shuffle_rng: Generator | None + _shuffle_rng: Optional[Generator] # TODO: Consider adding another convenience method wrapper to construct this object whose signature is more closely # aligned with get_anndata() params (i.e. "exploded" AxisQuery params). @@ -462,17 +460,17 @@ def __init__( experiment: soma.Experiment, measurement_name: str = "RNA", X_name: str = "raw", - obs_query: soma.AxisQuery | None = None, - var_query: soma.AxisQuery | None = None, + obs_query: Optional[soma.AxisQuery] = None, + var_query: Optional[soma.AxisQuery] = None, obs_column_names: Sequence[str] = (), batch_size: int = 1, shuffle: bool = True, - seed: int | None = None, + seed: Optional[int] = None, return_sparse_X: bool = False, - soma_chunk_size: int | None = 64, + soma_chunk_size: Optional[int] = 64, use_eager_fetch: bool = True, - encoders: list[Encoder] | None = None, - shuffle_chunk_count: int | None = 2000, + encoders: Optional[List[Encoder]] = None, + shuffle_chunk_count: Optional[int] = 2000, ) -> None: r"""Construct a new ``ExperimentDataPipe``. @@ -598,10 +596,10 @@ def _init(self) -> None: @staticmethod def _subset_ids_to_partition( - ids_chunked: list[npt.NDArray[np.int64]], + ids_chunked: List[npt.NDArray[np.int64]], partition_index: int, num_partitions: int, - ) -> list[npt.NDArray[np.int64]]: + ) -> List[npt.NDArray[np.int64]]: """Returns a single partition of the obs_joinids_chunked (a 2D ndarray), based upon the current process's distributed rank and world size. """ @@ -624,7 +622,7 @@ def _compute_partitions( loader_partitions: int, dist_partition: int, num_dist_partitions: int, - ) -> tuple[int, int]: + ) -> Tuple[int, int]: # NOTE: Can alternately use a `worker_init_fn` to split among workers split workload total_partitions = num_dist_partitions * loader_partitions partition = dist_partition * loader_partitions + loader_partition @@ -667,7 +665,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: dist_partition=dist.get_rank() if dist.is_initialized() else 0, num_dist_partitions=dist.get_world_size() if dist.is_initialized() else 1, ) - obs_joinids_chunked_partition: list[npt.NDArray[np.int64]] = self._subset_ids_to_partition( + obs_joinids_chunked_partition: List[npt.NDArray[np.int64]] = self._subset_ids_to_partition( obs_joinids_chunked, partition, partitions ) @@ -695,7 +693,7 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: ) @staticmethod - def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> list[npt.NDArray[np.int64]]: + def _chunk_ids(ids: npt.NDArray[np.int64], chunk_size: int) -> List[npt.NDArray[np.int64]]: num_chunks = max(1, ceil(len(ids) / chunk_size)) pytorch_logger.debug(f"Shuffling {len(ids)} obs joinids into {num_chunks} chunks of {chunk_size}") return np.array_split(ids, num_chunks) @@ -710,7 +708,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> ObsAndXDatum: raise NotImplementedError("IterDataPipe can only be iterated") - def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> list[Encoder]: + def _build_obs_encoders(self, query: soma.ExperimentAxisQuery) -> List[Encoder]: pytorch_logger.debug("Initializing encoders") encoders = [] @@ -750,7 +748,7 @@ def stats(self) -> Stats: return self._stats @property - def shape(self) -> tuple[int, int]: + def shape(self) -> Tuple[int, int]: """Get the shape of the data that will be returned by this :class:`cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect From 4b1a99364d179086c106c1e4b76eca29d8282e85 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 8 Jul 2024 21:10:31 +0000 Subject: [PATCH 18/18] Use more simplistic typing to avoid mypy + old python wrath --- .../src/cellxgene_census/experimental/ml/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 5d00426df..765c5cea1 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -38,7 +38,7 @@ # "Chunk" of X data, returned by each `Method` above -ChunkX = Union[npt.NDArray[np.number[Any]], sparse.csr_matrix] +ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] @define