<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/final_version_11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis
%pip install -q torch-scatter

!wget https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.7 MB/s 
[K     |████████████████████████████████| 72 kB 628 kB/s 
[K     |████████████████████████████████| 77 kB 6.2 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 3.3 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 5.4 MB/s 
[K     |████████████████████████████████| 136 kB 5.3 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
--2022-04-07 09:37:21--  https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Co

In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.local_device_count()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name = 'ogbn-proteins')
print(evaluator.expected_input_format)

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:25<00:00,  8.36it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.72s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  1.61it/s]


Saving...
==== Expected input format of Evaluator for ogbn-proteins
{'y_true': y_true, 'y_pred': y_pred}
- y_true: numpy ndarray or torch tensor of shape (num_node, num_task)
- y_pred: numpy ndarray or torch tensor of shape (num_node, num_task)
where y_pred stores score values (for computing ROC-AUC),
num_task is 112, and each row corresponds to one node.



In [4]:
import jax.numpy as jnp
import torch
from torch_scatter import scatter

# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]
ogbn_proteins_main_graph.ndata['species'] = scatter(
    ogbn_proteins_main_graph.edata['feat'],
    ogbn_proteins_main_graph.edges()[0],
    dim = 0,
    dim_size = ogbn_proteins_main_graph.num_nodes(),
    reduce = 'mean'
)
'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Diameter ~ 9 (https://cs.stanford.edu/people/jure/pubs/ogb-neurips20.pdf)
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

# Create masks
train_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['train'])].set(1)
valid_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['valid'])].set(1)
test_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['test'])].set(1)

In [5]:
import jraph

