-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] GraphSAGE with control variate sampling on new sampler #1355
Conversation
Wondering why not start with a non-distributed version? |
There's not much difference between a non-distributed version and a distributed one: the only differences are to initialize a distributed context and synchronize the gradients among GPUs. |
""" | ||
Copys features and labels of a set of nodes onto GPU. | ||
""" | ||
blocks[0].srcdata['features'] = g.ndata['features'][blocks[0].srcdata[dgl.NID]].to(dev_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, the original paper plays a trick. it aggregates the features in the first layer in advance.
else: | ||
assert isinstance(exception, Exception) | ||
raise exception.__class__(trace) | ||
return decorated_function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we keep this somewhere inside DGL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agree. Will put it in utils.
It is everywhere in our newly contributed examples already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm now wondering if utils is a good place to put this function, since it depends on PyTorch multiprocessing module.
Why not follow the original example? |
I thought using |
hist_col = 'hist_%d' % (i + 1) | ||
|
||
h_new = block.dstdata['h_new'].cpu() | ||
g.ndata[hist_col][ids] = h_new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will there be data race issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The children don't share the same frame apparently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Other suggestions:
|
I'll leave them as is. |
Description
This implements control variate sampling from https://arxiv.org/abs/1710.10568 with the new sampler interface.
@zheng-da @lingfanyu Please take a look if you are interested.
TODO
Refactor the code with vanilla GraphSAGE and extract common logicSkipping this for now.Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change