Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

using pmap with GraphsTuple #10

Closed
sooheon opened this issue Jan 30, 2021 · 6 comments
Closed

using pmap with GraphsTuple #10

sooheon opened this issue Jan 30, 2021 · 6 comments

Comments

@sooheon
Copy link
Contributor

sooheon commented Jan 30, 2021

Looking at some other code that uses pmap for multigpu, pmapped functions want an additional axis representing the gpus.

GraphsTuple is just an ordered tuple of feature arrays and metadata, where feature arrays are rank 2 and metadata is rank 1. Am I understanding correctly that the way to use multigpu learning here is by having rank 3 and 2 arrays inside GraphsTuple like the following?

  1. take BS * N_GPU graphs from data source
  2. make N_GPU batches via jraph.batch
  3. pad_graphs_to_nearest_power_of_two that handles multiple inputs => now you have equal shaped arrays
  4. create one GraphsTuple with pmappable (rank 3) features
  5. use graphnets with pmapped functions
@jg8610
Copy link
Contributor

jg8610 commented Mar 3, 2021

Hi Sooheon, I'm sorry I haven't got tot this sooner I had somehow missed it.

Good queston! You would have one GraphsTuple with pmappable features.

In psuedocode this would be something like:

batch = []
for _ in range(batch_size):
  batch.append(next(dataset))

# At the moment we have N graphs tuples. We pad each to the same common size.
batch = [pad(el) for el in batch]

batch = jax.tree_multimap(lambda *x: jnp.stack(x, axis=0), *batch)

@sooheon sooheon closed this as completed Mar 5, 2021
@sooheon
Copy link
Contributor Author

sooheon commented Mar 5, 2021

Cool, what I was thinking was along the lines of what you wrote, except the zero axis is the "n_gpus" dim, and each gpu receives a batched graphstuple that's big enough to saturate the memory. Basically the same but each element is already batched with jraph.batch. I wrote the following to pad multiple batches at once to a common dim, would you like a pr for this?

def pad_batch_of_graphs_to_nearest_power_of_two(batch: [jraph.GraphsTuple]) -> [jraph.GraphsTuple]:
  pad_nodes_to = 0
  pad_edges_to = 0
  pad_graphs_to = 0
  for graph in batch:
    n = _nearest_bigger_power_of_two(jnp.sum(graph.n_node)) + 1
    e = _nearest_bigger_power_of_two(jnp.sum(graph.n_edge))
    ng = graph.n_node.shape[0] + 1
    pad_nodes_to = max(pad_nodes_to, n)
    pad_edges_to = max(pad_edges_to, e)
    pad_graphs_to = max(pad_graphs_to, ng)
  return [jraph.pad_with_graphs(g, pad_nodes_to, pad_edges_to, pad_graphs_to)
          for g in batch]

@jg8610
Copy link
Contributor

jg8610 commented Mar 8, 2021

Hey,

Thanks for this! I think the best place for this would probably be in an example, what do you think?

It may be that at a later date we add it to the final library but It's not clear to me we have the final answer for the right way to pad - e.g. I'm not sure that padding to the nearest power of two is always best.

@sooheon
Copy link
Contributor Author

sooheon commented Mar 9, 2021

Yeah I could make an extension to one of your OGB examples with working pmap when I find some time.

@engmubarak48
Copy link

Yeah I could make an extension to one of your OGB examples with working pmap when I find some time.

@sooheon Did you manage to create that example with OGB?.

@jg8610
Copy link
Contributor

jg8610 commented May 12, 2022

Hey there,

I wanted to follow up with another suggestion for pmapping, one we use internally, using jraph.dynamically_batch.

The first step is to define a budget for your batch. The budget determined the maximum number of nodes, edges and graphs in your batch. This should be set to at least the size of your largest graph in your dataset.

Then, dynamic batching following the following pseudocode:

batch = None
While budget_not_hit:
  jraph.batch(next(dataset), batch)

Once the budget is hit, the batch is padded and returned to the user.

To use this with pmap you would do something like:

batched_dataset = jraph.dynamically_batch(dataset_iterator, n_node=10, n_edge=10, n_graph=2)
# for use with pmap
pmap_batch = []
for _ in range(num_devices):
  pmap_batch.append(next(batched_dataset))

jax.pmap(graph_net)(pmap_batch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants