In [1]:
import collections
import dataclasses
import datetime
import enum
import functools
import itertools
import json
import os
import pickle
import random
import tempfile
from typing import Sequence

import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import optax
import psutil
import scipy.sparse as sp
import seaborn as sns

2023-03-18 12:04:35.327318: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.2/lib64::/usr/local/cuda-11.2/lib64:/usr/local/cuda-11.2/extras/CUPTI/lib64
2023-03-18 12:04:35.328423: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.2/lib64::/usr/local/cuda-11.2/lib64:/usr/local/cuda-11.2/extras/CUPTI/lib64
  PyTreeDef = type(jax.tree_structure(None))


In [2]:
from BH.data_loader import *
from BH.generate_data import *
from Model import Model,Direction,Reduction
# from Load_Data import batch
from Train import train,print_accuracies

In [3]:
#@title Load data

#@markdown Training this model is pretty slow - an hour or so on the free tier colab, but subject to inactivity timeouts and pre-emptions.

#@markdown In order to make it possible to recreate the results from the paper reliably and quickly, we provide several helpers to either speed things up, or reduce the memory footprint:
#@markdown * Pretrained weights - greatly speeds things up by loading the trained model parameters rather than learning from the data
#@markdown * If you are running on a high memory machine (ie *not* on the free colab instance!) the input graph data can be loaded from a pickle (which is faster to load) and kept in memory (faster to re-use, but uses ~12Gb of memory). This makes no difference to training speed (it's only relevant for `generate_graph_data()` and `get_saliency_vectors()`).
DIR_PATH = "/Data/Ptab/n=5"


use_pretrained_weights = True  #@param{type:"boolean"}
hold_graphs_in_memory = False  #@param{type:"boolean"}

gb = 1024**3
total_memory = psutil.virtual_memory().total / gb
# Less than 20Gb of RAM means we need to do some things slower, but with lower memory impact - in
# particular, we want to allow things to run on the free colab tier.
if total_memory < 20 and hold_graphs_in_memory:
    raise RuntimeError(f"It is unlikely your machine (with {total_memory}Gb) will have enough memory to complete the colab's execution!")

print("Loading input data...")
full_dataset, train_dataset, test_dataset = load_input_data(DIR_PATH)

Loading input data...
Generating data from the directory /Data/Ptab/n=5


In [4]:
#@title Network Setup

step_size = 0.001
batch_size = 128

num_classes = np.max(train_dataset.labels) + 1
model = Model(
    num_layers=3,
    num_features=64,
    num_classes=num_classes,
    direction=Direction.BOTH,
    reduction=Reduction.SUM,
    apply_relu_activation=True,
    use_mask=False,
    share=False,
    message_relu=True,
    with_bias=True)

loss_val_gr = jax.value_and_grad(model.loss)
opt_init, opt_update = optax.adam(step_size)

In [5]:
num_epochs = 10
trained_params = model.net.init(
    jax.random.PRNGKey(42),
    features=train_dataset.features[0],
    rows=train_dataset.rows[0],
    cols=train_dataset.columns[0],
    batch_size=1,
    masks=train_dataset.features[0][np.newaxis, :, :])
trained_opt_state = opt_init(trained_params)

