<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/final_version_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

!wget https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.8 MB/s 
[K     |████████████████████████████████| 70 kB 3.6 MB/s 
[K     |████████████████████████████████| 77 kB 5.0 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 3.4 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 5.2 MB/s 
[K     |████████████████████████████████| 136 kB 5.2 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone
--2022-03-24 11:07:21--  https://raw.githubusercontent.com/deepmind/jraph/master/jraph/experimental/sharded_graphnet.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|1

In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.local_device_count()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name = 'ogbn-proteins')

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:25<00:00,  8.62it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.74s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  2.11it/s]


Saving...


In [4]:
import jax.numpy as jnp

# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

# Create masks
train_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['train'])].set(1)
valid_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['valid'])].set(1)
test_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['test'])].set(1)

In [5]:
import jraph

# From https://colab.research.google.com/github/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb#scrollTo=7vEmAsr5bKN8
def _nearest_bigger_power_of_two(x: int) -> int:
  """Computes the nearest power of two greater than x for padding."""
  y = 2
  while y < x:
    y *= 2
  return y

def pad_graph_to_nearest_power_of_two(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
  """Pads a batched `GraphsTuple` to the nearest power of two.
  For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
  would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)
  And since padding is accomplished using `jraph.pad_with_graphs`, an extra
  graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs
  Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).
  Returns:
    A graphs_tuple batched to the nearest power of two.
  """
  # Add 1 since we need at least one padding node for pad_with_graphs.
  pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1
  pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge))
  # Add 1 since we need at least one padding graph for pad_with_graphs.
  # We do not pad to nearest power of two because the batch size is fixed.
  pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
  return jraph.pad_with_graphs(graphs_tuple, pad_nodes_to, pad_edges_to,
                               pad_graphs_to)

In [42]:
import numpy as np
import torch
import jraph

import sharded_graphnet

def dgl_graph_to_jraph(node_ids, labels, train_mask, valid_mask, test_mask):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  in_tuple = jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
  print(in_tuple.nodes.shape)

  in_tuple = in_tuple._replace(
      nodes = {
          'inputs': in_tuple.nodes, 
          'targets': labels, 
          'train_mask': train_mask, 
          'valid_mask': valid_mask, 
          'test_mask': test_mask,
          'padding_mask': jnp.ones_like(in_tuple.nodes)
          }
  )

  in_tuple = pad_graph_to_nearest_power_of_two(in_tuple)
  
  return sharded_graphnet.graphs_tuple_to_broadcasted_sharded_graphs_tuple(
      in_tuple,
      num_shards = num_devices
      )
  
def get_labels_for_subgraph(node_ids):
  return jnp.array(ogbn_proteins_main_labels.index_select(0, node_ids))

In [7]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 50  ## TODO: Find some way to decrease this to something reasonable (< 50)

dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 2.604 seconds
Construct multi-constraint weights: 0.016 seconds
Metis partitioning: 27.301 seconds
Split the graph: 0.740 seconds
Construct subgraphs: 0.031 seconds


In [45]:
# Convert graphs to Jraph GraphsTuple
processed_graphs = {}

for idx in range(num_partitions):
  node_ids = dgl_graph_metis_partition[idx].ndata['_ID']

  labels = get_labels_for_subgraph(node_ids)
  graph = dgl_graph_to_jraph(node_ids, 
                             labels, 
                             train_mask = train_mask.at[jnp.array(node_ids)].get(),
                             valid_mask = valid_mask.at[jnp.array(node_ids)].get(),
                             test_mask = test_mask.at[jnp.array(node_ids)].get()
                             )

  processed_graphs[f'partition_{idx}'] = {
      'graph': graph._replace(nodes = graph.nodes['inputs']), 
      'labels': graph.nodes['targets'],
      'train_mask': graph.nodes['train_mask'],
      'valid_mask': graph.nodes['valid_mask'],
      'test_mask': graph.nodes['test_mask'],
      'padding_mask': graph.nodes['padding_mask']
      }

(3586, 1)
(3596, 1)
(1187, 1)
(2328, 1)
(1954, 1)
(3563, 1)
(5133, 1)
(4795, 1)
(2998, 1)
(3399, 1)
(3412, 1)
(1593, 1)
(2776, 1)
(2618, 1)
(2722, 1)
(4848, 1)
(2775, 1)
(2768, 1)
(2408, 1)
(2391, 1)
(2348, 1)
(2673, 1)
(2114, 1)
(2264, 1)
(2902, 1)
(3894, 1)
(2474, 1)
(1555, 1)
(2207, 1)
(2101, 1)
(2285, 1)
(759, 1)
(681, 1)
(678, 1)
(3355, 1)
(690, 1)
(1467, 1)
(3142, 1)
(3240, 1)
(1885, 1)
(3848, 1)
(4381, 1)
(4211, 1)
(1896, 1)
(3038, 1)
(1275, 1)
(2136, 1)
(3499, 1)
(2474, 1)
(2212, 1)


In [9]:
import haiku as hk
import jax
import optax

from typing import Sequence

# See https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py for alternative to using GraphMapFeatures
# From https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

hidden_dimension = 128
num_message_passing_steps = 3

@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension]), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Edge update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension]), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

def node_decoder_fn(feats: jnp.ndarray) -> jnp.ndarray:
  net = hk.Sequential([hk.nets.MLP(output_sizes = [112], activate_final = False), jax.nn.sigmoid])
  return net(feats)

@hk.without_apply_rng
@hk.transform
def network_definition(graph):
  """Defines a graph neural network.
  Args:
    graph: Graphstuple the network processes.
  Returns:
    Decoded nodes.
  """
  graph = graph._replace(
      nodes = hk.Linear(hidden_dimension)(graph.nodes),
  )
  
  sharded_gn = sharded_graphnet.ShardedEdgesGraphNetwork(
      update_node_fn = node_update_fn,
      update_edge_fn = edge_update_fn,
      num_shards = num_devices
      )

  for _ in range(num_message_passing_steps):
    graph = sharded_gn(graph)

  graph = graph._replace(
      nodes = hk.Linear(112)(graph.nodes)
  )
  return graph.nodes

In [10]:
def bcast_local_devices(value):
    """Broadcasts an object to all local devices."""
    devices = jax.local_devices()

    def _replicate(x):
      """Replicate an object on each device."""
      x = jnp.array(x)
      return jax.device_put_sharded(len(devices) * [x], devices)

    return jax.tree_util.tree_map(_replicate, value)

In [32]:
def reshape_broadcasted_data(data):
  # prev_return = np.array(data).reshape((
  #       np.array(data).shape[0] * np.array(data).shape[1], 
  #       np.array(data).shape[2]
  #       ))

  return np.array(data)[0]
  
def remove_mask_from_data(data, mask):
  # print(f'Shape data {data.shape}')
  # print(f'Number of trues in mask {np.unique(np.array(mask), return_counts = True)}')
  # print(f'Shape after {np.array(np.compress(np.array(mask).reshape(-1).astype(bool), data, axis = 0)).shape}')
  return np.array(np.compress(np.array(mask).reshape(-1).astype(bool), data, axis = 0))

In [37]:
import functools
import haiku as hk

from random import randint

# Try to follow this tutorial https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py
def compute_loss(params, graph, label, mask):
  predictions = network_definition.apply(params, graph)

  # use optax here (https://github.com/deepmind/optax/blob/master/optax/_src/loss.py#L116#L139)
  loss = optax.sigmoid_binary_cross_entropy(predictions, label)  # shape [num_nodes, num_classes]
  loss = loss * mask
  loss = jnp.sum(loss) / jnp.sum(mask) # loss = mean_with_mask(loss, mask)

  return loss

def train(num_training_steps):
  replicated_params = jax.pmap(network_definition.init, axis_name = 'i')(bcast_local_devices(jax.random.PRNGKey(42)), processed_graphs['partition_0']['graph'])

  opt_init, opt_update = optax.adam(learning_rate = 0.1)  
  replicated_opt_state = jax.pmap(opt_init, axis_name = 'i')(replicated_params)

  @functools.partial(jax.pmap, axis_name='i')
  def update(params, opt_state, graph, targets, mask):
    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

    # Combine the gradient across all devices (by taking their mean).
    grads = jax.lax.pmean(grads, axis_name='i')

    # Also combine the loss. Unnecessary for the update, but useful for logging.
    loss = jax.lax.pmean(loss, axis_name='i')

    updates, opt_state = opt_update(grads, opt_state)

    return optax.apply_updates(params, updates), opt_state, loss

  # Train
  for idx in range(num_training_steps):
    random_partition_idx = randint(0, num_partitions - 1)
    random_partition = processed_graphs[f'partition_{random_partition_idx}']

    graph = random_partition['graph']
    labels = random_partition['labels']   # Automatically broadcasted by the sharded graph net
    mask = random_partition['train_mask'] # Automatically broadcasted by the sharded graph net

    replicated_params, replicated_opt_state, loss = update(
        replicated_params, 
        replicated_opt_state, 
        graph, 
        labels,
        mask
        )
    
    print('Loss training:', loss)

    if (idx + 1) % 10 == 0:
      print(f'***************************')
      print(f'Trained on {idx + 1} graphs')
      print(f'***************************')

    # Write params and opt_state

  return replicated_params

def evaluate(params, num_graphs_eval):
  # Evaluate
  accumulated_loss = 0.0
  accumulated_roc = 0

  for idx in range(num_graphs_eval):
    # random_partition_idx = randint(0, num_partitions - 1)
    random_partition_idx = idx
    random_partition = processed_graphs[f'partition_{random_partition_idx}']

    graph = random_partition['graph']
    labels = random_partition['labels']   # Automatically broadcasted by the sharded graph net
    mask = random_partition['test_mask']  # Automatically broadcasted by the sharded graph net

    predictions, loss = predict_on_graph(params, graph, labels, mask)
    
    collected_labels = reshape_broadcasted_data(labels)
    collected_predictions = reshape_broadcasted_data(predictions)
    collected_mask = reshape_broadcasted_data(mask)

    roc = evaluator.eval({
        "y_true": remove_mask_from_data(collected_labels, collected_mask), 
        "y_pred": remove_mask_from_data(collected_predictions, collected_mask)
        })['rocauc']

    print(f'Test loss: {loss} | ROC: {roc}')

    accumulated_loss += loss
    accumulated_roc += roc

  print(f'Average test loss: {accumulated_loss / num_graphs_eval} | Average ROC: {accumulated_roc / num_graphs_eval}')

