# readme examples

In [None]:
import jraph
import jax.numpy as jnp

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.0], [1.0], [2.0]])

# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])

# You can optionally add edge attributes.
edges = jnp.array([[5.0], [6.0], [7.0]])

# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])

# Optionally you can add `global` information, such as a graph label.

global_context = jnp.array([[1]])

graph = jraph.GraphsTuple(
    nodes=node_features,
    senders=senders,
    receivers=receivers,
    edges=edges,
    n_node=n_node,
    n_edge=n_edge,
    globals=global_context,
)
graph

In [None]:
# graph tuple can have more than one graph
two_graph_graphstuple = jraph.batch([graph, graph])
two_graph_graphstuple._replace

In [None]:
jraph.batch([graph, graph]).n_node

In [None]:
# GraphNetwork


def update_edge_fn(edge, sender, receiver, globals):
    """e'k = φe (ek , vsk , vrk , u)"""
    return edge  # pass through, honoring I/O


def update_node_fn(nodes, sent_aggregations, received_aggregations, global_attributes):
    """

    paper:
    vi0 = φv (ē'i , vi , u)

    jraph:
    vi0 = φv (vi, ē'i^out, ē'i^in, u)
    """

    return nodes  # pass through, honoring I/O

In [None]:
net = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)

In [None]:
net(graph)