In [1]:
!pip install jraph flax dm-haiku matplotlib networkx scikit-learn ogb



In [2]:
# Imports
%matplotlib inline
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as onp
import networkx as nx
from typing import Any, Callable, Dict, List, Optional, Tuple
from sklearn.metrics import roc_auc_score

In [3]:
def add_self_edges_fn(receivers: jnp.ndarray, senders: jnp.ndarray,
                      total_num_nodes: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Adds self edges. Assumes self edges are not in the graph yet."""
    receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)), axis=0)
    senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)), axis=0)
    return receivers, senders

class MLP(hk.Module):
    def __init__(self, features: jnp.ndarray):
        super().__init__()
        self.features = features

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        layers = []
        for feat in self.features[:-1]:
            layers.append(hk.Linear(feat))
            layers.append(jax.nn.relu)
            layers.append(hk.Linear(self.features[-1]))

        mlp = hk.Sequential(layers)
        return mlp(x)

# Use MLP block to define the update node function
update_node_fn = lambda x: MLP(features=[8, 4])(x)

# Adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L506
def GraphConvolution(update_node_fn: Callable,
                     aggregate_nodes_fn: Callable = jax.ops.segment_sum,
                     add_self_edges: bool = False,
                     symmetric_normalization: bool = True) -> Callable:
  """Returns a method that applies a Graph Convolution layer.

  Graph Convolutional layer as in https://arxiv.org/abs/1609.02907,
  NOTE: This implementation does not add an activation after aggregation.
  If you are stacking layers, you may want to add an activation between
  each layer.
  Args:
    update_node_fn: function used to update the nodes. In the paper a single
      layer MLP is used.
    aggregate_nodes_fn: function used to aggregates the sender nodes.
    add_self_edges: whether to add self edges to nodes in the graph as in the
      paper definition of GCN. Defaults to False.
    symmetric_normalization: whether to use symmetric normalization. Defaults to
      True.

  Returns:
    A method that applies a Graph Convolution layer.
  """

  def _ApplyGCN(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Convolution layer."""
    nodes, _, receivers, senders, _, _, _ = graph

    # First pass nodes through the node updater.
    nodes = update_node_fn(nodes)
    # Equivalent to jnp.sum(n_node), but jittable
    total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
    if add_self_edges:
      # We add self edges to the senders and receivers so that each node
      # includes itself in aggregation.
      # In principle, a `GraphsTuple` should partition by n_edge, but in
      # this case it is not required since a GCN is agnostic to whether
      # the `GraphsTuple` is a batch of graphs or a single large graph.
      conv_receivers, conv_senders = add_self_edges_fn(receivers, senders,
                                                       total_num_nodes)
    else:
      conv_senders = senders
      conv_receivers = receivers

    # pylint: disable=g-long-lambda
    if symmetric_normalization:
      # Calculate the normalization values.
      count_edges = lambda x: jax.ops.segment_sum(
          jnp.ones_like(conv_senders), x, total_num_nodes)
      sender_degree = count_edges(conv_senders)
      receiver_degree = count_edges(conv_receivers)

      # Pre normalize by sqrt sender degree.
      # Avoid dividing by 0 by taking maximum of (degree, 1).
      nodes = tree.tree_map(
          lambda x: x * jax.lax.rsqrt(jnp.maximum(sender_degree, 1.0))[:, None],
          nodes,
      )
      # Aggregate the pre-normalized nodes.
      nodes = tree.tree_map(
          lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                       total_num_nodes), nodes)
      # Post normalize by sqrt receiver degree.
      # Avoid dividing by 0 by taking maximum of (degree, 1).
      nodes = tree.tree_map(
          lambda x:
          (x * jax.lax.rsqrt(jnp.maximum(receiver_degree, 1.0))[:, None]),
          nodes,
      )
    else:
      nodes = tree.tree_map(
          lambda x: aggregate_nodes_fn(x[conv_senders], conv_receivers,
                                       total_num_nodes), nodes)
    # pylint: enable=g-long-lambda
    return graph._replace(nodes=nodes)

  return _ApplyGCN

