In [None]:
## Molecule Example from Jraph colab notebook

In [1]:
import jraph 
import jax
import pickle
import networkx as nx
import optax
import jax.numpy as jnp
import functools

from flax import linen as nn
from typing import Sequence
from typing import Any, Callable, Dict, List, Optional, Tuple

In [2]:
# Download jraph version of MUTAG.
# !wget -P tmp/ https://storage.googleapis.com/dm-educational/assets/graph-nets/jraph_datasets/mutag.pickle

In [3]:
def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
    nodes, edges, receivers, senders, _, _, _ = jraph_graph
    nx_graph = nx.DiGraph()
    if nodes is None:
        for n in range(jraph_graph.n_node[0]):
          nx_graph.add_node(n)
    else:
        for n in range(jraph_graph.n_node[0]):
          nx_graph.add_node(n, node_feature=nodes[n])
    if edges is None:
        for e in range(jraph_graph.n_edge[0]):
          nx_graph.add_edge(int(senders[e]), int(receivers[e]))
    else:
        for e in range(jraph_graph.n_edge[0]):
          nx_graph.add_edge(
              int(senders[e]), int(receivers[e]), edge_feature=edges[e])
    return nx_graph


def draw_jraph_graph_structure(jraph_graph: jraph.GraphsTuple) -> None:
    nx_graph = convert_jraph_to_networkx_graph(jraph_graph)
    pos = nx.spring_layout(nx_graph)
    nx.draw(
      nx_graph, pos=pos, with_labels=True, node_size=500, font_color='yellow')

In [4]:
with open('tmp/mutag.pickle', 'rb') as f:
    mutag_ds = pickle.load(f)

In [5]:
# for i, obj in enumerate(mutag_ds[1:150]):

#     g = obj['input_graph']
#     print(f'Number of nodes: {g.n_node[0]}')
#     print(f'Number of edges: {g.n_edge[0]}')
#     print(f'Node features shape: {g.nodes.shape}')
#     print(f'Edge features shape: {g.edges.shape}')
#     print(f"Target: {obj['target']}")

In [6]:
# draw_jraph_graph_structure(g)

In [7]:
train_mutag_ds = mutag_ds[:150]
test_mutag_ds = mutag_ds[150:]

In [8]:
# Padding

def _nearest_bigger_power_of_two(x: int) -> int:
    """Computes the nearest power of two greater than x for padding."""
    y = 2
    while y < x:
        y *= 2
    return y

def pad_graph_to_nearest_power_of_two(
    graphs_tuple: jraph.GraphsTuple) -> jraph.GraphsTuple:
    """Pads a batched `GraphsTuple` to the nearest power of two.
    For example, if a `GraphsTuple` has 7 nodes, 5 edges and 3 graphs, this method
    would pad the `GraphsTuple` nodes and edges:
    7 nodes --> 8 nodes (2^3)
    5 edges --> 8 edges (2^3)
    And since padding is accomplished using `jraph.pad_with_graphs`, an extra
    graph and node is added:
    8 nodes --> 9 nodes
    3 graphs --> 4 graphs
    Args:
    graphs_tuple: a batched `GraphsTuple` (can be batch size 1).
    Returns:
    A graphs_tuple batched to the nearest power of two.
    """
    # Add 1 since we need at least one padding node for pad_with_graphs.
    pad_nodes_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_node)) + 1
    pad_edges_to = _nearest_bigger_power_of_two(jnp.sum(graphs_tuple.n_edge))
    # Add 1 since we need at least one padding graph for pad_with_graphs.
    # We do not pad to nearest power of two because the batch size is fixed.
    pad_graphs_to = graphs_tuple.n_node.shape[0] + 1
    return jraph.pad_with_graphs(
        graphs_tuple, 
        pad_nodes_to, 
        pad_edges_to,
        pad_graphs_to
    )

In [25]:
class ExplicitMLP(nn.Module):
    """A flax MLP."""
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
          x = lyr(x)
          if i != len(self.features) - 1:
            x = nn.relu(x)
        return x
    
def make_embed_fn(latent_size):
    def embed(inputs):
        return nn.Dense(latent_size)(inputs)
    return embed

def make_mlp(features):
    @jraph.concatenated_args
    def update_fn(inputs):
        return ExplicitMLP(features)(inputs)
    return update_fn

class GraphNetwork(nn.Module):
    """A flax GraphNetwork."""
    mlp_features: Sequence[int]
    latent_size: int

    @nn.compact
    def __call__(self, graph):
    
        # Add a global parameter for graph classification computation
        
        graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

        embedder = jraph.GraphMapFeatures(
            embed_node_fn=make_embed_fn(self.latent_size),
            embed_edge_fn=make_embed_fn(self.latent_size),
            embed_global_fn=make_embed_fn(self.latent_size))
        
        net = jraph.GraphNetwork(
            update_node_fn=make_mlp(self.mlp_features),
            update_edge_fn=make_mlp(self.mlp_features),
            # The global update outputs size 2 for binary classification.
            update_global_fn=make_mlp(self.mlp_features + (1,)))  # pytype: disable=unsupported-operands

        return net(embedder(graph))

In [26]:
def compute_loss(params, graph, label, net):
    """Computes loss."""
    pred_graph = net.apply(params, graph)
    preds = jax.nn.log_softmax(pred_graph.globals)
    targets = jax.nn.one_hot(label, 2)

    # Since we have an extra 'dummy' graph in our batch due to padding, we want
    # to mask out any loss associated with the dummy graph.
    # Since we padded with `pad_with_graphs` we can recover the mask by using
    # get_graph_padding_mask.
    mask = jraph.get_graph_padding_mask(pred_graph)

    # Cross entropy loss.
    # loss = -jnp.mean(preds * targets * mask[:, None])
    loss = -jnp.mean(preds * targets)

    # Accuracy taking into account the mask.
    accuracy = jnp.sum(
      (jnp.argmax(pred_graph.globals, axis=1) == label) * mask)/jnp.sum(mask)
        
    return loss, accuracy

In [27]:
def train(dataset, num_train_steps):

    net = GraphNetwork(mlp_features=(128, 128), latent_size=128)

    # Get a candidate graph and label to initialize the network.
    graph = dataset[0]['input_graph']

    # Initialize the network.
    params = net.init(jax.random.PRNGKey(42), graph)

    opt_init, opt_update = optax.adam(1e-4)
    opt_state = opt_init(params)

    compute_loss_fn = functools.partial(compute_loss, net=net)
    # We jit the computation of our loss, since this is the main computation.
    # Using jax.jit means that we will use a single accelerator. If you want
    # to use more than 1 accelerator, use jax.pmap. More information can be
    # found in the jax documentation.
    compute_loss_fn = jax.jit(jax.value_and_grad(compute_loss_fn, has_aux=True))

    for idx in range(num_train_steps):
        
        graph = dataset[idx % len(dataset)]['input_graph']
        label = dataset[idx % len(dataset)]['target']
        
        # Jax will re-jit your graphnet every time a new graph shape is encountered.
        # In the limit, this means a new compilation every training step, which
        # will result in *extremely* slow training. To prevent this, pad each
        # batch of graphs to the nearest power of two. Since jax maintains a cache
        # of compiled programs, the compilation cost is amortized.
        graph = pad_graph_to_nearest_power_of_two(graph)

        # Remove the label from the input graph/
        label = jnp.concatenate([label, jnp.array([0])])
        
        (loss, acc), grad = compute_loss_fn(params, graph, label)
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)

        if idx % 50 == 0:
            print(f'step: {idx}, loss: {loss}, acc: {acc}')
    
    print('Training finished')
    
    return params

In [28]:
# params = train(train_mutag_ds, num_train_steps=1000)

In [29]:
def evaluate(dataset: List[Dict[str, Any]],
             params) -> Tuple[jnp.ndarray, jnp.ndarray]:
    
    """Evaluation Script."""
    # Transform impure `net_fn` to pure functions with hk.transform.
    net = GraphNetwork(mlp_features=[128, 128], latent_size=128)
    # Get a candidate graph and label to initialize the network.
    graph = dataset[0]['input_graph']
    accumulated_loss = 0
    accumulated_accuracy = 0
    compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net))
    
    for idx in range(len(dataset)):
        graph = dataset[idx]['input_graph']
        label = dataset[idx]['target']
        graph = pad_graph_to_nearest_power_of_two(graph)
        label = jnp.concatenate([label, jnp.array([0])])
        loss, acc = compute_loss_fn(params, graph, label)
        accumulated_accuracy += acc
        accumulated_loss += loss
        if idx % 50 == 0:
          print(f'Evaluated {idx + 1} graphs')
    
    print('Completed evaluation.')
    loss = accumulated_loss / idx
    accuracy = accumulated_accuracy / idx
    print(f'Eval loss: {loss}, accuracy {accuracy}')
    return loss, accuracy

In [30]:
evaluate(test_mutag_ds, params)

In [31]:
# dataset = train_mutag_ds

# ### AUC etc. 

# net = GraphNetwork(mlp_features=[128, 128], latent_size=128)
# # Get a candidate graph and label to initialize the network.
# graph = dataset[0]['input_graph']
# accumulated_loss = 0
# accumulated_accuracy = 0
# compute_loss_fn = jax.jit(functools.partial(compute_loss, net=net))
# preds = []

# for idx in range(len(dataset)):
#     graph = dataset[idx]['input_graph']
#     # label = dataset[idx]['target']
#     graph = pad_graph_to_nearest_power_of_two(graph)
#     # label = jnp.concatenate([label, jnp.array([0])])
#     pred_graph = net.apply(params, graph)
#     preds.append(jax.nn.log_softmax(pred_graph.globals).flatten()[0])
    
# true_y = jnp.array([d['target'] for d in dataset]).flatten()
# preds = jnp.exp(jnp.array(preds))
# plot_and_print_auc(true_y, preds)