diff --git a/docs/source/api/python/dgl.graphbolt.rst b/docs/source/api/python/dgl.graphbolt.rst index 156b3a5712d7..ba7c7e129d9c 100644 --- a/docs/source/api/python/dgl.graphbolt.rst +++ b/docs/source/api/python/dgl.graphbolt.rst @@ -187,6 +187,7 @@ Utilities etype_tuple_to_str isin seed + index_select expand_indptr add_reverse_edges exclude_seed_edges diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index b9fa73353300..b144fdb1d5cc 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -284,6 +284,12 @@ def run(rank, world_size, args, devices, dataset): hidden_size = 256 out_size = num_classes + if args.gpu_cache_size > 0: + dataset.feature._features[("node", None, "feat")] = gb.GPUCachedFeature( + dataset.feature._features[("node", None, "feat")], + args.gpu_cache_size, + ) + # Create GraphSAGE model. It should be copied onto a GPU as a replica. model = SAGE(in_size, hidden_size, out_size).to(device) model = DDP(model) @@ -381,6 +387,12 @@ def parse_args(): parser.add_argument( "--num-workers", type=int, default=0, help="The number of processes." ) + parser.add_argument( + "--gpu-cache-size", + type=int, + default=0, + help="The capacity of the GPU cache, the number of features to store.", + ) parser.add_argument( "--mode", default="pinned-cuda", diff --git a/examples/sampling/graphbolt/link_prediction.py b/examples/sampling/graphbolt/link_prediction.py index 20e169b570a5..0794f79ceb8b 100644 --- a/examples/sampling/graphbolt/link_prediction.py +++ b/examples/sampling/graphbolt/link_prediction.py @@ -79,14 +79,10 @@ def forward(self, blocks, x): hidden_x = F.relu(hidden_x) return hidden_x - def inference(self, graph, features, dataloader, device): + def inference(self, graph, features, dataloader, storage_device): """Conduct layer-wise inference to get all the node embeddings.""" - feature = features.read("node", None, "feat") - - buffer_device = torch.device("cpu") - # Enable pin_memory for faster CPU to GPU data transfer if the - # model is running on a GPU. - pin_memory = buffer_device != device + pin_memory = storage_device == "pinned" + buffer_device = torch.device("cpu" if pin_memory else storage_device) print("Start node embedding inference.") for layer_idx, layer in enumerate(self.layers): @@ -99,17 +95,17 @@ def inference(self, graph, features, dataloader, device): device=buffer_device, pin_memory=pin_memory, ) - feature = feature.to(device) - for step, data in tqdm.tqdm(enumerate(dataloader)): - x = feature[data.input_nodes] - hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 + for data in tqdm.tqdm(dataloader): + # len(blocks) = 1 + hidden_x = layer(data.blocks[0], data.node_features["feat"]) if not is_last_layer: hidden_x = F.relu(hidden_x) # By design, our seed nodes are contiguous. y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( buffer_device, non_blocking=True ) - feature = y + if not is_last_layer: + features.update("node", None, "feat", y) return y @@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True): # [Role]: # Initialize a neighbor sampler for sampling the neighborhoods of nodes. ############################################################################ - datapipe = datapipe.sample_neighbor(graph, args.fanout) + datapipe = datapipe.sample_neighbor( + graph, args.fanout if is_train else [-1] + ) ############################################################################ # [Input]: @@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True): # A FeatureFetcher object to fetch node features. # [Role]: # Initialize a feature fetcher for fetching features of the sampled - # subgraphs. This step is skipped in evaluation/inference because features - # are updated as a whole during it, thus storing features in minibatch is - # unnecessary. + # subgraphs. ############################################################################ - if is_train: - datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) + datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) ############################################################################ # [Input]: @@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set): model.eval() evaluator = Evaluator(name="ogbl-citation2") - # Since we need to use all neghborhoods for evaluation, we set the fanout - # to -1. - args.fanout = [-1] dataloader = create_dataloader( args, graph, features, all_nodes_set, is_train=False ) # Compute node embeddings for the entire graph. - node_emb = model.inference(graph, features, dataloader, args.device) + node_emb = model.inference(graph, features, dataloader, args.storage_device) results = [] # Loop over both validation and test sets. @@ -340,6 +332,8 @@ def train(args, model, graph, features, train_set): total_loss += loss.item() if step + 1 == args.early_stop: + # Early stopping requires a new dataloader to reset its state. + dataloader = create_dataloader(args, graph, features, train_set) break end_epoch_time = time.time() diff --git a/examples/sampling/graphbolt/node_classification.py b/examples/sampling/graphbolt/node_classification.py index c8eaf9a47f79..e5496e23a567 100644 --- a/examples/sampling/graphbolt/node_classification.py +++ b/examples/sampling/graphbolt/node_classification.py @@ -131,11 +131,9 @@ def create_dataloader( # A FeatureFetcher object to fetch node features. # [Role]: # Initialize a feature fetcher for fetching features of the sampled - # subgraphs. This step is skipped in inference because features are updated - # as a whole during it, thus storing features in minibatch is unnecessary. + # subgraphs. ############################################################################ - if job != "infer": - datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) + datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) ############################################################################ # [Step-5]: @@ -194,14 +192,10 @@ def forward(self, blocks, x): hidden_x = self.dropout(hidden_x) return hidden_x - def inference(self, graph, features, dataloader, device): + def inference(self, graph, features, dataloader, storage_device): """Conduct layer-wise inference to get all the node embeddings.""" - feature = features.read("node", None, "feat") - - buffer_device = torch.device("cpu") - # Enable pin_memory for faster CPU to GPU data transfer if the - # model is running on a GPU. - pin_memory = buffer_device != device + pin_memory = storage_device == "pinned" + buffer_device = torch.device("cpu" if pin_memory else storage_device) for layer_idx, layer in enumerate(self.layers): is_last_layer = layer_idx == len(self.layers) - 1 @@ -213,11 +207,9 @@ def inference(self, graph, features, dataloader, device): device=buffer_device, pin_memory=pin_memory, ) - feature = feature.to(device) - - for step, data in tqdm(enumerate(dataloader)): - x = feature[data.input_nodes] - hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 + for data in tqdm(dataloader): + # len(blocks) = 1 + hidden_x = layer(data.blocks[0], data.node_features["feat"]) if not is_last_layer: hidden_x = F.relu(hidden_x) hidden_x = self.dropout(hidden_x) @@ -225,7 +217,8 @@ def inference(self, graph, features, dataloader, device): y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( buffer_device ) - feature = y + if not is_last_layer: + features.update("node", None, "feat", y) return y @@ -245,7 +238,7 @@ def layerwise_infer( num_workers=args.num_workers, job="infer", ) - pred = model.inference(graph, features, dataloader, args.device) + pred = model.inference(graph, features, dataloader, args.storage_device) pred = pred[test_set._items[0]] label = test_set._items[1].to(pred.device) diff --git a/graphbolt/src/cuda/gpu_cache.cu b/graphbolt/src/cuda/gpu_cache.cu index 0a47bbbddc18..f72446ec2626 100644 --- a/graphbolt/src/cuda/gpu_cache.cu +++ b/graphbolt/src/cuda/gpu_cache.cu @@ -43,20 +43,19 @@ std::tuple GpuCache::Query( torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); auto missing_keys = torch::empty(keys.size(0), keys.options().dtype(torch::kLong)); - cuda::CopyScalar missing_len; - auto stream = cuda::GetCurrentStream(); + auto allocator = cuda::GetAllocator(); + auto missing_len_device = allocator.AllocateStorage(1); cache_->Query( reinterpret_cast(keys.data_ptr()), keys.size(0), values.data_ptr(), reinterpret_cast(missing_index.data_ptr()), - reinterpret_cast(missing_keys.data_ptr()), missing_len.get(), - stream); + reinterpret_cast(missing_keys.data_ptr()), + missing_len_device.get(), cuda::GetCurrentStream()); values = values.view(torch::kByte) .slice(1, 0, num_bytes_) .view(dtype_) .view(shape_); - // To safely read missing_len, we synchronize - stream.synchronize(); + cuda::CopyScalar missing_len(missing_len_device.get()); missing_index = missing_index.slice(0, 0, static_cast(missing_len)); missing_keys = missing_keys.slice(0, 0, static_cast(missing_len)); return std::make_tuple(values, missing_index, missing_keys); @@ -79,6 +78,7 @@ void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) { "Values should have the correct dimensions."); TORCH_CHECK( values.scalar_type() == dtype_, "Values should have the correct dtype."); + if (keys.numel() == 0) return; keys = keys.to(torch::kLong); torch::Tensor float_values; if (num_bytes_ % sizeof(float) != 0) { diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 66de586f683e..4dacb9792448 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -97,12 +97,21 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( } if (node_attributes.has_value()) { for (const auto& pair : node_attributes.value()) { - TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1); + TORCH_CHECK( + pair.value().size(0) == indptr.size(0) - 1, + "Expected node_attribute.size(0) and num_nodes to be equal, " + "but node_attribute.size(0) was ", + pair.value().size(0), ", and num_nodes was ", indptr.size(0) - 1, + "."); } } if (edge_attributes.has_value()) { for (const auto& pair : edge_attributes.value()) { - TORCH_CHECK(pair.value().size(0) == indices.size(0)); + TORCH_CHECK( + pair.value().size(0) == indices.size(0), + "Expected edge_attribute.size(0) and num_edges to be equal, " + "but edge_attribute.size(0) was ", + pair.value().size(0), ", and num_edges was ", indices.size(0), "."); } } return c10::make_intrusive( @@ -810,12 +819,71 @@ torch::Tensor TemporalMask( return mask; } +/** + * @brief Fast path for temporal sampling without probability. It is used when + * the number of neighbors is large. It randomly samples neighbors and checks + * the timestamp of the neighbors. It is successful if the number of sampled + * neighbors in kTriedThreshold trials is equal to the fanout. + */ +std::pair> FastTemporalPick( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t fanout, + bool replace, const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, int64_t seed_offset, + int64_t offset, int64_t num_neighbors) { + constexpr int64_t kTriedThreshold = 1000; + auto timestamp = utils::GetValueByIndex(seed_timestamp, seed_offset); + std::vector sampled_edges; + sampled_edges.reserve(fanout); + std::set sampled_edge_set; + int64_t sample_count = 0; + int64_t tried = 0; + while (sample_count < fanout && tried < kTriedThreshold) { + int64_t edge_id = + RandomEngine::ThreadLocal()->RandInt(offset, offset + num_neighbors); + ++tried; + if (!replace && sampled_edge_set.count(edge_id) > 0) { + continue; + } + if (node_timestamp.has_value()) { + int64_t neighbor_id = + utils::GetValueByIndex(csc_indices, edge_id); + if (utils::GetValueByIndex( + node_timestamp.value(), neighbor_id) >= timestamp) + continue; + } + if (edge_timestamp.has_value() && + utils::GetValueByIndex(edge_timestamp.value(), edge_id) >= + timestamp) { + continue; + } + if (!replace) { + sampled_edge_set.insert(edge_id); + } + sampled_edges.push_back(edge_id); + sample_count++; + } + if (sample_count < fanout) { + return {false, {}}; + } + return {true, sampled_edges}; +} + int64_t TemporalNumPick( torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, bool replace, const torch::optional& probs_or_mask, const torch::optional& node_timestamp, const torch::optional& edge_timestamp, int64_t seed_offset, int64_t offset, int64_t num_neighbors) { + constexpr int64_t kFastPathThreshold = 1000; + if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) { + // TODO: Currently we use the fast path both in TemporalNumPick and + // TemporalPick. We may only sample once in TemporalNumPick and use the + // sampled edges in TemporalPick to avoid sampling twice. + auto [success, sampled_edges] = FastTemporalPick( + seed_timestamp, csc_indics, fanout, replace, node_timestamp, + edge_timestamp, seed_offset, offset, num_neighbors); + if (success) return sampled_edges.size(); + } auto mask = TemporalMask( utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indics, probs_or_mask, node_timestamp, edge_timestamp, @@ -1183,6 +1251,19 @@ int64_t TemporalPick( const torch::optional& node_timestamp, const torch::optional& edge_timestamp, SamplerArgs args, PickedType* picked_data_ptr) { + constexpr int64_t kFastPathThreshold = 1000; + if (S == SamplerType::NEIGHBOR && num_neighbors > kFastPathThreshold && + !probs_or_mask.has_value()) { + auto [success, sampled_edges] = FastTemporalPick( + seed_timestamp, csc_indices, fanout, replace, node_timestamp, + edge_timestamp, seed_offset, offset, num_neighbors); + if (success) { + for (size_t i = 0; i < sampled_edges.size(); ++i) { + picked_data_ptr[i] = static_cast(sampled_edges[i]); + } + return sampled_edges.size(); + } + } auto mask = TemporalMask( utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indices, probs_or_mask, node_timestamp, edge_timestamp, diff --git a/python/dgl/convert.py b/python/dgl/convert.py index 49578a168831..1ab64ddbb116 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -1,4 +1,5 @@ """Module for converting graph from/to other object.""" + from collections import defaultdict from collections.abc import Mapping @@ -296,9 +297,9 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None): >>> g = dgl.heterograph(data_dict) >>> g Graph(num_nodes={'game': 5, 'topic': 3, 'user': 4}, - num_edges={('user', 'follows', 'user'): 2, ('user', 'follows', 'topic'): 2, + num_edges={('user', 'follows', 'topic'): 2, ('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 2}, - metagraph=[('user', 'user', 'follows'), ('user', 'topic', 'follows'), + metagraph=[('user', 'topic', 'follows'), ('user', 'user', 'follows'), ('user', 'game', 'plays')]) Explicitly specify the number of nodes for each node type in the graph. @@ -1810,11 +1811,11 @@ def to_networkx( ... ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])), ... ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4])) ... }) - ... g.ndata['n'] = { + >>> g.ndata['n'] = { ... 'game': torch.zeros(5, 1), ... 'user': torch.ones(4, 1) ... } - ... g.edata['e'] = { + >>> g.edata['e'] = { ... ('user', 'follows', 'user'): torch.zeros(2, 1), ... 'plays': torch.ones(2, 1) ... } diff --git a/python/dgl/distributed/graph_services.py b/python/dgl/distributed/graph_services.py index 0a732ca0e7b0..58eeb6de1f89 100644 --- a/python/dgl/distributed/graph_services.py +++ b/python/dgl/distributed/graph_services.py @@ -3,8 +3,10 @@ import numpy as np -from .. import backend as F -from ..base import EID, NID +import torch + +from .. import backend as F, graphbolt as gb +from ..base import EID, ETYPE, NID from ..convert import graph, heterograph from ..sampling import ( sample_etype_neighbors as local_sample_etype_neighbors, @@ -38,16 +40,29 @@ class SubgraphResponse(Response): """The response for sampling and in_subgraph""" - def __init__(self, global_src, global_dst, global_eids): + def __init__( + self, global_src, global_dst, *, global_eids=None, etype_ids=None + ): self.global_src = global_src self.global_dst = global_dst self.global_eids = global_eids + self.etype_ids = etype_ids def __setstate__(self, state): - self.global_src, self.global_dst, self.global_eids = state + ( + self.global_src, + self.global_dst, + self.global_eids, + self.etype_ids, + ) = state def __getstate__(self): - return self.global_src, self.global_dst, self.global_eids + return ( + self.global_src, + self.global_dst, + self.global_eids, + self.etype_ids, + ) class FindEdgeResponse(Response): @@ -65,8 +80,99 @@ def __getstate__(self): return self.global_src, self.global_dst, self.order_id -def _sample_neighbors( - local_g, partition_book, seed_nodes, fan_out, edge_dir, prob, replace +def _sample_neighbors_graphbolt( + g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False +): + """Sample from local partition via graphbolt. + + The input nodes use global IDs. We need to map the global node IDs to local + node IDs, perform sampling and map the sampled results to the global IDs + space again. The sampled results are stored in three vectors that store + source nodes, destination nodes, etype IDs and edge IDs. + + Parameters + ---------- + g : FusedCSCSamplingGraph + The local partition. + gpb : GraphPartitionBook + The graph partition book. + nodes : tensor + The nodes to sample neighbors from. + fanout : tensor or int + The number of edges to be sampled for each node. + edge_dir : str, optional + Determines whether to sample inbound or outbound edges. + prob : tensor, optional + The probability associated with each neighboring edge of a node. + replace : bool, optional + If True, sample with replacement. + + Returns + ------- + tensor + The source node ID array. + tensor + The destination node ID array. + tensor + The edge ID array. + tensor + The edge type ID array. + """ + assert ( + edge_dir == "in" + ), f"GraphBolt only supports inbound edge sampling but got {edge_dir}." + + # 1. Map global node IDs to local node IDs. + nodes = gpb.nid2localnid(nodes, gpb.partid) + + # 2. Perform sampling. + # [Rui][TODO] `prob` and `replace` are not tested yet. Skip for now. + assert ( + prob is None + ), "DistGraphBolt does not support sampling with probability." + assert ( + not replace + ), "DistGraphBolt does not support sampling with replacement." + + # Sanity checks. + assert isinstance( + g, gb.FusedCSCSamplingGraph + ), "Expect a FusedCSCSamplingGraph." + assert isinstance(nodes, torch.Tensor), "Expect a tensor of nodes." + if isinstance(fanout, int): + fanout = torch.LongTensor([fanout]) + assert isinstance(fanout, torch.Tensor), "Expect a tensor of fanout." + # [Rui][TODO] Support multiple fanouts. + assert fanout.numel() == 1, "Expect a single fanout." + + return_eids = g.edge_attributes is not None and EID in g.edge_attributes + subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids) + + # 3. Map local node IDs to global node IDs. + local_src = subgraph.indices + local_dst = torch.repeat_interleave( + subgraph.original_column_node_ids, torch.diff(subgraph.indptr) + ) + global_nid_mapping = g.node_attributes[NID] + global_src = global_nid_mapping[local_src] + global_dst = global_nid_mapping[local_dst] + + global_eids = None + if return_eids: + global_eids = g.edge_attributes[EID][subgraph.original_edge_ids] + return LocalSampledGraph( + global_src, global_dst, global_eids, subgraph.type_per_edge + ) + + +def _sample_neighbors_dgl( + local_g, + partition_book, + seed_nodes, + fan_out, + edge_dir="in", + prob=None, + replace=False, ): """Sample from local partition. @@ -93,7 +199,38 @@ def _sample_neighbors( global_nid_mapping, src ), F.gather_row(global_nid_mapping, dst) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) - return global_src, global_dst, global_eids + return LocalSampledGraph(global_src, global_dst, global_eids) + + +def _sample_neighbors(use_graphbolt, *args, **kwargs): + """Wrapper for sampling neighbors. + + The actual sampling function depends on whether to use GraphBolt. + + Parameters + ---------- + use_graphbolt : bool + Whether to use GraphBolt for sampling. + args : list + The arguments for the sampling function. + kwargs : dict + The keyword arguments for the sampling function. + + Returns + ------- + tensor + The source node ID array. + tensor + The destination node ID array. + tensor + The edge ID array. + tensor + The edge type ID array. + """ + func = ( + _sample_neighbors_graphbolt if use_graphbolt else _sample_neighbors_dgl + ) + return func(*args, **kwargs) def _sample_etype_neighbors( @@ -134,7 +271,7 @@ def _sample_etype_neighbors( global_nid_mapping, src ), F.gather_row(global_nid_mapping, dst) global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) - return global_src, global_dst, global_eids + return LocalSampledGraph(global_src, global_dst, global_eids) def _find_edges(local_g, partition_book, seed_edges): @@ -180,7 +317,7 @@ def _in_subgraph(local_g, partition_book, seed_nodes): src, dst = sampled_graph.edges() global_src, global_dst = global_nid_mapping[src], global_nid_mapping[dst] global_eids = F.gather_row(local_g.edata[EID], sampled_graph.edata[EID]) - return global_src, global_dst, global_eids + return LocalSampledGraph(global_src, global_dst, global_eids) # --- NOTE 1 --- @@ -212,12 +349,21 @@ def _in_subgraph(local_g, partition_book, seed_nodes): class SamplingRequest(Request): """Sampling Request""" - def __init__(self, nodes, fan_out, edge_dir="in", prob=None, replace=False): + def __init__( + self, + nodes, + fan_out, + edge_dir="in", + prob=None, + replace=False, + use_graphbolt=False, + ): self.seed_nodes = nodes self.edge_dir = edge_dir self.prob = prob self.replace = replace self.fan_out = fan_out + self.use_graphbolt = use_graphbolt def __setstate__(self, state): ( @@ -226,6 +372,7 @@ def __setstate__(self, state): self.prob, self.replace, self.fan_out, + self.use_graphbolt, ) = state def __getstate__(self): @@ -235,6 +382,7 @@ def __getstate__(self): self.prob, self.replace, self.fan_out, + self.use_graphbolt, ) def process_request(self, server_state): @@ -245,16 +393,22 @@ def process_request(self, server_state): prob = [kv_store.data_store[self.prob]] else: prob = None - global_src, global_dst, global_eids = _sample_neighbors( + res = _sample_neighbors( + self.use_graphbolt, local_g, partition_book, self.seed_nodes, self.fan_out, - self.edge_dir, - prob, - self.replace, + edge_dir=self.edge_dir, + prob=prob, + replace=self.replace, + ) + return SubgraphResponse( + res.global_src, + res.global_dst, + global_eids=res.global_eids, + etype_ids=res.etype_ids, ) - return SubgraphResponse(global_src, global_dst, global_eids) class SamplingRequestEtype(Request): @@ -309,7 +463,7 @@ def process_request(self, server_state): ] else: probs = None - global_src, global_dst, global_eids = _sample_etype_neighbors( + res = _sample_etype_neighbors( local_g, partition_book, self.seed_nodes, @@ -320,7 +474,12 @@ def process_request(self, server_state): self.replace, self.etype_sorted, ) - return SubgraphResponse(global_src, global_dst, global_eids) + return SubgraphResponse( + res.global_src, + res.global_dst, + global_eids=res.global_eids, + etype_ids=res.etype_ids, + ) class EdgesRequest(Request): @@ -434,7 +593,7 @@ def process_request(self, server_state): global_src, global_dst, global_eids = _in_subgraph( local_g, partition_book, self.seed_nodes ) - return SubgraphResponse(global_src, global_dst, global_eids) + return SubgraphResponse(global_src, global_dst, global_eids=global_eids) def merge_graphs(res_list, num_nodes): @@ -443,24 +602,33 @@ def merge_graphs(res_list, num_nodes): srcs = [] dsts = [] eids = [] + etype_ids = [] for res in res_list: srcs.append(res.global_src) dsts.append(res.global_dst) eids.append(res.global_eids) + etype_ids.append(res.etype_ids) src_tensor = F.cat(srcs, 0) dst_tensor = F.cat(dsts, 0) - eid_tensor = F.cat(eids, 0) + eid_tensor = None if eids[0] is None else F.cat(eids, 0) + etype_id_tensor = None if etype_ids[0] is None else F.cat(etype_ids, 0) else: src_tensor = res_list[0].global_src dst_tensor = res_list[0].global_dst eid_tensor = res_list[0].global_eids + etype_id_tensor = res_list[0].etype_ids g = graph((src_tensor, dst_tensor), num_nodes=num_nodes) - g.edata[EID] = eid_tensor + if eid_tensor is not None: + g.edata[EID] = eid_tensor + if etype_id_tensor is not None: + g.edata[ETYPE] = etype_id_tensor return g -LocalSampledGraph = namedtuple( - "LocalSampledGraph", "global_src global_dst global_eids" +LocalSampledGraph = namedtuple( # pylint: disable=unexpected-keyword-arg + "LocalSampledGraph", + "global_src global_dst global_eids etype_ids", + defaults=(None, None, None, None), ) @@ -491,7 +659,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): """ req_list = [] partition_book = g.get_partition_book() - nodes = toindex(nodes).tousertensor() + if not isinstance(nodes, torch.Tensor): + nodes = toindex(nodes).tousertensor() partition_id = partition_book.nid2partid(nodes) local_nids = None for pid in range(partition_book.num_partitions()): @@ -515,10 +684,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): # sample neighbors for the nodes in the local partition. res_list = [] if local_nids is not None: - src, dst, eids = local_access( - g.local_partition, partition_book, local_nids - ) - res_list.append(LocalSampledGraph(src, dst, eids)) + res = local_access(g.local_partition, partition_book, local_nids) + res_list.append(res) # receive responses from remote machines. if msgseq2pos is not None: @@ -721,7 +888,15 @@ def local_access(local_g, partition_book, local_nids): return frontier -def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): +def sample_neighbors( + g, + nodes, + fanout, + edge_dir="in", + prob=None, + replace=False, + use_graphbolt=False, +): """Sample from the neighbors of the given nodes from a distributed graph. For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges @@ -764,6 +939,8 @@ def sample_neighbors(g, nodes, fanout, edge_dir="in", prob=None, replace=False): For sampling without replacement, if fanout > the number of neighbors, all the neighbors are sampled. If fanout == -1, all neighbors are collected. + use_graphbolt : bool, optional + Whether to use GraphBolt for sampling. Returns ------- @@ -795,20 +972,26 @@ def issue_remote_req(node_ids): else: _prob = None return SamplingRequest( - node_ids, fanout, edge_dir=edge_dir, prob=_prob, replace=replace + node_ids, + fanout, + edge_dir=edge_dir, + prob=_prob, + replace=replace, + use_graphbolt=use_graphbolt, ) def local_access(local_g, partition_book, local_nids): # See NOTE 1 _prob = [g.edata[prob].local_partition] if prob is not None else None return _sample_neighbors( + use_graphbolt, local_g, partition_book, local_nids, fanout, - edge_dir, - _prob, - replace, + edge_dir=edge_dir, + prob=_prob, + replace=replace, ) frontier = _distributed_access(g, nodes, issue_remote_req, local_access) diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index f9086ef1e888..a32357cea760 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -1,5 +1,6 @@ """Base types and utilities for Graph Bolt.""" +from collections import deque from dataclasses import dataclass import torch @@ -14,7 +15,12 @@ "etype_str_to_tuple", "etype_tuple_to_str", "CopyTo", + "FutureWaiter", + "Waiter", + "Bufferer", + "EndMarker", "isin", + "index_select", "expand_indptr", "CSCFormatBase", "seed", @@ -102,6 +108,33 @@ def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None): ) +def index_select(tensor, index): + """Returns a new tensor which indexes the input tensor along dimension dim + using the entries in index. + + The returned tensor has the same number of dimensions as the original tensor + (tensor). The first dimension has the same size as the length of index; + other dimensions have the same size as in the original tensor. + + When tensor is a pinned tensor and index.is_cuda is True, the operation runs + on the CUDA device and the returned tensor will also be on CUDA. + + Parameters + ---------- + tensor : torch.Tensor + The input tensor. + index : torch.Tensor + The 1-D tensor containing the indices to index. + + Returns + ------- + torch.Tensor + The indexed input tensor, equivalent to tensor[index]. + """ + assert index.dim() == 1, "Index should be 1D tensor." + return torch.ops.graphbolt.index_select(tensor, index) + + def etype_tuple_to_str(c_etype): """Convert canonical etype from tuple to string. @@ -219,6 +252,76 @@ def __iter__(self): yield data +@functional_datapipe("mark_end") +class EndMarker(IterDataPipe): + """Used to mark the end of a datapipe and is a no-op.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + yield from self.datapipe + + +@functional_datapipe("buffer") +class Bufferer(IterDataPipe): + """Buffers items before yielding them. + + Parameters + ---------- + datapipe : DataPipe + The data pipeline. + buffer_size : int, optional + The size of the buffer which stores the fetched samples. If data coming + from datapipe has latency spikes, consider setting to a higher value. + Default is 1. + """ + + def __init__(self, datapipe, buffer_size=1): + self.datapipe = datapipe + if buffer_size <= 0: + raise ValueError( + "'buffer_size' is required to be a positive integer." + ) + self.buffer = deque(maxlen=buffer_size) + + def __iter__(self): + for data in self.datapipe: + if len(self.buffer) < self.buffer.maxlen: + self.buffer.append(data) + else: + return_data = self.buffer.popleft() + self.buffer.append(data) + yield return_data + while len(self.buffer) > 0: + yield self.buffer.popleft() + + +@functional_datapipe("wait") +class Waiter(IterDataPipe): + """Calls the wait function of all items.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for data in self.datapipe: + data.wait() + yield data + + +@functional_datapipe("wait_future") +class FutureWaiter(IterDataPipe): + """Calls the result function of all items and returns their results.""" + + def __init__(self, datapipe): + self.datapipe = datapipe + + def __iter__(self): + for data in self.datapipe: + yield data.result() + + @dataclass class CSCFormatBase: r"""Basic class representing data in Compressed Sparse Column (CSC) format. diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index b0dd9daccfaf..cffb24070a06 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -1,6 +1,6 @@ """Graph Bolt DataLoaders""" -from collections import deque +from concurrent.futures import ThreadPoolExecutor import torch import torch.utils.data @@ -9,6 +9,7 @@ from .base import CopyTo from .feature_fetcher import FeatureFetcher +from .impl.neighbor_sampler import SamplePerLayer from .internal import datapipe_graph_to_adjlist from .item_sampler import ItemSampler @@ -16,8 +17,6 @@ __all__ = [ "DataLoader", - "Awaiter", - "Bufferer", ] @@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs): return datapipe_graph -class EndMarker(dp.iter.IterDataPipe): - """Used to mark the end of a datapipe and is a no-op.""" - - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - yield from self.datapipe - - -class Bufferer(dp.iter.IterDataPipe): - """Buffers items before yielding them. - - Parameters - ---------- - datapipe : DataPipe - The data pipeline. - buffer_size : int, optional - The size of the buffer which stores the fetched samples. If data coming - from datapipe has latency spikes, consider setting to a higher value. - Default is 1. - """ - - def __init__(self, datapipe, buffer_size=1): - self.datapipe = datapipe - if buffer_size <= 0: - raise ValueError( - "'buffer_size' is required to be a positive integer." - ) - self.buffer = deque(maxlen=buffer_size) - - def __iter__(self): - for data in self.datapipe: - if len(self.buffer) < self.buffer.maxlen: - self.buffer.append(data) - else: - return_data = self.buffer.popleft() - self.buffer.append(data) - yield return_data - while len(self.buffer) > 0: - yield self.buffer.popleft() - - -class Awaiter(dp.iter.IterDataPipe): - """Calls the wait function of all items.""" - - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - for data in self.datapipe: - data.wait() - yield data - - class MultiprocessingWrapper(dp.iter.IterDataPipe): """Wraps a datapipe with multiprocessing. @@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader): If True, the data loader will overlap the UVA feature fetcher operations with the rest of operations by using an alternative CUDA stream. Default is True. + overlap_graph_fetch : bool, optional + If True, the data loader will overlap the UVA graph fetching operations + with the rest of operations by using an alternative CUDA stream. Default + is False. max_uva_threads : int, optional Limits the number of CUDA threads used for UVA copies so that the rest of the computations can run simultaneously with it. Setting it to a too @@ -170,6 +118,7 @@ def __init__( num_workers=0, persistent_workers=True, overlap_feature_fetch=True, + overlap_graph_fetch=False, max_uva_threads=6144, ): # Multiprocessing requires two modifications to the datapipe: @@ -179,7 +128,7 @@ def __init__( # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # of the FeatureFetcher with a multiprocessing PyTorch DataLoader. - datapipe = EndMarker(datapipe) + datapipe = datapipe.mark_end() datapipe_graph = dp_utils.traverse_dps(datapipe) # (1) Insert minibatch distribution. @@ -223,7 +172,25 @@ def __init__( datapipe_graph = dp_utils.replace_dp( datapipe_graph, feature_fetcher, - Awaiter(Bufferer(feature_fetcher, buffer_size=1)), + feature_fetcher.buffer(1).wait(), + ) + + if ( + overlap_graph_fetch + and num_workers == 0 + and torch.cuda.is_available() + ): + torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) + samplers = dp_utils.find_dps( + datapipe_graph, + SamplePerLayer, + ) + executor = ThreadPoolExecutor(max_workers=1) + for sampler in samplers: + datapipe_graph = dp_utils.replace_dp( + datapipe_graph, + sampler, + sampler.fetch_and_sample(_get_uva_stream(), executor, 1), ) # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 7b94d1e1b3d8..01ff25af8c15 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -174,10 +174,5 @@ def _read(self, data): with torch.cuda.stream(self.stream): data = self._read_data(data, current_stream) if self.stream is not None: - event = torch.cuda.current_stream().record_event() - - def _wait(): - event.wait() - - data.wait = _wait + data.wait = torch.cuda.current_stream().record_event().wait return data diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index b8cee5e18a7c..de81c137833b 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -625,8 +625,16 @@ def sample_neighbors( if isinstance(nodes, dict): nodes = self._convert_to_homogeneous_nodes(nodes) + return_eids = ( + self.edge_attributes is not None + and ORIGINAL_EDGE_ID in self.edge_attributes + ) C_sampled_subgraph = self._sample_neighbors( - nodes, fanouts, replace, probs_name + nodes, + fanouts, + replace=replace, + probs_name=probs_name, + return_eids=return_eids, ) return self._convert_to_sampled_subgraph(C_sampled_subgraph) @@ -679,6 +687,7 @@ def _sample_neighbors( fanouts: torch.Tensor, replace: bool = False, probs_name: Optional[str] = None, + return_eids: bool = False, ) -> torch.ScriptObject: """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -714,6 +723,9 @@ def _sample_neighbors( corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. + return_eids: bool, optional + Boolean indicating whether to return the original edge IDs of the + sampled edges. Returns ------- @@ -722,16 +734,12 @@ def _sample_neighbors( """ # Ensure nodes is 1-D tensor. self._check_sampler_arguments(nodes, fanouts, probs_name) - has_original_eids = ( - self.edge_attributes is not None - and ORIGINAL_EDGE_ID in self.edge_attributes - ) return self._c_csc_graph.sample_neighbors( nodes, fanouts.tolist(), replace, False, - has_original_eids, + return_eids, probs_name, ) @@ -1018,7 +1026,13 @@ def sample_negative_edges_uniform_2( torch.cat( ( pos_src.repeat_interleave(negative_ratio), - torch.randint(0, max_node_id, (num_negative,)), + torch.randint( + 0, + max_node_id, + (num_negative,), + dtype=node_pairs.dtype, + device=node_pairs.device, + ), ), ) .view(2, num_negative) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 605da8ff5ce3..737b475a94e6 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -1,18 +1,163 @@ """Neighbor subgraph samplers for GraphBolt.""" +from concurrent.futures import ThreadPoolExecutor from functools import partial import torch from torch.utils.data import functional_datapipe +from torchdata.datapipes.iter import Mapper from ..internal import compact_csc_format, unique_and_compact_csc_formats from ..minibatch_transformer import MiniBatchTransformer from ..subgraph_sampler import SubgraphSampler +from .fused_csc_sampling_graph import fused_csc_sampling_graph from .sampled_subgraph_impl import SampledSubgraphImpl -__all__ = ["NeighborSampler", "LayerNeighborSampler"] +__all__ = [ + "NeighborSampler", + "LayerNeighborSampler", + "SamplePerLayer", + "SamplePerLayerFromFetchedSubgraph", + "FetchInsubgraphData", +] + + +@functional_datapipe("fetch_insubgraph_data") +class FetchInsubgraphData(Mapper): + """Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If + the provided sample_per_layer_obj has a valid prob_name, then it reads the + probabilies of all the fetched edges. Furthermore, if type_per_array tensor + exists in the underlying graph, then the types of all the fetched edges are + read as well.""" + + def __init__( + self, datapipe, sample_per_layer_obj, stream=None, executor=None + ): + super().__init__(datapipe, self._fetch_per_layer) + self.graph = sample_per_layer_obj.sampler.__self__ + self.prob_name = sample_per_layer_obj.prob_name + self.stream = stream + if executor is None: + self.executor = ThreadPoolExecutor(max_workers=1) + else: + self.executor = executor + + def _fetch_per_layer_impl(self, minibatch, stream): + with torch.cuda.stream(self.stream): + index = minibatch._seed_nodes + if isinstance(index, dict): + for idx in index.values(): + idx.record_stream(torch.cuda.current_stream()) + index = self.graph._convert_to_homogeneous_nodes(index) + else: + index.record_stream(torch.cuda.current_stream()) + + def record_stream(tensor): + if stream is not None and tensor.is_cuda: + tensor.record_stream(stream) + return tensor + + if self.graph.node_type_offset is None: + # sorting not needed. + minibatch._subgraph_seed_nodes = None + else: + index, original_positions = index.sort() + if (original_positions.diff() == 1).all().item(): + # already sorted. + minibatch._subgraph_seed_nodes = None + else: + minibatch._subgraph_seed_nodes = record_stream( + original_positions.sort()[1] + ) + index_select_csc_with_indptr = partial( + torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr + ) + + indptr, indices = index_select_csc_with_indptr( + self.graph.indices, index, None + ) + record_stream(indptr) + record_stream(indices) + output_size = len(indices) + if self.graph.type_per_edge is not None: + _, type_per_edge = index_select_csc_with_indptr( + self.graph.type_per_edge, index, output_size + ) + record_stream(type_per_edge) + else: + type_per_edge = None + if self.graph.edge_attributes is not None: + probs_or_mask = self.graph.edge_attributes.get( + self.prob_name, None + ) + if probs_or_mask is not None: + _, probs_or_mask = index_select_csc_with_indptr( + probs_or_mask, index, output_size + ) + record_stream(probs_or_mask) + else: + probs_or_mask = None + if self.graph.node_type_offset is not None: + node_type_offset = torch.searchsorted( + index, self.graph.node_type_offset + ) + else: + node_type_offset = None + subgraph = fused_csc_sampling_graph( + indptr, + indices, + node_type_offset=node_type_offset, + type_per_edge=type_per_edge, + node_type_to_id=self.graph.node_type_to_id, + edge_type_to_id=self.graph.edge_type_to_id, + ) + if self.prob_name is not None and probs_or_mask is not None: + subgraph.edge_attributes = {self.prob_name: probs_or_mask} + + minibatch.sampled_subgraphs.insert(0, subgraph) + + if self.stream is not None: + minibatch.wait = torch.cuda.current_stream().record_event().wait + + return minibatch + + def _fetch_per_layer(self, minibatch): + current_stream = None + if self.stream is not None: + current_stream = torch.cuda.current_stream() + self.stream.wait_stream(current_stream) + return self.executor.submit( + self._fetch_per_layer_impl, minibatch, current_stream + ) + + +@functional_datapipe("sample_per_layer_from_fetched_subgraph") +class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): + """Sample neighbor edges from a graph for a single layer.""" + + def __init__(self, datapipe, sample_per_layer_obj): + super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph) + self.sampler_name = sample_per_layer_obj.sampler.__name__ + self.fanout = sample_per_layer_obj.fanout + self.replace = sample_per_layer_obj.replace + self.prob_name = sample_per_layer_obj.prob_name + + def _sample_per_layer_from_fetched_subgraph(self, minibatch): + subgraph = minibatch.sampled_subgraphs[0] + + sampled_subgraph = getattr(subgraph, self.sampler_name)( + minibatch._subgraph_seed_nodes, + self.fanout, + self.replace, + self.prob_name, + ) + delattr(minibatch, "_subgraph_seed_nodes") + sampled_subgraph.original_column_node_ids = minibatch._seed_nodes + minibatch.sampled_subgraphs[0] = sampled_subgraph + + return minibatch @functional_datapipe("sample_per_layer") @@ -72,6 +217,19 @@ def _compact_per_layer(self, minibatch): return minibatch +@functional_datapipe("fetch_and_sample") +class FetcherAndSampler(MiniBatchTransformer): + """Overlapped graph sampling operation replacement.""" + + def __init__(self, sampler, stream, executor, buffer_size): + datapipe = sampler.datapipe.fetch_insubgraph_data( + sampler, stream, executor + ) + datapipe = datapipe.buffer(buffer_size).wait_future().wait() + datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler) + super().__init__(datapipe) + + @functional_datapipe("sample_neighbor") class NeighborSampler(SubgraphSampler): # pylint: disable=abstract-method @@ -173,7 +331,8 @@ def __init__( datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler ) - def _prepare(self, node_type_to_id, minibatch): + @staticmethod + def _prepare(node_type_to_id, minibatch): seeds = minibatch._seed_nodes # Enrich seeds with all node types. if isinstance(seeds, dict): diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 0799c93ea93a..577e29b7325b 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -7,6 +7,7 @@ import numpy as np import torch +from ..base import index_select from ..feature_store import Feature from .basic_feature_store import BasicFeatureStore from .ondisk_metadata import OnDiskFeatureData @@ -117,7 +118,7 @@ def read(self, ids: torch.Tensor = None): if self._tensor.is_pinned(): return self._tensor.cuda() return self._tensor - return torch.ops.graphbolt.index_select(self._tensor, ids) + return index_select(self._tensor, ids) def size(self): """Get the size of the feature. @@ -144,11 +145,6 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): updated. """ if ids is None: - assert self.size() == value.size()[1:], ( - f"ids is None, so the entire feature will be updated. " - f"But the size of the feature is {self.size()}, " - f"while the size of the value is {value.size()[1:]}." - ) self._tensor = value else: assert ids.shape[0] == value.shape[0], ( diff --git a/python/dgl/graphbolt/impl/uniform_negative_sampler.py b/python/dgl/graphbolt/impl/uniform_negative_sampler.py index cc7fa4e8fb1f..1b95d07f6601 100644 --- a/python/dgl/graphbolt/impl/uniform_negative_sampler.py +++ b/python/dgl/graphbolt/impl/uniform_negative_sampler.py @@ -76,15 +76,30 @@ def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False): # Construct indexes for all node pairs. num_pos_node_pairs = node_pairs.shape[0] negative_ratio = self.negative_ratio - pos_indexes = torch.arange(0, num_pos_node_pairs) + pos_indexes = torch.arange( + 0, + num_pos_node_pairs, + device=seeds.device, + ) neg_indexes = pos_indexes.repeat_interleave(negative_ratio) indexes = torch.cat((pos_indexes, neg_indexes)) # Construct labels for all node pairs. pos_num = node_pairs.shape[0] neg_num = seeds.shape[0] - pos_num labels = torch.cat( - (torch.ones(pos_num), torch.zeros(neg_num)) - ).bool() + ( + torch.ones( + pos_num, + dtype=torch.bool, + device=seeds.device, + ), + torch.zeros( + neg_num, + dtype=torch.bool, + device=seeds.device, + ), + ), + ) return seeds, labels, indexes else: return self.graph.sample_negative_edges_uniform( diff --git a/python/dgl/graphbolt/minibatch_transformer.py b/python/dgl/graphbolt/minibatch_transformer.py index 8822f2ac6203..b7b00b7a1b29 100644 --- a/python/dgl/graphbolt/minibatch_transformer.py +++ b/python/dgl/graphbolt/minibatch_transformer.py @@ -29,10 +29,10 @@ class MiniBatchTransformer(Mapper): def __init__( self, datapipe, - transformer, + transformer=None, ): super().__init__(datapipe, self._transformer) - self.transformer = transformer + self.transformer = transformer or self._identity def _transformer(self, minibatch): minibatch = self.transformer(minibatch) @@ -40,3 +40,7 @@ def _transformer(self, minibatch): minibatch, (MiniBatch,) ), "The transformer output should be an instance of MiniBatch" return minibatch + + @staticmethod + def _identity(minibatch): + return minibatch diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index b05b8ca30619..ab7c969063c9 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -46,11 +46,7 @@ def __init__( datapipe = datapipe.transform(self._preprocess) datapipe = self.sampling_stages(datapipe, *args, **kwargs) datapipe = datapipe.transform(self._postprocess) - super().__init__(datapipe, self._identity) - - @staticmethod - def _identity(minibatch): - return minibatch + super().__init__(datapipe) @staticmethod def _postprocess(minibatch): diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 9eb47342455b..0795d4a03d25 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -1,7 +1,7 @@ import multiprocessing as mp import os import random -import sys +import tempfile import time import traceback import unittest @@ -31,6 +31,7 @@ def start_server( disable_shared_mem, graph_name, graph_format=["csc", "coo"], + use_graphbolt=False, ): g = DistGraphServer( rank, @@ -40,6 +41,7 @@ def start_server( tmpdir / (graph_name + ".json"), disable_shared_mem=disable_shared_mem, graph_format=graph_format, + use_graphbolt=use_graphbolt, ) g.start() @@ -72,6 +74,8 @@ def start_sample_client_shuffle( group_id, orig_nid, orig_eid, + use_graphbolt=False, + return_eids=False, ): os.environ["DGL_GROUP_ID"] = str(group_id) gpb = None @@ -80,17 +84,26 @@ def start_sample_client_shuffle( tmpdir / "test_sampling.json", rank ) dgl.distributed.initialize("rpc_ip_config.txt") - dist_graph = DistGraph("test_sampling", gpb=gpb) - sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) + dist_graph = DistGraph( + "test_sampling", gpb=gpb, use_graphbolt=use_graphbolt + ) + sampled_graph = sample_neighbors( + dist_graph, [0, 10, 99, 66, 1024, 2008], 3, use_graphbolt=use_graphbolt + ) src, dst = sampled_graph.edges() src = orig_nid[src] dst = orig_nid[dst] assert sampled_graph.num_nodes() == g.num_nodes() assert np.all(F.asnumpy(g.has_edges_between(src, dst))) - eids = g.edge_ids(src, dst) - eids1 = orig_eid[sampled_graph.edata[dgl.EID]] - assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids)) + if use_graphbolt and not return_eids: + assert ( + dgl.EID not in sampled_graph.edata + ), "EID should not be in sampled graph if use_graphbolt=True." + else: + eids = g.edge_ids(src, dst) + eids1 = orig_eid[sampled_graph.edata[dgl.EID]] + assert np.array_equal(F.asnumpy(eids1), F.asnumpy(eids)) def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids, etype=None): @@ -378,7 +391,9 @@ def test_rpc_sampling(): check_rpc_sampling(Path(tmpdirname), 1) -def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): +def check_rpc_sampling_shuffle( + tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False +): generate_ip_config("rpc_ip_config.txt", num_server, num_server) g = CitationGraphDataset("cora")[0] @@ -393,6 +408,8 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) pserver_list = [] @@ -406,6 +423,7 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): num_server > 1, "test_sampling", ["csc", "coo"], + use_graphbolt, ), ) p.start() @@ -427,6 +445,8 @@ def check_rpc_sampling_shuffle(tmpdir, num_server, num_groups=1): group_id, orig_nids, orig_eids, + use_graphbolt, + return_eids, ), ) p.start() @@ -996,44 +1016,89 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server): assert np.all(F.asnumpy(orig_dst1) == orig_dst) -# Wait non shared memory graph store -@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") -@unittest.skipIf( - dgl.backend.backend_name == "tensorflow", - reason="Not support tensorflow for now", -) -@unittest.skipIf( - dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support" -) @pytest.mark.parametrize("num_server", [1]) -def test_rpc_sampling_shuffle(num_server): +@pytest.mark.parametrize("use_graphbolt", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids): reset_envs() - import tempfile + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: + check_rpc_sampling_shuffle( + Path(tmpdirname), + num_server, + use_graphbolt=use_graphbolt, + return_eids=return_eids, + ) + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_hetero_sampling_shuffle(num_server): + reset_envs() os.environ["DGL_DIST_MODE"] = "distributed" with tempfile.TemporaryDirectory() as tmpdirname: - check_rpc_sampling_shuffle(Path(tmpdirname), num_server) - # [TODO][Rhett] Tests for multiple groups may fail sometimes and - # root cause is unknown. Let's disable them for now. - # check_rpc_sampling_shuffle(Path(tmpdirname), num_server, num_groups=2) check_rpc_hetero_sampling_shuffle(Path(tmpdirname), num_server) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_hetero_sampling_empty_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) - check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server) - check_rpc_hetero_etype_sampling_shuffle( - Path(tmpdirname), num_server, ["csc"] - ) - check_rpc_hetero_etype_sampling_shuffle( - Path(tmpdirname), num_server, ["csr"] - ) + + +@pytest.mark.parametrize("num_server", [1]) +@pytest.mark.parametrize( + "graph_formats", [None, ["csc"], ["csr"], ["csc", "coo"]] +) +def test_rpc_hetero_etype_sampling_shuffle(num_server, graph_formats): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_hetero_etype_sampling_shuffle( - Path(tmpdirname), num_server, ["csc", "coo"] + Path(tmpdirname), num_server, graph_formats=graph_formats ) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_hetero_etype_sampling_empty_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_hetero_etype_sampling_empty_shuffle( Path(tmpdirname), num_server ) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_bipartite_sampling_empty_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_bipartite_sampling_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_bipartite_etype_sampling_empty_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server) + + +@pytest.mark.parametrize("num_server", [1]) +def test_rpc_bipartite_etype_sampling_shuffle(num_server): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server) diff --git a/tests/python/pytorch/graphbolt/gb_test_utils.py b/tests/python/pytorch/graphbolt/gb_test_utils.py index dd7abc74da0c..59c4c3a90276 100644 --- a/tests/python/pytorch/graphbolt/gb_test_utils.py +++ b/tests/python/pytorch/graphbolt/gb_test_utils.py @@ -269,7 +269,7 @@ def genereate_raw_data_for_hetero_dataset( # Generate train/test/valid set. os.makedirs(os.path.join(test_dir, "set"), exist_ok=True) user_ids = torch.arange(num_nodes["user"]) - np.random.shuffle(user_ids) + np.random.shuffle(user_ids.numpy()) num_train = int(num_nodes["user"] * 0.6) num_validation = int(num_nodes["user"] * 0.2) num_test = num_nodes["user"] - num_train - num_validation diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index d251701cdaf9..eb9a62babff1 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -28,14 +28,16 @@ torch.float64, ], ) -def test_gpu_cached_feature(dtype): +@pytest.mark.parametrize("cache_size_a", [1, 1024]) +@pytest.mark.parametrize("cache_size_b", [1, 1024]) +def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dtype, pin_memory=True) b = torch.tensor( [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True ) - feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), 2) - feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), 1) + feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a) + feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b) # Test read the entire feature. assert torch.equal(feat_store_a.read(), a.to("cuda")) @@ -52,6 +54,23 @@ def test_gpu_cached_feature(dtype): "cuda" ), ) + assert torch.equal( + feat_store_a.read(torch.tensor([1, 1]).to("cuda")), + torch.tensor([[4, 5, 6], [4, 5, 6]], dtype=dtype).to("cuda"), + ) + assert torch.equal( + feat_store_b.read(torch.tensor([0]).to("cuda")), + torch.tensor([[[1, 2], [3, 4]]], dtype=dtype).to("cuda"), + ) + # The cache should be full now for the large cache sizes, %100 hit expected. + if cache_size_a >= 1024: + total_miss = feat_store_a._feature.total_miss + feat_store_a.read(torch.tensor([0, 1]).to("cuda")) + assert total_miss == feat_store_a._feature.total_miss + if cache_size_b >= 1024: + total_miss = feat_store_b._feature.total_miss + feat_store_b.read(torch.tensor([0, 1]).to("cuda")) + assert total_miss == feat_store_b._feature.total_miss # Test get the size of the entire feature with ids. assert feat_store_a.size() == torch.Size([3]) diff --git a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py index 44aab2d8b8bb..9b2f783b7d86 100644 --- a/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_negative_sampler.py @@ -1,5 +1,7 @@ import re +import backend as F + import dgl.graphbolt as gb import pytest import torch @@ -14,7 +16,9 @@ def test_NegativeSampler_invoke(): torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) negative_ratio = 2 # Invoke NegativeSampler via class constructor. @@ -35,13 +39,17 @@ def test_NegativeSampler_invoke(): def test_UniformNegativeSampler_invoke(): # Instantiate graph and required datapipes. - graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) + graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to( + F.ctx() + ) num_seeds = 30 item_set = gb.ItemSet( torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) negative_ratio = 2 def _verify(negative_sampler): @@ -70,13 +78,17 @@ def _verify(negative_sampler): def test_UniformNegativeSampler_node_pairs_invoke(): # Instantiate graph and required datapipes. - graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) + graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to( + F.ctx() + ) num_seeds = 30 item_set = gb.ItemSet( torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="node_pairs" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) negative_ratio = 2 # Verify iteration over UniformNegativeSampler. @@ -106,13 +118,17 @@ def _verify(negative_sampler): @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) def test_Uniform_NegativeSampler_node_pairs(negative_ratio): # Construct FusedCSCSamplingGraph. - graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) + graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to( + F.ctx() + ) num_seeds = 30 item_set = gb.ItemSet( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="node_pairs" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) # Construct NegativeSampler. negative_sampler = gb.UniformNegativeSampler( item_sampler, @@ -134,13 +150,17 @@ def test_Uniform_NegativeSampler_node_pairs(negative_ratio): @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) def test_Uniform_NegativeSampler(negative_ratio): # Construct FusedCSCSamplingGraph. - graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) + graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to( + F.ctx() + ) num_seeds = 30 item_set = gb.ItemSet( torch.arange(0, num_seeds * 2).reshape(-1, 2), names="seeds" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) # Construct NegativeSampler. negative_sampler = gb.UniformNegativeSampler( item_sampler, @@ -159,12 +179,15 @@ def test_Uniform_NegativeSampler(negative_ratio): neg_src = data.seeds[batch_size:, 0] assert torch.equal(pos_src.repeat_interleave(negative_ratio), neg_src) # Check labels. - assert torch.equal(data.labels[:batch_size], torch.ones(batch_size)) assert torch.equal( - data.labels[batch_size:], torch.zeros(batch_size * negative_ratio) + data.labels[:batch_size], torch.ones(batch_size).to(F.ctx()) + ) + assert torch.equal( + data.labels[batch_size:], + torch.zeros(batch_size * negative_ratio).to(F.ctx()), ) # Check indexes. - pos_indexes = torch.arange(0, batch_size) + pos_indexes = torch.arange(0, batch_size).to(F.ctx()) neg_indexes = pos_indexes.repeat_interleave(negative_ratio) expected_indexes = torch.cat((pos_indexes, neg_indexes)) assert torch.equal(data.indexes, expected_indexes) @@ -173,13 +196,17 @@ def test_Uniform_NegativeSampler(negative_ratio): def test_Uniform_NegativeSampler_error_shape(): # 1. seeds with shape N*3. # Construct FusedCSCSamplingGraph. - graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True) + graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True).to( + F.ctx() + ) num_seeds = 30 item_set = gb.ItemSet( torch.arange(0, num_seeds * 3).reshape(-1, 3), names="seeds" ) batch_size = 10 - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) negative_ratio = 2 # Construct NegativeSampler. negative_sampler = gb.UniformNegativeSampler( @@ -201,7 +228,9 @@ def test_Uniform_NegativeSampler_error_shape(): item_set = gb.ItemSet( torch.arange(0, num_seeds * 2).reshape(-1, 2, 1), names="seeds" ) - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) # Construct NegativeSampler. negative_sampler = gb.UniformNegativeSampler( item_sampler, @@ -220,7 +249,9 @@ def test_Uniform_NegativeSampler_error_shape(): # 3. seeds with shape N. # Construct FusedCSCSamplingGraph. item_set = gb.ItemSet(torch.arange(0, num_seeds), names="seeds") - item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) + item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( + F.ctx() + ) # Construct NegativeSampler. negative_sampler = gb.UniformNegativeSampler( item_sampler, @@ -260,7 +291,7 @@ def get_hetero_graph(): def test_NegativeSampler_Hetero_node_pairs_Data(): - graph = get_hetero_graph() + graph = get_hetero_graph().to(F.ctx()) itemset = gb.ItemSetDict( { "n1:e1:n2": gb.ItemSet( @@ -274,13 +305,13 @@ def test_NegativeSampler_Hetero_node_pairs_Data(): } ) - item_sampler = gb.ItemSampler(itemset, batch_size=2) + item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, 1) assert len(list(negative_dp)) == 5 def test_NegativeSampler_Hetero_Data(): - graph = get_hetero_graph() + graph = get_hetero_graph().to(F.ctx()) itemset = gb.ItemSetDict( { "n1:e1:n2": gb.ItemSet( @@ -295,7 +326,9 @@ def test_NegativeSampler_Hetero_Data(): ) batch_size = 2 negative_ratio = 1 - item_sampler = gb.ItemSampler(itemset, batch_size=batch_size) + item_sampler = gb.ItemSampler(itemset, batch_size=batch_size).copy_to( + F.ctx() + ) negative_dp = gb.UniformNegativeSampler(item_sampler, graph, negative_ratio) assert len(list(negative_dp)) == 5 # Perform negative sampling. @@ -311,5 +344,5 @@ def test_NegativeSampler_Hetero_Data(): for etype, seeds_data in data.seeds.items(): neg_src = seeds_data[batch_size:, 0] neg_dst = seeds_data[batch_size:, 1] - assert torch.equal(expected_neg_src[i][etype], neg_src) + assert torch.equal(expected_neg_src[i][etype].to(F.ctx()), neg_src) assert (neg_dst < 3).all(), neg_dst diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py new file mode 100644 index 000000000000..09528d98899d --- /dev/null +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -0,0 +1,77 @@ +import unittest +from functools import partial + +import backend as F + +import dgl +import dgl.graphbolt as gb +import pytest +import torch + + +def get_hetero_graph(): + # COO graph: + # [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] + # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1] + # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. + # num_nodes = 5, num_n1 = 2, num_n2 = 3 + ntypes = {"n1": 0, "n2": 1} + etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} + indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) + indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) + type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) + edge_attributes = { + "weight": torch.FloatTensor( + [2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5] + ), + "mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]), + } + node_type_offset = torch.LongTensor([0, 2, 5]) + return gb.fused_csc_sampling_graph( + indptr, + indices, + node_type_offset=node_type_offset, + type_per_edge=type_per_edge, + node_type_to_id=ntypes, + edge_type_to_id=etypes, + edge_attributes=edge_attributes, + ) + + +@unittest.skipIf(F._default_context_str != "gpu", reason="Enabled only on GPU.") +@pytest.mark.parametrize("hetero", [False, True]) +@pytest.mark.parametrize("prob_name", [None, "weight", "mask"]) +@pytest.mark.parametrize("sorted", [False, True]) +def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): + if sorted: + items = torch.arange(3) + else: + items = torch.tensor([2, 0, 1]) + names = "seed_nodes" + itemset = gb.ItemSet(items, names=names) + graph = get_hetero_graph().to(F.ctx()) + if hetero: + itemset = gb.ItemSetDict({"n2": itemset}) + else: + graph.type_per_edge = None + item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) + fanout = torch.LongTensor([2]) + datapipe = item_sampler.map(gb.SubgraphSampler._preprocess) + datapipe = datapipe.map( + partial(gb.NeighborSampler._prepare, graph.node_type_to_id) + ) + sample_per_layer = gb.SamplePerLayer( + datapipe, graph.sample_neighbors, fanout, False, prob_name + ) + compact_per_layer = sample_per_layer.compact_per_layer(True) + gb.seed(123) + expected_results = list(compact_per_layer) + datapipe = gb.FetchInsubgraphData(datapipe, sample_per_layer) + datapipe = datapipe.wait_future() + datapipe = gb.SamplePerLayerFromFetchedSubgraph(datapipe, sample_per_layer) + datapipe = datapipe.compact_per_layer(True) + gb.seed(123) + new_results = list(datapipe) + assert len(expected_results) == len(new_results) + for a, b in zip(expected_results, new_results): + assert repr(a) == repr(b) diff --git a/tests/python/pytorch/graphbolt/test_base.py b/tests/python/pytorch/graphbolt/test_base.py index b25b28166294..5d7d6c477c33 100644 --- a/tests/python/pytorch/graphbolt/test_base.py +++ b/tests/python/pytorch/graphbolt/test_base.py @@ -250,6 +250,34 @@ def test_isin_non_1D_dim(): gb.isin(elements, test_elements) +@pytest.mark.parametrize( + "dtype", + [ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ], +) +@pytest.mark.parametrize("idtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("pinned", [False, True]) +def test_index_select(dtype, idtype, pinned): + if F._default_context_str != "gpu" and pinned: + pytest.skip("Pinned tests are available only on GPU.") + tensor = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype) + tensor = tensor.pin_memory() if pinned else tensor.to(F.ctx()) + index = torch.tensor([0, 2], dtype=idtype, device=F.ctx()) + gb_result = gb.index_select(tensor, index) + torch_result = tensor.to(F.ctx())[index.long()] + assert torch.equal(torch_result, gb_result) + + def torch_expand_indptr(indptr, dtype, nodes=None): if nodes is None: nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 80a8c7164a57..2ee78bf3be6a 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -47,11 +47,21 @@ def test_DataLoader(): F._default_context_str != "gpu", reason="This test requires the GPU.", ) -@pytest.mark.parametrize("overlap_feature_fetch", [True, False]) +@pytest.mark.parametrize( + "sampler_name", ["NeighborSampler", "LayerNeighborSampler"] +) @pytest.mark.parametrize("enable_feature_fetch", [True, False]) -def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch): +@pytest.mark.parametrize("overlap_feature_fetch", [True, False]) +@pytest.mark.parametrize("overlap_graph_fetch", [True, False]) +def test_gpu_sampling_DataLoader( + sampler_name, + enable_feature_fetch, + overlap_feature_fetch, + overlap_graph_fetch, +): N = 40 B = 4 + num_layers = 2 itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to( F.ctx() @@ -68,10 +78,10 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch): datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B) datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"]) - datapipe = dgl.graphbolt.NeighborSampler( + datapipe = getattr(dgl.graphbolt, sampler_name)( datapipe, graph, - fanouts=[torch.LongTensor([2]) for _ in range(2)], + fanouts=[torch.LongTensor([2]) for _ in range(num_layers)], ) if enable_feature_fetch: datapipe = dgl.graphbolt.FeatureFetcher( @@ -81,14 +91,18 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch): ) dataloader = dgl.graphbolt.DataLoader( - datapipe, overlap_feature_fetch=overlap_feature_fetch + datapipe, + overlap_feature_fetch=overlap_feature_fetch, + overlap_graph_fetch=overlap_graph_fetch, ) bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch) + if overlap_graph_fetch: + bufferer_awaiter_cnt += num_layers datapipe = dataloader.dataset datapipe_graph = dp_utils.traverse_dps(datapipe) awaiters = dp_utils.find_dps( datapipe_graph, - dgl.graphbolt.Awaiter, + dgl.graphbolt.Waiter, ) assert len(awaiters) == bufferer_awaiter_cnt bufferers = dp_utils.find_dps(