In this notebook, I will be checking if my batching implementation is up to snuff. Essentially, I hope to see that I have fewer batches than windows

In [45]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.lorenz import load_lorenz96_2coupled

import numpy as np
import jax.numpy as jnp


In [47]:
import ml_collections

config = ml_collections.ConfigDict()

# Data params. 
config.n_samples=500
config.input_steps=1
config.output_delay=0 # predict 0 hours into the future
config.output_steps=2
config.timestep_duration=3
config.sample_buffer=-1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
config.time_resolution=120
config.init_buffer_samples=0
config.train_pct=0.7
config.val_pct=0.2
config.test_pct=0.1
config.K=36
config.F=8
config.c=10
config.b=10
config.h=1
config.seed=42
config.normalize=True
config.fully_connected_edges=False

# Optimizer.
config.optimizer = 'adam'
config.learning_rate = 1e-3

# Training hyperparameters.
config.batch_size = 8
config.epochs = 50
config.log_every_epochs = 5
config.eval_every_epochs = 5
config.checkpoint_every_epochs = 10
config.max_checkpts_to_keep = 2 # None means keep all checkpoints

# GNN hyperparameters.
config.model = 'MLPGraphNetwork'
config.n_blocks = 1
config.activation = 'relu'
config.dropout_rate = 0.1
config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
config.layer_norm = False # TODO perhaps we want to turn on later
config.edge_features = (4, 8) # the last feature size will be the number of features that the graph predicts
config.node_features = (32, 2)
config.global_features = None
config.share_params = False

In [48]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts

In [114]:
# generate desired dataset with train/val split and subsampled windows
graph_tuple_dict, batched_graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=config.n_samples,
    input_steps=config.input_steps,
    output_delay=config.output_delay,
    output_steps=config.output_steps,
    timestep_duration=config.timestep_duration,
    sample_buffer=config.sample_buffer,
    time_resolution=config.time_resolution,
    init_buffer_samples=config.init_buffer_samples,
    train_pct=config.train_pct,
    val_pct=config.val_pct,
    test_pct=config.test_pct,
    K=config.K,
    F=config.F,
    c=config.c,
    b=config.b,
    h=config.h,
    seed=config.seed,
    normalize=config.normalize,
    fully_connected_edges=config.fully_connected_edges,
    batch_size=config.batch_size)

