Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ideas on Message Passing #2

Closed
reshinthadithyan opened this issue Dec 9, 2020 · 1 comment
Closed

Ideas on Message Passing #2

reshinthadithyan opened this issue Dec 9, 2020 · 1 comment

Comments

@reshinthadithyan
Copy link

reshinthadithyan commented Dec 9, 2020

I) I'm looking onto implementing the convolution operation specified in Graph Isomorphism Network.
Which transforms the Node Features by a set of Dense Layers. Can an haiku.module object be called on the node transformation function? If not how should that be done?

node_update_fn = haiku.Seq(node_feature) + haiku.Seq(incoming edge_feature)

@jg8610
Copy link
Contributor

jg8610 commented Dec 9, 2020

Hi there, thanks for your question.

This question is mainly about how haiku works. The answer may be different depending on if you switch to another NN framework like flax, but the good news is that you can carry on using Jraph 👍

Haiku nets must be wrapped in another function. For example, from their docs:

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

So in your case you would need to write your node_update_fn as a function that contains the haiku nets

def node_update_fn(node, incoming_edge_feature):
  node_net = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  edge_net = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  return node_net(nodes) + edge_net(incoming_edge_feature)

It's convenient to write the whole graph net inside a function, that way you only need to apply haikus transform to the outer most function.

def forward_pass_graph_net(graph):
  net = jraph.GraphNetwork(update_node_fn=update_node_fn, ... )
  return net(graph)

You can then use hk.transform to transform your function into a pure function (with no side effects) so it can be used with jax:
forward_pass_graph_net_t = hk.transform(forward_pass_graph_net)

Not: if you are using a configured graph net, by default you will also receive global features, and features from the edges for which you node is a sender. So just for completeness, you will have to handle those.

def node_update_fn(node, incoming_edge_feature, unused_outgoing_edge_feature, unused_global_feature):
  del unused_outgoing_edge_feature
  del unused_global_feature
  node_net = ...
  edge_net = ...
  return node_net(nodes) + edge_net(incoming_edge_feature)

Hope that's helpful, if you have any more questions let me know!

@jg8610 jg8610 closed this as completed Dec 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants