<a href="https://colab.research.google.com/github/bf319/Scaling_MPNNs/blob/main/version_1_restart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -q git+https://github.com/deepmind/dm-haiku
%pip install -q jraph
%pip install -q git+https://github.com/deepmind/jaxline
%pip install -q ogb
%pip install -q dgl
%pip install -q optax
%pip install -q metis

  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 75 kB 2.0 MB/s 
[K     |████████████████████████████████| 70 kB 3.4 MB/s 
[K     |████████████████████████████████| 77 kB 4.2 MB/s 
[?25h  Building wheel for jaxline (setup.py) ... [?25l[?25hdone
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 78 kB 2.9 MB/s 
[?25h  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4 MB 5.2 MB/s 
[K     |████████████████████████████████| 136 kB 5.2 MB/s 
[?25h  Building wheel for metis (setup.py) ... [?25l[?25hdone


In [2]:
## I think this is not necessary for doing the partition

# ## Install METIS
# %rm metis-*
# !wget http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/metis-5.1.0.tar.gz
# !gunzip metis-5.1.0.tar.gz
# !tar -xvf metis-5.1.0.tar

# %cd metis-5.1.0/
# !make config shared=1
# !make install

# %env METIS_DLL=/usr/local/lib/libmetis.so

In [3]:
# Initialize the TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [4]:
from ogb.nodeproppred import Evaluator
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = "ogbn-proteins")
split_idx = dataset.get_idx_split()

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip


Downloaded 0.21 GB: 100%|██████████| 216/216 [00:16<00:00, 12.72it/s]


Extracting dataset/proteins.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:05<00:00,  5.72s/it]


Converting graphs into DGL objects...


100%|██████████| 1/1 [00:00<00:00,  1.53it/s]


Saving...


In [5]:
# There is only one graph in Node Property Prediction datasets
ogbn_proteins_main_graph, ogbn_proteins_main_labels = dataset[0]

'''
  OGBN-Proteins
    #Nodes = 132,534
    #Edges = 39,561,252
    #Tasks = 112
    #Split Type = Species
    #Task Type = Binary classification
    #Metric = ROC-AUC

    Task:
      The task is to predict the presence of protein functions in a multi-label binary classification setup,
      where there are 112 kinds of labels to predict in total. 
      The performance is measured by the average of ROC-AUC scores across the 112 tasks.

    #Others:
      **undirected**
      **weighted**
      **typed (according to species)**

  (1) Nodes represent proteins
    (1.1) The proteins come from 8 species
      len(set(graph.ndata['species'].reshape(-1).tolist())) == 8
    (1.2) Each node has one feature associated with it (its species)
      graph.ndata['species'].shape == (#nodes, 1)
  
  (2) Edges indicate different types of biologically meaningful associations between proteins
    (2.1) All edges come with 8-dimensional features
      graph.edata['feat'].shape == (2 * #edges, 8)

'''
# Get split labels
train_label = dataset.labels[split_idx['train']]  # (86619, 112)
valid_label = dataset.labels[split_idx['valid']]  # (21236, 112)
test_label = dataset.labels[split_idx['test']]    # (24679, 112)

In [6]:
import dgl

'''
  Generate graph partition using metis, with balanced number of edges in each partition.
  Note: 
    The subgraphs do not contain the node/edge data in the input graph (https://docs.dgl.ai/generated/dgl.metis_partition.html)
'''
num_partitions = 10
dgl_graph_metis_partition = dgl.metis_partition(ogbn_proteins_main_graph, num_partitions, balance_edges = True)

Convert a graph into a bidirected graph: 8.475 seconds
Construct multi-constraint weights: 0.073 seconds
Metis partitioning: 42.739 seconds
Split the graph: 11.870 seconds
Construct subgraphs: 0.198 seconds


In [7]:
import torch
import jraph
import jax.numpy as jnp

def dgl_graph_to_jraph(node_ids):
  # First add back the node and edge features
  dgl_graph_with_features = dgl.node_subgraph(ogbn_proteins_main_graph, node_ids)

  node_features = jnp.array(dgl_graph_with_features.ndata['species'])
  
  # TODO: Check if order is correct
  # I think it should be -- DGLGraph.edges uses the default (uv) format for edges
  # with u == source and v == destination
  # From https://docs.dgl.ai/generated/dgl.DGLGraph.edges.html#dgl.DGLGraph.edges
  senders = jnp.array(dgl_graph_with_features.edges()[0])
  receivers = jnp.array(dgl_graph_with_features.edges()[1])

  # Edges -- here we should include the 8-dimensional edge features
  edges = jnp.array(dgl_graph_with_features.edata['feat'])

  n_node = jnp.array([dgl_graph_with_features.num_nodes()])
  n_edge = jnp.array([dgl_graph_with_features.num_edges()])

  return jraph.GraphsTuple(
            nodes = node_features, 
            senders = senders, 
            receivers = receivers,
            edges = edges,   
            n_node = n_node, 
            n_edge = n_edge,
            globals = None  # No global features
          )
  
def get_labels_for_subgraph(node_ids):
  return ogbn_proteins_main_labels.index_select(0, node_ids)

In [8]:
for i in range(num_partitions):
  node_ids = dgl_graph_metis_partition[i].ndata['_ID']

  jraph_graph = dgl_graph_to_jraph(node_ids)
  labels = get_labels_for_subgraph(node_ids)

In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
import jraph

In [10]:
# def network_definition(graph: jraph.GraphsTuple):
#   """Defines a graph neural network.
#   Args:
#     graph: GraphsTuple the network processes.
#   Returns:
#   """
#   update_node_fn = hk.Sequential(
#       hk.nets.MLP([128]),
#       hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
#       )
  
#   update_edge_fn = hk.Sequential(
#       hk.nets.MLP([128]),
#       hk.LayerNorm(axis = -1, create_scale = True, create_offset = True)
#       )

#   gn = jraph.InteractionNetwork(
#       update_edge_fn=update_edge_fn,
#       update_node_fn=update_node_fn
#       )
  
# # From https://github.com/deepmind/deepmind-research/blob/master/ogb_lsc/pcq/model.py
# def _softmax_cross_entropy(
#     logits: jnp.DeviceArray,
#     targets: jnp.DeviceArray,
# ) -> jnp.DeviceArray:
#   logits = jax.nn.log_softmax(logits)
#   return -jnp.sum(targets * logits, axis=-1)

# def get_loss(pred, targets):
#   targets /= jnp.maximum(1., jnp.sum(targets, axis=-1, keepdims=True))
#   loss = _softmax_cross_entropy(pred, targets)

#   return loss


In [13]:
# from typing import Callable, NamedTuple, Sequence

# def _get_activation_fn(name: str) -> Callable[[jnp.ndarray], jnp.ndarray]:
#   if name == 'identity':
#     return lambda x: x
#   if hasattr(jax.nn, name):
#     return getattr(jax.nn, name)
#   raise ValueError('Unknown activation function %s specified. '
#                    'See https://jax.readthedocs.io/en/latest/jax.nn.html'
#                    'for the list of supported function names.')
  
# class ModelOutput(NamedTuple):
#   node_embeddings: jnp.ndarray
#   node_embedding_projections: jnp.ndarray
#   node_projection_predictions: jnp.ndarray
#   node_logits: jnp.ndarray

# class MyModuleVersion1(hk.Module):
#   def __init__(
#       self,
#       mlp_hidden_sizes: Sequence[int],
#       latent_size: int,
#       num_classes: int,
#       num_message_passing_steps: int = 2,
#       activation: str = 'relu',
#       dropout_rate: float = 0.0,
#       dropedge_rate: float = 0.0,
#       use_sent_edges: bool = False,
#       disable_edge_updates: bool = False,
#       normalization_type: str = 'layer_norm',
#       aggregation_function: str = 'sum',
#       name='MyModuleVersion1',
#   ):
#     super().__init__(name=name)
#     self._num_classes = num_classes
#     self._latent_size = latent_size
#     self._output_sizes = list(mlp_hidden_sizes) + [latent_size]
#     self._num_message_passing_steps = num_message_passing_steps
#     self._activation = _get_activation_fn(activation)
#     self._dropout_rate = dropout_rate
#     self._dropedge_rate = dropedge_rate
#     self._use_sent_edges = use_sent_edges
#     self._disable_edge_updates = disable_edge_updates
#     self._normalization_type = normalization_type
#     self._aggregation_function = aggregation_function

#     def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
#       node_key, edge_key = hk.next_rng_keys(2)
#       nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
#       edges = graph.edges
#       if not self._disable_edge_updates:
#         edges = hk.dropout(edge_key, self._dropout_rate, edges)
#       return graph._replace(nodes=nodes, edges=edges)

#     def build_gn(
#         output_sizes: Sequence[int],
#         activation: Callable[[jnp.ndarray], jnp.ndarray],
#         suffix: str,
#         use_sent_edges: bool,
#         is_training: bool,
#         dropedge_rate: float,
#         normalization_type: str,
#         aggregation_function: str,
#     ):
#       """Builds an InteractionNetwork with MLP update functions."""
#       node_update_fn = build_update_fn(
#           f'node_processor_{suffix}',
#           output_sizes,
#           activation=activation,
#           normalization_type=normalization_type,
#           is_training=is_training,
#       )
#       edge_update_fn = build_update_fn(
#           f'edge_processor_{suffix}',
#           output_sizes,
#           activation=activation,
#           normalization_type=normalization_type,
#           is_training=is_training,
#       )

#     def build_update_fn(
#         name: str,
#         output_sizes: Sequence[int],
#         activation: Callable[[jnp.ndarray], jnp.ndarray],
#         normalization_type: str,
#         is_training: bool,
#     ):
#       """Builds update function."""

#       def single_mlp(inner_name: str):
#         """Creates a single MLP performing the update."""
#         mlp = hk.nets.MLP(
#             output_sizes=output_sizes,
#             name=inner_name,
#             activation=activation)
#         mlp = jraph.concatenated_args(mlp)
#         if normalization_type == 'layer_norm':
#           norm = hk.LayerNorm(
#               axis=-1,
#               create_scale=True,
#               create_offset=True,
#               name=name + '_layer_norm')
#         elif normalization_type == 'batch_norm':
#           batch_norm = hk.BatchNorm(
#               create_scale=True,
#               create_offset=True,
#               decay_rate=0.9,
#               name=f'{inner_name}_batch_norm',
#               cross_replica_axis=None if hk.running_init() else 'i',
#           )
#           norm = lambda x: batch_norm(x, is_training)
#         elif normalization_type == 'none':
#           return mlp
#         else:
#           raise ValueError(f'Unknown normalization type {normalization_type}')
#         return jraph.concatenated_args(hk.Sequential([mlp, norm]))

#       return single_mlp(f'{name}_homogeneous')

#     def _encode(
#         self,
#         graph: jraph.GraphsTuple,
#         is_training: bool,
#     ) -> jraph.GraphsTuple:
#       node_embed_fn = build_update_fn(
#           'node_encoder',
#           self._output_sizes,
#           activation=self._activation,
#           normalization_type=self._normalization_type,
#           is_training=is_training,
#       )
#       edge_embed_fn = build_update_fn(
#           'edge_encoder',
#           self._output_sizes,
#           activation=self._activation,
#           normalization_type=self._normalization_type,
#           is_training=is_training,
#       )
#       gn = jraph.GraphMapFeatures(edge_embed_fn, node_embed_fn)
#       graph = gn(graph)
#       if is_training:
#         graph = self._dropout_graph(graph)
#       return graph

#     def _process(
#         self,
#         graph: jraph.GraphsTuple,
#         is_training: bool,
#     ) -> jraph.GraphsTuple:
#       for idx in range(self._num_message_passing_steps):
#         net = build_gn(
#             output_sizes=self._output_sizes,
#             activation=self._activation,
#             suffix=str(idx),
#             use_sent_edges=self._use_sent_edges,
#             is_training=is_training,
#             dropedge_rate=self._dropedge_rate,
#             normalization_type=self._normalization_type,
#             aggregation_function=self._aggregation_function)
#         residual_graph = net(graph)
#         graph = graph._replace(nodes=graph.nodes + residual_graph.nodes)
#         if not self._disable_edge_updates:
#           graph = graph._replace(edges=graph.edges + residual_graph.edges)
#         if is_training:
#           graph = self._dropout_graph(graph)
#       return graph

#     def _node_mlp(
#         self,
#         graph: jraph.GraphsTuple,
#         is_training: bool,
#         output_size: int,
#         name: str,
#     ) -> jnp.ndarray:
#       decoder_sizes = list(self._output_sizes[:-1]) + [output_size]
#       net = build_update_fn(
#           name,
#           decoder_sizes,
#           self._activation,
#           normalization_type=self._normalization_type,
#           is_training=is_training,
#       )
#       return net(graph.nodes)

#     def __call__(
#         self,
#         graph: jraph.GraphsTuple,
#         is_training: bool,
#         stop_gradient_embedding_to_logits: bool = False,
#     ) -> ModelOutput:
#       # Note that these update configs may need to change if
#       # we switch back to GraphNetwork rather than InteractionNetwork.

#       graph = self._encode(graph, is_training)
#       graph = self._process(graph, is_training)
#       node_embeddings = graph.nodes
#       node_projections = self._node_mlp(graph, is_training, self._latent_size,
#                                         'projector')
#       node_predictions = self._node_mlp(
#           graph._replace(nodes=node_projections),
#           is_training,
#           self._latent_size,
#           'predictor',
#       )
#       if stop_gradient_embedding_to_logits:
#         graph = jax.tree_map(jax.lax.stop_gradient, graph)
#       node_logits = self._node_mlp(graph, is_training, self._num_classes,
#                                   'logits_decoder')
#       return ModelOutput(
#           node_embeddings=node_embeddings,
#           node_logits=node_logits,
#           node_embedding_projections=node_projections,
#           node_projection_predictions=node_predictions,
#       )

In [43]:
def try_to_apply_interaction_network():
  def network_definition(
      graph: jraph.GraphsTuple,
      num_message_passing_steps: int = 1) -> jraph.ArrayTree:
    """Defines a graph neural network.
    Args:
      graph: Graphstuple the network processes.
      num_message_passing_steps: number of message passing steps.
    Returns:
      Decoded nodes.
    """
    # Functions from https://github.com/deepmind/jraph/blob/38817e3a75b8a70e87abdc2c8ed00c40822f10b4/jraph/examples/lstm.py#L112
    def update_edge_fn(edges, sender_nodes, receiver_nodes):
      # We will run an LSTM memory on the inputs first, and then
      # process the output of the LSTM with an MLP.
      return hk.Sequential([hk.Linear(64), jax.nn.relu])(edges)

    def update_node_fn(nodes, received_edges):
      # Note `received_edges.state` will also contain the aggregated state for
      # all received edges, which we may choose to use in the node update.
      node_inputs = jnp.concatenate([nodes, received_edges], axis=-1)
      
      return hk.Sequential([hk.Linear(64), jax.nn.relu])(node_inputs)

    gn = jraph.InteractionNetwork(
          update_edge_fn=update_edge_fn,
          update_node_fn=update_node_fn)

    for _ in range(num_message_passing_steps):
      graph = gn(graph)

    return graph

  network = hk.without_apply_rng(hk.transform(network_definition))

  jraph_graph = dgl_graph_to_jraph(dgl_graph_metis_partition[0].ndata['_ID'])
  labels = get_labels_for_subgraph(dgl_graph_metis_partition[0].ndata['_ID'])

  params = network.init(jax.random.PRNGKey(42), jraph_graph)

  print(network.apply(params, jraph_graph).nodes.shape)
  print(jraph_graph.nodes.shape)

try_to_apply_interaction_network()

(14984, 64)
(14984, 1)
