<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/version_1_restart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax

In [2]:
# Initialize the TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [10]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

In [43]:
# There is only one graph in Node Property Prediction datasets
graph, labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112)
test_label = dataset.labels[split_idx['test']]    # (24679, 112)

In [7]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph

def network_definition(graph: jraph.GraphsTuple):
  """Defines a graph neural network.
  Args:
    graph: GraphsTuple the network processes.
  Returns:
  """
  update_node_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )
  
  update_edge_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )

  gn = jraph.InteractionNetwork(
      update_edge_fn=update_edge_fn,
      update_node_fn=update_node_fn
      )