Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message Passing with edge updates #20

Closed
tisabe opened this issue Aug 4, 2021 · 4 comments
Closed

Message Passing with edge updates #20

tisabe opened this issue Aug 4, 2021 · 4 comments

Comments

@tisabe
Copy link

tisabe commented Aug 4, 2021

Hi there,

I was looking to implement a message passing network with edge updates as described in https://arxiv.org/abs/1806.03146.
Looking at the Jraph paper, it is explained that calculating the messages M_t for each edge should be done with the edge update function phi_e, in the GraphNetwork from the model zoo. However, as I understand it, this prevents me from implementing a function that just updates the edges, based on the edge feature, sending and receiving node.

Is there a workaround using the current model zoo to seperate edge updates and edge-wise messages or is this a known problem?

Thanks!

@jg8610
Copy link
Contributor

jg8610 commented Aug 16, 2021

Hey! Thanks for waiting for the response.

I was just looking at the paper, and I believe the existing GraphNetwork should do what you want. It follows this pseudocode:

updated_edges = edge_update_fn(previous_edges, senders, receivers, globals_)
updated_nodes = node_update_fn(nodes, update_edges_senders, updated_edges_receivers, globals_)
return GraphsTuple(nodes=update_nodes, edges=updated_edges)

Is this different to what is described in the paper?

@tisabe
Copy link
Author

tisabe commented Aug 16, 2021

Yes, this is also how I intended to write the network. However, to me it looks like the inputs of the node_update_fn are too limited to compute the node update as described in the paper.
In the paper, each node update depends on the node features themselves and an aggregation of the incoming messages.
The messages (defined for every edge) depend on edge features, sending and receiving nodes.

I'll try to put it into pseudocode, as it is a bit hard to describe:

messages = message_fn(edges, senders, receivers)
aggregate_messages_per_node = aggregate_message_fn(messages)
updated_nodes = node_update_fn(nodes, aggregate_messages_per_node)

I think this does not fit within the update_node_fn in GraphNetwork. The update_node_fn in GraphNetwork uses only aggregated edges, but to calculate the message function it needs individual edge features.

@jg8610
Copy link
Contributor

jg8610 commented Aug 16, 2021

Ah, thanks for clarifying!

I think the easiest way to accomplish this is by using structured messages (pseudocode)

# initialize and edge message with zeros:
edge_message = {'message':  jnp.zeros(...), 'latent': jnp.zeros(...)}  # I'm using the phrase latent here for 'update' in the paper for disambiguation with jraph.

def update_edge_fn(edges, senders, receivers):
  edges['latent'] = update_latent_fn(edges['latent'], senders, receivers)  # the 'update_fn' in the paper 
  edges['message'] = update_message_fn(edges['latent'], senders, receivers)
  return edges

In the node update function, you just need to make sure just to use the 'message' not the latent.

@tisabe
Copy link
Author

tisabe commented Aug 17, 2021

Yes, this looks like it solves my problem nicely. Thanks!

@tisabe tisabe closed this as completed Aug 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants