From 4a245adcaff4d147686e5d049129cc1c5fc20ffc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 10 May 2024 17:24:48 +0000 Subject: [PATCH 01/11] test done --- python/dgl/graphbolt/itemset.py | 200 ++++++++---------- .../pytorch/graphbolt/test_item_sampler.py | 79 ------- .../python/pytorch/graphbolt/test_itemset.py | 96 ++------- 3 files changed, 103 insertions(+), 272 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index ce96eb60e96e..041f2eaa0e10 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,7 +1,7 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Iterator, Tuple, Union +from typing import Dict, Iterable, Iterator, Sequence, Tuple, Union import torch @@ -16,15 +16,16 @@ def is_scalar(x): class ItemSet: - r"""A wrapper of iterable data or tuple of iterable data. + r"""A wrapper of sequential data or tuple of sequential data. - All itemsets that represent an iterable of items should subclass it. Such + All itemsets that represent an sequence of items should subclass it. Such form of itemset is particularly useful when items come from a stream. This - class requires each input itemset to be iterable. + class requires each input itemset to be sequential (and iterable, off + course). Parameters ---------- - items: Union[int, Iterable, Tuple[Iterable]] + items: Union[int, Sequence, Tuple[Sequence]] The items to be iterated over. If it is a single integer, a `range()` object will be created and iterated over. If it's multi-dimensional iterable such as `torch.Tensor`, it will be iterated over the first @@ -129,31 +130,23 @@ class requires each input itemset to be iterable. def __init__( self, - items: Union[int, torch.Tensor, Iterable, Tuple[Iterable]], + items: Union[int, torch.Tensor, Sequence, Tuple[Sequence]], names: Union[str, Tuple[str]] = None, ) -> None: if is_scalar(items): self._length = int(items) self._items = items - self._num_items = 1 elif isinstance(items, tuple): - try: - self._length = len(items[0]) - except TypeError: - self._length = None - if self._length is not None: - if any(self._length != len(item) for item in items): - raise ValueError("Size mismatch between items.") + self._length = len(items[0]) + if any(self._length != len(item) for item in items): + raise ValueError("Size mismatch between items.") self._items = items - self._num_items = len(items) else: - try: - self._length = len(items) - except TypeError: - self._length = None + self._length = len(items) self._items = (items,) - self._num_items = 1 - + self._num_items = ( + len(self._items) if isinstance(self._items, tuple) else 1 + ) if names is not None: if isinstance(names, tuple): self._names = names @@ -161,65 +154,39 @@ def __init__( self._names = (names,) assert self._num_items == len(self._names), ( f"Number of items ({self._num_items}) and " - f"names ({len(self._names)}) must match." + f"names ({len(self._names)}) don't match." ) else: self._names = None - def __iter__(self) -> Iterator: - if is_scalar(self._items): - dtype = getattr(self._items, "dtype", torch.int64) - yield from torch.arange(self._items, dtype=dtype) - return - - if self._num_items == 1: - yield from self._items[0] - return - - if self._length is not None: - # Use for-loop to iterate over the items. It can avoid a long - # waiting time when the items are torch tensors. Since torch - # tensors need to call self.unbind(0) to slice themselves. - # While for-loops are slower than zip, they prevent excessive - # wait times during the loading phase, and the impact on overall - # performance during the training/testing stage is minimal. - # For more details, see https://github.com/dmlc/dgl/pull/6293. - for i in range(self._length): - yield tuple(item[i] for item in self._items) - else: - # If the items are not Sized, we use zip to iterate over them. - zip_items = zip(*self._items) - for item in zip_items: - yield tuple(item) + def __len__(self) -> int: + return self._length - def __getitem__(self, idx: Union[int, slice, Iterable]) -> Tuple: - if self._length is None: - raise TypeError( - f"{type(self).__name__} instance doesn't support indexing." - ) + def __getitem__(self, index: Union[int, slice, Iterable[int]]): if is_scalar(self._items): - if isinstance(idx, slice): - start, stop, step = idx.indices(self._length) - dtype = getattr(self._items, "dtype", torch.int64) + dtype = getattr(self._items, "dtype", torch.int64) + if isinstance(index, slice): + start, stop, step = index.indices(int(self._items)) return torch.arange(start, stop, step, dtype=dtype) - if isinstance(idx, int): - if idx < 0: - idx += self._length - if idx < 0 or idx >= self._length: + elif isinstance(index, int): + if index < 0: + index += int(self._items) + if index < 0 or index >= int(self._items): raise IndexError( f"{type(self).__name__} index out of range." ) - return ( - torch.tensor(idx, dtype=self._items.dtype) - if isinstance(self._items, torch.Tensor) - else idx + return torch.tensor(index, dtype=dtype) + elif isinstance(index, Iterable): + return torch.tensor(index, dtype=dtype) + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, not {type(index)}." ) - raise TypeError( - f"{type(self).__name__} indices must be integer or slice." - ) - if self._num_items == 1: - return self._items[0][idx] - return tuple(item[idx] for item in self._items) + elif self._num_items == 1: + return self._items[0][index] + else: + return tuple(item[index] for item in self._items) @property def names(self) -> Tuple[str]: @@ -231,13 +198,6 @@ def num_items(self) -> int: """Return the number of the items.""" return self._num_items - def __len__(self): - if self._length is None: - raise TypeError( - f"{type(self).__name__} instance doesn't have valid length." - ) - return self._length - def __repr__(self) -> str: ret = ( f"{self.__class__.__name__}(\n" @@ -245,7 +205,6 @@ def __repr__(self) -> str: f" names={self._names},\n" f")" ) - return ret @@ -359,55 +318,43 @@ class ItemSetDict: def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._itemsets = itemsets - self._names = itemsets[list(itemsets.keys())[0]].names + self._names = next(iter(itemsets.values())).names assert all( self._names == itemset.names for itemset in itemsets.values() ), "All itemsets must have the same names." - try: - # For indexable itemsets, we compute the offsets for each itemset - # in advance to speed up indexing. - offsets = [0] + [ - len(itemset) for itemset in self._itemsets.values() - ] - self._offsets = torch.tensor(offsets).cumsum(0) - except TypeError: - self._offsets = None - - def __iter__(self) -> Iterator: - for key, itemset in self._itemsets.items(): - for item in itemset: - yield {key: item} + offset = [0] + [len(itemset) for itemset in self._itemsets.values()] + self._offsets = torch.tensor(offset).cumsum(0) + self._length = int(self._offsets[-1]) + self._keys = list(self._itemsets.keys()) def __len__(self) -> int: return sum(len(itemset) for itemset in self._itemsets.values()) - def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]: - if self._offsets is None: - raise TypeError( - f"{type(self).__name__} instance doesn't support indexing." - ) - total_num = self._offsets[-1] - if isinstance(idx, int): - if idx < 0: - idx += total_num - if idx < 0 or idx >= total_num: + def __len__(self) -> int: + return self._length + + def __getitem__(self, index: Union[int, slice, Iterable[int]]): + if isinstance(index, int): + if index < 0: + index += self._length + if index < 0 or index >= self._length: raise IndexError(f"{type(self).__name__} index out of range.") - offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx = torch.searchsorted(self._offsets, index, right=True) offset_idx -= 1 - idx -= self._offsets[offset_idx] - key = list(self._itemsets.keys())[offset_idx] - return {key: self._itemsets[key][idx]} - elif isinstance(idx, slice): - start, stop, step = idx.indices(total_num) - assert step == 1, "Step must be 1." + index -= self._offsets[offset_idx] + key = self._keys[offset_idx] + return {key: self._itemsets[key][index]} + elif isinstance(index, slice): + start, stop, step = index.indices(self._length) + if step != 1: + return self.__getitem__(list(range(start, stop, step))) assert start < stop, "Start must be smaller than stop." data = {} offset_idx_start = max( 1, torch.searchsorted(self._offsets, start, right=False) ) - keys = list(self._itemsets.keys()) for offset_idx in range(offset_idx_start, len(self._offsets)): - key = keys[offset_idx - 1] + key = self._keys[offset_idx - 1] data[key] = self._itemsets[key][ max(0, start - self._offsets[offset_idx - 1]) : stop - self._offsets[offset_idx - 1] @@ -415,9 +362,38 @@ def __getitem__(self, idx: Union[int, slice]) -> Dict[str, Tuple]: if stop <= self._offsets[offset_idx]: break return data + elif isinstance(index, Iterable): + data = {key: [] for key in self._keys} + for idx in index: + if idx < 0: + idx += self._length + if idx < 0 or idx >= self._length: + raise IndexError( + f"{type(self).__name__} index out of range." + ) + offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx -= 1 + idx -= self._offsets[offset_idx] + key = self._keys[offset_idx] + data[key].append(int(idx)) + for key in self._keys: + indices = data[key] + if len(indices) == 0: + del data[key] + continue + item_set = self._itemsets[key] + try: + value = item_set[indices] + except TypeError: + # In case the itemset doesn't support list indexing. + value = tuple(item_set[idx] for idx in indices) + finally: + data[key] = value + return data else: raise TypeError( - f"{type(self).__name__} indices must be int or slice." + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, not {type(index)}." ) @property diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index 127480d9aa15..744b81462e4a 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -68,39 +68,6 @@ def minibatcher(batch, names): assert len(minibatch.seeds) == 4 -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): - num_ids = 103 - - class InvalidLength: - def __iter__(self): - return iter(torch.arange(0, num_ids)) - - seed_nodes = gb.ItemSet(InvalidLength()) - item_set = gb.ItemSet(seed_nodes, names="seeds") - item_sampler = gb.ItemSampler( - item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) - minibatch_ids = [] - for i, minibatch in enumerate(item_sampler): - assert isinstance(minibatch, gb.MiniBatch) - assert minibatch.seeds is not None - assert minibatch.labels is None - is_last = (i + 1) * batch_size >= num_ids - if not is_last or num_ids % batch_size == 0: - assert len(minibatch.seeds) == batch_size - else: - if not drop_last: - assert len(minibatch.seeds) == num_ids % batch_size - else: - assert False - minibatch_ids.append(minibatch.seeds) - minibatch_ids = torch.cat(minibatch_ids) - assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle - - @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) @@ -496,52 +463,6 @@ def test_append_with_other_datapipes(): assert len(data.seeds) == batch_size -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last): - class IterableOnly: - def __init__(self, start, stop): - self._start = start - self._stop = stop - - def __iter__(self): - return iter(torch.arange(self._start, self._stop)) - - num_ids = 205 - ids = { - "user": gb.ItemSet(IterableOnly(0, 99), names="seeds"), - "item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"), - } - chained_ids = [] - for key, value in ids.items(): - chained_ids += [(key, v) for v in value] - item_set = gb.ItemSetDict(ids) - item_sampler = gb.ItemSampler( - item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last - ) - minibatch_ids = [] - for i, minibatch in enumerate(item_sampler): - is_last = (i + 1) * batch_size >= num_ids - if not is_last or num_ids % batch_size == 0: - expected_batch_size = batch_size - else: - if not drop_last: - expected_batch_size = num_ids % batch_size - else: - assert False - assert isinstance(minibatch, gb.MiniBatch) - assert minibatch.seeds is not None - ids = [] - for _, v in minibatch.seeds.items(): - ids.append(v) - ids = torch.cat(ids) - assert len(ids) == expected_batch_size - minibatch_ids.append(ids) - minibatch_ids = torch.cat(minibatch_ids) - assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle - - @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index e41efcb2a2df..104d2a6ed8db 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -25,14 +25,14 @@ def test_ItemSet_names(): # Integer-initiated ItemSet with excessive names. with pytest.raises( AssertionError, - match=re.escape("Number of items (1) and names (2) must match."), + match=re.escape("Number of items (1) and names (2) don't match."), ): _ = gb.ItemSet(5, names=("seeds", "labels")) # ItemSet with mismatched items and names. with pytest.raises( AssertionError, - match=re.escape("Number of items (1) and names (2) must match."), + match=re.escape("Number of items (1) and names (2) don't match."), ): _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels")) @@ -72,37 +72,6 @@ def test_ItemSet_length(): assert i == item1.item() assert i + 5 == item2.item() - class InvalidLength: - def __iter__(self): - return iter([0, 1, 2]) - - # Single iterable with invalid length. - item_set = gb.ItemSet(InvalidLength()) - with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." - ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSet instance doesn't support indexing." - ): - _ = item_set[0] - for i, item in enumerate(item_set): - assert i == item - - # Tuple of iterables with invalid length. - item_set = gb.ItemSet((InvalidLength(), InvalidLength())) - with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." - ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSet instance doesn't support indexing." - ): - _ = item_set[0] - for i, (item1, item2) in enumerate(item_set): - assert i == item1 - assert i == item2 - def test_ItemSet_seed_nodes(): # Node IDs with tensor. @@ -113,7 +82,7 @@ def test_ItemSet_seed_nodes(): assert i == item.item() assert i == item_set[i] # Indexing with a slice. - assert torch.equal(item_set[:], torch.arange(0, 5)) + assert torch.equal(item_set[::2], torch.tensor([0, 2, 4])) # Indexing with an Iterable. assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) @@ -125,7 +94,8 @@ def test_ItemSet_seed_nodes(): assert i == item.item() assert i == item_set[i] # Indexing with a slice. - assert torch.equal(item_set[:], torch.arange(0, 5)) + assert torch.equal(item_set[::2], torch.tensor([0, 2, 4])) + assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) # Indexing with an integer. assert item_set[0] == 0 assert item_set[-1] == 4 @@ -134,11 +104,12 @@ def test_ItemSet_seed_nodes(): _ = item_set[5] with pytest.raises(IndexError, match="ItemSet index out of range."): _ = item_set[-10] - # Indexing with tensor. + # Indexing with invalid input type. with pytest.raises( - TypeError, match="ItemSet indices must be integer or slice." + TypeError, + match="ItemSet indices must be int, slice, or iterable of int, not .", ): - _ = item_set[torch.arange(3)] + _ = item_set[1.5] def test_ItemSet_seed_nodes_labels(): @@ -315,42 +286,6 @@ def test_ItemSetDict_length(): ) assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0) - class InvalidLength: - def __iter__(self): - return iter([0, 1, 2]) - - # Single iterable with invalid length. - item_set = gb.ItemSetDict( - { - "user": gb.ItemSet(InvalidLength()), - "item": gb.ItemSet(InvalidLength()), - } - ) - with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." - ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSetDict instance doesn't support indexing." - ): - _ = item_set[0] - - # Tuple of iterables with invalid length. - item_set = gb.ItemSetDict( - { - "user:like:item": gb.ItemSet((InvalidLength(), InvalidLength())), - "user:follow:user": gb.ItemSet((InvalidLength(), InvalidLength())), - } - ) - with pytest.raises( - TypeError, match="ItemSet instance doesn't have valid length." - ): - _ = len(item_set) - with pytest.raises( - TypeError, match="ItemSetDict instance doesn't support indexing." - ): - _ = item_set[0] - def test_ItemSetDict_iteration_seed_nodes(): # Node IDs. @@ -383,14 +318,12 @@ def test_ItemSetDict_iteration_seed_nodes(): partial_data = item_set[7:] assert len(list(partial_data.keys())) == 1 assert torch.equal(partial_data["item"], item_ids[2:]) - partial_data = item_set[3:7] + partial_data = item_set[3:8:2] assert len(list(partial_data.keys())) == 2 - assert torch.equal(partial_data["user"], user_ids[3:5]) - assert torch.equal(partial_data["item"], item_ids[:2]) + assert torch.equal(partial_data["user"], user_ids[3:-1:2]) + assert torch.equal(partial_data["item"], item_ids[0:3:2]) # Exception cases. - with pytest.raises(AssertionError, match="Step must be 1."): - _ = item_set[::2] with pytest.raises( AssertionError, match="Start must be smaller than stop." ): @@ -404,9 +337,10 @@ def test_ItemSetDict_iteration_seed_nodes(): with pytest.raises(IndexError, match="ItemSetDict index out of range."): _ = item_set[-20] with pytest.raises( - TypeError, match="ItemSetDict indices must be int or slice." + TypeError, + match="ItemSetDict indices must be int, slice, or iterable of int, not .", ): - _ = item_set[torch.arange(3)] + _ = item_set[1.5] def test_ItemSetDict_iteration_seed_nodes_labels(): From 06ff62cf904655d86872eb52836573458a7e39e6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 10 May 2024 17:35:02 +0000 Subject: [PATCH 02/11] rename and rm --- python/dgl/graphbolt/itemset.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 041f2eaa0e10..f63036420ec3 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -166,12 +166,12 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): if is_scalar(self._items): dtype = getattr(self._items, "dtype", torch.int64) if isinstance(index, slice): - start, stop, step = index.indices(int(self._items)) + start, stop, step = index.indices(self._length) return torch.arange(start, stop, step, dtype=dtype) elif isinstance(index, int): if index < 0: - index += int(self._items) - if index < 0 or index >= int(self._items): + index += self._length + if index < 0 or index >= self._length: raise IndexError( f"{type(self).__name__} index out of range." ) @@ -327,9 +327,6 @@ def __init__(self, itemsets: Dict[str, ItemSet]) -> None: self._length = int(self._offsets[-1]) self._keys = list(self._itemsets.keys()) - def __len__(self) -> int: - return sum(len(itemset) for itemset in self._itemsets.values()) - def __len__(self) -> int: return self._length From 70fc889e294ae95ab6efedc9360218e247791828 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 10 May 2024 18:06:14 +0000 Subject: [PATCH 03/11] itemset doc --- python/dgl/graphbolt/itemset.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index f63036420ec3..a987256b7ffc 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -18,24 +18,29 @@ def is_scalar(x): class ItemSet: r"""A wrapper of sequential data or tuple of sequential data. - All itemsets that represent an sequence of items should subclass it. Such - form of itemset is particularly useful when items come from a stream. This - class requires each input itemset to be sequential (and iterable, off - course). + This class requires each input item to be sequential (and, of course, + iterable), meaning that each item must have implemented `__getitem__` which + supports fetching a data for a given index, and `__len__` which is expected + to return the size of the item. + Parameters ---------- - items: Union[int, Sequence, Tuple[Sequence]] - The items to be iterated over. If it is a single integer, a `range()` - object will be created and iterated over. If it's multi-dimensional - iterable such as `torch.Tensor`, it will be iterated over the first - dimension. If it is a tuple, each item in the tuple is an iterable of - items. + items: Union[int, torch.Tensor, Sequence, Tuple[Sequence]] + The sequential items. + - If it is a single scalar (an integer or a tensor that holds a single + value), the item would be considered as a range_tensor created by + `torch.arange`. + - If it is a multi-dimensional sequence such as `torch.Tensor`, the + indexing will be performed along the first dimension. + - If it is a tuple, each item in the tuple must be a sequence. + names: Union[str, Tuple[str]], optional - The names of the items. If it is a tuple, each name corresponds to an - item in the tuple. The naming is arbitrary, but in general practice, - the names should be chosen from ['labels', 'seeds', 'indexes'] to align - with the attributes of class `dgl.graphbolt.MiniBatch`. + The names of the items. If it is a tuple, each name must corresponds to + an item in the `items` parameter. The naming is arbitrary, but in + general practice, the names should be chosen from ['labels', 'seeds', + 'indexes'] to align with the attributes of class + `dgl.graphbolt.MiniBatch`. Examples -------- From 6f252279ab950972b8e828801e7894a4cff70273 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 10 May 2024 18:15:15 +0000 Subject: [PATCH 04/11] itemsetdice docstring --- python/dgl/graphbolt/itemset.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index a987256b7ffc..7fdac4d3176a 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -74,7 +74,7 @@ class ItemSet: >>> item_set.names ('seeds',) - 3. Single iterable: seed nodes. + 3. Single sequence: seed nodes. >>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") @@ -85,7 +85,7 @@ class ItemSet: >>> item_set.names ('seeds',) - 4. Tuple of iterables with same shape: seed nodes and labels. + 4. Tuple of sequences with same shape: seed nodes and labels. >>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) @@ -99,7 +99,7 @@ class ItemSet: >>> item_set.names ('seeds', 'labels') - 5. Tuple of iterables with different shape: seeds and labels. + 5. Tuple of sequences with different shape: seeds and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) @@ -117,7 +117,7 @@ class ItemSet: >>> item_set.names ('seeds', 'labels') - 6. Tuple of iterables with different shape: hyperlink and labels. + 6. Tuple of sequences with different shape: hyperlink and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) @@ -216,8 +216,8 @@ def __repr__(self) -> str: class ItemSetDict: r"""Dictionary wrapper of **ItemSet**. - Each item is retrieved by iterating over each itemset and returned with - corresponding key as a dict. + This class is useful to assemble existing itemsets with different tags, for + example, seed_nodes of different node types in a graph. Parameters ---------- @@ -228,7 +228,7 @@ class ItemSetDict: >>> import torch >>> from dgl import graphbolt as gb - 1. Single iterable: seed nodes. + 1. Each itemset is a single sequence: seed nodes. >>> node_ids_user = torch.arange(0, 5) >>> node_ids_item = torch.arange(5, 10) @@ -245,7 +245,8 @@ class ItemSetDict: >>> item_set.names ('seeds',) - 2. Tuple of iterables with same shape: seed nodes and labels. + 2. Each itemset is a tuple of sequences with same shape: seed nodes and + labels. >>> node_ids_user = torch.arange(0, 2) >>> labels_user = torch.arange(0, 2) @@ -268,7 +269,8 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 3. Tuple of iterables with different shape: seeds and labels. + 3. Each itemset is a tuple of sequences with different shape: seeds and + labels. >>> seeds_like = torch.arange(0, 4).reshape(-1, 2) >>> labels_like = torch.tensor([1, 0]) @@ -295,7 +297,8 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 4. Tuple of iterables with different shape: hyperlink and labels. + 4. Each itemset is a tuple of sequences with different shape: hyperlink and + labels. >>> first_seeds = torch.arange(0, 6).reshape(-1, 3) >>> first_labels = torch.tensor([1, 0]) From f93ad91cbacad8765fe4c3766afa76a592b06587 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 10 May 2024 18:16:17 +0000 Subject: [PATCH 05/11] lint --- python/dgl/graphbolt/itemset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 7fdac4d3176a..39d1c51f9993 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -22,7 +22,7 @@ class ItemSet: iterable), meaning that each item must have implemented `__getitem__` which supports fetching a data for a given index, and `__len__` which is expected to return the size of the item. - + Parameters ---------- From a64786a6b2f89404da98377a90de321b352bd153 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 11 May 2024 04:50:23 +0000 Subject: [PATCH 06/11] lint1 --- python/dgl/graphbolt/itemset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 39d1c51f9993..57a06a1702a7 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,7 +1,7 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Iterator, Sequence, Tuple, Union +from typing import Dict, Iterable, Sequence, Tuple, Union import torch From cd6c5c6a392255fca5eb066429ce59d614ecd113 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 11 May 2024 04:59:59 +0000 Subject: [PATCH 07/11] recover test of invalid length --- .../python/pytorch/graphbolt/test_itemset.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index 104d2a6ed8db..a5ca2aeba0d4 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -72,6 +72,22 @@ def test_ItemSet_length(): assert i == item1.item() assert i + 5 == item2.item() + class InvalidLength: + def __iter__(self): + return iter([0, 1, 2]) + + # Single iterable with invalid length. + with pytest.raises( + TypeError, match="object of type 'InvalidLength' has no len()" + ): + item_set = gb.ItemSet(InvalidLength()) + + # Tuple of iterables with invalid length. + with pytest.raises( + TypeError, match="object of type 'InvalidLength' has no len()" + ): + item_set = gb.ItemSet((InvalidLength(), InvalidLength())) + def test_ItemSet_seed_nodes(): # Node IDs with tensor. @@ -286,6 +302,36 @@ def test_ItemSetDict_length(): ) assert len(item_set) == node_pairs_like.size(0) + node_pairs_follow.size(0) + class InvalidLength: + def __iter__(self): + return iter([0, 1, 2]) + + # Single iterable with invalid length. + with pytest.raises( + TypeError, match="object of type 'InvalidLength' has no len()" + ): + item_set = gb.ItemSetDict( + { + "user": gb.ItemSet(InvalidLength()), + "item": gb.ItemSet(InvalidLength()), + } + ) + + # Tuple of iterables with invalid length. + with pytest.raises( + TypeError, match="object of type 'InvalidLength' has no len()" + ): + item_set = gb.ItemSetDict( + { + "user:like:item": gb.ItemSet( + (InvalidLength(), InvalidLength()) + ), + "user:follow:user": gb.ItemSet( + (InvalidLength(), InvalidLength()) + ), + } + ) + def test_ItemSetDict_iteration_seed_nodes(): # Node IDs. From 0d837567f6391efbb9d3564014a7bdffa5e2edc0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 13 May 2024 05:42:32 +0000 Subject: [PATCH 08/11] rm sequence --- python/dgl/graphbolt/itemset.py | 40 ++++++++++++++------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 57a06a1702a7..6a618ae4815e 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,7 +1,7 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Sequence, Tuple, Union +from typing import Dict, Iterable, Tuple, Union import torch @@ -16,24 +16,18 @@ def is_scalar(x): class ItemSet: - r"""A wrapper of sequential data or tuple of sequential data. - - This class requires each input item to be sequential (and, of course, - iterable), meaning that each item must have implemented `__getitem__` which - supports fetching a data for a given index, and `__len__` which is expected - to return the size of the item. - + r"""A wrapper of a tensor or tuple of tensors. Parameters ---------- - items: Union[int, torch.Tensor, Sequence, Tuple[Sequence]] - The sequential items. + items: Union[int, torch.Tensor, Tuple[torch.Tensor]] + The tensors to be wrapped. - If it is a single scalar (an integer or a tensor that holds a single value), the item would be considered as a range_tensor created by `torch.arange`. - - If it is a multi-dimensional sequence such as `torch.Tensor`, the - indexing will be performed along the first dimension. - - If it is a tuple, each item in the tuple must be a sequence. + - If it is a multi-dimensional tensor, the indexing will be performed + along the first dimension. + - If it is a tuple, each item in the tuple must be a tensor. names: Union[str, Tuple[str]], optional The names of the items. If it is a tuple, each name must corresponds to @@ -74,7 +68,7 @@ class ItemSet: >>> item_set.names ('seeds',) - 3. Single sequence: seed nodes. + 3. Single tensor: seed nodes. >>> node_ids = torch.arange(0, 5) >>> item_set = gb.ItemSet(node_ids, names="seeds") @@ -85,7 +79,7 @@ class ItemSet: >>> item_set.names ('seeds',) - 4. Tuple of sequences with same shape: seed nodes and labels. + 4. Tuple of tensors with same shape: seed nodes and labels. >>> node_ids = torch.arange(0, 5) >>> labels = torch.arange(5, 10) @@ -99,7 +93,7 @@ class ItemSet: >>> item_set.names ('seeds', 'labels') - 5. Tuple of sequences with different shape: seeds and labels. + 5. Tuple of tensors with different shape: seeds and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 2) >>> labels = torch.tensor([1, 1, 0, 0, 0]) @@ -117,7 +111,7 @@ class ItemSet: >>> item_set.names ('seeds', 'labels') - 6. Tuple of sequences with different shape: hyperlink and labels. + 6. Tuple of tensors with different shape: hyperlink and labels. >>> seeds = torch.arange(0, 10).reshape(-1, 5) >>> labels = torch.tensor([1, 0]) @@ -135,7 +129,7 @@ class ItemSet: def __init__( self, - items: Union[int, torch.Tensor, Sequence, Tuple[Sequence]], + items: Union[int, torch.Tensor, Tuple[torch.Tensor]], names: Union[str, Tuple[str]] = None, ) -> None: if is_scalar(items): @@ -216,7 +210,7 @@ def __repr__(self) -> str: class ItemSetDict: r"""Dictionary wrapper of **ItemSet**. - This class is useful to assemble existing itemsets with different tags, for + This class aims to assemble existing itemsets with different tags, for example, seed_nodes of different node types in a graph. Parameters @@ -228,7 +222,7 @@ class ItemSetDict: >>> import torch >>> from dgl import graphbolt as gb - 1. Each itemset is a single sequence: seed nodes. + 1. Each itemset is a single tensor: seed nodes. >>> node_ids_user = torch.arange(0, 5) >>> node_ids_item = torch.arange(5, 10) @@ -245,7 +239,7 @@ class ItemSetDict: >>> item_set.names ('seeds',) - 2. Each itemset is a tuple of sequences with same shape: seed nodes and + 2. Each itemset is a tuple of tensors with same shape: seed nodes and labels. >>> node_ids_user = torch.arange(0, 2) @@ -269,7 +263,7 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 3. Each itemset is a tuple of sequences with different shape: seeds and + 3. Each itemset is a tuple of tensors with different shape: seeds and labels. >>> seeds_like = torch.arange(0, 4).reshape(-1, 2) @@ -297,7 +291,7 @@ class ItemSetDict: >>> item_set.names ('seeds', 'labels') - 4. Each itemset is a tuple of sequences with different shape: hyperlink and + 4. Each itemset is a tuple of tensors with different shape: hyperlink and labels. >>> first_seeds = torch.arange(0, 6).reshape(-1, 3) From d8ae48728500847a036cfb738c78ecb272ff16e0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 13 May 2024 07:25:15 +0000 Subject: [PATCH 09/11] add test for itemsetdict s list indexing --- tests/python/pytorch/graphbolt/test_itemset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index a5ca2aeba0d4..ae0d923e2081 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -368,6 +368,17 @@ def test_ItemSetDict_iteration_seed_nodes(): assert len(list(partial_data.keys())) == 2 assert torch.equal(partial_data["user"], user_ids[3:-1:2]) assert torch.equal(partial_data["item"], item_ids[0:3:2]) + # Indexing with an iterable of int. + partial_data = item_set[torch.tensor([1, 0, 4])] + assert len(list(partial_data.keys())) == 1 + assert torch.equal(partial_data["user"], torch.tensor([1, 0, 4])) + partial_data = item_set[torch.tensor([9, 8, 5])] + assert len(list(partial_data.keys())) == 1 + assert torch.equal(partial_data["item"], torch.tensor([9, 8, 5])) + partial_data = item_set[torch.tensor([8, 1, 0, 9, 7, 5])] + assert len(list(partial_data.keys())) == 2 + assert torch.equal(partial_data["user"], torch.tensor([1, 0])) + assert torch.equal(partial_data["item"], torch.tensor([8, 9, 7, 5])) # Exception cases. with pytest.raises( From 9de26a1546a8d3b2e777502ce8f0de50d81edf02 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 02:48:33 +0000 Subject: [PATCH 10/11] todo --- python/dgl/graphbolt/itemset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 6a618ae4815e..3797489df429 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -362,6 +362,7 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): break return data elif isinstance(index, Iterable): + # TODO[Mingbang]: Might have performance issue. Tests needed. data = {key: [] for key in self._keys} for idx in index: if idx < 0: From b3966ee42e9df6e58016d09709d88e1311fd7759 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 May 2024 03:02:00 +0000 Subject: [PATCH 11/11] use torch arange instead of list range --- python/dgl/graphbolt/itemset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 3797489df429..3f8472f856a3 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -346,7 +346,7 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): elif isinstance(index, slice): start, stop, step = index.indices(self._length) if step != 1: - return self.__getitem__(list(range(start, stop, step))) + return self.__getitem__(torch.arange(start, stop, step)) assert start < stop, "Start must be smaller than stop." data = {} offset_idx_start = max(