Skip to content

Commit

Permalink
[Easy][BE]: remove itertools.accumulate Python 2 shim and apply UFMT (p…
Browse files Browse the repository at this point in the history
…ytorch#116192)

Removes an unnecessary duplicated utility functions and just have it rely on itertools. Since the file is low traffic, I also added the modified files to UFMT'd files and formatted them.
Pull Request resolved: pytorch#116192
Approved by: https://github.com/malfet
  • Loading branch information
Skylion007 authored and dmenig committed Dec 21, 2023
1 parent 9e2932f commit 35e9f5c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 45 deletions.
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,6 @@ exclude_patterns = [
'torch/utils/data/datapipes/utils/common.py',
'torch/utils/data/datapipes/utils/decoder.py',
'torch/utils/data/datapipes/utils/snapshot.py',
'torch/utils/data/dataset.py',
'torch/utils/data/distributed.py',
'torch/utils/data/graph.py',
'torch/utils/data/graph_settings.py',
Expand Down
16 changes: 0 additions & 16 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,22 +493,6 @@ def _import_dotted_name(name):
return obj


# Taken from python 3.5 docs
def _accumulate(iterable, fn=lambda x, y: x + y):
"Return running totals"
# _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = fn(total, element)
yield total


def _flatten_dense_tensors(tensors):
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
same dense type.
Expand Down
87 changes: 59 additions & 28 deletions torch/utils/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import bisect
import warnings
import itertools
import math
import warnings
from typing import (
cast,
Dict,
Generic,
Iterable,
List,
Expand All @@ -10,12 +13,10 @@
Tuple,
TypeVar,
Union,
Dict
)

# No 'default_generator' in torch/__init__.pyi
from torch import default_generator, randperm
from torch._utils import _accumulate

from ... import Generator, Tensor

Expand All @@ -30,11 +31,11 @@
"random_split",
]

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
T_dict = Dict[str, T_co]
T_tuple = Tuple[T_co, ...]
T_stack = TypeVar('T_stack', T_tuple, T_dict)
T_stack = TypeVar("T_stack", T_tuple, T_dict)


class Dataset(Generic[T_co]):
Expand Down Expand Up @@ -63,7 +64,7 @@ def __getitem__(self, index) -> T_co:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])

# No `def __len__(self)` default?
Expand Down Expand Up @@ -199,7 +200,9 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]):
tensors: Tuple[Tensor, ...]

def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
assert all(
tensors[0].size(0) == tensor.size(0) for tensor in tensors
), "Size mismatch between tensors"
self.tensors = tensors

def __getitem__(self, index):
Expand Down Expand Up @@ -233,8 +236,10 @@ class StackDataset(Dataset[T_stack]):
def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None:
if args:
if kwargs:
raise ValueError("Supported either ``tuple``- (via ``args``) or"
"``dict``- (via ``kwargs``) like input/output, but both types are given.")
raise ValueError(
"Supported either ``tuple``- (via ``args``) or"
"``dict``- (via ``kwargs``) like input/output, but both types are given."
)
self._length = len(args[0]) # type: ignore[arg-type]
if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
raise ValueError("Size mismatch between datasets")
Expand All @@ -261,8 +266,10 @@ def __getitems__(self, indices: list):
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError("Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}")
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, d_sample in zip(items, dict_batch):
d_sample[k] = data
else:
Expand All @@ -276,8 +283,10 @@ def __getitems__(self, indices: list):
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError("Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}")
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, t_sample in zip(items, list_batch):
t_sample.append(data)
else:
Expand Down Expand Up @@ -314,9 +323,11 @@ def cumsum(sequence):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
Expand All @@ -325,7 +336,9 @@ def __len__(self):
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
Expand All @@ -336,8 +349,11 @@ def __getitem__(self, idx):

@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
warnings.warn(
"cummulative_sizes attribute is renamed to " "cumulative_sizes",
DeprecationWarning,
stacklevel=2,
)
return self.cumulative_sizes


Expand All @@ -358,13 +374,17 @@ def __init__(self, datasets: Iterable[Dataset]) -> None:

def __iter__(self):
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
yield from d

def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
total += len(d) # type: ignore[arg-type]
return total

Expand Down Expand Up @@ -402,8 +422,11 @@ def __len__(self):
return len(self.indices)


def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
def random_split(
dataset: Dataset[T],
lengths: Sequence[Union[int, float]],
generator: Optional[Generator] = default_generator,
) -> List[Subset[T]]:
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
Expand Down Expand Up @@ -446,12 +469,20 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]],
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)

# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)

indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
lengths = cast(Sequence[int], lengths)
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
]

0 comments on commit 35e9f5c

Please sign in to comment.