In this notebook, I will be checking if my batching implementation is up to snuff. Essentially, I hope to see that I have fewer batches than windows

In [1]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

In [2]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.lorenz import load_lorenz96_2coupled

import numpy as np
import jax.numpy as jnp


In [17]:
import ml_collections

config = ml_collections.ConfigDict()

# Data params. 
config.n_samples=22
config.input_steps=1
config.output_delay=0 # predict 0 hours into the future
config.output_steps=2
config.timestep_duration=3
config.sample_buffer=-1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
config.time_resolution=120
config.init_buffer_samples=0
config.train_pct=0.7
config.val_pct=0.2
config.test_pct=0.1
config.K=6
config.F=8
config.c=10
config.b=10
config.h=1
config.seed=42
config.normalize=True
config.fully_connected_edges=False

# Optimizer.
config.optimizer = 'adam'
config.learning_rate = 1e-3

# Training hyperparameters.
config.batch_size = 2
config.epochs = 50
config.log_every_epochs = 5
config.eval_every_epochs = 5
config.checkpoint_every_epochs = 10
config.max_checkpts_to_keep = 2 # None means keep all checkpoints

# GNN hyperparameters.
config.model = 'MLPGraphNetwork'
config.n_blocks = 1
config.activation = 'relu'
config.dropout_rate = 0.1
config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
config.layer_norm = False # TODO perhaps we want to turn on later
config.edge_features = (4, 8) # the last feature size will be the number of features that the graph predicts
config.node_features = (32, 2)
config.global_features = None
config.share_params = False

In [4]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts

In [34]:
# generate desired dataset with train/val split and subsampled windows
graph_tuple_dict, batched_graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=config.n_samples,
    input_steps=config.input_steps,
    output_delay=config.output_delay,
    output_steps=config.output_steps,
    timestep_duration=config.timestep_duration,
    sample_buffer=config.sample_buffer,
    time_resolution=config.time_resolution,
    init_buffer_samples=config.init_buffer_samples,
    train_pct=config.train_pct,
    val_pct=config.val_pct,
    test_pct=config.test_pct,
    K=config.K,
    F=config.F,
    c=config.c,
    b=config.b,
    h=config.h,
    seed=config.seed,
    normalize=config.normalize,
    fully_connected_edges=config.fully_connected_edges,
    batch_size=config.batch_size)

ValueError: All input arrays must have the same shape.

In [6]:
train_dataset = graph_tuple_dict['train']
val_dataset = graph_tuple_dict['val']
test_dataset = graph_tuple_dict['test']

train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(graph_tuple_dict))
print("graph_tuple_dict keys:", graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(sample_graph))
print("Edges shape:", sample_graph.edges.shape)

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 14
size of train targets: 14
size of val inputs: 4
size of val targets: 4
size of test inputs: 2
size of test targets: 2
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 1
element type in window: <class 'jraph._src.graph.GraphsTuple'>
Edges shape: (30, 1)


In [7]:
print(type(train_inputs[0]))
print(len(train_inputs))
print(type(train_targets[0]))
print(len(train_targets))
print(type(train_targets[0][0]))

<class 'list'>
14
<class 'list'>
14
<class 'jraph._src.graph.GraphsTuple'>


In [8]:
train_dataset = batched_graph_tuple_dict['train']
val_dataset = batched_graph_tuple_dict['val']
test_dataset = batched_graph_tuple_dict['test']

