<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/final_version_12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis
%pip install -q torch-scatter

!wget https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py

--2022-04-07 12:43:45--  https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 22444 (22K) [text/plain]
Saving to: ‘sharded_graphnet.py.6’


2022-04-07 12:43:45 (14.0 MB/s) - ‘sharded_graphnet.py.6’ saved [22444/22444]



In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.local_device_count()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name = 'ogbn-proteins')
print(evaluator.expected_input_format)

Using backend: pytorch


==== Expected input format of Evaluator for ogbn-proteins
{'y_true': y_true, 'y_pred': y_pred}
- y_true: numpy ndarray or torch tensor of shape (num_node, num_task)
- y_pred: numpy ndarray or torch tensor of shape (num_node, num_task)
where y_pred stores score values (for computing ROC-AUC),
num_task is 112, and each row corresponds to one node.



In [4]:
import jax.numpy as jnp
import torch
from torch_scatter import scatter

# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]
ogbn_proteins_main_graph.ndata['species'] = scatter(
    ogbn_proteins_main_graph.edata['feat'],
    ogbn_proteins_main_graph.edges()[0],
    dim = 0,
    dim_size = ogbn_proteins_main_graph.num_nodes(),
    reduce = 'mean'
)
'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Diameter ~ 9 (https://cs.stanford.edu/people/jure/pubs/ogb-neurips20.pdf)
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

# Create masks
train_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['train'])].set(1)
valid_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['valid'])].set(1)
test_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['test'])].set(1)

In [5]:
import jraph

# From https://colab.research.google.com/github/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb#scrollTo=7vEmAsr5bKN8
def _nearest_multiple_of_8(x: int) -> int:
  """Computes the nearest power of two greater than x for padding."""
  if x % 8 == 0:
    return x
  else:
    return (x // 8 + 1) * 8 

def pad_graph_to_nearest_multiple_of_8(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Pads a batched `GraphsTuple` to the nearest power of two.
  For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
  would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)
  And since padding is accomplished using `jraph.pad_with_graphs`, an extra
  graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs
  Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).
  Returns:
    A graphs_tuple batched to the nearest power of two.
  """
  # Add 1 since we need at least one padding node for pad_with_graphs.
  pad_nodes_to = _nearest_multiple_of_8(jnp.sum(graphs_tuple.n_node)) + 1
  pad_edges_to = _nearest_multiple_of_8(jnp.sum(graphs_tuple.n_edge))
  # Add 1 since we need at least one padding graph for pad_with_graphs.
  # We do not pad to nearest power of two because the batch size is fixed.
  pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
  return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)

In [6]:
import numpy as np
import torch
import jraph
import sharded_graphnet

from sklearn.preprocessing import OneHotEncoder

np.random.seed(42)

enc = OneHotEncoder()
enc.fit(ogbn_proteins_main_graph.ndata['species'])

def dgl_graph_to_jraph(node_ids, labels, train_mask, valid_mask, test_mask):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)
  
  node_features = jnp.array(dgl_graph_with_features.ndata['species'])

  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  in_tuple = jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
  in_tuple = in_tuple._replace(
      nodes = {
          'inputs': in_tuple.nodes, 
          'targets': labels, 
          'train_mask': train_mask, 
          'valid_mask': valid_mask, 
          'test_mask': test_mask,
          'padding_mask': jnp.ones((in_tuple.nodes.shape[0], 1)) 
                                                        # TODO: Check this above
                                                        # Adding this mask so that we can remove the nodes added after padding 
                                                        # for the final ROC computations on the full train / valid / test splits
                                                        # This is because I want to pass the predictions on the true nodes to the 
                                                        # ogbn-evaluator, so I would first need to remove the predictions that come from padding.
          }
  )

  in_tuple = pad_graph_to_nearest_multiple_of_8(in_tuple)
  
  return sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple(
      in_tuple,
      num_shards = num_devices
      )
  
def get_labels_for_subgraph(node_ids):
  return jnp.array(ogbn_proteins_main_labels.index_select(0, node_ids))

In [7]:
def preprocess_graph_to_jraph(dgl_graph_metis_partition, num_partitions):
  # Convert graphs to Jraph GraphsTuple
  processed_graphs = {}

  for idx in range(num_partitions):
    node_ids = dgl_graph_metis_partition[idx].ndata['_ID']

    labels = get_labels_for_subgraph(node_ids)
    graph = dgl_graph_to_jraph(node_ids, 
                              labels, 
                              train_mask = train_mask.at[jnp.array(node_ids)].get(),
                              valid_mask = valid_mask.at[jnp.array(node_ids)].get(),
                              test_mask = test_mask.at[jnp.array(node_ids)].get()
                              )

    processed_graphs[f'partition_{idx}'] = {
        'graph': graph._replace(nodes = graph.nodes['inputs']), 
        'labels': graph.nodes['targets'],
        'train_mask': graph.nodes['train_mask'],
        'valid_mask': graph.nodes['valid_mask'],
        'test_mask': graph.nodes['test_mask'],
        'padding_mask': graph.nodes['padding_mask']
        }

  return processed_graphs

def bcast_local_devices(value):
    """Broadcasts an object to all local devices."""
    devices = jax.local_devices()

    def _replicate(x):
      """Replicate an object on each device."""
      x = jnp.array(x)
      return jax.device_put_sharded(len(devices) * [x], devices)

    return jax.tree_util.tree_map(_replicate, value)

def reshape_broadcasted_data(data):
  '''
    Node predictions / Labels / Masks are identical on all the devices so we only take
    one of them in order to remove the leading axis.
  '''
  return np.array(data)[0]
  
def remove_mask_from_data(data, mask):
  '''
    data.shape = [num_nodes, 112]
    mask.shape = [num_nodes, 1]

    We want to only return the data where mask == True
  '''
  sliced_data = np.compress(np.array(mask).reshape(-1).astype(bool), data, axis = 0)
  return np.array(sliced_data)

In [8]:
import matplotlib.pyplot as plt

def plot_loss(loss_list):
  plt.plot(range(1, len(loss_list) + 1), loss_list)
  plt.xlabel('Iteration')
  plt.ylabel('Training loss')

  plt.show()

def plot_rocs(roc_train, roc_eval, roc_test, iters):
  plt.plot(iters, roc_train, label = 'Train ROC')
  plt.plot(iters, roc_eval, label = 'Valid ROC')
  plt.plot(iters, roc_test, label = 'Test ROC')
  
  plt.xlabel('Iteration')
  plt.ylabel('ROC')

  plt.legend(loc = 'upper right')
  plt.show()

In [13]:
import csv

def append_row_to_csv(file_path, values):
  with open(file_path, 'a') as csvfile:
    csv_writer = csv.writer(csvfile, delimiter = ',')
    csv_writer.writerow(values)

    csvfile.flush()

In [18]:
import io
import dgl
import haiku as hk
import jax
import optax
import functools
import random
import pandas as pd

from typing import Sequence
from random import randint
from datetime import datetime

random.seed(42)

def run_for_configuration(config, results_path):
  num_partitions = config['num_partitions']
  hidden_dimension = config['hidden_dimension']
  num_message_passing_steps = config['num_message_passing_steps']
  num_training_steps = config['num_training_steps']
  evaluate_every = config['evaluate_every']

  # Create partitions
  dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

  # Convert graphs to Jraph GrapshTuple
  processed_graphs = preprocess_graph_to_jraph(dgl_graph_metis_partition, num_partitions)

  # Define network functions
  @jraph.concatenated_args
  def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
    """Node update function for graph net."""
    net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension], activation = jax.nn.relu, activate_final = False), hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)])
    return net(feats)

  @jraph.concatenated_args
  def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
    """Edge update function for graph net."""
    net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension], activation = jax.nn.relu, activate_final = False), hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)])
    return net(feats)

  @hk.without_apply_rng
  @hk.transform
  def network_definition(graph):
    """Defines a graph neural network.
    Args:
      graph: Graphstuple the network processes.
    Returns:
      Decoded nodes.
    """
    graph = graph._replace(
        nodes = hk.Linear(hidden_dimension)(graph.nodes),
        device_edges = hk.Linear(hidden_dimension)(graph.device_edges)
    )
    
    sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork(
        update_node_fn = node_update_fn,
        update_edge_fn = edge_update_fn,
        num_shards = num_devices
        )

    for _ in range(num_message_passing_steps):
      residual_graph = sharded_gn(graph)
      graph = graph._replace(
          nodes = graph.nodes + residual_graph.nodes,
          device_edges = graph.device_edges + residual_graph.device_edges
      )

    graph = graph._replace(
        nodes = hk.Sequential([hk.Linear(hidden_dimension), jax.nn.relu, hk.Linear(112)])(graph.nodes)
    )
    return graph.nodes

  def compute_loss(params, graph, label, mask):
    predictions = network_definition.apply(params, graph)

    # use optax here (https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L116#L139)
    loss = optax.sigmoid_binary_cross_entropy(predictions, label)  # shape [num_nodes, num_classes]
    loss = loss * mask
    loss = jnp.sum(loss) / jnp.sum(mask) # loss = mean_with_mask(loss, mask)

    return loss

  @functools.partial(jax.pmap, axis_name='i')
  def predict_on_graph(params, graph, label, mask):
    decoded_nodes = network_definition.apply(params, graph)

    compute_loss_fn = functools.partial(compute_loss)
    loss = compute_loss_fn(params, graph, label, mask)

    return jax.nn.sigmoid(decoded_nodes), loss

  #########################
  # Evaluations on full set
  #########################
  def evaluate_on_full_sets(params, dgl_graph_metis_partition, processed_graphs, num_partitions):
    final_predictions = {}

    for i in range(num_partitions):
      node_ids = dgl_graph_metis_partition[i].ndata['_ID']
      partition = processed_graphs[f'partition_{i}']
      
      predictions, _ = predict_on_graph(params, 
                                        partition['graph'], 
                                        partition['labels'], 
                                        partition['test_mask']  # Only used in the loss computation, does not affect predictions
                                        )

      predictions_after_masked_nodes_are_removed = remove_mask_from_data(
          reshape_broadcasted_data(predictions),
          reshape_broadcasted_data(partition['padding_mask'])
          )

      for index, node_id in enumerate(node_ids):
        final_predictions[node_id] = predictions_after_masked_nodes_are_removed[index]

      if (i + 1) % 10 == 0:
        print(f'Evaluated {i + 1} / {num_partitions} subgraphs...')

    # Sort the final predictions based on the node ids
    predictions_in_order = dict(sorted(final_predictions.items()))

    # Convert the values to a list to be able to slice based on the ids of the 
    # nodes in the test set
    predictions_in_order = list(predictions_in_order.values())

    final_roc_train = evaluator.eval({
        "y_true": np.array(train_label), 
        "y_pred": np.array([predictions_in_order[x] for x in split_idx['train']])
        })['rocauc']

    final_roc_valid = evaluator.eval({
        "y_true": np.array(valid_label), 
        "y_pred": np.array([predictions_in_order[x] for x in split_idx['valid']])
        })['rocauc']

    final_roc_test = evaluator.eval({
        "y_true": np.array(test_label),
        "y_pred": np.array([predictions_in_order[x] for x in split_idx['test']])
        })['rocauc']

    print()
    print(f'Final ROC on the train set {final_roc_train}')
    print(f'Final ROC on the validation set {final_roc_valid}')
    print(f'Final ROC on the test set {final_roc_test}')

    return (final_roc_train, final_roc_valid, final_roc_test)

  ####################
  # Training procedure
  ####################
  def train(num_training_steps, evaluate_every):
    losses = []

    replicated_params = jax.pmap(network_definition.init, axis_name = 'i')(
        bcast_local_devices(jax.random.PRNGKey(42)), 
        processed_graphs['partition_0']['graph']
        )

    opt_init, opt_update = optax.adam(learning_rate = 0.001)  
    replicated_opt_state = jax.pmap(opt_init, axis_name = 'i')(replicated_params)

    @functools.partial(jax.pmap, axis_name='i')
    def update(params, opt_state, graph, targets, mask):
      # Compute the gradients on the given minibatch (individually on each device).
      loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

      # Combine the gradient across all devices.
      grads = jax.lax.pmean(grads, axis_name='i')

      # Also combine the loss. Unnecessary for the update, but useful for logging.
      loss = jax.lax.pmean(loss, axis_name='i')

      updates, opt_state = opt_update(updates = grads, state = opt_state)

      return optax.apply_updates(params, updates), opt_state, loss

    ############################################################
    ### Count the number of parameters in this configuration ###
    ############################################################
    params_counts = hk.experimental.tabulate(update, columns = ['params_size'], tabulate_kwargs={"tablefmt": "tsv"})(
      replicated_params, 
      replicated_opt_state, 
      processed_graphs['partition_0']['graph'], 
      processed_graphs['partition_0']['labels'],
      processed_graphs['partition_0']['train_mask'] 
    ).replace(' ', '').replace(',', '')

    df = pd.read_csv(io.StringIO(params_counts), sep = '\t', index_col = False)
    vals = list(df.iloc[:, 0])
    vals = list(map(int, vals))

    total_params = 0
    for count in vals:
      total_params += count

    total_params = int(total_params / 2)

    ############################################################

    # Train
    for idx in range(num_training_steps):
      random_partition_idx = randint(0, num_partitions - 1)
      random_partition = processed_graphs[f'partition_{random_partition_idx}']

      graph = random_partition['graph']
      labels = random_partition['labels']   # Automatically broadcasted by the sharded graph net
      mask = random_partition['train_mask'] # Automatically broadcasted by the sharded graph net

      replicated_params, replicated_opt_state, loss = update(
          replicated_params, 
          replicated_opt_state, 
          graph, 
          labels,
          mask
          )
      
      print('Loss training:', reshape_broadcasted_data(loss))

      losses.append(reshape_broadcasted_data(loss))

      if (idx + 1) % 10 == 0:
        print()
        print(f'***************************')
        print(f'Trained on {idx + 1} graphs')
        print(f'***************************')
        print()

      if (idx + 1) % evaluate_every == 0:
        print()
        print(f'*** Full evaluations after {idx + 1} training steps ***')

        epoch = idx + 1
        roc_train, roc_eval, roc_test = evaluate_on_full_sets(replicated_params, dgl_graph_metis_partition, processed_graphs, num_partitions)

        avg_loss = sum(losses) / len(losses)
        losses = []

        append_row_to_csv(results_path, [
          str(total_params),
          str(num_partitions),
          str(hidden_dimension),
          str(num_message_passing_steps),
          str(epoch),
          str(roc_train),
          str(roc_eval),
          str(roc_test),
          str(avg_loss),
          'N/A', # Running time
          'N/A', # Memory consumption
        ])
        

    # plot_loss(losses)
    # plot_rocs(roc_train_list, roc_eval_list, roc_test_list, roc_step)    

    # RESULTS CSV
    # Every 100 iterations
    # (#parameters, #partitions, #hidden_dim, #message_passing_steps, roc_train_100, roc_eval_100, roc_test_100, loss_100, running_time!!, mem_usage!!)

  ### Now actually run the training loop and get the evaluations for this run
  train(
      num_training_steps = num_training_steps,
      evaluate_every = evaluate_every
  )

In [None]:
current_time = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
results_path = f'/content/results_{current_time}.csv'

append_row_to_csv(results_path, [
  'Parameters',
  'Partitions',
  'Hidden dimension',
  'Message passing steps',
  'Epoch',
  'ROC train',
  'ROC eval',
  'ROC test',
  'Loss',
  'Running time',
  'Memory usage'
])

run_for_configuration(config = {
    'num_partitions': 100,
    'hidden_dimension': 256,
    'num_message_passing_steps': 3,
    'num_training_steps': 5000,
    'evaluate_every': 100
    }, results_path = results_path)

Convert a graph into a bidirected graph: 3.535 seconds
Construct multi-constraint weights: 0.003 seconds
Metis partitioning: 41.338 seconds
Split the graph: 0.820 seconds
Construct subgraphs: 0.044 seconds
Loss training: 89.44355
Loss training: 70.50785
