In this notebook, I will be checking if my batching implementation is up to snuff. Essentially, I hope to see that I have fewer batches than windows

In [67]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [68]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.lorenz import load_lorenz96_2coupled

import numpy as np
import jax.numpy as jnp


In [130]:
import ml_collections

config = ml_collections.ConfigDict()

# Data params. 
config.n_samples=500
config.input_steps=1
config.output_delay=0 # predict 0 hours into the future
config.output_steps=2
config.timestep_duration=3
config.sample_buffer=-1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
config.time_resolution=120
config.init_buffer_samples=0
config.train_pct=0.7
config.val_pct=0.2
config.test_pct=0.1
config.K=6
config.F=8
config.c=10
config.b=10
config.h=1
config.seed=42
config.normalize=True
config.fully_connected_edges=False

# Optimizer.
config.optimizer = 'adam'
config.learning_rate = 1e-3

# Training hyperparameters.
config.batch_size = 8
config.epochs = 50
config.log_every_epochs = 5
config.eval_every_epochs = 5
config.checkpoint_every_epochs = 10
config.max_checkpts_to_keep = 2 # None means keep all checkpoints

# GNN hyperparameters.
config.model = 'MLPGraphNetwork'
config.n_blocks = 1
config.activation = 'relu'
config.dropout_rate = 0.1
config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
config.layer_norm = False # TODO perhaps we want to turn on later
config.edge_features = (4, 8) # the last feature size will be the number of features that the graph predicts
config.node_features = (32, 2)
config.global_features = None
config.share_params = False

In [131]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts

In [132]:
# generate desired dataset with train/val split and subsampled windows
graph_tuple_dict, batched_graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=config.n_samples,
    input_steps=config.input_steps,
    output_delay=config.output_delay,
    output_steps=config.output_steps,
    timestep_duration=config.timestep_duration,
    sample_buffer=config.sample_buffer,
    time_resolution=config.time_resolution,
    init_buffer_samples=config.init_buffer_samples,
    train_pct=config.train_pct,
    val_pct=config.val_pct,
    test_pct=config.test_pct,
    K=config.K,
    F=config.F,
    c=config.c,
    b=config.b,
    h=config.h,
    seed=config.seed,
    normalize=config.normalize,
    fully_connected_edges=config.fully_connected_edges,
    batch_size=config.batch_size)

INFO:root:starting integration


made it to the last graph
made it to the last graph
made it to the last graph


In [133]:
train_dataset = graph_tuple_dict['train']
val_dataset = graph_tuple_dict['val']
test_dataset = graph_tuple_dict['test']

train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(graph_tuple_dict))
print("graph_tuple_dict keys:", graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(sample_graph))
print("Edges shape:", sample_graph.edges.shape)

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 350
size of train targets: 350
size of val inputs: 100
size of val targets: 100
size of test inputs: 50
size of test targets: 50
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 2
element type in window: <class 'jraph._src.graph.GraphsTuple'>
Edges shape: (30, 1)


In [134]:
print(type(train_inputs[0]))
print(len(train_inputs))
print(type(train_targets[0]))
print(len(train_targets))
print(type(train_targets[0][0]))

<class 'list'>
350
<class 'list'>
350
<class 'jraph._src.graph.GraphsTuple'>


In [152]:
train_dataset = batched_graph_tuple_dict['train']
val_dataset = batched_graph_tuple_dict['val']
test_dataset = batched_graph_tuple_dict['test']

batched_train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
batched_sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(batched_graph_tuple_dict))
print("graph_tuple_dict keys:", batched_graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

# each has 7 items because the GraphsTuple data structure has 7 parameters (nodes, edges, receivers, senders, globals, n_node, n_edge)
print("num items of train inputs axis:", len(batched_train_inputs))
print("num tems of train targets axis:", len(train_targets))
print("num items of size of val inputs axis:", len(val_inputs))
print("num items of val targets axis:", len(val_targets))
print("num items of of test inputs axis:", len(test_inputs))
print("num items of test targets axis:", len(test_targets))

print("size of train inputs:", len(batched_train_inputs[0]))
print("size of train targets:", len(train_targets[0]))
print("size of val inputs:", len(val_inputs[0]))
print("size of val targets:", len(val_targets[0]))
print("size of test inputs:", len(test_inputs[0]))
print("size of test targets:", len(test_targets[0]))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(batched_sample_graph))

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
num items of train inputs axis: 7
num tems of train targets axis: 7
num items of size of val inputs axis: 7
num items of val targets axis: 7
num items of of test inputs axis: 7
num items of test targets axis: 7
size of train inputs: 43
size of train targets: 43
size of val inputs: 12
size of val targets: 12
size of test inputs: 6
size of test targets: 6
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 2
element type in window: <class 'jraph._src.graph.GraphsTuple'>


