diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 3df09bf852ea..c8b53eb73b74 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -36,6 +36,7 @@ │ └───> Test set evaluation """ + import argparse import os import time diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index f92798e6b5cf..2cd2379380b3 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -15,10 +15,16 @@ from ..batch import batch as dgl_batch from ..heterograph import DGLGraph from .internal import calculate_range -from .itemset import ItemSet, ItemSetDict +from .itemset import ItemSet, ItemSet4, ItemSetDict, ItemSetDict4 from .minibatch import MiniBatch -__all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"] +__all__ = [ + "ItemSampler", + "DistributedItemSampler", + "minibatcher_default", + "ItemSampler4", + "DistributedItemSampler4", +] def minibatcher_default(batch, names): @@ -833,3 +839,163 @@ def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None): neg_indexes = pos_indexes.repeat_interleave(negative_ratio) indexes[etype] = torch.cat((pos_indexes, neg_indexes)) return seeds, labels, indexes + + +class ItemSampler4(IterDataPipe): + """Experimental. Try to implement based on the current ItemSampler.""" + + def __init__( + self, + item_set: Union[ItemSet4, ItemSetDict4], + batch_size: int, + minibatcher: Optional[Callable] = minibatcher_default, + drop_last: Optional[bool] = False, + shuffle: Optional[bool] = False, + ) -> None: + super().__init__() + self._names = item_set.names + self._item_set = item_set + self._batch_size = batch_size + self._minibatcher = minibatcher + self._drop_last = drop_last + self._shuffle = shuffle + self._distributed = False + self._drop_uneven_inputs = False + self._world_size = None + self._rank = None + self._seed = np.random.randint(0, np.iinfo(np.int32).max) + self._epoch = 0 + + def _collate_batch(self, buffer, indices, offsets=None): + """Collate a batch from the buffer. For internal use only.""" + if isinstance(buffer, torch.Tensor): + # For item set that's initialized with integer or single tensor, + # `buffer` is a tensor. + return torch.index_select(buffer, dim=0, index=indices) + elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph): + # For item set that's initialized with a list of + # DGLGraphs, `buffer` is a list of DGLGraphs. + return dgl_batch([buffer[idx] for idx in indices]) + elif isinstance(buffer, tuple): + # For item set that's initialized with a tuple of items, + # `buffer` is a tuple of tensors. + return tuple(item[indices] for item in buffer) + elif isinstance(buffer, Mapping): + # For item set that's initialized with a dict of items, + # `buffer` is a dict of tensors/lists/tuples. + keys = list(buffer.keys()) + key_indices = torch.searchsorted(offsets, indices, right=True) - 1 + batch = {} + for j, key in enumerate(keys): + mask = (key_indices == j).nonzero().squeeze(1) + if len(mask) == 0: + continue + batch[key] = self._collate_batch( + buffer[key], indices[mask] - offsets[j] + ) + return batch + raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") + + def _calculate_offsets(self, buffer): + """Calculate offsets for each item in buffer. For internal use only.""" + if not isinstance(buffer, Mapping): + return None + offsets = [0] + for value in buffer.values(): + if isinstance(value, torch.Tensor): + offsets.append(offsets[-1] + len(value)) + elif isinstance(value, tuple): + offsets.append(offsets[-1] + len(value[0])) + else: + raise TypeError( + f"Unsupported buffer type {type(value).__name__}." + ) + return torch.tensor(offsets) + + def __iter__(self) -> Iterator: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker_id = worker_info.id + else: + num_workers = 1 + worker_id = 0 + total = len(self._item_set) + start_offset, assigned_count, output_count = calculate_range( + self._distributed, + total, + self._world_size, + self._rank, + num_workers, + worker_id, + self._batch_size, + self._drop_last, + self._drop_uneven_inputs, + ) + if self._shuffle: + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) + indices = torch.randperm(total, generator=g) + buffer = self._item_set[ + indices[start_offset : start_offset + assigned_count] + ] + else: + buffer = self._item_set[ + start_offset : start_offset + assigned_count + ] + offsets = self._calculate_offsets(buffer) + for i in range(0, assigned_count, self._batch_size): + if output_count <= 0: + break + batch_indices = torch.arange( + i, i + min(self._batch_size, output_count) + ) + output_count -= self._batch_size + yield self._minibatcher( + self._collate_batch(buffer, batch_indices, offsets), self._names + ) + + self._epoch += 1 + + +class DistributedItemSampler4(ItemSampler4): + """Experimental. Try to implement based on the current + DistributedItemSampler.""" + + def __init__( + self, + item_set: Union[ItemSet4, ItemSetDict4], + batch_size: int, + minibatcher: Optional[Callable] = minibatcher_default, + drop_last: Optional[bool] = False, + shuffle: Optional[bool] = False, + drop_uneven_inputs: Optional[bool] = False, + ) -> None: + super().__init__( + item_set, + batch_size, + minibatcher, + drop_last, + shuffle, + ) + self._distributed = True + self._drop_uneven_inputs = drop_uneven_inputs + if not dist.is_available(): + raise RuntimeError( + "Distributed item sampler requires distributed package." + ) + self._world_size = dist.get_world_size() + self._rank = dist.get_rank() + + device = ( + torch.cuda.current_device() + if torch.cuda.is_available() and dist.get_backend() == "nccl" + else "cpu" + ) + if self._rank == 0: + seed = np.random.randint(0, np.iinfo(np.int32).max) + seed_tensor = torch.tensor(seed, dtype=torch.int32, device=device) + else: + seed_tensor = torch.empty([], dtype=torch.int32, device=device) + dist.broadcast(seed_tensor, src=0) + self._seed = seed_tensor.item() diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index ce96eb60e96e..6ec83c680e7f 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -1,11 +1,12 @@ """GraphBolt Itemset.""" import textwrap -from typing import Dict, Iterable, Iterator, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Mapping, Tuple, Union import torch +from torch.utils.data import Dataset -__all__ = ["ItemSet", "ItemSetDict"] +__all__ = ["ItemSet", "ItemSetDict", "ItemSet4", "ItemSetDict4"] def is_scalar(x): @@ -442,3 +443,179 @@ def __repr__(self) -> str: itemsets=itemsets_str, names=self._names, ) + + +class ItemSet4(Dataset): + r"""Class for iterating over tensor-like data. + Experimental. Implemented only __getitem__() accepting slice and list. + """ + + def __init__( + self, + items: Union[torch.Tensor, Mapping, Tuple[Mapping]], + names: Union[str, Tuple[str]] = None, + ): + if is_scalar(items): + self._length = int(items) + self._items = items + elif isinstance(items, tuple): + self._length = len(items[0]) + if any(self._length != len(item) for item in items): + raise ValueError("Size mismatch between items.") + self._items = items + else: + self._length = len(items) + self._items = (items,) + if names is not None: + num_items = ( + len(self._items) if isinstance(self._items, tuple) else 1 + ) + if isinstance(names, tuple): + self._names = names + else: + self._names = (names,) + assert num_items == len(self._names), ( + f"Number of items ({num_items}) and " + f"names ({len(self._names)}) must match." + ) + else: + self._names = None + + def __len__(self) -> int: + return self._length + + def __getitem__(self, index: Union[int, slice, List[int]]): + if is_scalar(self._items): + if isinstance(index, slice): + start, stop, step = index.indices(int(self._items)) + dtype = getattr(self._items, "dtype", torch.int64) + return torch.arange(start, stop, step, dtype=dtype) + elif isinstance(index, int): + if index < 0: + index += int(self._items) + if index < 0 or index >= int(self._items): + raise IndexError( + f"{type(self).__name__} index out of range." + ) + return torch.tensor(index, dtype=self._items.dtype) + elif isinstance(index, list): + dtype = getattr(self._items, "dtype", torch.int64) + return torch.tensor(index, dtype=dtype) + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice or list of int." + ) + elif len(self._items) == 1: + return self._items[0][index] + else: + return tuple(item[index] for item in self._items) + + @property + def names(self) -> Tuple[str]: + """Return the names of the items.""" + return self._names + + def __repr__(self) -> str: + ret = ( + f"{self.__class__.__name__}(\n" + f" items={self._items},\n" + f" names={self._names},\n" + f")" + ) + return ret + + +class ItemSetDict4(Dataset): + r"""Experimental.""" + + def __init__(self, itemsets: Dict[str, ItemSet4]) -> None: + super().__init__() + self._itemsets = itemsets + self._names = next(iter(itemsets.values())).names + self._length = sum(len(itemset) for itemset in itemsets.values()) + if any(self._names != itemset.names for itemset in itemsets.values()): + raise ValueError("All itemsets must have the same names.") + offset = [0] + [len(itemset) for itemset in self._itemsets.values()] + self._offsets = torch.tensor(offset).cumsum(0) + + def __len__(self) -> int: + return self._length + + def __getitem__(self, index: Union[int, slice, List[int]]): + total_num = self._offsets[-1] + if isinstance(index, int): + if index < 0: + index += total_num + if index < 0 or index >= total_num: + raise IndexError(f"{type(self).__name__} index out of range.") + offset_idx = torch.searchsorted(self._offsets, index, right=True) + offset_idx -= 1 + index -= self._offsets[offset_idx] + key = list(self._itemsets.keys())[offset_idx] + return {key: self._itemsets[key][index]} + elif isinstance(index, slice): + start, stop, step = index.indices(total_num) + assert step == 1, "Step must be 1." + assert start < stop, "Start must be smaller than stop." + data = {} + offset_idx_start = max( + 1, torch.searchsorted(self._offsets, start, right=False) + ) + keys = list(self._itemsets.keys()) + for offset_idx in range(offset_idx_start, len(self._offsets)): + key = keys[offset_idx - 1] + data[key] = self._itemsets[key][ + max(0, start - self._offsets[offset_idx - 1]) : stop + - self._offsets[offset_idx - 1] + ] + if stop <= self._offsets[offset_idx]: + break + return data + elif isinstance(index, list): + data = dict.fromkeys(self._itemsets.keys(), []) + for idx in index: + if idx < 0: + idx += total_num + if idx < 0 or idx >= total_num: + raise IndexError( + f"{type(self).__name__} index out of range." + ) + offset_idx = torch.searchsorted(self._offsets, idx, right=True) + offset_idx -= 1 + idx -= self._offsets[offset_idx] + key = list(self._itemsets.keys())[offset_idx] + data[key].append(idx) + for key, value in data.items(): + item_set = self._itemsets[key] + try: + value = item_set[value] + except TypeError: + value = tuple(item_set[idx] for idx in value) + finally: + data[key] = value + return data + else: + raise TypeError( + f"{type(self).__name__} indices must be int, slice or list of int." + ) + + @property + def names(self) -> Tuple[str]: + """Return the names of the items.""" + return self._names + + def __repr__(self) -> str: + ret = ( + "{Classname}(\n" + " itemsets={itemsets},\n" + " names={names},\n" + ")" + ) + itemsets_str = textwrap.indent( + repr(self._itemsets), " " * len(" itemsets=") + ).strip() + return ret.format( + Classname=self.__class__.__name__, + itemsets=itemsets_str, + names=self._names, + ) diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index 127480d9aa15..a1b721ae6d0e 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -4,8 +4,6 @@ from collections import defaultdict from sys import platform -import backend as F - import dgl import pytest import torch @@ -1151,7 +1149,9 @@ def test_RangeCalculation(params): assert key == answer -@unittest.skipIf(F._default_context_str != "cpu", reason="GPU not required.") +@unittest.skipIf( + os.getenv("DGLTESTDEV", "cpu") != "cpu", reason="GPU not required." +) @pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36]) @pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("drop_last", [False, True]) @@ -1185,3 +1185,106 @@ def test_DistributedItemSampler( nprocs=nprocs, join=True, ) + + +def distributed_item_sampler_subprocess4( + proc_id, + nprocs, + item_set, + num_ids, + num_workers, + batch_size, + drop_last, + drop_uneven_inputs, +): + # On Windows, the init method can only be file. + init_method = ( + f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}" + if platform == "win32" + else "tcp://127.0.0.1:12345" + ) + dist.init_process_group( + backend="gloo", # Use Gloo backend for CPU multiprocessing + init_method=init_method, + world_size=nprocs, + rank=proc_id, + ) + + # Create a DistributedItemSampler4. + item_sampler = gb.DistributedItemSampler4( + item_set, + batch_size=batch_size, + shuffle=True, + drop_last=drop_last, + drop_uneven_inputs=drop_uneven_inputs, + ) + feature_fetcher = gb.FeatureFetcher( + item_sampler, + gb.BasicFeatureStore({}), + [], + ) + data_loader = gb.DataLoader(feature_fetcher, num_workers=num_workers) + + # Count the numbers of items and batches. + num_items = 0 + sampled_count = torch.zeros(num_ids, dtype=torch.int32) + for i in data_loader: + # Count how many times each item is sampled. + sampled_count[i.seeds] += 1 + if drop_last: + assert i.seeds.size(0) == batch_size + num_items += i.seeds.size(0) + num_batches = len(list(item_sampler)) + + if drop_uneven_inputs: + num_batches_tensor = torch.tensor(num_batches) + dist.broadcast(num_batches_tensor, 0) + # Test if the number of batches are the same for all processes. + assert num_batches_tensor == num_batches + + # Add up results from all processes. + dist.reduce(sampled_count, 0) + + try: + # Make sure no item is sampled more than once. + assert sampled_count.max() <= 1 + finally: + dist.destroy_process_group() + + +@unittest.skipIf( + os.getenv("DGLTESTDEV", "cpu") != "cpu", reason="GPU not required." +) +@pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36]) +@pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("drop_last", [False, True]) +@pytest.mark.parametrize("drop_uneven_inputs", [False, True]) +def test_DistributedItemSampler4( + num_ids, num_workers, drop_last, drop_uneven_inputs +): + nprocs = 4 + batch_size = 4 + item_set = gb.ItemSet4(torch.arange(0, num_ids), names="seeds") + + # On Windows, if the process group initialization file already exists, + # the program may hang. So we need to delete it if it exists. + if platform == "win32": + try: + os.remove(os.path.join(os.getcwd(), "dis_tempfile")) + except FileNotFoundError: + pass + + mp.spawn( + distributed_item_sampler_subprocess4, + args=( + nprocs, + item_set, + num_ids, + num_workers, + batch_size, + drop_last, + drop_uneven_inputs, + ), + nprocs=nprocs, + join=True, + )