<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/final_version_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.6 MB/s 
[K     |████████████████████████████████| 70 kB 3.5 MB/s 
[K     |████████████████████████████████| 77 kB 5.4 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 3.4 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 5.4 MB/s 
[K     |████████████████████████████████| 136 kB 4.8 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone


In [2]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.local_device_count()
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name = 'ogbn-proteins')

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:06<00:00, 35.07it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.72s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  2.23it/s]


Saving...


In [4]:
import jax.numpy as jnp

# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

# Create masks
train_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['train'])].set(1)
valid_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['valid'])].set(1)
test_mask = jnp.zeros((ogbn_proteins_main_graph.num_nodes(), 1)).at[jnp.array(split_idx['test'])].set(1)

In [5]:
import numpy as np
import torch
import jraph

def dgl_graph_to_jraph(node_ids):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
def get_labels_for_subgraph(node_ids):
  return jnp.array(ogbn_proteins_main_labels.index_select(0, node_ids))

In [6]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 256

# reshuffle == False, so my understandinng is that
# partition.ndata['_ID'] uses the same ids for the initial nodes
# therefore, the initial train / valid / test masks should still work.
dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 2.694 seconds
Construct multi-constraint weights: 0.015 seconds
Metis partitioning: 35.802 seconds
Split the graph: 0.494 seconds
Construct subgraphs: 0.236 seconds


In [7]:
# Convert graphs to Jraph GraphsTuple
processed_graphs = {}

for idx in range(num_partitions):
  node_ids = dgl_graph_metis_partition[idx].ndata['_ID']

  graph = dgl_graph_to_jraph(node_ids)
  labels = get_labels_for_subgraph(node_ids)

  processed_graphs[f'partition_{idx}'] = {
      'graph': graph, 
      'labels': labels,
      'train_mask': train_mask.at[jnp.array(node_ids)].get(),
      'valid_mask': valid_mask.at[jnp.array(node_ids)].get(),
      'test_mask': test_mask.at[jnp.array(node_ids)].get()
      }

In [8]:
import haiku as hk
import jax
import optax

from typing import Sequence

# See https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py for alternative to using GraphMapFeatures
# From https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

hidden_dimension = 128
num_message_passing_steps = 3

@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension]), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Edge update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = [hidden_dimension, hidden_dimension]), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

def node_decoder_fn(feats: jnp.ndarray) -> jnp.ndarray:
  ## TODO: Check if this is correct
  net = hk.Sequential([hk.nets.MLP(output_sizes = [112], activate_final = False), jax.nn.sigmoid])
  return net(feats)

def network_definition(graph):
  """Defines a graph neural network.
  Args:
    graph: Graphstuple the network processes.
  Returns:
    Decoded nodes.
  """

  embedder = jraph.GraphMapFeatures(
      embed_node_fn=hk.Linear(hidden_dimension),
      embed_edge_fn=hk.Linear(hidden_dimension),
      )
  graph = embedder(graph)

  gn = jraph.InteractionNetwork(
      update_node_fn=node_update_fn,
      update_edge_fn=edge_update_fn,
      include_sent_messages_in_node_update=True
      )

  for _ in range(num_message_passing_steps):
    graph = gn(graph)

  decoder = jraph.GraphMapFeatures(embed_node_fn = node_decoder_fn)
  
  processed_graph = decoder(graph)
  return processed_graph.nodes

In [12]:
import functools
import haiku as hk

from random import randint

# Try to follow this tutorial https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py
def compute_loss(params, graph, label, mask):
  # Question: would the net need to be passed from the train function?
  # I don't think so because this is just a function and we pass the parameters
  # below, but I just wanted to confirm.
  net = hk.without_apply_rng(hk.transform(network_definition))

  predictions = net.apply(params, graph) # Shape == label.shape
  
  # Question: Is node masking implemented properly?
  loss = optax.sigmoid_binary_cross_entropy(predictions, label)
  loss = loss * mask
  loss = loss.mean()
  return loss

def train(num_training_steps):
  # Transform the function (MPNN) into a pure function (with no side effects) so that it can be used with jax
  net = hk.without_apply_rng(hk.transform(network_definition))

  params = net.init(jax.random.PRNGKey(42), processed_graphs['partition_0']['graph'])

  opt_init, opt_update = optax.adam(learning_rate = 0.1)  # This learning rate could be a bit low, try 1e-1 for full graph descent?  
  opt_state = opt_init(params)

  # n_devices = 8
  # replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
  # replicated_opt_state = jax.tree_map(lambda x: jnp.array([x] * n_devices), opt_state)

  # @functools.partial(jax.pmap, axis_name='num_devices')
  def update(params, opt_state, graph, targets, mask):
    # Compute the gradients on the given minibatch (individually on each device).
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, mask)

    # # Combine the gradient across all devices (by taking their mean).
    # grads = jax.lax.pmean(grads, axis_name='num_devices')

    # # Also combine the loss. Unnecessary for the update, but useful for logging.
    # loss = jax.lax.pmean(loss, axis_name='num_devices')

    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state, loss

  ## TODO - Question: How to train only using the train_split?

  # Train
  for idx in range(num_training_steps):
    random_partition_idx = randint(0, num_partitions - 1)
    random_partition = processed_graphs[f'partition_{random_partition_idx}']

    graph = random_partition['graph']
    labels = random_partition['labels']
    mask = random_partition['train_mask']

    params, opt_state, loss = update(
        params, 
        opt_state, 
        graph, 
        labels,
        mask
        )
    print('Loss training:', loss)

    # Write params and opt_state

  return params

def evaluate(params, num_graphs_eval):
  # Evaluate
  accumulated_loss = 0.0
  accumulated_roc = 0

  for idx in range(num_graphs_eval):
    random_partition_idx = randint(0, num_partitions - 1)
    random_partition = processed_graphs[f'partition_{random_partition_idx}']

    graph = random_partition['graph']
    labels = random_partition['labels']
    mask = random_partition['test_mask']

    (loss, roc) = evaluate_graph(params, graph, labels, mask)

    accumulated_loss += loss
    accumulated_roc += roc

  print(f'Average loss: {accumulated_loss / num_graphs_eval} | Average ROC: {accumulated_roc / num_graphs_eval}')

def evaluate_graph(params, graph, label, mask):
  compute_loss_fn = jax.jit(functools.partial(compute_loss))

  net = hk.without_apply_rng(hk.transform(network_definition))
  decoded_nodes = net.apply(params, graph) # Shape == label.shape
  
  ## TODO -- Question: Should I round the predictions here? (predictions = jax.lax.round(decoded_nodes))
  predictions = jax.lax.round(decoded_nodes)

  loss = compute_loss_fn(params, graph, label, mask)
  roc = evaluator.eval({"y_true": np.array(label), "y_pred": np.array(predictions)})['rocauc']

  print(f'Test loss: {loss} | ROC: {roc}')
  return (loss, roc)

final_params = train(num_training_steps = 10)
evaluate(final_params, 5)

Loss training: 0.9700462
Loss training: 0.30946547
Loss training: 0.25340915
Loss training: 0.33717442
Loss training: 0.3528573
Loss training: 0.20972906
Loss training: 0.69314766
Loss training: 0.69314766
Loss training: 0.50661224
Loss training: 0.51450115
Eval loss: 0.2184399664402008 | ROC: 0.5
Eval loss: 0.006370839662849903 | ROC: 0.5
Eval loss: 0.23177120089530945 | ROC: 0.5
Eval loss: 0.42447325587272644 | ROC: 0.5
Eval loss: 0.12577266991138458 | ROC: 0.5
Average loss: 0.20136559009552002 | Average ROC: 0.5


In [None]:
  ## TODO: Implement masking



  # Train split, [0, 1, 10]
  # Train mask, jnp.zeros((num_nodes, 1)).at[train_split].set(1)
  # bce(pred, target) * train_mask

  # # Preprocess data
  # processed_data = []
  # for idx in len(data):
  #   graph_idx = dgl_graph_to_jraph(dgl_graph_metis_partition[idx].ndata['_ID'])
  #   labels_idx = get_labels_for_subgraph(dgl_graph_metis_partition[idx].ndata['_ID'])
  #   labels_idx = jnp.array(labels_idx)


  # # num training steps = 1000
  # for idx in range(num_training_steps):
  #   graph = random.choice(processed_data)
  #   labels = random.choice(labels)
  #   params, opt_state, loss = update(params, opt_state, graph_idx, labels_idx)
  #   print('Loss training:', loss)
  #   # save parameters and opt state
  # return params