Skip to content

Commit

Permalink
[Graphbolt]Fix negative sampler (#6933) (#6938)
Browse files Browse the repository at this point in the history
Co-authored-by: peizhou001 <110809584+peizhou001@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-21-218.ap-northeast-1.compute.internal>
  • Loading branch information
3 people committed Jan 11, 2024
1 parent c047950 commit 92c8f08
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 151 deletions.
26 changes: 0 additions & 26 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,32 +356,6 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const;

/**
* @brief Sample negative edges by randomly choosing negative
* source-destination pairs according to a uniform distribution. For each edge
* ``(u, v)``, it is supposed to generate `negative_ratio` pairs of negative
* edges ``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
* the graph.
*
* @param node_pairs A tuple of two 1D tensors that represent the source and
* destination of positive edges, with 'positive' indicating that these edges
* are present in the graph. It's important to note that within the context of
* a heterogeneous graph, the ids in these tensors signify heterogeneous ids.
* @param negative_ratio The ratio of the number of negative samples to
* positive samples.
* @param max_node_id The maximum ID of the node to be selected. It
* should correspond to the number of nodes of a specific type.
*
* @return A tuple consisting of two 1D tensors represents the source and
* destination of negative edges. In the context of a heterogeneous
* graph, both the input nodes and the selected nodes are represented
* by heterogeneous IDs. Note that negative refers to false negatives,
* which means the edge could be present or not present in the graph.
*/
std::tuple<torch::Tensor, torch::Tensor> SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const;

/**
* @brief Copy the graph to shared memory.
* @param shared_memory_name The name of the shared memory.
Expand Down
12 changes: 0 additions & 12 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,18 +692,6 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp));
}

std::tuple<torch::Tensor, torch::Tensor>
FusedCSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
int64_t negative_ratio, int64_t max_node_id) const {
torch::Tensor pos_src;
std::tie(pos_src, std::ignore) = node_pairs;
auto neg_len = pos_src.size(0) * negative_ratio;
auto neg_src = pos_src.repeat(negative_ratio);
auto neg_dst = torch::randint(0, max_node_id, {neg_len}, pos_src.options());
return std::make_tuple(neg_src, neg_dst);
}

static c10::intrusive_ptr<FusedCSCSamplingGraph>
BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) {
helper.InitializeRead();
Expand Down
3 changes: 0 additions & 3 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ TORCH_LIBRARY(graphbolt, m) {
.def(
"temporal_sample_neighbors",
&FusedCSCSamplingGraph::TemporalSampleNeighbors)
.def(
"sample_negative_edges_uniform",
&FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &FusedCSCSamplingGraph::CopyToSharedMemory)
.def_pickle(
// __getstate__
Expand Down
32 changes: 16 additions & 16 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ def sample_negative_edges_uniform(
pairs according to a uniform distribution. For each edge ``(u, v)``,
it is supposed to generate `negative_ratio` pairs of negative edges
``(u, v')``, where ``v'`` is chosen uniformly from all the nodes in
the graph.
the graph. As ``u`` is exactly same as the corresponding positive edges,
it returns None for negative sources.
Parameters
----------
Expand All @@ -877,23 +878,22 @@ def sample_negative_edges_uniform(
`edge_type`. Note that negative refers to false negatives, which
means the edge could be present or not present in the graph.
"""
if edge_type is not None:
assert (
self.node_type_offset is not None
), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.node_type_to_id[dst_node_type]
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id]
)
if edge_type:
_, _, dst_ntype = etype_str_to_tuple(edge_type)
max_node_id = self.num_nodes[dst_ntype]
else:
max_node_id = self.total_num_nodes
return self._c_csc_graph.sample_negative_edges_uniform(
node_pairs,
negative_ratio,
max_node_id,
pos_src, _ = node_pairs
num_negative = pos_src.size(0) * negative_ratio
return (
None,
torch.randint(
0,
max_node_id,
(num_negative,),
dtype=pos_src.dtype,
device=pos_src.device,
),
)

def copy_to_shared_memory(self, shared_memory_name: str):
Expand Down
15 changes: 9 additions & 6 deletions python/dgl/graphbolt/impl/uniform_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,23 @@ class UniformNegativeSampler(NegativeSampler):
Examples
--------
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5])
>>> indices = torch.LongTensor([1, 2, 0, 2, 0])
>>> indptr = torch.LongTensor([0, 1, 2, 3, 4])
>>> indices = torch.LongTensor([1, 2, 3, 0])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = (torch.tensor([0, 1]), torch.tensor([1, 2]))
>>> node_pairs = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=1,)
... item_set, batch_size=4,)
>>> neg_sampler = gb.UniformNegativeSampler(
... item_sampler, graph, 2)
>>> for minibatch in neg_sampler:
... print(minibatch.negative_srcs)
... print(minibatch.negative_dsts)
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([1, 0, 0]))
(tensor([1, 1, 1]), tensor([2, 1, 2]), tensor([1, 0, 0]))
None
tensor([[2, 1],
[2, 1],
[3, 2],
[1, 3]])
"""

def __init__(
Expand Down
8 changes: 2 additions & 6 deletions tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def test_UniformNegativeSampler_invoke():
def _verify(negative_sampler):
for data in negative_sampler:
# Assertation
assert data.negative_srcs.size(0) == batch_size
assert data.negative_srcs.size(1) == negative_ratio
assert data.negative_srcs is None
assert data.negative_dsts.size(0) == batch_size
assert data.negative_dsts.size(1) == negative_ratio

Expand Down Expand Up @@ -90,12 +89,9 @@ def test_Uniform_NegativeSampler(negative_ratio):
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
assert len(neg_src) == batch_size
assert len(neg_dst) == batch_size
assert neg_src.numel() == batch_size * negative_ratio
assert neg_src is None
assert neg_dst.numel() == batch_size * negative_ratio
expected_src = pos_src.repeat(negative_ratio).view(-1, negative_ratio)
assert torch.equal(expected_src, neg_src)


def get_hetero_graph():
Expand Down

0 comments on commit 92c8f08

Please sign in to comment.