-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using pmap with GraphsTuple #10
Comments
Hi Sooheon, I'm sorry I haven't got tot this sooner I had somehow missed it. Good queston! You would have one GraphsTuple with pmappable features. In psuedocode this would be something like:
|
Cool, what I was thinking was along the lines of what you wrote, except the zero axis is the "n_gpus" dim, and each gpu receives a batched graphstuple that's big enough to saturate the memory. Basically the same but each element is already batched with def pad_batch_of_graphs_to_nearest_power_of_two(batch: [jraph.GraphsTuple]) -> [jraph.GraphsTuple]:
pad_nodes_to = 0
pad_edges_to = 0
pad_graphs_to = 0
for graph in batch:
n = _nearest_bigger_power_of_two(jnp.sum(graph.n_node)) + 1
e = _nearest_bigger_power_of_two(jnp.sum(graph.n_edge))
ng = graph.n_node.shape[0] + 1
pad_nodes_to = max(pad_nodes_to, n)
pad_edges_to = max(pad_edges_to, e)
pad_graphs_to = max(pad_graphs_to, ng)
return [jraph.pad_with_graphs(g, pad_nodes_to, pad_edges_to, pad_graphs_to)
for g in batch] |
Hey, Thanks for this! I think the best place for this would probably be in an example, what do you think? It may be that at a later date we add it to the final library but It's not clear to me we have the final answer for the right way to pad - e.g. I'm not sure that padding to the nearest power of two is always best. |
Yeah I could make an extension to one of your OGB examples with working pmap when I find some time. |
@sooheon Did you manage to create that example with OGB?. |
Hey there, I wanted to follow up with another suggestion for The first step is to define a Then, dynamic batching following the following pseudocode:
Once the budget is hit, the batch is padded and returned to the user. To use this with pmap you would do something like:
|
Looking at some other code that uses pmap for multigpu, pmapped functions want an additional axis representing the gpus.
GraphsTuple is just an ordered tuple of feature arrays and metadata, where feature arrays are rank 2 and metadata is rank 1. Am I understanding correctly that the way to use multigpu learning here is by having rank 3 and 2 arrays inside GraphsTuple like the following?
The text was updated successfully, but these errors were encountered: