New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PinSAGE samper] Adjust the APIs for PinSAGESamper #3529
Conversation
To trigger regression tests:
|
A simple performance test script is as follow: import dgl
import torch
import scipy
import time
if __name__ == '__main__':
g = scipy.sparse.random(6000, 8000, 0.3)
G = dgl.heterograph({
('A', 'AB', 'B'): g.nonzero(),
('B', 'BA', 'A'): g.T.nonzero()
})
print('edge size AB and BA: {}, {}'.format(G.num_edges('AB'), G.num_edges('BA')))
sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)
seeds = torch.LongTensor([i for i in range(1000)])
repeat = 10
print('graph edges: {}'.format(G.num_edges()))
t1 = time.time()
for _ in range(repeat):
frontier = sampler(seeds)
print('frontier edges: {}'.format(frontier.num_edges()))
time_cost = (time.time() - t1)
# frontier.all_edges(form='uv')
print('time cost: {:.5f}'.format(time_cost / repeat)) |
python/dgl/sampling/randomwalks.py
Outdated
@@ -209,6 +210,47 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob | |||
eids = F.from_dgl_nd(eids) | |||
return (traces, eids, types) if return_eids else (traces, types) | |||
|
|||
def randomwalk_topk(src, dst, num_samples_per_node, k): | |||
"""Select the top-k nodes in src for each node in dst(dedup). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand, this is fusing to_simple
, select_topk
, and counting the number of occurrences together. The docstring did not show that. Could you rewrite the docstring? Also, I don't think randomwalk_topk
is a suitable name since this function is not related to random walk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, thank you for your reply. I have modified the docstring according to your requirement. I modified these codes to optimize the random walk algorithm at first, so I just give this API with the name "randomwalk_topk". I tried to think of another name but did not get a suitable name, because of my poor English. Could you give me a suggestion about this name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function is only used for PinSAGE. I'll change it to an internal function instead.
Description
Compact the "to_simple", "select_topk" and "gather_row" to one API "randomwalk_topk" in class RandomWalkNeighborSampler. The new API can speed up to about 2X than old three APIs.
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes