Skip to content

Commit

Permalink
[GraphBolt] Refactor sampling (#7367)
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou committed Apr 29, 2024
1 parent 6b140f2 commit f0213d2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 178 deletions.
8 changes: 2 additions & 6 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace graphbolt {
namespace sampling {

enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };
enum TemporalOption { NOT_TEMPORAL, TEMPORAL };

constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
Expand Down Expand Up @@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm);

private:
template <typename NumPickFn, typename PickFn>
template <TemporalOption Temporal, typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;

template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;

/** @brief CSC format index pointer array. */
torch::Tensor indptr_;

Expand Down
232 changes: 60 additions & 172 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ auto GetTemporalPickFn(
};
}

template <typename NumPickFn, typename PickFn>
template <TemporalOption Temporal, typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& seeds,
Expand All @@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> edge_offsets = torch::nullopt;

bool with_seed_offsets = seed_offsets.has_value();
bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1;
bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1 &&
Temporal == TemporalOption::NOT_TEMPORAL;

// Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
Expand Down Expand Up @@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;

const auto seed_type_id =
(hetero_with_seed_offsets)
? std::upper_bound(
seed_offsets->begin(), seed_offsets->end(),
i) -
seed_offsets->begin() - 1
: 0;
// `seed_index` indicates the index of the current
// seed within the group of seeds which have the same
// node type.
const auto seed_index =
(hetero_with_seed_offsets)
? i - seed_offsets->at(seed_type_id)
: i;
num_pick_fn(
offset, num_neighbors,
num_picked_neighbors_data_ptr + 1, seed_index,
etype_id_to_num_picked_offset);
if constexpr (Temporal == TemporalOption::TEMPORAL) {
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0
? 0
: num_pick_fn(i, offset, num_neighbors);
} else {
const auto seed_type_id =
(hetero_with_seed_offsets)
? std::upper_bound(
seed_offsets->begin(),
seed_offsets->end(), i) -
seed_offsets->begin() - 1
: 0;
// `seed_index` indicates the index of the current
// seed within the group of seeds which have the same
// node type.
const auto seed_index =
(hetero_with_seed_offsets)
? i - seed_offsets->at(seed_type_id)
: i;
num_pick_fn(
offset, num_neighbors,
num_picked_neighbors_data_ptr + 1, seed_index,
etype_id_to_num_picked_offset);
}
}
});

Expand Down Expand Up @@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
: i;

// Step 4. Pick neighbors for each node.
picked_number = pick_fn(
offset, num_neighbors, picked_eids_data_ptr,
seed_index, subgraph_indptr_data_ptr,
etype_id_to_num_picked_offset);
if (!hetero_with_seed_offsets) {
TORCH_CHECK(
num_picked_neighbors_data_ptr[i + 1] ==
picked_number,
"Actual picked count doesn't match the calculated "
"pick number.");
if constexpr (Temporal == TemporalOption::TEMPORAL) {
picked_number = num_picked_neighbors_data_ptr[i + 1];
auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
i, offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated"
" pick number.");
}
} else {
picked_number = pick_fn(
offset, num_neighbors, picked_eids_data_ptr,
seed_index, subgraph_indptr_data_ptr,
etype_id_to_num_picked_offset);
if (!hetero_with_seed_offsets) {
TORCH_CHECK(
num_picked_neighbors_data_ptr[i + 1] ==
picked_number,
"Actual picked count doesn't match the calculated"
" pick number.");
}
}

// Step 5. Calculate other attributes and return the
Expand Down Expand Up @@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets);
}

template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
const auto indptr_options = indptr_.options();
torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_nodes + 1}, indptr_options);

// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
torch::Tensor picked_eids;
torch::Tensor subgraph_indptr;
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;

AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = index_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<nodes_t>();

// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;

num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0
? 0
: num_pick_fn(i, offset, num_neighbors);
}
});

// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());

// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<indptr_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices =
torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge = torch::empty(
{total_length}, type_per_edge_.value().options());
}

// Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number =
num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
i, offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated "
"pick "
"number.");

// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INDEX_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<index_t>();
auto indices_data_ptr =
indices_.data_ptr<index_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr
[picked_eids_data_ptr[i]];
}
}));
}
}
}
});
}));
}));

torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}

c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
Expand Down Expand Up @@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
Expand All @@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
NumNodes()};
}
}();
return SampleNeighborsImpl(
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
Expand All @@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets),
Expand All @@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
torch::optional<std::vector<int64_t>> seed_offsets = torch::nullopt;
// 1. Get probs_or_mask.
auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value()) {
Expand All @@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return TemporalSampleNeighborsImpl(
input_nodes, return_eids,
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
Expand All @@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp, args));
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return TemporalSampleNeighborsImpl(
input_nodes, return_eids,
return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
Expand Down

0 comments on commit f0213d2

Please sign in to comment.