Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Add experimental ItemSet/Dict4 and ItemSampler4 #7371

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
└───> Test set evaluation
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work well with --num-workers 2 for multiple GPUs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both old and new Implementation encounter the same error with --num-workers 2

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/site-packages/torch/utils/data/datapipes/datapipe.py", line 359, in __setstate__
    self._datapipe = dill.loads(value)
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/site-packages/dill/_dill.py", line 303, in loads
    return load(file, ignore, **kwds)
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/site-packages/dill/_dill.py", line 289, in load
    return Unpickler(file, ignore=ignore, **kwds).load()
  File "/home/ubuntu/miniconda3/envs/dgl/lib/python3.9/site-packages/dill/_dill.py", line 444, in load
    obj = StockUnpickler.load(self)
AttributeError: 'PyCapsule' object has no attribute 'cudaHostUnregister'

Is this a long-standing problem? Or is there something wrong with my package version?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid no one run the multi-gpu example with multiple num_workers before. Please file an issue and look into it.

import argparse
import os
import time
Expand Down
171 changes: 169 additions & 2 deletions python/dgl/graphbolt/item_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .internal import calculate_range
from .itemset import ItemSet, ItemSetDict
from .itemset import ItemSet, ItemSet4, ItemSetDict, ItemSetDict4
from .minibatch import MiniBatch

__all__ = ["ItemSampler", "DistributedItemSampler", "minibatcher_default"]
__all__ = [
"ItemSampler",
"DistributedItemSampler",
"minibatcher_default",
"ItemSampler4",
"DistributedItemSampler4",
]


def minibatcher_default(batch, names):
Expand Down Expand Up @@ -833,3 +839,164 @@ def _construct_seeds(pos_seeds, neg_srcs=None, neg_dsts=None):
neg_indexes = pos_indexes.repeat_interleave(negative_ratio)
indexes[etype] = torch.cat((pos_indexes, neg_indexes))
return seeds, labels, indexes


class ItemSampler4(IterDataPipe):
"""Experimental. Try to implement based on the current ItemSampler."""

def __init__(
self,
item_set: Union[ItemSet4, ItemSetDict4],
batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
) -> None:
super().__init__()
self._names = item_set.names
self._item_set = item_set
self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last
self._shuffle = shuffle
self._distributed = False
self._drop_uneven_inputs = False
self._world_size = None
self._rank = None
self._seed = np.random.randint(0, np.iinfo(np.int32).max)
self._epoch = 0

def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only."""
if isinstance(buffer, torch.Tensor):
# For item set that's initialized with integer or single tensor,
# `buffer` is a tensor.
return torch.index_select(buffer, dim=0, index=indices)
elif isinstance(buffer, list) and isinstance(buffer[0], DGLGraph):
# For item set that's initialized with a list of
# DGLGraphs, `buffer` is a list of DGLGraphs.
return dgl_batch([buffer[idx] for idx in indices])
elif isinstance(buffer, tuple):
# For item set that's initialized with a tuple of items,
# `buffer` is a tuple of tensors.
return tuple(item[indices] for item in buffer)
elif isinstance(buffer, Mapping):
# For item set that's initialized with a dict of items,
# `buffer` is a dict of tensors/lists/tuples.
keys = list(buffer.keys())
key_indices = torch.searchsorted(offsets, indices, right=True) - 1
batch = {}
for j, key in enumerate(keys):
mask = (key_indices == j).nonzero().squeeze(1)
if len(mask) == 0:
continue
batch[key] = self._collate_batch(
buffer[key], indices[mask] - offsets[j]
)
return batch
raise TypeError(f"Unsupported buffer type {type(buffer).__name__}.")

def _calculate_offsets(self, buffer):
"""Calculate offsets for each item in buffer. For internal use only."""
if not isinstance(buffer, Mapping):
return None
offsets = [0]
for value in buffer.values():
if isinstance(value, torch.Tensor):
offsets.append(offsets[-1] + len(value))
elif isinstance(value, tuple):
offsets.append(offsets[-1] + len(value[0]))
else:
raise TypeError(
f"Unsupported buffer type {type(value).__name__}."
)
return torch.tensor(offsets)

def __iter__(self) -> Iterator:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers = worker_info.num_workers
worker_id = worker_info.id
else:
num_workers = 1
worker_id = 0
total = len(self._item_set)
start_offset, assigned_count, output_count = calculate_range(
self._distributed,
total,
self._world_size,
self._rank,
num_workers,
worker_id,
self._batch_size,
self._drop_last,
self._drop_uneven_inputs,
)
if self._shuffle:
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
indices = torch.randperm(total, generator=g)
buffer = self._item_set[
indices[start_offset : start_offset + assigned_count]
]
else:
buffer = self._item_set[
start_offset : start_offset + assigned_count
]
offsets = self._calculate_offsets(buffer)
for i in range(0, assigned_count, self._batch_size):
if output_count <= 0:
break
batch_indices = torch.arange(
i, i + min(self._batch_size, output_count)
)
output_count -= self._batch_size
yield self._minibatcher(
self._collate_batch(buffer, batch_indices, offsets), self._names
)

self._epoch += 1


class DistributedItemSampler4(ItemSampler4):
"""Experimental. Try to implement based on the current
DistributedItemSampler."""

def __init__(
self,
item_set: Union[ItemSet4, ItemSetDict4],
batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
drop_uneven_inputs: Optional[bool] = False,
) -> None:
super().__init__(
item_set,
batch_size,
minibatcher,
drop_last,
shuffle,
)
self._distributed = True
self._drop_uneven_inputs = drop_uneven_inputs
if not dist.is_available():
raise RuntimeError(
"Distributed item sampler requires distributed package."
)
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()

device = (
torch.cuda.current_device()
if torch.cuda.is_available() and dist.get_backend() == "nccl"
else "cpu"
)
if self._rank == 0:
seed = np.random.randint(0, np.iinfo(np.int32).max)
seed_tensor = torch.tensor(seed, dtype=torch.int32, device=device)
else:
seed_tensor = torch.empty([], dtype=torch.int32, device=device)
dist.barrier()
Skeleton003 marked this conversation as resolved.
Show resolved Hide resolved
dist.broadcast(seed_tensor, src=0)
self._seed = seed_tensor.item()
181 changes: 179 additions & 2 deletions python/dgl/graphbolt/itemset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""GraphBolt Itemset."""

