# jraph.examples.basic

In [None]:
import jax.numpy as jnp
import jraph

In [None]:
# Creating graph tuples.

# Creates a GraphsTuple from scratch containing a single graph.
# The graph has 3 nodes and 2 edges.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
single_graph = jraph.GraphsTuple(
    n_node=jnp.asarray([3]),
    n_edge=jnp.asarray([2]),
    nodes=jnp.ones((3, 4)),
    edges=jnp.ones((2, 5)),
    globals=jnp.ones((1, 6)),
    senders=jnp.asarray([0, 1]),
    receivers=jnp.asarray([2, 2]),
)


single_graph

# NOTE:
# rows == nodes / edges feature, columns == elements of the feature

In [None]:
# Creates a GraphsTuple from scratch containing a single graph with nested
# feature vectors.
# The graph has 3 nodes and 2 edges.
# The feature vector can be arbitrary nested types of dict, list and tuple,
# or any other type you registered with jax.tree_util.register_pytree_node.
nested_graph = jraph.GraphsTuple(
    n_node=jnp.asarray([3]),
    n_edge=jnp.asarray([2]),
    nodes={"a": jnp.ones((3, 4))},
    edges={"b": jnp.ones((2, 5))},
    globals={"c": jnp.ones((1, 6))},
    senders=jnp.array([0, 1]),
    receivers=jnp.array([2, 2]),
)

nested_graph

# NOTE
# this is a homogenous graph, where the property a of the nodes is jnp.ones((3, 4)), etc.

In [None]:
# Creates a GraphsTuple from scratch containing 2 graphs using an implicit
# batch dimension.
# The first graph has 3 nodes and 2 edges.
# The second graph has 1 node and 1 edge.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
implicitly_batched_graph = jraph.GraphsTuple(
    n_node=jnp.asarray([3, 1]),
    n_edge=jnp.asarray([2, 1]),
    nodes=jnp.ones((4, 4)),
    edges=jnp.ones((3, 5)),
    globals=jnp.ones((2, 6)),
    senders=jnp.array([0, 1, 3]),
    receivers=jnp.array([2, 2, 3]),
)

implicitly_batched_graph

# NOTE
# the second graph's node is self-referential

In [None]:
# Batching graphs can be challenging. There are in general two approaches:
# 1. Implicit batching: Independent graphs are combined into the same
#    GraphsTuple first, and the padding is added to the combined graph.
# 2. Explicit batching: Pad all graphs to a maximum size, stack them together
#    using an explicit batch dimension followed by jax.vmap.
# Both approaches are shown below.

# Creates a GraphsTuple from two existing GraphsTuple using an implicit
# batch dimension.
# The GraphsTuple will contain three graphs.
implicitly_batched_graph = jraph.batch([single_graph, implicitly_batched_graph])

# Creates multiple GraphsTuples from an existing GraphsTuple with an implicit
# batch dimension.
graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph)

print("=" * 5, "Implicit Batching", "=" * 5)
print(implicitly_batched_graph)
print(graph_1)
print(graph_2)
print(graph_3)


# Creates a padded GraphsTuple from an existing GraphsTuple.
# The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
# Three graphs are added for the padding.
# First a dummy graph which contains the padding nodes and edges and secondly
# two empty graphs without nodes or edges to pad out the graphs.
padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4)

# Creates a GraphsTuple from an existing padded GraphsTuple.
# The previously added padding is removed.
single_graph = jraph.unpad_with_graphs(padded_graph)

print("\n", "=" * 5, "Explicit Batching", "=" * 5)
print(padded_graph)
print(single_graph)