# JAX on IPU: GNN Tutorial
In this tutorial we use JAX on Graphcore IPUs to build a simple GNN for a small node classification task. For educational purposes we rely on plain JAX without using any higher-level libraries such as [Flax](https://github.com/google/flax), [Haiku](https://github.com/deepmind/dm-haiku), or [Jraph](https://github.com/deepmind/jraph).

First install and import some dependencies:

In [None]:
%pip uninstall -q -y jax jaxlib
%pip install -q https://github.com/graphcore-research/jax-experimental/releases/latest/download/jaxlib-0.3.15-cp38-none-manylinux2014_x86_64.whl
%pip install -q https://github.com/graphcore-research/jax-experimental/releases/latest/download/jax-0.3.16-py3-none-any.whl
%pip install -q matplotlib
%pip install -q networkx

In [None]:
import numpy as np
import jax
from jax import numpy as jnp
from jax.config import config
import networkx as nx
import matplotlib.pyplot as plt
from functools import partial
import time

We set `jax_platforms = "cpu,ipu"` for using CPU as the default platform for initialization of parameters and dataset.

In [None]:
config.FLAGS.jax_platforms = "cpu,ipu"

We can switch between devices to train the model by setting the `DEVICE` variable:

In [None]:
DEVICE = "ipu"

In [None]:
devices = jax.devices(DEVICE)
print(devices)

## Define the graph
For this notebook we use [Zachary's karate club](https://en.wikipedia.org/wiki/Zachary%27s_karate_club) graph, a well-known example for node classification on a small social graph: The 34 nodes of the graph represent the members of a karate club, edges represent social interactions between club members. A conflict between the members represented by nodes 0 and 33 lead to a splitting of the club. The task is to predict for every member which of the two new clubs they are going to join.

We can define the graph in this notebook:

In [None]:
edges = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), 
         (3, 2), (4, 0), (5, 0), (6, 0), (6, 4), 
         (6, 5), (7, 0), (7, 1), (7, 2), (7, 3), 
         (8, 0), (8, 2), (9, 2), (10, 0), (10, 4), 
         (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), 
         (13, 1), (13, 2), (13, 3), (16, 5), (16, 6), 
         (17, 0), (17, 1), (19, 0), (19, 1), (21, 0), 
         (21, 1), (25, 23), (25, 24), (27, 2), (27, 23), 
         (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), 
         (30, 8), (31, 0), (31, 24), (31, 25), (31, 28), 
         (32, 2), (32, 8), (32, 14), (32, 15), (32, 18), 
         (32, 20), (32, 22), (32, 23), (32, 29), (32, 30), 
         (32, 31), (33, 8), (33, 9), (33, 13), (33, 14), 
         (33, 15), (33, 18), (33, 19), (33, 20), (33, 22), 
         (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), 
         (33, 30), (33, 31), (33, 32)]

node_labels = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
num_nodes = len(node_labels)

#### Visualisation
Next we create a networkx graph for visualisation and represent the graph as a `jax.numpy` array.

In [None]:
g = nx.Graph()
g.add_nodes_from(range(num_nodes))
g.add_edges_from(edges)

Now we can draw the graph of the club members and their social interactions. The members 0 and 33 are treated as the only ones with know label corresponding to their new clubs.

In [None]:
c_0 = np.array([0.1, 0.5, 1.0])
c_1 = np.array([1.0, 0.5, 0.1])
c_default = np.array([.5, .5, .5])
c_error = np.array([1.0, 0.1, 0.1])

color_map = [c_default for _ in range(num_nodes)]
color_map[0] = c_0
color_map[-1] = c_1

NODE_SIZE = 200

In [None]:
pos = nx.kamada_kawai_layout(g)
fig, ax = plt.subplots(1, 1, figsize=[6, 7])
nx.draw(g, pos, node_size=NODE_SIZE, node_color=color_map, with_labels=True, font_color="w", font_size=10, ax=ax)

For modelling purposes we add inverse edges as well as self-loops to the graph:

In [None]:
all_edges = edges + [(edge[1], edge[0]) for edge in edges] + [(i, i) for i in range(num_nodes)]
graph = jnp.array(all_edges)

