diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index ce96eb60e96e..9dbe5c074196 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,11 +1,12 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Iterator, Tuple, Union +from typing import Dict, Iterable, Iterator, Mapping, Tuple, Union import torch +from torch.utils.data import Dataset -__all__ = ["ItemSet", "ItemSetDict"] +__all__ = ["ItemSet", "ItemSetDict", "ItemSet4", "ItemSetDict4"] def is_scalar(x): @@ -442,3 +443,188 @@ def __repr__(self) -> str: itemsets=itemsets_str, names=self._names, ) + + +class ItemSet4(Dataset): + r"""Class for iterating over tensor-like data. + Experimental. Implemented only __getitem__() accepting slice and list. + """ + + def __init__( + self, + items: Union[torch.Tensor, Mapping, Tuple[Mapping]], + names: Union[str, Tuple[str]] = None, + ): + if is_scalar(items): + self._length = int(items) + self._items = items + elif isinstance(items, tuple): + self._length = len(items[0]) + if any(self._length != len(item) for item in items): + raise ValueError("Size mismatch between items.") + self._items = items + else: + self._length = len(items) + self._items = (items,) + if names is not None: + num_items = ( + len(self._items) if isinstance(self._items, tuple) else 1 + ) + if isinstance(names, tuple): + self._names = names + else: + self._names = (names,) + assert num_items == len(self._names), ( + f"Number of items ({num_items}) and " + f"names ({len(self._names)}) don't match." + ) + else: + self._names = None + + def __len__(self) -> int: + return self._length + + def __getitem__(self, index: Union[int, slice, Iterable[int]]): + if is_scalar(self._items): + if isinstance(index, slice): + start, stop, step = index.indices(int(self._items)) + dtype = getattr(self._items, "dtype", torch.int64) + return torch.arange(start, stop, step, dtype=dtype) + elif isinstance(index, int): + if index < 0: + index += int(self._items) + if index < 0 or index >= int(self._items): + raise IndexError( + f"{type(self).__name__} index out of range." + ) + return torch.tensor(index, dtype=self._items.dtype) + elif isinstance(index, Iterable): + dtype = getattr(self._items, "dtype", torch.int64) + return torch.tensor(index, dtype=dtype) + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, but got {type(index)}." + ) + elif len(self._items) == 1: + return self._items[0][index] + else: + return tuple(item[index] for item in self._items) + + @property + def names(self) -> Tuple[str]: + """Return the names of the items.""" + return self._names + + def __repr__(self) -> str: + _repr = ( + f"{self.__class__.__name__}(\n" + f" items={self._items},\n" + f" names={self._names},\n" + f")" + ) + return _repr + + +class ItemSetDict4(Dataset): + r"""Experimental.""" + + def __init__(self, itemsets: Dict[str, ItemSet4]) -> None: + super().__init__() + self._itemsets = itemsets + self._names = next(iter(itemsets.values())).names + if any(self._names != itemset.names for itemset in itemsets.values()): + raise ValueError("All itemsets must have the same names.") + offset = [0] + [len(itemset) for itemset in self._itemsets.values()] + self._offsets = torch.tensor(offset).cumsum(0) + self._length = int(self._offsets[-1]) + self._keys = list(self._itemsets.keys()) + + def __len__(self) -> int: + return self._length + + def __getitem__(self, index: Union[int, slice, Iterable[int]]): + if isinstance(index, int): + if index < 0: + index += self._length + if index < 0 or index >= self._length: + raise IndexError(f"{type(self).__name__} index out of range.") + offset_idx = torch.searchsorted(self._offsets, index, right=True) + offset_idx -= 1 + index -= self._offsets[offset_idx] + key = self._keys[offset_idx] + return {key: self._itemsets[key][index]} + elif isinstance(index, slice): + start, stop, step = index.indices(self._length) + # print(f"slice: {slice}, start, stop, step: {(start, stop, step)}") + # print(f"res list: {list(range(start, stop, step))}") + if step != 1: + return self.__getitem__(list(range(start, stop, step))) + assert start < stop, "Start must be smaller than stop." + data = {} + offset_idx_start = max( + 1, torch.searchsorted(self._offsets, start, right=False) + ) + for offset_idx in range(offset_idx_start, len(self._offsets)): + key = self._keys[offset_idx - 1] + data[key] = self._itemsets[key][ + max(0, start - self._offsets[offset_idx - 1]) : stop + - self._offsets[offset_idx - 1] + ] + if stop <= self._offsets[offset_idx]: + break + return data + elif isinstance(index, Iterable): + data = {key: [] for key in self._keys} + for idx in index: + if idx < 0: + idx += self._length + if idx < 0 or idx >= self._length: + raise IndexError( + f"{type(self).__name__} index out of range." + ) + offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx -= 1 + idx -= self._offsets[offset_idx] + key = self._keys[offset_idx] + data[key].append(int(idx)) + for key in self._keys: + indices = data[key] + if len(indices) == 0: + del data[key] + continue + item_set = self._itemsets[key] + try: + value = item_set[indices] + except TypeError: + # In case the itemset doesn't support list indexing. + value = tuple(item_set[idx] for idx in indices) + finally: + data[key] = value + return data + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice, or " + f"iterable of int, but got {type(index)}." + ) + + @property + def names(self) -> Tuple[str]: + """Return the names of the items.""" + return self._names + + def __repr__(self) -> str: + _repr = ( + "{Classname}(\n" + " itemsets={itemsets},\n" + " names={names},\n" + ")" + ) + itemsets_str = textwrap.indent( + repr(self._itemsets), " " * len(" itemsets=") + ).strip() + return _repr.format( + Classname=self.__class__.__name__, + itemsets=itemsets_str, + names=self._names, + ) diff --git a/tests/python/pytorch/graphbolt/test_itemset.py b/tests/python/pytorch/graphbolt/test_itemset.py index e41efcb2a2df..57c47a7dd1f1 100644 --- a/tests/python/pytorch/graphbolt/test_itemset.py +++ b/tests/python/pytorch/graphbolt/test_itemset.py @@ -615,3 +615,33 @@ def test_ItemSetDict_repr(): ")" ) assert str(item_set) == expected_str, item_set + + +def test_ItemSetDict4_indexing_with_list_of_integers(): + """Test indexing a ItemSetdict4 with iterable of integers.""" + item_set = gb.ItemSetDict4( + { + "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + "item": gb.ItemSet(torch.arange(5, 10), names="seeds"), + } + ) + indexing_res = item_set[1, 2, 3, 9, 8, 5] + assert torch.equal(indexing_res["user"], torch.tensor([1, 2, 3])) + assert torch.equal(indexing_res["item"], torch.tensor([9, 8, 5])) + + +def test_ItemSetDict4_slicing_with_step_not_equal_to_1(): + """Test indexing a ItemSetdict4 with slice whose step is other than 1.""" + item_set = gb.ItemSetDict4( + { + "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), + "item": gb.ItemSet(torch.arange(5, 10), names="seeds"), + } + ) + res = item_set[::2] + assert torch.equal(res["user"], torch.tensor([0, 2, 4])) + assert torch.equal(res["item"], torch.tensor([6, 8])) + + res1 = item_set[::-2] + assert torch.equal(res1["user"], torch.tensor([3, 1])) + assert torch.equal(res1["item"], torch.tensor([9, 7, 5]))