diff --git a/python/dgl/graphbolt/item_sampler.py b/python/dgl/graphbolt/item_sampler.py index f92798e6b5cf..bfa2dbed88ce 100644 --- a/python/dgl/graphbolt/item_sampler.py +++ b/python/dgl/graphbolt/item_sampler.py @@ -1,19 +1,15 @@ """Item Sampler""" from collections.abc import Mapping -from functools import partial from typing import Callable, Iterator, Optional, Union import numpy as np import torch import torch.distributed as dist -from torch.utils.data import default_collate -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe +from torchdata.datapipes.iter import IterDataPipe from ..base import dgl_warning -from ..batch import batch as dgl_batch -from ..heterograph import DGLGraph from .internal import calculate_range from .itemset import ItemSet, ItemSetDict from .minibatch import MiniBatch @@ -110,166 +106,11 @@ def minibatcher_default(batch, names): return minibatch -class ItemShufflerAndBatcher: - """A shuffler to shuffle items and create batches. - - This class is used internally by :class:`ItemSampler` to shuffle items and - create batches. It is not supposed to be used directly. The intention of - this class is to avoid time-consuming iteration over :class:`ItemSet`. As - an optimization, it slices from the :class:`ItemSet` via indexing first, - then shuffle and create batches. - - Parameters - ---------- - item_set : ItemSet - Data to be iterated. - shuffle : bool - Option to shuffle before batching. - batch_size : int - The size of each batch. - drop_last : bool - Option to drop the last batch if it's not full. - buffer_size : int - The size of the buffer to store items sliced from the :class:`ItemSet` - or :class:`ItemSetDict`. - distributed : bool - Option to apply on :class:`DistributedItemSampler`. - drop_uneven_inputs : bool - Option to make sure the numbers of batches for each replica are the - same. Applies only when `distributed` is True. - world_size : int - The number of model replicas that will be created during Distributed - Data Parallel (DDP) training. It should be the same as the real world - size, otherwise it could cause errors. Applies only when `distributed` - is True. - rank : int - The rank of the current replica. Applies only when `distributed` is - True. - rng : np.random.Generator - The random number generator to use for shuffling. - """ - - def __init__( - self, - item_set: ItemSet, - shuffle: bool, - batch_size: int, - drop_last: bool, - buffer_size: int, - distributed: Optional[bool] = False, - drop_uneven_inputs: Optional[bool] = False, - world_size: Optional[int] = 1, - rank: Optional[int] = 0, - rng: Optional[np.random.Generator] = None, - ): - self._item_set = item_set - self._shuffle = shuffle - self._batch_size = batch_size - self._drop_last = drop_last - self._buffer_size = buffer_size - # Round up the buffer size to the nearest multiple of batch size. - self._buffer_size = ( - (self._buffer_size + batch_size - 1) // batch_size * batch_size - ) - self._distributed = distributed - self._drop_uneven_inputs = drop_uneven_inputs - self._num_replicas = world_size - self._rank = rank - self._rng = rng - - def _collate_batch(self, buffer, indices, offsets=None): - """Collate a batch from the buffer. For internal use only.""" - if isinstance(buffer, torch.Tensor): - # For item set that's initialized with integer or single tensor, - # `buffer` is a tensor. - return torch.index_select(buffer, dim=0, index=indices) - elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph): - # For item set that's initialized with a list of - # DGLGraphs, `buffer` is a list of DGLGraphs. - return dgl_batch([buffer[idx] for idx in indices]) - elif isinstance(buffer, tuple): - # For item set that's initialized with a tuple of items, - # `buffer` is a tuple of tensors. - return tuple(item[indices] for item in buffer) - elif isinstance(buffer, Mapping): - # For item set that's initialized with a dict of items, - # `buffer` is a dict of tensors/lists/tuples. - keys = list(buffer.keys()) - key_indices = torch.searchsorted(offsets, indices, right=True) - 1 - batch = {} - for j, key in enumerate(keys): - mask = (key_indices == j).nonzero().squeeze(1) - if len(mask) == 0: - continue - batch[key] = self._collate_batch( - buffer[key], indices[mask] - offsets[j] - ) - return batch - raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") - - def _calculate_offsets(self, buffer): - """Calculate offsets for each item in buffer. For internal use only.""" - if not isinstance(buffer, Mapping): - return None - offsets = [0] - for value in buffer.values(): - if isinstance(value, torch.Tensor): - offsets.append(offsets[-1] + len(value)) - elif isinstance(value, tuple): - offsets.append(offsets[-1] + len(value[0])) - else: - raise TypeError( - f"Unsupported buffer type {type(value).__name__}." - ) - return torch.tensor(offsets) - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - num_workers = worker_info.num_workers - worker_id = worker_info.id - else: - num_workers = 1 - worker_id = 0 - buffer = None - total = len(self._item_set) - start_offset, assigned_count, output_count = calculate_range( - self._distributed, - total, - self._num_replicas, - self._rank, - num_workers, - worker_id, - self._batch_size, - self._drop_last, - self._drop_uneven_inputs, - ) - start = 0 - while start < assigned_count: - end = min(start + self._buffer_size, assigned_count) - buffer = self._item_set[start_offset + start : start_offset + end] - indices = torch.arange(end - start) - if self._shuffle: - self._rng.shuffle(indices.numpy()) - offsets = self._calculate_offsets(buffer) - for i in range(0, len(indices), self._batch_size): - if output_count <= 0: - break - batch_indices = indices[ - i : i + min(self._batch_size, output_count) - ] - output_count -= self._batch_size - yield self._collate_batch(buffer, batch_indices, offsets) - buffer = None - start = end - - class ItemSampler(IterDataPipe): - """A sampler to iterate over input items and create subsets. + """A sampler to iterate over input items and create minibatches. Input items could be node IDs, node pairs with or without labels, node - pairs with negative sources/destinations, DGLGraphs and heterogeneous - counterparts. + pairs with negative sources/destinations. Note: This class `ItemSampler` is not decorated with `torchdata.datapipes.functional_datapipe` on purpose. This indicates it @@ -288,24 +129,9 @@ class ItemSampler(IterDataPipe): Option to drop the last batch if it's not full. shuffle : bool Option to shuffle before sample. - use_indexing : bool - Option to use indexing to slice items from the item set. This is an - optimization to avoid time-consuming iteration over the item set. If - the item set does not support indexing, this option will be disabled - automatically. If the item set supports indexing but the user wants to - disable it, this option can be set to False. By default, it is set to - True. - buffer_size : int - The size of the buffer to store items sliced from the :class:`ItemSet` - or :class:`ItemSetDict`. By default, it is set to -1, which means the - buffer size will be set as the total number of items in the item set if - indexing is supported. If indexing is not supported, it is set to 10 * - batch size. If the item set is too large, it is recommended to set a - smaller buffer size to avoid out of memory error. As items are shuffled - within each buffer, a smaller buffer size may incur less randomness and - such less randomness can further affect the training performance such as - convergence speed and accuracy. Therefore, it is recommended to set a - larger buffer size if possible. + seed: int + The seed for reproducible stochastic shuffling. If None, a random seed + will be generated. Examples -------- @@ -369,21 +195,7 @@ class ItemSampler(IterDataPipe): indexes=tensor([0, 1, 0, 0]), edge_features=None, compacted_seeds=None, blocks=None,) - 5. DGLGraphs. - - >>> import dgl - >>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ] - >>> item_set = gb.ItemSet(graphs) - >>> item_sampler = gb.ItemSampler(item_set, 3) - >>> list(item_sampler) - [Graph(num_nodes=30, num_edges=60, - ndata_schemes={} - edata_schemes={}), - Graph(num_nodes=20, num_edges=40, - ndata_schemes={} - edata_schemes={})] - - 6. Further process batches with other datapipes such as + 5. Further process batches with other datapipes such as :class:`torchdata.datapipes.iter.Mapper`. >>> item_set = gb.ItemSet(torch.arange(0, 10)) @@ -394,7 +206,7 @@ class ItemSampler(IterDataPipe): >>> list(data_pipe) [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])] - 7. Heterogeneous node IDs. + 6. Heterogeneous node IDs. >>> ids = { ... "user": gb.ItemSet(torch.arange(0, 5), names="seeds"), @@ -407,7 +219,7 @@ class ItemSampler(IterDataPipe): node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,) - 8. Heterogeneous node pairs. + 7. Heterogeneous node pairs. >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> seeds_follow = torch.arange(10, 20).reshape(-1, 2) @@ -424,7 +236,7 @@ class ItemSampler(IterDataPipe): node_features=None, labels=None, input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,) - 9. Heterogeneous node pairs and labels. + 8. Heterogeneous node pairs and labels. >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> labels_like = torch.arange(0, 5) @@ -444,7 +256,7 @@ class ItemSampler(IterDataPipe): input_nodes=None, indexes=None, edge_features=None, compacted_seeds=None, blocks=None,) - 10. Heterogeneous node pairs, labels and indexes. + 9. Heterogeneous node pairs, labels and indexes. >>> seeds_like = torch.arange(0, 10).reshape(-1, 2) >>> labels_like = torch.tensor([1, 1, 0, 0, 0]) @@ -474,34 +286,11 @@ def __init__( minibatcher: Optional[Callable] = minibatcher_default, drop_last: Optional[bool] = False, shuffle: Optional[bool] = False, - # [TODO][Rui] For now, it's a temporary knob to disable indexing. In - # the future, we will enable indexing for all the item sets. - use_indexing: Optional[bool] = True, - buffer_size: Optional[int] = -1, + seed: Optional[int] = None, ) -> None: super().__init__() + self._item_set = item_set self._names = item_set.names - # Check if the item set supports indexing. - indexable = True - try: - item_set[0] - except TypeError: - indexable = False - self._use_indexing = use_indexing and indexable - self._item_set = ( - item_set if self._use_indexing else IterableWrapper(item_set) - ) - if buffer_size == -1: - if indexable: - # Set the buffer size to the total number of items in the item - # set if indexing is supported and the buffer size is not - # specified. - buffer_size = len(self._item_set) - else: - # Set the buffer size to 10 * batch size if indexing is not - # supported and the buffer size is not specified. - buffer_size = 10 * batch_size - self._buffer_size = buffer_size self._batch_size = batch_size self._minibatcher = minibatcher self._drop_last = drop_last @@ -510,68 +299,102 @@ def __init__( self._drop_uneven_inputs = False self._world_size = None self._rank = None - self._rng = np.random.default_rng() - - def _organize_items(self, data_pipe) -> None: - # Shuffle before batch. - if self._shuffle: - data_pipe = data_pipe.shuffle(buffer_size=self._buffer_size) + # For the sake of reproducibility, the seed should be allowed to be + # manually set by the user. + if seed is None: + self._seed = np.random.randint(0, np.iinfo(np.int32).max) + else: + self._seed = seed + # The attribute `self._epoch` is added to make shuffling work properly + # across multiple epochs. Otherwise, the same ordering will always be + # used in every epoch. + self._epoch = 0 - # Batch. - data_pipe = data_pipe.batch( - batch_size=self._batch_size, - drop_last=self._drop_last, - ) + def _collate_batch(self, buffer, indices, offsets=None): + """Collate a batch from the buffer. For internal use only.""" + if isinstance(buffer, torch.Tensor): + # For item set that's initialized with integer or single tensor, + # `buffer` is a tensor. + return torch.index_select(buffer, dim=0, index=indices) + elif isinstance(buffer, tuple): + # For item set that's initialized with a tuple of items, + # `buffer` is a tuple of tensors. + return tuple(item[indices] for item in buffer) + elif isinstance(buffer, Mapping): + # For item set that's initialized with a dict of items, + # `buffer` is a dict of tensors/lists/tuples. + keys = list(buffer.keys()) + key_indices = torch.searchsorted(offsets, indices, right=True) - 1 + batch = {} + for j, key in enumerate(keys): + mask = (key_indices == j).nonzero().squeeze(1) + if len(mask) == 0: + continue + batch[key] = self._collate_batch( + buffer[key], indices[mask] - offsets[j] + ) + return batch + raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.") - return data_pipe - - @staticmethod - def _collate(batch): - """Collate items into a batch. For internal use only.""" - data = next(iter(batch)) - if isinstance(data, DGLGraph): - return dgl_batch(batch) - elif isinstance(data, Mapping): - assert len(data) == 1, "Only one type of data is allowed." - # Collect all the keys. - keys = {key for item in batch for key in item.keys()} - # Collate each key. - return { - key: default_collate( - [item[key] for item in batch if key in item] + def _calculate_offsets(self, buffer): + """Calculate offsets for each item in buffer. For internal use only.""" + if not isinstance(buffer, Mapping): + return None + offsets = [0] + for value in buffer.values(): + if isinstance(value, torch.Tensor): + offsets.append(offsets[-1] + len(value)) + elif isinstance(value, tuple): + offsets.append(offsets[-1] + len(value[0])) + else: + raise TypeError( + f"Unsupported buffer type {type(value).__name__}." ) - for key in keys - } - return default_collate(batch) + return torch.tensor(offsets) def __iter__(self) -> Iterator: - if self._use_indexing: - seed = self._rng.integers(0, np.iinfo(np.int32).max) - data_pipe = IterableWrapper( - ItemShufflerAndBatcher( - self._item_set, - self._shuffle, - self._batch_size, - self._drop_last, - self._buffer_size, - distributed=self._distributed, - drop_uneven_inputs=self._drop_uneven_inputs, - world_size=self._world_size, - rank=self._rank, - rng=np.random.default_rng(seed), - ) - ) + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker_id = worker_info.id else: - # Organize items. - data_pipe = self._organize_items(self._item_set) - - # Collate. - data_pipe = data_pipe.collate(collate_fn=self._collate) - - # Map to minibatch. - data_pipe = data_pipe.map(partial(self._minibatcher, names=self._names)) + num_workers = 1 + worker_id = 0 + total = len(self._item_set) + start_offset, assigned_count, output_count = calculate_range( + self._distributed, + total, + self._world_size, + self._rank, + num_workers, + worker_id, + self._batch_size, + self._drop_last, + self._drop_uneven_inputs, + ) + if self._shuffle: + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) + _permutation = torch.randperm(total, generator=g) + buffer = self._item_set[ + _permutation[start_offset : start_offset + assigned_count] + ] + else: + buffer = self._item_set[ + start_offset : start_offset + assigned_count + ] + offsets = self._calculate_offsets(buffer) + for i in range(0, assigned_count, self._batch_size): + if output_count <= 0: + break + indices = torch.arange(i, i + min(self._batch_size, output_count)) + output_count -= self._batch_size + yield self._minibatcher( + self._collate_batch(buffer, indices, offsets), + self._names, + ) - return iter(data_pipe) + self._epoch += 1 class DistributedItemSampler(ItemSampler): @@ -623,16 +446,9 @@ class DistributedItemSampler(ItemSampler): https://pytorch.org/tutorials/advanced/generic_join.html. However, this option can be used if the Join Context Manager is not helpful for any reason. - buffer_size : int - The size of the buffer to store items sliced from the :class:`ItemSet` - or :class:`ItemSetDict`. By default, it is set to -1, which means the - buffer size will be set as the total number of items in the item set. - If the item set is too large, it is recommended to set a smaller buffer - size to avoid out of memory error. As items are shuffled within each - buffer, a smaller buffer size may incur less randomness and such less - randomness can further affect the training performance such as - convergence speed and accuracy. Therefore, it is recommended to set a - larger buffer size if possible. + seed: int + The seed for reproducible stochastic shuffling. If None, a random seed + will be generated. Examples -------- @@ -737,7 +553,7 @@ def __init__( drop_last: Optional[bool] = False, shuffle: Optional[bool] = False, drop_uneven_inputs: Optional[bool] = False, - buffer_size: Optional[int] = -1, + seed: Optional[int] = None, ) -> None: super().__init__( item_set, @@ -745,8 +561,7 @@ def __init__( minibatcher, drop_last, shuffle, - use_indexing=True, - buffer_size=buffer_size, + seed, ) self._distributed = True self._drop_uneven_inputs = drop_uneven_inputs @@ -756,6 +571,40 @@ def __init__( ) self._world_size = dist.get_world_size() self._rank = dist.get_rank() + if self._world_size > 1: + # For the sake of reproducibility, the seed should be allowed to be + # manually set by the user. + self._align_seeds(src=0, seed=seed) + + def _align_seeds( + self, src: Optional[int] = 0, seed: Optional[int] = None + ) -> None: + """Aligns seeds across distributed processes. + + This method synchronizes seeds across distributed processes, ensuring + consistent randomness. + + Parameters + ---------- + src: int, optional + The source process rank. Defaults to 0. + seed: int, optional + The seed value to synchronize. If None, a random seed will be + generated. Defaults to None. + """ + device = ( + torch.cuda.current_device() + if torch.cuda.is_available() and dist.get_backend() == "nccl" + else "cpu" + ) + if seed is None: + seed = np.random.randint(0, np.iinfo(np.int32).max) + if self._rank == src: + seed_tensor = torch.tensor(seed, dtype=torch.int32, device=device) + else: + seed_tensor = torch.empty([], dtype=torch.int32, device=device) + dist.broadcast(seed_tensor, src=src) + self._seed = seed_tensor.item() def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None): diff --git a/python/dgl/graphbolt/itemset.py b/python/dgl/graphbolt/itemset.py index 3f8472f856a3..848a61561f45 100644 --- a/python/dgl/graphbolt/itemset.py +++ b/python/dgl/graphbolt/itemset.py @@ -362,33 +362,20 @@ def __getitem__(self, index: Union[int, slice, Iterable[int]]): break return data elif isinstance(index, Iterable): - # TODO[Mingbang]: Might have performance issue. Tests needed. - data = {key: [] for key in self._keys} - for idx in index: - if idx < 0: - idx += self._length - if idx < 0 or idx >= self._length: - raise IndexError( - f"{type(self).__name__} index out of range." - ) - offset_idx = torch.searchsorted(self._offsets, idx, right=True) - offset_idx -= 1 - idx -= self._offsets[offset_idx] - key = self._keys[offset_idx] - data[key].append(int(idx)) - for key in self._keys: - indices = data[key] - if len(indices) == 0: - del data[key] + if not isinstance(index, torch.Tensor): + index = torch.tensor(index) + assert torch.all((index >= 0) & (index < self._length)) + key_indices = ( + torch.searchsorted(self._offsets, index, right=True) - 1 + ) + data = {} + for key_id, key in enumerate(self._keys): + mask = (key_indices == key_id).nonzero().squeeze(1) + if len(mask) == 0: continue - item_set = self._itemsets[key] - try: - value = item_set[indices] - except TypeError: - # In case the itemset doesn't support list indexing. - value = tuple(item_set[idx] for idx in indices) - finally: - data[key] = value + data[key] = self._itemsets[key][ + index[mask] - self._offsets[key_id] + ] return data else: raise TypeError( diff --git a/tests/python/pytorch/graphbolt/test_item_sampler.py b/tests/python/pytorch/graphbolt/test_item_sampler.py index 744b81462e4a..8a5d78ee9583 100644 --- a/tests/python/pytorch/graphbolt/test_item_sampler.py +++ b/tests/python/pytorch/graphbolt/test_item_sampler.py @@ -162,54 +162,6 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last): ) -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_ItemSet_graphs(batch_size, shuffle, drop_last): - # Graphs. - num_graphs = 103 - num_nodes = 10 - num_edges = 20 - graphs = [ - dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1)) - for i in range(num_graphs) - ] - item_set = gb.ItemSet(graphs, names="graphs") - # DGLGraph is not supported in gb.MiniBatch yet. Let's use a customized - # minibatcher to return the original graphs. - customized_minibatcher = lambda batch, names: batch - item_sampler = gb.ItemSampler( - item_set, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - minibatcher=customized_minibatcher, - ) - minibatch_num_nodes = [] - minibatch_num_edges = [] - for i, minibatch in enumerate(item_sampler): - is_last = (i + 1) * batch_size >= num_graphs - if not is_last or num_graphs % batch_size == 0: - assert minibatch.batch_size == batch_size - else: - if not drop_last: - assert minibatch.batch_size == num_graphs % batch_size - else: - assert False - minibatch_num_nodes.append(minibatch.batch_num_nodes()) - minibatch_num_edges.append(minibatch.batch_num_edges()) - minibatch_num_nodes = torch.cat(minibatch_num_nodes) - minibatch_num_edges = torch.cat(minibatch_num_edges) - assert ( - torch.all(minibatch_num_nodes[:-1] <= minibatch_num_nodes[1:]) - is not shuffle - ) - assert ( - torch.all(minibatch_num_edges[:-1] <= minibatch_num_edges[1:]) - is not shuffle - ) - - @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])