# From https://colab.research.google.com/github/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb#scrollTo=7vEmAsr5bKN8
def _nearest_multiple_of_8(x: int) -> int:
  """Computes the nearest power of two greater than x for padding."""
  if x % 8 == 0:
    return x
  else:
    return (x // 8 + 1) * 8 

def pad_graph_to_nearest_multiple_of_8(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Pads a batched `GraphsTuple` to the nearest power of two.
  For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
  would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)
  And since padding is accomplished using `jraph.pad_with_graphs`, an extra
  graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs
  Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).
  Returns:
    A graphs_tuple batched to the nearest power of two.
  """
  # Add 1 since we need at least one padding node for pad_with_graphs.
  pad_nodes_to = _nearest_multiple_of_8(jnp.sum(graphs_tuple.n_node)) + 1
  pad_edges_to = _nearest_multiple_of_8(jnp.sum(graphs_tuple.n_edge))
  # Add 1 since we need at least one padding graph for pad_with_graphs.
  # We do not pad to nearest power of two because the batch size is fixed.
  pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
  return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)

In [6]:
import numpy as np
import torch
import jraph
import sharded_graphnet

from sklearn.preprocessing import OneHotEncoder

np.random.seed(42)

enc = OneHotEncoder()
enc.fit(ogbn_proteins_main_graph.ndata['species'])

def dgl_graph_to_jraph(node_ids, labels, train_mask, valid_mask, test_mask):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  # node_features = jnp.array(enc.transform(dgl_graph_with_features.ndata['species']).toarray())
  
  node_features = jnp.array(dgl_graph_with_features.ndata['species'])

  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  in_tuple = jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
  in_tuple = in_tuple._replace(
      nodes = {
          'inputs': in_tuple.nodes, 
          'targets': labels, 
          'train_mask': train_mask, 
          'valid_mask': valid_mask, 
          'test_mask': test_mask,
          'padding_mask': jnp.ones((in_tuple.nodes.shape[0], 1)) 
                                                        # TODO: Check this above
                                                        # Adding this mask so that we can remove the nodes added after padding 
                                                        # for the final ROC computations on the full train / valid / test splits
                                                        # This is because I want to pass the predictions on the true nodes to the 
                                                        # ogbn-evaluator, so I would first need to remove the predictions that come from padding.
          }
  )

  in_tuple = pad_graph_to_nearest_multiple_of_8(in_tuple)
  
  return sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple(
      in_tuple,
      num_shards = num_devices
      )
  
def get_labels_for_subgraph(node_ids):
  return jnp.array(ogbn_proteins_main_labels.index_select(0, node_ids))

In [7]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 150  ## TODO: Find some way to decrease this to something reasonable (< 50)

dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 2.041 seconds
Construct multi-constraint weights: 0.016 seconds
Metis partitioning: 37.774 seconds
Split the graph: 0.558 seconds
Construct subgraphs: 0.167 seconds


In [8]:
# Convert graphs to Jraph GraphsTuple
processed_graphs = {}

for idx in range(num_partitions):
  node_ids = dgl_graph_metis_partition[idx].ndata['_ID']

  labels = get_labels_for_subgraph(node_ids)
  graph = dgl_graph_to_jraph(node_ids, 
                             labels, 
                             train_mask = train_mask.at[jnp.array(node_ids)].get(),
                             valid_mask = valid_mask.at[jnp.array(node_ids)].get(),
                             test_mask = test_mask.at[jnp.array(node_ids)].get()
                             )

  processed_graphs[f'partition_{idx}'] = {
      'graph': graph._replace(nodes = graph.nodes['inputs']), 
      'labels': graph.nodes['targets'],
      'train_mask': graph.nodes['train_mask'],
      'valid_mask': graph.nodes['valid_mask'],
      'test_mask': graph.nodes['test_mask'],
      'padding_mask': graph.nodes['padding_mask']
      }

In [9]:
import haiku as hk
import jax
import optax

from typing import Sequence

# See https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py for alternative to using GraphMapFeatures
# From https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

hidden_dimension = 128
num_message_passing_steps = 9 # Question: (256, 4) fails / (128, 6) works

@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension], activation = jax.nn.relu, activate_final = False), hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)])
  return net(feats)

@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Edge update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension], activation = jax.nn.relu, activate_final = False), hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)])
  return net(feats)

@hk.without_apply_rng
@hk.transform
def network_definition(graph):
  """Defines a graph neural network.
  Args:
    graph: Graphstuple the network processes.
  Returns:
    Decoded nodes.
  """
  graph = graph._replace(
      nodes = hk.Linear(hidden_dimension)(graph.nodes),
      device_edges = hk.Linear(hidden_dimension)(graph.device_edges)
  )
  
  sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork(
      update_node_fn = node_update_fn,
      update_edge_fn = edge_update_fn,
      num_shards = num_devices
      )

  for _ in range(num_message_passing_steps):
    residual_graph = sharded_gn(graph)
    graph = graph._replace(
        nodes = graph.nodes + residual_graph.nodes,
        device_edges = graph.device_edges + residual_graph.device_edges
    )

  graph = graph._replace(
      nodes = hk.Sequential([hk.Linear(hidden_dimension), jax.nn.relu, hk.Linear(112)])(graph.nodes)
  )
  return graph.nodes

In [13]:
def bcast_local_devices(value):
    """Broadcasts an object to all local devices."""
    devices = jax.local_devices()

    def _replicate(x):
      """Replicate an object on each device."""
      x = jnp.array(x)
      return jax.device_put_sharded(len(devices) * [x], devices)

    return jax.tree_util.tree_map(_replicate, value)

In [15]:
def reshape_broadcasted_data(data):
  '''
    Node predictions / Labels / Masks are identical on all the devices so we only take
    one of them in order to remove the leading axis.
  '''
  return np.array(data)[0]
  
def remove_mask_from_data(data, mask):
  '''
    data.shape = [num_nodes, 112]
    mask.shape = [num_nodes, 1]

    We want to only return the data where mask == True
  '''
  sliced_data = np.compress(np.array(mask).reshape(-1).astype(bool), data, axis = 0)
  return np.array(sliced_data)

In [16]:
import functools
import haiku as hk

@functools.partial(jax.pmap, axis_name='i')
def predict_on_graph(params, graph, label, mask):
  decoded_nodes = network_definition.apply(params, graph)

  compute_loss_fn = functools.partial(compute_loss)
  loss = compute_loss_fn(params, graph, label, mask)

  return jax.nn.sigmoid(decoded_nodes), loss

def evaluate_on_full_sets(params):
  final_predictions = {}

  for i in range(num_partitions):
    node_ids = dgl_graph_metis_partition[i].ndata['_ID']
    partition = processed_graphs[f'partition_{i}']
    
    predictions, _ = predict_on_graph(params, 
                                      partition['graph'], 
                                      partition['labels'], 
                                      partition['test_mask']  # Only used in the loss computation, does not affect predictions
                                      )

    predictions_after_masked_nodes_are_removed = remove_mask_from_data(
        reshape_broadcasted_data(predictions),
        reshape_broadcasted_data(partition['padding_mask'])
        )

    for index, node_id in enumerate(node_ids):
      final_predictions[node_id] = predictions_after_masked_nodes_are_removed[index]

    if (i + 1) % 10 == 0:
      print(f'Evaluated {i + 1} / {num_partitions} subgraphs...')

  # Sort the final predictions based on the node ids
  predictions_in_order = dict(sorted(final_predictions.items()))

  # Convert the values to a list to be able to slice based on the ids of the 
  # nodes in the test set
  predictions_in_order = list(predictions_in_order.values())

  final_roc_train = evaluator.eval({
      "y_true": np.array(train_label), 
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['train']])
      })['rocauc']

  final_roc_valid = evaluator.eval({
      "y_true": np.array(valid_label), 
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['valid']])
      })['rocauc']

  final_roc_test = evaluator.eval({
      "y_true": np.array(test_label),
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['test']])
      })['rocauc']

  print()
  print(f'Final ROC on the train set {final_roc_train}')
  print(f'Final ROC on the validation set {final_roc_valid}')
  print(f'Final ROC on the test set {final_roc_test}')

  return (final_roc_train, final_roc_valid, final_roc_test)

In [17]:
import pickle
import random

random.seed(42)

from random import randint
from google.colab import files

# Try to follow this tutorial https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py
def compute_loss(params, graph, label, mask):
  predictions = network_definition.apply(params, graph)

  # use optax here (https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L116#L139)
  loss = optax.sigmoid_binary_cross_entropy(predictions, label)  # shape [num_nodes, num_classes]
  loss = loss * mask
  loss = jnp.sum(loss) / jnp.sum(mask) # loss = mean_with_mask(loss, mask)

  return loss

def train(num_training_steps, learning_rate, results_path):
  losses = []

  roc_train_list = []
  roc_eval_list = []
  roc_test_list = []
  roc_step = []

  replicated_params = jax.pmap(network_definition.init, axis_name = 'i')(
      bcast_local_devices(jax.random.PRNGKey(42)), 
      processed_graphs['partition_0']['graph']
      )

  opt_init, opt_update = optax.adam(learning_rate = learning_rate)  
  replicated_opt_state = jax.pmap(opt_init, axis_name = 'i')(replicated_params)

  @functools.partial(jax.pmap, axis_name='i')
  def update(params, opt_state, graph, targets, mask):
    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

    # Combine the gradient across all devices.
    grads = jax.lax.pmean(grads, axis_name='i')     ## Question / TODO: Change to psum

    # Also combine the loss. Unnecessary for the update, but useful for logging.
    loss = jax.lax.pmean(loss, axis_name='i')

    updates, opt_state = opt_update(updates = grads, state = opt_state, )

    return optax.apply_updates(params, updates), opt_state, loss

  # Train
  for idx in range(num_training_steps):
    random_partition_idx = randint(0, num_partitions - 1)
    random_partition = processed_graphs[f'partition_{random_partition_idx}']

    graph = random_partition['graph']
    labels = random_partition['labels']   # Automatically broadcasted by the sharded graph net
    mask = random_partition['train_mask'] # Automatically broadcasted by the sharded graph net

    replicated_params, replicated_opt_state, loss = update(
        replicated_params, 
        replicated_opt_state, 
        graph, 
        labels,
        mask
        )
    
    print('Loss training:', reshape_broadcasted_data(loss))

    losses.append(reshape_broadcasted_data(loss))

    if (idx + 1) % 10 == 0:
      print()
      print(f'***************************')
      print(f'Trained on {idx + 1} graphs')
      print(f'***************************')
      print()

    if (idx + 1) % 200 == 0:
      print()
      print(f'*** Full evaluations after {idx + 1} training steps ***')
      
      roc_train, roc_eval, roc_test = evaluate_on_full_sets(replicated_params)

      roc_train_list.append(roc_train)
      roc_eval_list.append(roc_eval)
      roc_test_list.append(roc_test)
      roc_step.append(idx + 1)

  plot_loss(losses)
  plot_rocs(roc_train_list, roc_eval_list, roc_test_list, roc_step)    

  return replicated_params

In [53]:
import pandas as pd
import io

def get_params_count():
  @functools.partial(jax.pmap, axis_name='i')
  def demo_update(params, opt_state, graph, targets, mask):
    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

    # Combine the gradient across all devices.
    grads = jax.lax.pmean(grads, axis_name='i')     ## Question / TODO: Change to psum

    # Also combine the loss. Unnecessary for the update, but useful for logging.
    loss = jax.lax.pmean(loss, axis_name='i')

    updates, opt_state = demo_opt_update(updates = grads, state = opt_state, )

    return optax.apply_updates(params, updates), opt_state, loss

  demo_graph = processed_graphs['partition_0']['graph']
  demo_labels = processed_graphs['partition_0']['labels']   
  demo_mask = processed_graphs['partition_0']['train_mask'] 

  demo_replicated_params = jax.pmap(network_definition.init, axis_name = 'i')(
      bcast_local_devices(jax.random.PRNGKey(42)), 
      demo_graph
      )
  
  demo_opt_init, demo_opt_update = optax.adam(learning_rate = 0.001)  
  demo_replicated_opt_state = jax.pmap(demo_opt_init, axis_name = 'i')(demo_replicated_params)

  params_counts = hk.experimental.tabulate(demo_update, columns = ['params_size'], tabulate_kwargs={"tablefmt": "tsv"})(
    demo_replicated_params, 
    demo_replicated_opt_state, 
    demo_graph, 
    demo_labels,
    demo_mask 
  ).replace(' ', '').replace(',', '')

  df = pd.read_csv(io.StringIO(params_counts), sep = '\t', index_col = False)
  vals = list(df.iloc[:, 0])
  vals = list(map(int, vals))

  total_params = 0
  for count in vals:
    total_params += count

  return int(total_params / 2)

1813104


In [None]:
import matplotlib.pyplot as plt

def plot_loss(loss_list):
  plt.plot(range(1, len(loss_list) + 1), loss_list)
  plt.xlabel('Iteration')
  plt.ylabel('Training loss')

  plt.show()

def plot_rocs(roc_train, roc_eval, roc_test, iters):
  plt.plot(iters, roc_train, label = 'Train ROC')
  plt.plot(iters, roc_eval, label = 'Valid ROC')
  plt.plot(iters, roc_test, label = 'Test ROC')
  
  plt.xlabel('Iteration')
  plt.ylabel('ROC')

  plt.legend(loc = 'upper right')
  plt.show()

In [None]:
import os
from datetime import datetime

current_time = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
exp_path = f'/content/exp_{current_time}/'
os.makedirs(exp_path, exist_ok = False)

# Main training loop
final_params = train(
    num_training_steps = 10000, 
    learning_rate = 0.001,
    results_path = exp_path
    )

# with open('/content/exp_2022-03-30-09:51:17/params_epochs_9000.pickle', 'rb') as f:
#     loaded_params = pickle.load(f)
# loaded_params = final_params
# evaluate_on_full_sets(loaded_params)

'''
  Previous runs (padding to power of 2)
  (1) Configuration
        learning_rate = 0.001
        num_partitions = 50
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 1000
    ROC on the train set 0.7348797273386144
    ROC on the validation set 0.6025038939324504
    ROC on the test set 0.5896861508337246

  (2) Configuration
        learning_rate = 0.001
        num_partitions = 50
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 3000
    ROC on the train set 0.8050085464161815
    ROC on the validation set 0.6327603823722211
    ROC on the test set 0.5078022533003436

  (3) Configuration
        learning_rate = 0.1 (Question: I think this might be too high -- based on the results in (5) with lower number of epochs)
        num_partitions = 100
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 1000
    ROC on the train set 0.5
    ROC on the validation set 0.5 
    ROC on the test set 0.5

  (4) Configuration
        learning_rate = 0.01
        num_partitions = 100
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 100
    ROC on the train set 0.6501172261188106
    ROC on the validation set 0.5281974299591566
    ROC on the test set 0.47652056321124514

  (5) Configuration
        learning_rate = 0.01
        num_partitions = 100
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 500
    ROC on the train set 0.6939371049645034
    ROC on the validation set 0.559224577731843
    ROC on the test set 0.5488968392833208

  (6) Configuration
        opt: LAMB
        learning_rate: 1e-4
        num-partitions = grad100
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 500
  ROC on the train set 0.6299712777663571
  ROC on the validation set 0.5054189612195771
  ROC on the test set 0.5083185060310427

  ********************************************

  Previous runs (padding to multiple of 8)
  (1) Configuration
        opt: LAMB
        learning_rate: 1e-4
        num-partitions = 35
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 500
  ROC on the train set 0.6438794374674618
  ROC on the validation set 0.5162833891590899
  ROC on the test set 0.5175085147535061

  (2) Configuration
        opt: ADAM
        learning_rate: 1e-4
        num-partitions = 35
        hidden_dimension = 128
        num_message_passing_steps = 5
        num_training_steps = 10000
  ROC on the train set 0.8873957875287084
  ROC on the validation set 0.604810879077169
  ROC on the test set 0.57961973018596

  (3) Configuration
      opt: ADAM
      learning_rate = 0.001
      num_partitions = 35
      hidden_dimension = 128
      num_message_passing_steps = 5
      num_training_steps = 100
      *** Edge features = Edge features + sin(edge_features) + cost(edge_features)
  ROC on the train set 0.6798269721961256
  ROC on the validation set 0.5658498109351704
  ROC on the test set 0.5796600014265777

  (4) Configuration
      opt: ADAM
      learning_rate = 0.001
      num_partitions = 35
      hidden_dimension = 128
      num_message_passing_steps = 5
      num_training_steps = 1000
      *** Edge features = edge_features + sort of pos encoding * edge_features

  ROC on the train set 0.7458021952579157
  ROC on the validation set 0.6625978573423217
  ROC on the test set 0.6162747073675334
'''

In [None]:
###################################### FUNCTIONS FOR TESTING ######################################
run_overfit_on_single_partition = False
run_overfit_on_demo_graph = False

In [None]:
def overfit_on_single_graph(
    num_training_steps, 
    learning_rate,
    graph,
    labels,
    mask
    ):
  replicated_params = jax.pmap(network_definition.init, axis_name = 'i')(bcast_local_devices(jax.random.PRNGKey(42)), graph)

  opt_init, opt_update = optax.adam(learning_rate = learning_rate)  
  replicated_opt_state = jax.pmap(opt_init, axis_name = 'i')(replicated_params)

  @functools.partial(jax.pmap, axis_name='i')
  def update(params, opt_state, graph, targets, mask):
    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

    # Combine the gradient across all devices (by taking their mean).
    grads = jax.lax.pmean(grads, axis_name='i')

    # Also combine the loss. Unnecessary for the update, but useful for logging.
    loss = jax.lax.pmean(loss, axis_name='i')

    updates, opt_state = opt_update(grads, opt_state)

    return optax.apply_updates(params, updates), opt_state, loss

  # Train on a single partition
  for idx in range(num_training_steps):
    replicated_params, replicated_opt_state, loss = update(
      replicated_params, 
      replicated_opt_state, 
      graph, 
      labels,
      mask
      ) 

    print('Loss training:', reshape_broadcasted_data(loss))

    if (idx + 1) % 10 == 0:
      print()
      print(f'***************************')
      print(f'Trained on {idx + 1} graphs')
      print(f'***************************')
      print()

In [None]:
def get_demo_training_graph(num_nodes, num_edges):
  rand_dgl_graph = dgl.rand_graph(num_nodes = num_nodes, num_edges = num_edges)

  node_features = jnp.array([[randint(0, 7)] for i in range(num_nodes)])
  edge_features = jnp.array([[0.1 * randint(0, 10) for _ in range(8)] for i in range(num_edges)])

  senders = jnp.array(rand_dgl_graph.edges()[0])
  receivers = jnp.array(rand_dgl_graph.edges()[1])

  in_tuple = jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            edges = edge_features.astype(np.float32),  
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            n_node = jnp.array([num_nodes]), 
            n_edge = jnp.array([num_edges]),
            globals = None  # No global features
          )
  
  labels = jnp.array([[randint(0, 1) for j in range(112)] for i in range(num_nodes)])
  
  in_tuple = in_tuple._replace(
      nodes = {
          'inputs': in_tuple.nodes, 
          'targets': labels, 
          'train_mask': jnp.ones((num_nodes, 1)), # No nodes are masked 
          }
  )

  # in_tuple = pad_graph_to_nearest_multiple_of_8(in_tuple)
  
  return sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple(
      in_tuple,
      num_shards = num_devices
      )

def overfit_on_demo_graph(num_training_steps, learning_rate):
  demo_graph = get_demo_training_graph(num_nodes = 16, num_edges = 8)
  demo_labels = demo_graph.nodes['targets']
  demo_mask = demo_graph.nodes['train_mask']
  demo_graph = demo_graph._replace(nodes = demo_graph.nodes['inputs'])

  overfit_on_single_graph(
      num_training_steps = num_training_steps,
      learning_rate = learning_rate,
      graph = demo_graph,
      labels = demo_labels,
      mask = demo_mask
  )

In [None]:
# Overfit on an existing partition
if run_overfit_on_single_partition:
  print('*** Trying to overfit on a single partition ***')
  overfit_on_single_graph(
      num_training_steps = 5000,
      learning_rate = 0.001,
      graph = processed_graphs['partition_0']['graph'],
      labels = processed_graphs['partition_0']['labels'],
      mask = processed_graphs['partition_0']['train_mask']
      )

In [None]:
if run_overfit_on_demo_graph:
  print('*** Trying to overfit on a random demo graph ***')
  overfit_on_demo_graph(
      num_training_steps = 1000,
      learning_rate = 0.001
  )