<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/attempt_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.7 MB/s 
[K     |████████████████████████████████| 70 kB 3.4 MB/s 
[K     |████████████████████████████████| 77 kB 7.0 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 3.2 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 4.2 MB/s 
[K     |████████████████████████████████| 136 kB 4.1 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone


In [2]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:16<00:00, 12.83it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.63s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  2.45it/s]


Saving...


In [3]:
# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112) -- binary values (presence of protein functions)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112) -- binary values (presence of protein functions)
test_label = dataset.labels[split_idx['test']]    # (24679, 112) -- binary values (presence of protein functions)

In [4]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 100
dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 5.454 seconds
Construct multi-constraint weights: 0.108 seconds
Metis partitioning: 41.460 seconds
Split the graph: 5.518 seconds
Construct subgraphs: 0.126 seconds


In [5]:
import numpy as np
import torch
import jraph
import jax.numpy as jnp

def dgl_graph_to_jraph(node_ids):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  
  # I think it should be -- DGLGraph.edges uses the default (uv) format for edges
  # with u == source and v == destination
  # From https://docs.dgl.ai/generated/dgl.DGLGraph.edges.html#dgl.DGLGraph.edges
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features.astype(np.float32),
            senders = senders.astype(np.int32), 
            receivers = receivers.astype(np.int32),
            edges = edges.astype(np.float32),  
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
def get_labels_for_subgraph(node_ids):
  return ogbn_proteins_main_labels.index_select(0, node_ids)

In [6]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import optax

In [7]:
%pip install -q flax

[?25l[K     |█▉                              | 10 kB 33.3 MB/s eta 0:00:01[K     |███▊                            | 20 kB 8.5 MB/s eta 0:00:01[K     |█████▋                          | 30 kB 7.6 MB/s eta 0:00:01[K     |███████▍                        | 40 kB 3.5 MB/s eta 0:00:01[K     |█████████▎                      | 51 kB 3.5 MB/s eta 0:00:01[K     |███████████▏                    | 61 kB 4.2 MB/s eta 0:00:01[K     |█████████████                   | 71 kB 4.4 MB/s eta 0:00:01[K     |██████████████▉                 | 81 kB 4.6 MB/s eta 0:00:01[K     |████████████████▊               | 92 kB 5.2 MB/s eta 0:00:01[K     |██████████████████▌             | 102 kB 4.1 MB/s eta 0:00:01[K     |████████████████████▍           | 112 kB 4.1 MB/s eta 0:00:01[K     |██████████████████████▎         | 122 kB 4.1 MB/s eta 0:00:01[K     |████████████████████████        | 133 kB 4.1 MB/s eta 0:00:01[K     |██████████████████████████      | 143 kB 4.1 MB/s eta 0:00:01[K    

In [77]:
from typing import Sequence

# From https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

mlp_sizes = (64, 128)
num_message_passing_steps = 7

@jraph.concatenated_args
def node_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Node update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = mlp_sizes), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

@jraph.concatenated_args
def edge_update_fn(feats: jnp.ndarray) -> jnp.ndarray:
  """Edge update function for graph net."""
  net = hk.Sequential([hk.nets.MLP(output_sizes = mlp_sizes), jax.nn.relu, hk.LayerNorm(axis = -1, create_scale = False, create_offset = False)])
  return net(feats)

def node_decoder_fn(feats: jnp.ndarray) -> jnp.ndarray:
  net = hk.Sequential([hk.nets.MLP(output_sizes = [112]), jax.nn.sigmoid])
  return net(feats)

def network_definition(graph):
  """Defines a graph neural network.
  Args:
    graph: Graphstuple the network processes.
  Returns:
    Decoded nodes.
  """
  model_fn = functools.partial(hk.nets.MLP)

  node_encoder = model_fn(output_sizes=mlp_sizes)
  edge_encoder = model_fn(output_sizes=mlp_sizes)
  node_decoder = node_decoder_fn

  node_encoding = node_encoder(graph.nodes)
  edge_encoding = edge_encoder(graph.edges)
  graph = graph._replace(nodes=node_encoding, edges=edge_encoding)

  gn = jraph.InteractionNetwork(
      update_node_fn=node_update_fn,
      update_edge_fn=edge_update_fn,
      include_sent_messages_in_node_update=True
      )

  for _ in range(num_message_passing_steps):
    graph = graph._replace(
        nodes=jnp.concatenate([graph.nodes, node_encoding], axis=-1),
        edges=jnp.concatenate([graph.edges, edge_encoding], axis=-1)
        )
    graph = gn(graph)

  return node_decoder(graph.nodes)


In [87]:
import functools
import haiku as hk

# Try to follow this tutorial https://github.com/YuxuanXie/mcl/blob/5f7ee92e2a6bc89736263873a4ba9c14d1a676ff/glassy_dynamics/train_using_jax.py

def compute_loss(params, graph, label, net):
  decoded_nodes = net.apply(params, graph) # Shape == label.shape
  predictions = jax.lax.round(decoded_nodes)

  #################################################################################################################################################
  ############################################################# Compute loss function #############################################################
  #################################################################################################################################################
  # From https://colab.research.google.com/github/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb#scrollTo=_Ld4b3D6Lwel
  def compute_bce_with_logits_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    """Computes binary cross-entropy with logits loss.

    Combines sigmoid and BCE, and uses log-sum-exp trick for numerical stability.
    See https://stackoverflow.com/a/66909858 if you want to learn more.

    Args:
      x: Predictions (logits).
      y: Labels.

    Returns:
      Binary cross-entropy loss with mean aggregation.

    """
    max_val = jnp.clip(x, 0, None)
    loss = x - x * y + max_val + jnp.log(jnp.exp(-max_val) + jnp.exp((-x - max_val)))
    return loss.mean()
  #################################################################################################################################################

  loss = compute_bce_with_logits_loss(predictions, label)

  return loss

def evaluate(params, graph, label, net):
  # compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net))
  compute_loss_fn = functools.partial(compute_loss, net=net)

  decoded_nodes = net.apply(params, graph) # Shape == label.shape
  predictions = jax.lax.round(decoded_nodes)

  loss = compute_loss_fn(params, graph, label)
  acc = jnp.sum(predictions == label) / (label.shape[0] * label.shape[1])

  print(f'Eval loss: {loss} | Accuracy: {acc}')

def train_and_evaluate(num_training_steps):
  net = hk.without_apply_rng(hk.transform(network_definition))

  training_graph = dgl_graph_to_jraph(dgl_graph_metis_partition[0].ndata['_ID'])
  labels_training = get_labels_for_subgraph(dgl_graph_metis_partition[0].ndata['_ID'])
  labels_training = jnp.array(labels_training)

  params = net.init(jax.random.PRNGKey(42), training_graph)

  opt_init, opt_update = optax.adam(learning_rate = 1e-5)
  opt_state = opt_init(params)

  # @jax.jit
  def update(params, opt_state, graph, targets):
    loss, grads = jax.value_and_grad(compute_loss)(params, graph, targets, net)
    updates, opt_state = opt_update(grads, opt_state)
    return optax.apply_updates(params, updates), opt_state, loss

  # Train
  for idx in range(1, num_training_steps):
    graph = training_graph
    labels = labels_training

    params, opt_state, loss = update(params, opt_state, graph, labels)
    print('Loss training:', loss)

  # Evaluate
  evaluate(params, training_graph, labels_training, net)

predicted_labels = train_and_evaluate(10)

Loss training: 0.9419627
Loss training: 0.9423258
Loss training: 0.8922813
Loss training: 0.95148146
Loss training: 0.94151914
Loss training: 0.9886112
Loss training: 0.96643955


RuntimeError: ignored

In [14]:
## TODO: Take into account that nodes might be reshufled
def predict_labels(params, net):
  y_pred = None

  for i in range(num_partitions):
    graph = dgl_graph_to_jraph(dgl_graph_metis_partition[i].ndata['_ID'])

    if i == 0:
      y_pred = np.array(jax.nn.softmax(net.apply(params, graph)))
    else:
      if i % 5 == 0:
        print(f'Reached iteration {i}')
      y_pred = np.append(
          y_pred,
          np.array(jax.nn.softmax(net.apply(params, graph))),
          axis = 0
      )

  return y_pred

In [22]:
(_, true_labels) = dataset[0]
print(true_labels.shape)
print(torch.tensor(predicted_labels).shape)

torch.Size([132534, 112])
torch.Size([132534, 112])


In [23]:
from ogb.nodeproppred import Evaluator

evaluator = Evaluator(name = 'ogbn-proteins')
final_results = evaluator.eval({
    'y_true': true_labels,
    'y_pred': torch.tensor(predicted_labels)
})
print(final_results)
# print(evaluator.expected_input_format) 
# print(evaluator.expected_output_format) 

{'rocauc': 0.5000289552030827}