len curr batch inputs: 8
[[GraphsTuple(nodes=array([[ 0.99839735, -0.15272506],
       [ 0.5373525 , -0.5844919 ],
       [-0.7494299 ,  0.15804267],
       [-0.15865315,  0.88742316],
       [ 1.192225  , -0.6378343 ],
       [-0.8070752 , -0.9933975 ],
       [-0.9008007 , -0.74469674],
       [-1.0211302 , -0.65930045],
       [ 0.8171373 , -0.25954586],
       [ 1.1780626 ,  0.83491665],
       [-0.7134528 ,  0.4888478 ],
       [ 0.7419451 , -0.56325823],
       [-0.3684877 , -1.0183238 ],
       [-1.7734832 ,  0.10879993],
       [-0.39404944, -1.5159448 ],
       [ 0.8239101 ,  0.12712447],
       [ 0.8791535 ,  0.75964206],
       [-1.0998439 , -1.0779039 ],
       [-1.3527946 , -1.7397071 ],
       [-0.05148394,  0.61863536],
       [ 0.22290471, -0.1617617 ],
       [ 2.381793  ,  2.5122936 ],
       [-0.9449286 ,  1.1262387 ],
       [-0.11906508,  0.61627793],
       [-0.25794533,  0.23470293],
       [ 0.66963   , -1.0680293 ],
       [ 1.5758544 ,  1.440244  ],
       [-0

In [112]:
train_dataset = graph_tuple_dict['train']
val_dataset = graph_tuple_dict['val']
test_dataset = graph_tuple_dict['test']

train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(graph_tuple_dict))
print("graph_tuple_dict keys:", graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(sample_graph))
print("Edges shape:", sample_graph.edges.shape)

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 350
size of train targets: 350
size of val inputs: 100
size of val targets: 100
size of test inputs: 50
size of test targets: 50
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 1
element type in window: <class 'jraph._src.graph.GraphsTuple'>
Edges shape: (180, 1)


In [113]:
print(type(train_inputs[0]))
print(len(train_inputs))
print(type(train_targets[0]))
print(len(train_targets))
print(type(train_targets[0][0]))

<class 'list'>
350
<class 'list'>
350
<class 'jraph._src.graph.GraphsTuple'>


In [72]:
train_dataset = batched_graph_tuple_dict['train']
val_dataset = batched_graph_tuple_dict['val']
test_dataset = batched_graph_tuple_dict['test']

batched_train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
batched_test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
batched_sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(batched_graph_tuple_dict))
print("graph_tuple_dict keys:", batched_graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(batched_train_inputs))

# this is the number of batches.
print("size of train inputs batches:", len(batched_train_inputs))
print("size of train targets batches:", len(train_targets))
print("size of val inputs batches:", len(val_inputs))
print("size of val targets batches:", len(val_targets))
print("size of test inputs batches:", len(batched_test_inputs))
print("size of test targets batches:", len(test_targets))

# this is the number of items in the batch. all will be batch_size.
print("train dataset value type:", type(batched_train_inputs[0]))
print("num items of train inputs axis:", len(batched_train_inputs[0]))
print("num tems of train targets axis:", len(train_targets[0]))
print("num items of size of val inputs axis:", len(val_inputs[0]))
print("num items of val targets axis:", len(val_targets[0]))
print("num items of of test inputs axis:", len(batched_test_inputs[0]))
print("num items of test targets axis:", len(test_targets[0]))

# inside each batch, the data is organized by GraphsTuple
# each has 7 items because the GraphsTuple data structure has 7 parameters (nodes, edges, receivers, senders, globals, n_node, n_edge
print("train dataset value type:", type(batched_train_inputs[0][0]))
print("size of train inputs:", len(batched_train_inputs[0][0]))
print("size of train targets:", len(train_targets[0][0]))
print("size of val inputs:", len(val_inputs[0][0]))
print("size of val targets:", len(val_targets[0][0]))
print("size of test inputs:", len(batched_test_inputs[0][0]))
print("size of test targets:", len(test_targets[0][0]))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(batched_sample_graph))

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs batches: 8
size of train targets batches: 8
size of val inputs batches: 8
size of val targets batches: 8
size of test inputs batches: 8
size of test targets batches: 8
train dataset value type: <class 'jraph._src.graph.GraphsTuple'>
num items of train inputs axis: 7
num tems of train targets axis: 7
num items of size of val inputs axis: 7
num items of val targets axis: 7
num items of of test inputs axis: 7
num items of test targets axis: 7
train dataset value type: <class 'jaxlib.xla_extension.ArrayImpl'>
size of train inputs: 43
size of train targets: 43
size of val inputs: 12
size of val targets: 12
size of test inputs: 6
size of test targets: 6
train inputs window type: <class 'list'>
train input window size (i.e. input steps per w

In [73]:
print(batched_test_inputs)

[GraphsTuple(nodes=Array([[[ 2.44008154e-01, -5.91096044e-01],
        [ 9.05892432e-01,  1.49847835e-01],
        [ 1.89256680e+00,  2.00257063e+00],
        [ 4.64071810e-01,  6.33024096e-01],
        [-4.45582598e-01, -1.37226716e-01],
        [ 1.22530067e+00, -1.03552401e+00],
        [ 1.10967588e+00,  1.04645276e+00],
        [-6.24471724e-01,  1.78544152e+00],
        [ 1.08447897e+00, -4.83662605e-01],
        [-2.52050847e-01, -9.41360712e-01],
        [-1.09418738e+00, -1.62594184e-01],
        [-4.80061352e-01, -1.39728928e+00],
        [ 1.43129957e+00, -9.49485227e-02],
        [-9.33846176e-01,  1.04941106e+00],
        [-2.91001469e-01,  4.21531588e-01],
        [-2.46066302e-01, -5.80154538e-01],
        [ 4.61132646e-01,  8.70635033e-01],
        [ 9.17534828e-01, -5.21893024e-01],
        [-1.49767399e+00, -1.15308225e+00],
        [-3.54688317e-01, -1.60847807e+00],
        [-2.90655047e-01,  4.58534390e-01],
        [ 3.15749735e-01,  1.41874626e-01],
        [ 9.5

recall that we called jnp.stack() the batched data. because of this, the way we previously iterated through to train (seen below) will no longer work. solution: vmap!

In [74]:
print(test_inputs)

[[GraphsTuple(nodes=array([[ 0.24400815, -0.59109604],
       [ 0.90589243,  0.14984784],
       [ 1.8925668 ,  2.0025706 ],
       [ 0.4640718 ,  0.6330241 ],
       [-0.4455826 , -0.13722672],
       [ 1.2253007 , -1.035524  ],
       [ 1.1096759 ,  1.0464528 ],
       [-0.6244717 ,  1.7854415 ],
       [ 1.084479  , -0.4836626 ],
       [-0.25205085, -0.9413607 ],
       [-1.0941874 , -0.16259418],
       [-0.48006135, -1.3972893 ],
       [ 1.4312996 , -0.09494852],
       [-0.9338462 ,  1.049411  ],
       [-0.29100147,  0.4215316 ],
       [-0.2460663 , -0.58015454],
       [ 0.46113265,  0.87063503],
       [ 0.9175348 , -0.521893  ],
       [-1.497674  , -1.1530823 ],
       [-0.35468832, -1.6084781 ],
       [-0.29065505,  0.4585344 ],
       [ 0.31574973,  0.14187463],
       [ 0.95250857,  1.2174339 ],
       [-1.1355298 ,  0.09053281],
       [ 0.17716452, -0.1107817 ],
       [-0.2955976 , -2.0366735 ],
       [ 0.08295456,  0.47286272],
       [ 2.0425375 ,  0.12580045],


In [75]:
print(batched_train_inputs)

[GraphsTuple(nodes=Array([[[ 0.99839735, -0.15272506],
        [ 0.5373525 , -0.5844919 ],
        [-0.7494299 ,  0.15804267],
        ...,
        [ 1.1945474 ,  2.1708732 ],
        [ 0.33699018, -0.5155333 ],
        [-0.89399034, -0.94752365]],

       [[ 1.1363362 ,  0.05768512],
        [ 0.34467852,  0.54176795],
        [-1.1669252 , -0.271508  ],
        ...,
        [-1.1560173 , -0.37617916],
        [-0.5561475 , -0.49043205],
        [-0.43386835, -1.0037299 ]],

       [[ 1.0348173 ,  0.41358483],
        [-0.9960111 , -0.9060984 ],
        [-0.98625994, -0.17224056],
        ...,
        [-1.7019933 , -1.0716041 ],
        [-0.2950277 , -0.9655564 ],
        [ 0.77597743,  2.0876803 ]],

       ...,

       [[ 0.5896881 , -0.4687768 ],
        [ 0.9684779 ,  1.0085859 ],
        [-0.6375074 , -0.7186262 ],
        ...,
        [-1.0262748 , -1.6156013 ],
        [-1.4303272 , -0.7703537 ],
        [-1.6185899 , -1.2360436 ]],

       [[-0.16479048, -1.4331877 ],
        

lets see if this data structure works with a basic vmapped function

In [84]:
import jax
import jraph

def model(graph):
    print("graph happens now")
    print("len graph:", len(graph)) # a batch passed in with vmap has a length of batch_size (yay)
    print("type graph:", type(graph[0])) # each graph in the batch is a GraphsTuple. yay
    print("graph nodes shape:", graph[0][0].shape) # these are the nodes!
    stacked_graph = jax.tree_map(lambda *args: jnp.stack(args), *graph)
    def window_model(window):
        print("window happens now")
        print("len graph:", len(window)) # a window passed in with vmap should contain a GraphsTuple, of length 7
        print("type graph:", type(window[0])) # each graph in the batch is a GraphsTuple. yay
        total = 42
        return total
    vectorized_window = jax.vmap(window_model, in_axes=0, out_axes=0)
    sum = vectorized_window(stacked_graph)
    return sum

# Define a batched version of the model function
vectorized_model = jax.vmap(model, in_axes=0, out_axes=0)

stacked_batched_graph = batched_graph_tuple_dict['test']['inputs']

# Apply the vectorized model to the stacked batched graphs
num = vectorized_model(stacked_batched_graph)
print(num) # one 42 for every batch!


graph happens now
len graph: 8
type graph: <class 'jraph._src.graph.GraphsTuple'>
graph nodes shape: (36, 2)
window happens now
len graph: 7
type graph: <class 'jax._src.interpreters.batching.BatchTracer'>
[[42 42 42 42 42 42 42 42]
 [42 42 42 42 42 42 42 42]
 [42 42 42 42 42 42 42 42]
 [42 42 42 42 42 42 42 42]
 [42 42 42 42 42 42 42 42]
 [42 42 42 42 42 42 42 42]]


YIPPPEEEEE!!!!!! IT WORKS WITH A BASIC FUNCTION!!!!!!!!!!

ok so we have the batching data structure mostly working. lets see if this runs! at all!

In [85]:
from utils.jraph_training import train_and_evaluate_with_data
from utils.jraph_vis import plot_predictions

In [86]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [109]:
workdir="tests/outputs/batch_test1"

trained_state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=config, workdir=workdir, window_datasets=graph_tuple_dict, batched_datasets=batched_graph_tuple_dict)


INFO:absl:Hyperparameters: {'F': 8, 'K': 36, 'activation': 'relu', 'b': 10, 'batch_size': 8, 'c': 10, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.1, 'edge_features': (4, 8), 'epochs': 50, 'eval_every_epochs': 5, 'fully_connected_edges': False, 'global_features': None, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 1, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 5, 'max_checkpts_to_keep': 2, 'model': 'MLPGraphNetwork', 'n_blocks': 1, 'n_samples': 500, 'node_features': (32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 2, 'sample_buffer': -2, 'seed': 42, 'share_params': False, 'skip_connections': False, 'test_pct': 0.1, 'time_resolution': 120, 'timestep_duration': 3, 'train_pct': 0.7, 'val_pct': 0.2}
INFO:absl:Initializing network.
INFO:absl:
+----------------------------------------+----------+------+----------+-------+
| Name                                   | Shape    | Size | Mean     | Std   |
+----------------------------

input graph type: <class 'jraph._src.graph.GraphsTuple'>
GraphsTuple(nodes=Traced<ShapedArray(float32[36,2])>with<DynamicJaxprTrace(level=1/0)>, edges=Traced<ShapedArray(float32[180,1])>with<DynamicJaxprTrace(level=1/0)>, receivers=Traced<ShapedArray(int32[180])>with<DynamicJaxprTrace(level=1/0)>, senders=Traced<ShapedArray(int32[180])>with<DynamicJaxprTrace(level=1/0)>, globals=Traced<ShapedArray(float32[1,1])>with<DynamicJaxprTrace(level=1/0)>, n_node=Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>, n_edge=Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>)
blank 7
[0] 36
[0][0] 2
<class 'jraph._src.graph.GraphsTuple'>
GraphsTuple(nodes=Traced<ShapedArray(float32[36,2])>with<DynamicJaxprTrace(level=2/0)>, edges=Traced<ShapedArray(float32[180,1])>with<DynamicJaxprTrace(level=2/0)>, receivers=Traced<ShapedArray(int32[180])>with<DynamicJaxprTrace(level=2/0)>, senders=Traced<ShapedArray(int32[180])>with<DynamicJaxprTrace(level=2/0)>, globals=Traced<ShapedArra

ValueError: too many values to unpack (expected 7)

unbatched:

batched:

In [None]:
plot_predictions(
    config=config,
    workdir=workdir, # for loading checkpoints 
    plot_ith_rollout_step=0, # 0 indexed # for this study, we have a 4-step rollout 
    # dataset,
    # preds,
    # timestep_duration,
    # n_rollout_steps,
    #  total_steps,
    node=0, # 0-indexed 
    plot_mode="val", # i.e. "train"/"val"/"test"
    plot_days=60,
    title="Val Predictions for node 0, rollout step 0"
)

ok so the training isn't working that well
we need to see what the 