The simple example from jraph repo

In [1]:
import jraph
import jax.numpy as jnp

# Define a three node graph, each node has an integer as its feature.
node_features = jnp.array([[0.], [1.], [2.]])

# We will construct a graph for which there is a directed edge between each node
# and its successor. We define this with `senders` (source nodes) and `receivers`
# (destination nodes).
senders = jnp.array([0, 1, 2])
receivers = jnp.array([1, 2, 0])

# You can optionally add edge attributes.
edges = jnp.array([[5.], [6.], [7.]])

# We then save the number of nodes and the number of edges.
# This information is used to make running GNNs over multiple graphs
# in a GraphsTuple possible.
n_node = jnp.array([3])
n_edge = jnp.array([3])

# Optionally you can add `global` information, such as a graph label.

global_context = jnp.array([[1]])
graph = jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
edges=edges, n_node=n_node, n_edge=n_edge, globals=global_context)

In [2]:
graph

GraphsTuple(nodes=DeviceArray([[0.],
             [1.],
             [2.]], dtype=float32), edges=DeviceArray([[5.],
             [6.],
             [7.]], dtype=float32), receivers=DeviceArray([1, 2, 0], dtype=int32), senders=DeviceArray([0, 1, 2], dtype=int32), globals=DeviceArray([[1]], dtype=int32), n_node=DeviceArray([3], dtype=int32), n_edge=DeviceArray([3], dtype=int32))

In [3]:
two_graph_graphstuple = jraph.batch([graph, graph])

In [4]:
jraph.batch([graph, graph]).nodes

DeviceArray([[0.],
             [1.],
             [2.],
             [0.],
             [1.],
             [2.]], dtype=float32)