diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 37d03437f70e..2f145970ec7c 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -18,6 +18,7 @@ namespace graphbolt { namespace sampling { enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT }; +enum TemporalOption { NOT_TEMPORAL, TEMPORAL }; constexpr bool is_labor(SamplerType S) { return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT; @@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm); private: - template + template c10::intrusive_ptr SampleNeighborsImpl( const torch::Tensor& seeds, torch::optional>& seed_offsets, const std::vector& fanouts, bool return_eids, NumPickFn num_pick_fn, PickFn pick_fn) const; - template - c10::intrusive_ptr TemporalSampleNeighborsImpl( - const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, - PickFn pick_fn) const; - /** @brief CSC format index pointer array. */ torch::Tensor indptr_; diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index b5c2587f693a..e2000b410458 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -492,7 +492,7 @@ auto GetTemporalPickFn( }; } -template +template c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighborsImpl( const torch::Tensor& seeds, @@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( torch::optional edge_offsets = torch::nullopt; bool with_seed_offsets = seed_offsets.has_value(); - bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1; + bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1 && + Temporal == TemporalOption::NOT_TEMPORAL; // Get the number of edge types. If it's homo or if the size of fanouts is 1 // (hetero graph but sampled as a homo graph), set num_etypes as 1. @@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( const auto offset = indptr_data[nid]; const auto num_neighbors = indptr_data[nid + 1] - offset; - const auto seed_type_id = - (hetero_with_seed_offsets) - ? std::upper_bound( - seed_offsets->begin(), seed_offsets->end(), - i) - - seed_offsets->begin() - 1 - : 0; - // `seed_index` indicates the index of the current - // seed within the group of seeds which have the same - // node type. - const auto seed_index = - (hetero_with_seed_offsets) - ? i - seed_offsets->at(seed_type_id) - : i; - num_pick_fn( - offset, num_neighbors, - num_picked_neighbors_data_ptr + 1, seed_index, - etype_id_to_num_picked_offset); + if constexpr (Temporal == TemporalOption::TEMPORAL) { + num_picked_neighbors_data_ptr[i + 1] = + num_neighbors == 0 + ? 0 + : num_pick_fn(i, offset, num_neighbors); + } else { + const auto seed_type_id = + (hetero_with_seed_offsets) + ? std::upper_bound( + seed_offsets->begin(), + seed_offsets->end(), i) - + seed_offsets->begin() - 1 + : 0; + // `seed_index` indicates the index of the current + // seed within the group of seeds which have the same + // node type. + const auto seed_index = + (hetero_with_seed_offsets) + ? i - seed_offsets->at(seed_type_id) + : i; + num_pick_fn( + offset, num_neighbors, + num_picked_neighbors_data_ptr + 1, seed_index, + etype_id_to_num_picked_offset); + } } }); @@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( : i; // Step 4. Pick neighbors for each node. - picked_number = pick_fn( - offset, num_neighbors, picked_eids_data_ptr, - seed_index, subgraph_indptr_data_ptr, - etype_id_to_num_picked_offset); - if (!hetero_with_seed_offsets) { - TORCH_CHECK( - num_picked_neighbors_data_ptr[i + 1] == - picked_number, - "Actual picked count doesn't match the calculated " - "pick number."); + if constexpr (Temporal == TemporalOption::TEMPORAL) { + picked_number = num_picked_neighbors_data_ptr[i + 1]; + auto picked_offset = subgraph_indptr_data_ptr[i]; + if (picked_number > 0) { + auto actual_picked_count = pick_fn( + i, offset, num_neighbors, + picked_eids_data_ptr + picked_offset); + TORCH_CHECK( + actual_picked_count == picked_number, + "Actual picked count doesn't match the calculated" + " pick number."); + } + } else { + picked_number = pick_fn( + offset, num_neighbors, picked_eids_data_ptr, + seed_index, subgraph_indptr_data_ptr, + etype_id_to_num_picked_offset); + if (!hetero_with_seed_offsets) { + TORCH_CHECK( + num_picked_neighbors_data_ptr[i + 1] == + picked_number, + "Actual picked count doesn't match the calculated" + " pick number."); + } } // Step 5. Calculate other attributes and return the @@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets); } -template -c10::intrusive_ptr -FusedCSCSamplingGraph::TemporalSampleNeighborsImpl( - const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, - PickFn pick_fn) const { - const int64_t num_nodes = nodes.size(0); - const auto indptr_options = indptr_.options(); - torch::Tensor num_picked_neighbors_per_node = - torch::empty({num_nodes + 1}, indptr_options); - - // Calculate GrainSize for parallel_for. - // Set the default grain size to 64. - const int64_t grain_size = 64; - torch::Tensor picked_eids; - torch::Tensor subgraph_indptr; - torch::Tensor subgraph_indices; - torch::optional subgraph_type_per_edge = torch::nullopt; - - AT_DISPATCH_INDEX_TYPES( - indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] { - using indptr_t = index_t; - AT_DISPATCH_INDEX_TYPES( - nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] { - using nodes_t = index_t; - const auto indptr_data = indptr_.data_ptr(); - auto num_picked_neighbors_data_ptr = - num_picked_neighbors_per_node.data_ptr(); - num_picked_neighbors_data_ptr[0] = 0; - const auto nodes_data_ptr = nodes.data_ptr(); - - // Step 1. Calculate pick number of each node. - torch::parallel_for( - 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { - for (int64_t i = begin; i < end; ++i) { - const auto nid = nodes_data_ptr[i]; - TORCH_CHECK( - nid >= 0 && nid < NumNodes(), - "The seed nodes' IDs should fall within the range of " - "the " - "graph's node IDs."); - const auto offset = indptr_data[nid]; - const auto num_neighbors = indptr_data[nid + 1] - offset; - - num_picked_neighbors_data_ptr[i + 1] = - num_neighbors == 0 - ? 0 - : num_pick_fn(i, offset, num_neighbors); - } - }); - - // Step 2. Calculate prefix sum to get total length and offsets of - // each node. It's also the indptr of the generated subgraph. - subgraph_indptr = num_picked_neighbors_per_node.cumsum( - 0, indptr_.scalar_type()); - - // Step 3. Allocate the tensor for picked neighbors. - const auto total_length = - subgraph_indptr.data_ptr()[num_nodes]; - picked_eids = torch::empty({total_length}, indptr_options); - subgraph_indices = - torch::empty({total_length}, indices_.options()); - if (type_per_edge_.has_value()) { - subgraph_type_per_edge = torch::empty( - {total_length}, type_per_edge_.value().options()); - } - - // Step 4. Pick neighbors for each node. - auto picked_eids_data_ptr = picked_eids.data_ptr(); - auto subgraph_indptr_data_ptr = - subgraph_indptr.data_ptr(); - torch::parallel_for( - 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { - for (int64_t i = begin; i < end; ++i) { - const auto nid = nodes_data_ptr[i]; - const auto offset = indptr_data[nid]; - const auto num_neighbors = indptr_data[nid + 1] - offset; - const auto picked_number = - num_picked_neighbors_data_ptr[i + 1]; - const auto picked_offset = subgraph_indptr_data_ptr[i]; - if (picked_number > 0) { - auto actual_picked_count = pick_fn( - i, offset, num_neighbors, - picked_eids_data_ptr + picked_offset); - TORCH_CHECK( - actual_picked_count == picked_number, - "Actual picked count doesn't match the calculated " - "pick " - "number."); - - // Step 5. Calculate other attributes and return the - // subgraph. - AT_DISPATCH_INDEX_TYPES( - subgraph_indices.scalar_type(), - "IndexSelectSubgraphIndices", ([&] { - auto subgraph_indices_data_ptr = - subgraph_indices.data_ptr(); - auto indices_data_ptr = - indices_.data_ptr(); - for (auto i = picked_offset; - i < picked_offset + picked_number; ++i) { - subgraph_indices_data_ptr[i] = - indices_data_ptr[picked_eids_data_ptr[i]]; - } - })); - if (type_per_edge_.has_value()) { - AT_DISPATCH_INTEGRAL_TYPES( - subgraph_type_per_edge.value().scalar_type(), - "IndexSelectTypePerEdge", ([&] { - auto subgraph_type_per_edge_data_ptr = - subgraph_type_per_edge.value() - .data_ptr(); - auto type_per_edge_data_ptr = - type_per_edge_.value().data_ptr(); - for (auto i = picked_offset; - i < picked_offset + picked_number; ++i) { - subgraph_type_per_edge_data_ptr[i] = - type_per_edge_data_ptr - [picked_eids_data_ptr[i]]; - } - })); - } - } - } - }); - })); - })); - - torch::optional subgraph_reverse_edge_ids = torch::nullopt; - if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); - - return c10::make_intrusive( - subgraph_indptr, subgraph_indices, nodes, torch::nullopt, - subgraph_reverse_edge_ids, subgraph_type_per_edge); -} - c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, @@ -969,7 +856,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( indices_, {random_seed.value(), static_cast(seed2_contribution)}, NumNodes()}; - return SampleNeighborsImpl( + return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, GetNumPickFn( fanouts, replace, type_per_edge_, probs_or_mask, @@ -990,7 +877,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( NumNodes()}; } }(); - return SampleNeighborsImpl( + return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, GetNumPickFn( fanouts, replace, type_per_edge_, probs_or_mask, @@ -1001,7 +888,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( } } else { SamplerArgs args; - return SampleNeighborsImpl( + return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, GetNumPickFn( fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets), @@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( bool return_eids, torch::optional probs_name, torch::optional node_timestamp_attr_name, torch::optional edge_timestamp_attr_name) const { + torch::optional> seed_offsets = torch::nullopt; // 1. Get probs_or_mask. auto probs_or_mask = this->EdgeAttribute(probs_name); if (probs_name.has_value()) { @@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( static_cast(0), std::numeric_limits::max()); SamplerArgs args{indices_, random_seed, NumNodes()}; - return TemporalSampleNeighborsImpl( - input_nodes, return_eids, + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, GetTemporalNumPickFn( input_nodes_timestamp, this->indices_, fanouts, replace, type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), @@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( edge_timestamp, args)); } else { SamplerArgs args; - return TemporalSampleNeighborsImpl( - input_nodes, return_eids, + return SampleNeighborsImpl( + input_nodes, seed_offsets, fanouts, return_eids, GetTemporalNumPickFn( input_nodes_timestamp, this->indices_, fanouts, replace, type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),