<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/attempt_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

In [2]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

Using backend: pytorch


In [3]:
# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

In [4]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 100
dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 6.486 seconds
Construct multi-constraint weights: 0.005 seconds
Metis partitioning: 44.488 seconds
Split the graph: 5.638 seconds
Construct subgraphs: 0.055 seconds


In [5]:
import numpy as np
import torch
import jraph
import jax.numpy as jnp

def dgl_graph_to_jraph(node_ids):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  
  # TODO: Check if order is correct
  # I think it should be -- DGLGraph.edges uses the default (uv) format for edges
  # with u == source and v == destination
  # From https://docs.dgl.ai/generated/dgl.DGLGraph.edges.html#dgl.DGLGraph.edges
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
def get_labels_for_subgraph(node_ids):
  return ogbn_proteins_main_labels.index_select(0, node_ids)

In [6]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import optax

In [7]:
%pip install -q flax

[?25l[K     |█▉                              | 10 kB 22.6 MB/s eta 0:00:01[K     |███▊                            | 20 kB 29.3 MB/s eta 0:00:01[K     |█████▋                          | 30 kB 23.3 MB/s eta 0:00:01[K     |███████▍                        | 40 kB 18.7 MB/s eta 0:00:01[K     |█████████▎                      | 51 kB 13.1 MB/s eta 0:00:01[K     |███████████▏                    | 61 kB 15.2 MB/s eta 0:00:01[K     |█████████████                   | 71 kB 13.8 MB/s eta 0:00:01[K     |██████████████▉                 | 81 kB 15.0 MB/s eta 0:00:01[K     |████████████████▊               | 92 kB 16.3 MB/s eta 0:00:01[K     |██████████████████▌             | 102 kB 14.2 MB/s eta 0:00:01[K     |████████████████████▍           | 112 kB 14.2 MB/s eta 0:00:01[K     |██████████████████████▎         | 122 kB 14.2 MB/s eta 0:00:01[K     |████████████████████████        | 133 kB 14.2 MB/s eta 0:00:01[K     |██████████████████████████      | 143 kB 14.2 MB/s eta 0:

In [50]:
from typing import Sequence

def network_definition(graph):
  """Defines a graph neural network.
  Args:
    graph: Graphstuple the network processes.
  Returns:
    Decoded nodes.
  """
  model_fn = functools.partial(
      hk.nets.MLP,
      w_init=hk.initializers.VarianceScaling(1.0),
      b_init=hk.initializers.VarianceScaling(1.0))
  mlp_sizes = (64, 64)
  num_message_passing_steps = 7

  node_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
  edge_encoder = model_fn(output_sizes=mlp_sizes, activate_final=True)
  node_decoder = model_fn(output_sizes=[112], activate_final=False)

  node_encoding = node_encoder(graph.nodes)
  edge_encoding = edge_encoder(graph.edges)
  graph = graph._replace(nodes=node_encoding, edges=edge_encoding)

  update_edge_fn = jraph.concatenated_args(
      model_fn(output_sizes=mlp_sizes, activate_final=True))
  update_node_fn = jraph.concatenated_args(
      model_fn(output_sizes=mlp_sizes, activate_final=True))
  gn = jraph.InteractionNetwork(
      update_edge_fn=update_edge_fn,
      update_node_fn=update_node_fn,
      include_sent_messages_in_node_update=True)
  
  for _ in range(num_message_passing_steps):
    graph = graph._replace(
        nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1),
        edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1))
    graph = gn(graph)

  # return jnp.squeeze(node_decoder(graph.nodes), axis=-1)
  return node_decoder(graph.nodes)


In [64]:
import functools
import haiku as hk

# Try to follow this tutorial https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

def compute_loss(params, graph, label, net):
  decoded_nodes = net.apply(params, graph) # Shape == label.shape
  predictions = jax.nn.softmax(decoded_nodes)

  loss = -jnp.mean(predictions * label)

  return loss

# def train_step(optimizer, graph, label, net):
#   partial_loss_fn = functools.partial(
#       compute_loss, graph=graph, label=label, net=net)
#   grad_fn = jax.value_and_grad(partial_loss_fn, has_aux=True)
#   loss, grad = grad_fn(optimizer.target)
#   optimizer = optimizer.apply_gradient(grad)
#   return optimizer, loss

def evaluate(params, graph, label):
  accumulated_loss = 0
  accumulated_accuracy = 0
  idx = 0

  # net = GraphNetwork(mlp_features=[112], latent_size=128)
  net = hk.without_apply_rng(hk.transform(network_definition))
  compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net))

  ## Evaluate on a batch of graphs
  loss = compute_loss_fn(params, graph, label)
  acc = jnp.sum(jax.nn.softmax(net.apply(params, graph)) == label) / (label.shape[0] * label.shape[1])
  idx += 1

  ## After evaluating on a number of graphs, compute the mean loss and accuracy
  accumulated_loss += loss
  accumulated_accuracy += acc

  print(f'Eval loss: {accumulated_loss / idx} | Accuracy: {accumulated_accuracy / idx}')

def train_and_evaluate(num_training_steps):
  # net = GraphNetwork(mlp_features = [112], latent_size = 128)
  net = hk.without_apply_rng(hk.transform(network_definition))

  training_graph = dgl_graph_to_jraph(dgl_graph_metis_partition[0].ndata['_ID'])
  labels_training = get_labels_for_subgraph(dgl_graph_metis_partition[0].ndata['_ID'])
  labels_training = jnp.array(labels_training)

  params = net.init(jax.random.PRNGKey(42), training_graph)

  opt_init, opt_update = optax.adam(learning_rate = 1e-5)
  opt_state = opt_init(params)

  @jax.jit
  def update(params, opt_state, graph, targets):
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, net)
    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state, loss

  # Train
  for idx in range(num_training_steps):
    # optimizer, scalars = train_step(optimizer, training_graph, labels_training, net)
    graph = dgl_graph_to_jraph(dgl_graph_metis_partition[idx].ndata['_ID'])
    label = get_labels_for_subgraph(dgl_graph_metis_partition[idx].ndata['_ID'])
    label = jnp.array(label)

    params, opt_state, loss = update(params, opt_state, graph, label)
    print('Loss training:', loss)

  # Evaluate
  evaluate(params, training_graph, labels_training)

train_and_evaluate(5)

Loss training: -0.0022140632
Loss training: -0.0010719331
Loss training: -0.00086903904
Loss training: -0.0016110134
Loss training: -0.0018560488
Eval loss: -0.0022140631917864084 | Accuracy: 0.8766355514526367
