In this notebook, I will be checking if my batching implementation is up to snuff. Essentially, I hope to see that I have fewer batches than windows

In [82]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts
from utils.lorenz import load_lorenz96_2coupled

import numpy as np
import jax.numpy as jnp


In [88]:
import ml_collections

config = ml_collections.ConfigDict()

# Data params. 
config.n_samples=20
config.input_steps=1
config.output_delay=0 # predict 0 hours into the future
config.output_steps=4
config.timestep_duration=3
config.sample_buffer=-1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
config.time_resolution=120
config.init_buffer_samples=0
config.train_pct=0.7
config.val_pct=0.2
config.test_pct=0.1
config.K=6
config.F=8
config.c=10
config.b=10
config.h=1
config.seed=42
config.normalize=True
config.fully_connected_edges=False

# Optimizer.
config.optimizer = 'adam'
config.learning_rate = 1e-3

# Training hyperparameters.
config.batch_size = 2
config.epochs = 150
config.log_every_epochs = 5
config.eval_every_epochs = 5
config.checkpoint_every_epochs = 10
config.max_checkpts_to_keep = 2 # None means keep all checkpoints

# GNN hyperparameters.
config.model = 'MLPGraphNetwork'
config.n_blocks = 1
config.activation = 'relu'
config.dropout_rate = 0.1
config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
config.layer_norm = False # TODO perhaps we want to turn on later
config.edge_features = (4, 8) # the last feature size will be the number of features that the graph predicts
config.node_features = (32, 2)
config.global_features = None
config.share_params = False

In [89]:
from utils.jraph_data import get_lorenz_graph_tuples, print_graph_fts

In [91]:
# generate desired dataset with train/val split and subsampled windows
graph_tuple_dict, batched_graph_tuple_dict = get_lorenz_graph_tuples(
    n_samples=config.n_samples,
    input_steps=config.input_steps,
    output_delay=config.output_delay,
    output_steps=config.output_steps,
    timestep_duration=config.timestep_duration,
    sample_buffer=config.sample_buffer,
    time_resolution=config.time_resolution,
    init_buffer_samples=config.init_buffer_samples,
    train_pct=config.train_pct,
    val_pct=config.val_pct,
    test_pct=config.test_pct,
    K=config.K,
    F=config.F,
    c=config.c,
    b=config.b,
    h=config.h,
    seed=config.seed,
    normalize=config.normalize,
    fully_connected_edges=config.fully_connected_edges,
    batch_size=config.batch_size)

In [92]:
train_dataset = graph_tuple_dict['train']
val_dataset = graph_tuple_dict['val']
test_dataset = graph_tuple_dict['test']

train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(graph_tuple_dict))
print("graph_tuple_dict keys:", graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(sample_graph))

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 14
size of train targets: 14
size of val inputs: 4
size of val targets: 4
size of test inputs: 2
size of test targets: 2
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 1
element type in window: <class 'jraph._src.graph.GraphsTuple'>


In [93]:
train_dataset = batched_graph_tuple_dict['train']
val_dataset = batched_graph_tuple_dict['val']
test_dataset = batched_graph_tuple_dict['test']

train_inputs = train_dataset['inputs']
train_targets = train_dataset['targets']
val_inputs = val_dataset['inputs']
val_targets = val_dataset['targets']
test_inputs = test_dataset['inputs']
test_targets = test_dataset['targets']

sample_input_window = train_inputs[0]
sample_target_window = train_targets[0]
batched_sample_graph = sample_input_window[0]

print("graph_tuple_dict type:", type(batched_graph_tuple_dict))
print("graph_tuple_dict keys:", batched_graph_tuple_dict.keys())
print("graph_tuple_dict value type:", type(train_dataset))
print("train dataset keys:", train_dataset.keys())
print("train dataset value type:", type(train_inputs))

print("size of train inputs:", len(train_inputs))
print("size of train targets:", len(train_targets))
print("size of val inputs:", len(val_inputs))
print("size of val targets:", len(val_targets))
print("size of test inputs:", len(test_inputs))
print("size of test targets:", len(test_targets))

