Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] dataloader optimization (picking up 1169) #1224

Merged
merged 21 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import itertools
import logging
import os
import typing
from contextlib import contextmanager
from datetime import timedelta
from math import ceil
from time import time
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import numpy.typing as npt
import pandas as pd
import psutil
import scipy
import tiledbsoma as soma
import torch
import torchdata.datapipes.iter as pipes
from attr import define
from numpy.random import Generator
from pyarrow import Table
from scipy import sparse
from torch import Tensor
from torch import distributed as dist
Expand All @@ -36,6 +37,10 @@
The Tensors are rank 1 if ``batch_size`` is 1, otherwise the Tensors are rank 2."""


# "Chunk" of X data, returned by each `Method` above
ChunkX = Union[npt.NDArray[Any], sparse.csr_matrix]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (optional): why npt.NDArray[Any] and not npt.NDArray[np.number] (or something even more specific?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is the API signature really saying that it could return any sparse type? If it is actually constrained to a more specific set of sub-types (e.g, csr_matrix), it would be useful to further specify this type alias (even if just for the static type checking...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used np.number, but mypy also want's explicit generics, so I went with np.number[Any].

Not totally sure I'm understanding the second part, since the alias here is already csr_matrix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This is opening up a big can of typing worms. There wouldn't be as much of an issue once we drop support for python 3.8.

I think I will go back to Any until we move to SPEC-0



@define
class _SOMAChunk:
"""Return type of ``_ObsAndXSOMAIterator`` that pairs a chunk of ``obs`` rows with the respective rows from the ``X``
Expand All @@ -46,7 +51,7 @@ class _SOMAChunk:
"""

obs: pd.DataFrame
X: scipy.sparse.spmatrix
X: ChunkX
stats: "Stats"

def __len__(self) -> int:
Expand All @@ -72,7 +77,7 @@ class Stats:
nnz: int = 0
"""The total number of values retrieved"""

elapsed: int = 0
elapsed: float = 0
"""The total elapsed time in seconds for retrieving all batches"""

n_soma_chunks: int = 0
Expand Down Expand Up @@ -101,6 +106,17 @@ def _open_experiment(
yield exp


def _tables_to_np(
tables: Iterator[Tuple[Table, Any]], shape: Tuple[int, int]
) -> typing.Generator[Tuple[npt.NDArray[Any], Any, int], None, None]:
for tbl, indices in tables:
row_indices, col_indices, data = (x.to_numpy() for x in tbl.columns)
nnz = len(data)
dense_matrix = np.zeros(shape, dtype=data.dtype)
dense_matrix[row_indices, col_indices] = data
yield dense_matrix, indices, nnz


class _ObsAndXSOMAIterator(Iterator[_SOMAChunk]):
"""Iterates the SOMA chunks of corresponding ``obs`` and ``X`` data. This is an internal class,
not intended for public use.
Expand All @@ -123,11 +139,12 @@ def __init__(
var_joinids: npt.NDArray[np.int64],
shuffle_chunk_count: Optional[int] = None,
shuffle_rng: Optional[Generator] = None,
return_sparse_X: bool = False,
):
self.obs = obs
self.X = X
self.obs_column_names = obs_column_names
if shuffle_chunk_count:
if shuffle_chunk_count is not None:
assert shuffle_rng is not None

