Skip to content

Commit

Permalink
[python] dataloader optimization (picking up 1169) (#1224)
Browse files Browse the repository at this point in the history
* ExperimentDataPipe: configurable array-conversion method

`nd.array`, `scipy.coo`, `scipy.csr`

* group chunks iff `shuffle_chunk_count > 1`

* rm scipy.coo method

* add `METHODS` constant

* Use return_sparse_X instead of adding a method kwarg

* Oops, one more thing

* revert change to shuffle_chunk_count branch

* mypy + linting fixes

* Add test for bug I found

* Fix bug, possibly speed up dataloaders

* Remove unused (I think) branch

* Remove more unused code

* Slightly more precise typing

* Reduce additional copies in _tables_to_np

* Simplify branching in _ObsAndXSOMAIterator

* typing fix

* Revert "typing fix"

This reverts commit 357582e.

* Use more simplistic typing to avoid mypy + old python wrath

---------

Co-authored-by: Ryan Williams <ryan.williams@tiledb.com>
  • Loading branch information
ivirshup and ryan-williams committed Jul 8, 2024
1 parent 82a3f86 commit 473ba97
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import itertools
import logging
import os
import typing
from contextlib import contextmanager
from datetime import timedelta
from math import ceil
from time import time
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import numpy.typing as npt
import pandas as pd
import psutil
import scipy
import tiledbsoma as soma
import torch
import torchdata.datapipes.iter as pipes
from attr import define
from numpy.random import Generator
from pyarrow import Table
from scipy import sparse
from torch import Tensor
from torch import distributed as dist
Expand All @@ -36,6 +37,10 @@
The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2."""


# "Chunk" of X data, returned by each `Method` above
ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix]


@define
class _SOMAChunk:
"""Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X``
Expand All @@ -46,7 +51,7 @@ class _SOMAChunk:
"""

obs: pd.DataFrame
X: scipy.sparse.spmatrix
X: ChunkX
stats: "Stats"

def __len__(self) -> int:
Expand All @@ -72,7 +77,7 @@ class Stats:
nnz: int = 0
"""The total number of values retrieved"""

elapsed: int = 0
elapsed: float = 0
"""The total elapsed time in seconds for retrieving all batches"""

n_soma_chunks: int = 0
Expand Down Expand Up @@ -101,6 +106,17 @@ def _open_experiment(
yield exp


def _tables_to_np(
tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int]
) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]:
for tbl, indices in tables:
row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns)
nnz = len(data)
dense_matrix = np.zeros(shape, dtype=data.dtype)
dense_matrix[row_indices, col_indices] = data
yield dense_matrix, indices, nnz


class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]):
"""Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class,
not intended for public use.
Expand All @@ -123,11 +139,12 @@ def __init__(
var_joinids: npt.NDArray[np.int64],
shuffle_chunk_count: Optional[int] = None,
shuffle_rng: Optional[Generator] = None,
return_sparse_X: bool = False,
):
self.obs = obs
self.X = X
self.obs_column_names = obs_column_names
if shuffle_chunk_count:
if shuffle_chunk_count is not None:
assert shuffle_rng is not None

# At the start of this step, `obs_joinids_chunked` is a list of one dimensional
Expand All @@ -145,6 +162,7 @@ def __init__(
self.obs_joinids_chunks_iter = iter(obs_joinids_chunked)
self.var_joinids = var_joinids
self.shuffle_chunk_count = shuffle_chunk_count
self.return_sparse_X = return_sparse_X

def __next__(self) -> _SOMAChunk:
pytorch_logger.debug("Retrieving next SOMA chunk...")
Expand Down Expand Up @@ -178,18 +196,25 @@ def __next__(self) -> _SOMAChunk:

# note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix,
# but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block)
scipy_iter = (
self.X.read(coords=(obs_joinids_chunk, self.var_joinids))
.blockwise(axis=0, size=len(obs_joinids_chunk), eager=False)
.scipy(compress=True)
blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise(
axis=0, size=len(obs_joinids_chunk), eager=False
)
X_batch, _ = next(scipy_iter)

X_batch: ChunkX
if not self.return_sparse_X:
res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))))
X_batch, nnz = res[0], res[2]
else:
X_batch = next(blockwise_iter.scipy(compress=True))[0]
nnz = X_batch.nnz

assert obs_batch.shape[0] == X_batch.shape[0]

end_time = time()
stats = Stats()
stats.n_obs += X_batch.shape[0]
stats.nnz += X_batch.nnz
stats.elapsed += int(time() - start_time)
stats.nnz += nnz
stats.elapsed += end_time - start_time
stats.n_soma_chunks += 1

pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}")
Expand All @@ -213,17 +238,19 @@ def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]:
return result


def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103
def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103
proc = psutil.Process(os.getpid())

pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory()
start = time()
gc.collect()
gc_elapsed = time() - start
post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory()