import textwrap
from typing import Dict, Iterable, Iterator, Tuple, Union
from typing import Dict, Iterable, Iterator, List, Mapping, Tuple, Union

import torch
from torch.utils.data import Dataset

__all__ = ["ItemSet", "ItemSetDict"]
__all__ = ["ItemSet", "ItemSetDict", "ItemSet4", "ItemSetDict4"]


def is_scalar(x):
Expand Down Expand Up @@ -442,3 +443,179 @@ def __repr__(self) -> str:
itemsets=itemsets_str,
names=self._names,
)


class ItemSet4(Dataset):
r"""Class for iterating over tensor-like data.
Experimental. Implemented only __getitem__() accepting slice and list.
"""

def __init__(
self,
items: Union[torch.Tensor, Mapping, Tuple[Mapping]],
names: Union[str, Tuple[str]] = None,
):
if is_scalar(items):
self._length = int(items)
self._items = items
elif isinstance(items, tuple):
self._length = len(items[0])
if any(self._length != len(item) for item in items):
raise ValueError("Size mismatch between items.")
self._items = items
else:
self._length = len(items)
self._items = (items,)
if names is not None:
num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert num_items == len(self._names), (
f"Number of items ({num_items}) and "
f"names ({len(self._names)}) must match."
)
else:
self._names = None

def __len__(self) -> int:
return self._length

def __getitem__(self, index: Union[int, slice, List[int]]):
if is_scalar(self._items):
if isinstance(index, slice):
start, stop, step = index.indices(int(self._items))
dtype = getattr(self._items, "dtype", torch.int64)
return torch.arange(start, stop, step, dtype=dtype)
elif isinstance(index, int):
if index < 0:
index += int(self._items)
if index < 0 or index >= int(self._items):
raise IndexError(
f"{type(self).__name__} index out of range."
)
return torch.tensor(index, dtype=self._items.dtype)
elif isinstance(index, list):
dtype = getattr(self._items, "dtype", torch.int64)
return torch.tensor(index, dtype=dtype)
else:
raise TypeError(
f"{type(self).__name__} indices must be int, slice or list of int."
)
elif len(self._items) == 1:
return self._items[0][index]
else:
return tuple(item[index] for item in self._items)

@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names

def __repr__(self) -> str:
ret = (
f"{self.__class__.__name__}(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
)
return ret


class ItemSetDict4(Dataset):
r"""Experimental."""

def __init__(self, itemsets: Dict[str, ItemSet4]) -> None:
super().__init__()
self._itemsets = itemsets
self._names = next(iter(itemsets.values())).names
self._length = sum(len(itemset) for itemset in itemsets.values())
if any(self._names != itemset.names for itemset in itemsets.values()):
raise ValueError("All itemsets must have the same names.")
offset = [0] + [len(itemset) for itemset in self._itemsets.values()]
self._offsets = torch.tensor(offset).cumsum(0)

def __len__(self) -> int:
return self._length

def __getitem__(self, index: Union[int, slice, List[int]]):
total_num = self._offsets[-1]
if isinstance(index, int):
if index < 0:
index += total_num
if index < 0 or index >= total_num:
raise IndexError(f"{type(self).__name__} index out of range.")
offset_idx = torch.searchsorted(self._offsets, index, right=True)
offset_idx -= 1
index -= self._offsets[offset_idx]
key = list(self._itemsets.keys())[offset_idx]
return {key: self._itemsets[key][index]}
elif isinstance(index, slice):
start, stop, step = index.indices(total_num)
assert step == 1, "Step must be 1."
assert start < stop, "Start must be smaller than stop."
data = {}
offset_idx_start = max(
1, torch.searchsorted(self._offsets, start, right=False)
)
keys = list(self._itemsets.keys())
for offset_idx in range(offset_idx_start, len(self._offsets)):
key = keys[offset_idx - 1]
data[key] = self._itemsets[key][
max(0, start - self._offsets[offset_idx - 1]) : stop
- self._offsets[offset_idx - 1]
]
if stop <= self._offsets[offset_idx]:
break
return data
elif isinstance(index, list):
data = dict.fromkeys(self._itemsets.keys(), [])
for idx in index:
if idx < 0:
idx += total_num
if idx < 0 or idx >= total_num:
raise IndexError(
f"{type(self).__name__} index out of range."
)
offset_idx = torch.searchsorted(self._offsets, idx, right=True)
offset_idx -= 1
idx -= self._offsets[offset_idx]
key = list(self._itemsets.keys())[offset_idx]
data[key].append(idx)
for key, value in data.items():
item_set = self._itemsets[key]
try:
value = item_set[value]
except TypeError:
value = tuple(item_set[idx] for idx in value)
finally:
data[key] = value
return data
else:
raise TypeError(
f"{type(self).__name__} indices must be int, slice or list of int."
)

@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names

def __repr__(self) -> str:
ret = (
"{Classname}(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
itemsets_str = textwrap.indent(
repr(self._itemsets), " " * len(" itemsets=")
).strip()
return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)