batched_train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
batched_sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(batched_graph_tuple_dict))
print("graph_tuple_dict keys:", batched_graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(batched_train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(batched_sample_graph))

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 7
size of train targets: 7
size of val inputs: 2
size of val targets: 2
size of test inputs: 1
size of test targets: 1
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 1
element type in window: <class 'jraph._src.graph.GraphsTuple'>


In [9]:
print(type(train_inputs[0]))
print(len(train_inputs))
print(type(train_inputs[0]))
print(len(train_inputs))
print(type(train_inputs[0][0]))

<class 'list'>
14
<class 'list'>
14
<class 'jraph._src.graph.GraphsTuple'>


In [10]:
print(train_inputs)

[[GraphsTuple(nodes=array([[ 1.4310255 ,  0.9362091 ],
       [-0.23781072,  0.9620181 ],
       [-0.2907959 , -0.59176767],
       [-1.3872254 , -1.2863369 ],
       [-1.3582085 , -1.94962   ],
       [ 0.79842323, -0.67540556]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6]

In [11]:
print(batched_train_inputs)

[GraphsTuple(nodes=Array([[ 1.4310255 ,  0.9362091 ],
       [-0.23781072,  0.9620181 ],
       [-0.2907959 , -0.59176767],
       [-1.3872254 , -1.2863369 ],
       [-1.3582085 , -1.94962   ],
       [ 0.79842323, -0.67540556],
       [ 1.5526141 ,  1.2326598 ],
       [-0.3102697 , -0.6572716 ],
       [-0.305259  , -1.565111  ],
       [-1.3175128 , -0.90410674],
       [-1.289803  , -1.7428546 ],
       [ 0.71848804, -1.2394744 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
      

ok so we have the batching data structure mostly working. lets see if this runs! at all!

In [12]:
from utils.jraph_training import train_and_evaluate_with_data
from utils.jraph_vis import plot_predictions

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [14]:
workdir="tests/outputs/batch_test"

trained_state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=config, workdir=workdir, datasets=batched_graph_tuple_dict)


INFO:absl:Hyperparameters: {'F': 8, 'K': 6, 'activation': 'relu', 'b': 10, 'batch_size': 2, 'c': 10, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.1, 'edge_features': (4, 8), 'epochs': 50, 'eval_every_epochs': 5, 'fully_connected_edges': False, 'global_features': None, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 1, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 5, 'max_checkpts_to_keep': 2, 'model': 'MLPGraphNetwork', 'n_blocks': 1, 'n_samples': 20, 'node_features': (32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 2, 'sample_buffer': -2, 'seed': 42, 'share_params': False, 'skip_connections': False, 'test_pct': 0.1, 'time_resolution': 120, 'timestep_duration': 3, 'train_pct': 0.7, 'val_pct': 0.2}
INFO:absl:Initializing network.


ValueError: too many values to unpack (expected 7)

unbatched:

In [None]:
[[GraphsTuple(nodes=array([[ 1.4310255 ,  0.9362091 ],
       [-0.23781072,  0.9620181 ],
       [-0.2907959 , -0.59176767],
       [-1.3872254 , -1.2863369 ],
       [-1.3582085 , -1.94962   ],
       [ 0.79842323, -0.67540556]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.5526141 ,  1.2326598 ],
       [-0.3102697 , -0.6572716 ],
       [-0.305259  , -1.565111  ],
       [-1.3175128 , -0.90410674],
       [-1.289803  , -1.7428546 ],
       [ 0.71848804, -1.2394744 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.6542437 ,  1.2344846 ],
       [-0.3747552 , -1.9412924 ],
       [-0.30857062, -1.285931  ],
       [-1.2460263 , -0.45623457],
       [-1.2167777 , -1.9310918 ],
       [ 0.6529863 , -0.9085993 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.7392533 ,  0.95759296],
       [-0.43252218, -2.1865675 ],
       [-0.30307558, -0.85924834],
       [-1.1735647 ,  0.06826954],
       [-1.1398444 , -1.8501254 ],
       [ 0.60270166, -0.25738627]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.810329  ,  0.53846467],
       [-0.48556253, -1.7714826 ],
       [-0.2893304 , -0.5998158 ],
       [-1.100758  ,  0.6552498 ],
       [-1.0604134 , -1.5036267 ],
       [ 0.56787527,  0.23045582]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.8690554 ,  0.2939899 ],
       [-0.5350527 , -1.1591498 ],
       [-0.26760244, -0.44173434],
       [-1.0276515 ,  1.0719203 ],
       [-0.97968096, -0.9424964 ],
       [ 0.548275  ,  0.52994066]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9160509 ,  0.32571945],
       [-0.58140254, -0.6382894 ],
       [-0.23829897, -0.37629175],
       [-0.95356774,  1.1534646 ],
       [-0.89854074, -0.2603695 ],
       [ 0.54331297,  0.75573635]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9516468 ,  0.53705376],
       [-0.62466705, -0.30256033],
       [-0.20179231, -0.43128425],
       [-0.8774405 ,  0.8398546 ],
       [-0.81745327,  0.42124888],
       [ 0.55214036,  1.0080566 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9761665 ,  0.78660697],
       [-0.66475624, -0.23715739],
       [-0.15871377, -0.3898767 ],
       [-0.7979403 ,  0.08430386],
       [-0.73636436,  0.8898245 ],
       [ 0.5737296 ,  1.4118342 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.990026  ,  0.92309517],
       [-0.70164853, -0.2702189 ],
       [-0.11059661,  0.0703419 ],
       [-0.7139699 , -0.6550989 ],
       [-0.65410423,  0.55338085],
       [ 0.6069867 ,  1.8810234 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9931473 ,  1.1659274 ],
       [-0.73643506, -0.01060382],
       [-0.05860712,  0.30287394],
       [-0.6262841 , -0.47436485],
       [-0.56845033, -0.392648  ],
       [ 0.6509435 ,  1.8739427 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9845006 ,  1.5737634 ],
       [-0.7700825 ,  0.23049435],
       [-0.00250266,  0.3548358 ],
       [-0.53610665,  0.11151384],
       [-0.4796957 , -0.6321924 ],
       [ 0.70462346,  1.2958279 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.963262  ,  1.9060342 ],
       [-0.8029301 ,  0.5273168 ],
       [ 0.05697121,  0.5734809 ],
       [-0.4424307 ,  0.31367725],
       [-0.3893648 , -0.4921234 ],
       [ 0.7666247 ,  0.5045918 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=array([[ 1.9291928 ,  1.833312  ],
       [-0.83551645,  0.9811772 ],
       [ 0.11901677,  0.7879613 ],
       [-0.3434139 ,  0.20988221],
       [-0.29762927, -0.25796196],
       [ 0.8351705 , -0.37164098]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))]]


batched:

In [None]:
[[GraphsTuple(nodes=Array([[ 1.4310255 ,  0.9362091 ],
       [-0.23781072,  0.9620181 ],
       [-0.2907959 , -0.59176767],
       [-1.3872254 , -1.2863369 ],
       [-1.3582085 , -1.94962   ],
       [ 0.79842323, -0.67540556]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.5526141 ,  1.2326598 ],
       [-0.3102697 , -0.6572716 ],
       [-0.305259  , -1.565111  ],
       [-1.3175128 , -0.90410674],
       [-1.289803  , -1.7428546 ],
       [ 0.71848804, -1.2394744 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.6542437 ,  1.2344846 ],
       [-0.3747552 , -1.9412924 ],
       [-0.30857062, -1.285931  ],
       [-1.2460263 , -0.45623457],
       [-1.2167777 , -1.9310918 ],
       [ 0.6529863 , -0.9085993 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.7392533 ,  0.95759296],
       [-0.43252218, -2.1865675 ],
       [-0.30307558, -0.85924834],
       [-1.1735647 ,  0.06826954],
       [-1.1398444 , -1.8501254 ],
       [ 0.60270166, -0.25738627]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.810329  ,  0.53846467],
       [-0.48556253, -1.7714826 ],
       [-0.2893304 , -0.5998158 ],
       [-1.100758  ,  0.6552498 ],
       [-1.0604134 , -1.5036267 ],
       [ 0.56787527,  0.23045582]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.8690554 ,  0.2939899 ],
       [-0.5350527 , -1.1591498 ],
       [-0.26760244, -0.44173434],
       [-1.0276515 ,  1.0719203 ],
       [-0.97968096, -0.9424964 ],
       [ 0.548275  ,  0.52994066]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.9160509 ,  0.32571945],
       [-0.58140254, -0.6382894 ],
       [-0.23829897, -0.37629175],
       [-0.95356774,  1.1534646 ],
       [-0.89854074, -0.2603695 ],
       [ 0.54331297,  0.75573635]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.9516468 ,  0.53705376],
       [-0.62466705, -0.30256033],
       [-0.20179231, -0.43128425],
       [-0.8774405 ,  0.8398546 ],
       [-0.81745327,  0.42124888],
       [ 0.55214036,  1.0080566 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.9761665 ,  0.78660697],
       [-0.66475624, -0.23715739],
       [-0.15871377, -0.3898767 ],
       [-0.7979403 ,  0.08430386],
       [-0.73636436,  0.8898245 ],
       [ 0.5737296 ,  1.4118342 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.990026  ,  0.92309517],
       [-0.70164853, -0.2702189 ],
       [-0.11059661,  0.0703419 ],
       [-0.7139699 , -0.6550989 ],
       [-0.65410423,  0.55338085],
       [ 0.6069867 ,  1.8810234 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.9931473 ,  1.1659274 ],
       [-0.73643506, -0.01060382],
       [-0.05860712,  0.30287394],
       [-0.6262841 , -0.47436485],
       [-0.56845033, -0.392648  ],
       [ 0.6509435 ,  1.8739427 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.9845006 ,  1.5737634 ],
       [-0.7700825 ,  0.23049435],
       [-0.00250266,  0.3548358 ],
       [-0.53610665,  0.11151384],
       [-0.4796957 , -0.6321924 ],
       [ 0.70462346,  1.2958279 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))], [GraphsTuple(nodes=Array([[ 1.963262  ,  1.9060342 ],
       [-0.8029301 ,  0.5273168 ],
       [ 0.05697121,  0.5734809 ],
       [-0.4424307 ,  0.31367725],
       [-0.3893648 , -0.4921234 ],
       [ 0.7666247 ,  0.5045918 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32)), GraphsTuple(nodes=Array([[ 1.9291928 ,  1.833312  ],
       [-0.83551645,  0.9811772 ],
       [ 0.11901677,  0.7879613 ],
       [-0.3434139 ,  0.20988221],
       [-0.29762927, -0.25796196],
       [ 0.8351705 , -0.37164098]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6], dtype=int32), n_edge=Array([30], dtype=int32))]]


In [None]:
plot_predictions(
    config=config,
    workdir=workdir, # for loading checkpoints 
    plot_ith_rollout_step=0, # 0 indexed # for this study, we have a 4-step rollout 
    # dataset,
    # preds,
    # timestep_duration,
    # n_rollout_steps,
    #  total_steps,
    node=0, # 0-indexed 
    plot_mode="val", # i.e. "train"/"val"/"test"
    plot_days=60,
    title="Val Predictions for node 0, rollout step 0"
)

ok so the training isn't working that well
we need to see what the 