for ep in range(1, num_epochs + 1):
    tr_data = list(
        zip(
            train_dataset.features,
            train_dataset.rows,
            train_dataset.columns,
            train_dataset.labels,
            train_dataset.edge_types,
        ))
    random.shuffle(tr_data)
    features_train, rows_train, cols_train, ys_train, edge_types_train = zip(
        *tr_data)

    features_train = list(features_train)
    rows_train = list(rows_train)
    cols_train = list(cols_train)
    ys_train = np.array(ys_train)
    edge_types_train = list(edge_types_train)

    for i in range(0, len(features_train), batch_size):
        b_features, b_rows, b_cols, b_ys, b_edges = batch(
            features_train[i:i + batch_size],
            rows_train[i:i + batch_size],
            cols_train[i:i + batch_size],
            ys_train[i:i + batch_size],
            edge_types_train[i:i + batch_size],
        )

        trained_params, trained_opt_state, curr_loss = train(
            loss_val_gr,
            opt_update,
            trained_params,
            trained_opt_state,
            b_features,
            b_rows,
            b_cols,
            b_ys,
            b_edges,
        )

        accs = model.accuracy(
            trained_params,
            b_features,
            b_rows,
            b_cols,
            b_ys,
            b_edges,
        )
        print(datetime.datetime.now(),
              f"Iteration {i:4d} | Batch loss {curr_loss:.6f}",
              f"Batch accuracy {accs:.2f}")

    print(datetime.datetime.now(), f"Epoch {ep:2d} completed!")

    # Calculate accuracy across full dataset once per epoch
    print(datetime.datetime.now(), f"Epoch {ep:2d}       | ", end="")
    print_accuracies(model,trained_params, test_dataset, train_dataset)

  return init(shape, dtype)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


2023-03-18 12:04:42.326668 Iteration    0 | Batch loss 4.190897 Batch accuracy 0.54
2023-03-18 12:04:42.393625 Iteration  128 | Batch loss 0.766553 Batch accuracy 0.41
2023-03-18 12:04:42.433375 Iteration  256 | Batch loss 1.418116 Batch accuracy 0.46
2023-03-18 12:04:42.472408 Iteration  384 | Batch loss 0.680156 Batch accuracy 0.49
2023-03-18 12:04:42.511329 Iteration  512 | Batch loss 0.535574 Batch accuracy 0.55
2023-03-18 12:04:42.550231 Iteration  640 | Batch loss 0.490587 Batch accuracy 0.53
2023-03-18 12:04:42.590626 Iteration  768 | Batch loss 0.352529 Batch accuracy 0.53
2023-03-18 12:04:42.645173 Iteration  896 | Batch loss 0.448731 Batch accuracy 0.51
2023-03-18 12:04:42.684754 Iteration 1024 | Batch loss 0.369457 Batch accuracy 0.55
2023-03-18 12:04:42.724314 Iteration 1152 | Batch loss 0.472284 Batch accuracy 0.50
2023-03-18 12:04:44.110854 Iteration 1280 | Batch loss 0.428029 Batch accuracy 0.59
2023-03-18 12:04:44.111098 Epoch  1 completed!
2023-03-18 12:04:44.111129 Ep

2023-03-18 12:04:48.887944 Iteration 1152 | Batch loss 0.330533 Batch accuracy 0.59
2023-03-18 12:04:48.912789 Iteration 1280 | Batch loss 0.274451 Batch accuracy 0.72
2023-03-18 12:04:48.912849 Epoch  8 completed!
2023-03-18 12:04:48.912858 Epoch  8       | Train accuracy: 0.532 | Test accuracy: 0.526 | Combined accuracy: 0.531
2023-03-18 12:04:49.004963 Iteration    0 | Batch loss 0.373263 Batch accuracy 0.48
2023-03-18 12:04:49.044091 Iteration  128 | Batch loss 0.337002 Batch accuracy 0.59
2023-03-18 12:04:49.083164 Iteration  256 | Batch loss 0.313133 Batch accuracy 0.53
2023-03-18 12:04:49.122248 Iteration  384 | Batch loss 0.335329 Batch accuracy 0.59
2023-03-18 12:04:49.161865 Iteration  512 | Batch loss 0.318038 Batch accuracy 0.62
2023-03-18 12:04:49.201358 Iteration  640 | Batch loss 0.325529 Batch accuracy 0.60
2023-03-18 12:04:49.239720 Iteration  768 | Batch loss 0.308404 Batch accuracy 0.62
2023-03-18 12:04:49.279591 Iteration  896 | Batch loss 0.343851 Batch accuracy 0.