Experiment with jraph && karateclub, the example was originally here: https://github.com/deepmind/jraph/blob/master/jraph/examples/zacharys_karate_club.py

In [1]:
import logging

import jax
import jax.numpy as jnp
import jraph
import optax

ModuleNotFoundError: No module named 'jax'

In [None]:
def optimize_club(num_steps: int):
  """Solves the karte club problem by optimizing the assignments of students."""
  network = hk.without_apply_rng(hk.transform(network_definition))
  zacharys_karate_club = get_zacharys_karate_club()
  labels = get_ground_truth_assignments_for_zacharys_karate_club()
  params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

  @jax.jit
  def prediction_loss(params):
    decoded_nodes = network.apply(params, zacharys_karate_club)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes)
    # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
    # and John A (Node 33).
    return -(log_prob[0, 0] + log_prob[33, 1])

  opt_init, opt_update = optax.adam(1e-2)
  opt_state = opt_init(params)

  @jax.jit
  def update(params, opt_state):
    g = jax.grad(prediction_loss)(params)
    updates, opt_state = opt_update(g, opt_state)
    return optax.apply_updates(params, updates), opt_state

  @jax.jit
  def accuracy(params):
    decoded_nodes = network.apply(params, zacharys_karate_club)
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

  for step in range(num_steps):
    logging.info("step %r accuracy %r", step, accuracy(params).item())
    params, opt_state = update(params, opt_state)