# At the start of this step, `obs_joinids_chunked` is a list of one dimensional
Expand All @@ -145,6 +162,7 @@ def __init__(
self.obs_joinids_chunks_iter = iter(obs_joinids_chunked)
self.var_joinids = var_joinids
self.shuffle_chunk_count = shuffle_chunk_count
self.return_sparse_X = return_sparse_X

def __next__(self) -> _SOMAChunk:
pytorch_logger.debug("Retrieving next SOMA chunk...")
Expand Down Expand Up @@ -178,18 +196,25 @@ def __next__(self) -> _SOMAChunk:

# note: the `blockwise` call is employed for its ability to reindex the axes of the sparse matrix,
# but the blockwise iteration feature is not used (block_size is set to retrieve the chunk as a single block)
scipy_iter = (
self.X.read(coords=(obs_joinids_chunk, self.var_joinids))
.blockwise(axis=0, size=len(obs_joinids_chunk), eager=False)
.scipy(compress=True)
blockwise_iter = self.X.read(coords=(obs_joinids_chunk, self.var_joinids)).blockwise(
axis=0, size=len(obs_joinids_chunk), eager=False
)
X_batch, _ = next(scipy_iter)

X_batch: ChunkX
if not self.return_sparse_X:
res = next(_tables_to_np(blockwise_iter.tables(), shape=(obs_batch.shape[0], len(self.var_joinids))))
X_batch, nnz = res[0], res[2]
else:
X_batch = next(blockwise_iter.scipy(compress=True))[0]
nnz = X_batch.nnz

assert obs_batch.shape[0] == X_batch.shape[0]

end_time = time()
stats = Stats()
stats.n_obs += X_batch.shape[0]
stats.nnz += X_batch.nnz
stats.elapsed += int(time() - start_time)
stats.nnz += nnz
stats.elapsed += end_time - start_time
stats.n_soma_chunks += 1

pytorch_logger.debug(f"Retrieved SOMA chunk: {stats}")
Expand All @@ -213,17 +238,19 @@ def list_split(arr_list: List[Any], sublist_len: int) -> List[List[Any]]:
return result


def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any]]: # noqa: D103
def run_gc() -> Tuple[Tuple[Any, Any, Any], Tuple[Any, Any, Any], float]: # noqa: D103
proc = psutil.Process(os.getpid())

pre_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory()
start = time()
gc.collect()
gc_elapsed = time() - start
post_gc = proc.memory_full_info(), psutil.virtual_memory(), psutil.swap_memory()

pytorch_logger.debug(f"gc: pre={pre_gc}")
pytorch_logger.debug(f"gc: post={post_gc}")

return pre_gc, post_gc
return pre_gc, post_gc, gc_elapsed


class _ObsAndXIterator(Iterator[ObsAndXDatum]):
Expand Down Expand Up @@ -268,6 +295,7 @@ def __init__(
var_joinids,
shuffle_chunk_count,
shuffle_rng,
return_sparse_X=return_sparse_X,
)
if use_eager_fetch:
self.soma_chunk_iter = _EagerIterator(self.soma_chunk_iter)
Expand All @@ -277,24 +305,33 @@ def __init__(
self.return_sparse_X = return_sparse_X
self.encoders = encoders
self.stats = stats
self.gc_elapsed = 0.0
self.max_process_mem_usage_bytes = 0
self.X_dtype = X.schema[2].type.to_pandas_dtype()

def __next__(self) -> ObsAndXDatum:
"""Read the next torch batch, possibly across multiple soma chunks."""
obs: pd.DataFrame = pd.DataFrame()
X: sparse.csr_matrix = sparse.csr_matrix((0, len(self.var_joinids)), dtype=self.X_dtype)
obss: list[pd.DataFrame] = []
Xs: list[ChunkX] = []
n_obs = 0

while len(obs) < self.batch_size:
while n_obs < self.batch_size:
try:
obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - len(obs))
obs = pd.concat([obs, obs_partial], axis=0)
X = sparse.vstack([X, X_partial])
obs_partial, X_partial = self._read_partial_torch_batch(self.batch_size - n_obs)
n_obs += len(obs_partial)
obss.append(obs_partial)
Xs.append(X_partial)
except StopIteration:
break

if len(obs) == 0:
if len(Xs) == 0: # If we ran out of data
raise StopIteration
else:
if self.return_sparse_X:
X = sparse.vstack(Xs)
else:
X = np.concatenate(Xs, axis=0)
obs = pd.concat(obss, axis=0)

obs_encoded = pd.DataFrame()

Expand All @@ -308,7 +345,7 @@ def __next__(self) -> ObsAndXDatum:
obs_tensor = torch.from_numpy(obs_encoded.to_numpy())

if not self.return_sparse_X:
X_tensor = torch.from_numpy(X.todense())
X_tensor = torch.from_numpy(X)
else:
coo = X.tocoo()

Expand All @@ -325,25 +362,28 @@ def __next__(self) -> ObsAndXDatum:

return X_tensor, obs_tensor