## Model definition
We solve this task with a simple Graph Convolutional Network (GCN) ([Kipf, Welling, 2016](https://arxiv.org/abs/1609.02907)).
First, we define functions to explicitly initialise the parameters and apply a GCN layer:

In [None]:
def gcn_layer(input_size, output_size, nonlinearity=None, use_bias=False):
    def parameter_init(key, scale=0.02):
        if use_bias:
            return (scale * jax.random.normal(key, (output_size, input_size)), jnp.zeros(output_size))
        return (scale * jax.random.normal(key, (output_size, input_size)), )
    
    def apply(params, node_embeddings, graph):
        node_embeddings = jnp.dot(params[0], node_embeddings.T).T
        if use_bias:
            node_embeddings = node_embeddings + params[1]
        if nonlinearity:
            node_embeddings = nonlinearity(node_embeddings)
        messages = node_embeddings[graph[:, 0]]
        node_embeddings = jax.ops.segment_sum(messages, graph[:, 1], num_nodes)
        return node_embeddings, graph
    
    return parameter_init, apply

Now, we can define a multi-layer GCN in a similar way:

In [None]:
def gcn(layer_sizes):
    layers = []
    parameter_inits = []
    for n, (input_size, output_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        layer_parameter_init, layer_apply = gcn_layer(
            input_size,
            output_size,
            nonlinearity=jax.nn.relu if n < len(layer_sizes) - 2 else None,
            use_bias=False
        )
        parameter_inits.append(layer_parameter_init)
        layers.append(layer_apply)
    
    def parameter_init(key, scale=0.02):
        keys = jax.random.split(key, len(layer_sizes))
        params = []
        for layer_parameter_init, layer_key in zip(parameter_inits, keys):
            params.append(layer_parameter_init(layer_key, scale))
        return params
    
    def apply(params, node_embeddings, graph):
        for layer, layer_params in zip(layers, params):
            node_embeddings, graph = layer(layer_params, node_embeddings, graph)
        return node_embeddings
    
    return parameter_init, apply

In [None]:
layer_size = [num_nodes] + [64, 64] + [2]
gcn_init, gcn_predict = gcn(layer_size)
params = gcn_init(jax.random.PRNGKey(1))

Finally, we define a prediction function that returns the softmax of the final, 2-dimensional node embeddings, a loss function and a training step that uses plain SGD to update the model paramters. Note that the loss function only uses the embeddings of the only two nodes whose label we know at this point: node 0 and node 33. 

In [None]:
def prediction(params, graph):
    initial_node_embeddings = jnp.eye(num_nodes)
    return jax.nn.softmax(gcn_predict(params, initial_node_embeddings, graph))

def loss_fun(params, graph):
    log_prob = jnp.log(prediction(params, graph))
    return -(log_prob[0, 0] + log_prob[-1, 1]) / 2

# Explicit jitting for IPU backend.
# Donate `params`` to keep parameters on IPU SRAM. 
@partial(jax.jit, backend=DEVICE, donate_argnums=(0,))
def training_step(params, graph, learning_rate):
    grads = jax.grad(loss_fun)(params, graph)
    return [[p - learning_rate * dp for p, dp in zip(param, d_param)] for param, d_param in zip(params, grads)]

## Training
We train the model on IPU for 20 steps. The first step includes compilation and therefore takes longer.

Every fifth step we copy the parameters back to the host to perform a validation step on host and visualise the node probabilities. The accuracy (proportion of correctly classified nodes should reach a steady state at 0.97 quite early on while the validation loss still drops, showing a better separation of the to classes.

In [None]:
def visualise(prob, ax):
    color_map_predicted = [p[0] * c_0 + p[1] * c_1 for p in prob]
    color_map_predicted[0] = c_0
    color_map_predicted[-1] = c_1
    
    nx.draw(g, pos, node_size=NODE_SIZE, node_color=color_map_predicted, with_labels=True, font_color="w", font_size=10, ax=ax)
    ax.set_title("Soft Predictions")

In [None]:
learning_rate = 0.02
num_steps = 20
validation_step = 5

fig, ax = plt.subplots(1, num_steps//validation_step, figsize=[20, 6])
fig.suptitle("Node classification", fontsize=20)
fig.tight_layout()

for step in range(1, num_steps + 1):
    t0 = time.time()
    params = training_step(params, graph, learning_rate)
    if step % validation_step == 0:
        params_host = jax.device_get(params)
        probs = prediction(params_host, graph)
        acc = jnp.mean(jnp.argmax(probs, axis=1) == node_labels)
        log_probs = jnp.log(probs)
        valid_loss = -np.mean([log_probs[n, node_labels[n]] for n in range(num_nodes)])
        visualise(np.array(probs), ax[step//validation_step - 1])
        ax[step//validation_step - 1].set_title(f"Step {step}")
        print(f"Step {step}, duration = {(time.time() - t0) * 1000:.2f} ms, Validation Loss = {valid_loss:.3f}, Accuracy = {acc:.3f}")
    else:
        print(f"Step {step}, duration = {(time.time() - t0) * 1000:.2f} ms")


We now plot the final results. We should find that node 8 gets misclassified, an error observed in many predictions on this dataset, including in [Zachary's original 1977 publication](https://www.jstor.org/stable/3629752).

In [None]:
params_host = jax.device_get(params)
probs = np.array(prediction(params_host, graph))

In [None]:
predicted_class = np.argmax(probs, axis=1)
errors = (node_labels != predicted_class)

color_map_predicted = [(1-p) * 0.75 * c_0 + p * 0.75 * c_1 for p in predicted_class]
color_map_predicted[0] = c_0
color_map_predicted[-1] = c_1

color_map_labels = [l * 0.75 * c_1 + (1 - l) * 0.75 * c_0 for l in node_labels]
color_map_labels[0] = c_0
color_map_labels[-1] = c_1

color_map_err = [l * c_error + (1 - l) * c_default for l in errors]
color_map_err[0] = c_0
color_map_err[-1] = c_1

In [None]:
fig, ax = plt.subplots(1, 3, figsize=[18, 7])
nx.draw(g, pos, node_size=NODE_SIZE, node_color=color_map_predicted, with_labels=True, font_color="w", font_size=10, ax=ax[0])
nx.draw(g, pos, node_size=NODE_SIZE, node_color=color_map_labels, with_labels=True, font_color="w", font_size=10, ax=ax[1])
nx.draw(g, pos, node_size=NODE_SIZE, node_color=color_map_err, with_labels=True, font_color="w", font_size=10, ax=ax[2])

ax[0].set_title("Predicted Labels")
ax[1].set_title("Ground Truth")
ax[2].set_title("Difference")
plt.show()