stack version of batching

In [36]:
import jraph
import jax
import jax.numpy as jnp

# Create consistent GraphsTuples
def create_graph(nodes, edges, receivers, senders, globals):
    return jraph.GraphsTuple(
        nodes=jnp.array(nodes),
        edges=jnp.array(edges),
        receivers=jnp.array(receivers),
        senders=jnp.array(senders),
        n_node=jnp.array([len(nodes)]),
        n_edge=jnp.array([len(edges)]),
        globals=jnp.array(globals)
    )

GraphTuple1 = create_graph(
    nodes=[[1.0, 2.0]],
    edges=[[1.0], [2.0]],
    receivers=[0, 1],
    senders=[1, 0],
    globals=[1.0]
)

GraphTuple2 = create_graph(
    nodes=[[3.0, 4.0]],
    edges=[[3.0], [4.0]],
    receivers=[0, 1],
    senders=[1, 0],
    globals=[2.0]
)

GraphTuple3 = create_graph(
    nodes=[[5.0, 6.0]],
    edges=[[5.0], [6.0]],
    receivers=[0, 1],
    senders=[1, 0],
    globals=[3.0]
)

GraphTuple4 = create_graph(
    nodes=[[7.0, 8.0]],
    edges=[[7.0], [8.0]],
    receivers=[0, 1],
    senders=[1, 0],
    globals=[4.0]
)



In [37]:
# Create a window containing two GraphsTuples
window1 = [GraphTuple1, GraphTuple2]

window2 = [GraphTuple3, GraphTuple4]

# Batch the window
batched_window1 = jraph.batch(window1)
batched_window2 = jraph.batch(window2)


In [38]:
# now, batch the two windows together
window_list = [batched_window1, batched_window2]
batched_graph = jraph.batch(window_list)

In [41]:
def model(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    # Example operation: multiply node features by 2
    updated_nodes = graph.nodes + 10
    return graph._replace(nodes=updated_nodes)

In [43]:
# Create a list of batched graphs (assuming multiple batches)
batched_graphs_list = [batched_graph, batched_graph]  # Just repeating the same for demonstration

# Define a batched version of the model function
vectorized_model = jax.vmap(model)

# Since `vmap` expects batched input, stack the batched graphs
stacked_batched_graphs = jax.tree_map(lambda *args: jnp.stack(args), *batched_graphs_list)

# Apply the vectorized model to the stacked batched graphs
output_graphs = vectorized_model(stacked_batched_graphs)

print(stacked_batched_graphs)
print(output_graphs)

GraphsTuple(nodes=Array([[[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]],

       [[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.]]], dtype=float32), edges=Array([[[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]],

       [[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]]], dtype=float32), receivers=Array([[0, 1, 1, 2, 2, 3, 3, 4],
       [0, 1, 1, 2, 2, 3, 3, 4]], dtype=int32), senders=Array([[1, 0, 2, 1, 3, 2, 4, 3],
       [1, 0, 2, 1, 3, 2, 4, 3]], dtype=int32), globals=Array([[1., 2., 3., 4.],
       [1., 2., 3., 4.]], dtype=float32), n_node=Array([[1, 1, 1, 1],
       [1, 1, 1, 1]], dtype=int32), n_edge=Array([[2, 2, 2, 2],
       [2, 2, 2, 2]], dtype=int32))
GraphsTuple(nodes=Array([[[11., 12.],
        [13., 14.],
        [15., 16.],
        [17., 18.]],

       [[11., 12.],
        [13., 14.],
        [15., 16.],
        [17., 18.]]], dty