# GAT implementation adapted from https://github.com/deepmind/jraph/blob/master/jraph/_src/models.py#L442.
def GAT(attention_query_fn: Callable,
        attention_logit_fn: Callable,
        node_update_fn: Optional[Callable] = None,
        add_self_edges: bool = True) -> Callable:
  """Returns a method that applies a Graph Attention Network layer.

  Graph Attention message passing as described in
  https://arxiv.org/pdf/1710.10903.pdf. This model expects node features as a
  jnp.array, may use edge features for computing attention weights, and
  ignore global features. It does not support nests.
  Args:
    attention_query_fn: function that generates attention queries from sender
      node features.
    attention_logit_fn: function that converts attention queries into logits for
      softmax attention.
    node_update_fn: function that updates the aggregated messages. If None, will
      apply leaky relu and concatenate (if using multi-head attention).

  Returns:
    A function that applies a Graph Attention layer.
  """
  # pylint: disable=g-long-lambda
  if node_update_fn is None:
    # By default, apply the leaky relu and then concatenate the heads on the
    # feature axis.
    node_update_fn = lambda x: jnp.reshape(
        jax.nn.leaky_relu(x), (x.shape[0], -1))

  def _ApplyGAT(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Applies a Graph Attention layer."""
    nodes, edges, receivers, senders, _, _, _ = graph
    # Equivalent to the sum of n_node, but statically known.
    try:
      sum_n_node = nodes.shape[0]
    except IndexError:
      raise IndexError('GAT requires node features')

    # Pass nodes through the attention query function to transform
    # node features, e.g. with an MLP.
    nodes = attention_query_fn(nodes)

    total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]
    if add_self_edges:
      # We add self edges to the senders and receivers so that each node
      # includes itself in aggregation.
      receivers, senders = add_self_edges_fn(receivers, senders,
                                             total_num_nodes)

    # We compute the softmax logits using a function that takes the
    # embedded sender and receiver attributes.
    sent_attributes = nodes[senders]
    received_attributes = nodes[receivers]
    att_softmax_logits = attention_logit_fn(sent_attributes,
                                            received_attributes, edges)

    # Compute the attention softmax weights on the entire tree.
    att_weights = jraph.segment_softmax(
        att_softmax_logits, segment_ids=receivers, num_segments=sum_n_node)

    # Apply attention weights.
    messages = sent_attributes * att_weights
    # Aggregate messages to nodes.
    nodes = jax.ops.segment_sum(messages, receivers, num_segments=sum_n_node)

    # Apply an update function to the aggregated messages.
    nodes = node_update_fn(nodes)

    return graph._replace(nodes=nodes)

  # pylint: enable=g-long-lambda
  return _ApplyGAT

def attention_logit_fn(sender_attr: jnp.ndarray, receiver_attr: jnp.ndarray,
                       edges: jnp.ndarray) -> jnp.ndarray:
  del edges
  x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
  return hk.Linear(1)(x)




In [25]:
from ogb.nodeproppred import NodePropPredDataset

dataset = NodePropPredDataset(name = 'ogbn-arxiv')
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

# subsampling the number of nodes for computational reasons
subnodes = 10_000
train_idx = onp.random.choice(train_idx, size=subnodes, replace=False)
valid_idx = onp.random.choice(valid_idx, size=subnodes//2, replace=False)
test_idx = onp.random.choice(test_idx, size=subnodes//2, replace=False)

graph, labels = dataset[0]

def graph_to_jax_graph(graph: dict) -> jraph.GraphsTuple:
    """Returns a jax graph built from a dictionary representing a graph."""
    node_features = jnp.array(graph['node_feat'])
    senders, receivers = jnp.array(graph['edge_index'])
    edges = None
    n_node = jnp.array([graph['num_nodes']])
    n_edge = jnp.array([len(senders)])
    global_context = jnp.array([[1]])  # dummy global

    graph_jax = jraph.GraphsTuple(
        nodes=node_features,
        edges=edges,
        senders=senders,
        receivers=receivers,
        n_node=n_node,
        n_edge = n_edge,
        globals=global_context
    )
    
    return graph_jax

  loaded_dict = torch.load(pre_processed_file_path)


In [27]:
arxiv_graph = graph_to_jax_graph(graph)
arxiv_labels = jnp.array(labels)

# indices for splits for aggregation
train_mask = jnp.zeros(graph['num_nodes'], dtype=bool).at[train_idx].set(True)
valid_mask = jnp.zeros(graph['num_nodes'], dtype=bool).at[valid_idx].set(True)
test_mask = jnp.zeros(graph['num_nodes'], dtype=bool).at[test_idx].set(True)

In [28]:
def build_gcn_network(layers: list[int], n_classes: int):
    def gcn_network(graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
        for x in layers: 
            gn = GraphConvolution(
                update_node_fn=lambda n: jax.nn.relu(hk.Linear(x)(n)),
                add_self_edges=True
            )
            
            graph = gn(graph)

        gn = GraphConvolution(
            update_node_fn=hk.Linear(n_classes)
        )
        graph = gn(graph)
        
        return graph
    
    return gcn_network

In [29]:
layers = [128 for _ in range(3)]
n_classes = len(set(labels.flatten()))

network = hk.without_apply_rng(
    hk.transform(
        build_gcn_network(layers, n_classes)
    )
)

In [30]:
def optimize_network(
        network: hk.Transformed, 
        num_steps: int,
    ) -> jnp.ndarray:
    params = network.init(jax.random.PRNGKey(42), arxiv_graph)

    opt_init, opt_update = optax.adam(1e-2)
    opt_state = opt_init(params)

    @jax.jit
    def predict(params: hk.Params) -> jnp.ndarray:
        decoded_graph = network.apply(params, arxiv_graph)
        return jnp.argmax(decoded_graph.nodes, axis=1)

    @jax.jit
    def cross_entropy_loss(params: hk.Params) -> jnp.ndarray:
        decoded_graph = network.apply(params, arxiv_graph)
        log_prob = jax.nn.log_softmax(decoded_graph.nodes[train_mask])
        target = jax.nn.one_hot(arxiv_labels[train_mask], n_classes)
        return -jnp.sum(log_prob * target)


    def update(params: hk.Params, opt_state) -> Tuple[hk.Params, Any]:
        """Returns updated params and state."""
        g = jax.grad(cross_entropy_loss)(params)
        updates, opt_state = opt_update(g, opt_state)
        return optax.apply_updates(params, updates), opt_state

    @jax.jit
    def accuracy(params: hk.Params) -> jnp.ndarray:
        """Computes the node-accuracy on the different splits created above."""
        decoded_graph = network.apply(params, arxiv_graph)
        training_accuracy = jnp.mean(jnp.argmax(decoded_graph.nodes[train_mask], axis=1) == arxiv_labels[train_mask])
        testing_accuracy = jnp.mean(jnp.argmax(decoded_graph.nodes[test_mask], axis=1) == arxiv_labels[test_mask])
        overall_accuracy = jnp.mean(jnp.argmax(decoded_graph.nodes, axis=1) == arxiv_labels)
        return training_accuracy, testing_accuracy, overall_accuracy

    for step in range(num_steps):
        if step % 10 == 0:
            print(f"step {step} | train acc. {float(accuracy(params)[0])*100:.3f} % - test acc. {float(accuracy(params)[1])*100:.3f} % - overall acc. {float(accuracy(params)[2])*100:.3f} %")

        params, opt_state = update(params, opt_state)
    return predict(params)

In [31]:
optimize_network(network, 1)

: 