pytorch_logger.debug(f"gc: pre={pre_gc}")
pytorch_logger.debug(f"gc: post={post_gc}")

return pre_gc, post_gc
return pre_gc, post_gc, gc_elapsed


class _ObsAndXIterator(Iterator[ObsAndXDatum]):
Expand Down Expand Up @@ -268,6 +295,7 @@ def __init__(
var_joinids,
shuffle_chunk_count,
shuffle_rng,
return_sparse_X=return_sparse_X,
)
if use_eager_fetch:
self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter)
Expand All @@ -277,24 +305,33 @@ def __init__(
self.return_sparse_X = return_sparse_X
self.encoders = encoders
self.stats = stats
self.gc_elapsed = 0.0
self.max_process_mem_usage_bytes = 0
self.X_dtype = X.schema[2].type.to_pandas_dtype()

def __next__(self) -> ObsAndXDatum:
"""Read the next torch batch, possibly across multiple soma chunks."""
obs: pd.DataFrame = pd.DataFrame()
X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype)
obss: list[pd.DataFrame] = []
Xs: list[ChunkX] = []
n_obs = 0

while len(obs) < self.batch_size:
while n_obs < self.batch_size:
try:
obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs))
obs = pd.concat([obs, obs_partial], axis=0)
X = sparse.vstack([X, X_partial])
obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs)
n_obs += len(obs_partial)
obss.append(obs_partial)
Xs.append(X_partial)
except StopIteration:
break

if len(obs) == 0:
if len(Xs) == 0: # If we ran out of data
raise StopIteration
else:
if self.return_sparse_X:
X = sparse.vstack(Xs)
else:
X = np.concatenate(Xs, axis=0)
obs = pd.concat(obss, axis=0)

obs_encoded = pd.DataFrame()

Expand All @@ -308,7 +345,7 @@ def __next__(self) -> ObsAndXDatum:
obs_tensor = torch.from_numpy(obs_encoded.to_numpy())

if not self.return_sparse_X:
X_tensor = torch.from_numpy(X.todense())
X_tensor = torch.from_numpy(X)
else:
coo = X.tocoo()

Expand All @@ -325,25 +362,28 @@ def __next__(self) -> ObsAndXDatum:

return X_tensor, obs_tensor

def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum:
def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]:
"""Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may
contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current
SOMA chunk are fewer than the requested ``batch_size``.
"""
if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)):
# GC memory from previous soma_chunk
self.soma_chunk = None
mem_info = run_gc()
self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, mem_info[0][0].uss)
pre_gc, _, gc_elapsed = run_gc()
self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss)

self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter)
self.stats += self.soma_chunk.stats
self.gc_elapsed += gc_elapsed
self.i = 0

pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}")
pytorch_logger.debug(
f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}"
)

obs_batch = self.soma_chunk.obs
X_batch = self.soma_chunk.X
X_chunk = self.soma_chunk.X

safe_batch_size = min(batch_size, len(obs_batch) - self.i)
slice_ = slice(self.i, self.i + safe_batch_size)
Expand All @@ -353,12 +393,13 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum:
assert obs_rows.index.is_unique
assert safe_batch_size == obs_rows.shape[0]

X_csr_scipy = X_batch[slice_]
assert obs_rows.shape[0] == X_csr_scipy.shape[0]
X_batch = X_chunk[slice_]

assert obs_rows.shape[0] == X_batch.shape[0]

self.i += safe_batch_size

return obs_rows, X_csr_scipy
return obs_rows, X_batch


class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore
Expand Down Expand Up @@ -517,6 +558,7 @@ def __init__(
self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None
self._shuffle_rng = np.random.default_rng(seed) if shuffle else None
self._initialized = False
self.max_process_mem_usage_bytes = 0

if obs_column_names and encoders:
raise ValueError(
Expand Down Expand Up @@ -645,8 +687,9 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:

yield from obs_and_x_iter

self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes
pytorch_logger.debug(
"max process memory usage=" f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB"
"max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB"
)

@staticmethod
Expand Down
33 changes: 33 additions & 0 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None
assert row[1].tolist() == [0]


@pytest.mark.experimental
# noinspection PyTestParametrized
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize("return_sparse_X", [True, False])
def test_uneven_soma_and_result_batches(
soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool
) -> None:
"""This is checking that batches are correctly created when they require fetching multiple chunks.
This was added due to failures in _ObsAndXIterator.__next__.
"""
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
shuffle=False,
batch_size=3,
soma_chunk_size=2,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
)
row_iter = iter(exp_data_pipe)

row = next(row_iter)
X_batch = row[0].to_dense() if return_sparse_X else row[0]
assert X_batch.int()[0].tolist() == [0, 1, 0]
assert row[1].tolist() == [[0], [1], [2]]


@pytest.mark.experimental
# noinspection PyTestParametrized,DuplicatedCode
@pytest.mark.parametrize(
Expand Down

0 comments on commit 473ba97

Please sign in to comment.