recall that we called jnp.stack() the batched data. because of this, the way we previously iterated through to train (seen below) will no longer work. solution: vmap!

In [151]:
input_data = batched_graph_tuple_dict['train']['inputs']
target_data = batched_graph_tuple_dict['train']['targets']
i = 0
for (input_batch_graphs, target_batch_graphs) in zip(input_data, target_data):
    print("iteration ", i)
    print(input_batch_graphs)
    print(target_batch_graphs)
    i+=1

iteration  0
[[[ 1.0854319e+00  1.0938451e+00]
  [ 8.0362368e-01 -8.2057786e-01]
  [-1.5494671e+00 -8.1950420e-01]
  ...
  [-7.3889124e-01 -4.0229657e-01]
  [-1.8935165e-03 -1.8134503e-01]
  [ 1.4562833e+00  1.3316844e+00]]

 [[ 7.6141250e-01 -5.7864499e-01]
  [-1.5537400e+00 -6.1634815e-01]
  [-1.0359222e+00 -1.7511472e-01]
  ...
  [-1.2141046e-01  9.3992949e-02]
  [ 4.6704370e-01  5.6095117e-01]
  [ 1.2834107e+00  2.1811616e-01]]

 [[-1.0011412e+00 -3.7116888e-01]
  [-1.3818300e+00 -1.1013335e+00]
  [-2.6139760e-01 -4.8096818e-03]
  ...
  [ 6.3763362e-01  5.2901411e-01]
  [ 1.0736444e+00  9.1244847e-01]
  [-1.1884686e-01  3.5082124e-04]]

 ...

 [[ 1.5066472e-01  1.5783851e-01]
  [-9.0608507e-02 -4.4729108e-01]
  [ 4.8095603e-02  3.2209247e-01]
  ...
  [ 8.9830953e-01  1.0816194e+00]
  [-4.2269218e-01 -1.6746642e-01]
  [-1.1046348e+00 -9.0411937e-01]]

 [[ 4.1601068e-01  8.8523871e-01]
  [ 1.2725252e+00  6.5662414e-02]
  [ 8.4157699e-01  1.1365423e+00]
  ...
  [-8.6041248e-01  1.0232

In [137]:
print(train_inputs)

[[GraphsTuple(nodes=array([[ 1.0854319 ,  1.0938451 ],
       [ 0.8036237 , -0.82057786],
       [-1.5494671 , -0.8195042 ],
       [-1.1777799 , -0.51347125],
       [ 0.0165526 , -0.12953193],
       [ 1.0170294 ,  0.0603915 ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6]

In [138]:
print(batched_train_inputs[0])

[[[ 1.0854319e+00  1.0938451e+00]
  [ 8.0362368e-01 -8.2057786e-01]
  [-1.5494671e+00 -8.1950420e-01]
  ...
  [-7.3889124e-01 -4.0229657e-01]
  [-1.8935165e-03 -1.8134503e-01]
  [ 1.4562833e+00  1.3316844e+00]]

 [[ 7.6141250e-01 -5.7864499e-01]
  [-1.5537400e+00 -6.1634815e-01]
  [-1.0359222e+00 -1.7511472e-01]
  ...
  [-1.2141046e-01  9.3992949e-02]
  [ 4.6704370e-01  5.6095117e-01]
  [ 1.2834107e+00  2.1811616e-01]]

 [[-1.0011412e+00 -3.7116888e-01]
  [-1.3818300e+00 -1.1013335e+00]
  [-2.6139760e-01 -4.8096818e-03]
  ...
  [ 6.3763362e-01  5.2901411e-01]
  [ 1.0736444e+00  9.1244847e-01]
  [-1.1884686e-01  3.5082124e-04]]

 ...

 [[ 1.5066472e-01  1.5783851e-01]
  [-9.0608507e-02 -4.4729108e-01]
  [ 4.8095603e-02  3.2209247e-01]
  ...
  [ 8.9830953e-01  1.0816194e+00]
  [-4.2269218e-01 -1.6746642e-01]
  [-1.1046348e+00 -9.0411937e-01]]

 [[ 4.1601068e-01  8.8523871e-01]
  [ 1.2725252e+00  6.5662414e-02]
  [ 8.4157699e-01  1.1365423e+00]
  ...
  [-8.6041248e-01  1.0232806e+00]
  [-

lets see if this data structure works with a basic vmapped function

In [146]:
import jax
import jraph

def model(graph):
    # Example operation: add 10 to node features
    print("in function")
    updated_nodes = graph.nodes + 10
    return graph._replace(nodes=updated_nodes)

# Define a batched version of the model function
vectorized_model = jax.vmap(model, in_axes=0, out_axes=0)

stacked_batched_graph = batched_graph_tuple_dict['train']['inputs']
print(stacked_batched_graph)

# Apply the vectorized model to the stacked batched graphs
output_graphs = vectorized_model(stacked_batched_graph)

print(output_graphs)

GraphsTuple(nodes=Array([[[ 1.0854319e+00,  1.0938451e+00],
        [ 8.0362368e-01, -8.2057786e-01],
        [-1.5494671e+00, -8.1950420e-01],
        ...,
        [-7.3889124e-01, -4.0229657e-01],
        [-1.8935165e-03, -1.8134503e-01],
        [ 1.4562833e+00,  1.3316844e+00]],

       [[ 7.6141250e-01, -5.7864499e-01],
        [-1.5537400e+00, -6.1634815e-01],
        [-1.0359222e+00, -1.7511472e-01],
        ...,
        [-1.2141046e-01,  9.3992949e-02],
        [ 4.6704370e-01,  5.6095117e-01],
        [ 1.2834107e+00,  2.1811616e-01]],

       [[-1.0011412e+00, -3.7116888e-01],
        [-1.3818300e+00, -1.1013335e+00],
        [-2.6139760e-01, -4.8096818e-03],
        ...,
        [ 6.3763362e-01,  5.2901411e-01],
        [ 1.0736444e+00,  9.1244847e-01],
        [-1.1884686e-01,  3.5082124e-04]],

       ...,

       [[ 1.5066472e-01,  1.5783851e-01],
        [-9.0608507e-02, -4.4729108e-01],
        [ 4.8095603e-02,  3.2209247e-01],
        ...,
        [ 8.9830953e-01,  1.0

YIPPPEEEEE!!!!!! IT WORKS WITH A BASIC FUNCTION!!!!!!!!!!

ok so we have the batching data structure mostly working. lets see if this runs! at all!

In [147]:
from utils.jraph_training import train_and_evaluate_with_data
from utils.jraph_vis import plot_predictions

In [148]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [149]:
workdir="tests/outputs/batch_test"

trained_state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=config, workdir=workdir, window_datasets=graph_tuple_dict, batched_datasets=batched_graph_tuple_dict)


INFO:absl:Hyperparameters: {'F': 8, 'K': 6, 'activation': 'relu', 'b': 10, 'batch_size': 8, 'c': 10, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.1, 'edge_features': (4, 8), 'epochs': 50, 'eval_every_epochs': 5, 'fully_connected_edges': False, 'global_features': None, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 2, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 5, 'max_checkpts_to_keep': 2, 'model': 'MLPGraphNetwork', 'n_blocks': 1, 'n_samples': 500, 'node_features': (32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 2, 'sample_buffer': -3, 'seed': 42, 'share_params': False, 'skip_connections': False, 'test_pct': 0.1, 'time_resolution': 120, 'timestep_duration': 3, 'train_pct': 0.7, 'val_pct': 0.2}
INFO:absl:Initializing network.
INFO:absl:
+----------------------------------------+----------+------+----------+-------+
| Name                                   | Shape    | Size | Mean     | Std   |
+-----------------------------

AssertionError: (96, 2)

unbatched:

batched:

In [None]:
plot_predictions(
    config=config,
    workdir=workdir, # for loading checkpoints 
    plot_ith_rollout_step=0, # 0 indexed # for this study, we have a 4-step rollout 
    # dataset,
    # preds,
    # timestep_duration,
    # n_rollout_steps,
    #  total_steps,
    node=0, # 0-indexed 
    plot_mode="val", # i.e. "train"/"val"/"test"
    plot_days=60,
    title="Val Predictions for node 0, rollout step 0"
)

ok so the training isn't working that well
we need to see what the 