diff --git a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py index 9dced9754..765c5cea1 100644 --- a/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py +++ b/api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py @@ -2,22 +2,23 @@ import itertools import logging import os +import typing from contextlib import contextmanager from datetime import timedelta from math import ceil from time import time -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import numpy.typing as npt import pandas as pd import psutil -import scipy import tiledbsoma as soma import torch import torchdata.datapipes.iter as pipes from attr import define from numpy.random import Generator +from pyarrow import Table from scipy import sparse from torch import Tensor from torch import distributed as dist @@ -36,6 +37,10 @@ The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2.""" +# "Chunk" of X data, returned by each `Method` above +ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix] + + @define class _SOMAChunk: """Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X`` @@ -46,7 +51,7 @@ class _SOMAChunk: """ obs: pd.DataFrame - X: scipy.sparse.spmatrix + X: ChunkX stats: "Stats" def __len__(self) -> int: @@ -72,7 +77,7 @@ class Stats: nnz: int = 0 """The total number of values retrieved""" - elapsed: int = 0 + elapsed: float = 0 """The total elapsed time in seconds for retrieving all batches""" n_soma_chunks: int = 0 @@ -101,6 +106,17 @@ def _open_experiment( yield exp +def _tables_to_np( + tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int] +) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]: + for tbl, indices in tables: + row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns) + nnz = len(data) + dense_matrix = np.zeros(shape, dtype=data.dtype) + dense_matrix[row_indices, col_indices] = data + yield dense_matrix, indices, nnz + + class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]): """Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class, not intended for public use. @@ -123,11 +139,12 @@ def __init__( var_joinids: npt.NDArray[np.int64], shuffle_chunk_count: Optional[int] = None, shuffle_rng: Optional[Generator] = None, + return_sparse_X: bool = False, ): self.obs = obs self.X = X self.obs_column_names = obs_column_names - if shuffle_chunk_count: + if shuffle_chunk_count is not None: assert shuffle_rng is not None # At the start of this step, `obs_joinids_chunked` is a list of one dimensional @@ -145,6 +162,7 @@ def __init__( self.obs_joinids_chunks_iter = iter(obs_joinids_chunked) self.var_joinids = var_joinids self.shuffle_chunk_count = shuffle_chunk_count + self.return_sparse_X = return_sparse_X def __next__(self) -> _SOMAChunk: pytorch_logger.debug("Retrieving next SOMA chunk...") @@ -178,18 +196,25 @@ def __next__(self) -> _SOMAChunk: # note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix, # but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block) - scipy_iter = ( - self.X.read(coords=(obs_joinids_chunk, self.var_joinids)) - .blockwise(axis=0, size=len(obs_joinids_chunk), eager=False) - .scipy(compress=True) + blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise( + axis=0, size=len(obs_joinids_chunk), eager=False ) - X_batch, _ = next(scipy_iter) + + X_batch: ChunkX + if not self.return_sparse_X: + res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids)))) + X_batch, nnz = res[0], res[2] + else: + X_batch = next(blockwise_iter.scipy(compress=True))[0] + nnz = X_batch.nnz + assert obs_batch.shape[0] == X_batch.shape[0] + end_time = time() stats = Stats() stats.n_obs += X_batch.shape[0] - stats.nnz += X_batch.nnz - stats.elapsed += int(time() - start_time) + stats.nnz += nnz + stats.elapsed += end_time - start_time stats.n_soma_chunks += 1 pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}") @@ -213,17 +238,19 @@ def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]: return result -def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103 +def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103 proc = psutil.Process(os.getpid()) pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() + start = time() gc.collect() + gc_elapsed = time() - start post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory() pytorch_logger.debug(f"gc: pre={pre_gc}") pytorch_logger.debug(f"gc: post={post_gc}") - return pre_gc, post_gc + return pre_gc, post_gc, gc_elapsed class _ObsAndXIterator(Iterator[ObsAndXDatum]): @@ -268,6 +295,7 @@ def __init__( var_joinids, shuffle_chunk_count, shuffle_rng, + return_sparse_X=return_sparse_X, ) if use_eager_fetch: self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter) @@ -277,24 +305,33 @@ def __init__( self.return_sparse_X = return_sparse_X self.encoders = encoders self.stats = stats + self.gc_elapsed = 0.0 self.max_process_mem_usage_bytes = 0 self.X_dtype = X.schema[2].type.to_pandas_dtype() def __next__(self) -> ObsAndXDatum: """Read the next torch batch, possibly across multiple soma chunks.""" - obs: pd.DataFrame = pd.DataFrame() - X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype) + obss: list[pd.DataFrame] = [] + Xs: list[ChunkX] = [] + n_obs = 0 - while len(obs) < self.batch_size: + while n_obs < self.batch_size: try: - obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs)) - obs = pd.concat([obs, obs_partial], axis=0) - X = sparse.vstack([X, X_partial]) + obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs) + n_obs += len(obs_partial) + obss.append(obs_partial) + Xs.append(X_partial) except StopIteration: break - if len(obs) == 0: + if len(Xs) == 0: # If we ran out of data raise StopIteration + else: + if self.return_sparse_X: + X = sparse.vstack(Xs) + else: + X = np.concatenate(Xs, axis=0) + obs = pd.concat(obss, axis=0) obs_encoded = pd.DataFrame() @@ -308,7 +345,7 @@ def __next__(self) -> ObsAndXDatum: obs_tensor = torch.from_numpy(obs_encoded.to_numpy()) if not self.return_sparse_X: - X_tensor = torch.from_numpy(X.todense()) + X_tensor = torch.from_numpy(X) else: coo = X.tocoo() @@ -325,7 +362,7 @@ def __next__(self) -> ObsAndXDatum: return X_tensor, obs_tensor - def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: + def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]: """Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current SOMA chunk are fewer than the requested ``batch_size``. @@ -333,17 +370,20 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)): # GC memory from previous soma_chunk self.soma_chunk = None - mem_info = run_gc() - self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, mem_info[0][0].uss) + pre_gc, _, gc_elapsed = run_gc() + self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss) self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter) self.stats += self.soma_chunk.stats + self.gc_elapsed += gc_elapsed self.i = 0 - pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}") + pytorch_logger.debug( + f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}" + ) obs_batch = self.soma_chunk.obs - X_batch = self.soma_chunk.X + X_chunk = self.soma_chunk.X safe_batch_size = min(batch_size, len(obs_batch) - self.i) slice_ = slice(self.i, self.i + safe_batch_size) @@ -353,12 +393,13 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum: assert obs_rows.index.is_unique assert safe_batch_size == obs_rows.shape[0] - X_csr_scipy = X_batch[slice_] - assert obs_rows.shape[0] == X_csr_scipy.shape[0] + X_batch = X_chunk[slice_] + + assert obs_rows.shape[0] == X_batch.shape[0] self.i += safe_batch_size - return obs_rows, X_csr_scipy + return obs_rows, X_batch class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore @@ -517,6 +558,7 @@ def __init__( self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None self._shuffle_rng = np.random.default_rng(seed) if shuffle else None self._initialized = False + self.max_process_mem_usage_bytes = 0 if obs_column_names and encoders: raise ValueError( @@ -645,8 +687,9 @@ def __iter__(self) -> Iterator[ObsAndXDatum]: yield from obs_and_x_iter + self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes pytorch_logger.debug( - "max process memory usage=" f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" + "max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB" ) @staticmethod diff --git a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py index 74a94f885..3490c4bbd 100644 --- a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py +++ b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py @@ -155,6 +155,39 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None assert row[1].tolist() == [0] +@pytest.mark.experimental +# noinspection PyTestParametrized +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +def test_uneven_soma_and_result_batches( + soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool +) -> None: + """This is checking that batches are correctly created when they require fetching multiple chunks. + + This was added due to failures in _ObsAndXIterator.__next__. + """ + exp_data_pipe = ExperimentDataPipe( + soma_experiment, + measurement_name="RNA", + X_name="raw", + obs_column_names=["label"], + shuffle=False, + batch_size=3, + soma_chunk_size=2, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + X_batch = row[0].to_dense() if return_sparse_X else row[0] + assert X_batch.int()[0].tolist() == [0, 1, 0] + assert row[1].tolist() == [[0], [1], [2]] + + @pytest.mark.experimental # noinspection PyTestParametrized,DuplicatedCode @pytest.mark.parametrize(