def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum:
def _read_partial_torch_batch(self, batch_size: int) -> Tuple[pd.DataFrame, ChunkX]:
"""Reads a torch-size batch of data from the current SOMA chunk, returning a torch-size batch whose size may
contain fewer rows than the requested ``batch_size``. This can happen when the remaining rows in the current
SOMA chunk are fewer than the requested ``batch_size``.
"""
if self.soma_chunk is None or not (0 <= self.i < len(self.soma_chunk)):
# GC memory from previous soma_chunk
self.soma_chunk = None
mem_info = run_gc()
self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, mem_info[0][0].uss)
pre_gc, _, gc_elapsed = run_gc()
self.max_process_mem_usage_bytes = max(self.max_process_mem_usage_bytes, pre_gc[0].uss)

self.soma_chunk: _SOMAChunk = next(self.soma_chunk_iter)
self.stats += self.soma_chunk.stats
self.gc_elapsed += gc_elapsed
self.i = 0

pytorch_logger.debug(f"Retrieved SOMA chunk totals: {self.stats}")
pytorch_logger.debug(
f"Retrieved SOMA chunk totals: {self.stats}, gc_elapsed={timedelta(seconds=self.gc_elapsed)}"
)

obs_batch = self.soma_chunk.obs
X_batch = self.soma_chunk.X
X_chunk = self.soma_chunk.X

safe_batch_size = min(batch_size, len(obs_batch) - self.i)
slice_ = slice(self.i, self.i + safe_batch_size)
Expand All @@ -353,12 +393,13 @@ def _read_partial_torch_batch(self, batch_size: int) -> ObsAndXDatum:
assert obs_rows.index.is_unique
assert safe_batch_size == obs_rows.shape[0]

X_csr_scipy = X_batch[slice_]
assert obs_rows.shape[0] == X_csr_scipy.shape[0]
X_batch = X_chunk[slice_]

assert obs_rows.shape[0] == X_batch.shape[0]

self.i += safe_batch_size

return obs_rows, X_csr_scipy
return obs_rows, X_batch


class ExperimentDataPipe(pipes.IterDataPipe[Dataset[ObsAndXDatum]]): # type: ignore
Expand Down Expand Up @@ -517,6 +558,7 @@ def __init__(
self._shuffle_chunk_count = shuffle_chunk_count if shuffle else None
self._shuffle_rng = np.random.default_rng(seed) if shuffle else None
self._initialized = False
self.max_process_mem_usage_bytes = 0

if obs_column_names and encoders:
raise ValueError(
Expand Down Expand Up @@ -645,8 +687,9 @@ def __iter__(self) -> Iterator[ObsAndXDatum]:

yield from obs_and_x_iter

self.max_process_mem_usage_bytes = obs_and_x_iter.max_process_mem_usage_bytes
pytorch_logger.debug(
"max process memory usage=" f"{obs_and_x_iter.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB"
"max process memory usage=" f"{self.max_process_mem_usage_bytes / (1024 ** 3):.3f} GiB"
)

@staticmethod
Expand Down
33 changes: 33 additions & 0 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ def test_non_batched(soma_experiment: Experiment, use_eager_fetch: bool) -> None
assert row[1].tolist() == [0]


@pytest.mark.experimental
# noinspection PyTestParametrized
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize("return_sparse_X", [True, False])
def test_uneven_soma_and_result_batches(
soma_experiment: Experiment, use_eager_fetch: bool, return_sparse_X: bool
) -> None:
"""This is checking that batches are correctly created when they require fetching multiple chunks.

This was added due to failures in _ObsAndXIterator.__next__.
"""
exp_data_pipe = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
shuffle=False,
batch_size=3,
soma_chunk_size=2,
return_sparse_X=return_sparse_X,
use_eager_fetch=use_eager_fetch,
)
row_iter = iter(exp_data_pipe)

row = next(row_iter)
X_batch = row[0].to_dense() if return_sparse_X else row[0]
assert X_batch.int()[0].tolist() == [0, 1, 0]
assert row[1].tolist() == [[0], [1], [2]]


@pytest.mark.experimental
# noinspection PyTestParametrized,DuplicatedCode
@pytest.mark.parametrize(
Expand Down
Loading