@functools.partial(jax.pmap, axis_name='i')
def predict_on_graph(params, graph, label, mask):
  decoded_nodes = network_definition.apply(params, graph)

  compute_loss_fn = functools.partial(compute_loss)
  loss = compute_loss_fn(params, graph, label, mask)

  return jax.lax.round(decoded_nodes), loss

final_params = train(num_training_steps = 1)
evaluate(final_params, num_graphs_eval = num_partitions)

############## END OF RELEVANT CODE ####################

Loss training: [90.620834 90.620834 90.620834 90.620834 90.620834 90.620834 90.620834
 90.620834]
*****************************
Evaluated on 10 graphs
*****************************
Nodes evaluated so far 3667

*****************************
Evaluated on 20 graphs
*****************************
Nodes evaluated so far 7349

*****************************
Evaluated on 30 graphs
*****************************
Nodes evaluated so far 11557

*****************************
Evaluated on 40 graphs
*****************************
Nodes evaluated so far 16261

*****************************
Evaluated on 50 graphs
*****************************
Nodes evaluated so far 24679

Final count 24679
Final count 24679
Final count 24679


In [49]:
def evaluate_on_full_sets(params):
  final_predictions = {}

  for i in range(num_partitions):
    node_ids = dgl_graph_metis_partition[i].ndata['_ID']
    partition = processed_graphs[f'partition_{i}']
    padding_mask = reshape_broadcasted_data(partition['padding_mask'])
    
    predictions, _ = predict_on_graph(params, partition['graph'], partition['labels'], partition['test_mask'])
    predictions_after_masked_nodes_are_removed = reshape_broadcasted_data(predictions)
    predictions_after_masked_nodes_are_removed = remove_mask_from_data(
        predictions_after_masked_nodes_are_removed,
        padding_mask
        )

    for index, node_id in enumerate(node_ids):
      final_predictions[node_id] = predictions_after_masked_nodes_are_removed[index]

    if (i + 1) % 10 == 0:
      print(f'Evaluated {i + 1} / {num_partitions} subgraphs...')

  # Sort the final predictions based on the node ids
  predictions_in_order = dict(sorted(final_predictions.items()))

  # Convert the values to a list to be able to slice based on the ids of the 
  # nodes in the test set
  predictions_in_order = list(predictions_in_order.values())

  final_roc_train = evaluator.eval({
      "y_true": np.array(train_label), 
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['train']])
      })['rocauc']

  final_roc_valid = evaluator.eval({
      "y_true": np.array(valid_label), 
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['valid']])
      })['rocauc']

  final_roc_test = evaluator.eval({
      "y_true": np.array(test_label), 
      "y_pred": np.array([predictions_in_order[x] for x in split_idx['test']])
      })['rocauc']

  print()
  print(f'Final ROC on the train set {final_roc_train}')
  print(f'Final ROC on the validation set {final_roc_valid}')
  print(f'Final ROC on the test set {final_roc_test}')

evaluate_on_full_sets(final_params)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Index 2085 -- Node ID 120200
Index 2086 -- Node ID 120230
Index 2087 -- Node ID 120234
Index 2088 -- Node ID 120261
Index 2089 -- Node ID 120264
Index 2090 -- Node ID 120324
Index 2091 -- Node ID 120389
Index 2092 -- Node ID 120410
Index 2093 -- Node ID 120423
Index 2094 -- Node ID 120479
Index 2095 -- Node ID 120490
Index 2096 -- Node ID 120569
Index 2097 -- Node ID 120591
Index 2098 -- Node ID 120596
Index 2099 -- Node ID 120653
Index 2100 -- Node ID 120669
Index 2101 -- Node ID 120817
Index 2102 -- Node ID 120845
Index 2103 -- Node ID 120871
Index 2104 -- Node ID 120952
Index 2105 -- Node ID 121021
Index 2106 -- Node ID 121074
Index 2107 -- Node ID 121182
Index 2108 -- Node ID 121211
Index 2109 -- Node ID 121232
Index 2110 -- Node ID 121343
Index 2111 -- Node ID 121345
Index 2112 -- Node ID 121353
Index 2113 -- Node ID 121355
Index 2114 -- Node ID 121357
Index 2115 -- Node ID 121362
Index 2116 -- Node ID 121375
Index 2

KeyboardInterrupt: ignored

In [None]:
# Question
# I am not really sure I understand what sharding the graph actually means which is why I wanted to confirm this.
# From what I understand, with ShardedEdgesGraphNetworks sharding the graph means replicating the node / node features
# on all devices and splitting the edges of the graphs across all devices. This would mean that on each training step
# I only use a single graph right?

# This is in contrast with using a batch of 8 graphs on each training step, then passing one graph to each device.
# (something similar to this https://github.com/deepmind/jraph/issues/10)