Skip to content

Commit

Permalink
[python] Fix ExperimentDataPipe length (#1221)
Browse files Browse the repository at this point in the history
* Fix ExperimentDataPipe length

* Update api/python/cellxgene_census/src/cellxgene_census/experimental/ml/pytorch.py

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>

* Update api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>

---------

Co-authored-by: Isaac Virshup <ivirshup@gmail.com>
  • Loading branch information
ebezzi and ivirshup committed Jul 8, 2024
1 parent 8457e3f commit af88f3d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ def __len__(self) -> int:
self._init()
assert self._obs_joinids is not None

return len(self._obs_joinids)
div, rem = divmod(len(self._obs_joinids), self.batch_size)
return div + bool(rem)

def __getitem__(self, index: int) -> ObsAndXDatum:
raise NotImplementedError("IterDataPipe can only be iterated")
Expand Down
20 changes: 20 additions & 0 deletions api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,26 @@ def test_experiment_dataloader__batched(soma_experiment: Experiment, use_eager_f
assert batch[1].tolist() == [[0, 0], [1, 1], [2, 2]]


@pytest.mark.experimental
# noinspection PyTestParametrized,DuplicatedCode
@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(10, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
def test_experiment_dataloader__batched_length(soma_experiment: Experiment, use_eager_fetch: bool) -> None:
dp = ExperimentDataPipe(
soma_experiment,
measurement_name="RNA",
X_name="raw",
obs_column_names=["label"],
batch_size=3,
shuffle=False,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
assert len(dl) == len(list(dl))


@pytest.mark.experimental
# noinspection PyTestParametrized,DuplicatedCode
@pytest.mark.parametrize(
Expand Down

0 comments on commit af88f3d

Please sign in to comment.