print("train inputs window type:", type(sample_input_window))
print("train input window size (i.e. input steps per window):", len(sample_input_window))
print("element type in window:", type(batched_sample_graph))

graph_tuple_dict type: <class 'dict'>
graph_tuple_dict keys: dict_keys(['train', 'val', 'test'])
graph_tuple_dict value type: <class 'dict'>
train dataset keys: dict_keys(['inputs', 'targets'])
train dataset value type: <class 'list'>
size of train inputs: 7
size of train targets: 7
size of val inputs: 2
size of val targets: 2
size of test inputs: 1
size of test targets: 1
train inputs window type: <class 'list'>
train input window size (i.e. input steps per window): 2
element type in window: <class 'jraph._src.graph.GraphsTuple'>


In [94]:
print(graph_tuple_dict['val']['inputs'][0])

[GraphsTuple(nodes=array([[ 1.1636245 ,  1.0790482 ],
       [-1.293495  , -0.35579923],
       [ 0.37819758,  0.18474741],
       [ 0.37264496, -1.640342  ],
       [ 0.20188326, -1.047298  ],
       [ 1.1561959 ,  3.074735  ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6],

In [100]:
print(batched_graph_tuple_dict['val']['inputs'][0])

[GraphsTuple(nodes=Array([[ 1.1636245 ,  1.0790482 ],
       [-1.293495  , -0.35579923],
       [ 0.37819758,  0.18474741],
       [ 0.37264496, -1.640342  ],
       [ 0.20188326, -1.047298  ],
       [ 1.1561959 ,  3.074735  ]], dtype=float32), edges=Array([[ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.],
       [ 0.],
       [ 1.],
       [ 2.],
       [-1.],
       [-2.]], dtype=float32), receivers=Array([0, 1, 2, 5, 4, 1, 2, 3, 0, 5, 2, 3, 4, 1, 0, 3, 4, 5, 2, 1, 4, 5,
       0, 3, 2, 5, 0, 1, 4, 3], dtype=int32), senders=Array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4,
       4, 4, 4, 5, 5, 5, 5, 5], dtype=int32), globals=Array([[1.]], dtype=float32), n_node=Array([6],

ok so we have the batching data structure mostly working. lets see if this runs! at all!

In [101]:
from utils.jraph_training import train_and_evaluate_with_data
from utils.jraph_vis import plot_predictions

In [102]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [111]:
workdir="tests/outputs/batch_test"

trained_state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=config, workdir=workdir, datasets=batched_graph_tuple_dict)


INFO:absl:Hyperparameters: {'F': 8, 'K': 6, 'activation': 'relu', 'b': 10, 'batch_size': 2, 'c': 10, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.1, 'edge_features': (4, 8), 'epochs': 150, 'eval_every_epochs': 5, 'fully_connected_edges': False, 'global_features': None, 'h': 1, 'init_buffer_samples': 0, 'input_steps': 1, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 5, 'max_checkpts_to_keep': 2, 'model': 'MLPGraphNetwork', 'n_blocks': 1, 'n_samples': 20, 'node_features': (32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 0, 'output_steps': 4, 'sample_buffer': -4, 'seed': 42, 'share_params': False, 'skip_connections': False, 'test_pct': 0.1, 'time_resolution': 120, 'timestep_duration': 3, 'train_pct': 0.7, 'val_pct': 0.2}
INFO:absl:Initializing network.
INFO:absl:
+----------------------------------------+----------+------+----------+-------+
| Name                                   | Shape    | Size | Mean     | Std   |
+-----------------------------

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ((1, 1, 1, 1, None, None, None), (1, 1, 1, 1, None, None, None)) for value tree PyTreeDef(([CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])], [CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])])).

In [None]:
plot_predictions(
    config=config,
    workdir=workdir, # for loading checkpoints 
    plot_ith_rollout_step=0, # 0 indexed # for this study, we have a 4-step rollout 
    # dataset,
    # preds,
    # timestep_duration,
    # n_rollout_steps,
    #  total_steps,
    node=0, # 0-indexed 
    plot_mode="val", # i.e. "train"/"val"/"test"
    plot_days=60,
    title="Val Predictions for node 0, rollout step 0"
)

ok so the training isn't working that well
we need to see what the 