<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/version_1_restart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

In [2]:
## I think this is not necessary for doing the partition

# ## Install METIS
# %rm metis-*
# !wget http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/metis-5.1.0.tar.gz
# !gunzip metis-5.1.0.tar.gz
# !tar -xvf metis-5.1.0.tar

# %cd metis-5.1.0/
# !make config shared=1
# !make install

# %env METIS_DLL=/usr/local/lib/libmetis.so

In [3]:
# Initialize the TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [4]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:38<00:00,  5.55it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:04<00:00,  4.19s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  2.34it/s]


Saving...


In [5]:
# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112)
test_label = dataset.labels[split_idx['test']]    # (24679, 112)

In [31]:
import torch
import jraph
import jax.numpy as jnp

def dgl_graph_to_jraph(node_ids):
  # Node and edge features are not copied when creating the metis partition
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  
  # Order should not matter here because the graph is undirected
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features, 
            senders = senders, 
            receivers = receivers,
            edges = edges,   
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )

In [8]:
import dgl

dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, 10, balance_edges = True)

Convert a graph into a bidirected graph: 9.997 seconds
Construct multi-constraint weights: 0.191 seconds
Metis partitioning: 54.945 seconds
Split the graph: 13.826 seconds
Construct subgraphs: 0.125 seconds


In [30]:
def get_labels_for_subgraph(ids):
  return ogbn_proteins_main_labels.index_select(0, ids)

get_labels_for_subgraph(dgl_graph_metis_partition[0].ndata['_ID'])

tensor([[1, 0, 1,  ..., 0, 0, 0],
        [1, 0, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [37]:
for i in range(1):
  node_ids = dgl_graph_metis_partition[i].ndata['_ID']

  jraph_graph = dgl_graph_to_jraph(node_ids)
  labels = get_labels_for_subgraph(node_ids)

(15651, 1)
torch.Size([15651, 112])


In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph

def network_definition(graph: jraph.GraphsTuple):
  """Defines a graph neural network.
  Args:
    graph: GraphsTuple the network processes.
  Returns:
  """
  update_node_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )
  
  update_edge_fn = hk.Sequential(
      hk.nets.MLP([128]),
      hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
      )

  gn = jraph.InteractionNetwork(
      update_edge_fn=update_edge_fn,
      update_node_fn=update_node_fn
      )