Skip to content

Commit

Permalink
[DataLoader] __getitems__ added to description of Dataset API and b…
Browse files Browse the repository at this point in the history
…etter supported within `Subset` (pytorch#100375)

DataLoader supports batched loading from Mapped Datasets.

This is the fetcher's implementation of auto-detection of batch loading support.

torch.utils.data._utils.fetch._MapDatasetFetcher
```
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
```

Description of Dataset API now shows this feature.

Additionally, Subset dataset now supports `__getitems__` if parent dataset supports it.
Pull Request resolved: pytorch#100375
Approved by: https://github.com/ejguan, https://github.com/NivekT
  • Loading branch information
stsouko authored and kiersten-stokes committed May 8, 2023
1 parent d8ec904 commit bc41e94
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ class Dataset(Generic[T_co]):
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
Expand All @@ -52,6 +55,10 @@ class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

# def __getitems__(self, indices: List) -> List[T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])

Expand Down Expand Up @@ -296,6 +303,14 @@ def __getitem__(self, idx):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]

def __getitems__(self, indices: List[int]) -> List[T_co]:
# add batched sampling support when parent dataset supports it.
# see torch.utils.data._utils.fetch._MapDatasetFetcher
if callable(getattr(self.dataset, "__getitems__", None)):
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
else:
return [self.dataset[self.indices[idx]] for idx in indices]

def __len__(self):
return len(self.indices)

Expand Down

0 comments on commit bc41e94

Please sign in to comment.