diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index ce96eb60e96e..3f8472f856a3 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,7 +1,7 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Iterator, Tuple, Union +from typing import Dict, Iterable, Tuple, Union import torch @@ -16,25 +16,25 @@ def is_scalar(x): class ItemSet: - r"""A wrapper of iterable data or tuple of iterable data. - - All itemsets that represent an iterable of items should subclass it. Such - form of itemset is particularly useful when items come from a stream. This - class requires each input itemset to be iterable. + r"""A wrapper of a tensor or tuple of tensors. Parameters ---------- - items: Union[int, Iterable, Tuple[Iterable]] - The items to be iterated over. If it is a single integer, a `range()` - object will be created and iterated over. If it's multi-dimensional - iterable such as `torch.Tensor`, it will be iterated over the first - dimension. If it is a tuple, each item in the tuple is an iterable of - items. + items: Union[int, torch.Tensor, Tuple[torch.Tensor]] + The tensors to be wrapped. + - If it is a single scalar (an integer or a tensor that holds a single + value), the item would be considered as a range_tensor created by + `torch.arange`. + - If it is a multi-dimensional tensor, the indexing will be performed + along the first dimension. + - If it is a tuple, each item in the tuple must be a tensor. + names: Union[str, Tuple[str]], optional - The names of the items. If it is a tuple, each name corresponds to an - item in the tuple. The naming is arbitrary, but in general practice, - the names should be chosen from ['labels', 'seeds', 'indexes'] to align - with the attributes of class `dgl.graphbolt.MiniBatch`. + The names of the items. If it is a tuple, each name must corresponds to + an item in the `items` parameter. The naming is arbitrary, but in + general practice, the names should be chosen from ['labels', 'seeds', + 'indexes'] to align with the attributes of class + `dgl.graphbolt.MiniBatch`. Examples -------- @@ -68,7 +68,7 @@ class requires each input itemset to be iterable. >>> item_set.names ('seeds',) - 3. Single iterable: seed nodes. + 3. Single tensor: seed nodes. >>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") @@ -79,7 +79,7 @@ class requires each input itemset to be iterable. >>> item_set.names ('seeds',) - 4. Tuple of iterables with same shape: seed nodes and labels. + 4. Tuple of tensors with same shape: seed nodes and labels. >>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) @@ -93,7 +93,7 @@ class requires each input itemset to be iterable. >>> item_set.names ('seeds', 'labels') - 5. Tuple of iterables with different shape: seeds and labels. + 5. Tuple of tensors with different shape: seeds and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) @@ -111,7 +111,7 @@ class requires each input itemset to be iterable. >>> item_set.names ('seeds', 'labels') - 6. Tuple of iterables with different shape: hyperlink and labels. + 6. Tuple of tensors with different shape: hyperlink and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) @@ -129,31 +129,23 @@ class requires each input itemset to be iterable. def __init__( self, - items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]], + items: Union[int, torch.Tensor, Tuple[torch.Tensor]], names: Union[str, Tuple[str]] = None, ) -> None: if is_scalar(items): self._length = int(items) self._items = items - self._num_items = 1 elif isinstance(items, tuple): - try: - self._length = len(items[0]) - except TypeError: - self._length = None - if self._length is not None: - if any(self._length != len(item) for item in items): - raise ValueError("Size mismatch between items.") + self._length = len(items[0]) + if any(self._length != len(item) for item in items): + raise ValueError("Size mismatch between items.") self._items = items - self._num_items = len(items) else: - try: - self._length = len(items) - except TypeError: - self._length = None + self._length = len(items) self._items = (items,) - self._num_items = 1 - + self._num_items = ( + len(self._items) if isinstance(self._items, tuple) else 1 + ) if names is not None: if isinstance(names, tuple): self._names = names @@ -161,65 +153,39 @@ def __init__( self._names = (names,) assert self._num_items == len(self._names), ( f"Number of items ({self._num_items}) and " - f"names ({len(self._names)}) must match." + f"names ({len(self._names)}) don't match." ) else: self._names = None - def __iter__(self) -> Iterator: - if is_scalar(self._items): - dtype = getattr(self._items, "dtype", torch.int64) - yield from torch.arange(self._items, dtype=dtype) - return - - if self._num_items == 1: - yield from self._items[0] - return - - if self._length is not None: - # Use for-loop to iterate over the items. It can avoid a long - # waiting time when the items are torch tensors. Since torch - # tensors need to call self.unbind(0) to slice themselves. - # While for-loops are slower than zip, they prevent excessive - # wait times during the loading phase, and the impact on overall - # performance during the training/testing stage is minimal. - # For more details, see https://github.com/dmlc/dgl/pull/6293. - for i in range(self._length): - yield tuple(item[i] for item in self._items) - else: - # If the items are not Sized, we use zip to iterate over them. - zip_items = zip(*self._items) - for item in zip_items: - yield tuple(item) + def __len__(self) -> int: + return self._length - def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: - if self._length is None: - raise TypeError( - f"{type(self).__name__} instance doesn't support indexing." - ) + def __getitem__(self, index: Union[int, slice, Iterable[int]]): if is_scalar(self._items): - if isinstance(idx, slice): - start, stop, step = idx.indices(self._length) - dtype = getattr(self._items, "dtype", torch.int64) + dtype = getattr(self._items, "dtype", torch.int64) + if isinstance(index, slice): + start, stop, step = index.indices(self._length) return torch.arange(start, stop, step, dtype=dtype) - if isinstance(idx, int): - if idx < 0: - idx += self._length - if idx < 0 or idx >= self._length: + elif isinstance(index, int): + if index < 0: + index += self._length + if index < 0 or index >= self._length: raise IndexError( f"{type(self).__name__} index out of range." ) - return ( - torch.tensor(idx, dtype=self._items.dtype) - if isinstance(self._items, torch.Tensor) - else idx + return torch.tensor(index, dtype=dtype) + elif isinstance(index, Iterable): + return torch.tensor(index, dtype=dtype) + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, not {type(index)}." ) - raise TypeError( - f"{type(self).__name__} indices must be integer or slice." - ) - if self._num_items == 1: - return self._items[0][idx] - return tuple(item[idx] for item in self._items) + elif self._num_items == 1: + return self._items[0][index] + else: + return tuple(item[index] for item in self._items) @property def names(self) -> Tuple[str]: @@ -231,13 +197,6 @@ def num_items(self) -> int: """Return the number of the items.""" return self._num_items - def __len__(self): - if self._length is None: - raise TypeError( - f"{type(self).__name__} instance doesn't have valid length." - ) - return self._length - def __repr__(self) -> str: ret = ( f"{self.__class__.__name__}(\n" @@ -245,15 +204,14 @@ def __repr__(self) -> str: f" names={self._names},\n" f")" ) - return ret class ItemSetDict: r"""Dictionary wrapper of **ItemSet**. - Each item is retrieved by iterating over each itemset and returned with - corresponding key as a dict. + This class aims to assemble existing itemsets with different tags, for + example, seed_nodes of different node types in a graph. Parameters ---------- @@ -264,7 +222,7 @@ class ItemSetDict: >>> import torch >>> from dgl import graphbolt as gb - 1. Single iterable: seed nodes. + 1. Each itemset is a single tensor: seed nodes. >>> node_ids_user = torch.arange(0, 5) >>> node_ids_item = torch.arange(5, 10) @@ -281,7 +239,8 @@ class ItemSetDict: >>> item_set.names ('seeds',) - 2. Tuple of iterables with same shape: seed nodes and labels. + 2. Each itemset is a tuple of tensors with same shape: seed nodes and + labels. >>> node_ids_user = torch.arange(0, 2) >>> labels_user = torch.arange(0, 2) @@ -304,7 +263,8 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 3. Tuple of iterables with different shape: seeds and labels. + 3. Each itemset is a tuple of tensors with different shape: seeds and + labels. >>> seeds_like = torch.arange(0, 4).reshape(-1, 2) >>> labels_like = torch.tensor([1, 0]) @@ -331,7 +291,8 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 4. Tuple of iterables with different shape: hyperlink and labels. + 4. Each itemset is a tuple of tensors with different shape: hyperlink and + labels. >>> first_seeds = torch.arange(0, 6).reshape(-1, 3) >>> first_labels = torch.tensor([1, 0]) @@ -359,55 +320,40 @@ class ItemSetDict: def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._itemsets = itemsets - self._names = itemsets[list(itemsets.keys())[0]].names + self._names = next(iter(itemsets.values())).names assert all( self._names == itemset.names for itemset in itemsets.values() ), "All itemsets must have the same names." - try: - # For indexable itemsets, we compute the offsets for each itemset - # in advance to speed up indexing. - offsets = [0] + [ - len(itemset) for itemset in self._itemsets.values() - ] - self._offsets = torch.tensor(offsets).cumsum(0) - except TypeError: - self._offsets = None - - def __iter__(self) -> Iterator: - for key, itemset in self._itemsets.items(): - for item in itemset: - yield {key: item} + offset = [0] + [len(itemset) for itemset in self._itemsets.values()] + self._offsets = torch.tensor(offset).cumsum(0) + self._length = int(self._offsets[-1]) + self._keys = list(self._itemsets.keys()) def __len__(self) -> int: - return sum(len(itemset) for itemset in self._itemsets.values()) + return self._length - def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]: - if self._offsets is None: - raise TypeError( - f"{type(self).__name__} instance doesn't support indexing." - ) - total_num = self._offsets[-1] - if isinstance(idx, int): - if idx < 0: - idx += total_num - if idx < 0 or idx >= total_num: + def __getitem__(self, index: Union[int, slice, Iterable[int]]): + if isinstance(index, int): + if index < 0: + index += self._length + if index < 0 or index >= self._length: raise IndexError(f"{type(self).__name__} index out of range.") - offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx = torch.searchsorted(self._offsets, index, right=True) offset_idx -= 1 - idx -= self._offsets[offset_idx] - key = list(self._itemsets.keys())[offset_idx] - return {key: self._itemsets[key][idx]} - elif isinstance(idx, slice): - start, stop, step = idx.indices(total_num) - assert step == 1, "Step must be 1." + index -= self._offsets[offset_idx] + key = self._keys[offset_idx] + return {key: self._itemsets[key][index]} + elif isinstance(index, slice): + start, stop, step = index.indices(self._length) + if step != 1: + return self.__getitem__(torch.arange(start, stop, step)) assert start < stop, "Start must be smaller than stop." data = {} offset_idx_start = max( 1, torch.searchsorted(self._offsets, start, right=False) ) - keys = list(self._itemsets.keys()) for offset_idx in range(offset_idx_start, len(self._offsets)): - key = keys[offset_idx - 1] + key = self._keys[offset_idx - 1] data[key] = self._itemsets[key][ max(0, start - self._offsets[offset_idx - 1]) : stop - self._offsets[offset_idx - 1] @@ -415,9 +361,39 @@ def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]: if stop <= self._offsets[offset_idx]: break return data + elif isinstance(index, Iterable): + # TODO[Mingbang]: Might have performance issue. Tests needed. + data = {key: [] for key in self._keys} + for idx in index: + if idx < 0: + idx += self._length + if idx < 0 or idx >= self._length: + raise IndexError( + f"{type(self).__name__} index out of range." + ) + offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx -= 1 + idx -= self._offsets[offset_idx] + key = self._keys[offset_idx] + data[key].append(int(idx)) + for key in self._keys: + indices = data[key] + if len(indices) == 0: + del data[key] + continue + item_set = self._itemsets[key] + try: + value = item_set[indices] + except TypeError: + # In case the itemset doesn't support list indexing. + value = tuple(item_set[idx] for idx in indices) + finally: + data[key] = value + return data else: raise TypeError( - f"{type(self).__name__} indices must be int or slice." + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, not {type(index)}." ) @property diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index 127480d9aa15..744b81462e4a 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -68,39 +68,6 @@ def minibatcher(batch, names): assert len(minibatch.seeds) == 4 -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): - num_ids = 103 - - class InvalidLength: - def __iter__(self): - return iter(torch.arange(0, num_ids)) - - seed_nodes = gb.ItemSet(InvalidLength()) - item_set = gb.ItemSet(seed_nodes, names="seeds") - item_sampler = gb.ItemSampler( - item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) - minibatch_ids = [] - for i, minibatch in enumerate(item_sampler): - assert isinstance(minibatch, gb.MiniBatch) - assert minibatch.seeds is not None - assert minibatch.labels is None - is_last = (i + 1) * batch_size >= num_ids - if not is_last or num_ids % batch_size == 0: - assert len(minibatch.seeds) == batch_size - else: - if not drop_last: - assert len(minibatch.seeds) == num_ids % batch_size - else: - assert False - minibatch_ids.append(minibatch.seeds) - minibatch_ids = torch.cat(minibatch_ids) - assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle - - @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) @@ -496,52 +463,6 @@ def test_append_with_other_datapipes(): assert len(data.seeds) == batch_size -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last): - class IterableOnly: - def __init__(self, start, stop): - self._start = start - self._stop = stop - - def __iter__(self): - return iter(torch.arange(self._start, self._stop)) - - num_ids = 205 - ids = { - "user": gb.ItemSet(IterableOnly(0, 99), names="seeds"), - "item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"), - } - chained_ids = [] - for key, value in ids.items(): - chained_ids += [(key, v) for v in value] - item_set = gb.ItemSetDict(ids) - item_sampler = gb.ItemSampler( - item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) - minibatch_ids = [] - for i, minibatch in enumerate(item_sampler): - is_last = (i + 1) * batch_size >= num_ids - if not is_last or num_ids % batch_size == 0: - expected_batch_size = batch_size - else: - if not drop_last: - expected_batch_size = num_ids % batch_size - else: - assert False - assert isinstance(minibatch, gb.MiniBatch) - assert minibatch.seeds is not None - ids = [] - for _, v in minibatch.seeds.items(): - ids.append(v) - ids = torch.cat(ids) - assert len(ids) == expected_batch_size - minibatch_ids.append(ids) - minibatch_ids = torch.cat(minibatch_ids) - assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle - - @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index e41efcb2a2df..ae0d923e2081 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -25,14 +25,14 @@ def test_ItemSet_names(): # Integer-initiated ItemSet with excessive names. with pytest.raises( AssertionError, - match=re.escape("Number of items (1) and names (2) must match."), + match=re.escape("Number of items (1) and names (2) don't match."), ): _ = gb.ItemSet(5, names=("seeds", "labels")) # ItemSet with mismatched items and names. with pytest.raises( AssertionError, - match=re.escape("Number of items (1) and names (2) must match."), + match=re.escape("Number of items (1) and names (2) don't match."), ): _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels")) @@ -77,31 +77,16 @@ def __iter__(self): return iter([0, 1, 2]) # Single iterable with invalid length. - item_set = gb.ItemSet(InvalidLength()) with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." + TypeError, match="object of type 'InvalidLength' has no len()" ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSet instance doesn't support indexing." - ): - _ = item_set[0] - for i, item in enumerate(item_set): - assert i == item + item_set = gb.ItemSet(InvalidLength()) # Tuple of iterables with invalid length. - item_set = gb.ItemSet((InvalidLength(), InvalidLength())) with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." + TypeError, match="object of type 'InvalidLength' has no len()" ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSet instance doesn't support indexing." - ): - _ = item_set[0] - for i, (item1, item2) in enumerate(item_set): - assert i == item1 - assert i == item2 + item_set = gb.ItemSet((InvalidLength(), InvalidLength())) def test_ItemSet_seed_nodes(): @@ -113,7 +98,7 @@ def test_ItemSet_seed_nodes(): assert i == item.item() assert i == item_set[i] # Indexing with a slice. - assert torch.equal(item_set[:], torch.arange(0, 5)) + assert torch.equal(item_set[::2], torch.tensor([0, 2, 4])) # Indexing with an Iterable. assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) @@ -125,7 +110,8 @@ def test_ItemSet_seed_nodes(): assert i == item.item() assert i == item_set[i] # Indexing with a slice. - assert torch.equal(item_set[:], torch.arange(0, 5)) + assert torch.equal(item_set[::2], torch.tensor([0, 2, 4])) + assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) # Indexing with an integer. assert item_set[0] == 0 assert item_set[-1] == 4 @@ -134,11 +120,12 @@ def test_ItemSet_seed_nodes(): _ = item_set[5] with pytest.raises(IndexError, match="ItemSet index out of range."): _ = item_set[-10] - # Indexing with tensor. + # Indexing with invalid input type. with pytest.raises( - TypeError, match="ItemSet indices must be integer or slice." + TypeError, + match="ItemSet indices must be int, slice, or iterable of int, not .", ): - _ = item_set[torch.arange(3)] + _ = item_set[1.5] def test_ItemSet_seed_nodes_labels(): @@ -320,36 +307,30 @@ def __iter__(self): return iter([0, 1, 2]) # Single iterable with invalid length. - item_set = gb.ItemSetDict( - { - "user": gb.ItemSet(InvalidLength()), - "item": gb.ItemSet(InvalidLength()), - } - ) with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." + TypeError, match="object of type 'InvalidLength' has no len()" ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSetDict instance doesn't support indexing." - ): - _ = item_set[0] + item_set = gb.ItemSetDict( + { + "user": gb.ItemSet(InvalidLength()), + "item": gb.ItemSet(InvalidLength()), + } + ) # Tuple of iterables with invalid length. - item_set = gb.ItemSetDict( - { - "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())), - "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())), - } - ) - with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." - ): - _ = len(item_set) with pytest.raises( - TypeError, match="ItemSetDict instance doesn't support indexing." + TypeError, match="object of type 'InvalidLength' has no len()" ): - _ = item_set[0] + item_set = gb.ItemSetDict( + { + "user:like:item": gb.ItemSet( + (InvalidLength(), InvalidLength()) + ), + "user:follow:user": gb.ItemSet( + (InvalidLength(), InvalidLength()) + ), + } + ) def test_ItemSetDict_iteration_seed_nodes(): @@ -383,14 +364,23 @@ def test_ItemSetDict_iteration_seed_nodes(): partial_data = item_set[7:] assert len(list(partial_data.keys())) == 1 assert torch.equal(partial_data["item"], item_ids[2:]) - partial_data = item_set[3:7] + partial_data = item_set[3:8:2] + assert len(list(partial_data.keys())) == 2 + assert torch.equal(partial_data["user"], user_ids[3:-1:2]) + assert torch.equal(partial_data["item"], item_ids[0:3:2]) + # Indexing with an iterable of int. + partial_data = item_set[torch.tensor([1, 0, 4])] + assert len(list(partial_data.keys())) == 1 + assert torch.equal(partial_data["user"], torch.tensor([1, 0, 4])) + partial_data = item_set[torch.tensor([9, 8, 5])] + assert len(list(partial_data.keys())) == 1 + assert torch.equal(partial_data["item"], torch.tensor([9, 8, 5])) + partial_data = item_set[torch.tensor([8, 1, 0, 9, 7, 5])] assert len(list(partial_data.keys())) == 2 - assert torch.equal(partial_data["user"], user_ids[3:5]) - assert torch.equal(partial_data["item"], item_ids[:2]) + assert torch.equal(partial_data["user"], torch.tensor([1, 0])) + assert torch.equal(partial_data["item"], torch.tensor([8, 9, 7, 5])) # Exception cases. - with pytest.raises(AssertionError, match="Step must be 1."): - _ = item_set[::2] with pytest.raises( AssertionError, match="Start must be smaller than stop." ): @@ -404,9 +394,10 @@ def test_ItemSetDict_iteration_seed_nodes(): with pytest.raises(IndexError, match="ItemSetDict index out of range."): _ = item_set[-20] with pytest.raises( - TypeError, match="ItemSetDict indices must be int or slice." + TypeError, + match="ItemSetDict indices must be int, slice, or iterable of int, not .", ): - _ = item_set[torch.arange(3)] + _ = item_set[1.5] def test_ItemSetDict_iteration_seed_nodes_labels():