<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/version_1_restart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.7 MB/s 
[K     |████████████████████████████████| 70 kB 3.7 MB/s 
[K     |████████████████████████████████| 77 kB 5.6 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 3.1 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 5.5 MB/s 
[K     |████████████████████████████████| 136 kB 5.5 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone


In [None]:
## Install METIS
%rm metis-*
!wget http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/metis-5.1.0.tar.gz
!gunzip metis-5.1.0.tar.gz
!tar -xvf metis-5.1.0.tar

%cd metis-5.1.0/
!make config shared=1
!make install

%env METIS_DLL=/usr/local/lib/libmetis.so

In [2]:
# Initialize the TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [3]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:27<00:00,  7.74it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:01<00:00,  1.84s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  1.35it/s]


Saving...


In [4]:
# There is only one graph in Node Property Prediction datasets
graph, labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112)
test_label = dataset.labels[split_idx['test']]    # (24679, 112)

In [38]:
import torch
import jraph
import jax.numpy as jnp
import metis

def dgl_graph_to_jraph(graph):
  node_features = jnp.array(graph.ndata['species'])
  
  # Order should not matter here because the graph is undirected
  senders = jnp.array(graph.edges()[0])
  receivers = jnp.array(graph.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  # TODO: Figure out how to use **metis** to split the graph
  edges = jnp.array(graph.edata['feat'][:100])

  n_node = jnp.array([graph.num_nodes()])
  n_edge = jnp.array([graph.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features, 
            senders = senders, 
            receivers = receivers,
            edges = edges,   
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )

jraph_ogb_proteins_graph = dgl_graph_to_jraph(graph)

In [7]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph

def network_definition(graph: jraph.GraphsTuple):
  """Defines a graph neural network.
  Args:
    graph: GraphsTuple the network processes.
  Returns:
  """
  update_node_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )
  
  update_edge_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )

  gn = jraph.InteractionNetwork(
      update_edge_fn=update_edge_fn,
      update_node_fn=update_node_fn
      )

RuntimeError: ignored