<a href="https://colab.research.google.com/github/dsteiner93/rubiks-ml/blob/main/toy_gnns.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook we define a toy graph problem and solve it using implementations of graph learning methods from scratch as well as the framework [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/).

In [None]:
import collections
import dataclasses
import jax
import jax.numpy as jnp
import optax
import numpy as np
import random
import torch
import torch.optim as optim
from torch import nn

### Define a toy graph problem

In [None]:
def set_seeds(seed: int) -> None:
  random.seed(seed)
  np.random.seed(seed + 1)
  torch.manual_seed(seed + 2)

In [None]:
@dataclasses.dataclass(frozen=True, kw_only=True)
class Node:
  node_id: int
  node_feature_a: int  # Must be in [0, node_feature_a_size]
  node_feature_b: float  # Must be in [-1.0, 1.0]
  distractor_feature: int  # Must be in [-1, 1]

  def __post_init__(self):
    assert self.node_feature_a >= 0
    assert self.node_feature_b >= -1.0 and self.node_feature_b <= 1.0
    assert self.distractor_feature >= -1 and self.distractor_feature <= 1

  def get_feature_vector(self, *, node_feature_a_size: int) -> np.ndarray:
    assert self.node_feature_a <= node_feature_a_size

    feature_vector = np.zeros(node_feature_a_size + 2, dtype=np.float32)
    feature_vector[self.node_feature_a] = 1.0
    feature_vector[-2] = self.node_feature_b
    feature_vector[-1] = self.distractor_feature
    return feature_vector

  def get_feature_vector_jnp(self, *, node_feature_a_size: int) -> jnp.ndarray:
    return jnp.array(self.get_feature_vector(node_feature_a_size=node_feature_a_size))

@dataclasses.dataclass(frozen=True, kw_only=True)
class UndirectedEdge:
  # By convention, we'll always put the lower id as the src and higher id as dst.
  src: int
  dst: int

class Graph:

  def __init__(self):
    self.id_to_node: dict[int, Node] = {}
    self.edges: set[UndirectedEdge] = set()
    self.adjacency_list: dict[int, set[int]] = collections.defaultdict(set)

  def add_node(self, node_to_add: Node) -> None:
    if node_to_add.node_id in self.id_to_node:
      raise ValueError(f"Node with id {node_to_add.node_id} is already in the graph.")
    self.id_to_node[node_to_add.node_id] = node_to_add

  def add_edge(self, src: int, dst: int) -> None:
    if src not in self.id_to_node:
      raise ValueError(f"Source {src} not in node ids {self.id_to_node.keys()}")
    if dst not in self.id_to_node:
      raise ValueError(f"Destination {dst} not in node ids {self.id_to_node.keys()}")
    if dst < src:
      tmp = src
      src = dst
      dst = tmp
    edge_to_add = UndirectedEdge(src=src, dst=dst)
    self.edges.add(edge_to_add)
    self.adjacency_list[src].add(dst)
    # Because it is undirected, the edge needs to go both ways.
    self.adjacency_list[dst].add(src)

  def get_adjacency_matrix(self, add_self_loops: bool = False) -> np.ndarray:
    if sorted(list(self.id_to_node.keys())) != list(range(len(self.id_to_node))):
      raise ValueError(f"Node ids {sorted(list(self.id_to_node.keys()))} must be range(len(id_to_node)).")
    adjacency_matrix = np.zeros((len(self.id_to_node), len(self.id_to_node)), dtype=np.int32)
    for node_id, neighbors in self.adjacency_list.items():
      for neighbor in neighbors:
        adjacency_matrix[node_id][neighbor] = 1
      if add_self_loops:
        adjacency_matrix[node_id][node_id] = 1
    return adjacency_matrix

  def get_adjacency_matrix_jnp(self, add_self_loops: bool = False) -> jnp.ndarray:
    return jnp.array(self.get_adjacency_matrix(add_self_loops=add_self_loops))

  def get_edge_connections_coo(self) -> torch.LongTensor:
    # Used by PyTorch Geomtric.
    # COO described here: https://docs.pytorch.org/docs/stable/sparse.html#sparse-coo-docs
    # Since the graph is undirected, we need to specify the edges going in both directions.
    a = torch.LongTensor([[e.src, e.dst] for e in self.edges]).t().reshape((2, -1))
    b = torch.LongTensor([[e.dst, e.src] for e in self.edges]).t().reshape((2, -1))
    return torch.cat([a, b], dim=1)

In [None]:
import unittest

class NodeTest(unittest.TestCase):

  def test_invalid_node_rejected(self):
    with self.assertRaises(AssertionError):
      node = Node(node_id=0, node_feature_a=1, node_feature_b=-2.0, distractor_feature=1)

  def test_get_feature_vector(self):
    node = Node(node_id=0, node_feature_a=5, node_feature_b=.2, distractor_feature=1)
    expected = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, .2, 1], dtype=np.float32)
    np.testing.assert_array_equal(expected, node.get_feature_vector(node_feature_a_size=10))
    np.testing.assert_array_equal(expected, node.get_feature_vector_jnp(node_feature_a_size=10))

class GraphTest(unittest.TestCase):

  def test_graph_representations(self):
    graph = Graph()
    node0 = Node(node_id=0, node_feature_a=5, node_feature_b=.0, distractor_feature=1)
    node1 = Node(node_id=1, node_feature_a=2, node_feature_b=-.5, distractor_feature=0)
    node2 = Node(node_id=2, node_feature_a=6, node_feature_b=-.123, distractor_feature=1)
    node3 = Node(node_id=3, node_feature_a=9, node_feature_b=.55, distractor_feature=-1)
    graph.add_node(node0)
    graph.add_node(node1)
    graph.add_node(node2)
    graph.add_node(node3)
    graph.add_edge(0, 1)
    graph.add_edge(0, 2)
    graph.add_edge(2, 1)
    graph.add_edge(3, 0)
    expected_adjacency_matrix_no_self_loops = np.array([
        [0, 1, 1, 1],
        [1, 0, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ])
    expected_adjacency_matrix_with_self_loops = np.array([
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 1, 0],
        [1, 0, 0, 1],
    ])
    np.testing.assert_array_equal(
        expected_adjacency_matrix_no_self_loops, graph.get_adjacency_matrix(add_self_loops=False))
    np.testing.assert_array_equal(
        expected_adjacency_matrix_with_self_loops, graph.get_adjacency_matrix(add_self_loops=True))
    np.testing.assert_array_equal(
        expected_adjacency_matrix_no_self_loops, graph.get_adjacency_matrix_jnp(add_self_loops=False))
    np.testing.assert_array_equal(
        expected_adjacency_matrix_with_self_loops, graph.get_adjacency_matrix_jnp(add_self_loops=True))
    coo = graph.get_edge_connections_coo()
    self.assertEqual((2, 8), coo.shape)


unittest.main(NodeTest(), argv=[''], verbosity=2, exit=False)
unittest.main(GraphTest(), argv=[''], verbosity=2, exit=False)

test_get_feature_vector (__main__.NodeTest.test_get_feature_vector) ... ok
test_invalid_node_rejected (__main__.NodeTest.test_invalid_node_rejected) ... ok

----------------------------------------------------------------------
Ran 2 tests in 1.696s

OK
test_graph_representations (__main__.GraphTest.test_graph_representations) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.078s

OK


<unittest.main.TestProgram at 0x7d97d03fda10>

In [None]:
# We will construct random graphs of variable size for a toy node classification problem.
# Each node will be assigned to one of 8 possible classes.

# Each node has 3 features with the following rules:
# node_feature_a is an int that can take values from [0, 9]
# node_feature_b is a float that can take values from [-1, 1]
# distractor_feature is an int that can take values in {-1, 0, 1}, but it is never relevant to making a correct prediction.

# The output prediction is one of 8 classes {0, 1, 2, 3, 4, 5, 6, 7}.
# If a node has exactly 3 neighbors (not including self), its class label should be 0.
# Elif all of a node's neighbors within a 2-hop neighborhood (including self) have node_feature_b <= 0, it should be 1.
# Elif >= 75% of a node's neighbors (including self) have an even value for node_feature_a, it should be 2.
# Elif >= 75% of a node's neighbors (including self) have an odd value for node_feature_a, it should be 3.
# Elif >= 50% of a node's neighbors (not including self) have node_feature_b > 0, it should be 4.
# Elif node_feature_a is even and node_feature_b is <= 0, it should be 5.
# Elif node_feature_a is odd and node_feature_b is > 0, it should be 6.
# Everything else should be 7.

set_seeds(1)

node_feature_a_size = 10
num_classes = 8

def generate_random_graph(
    mean_nodes: int = 8,
) -> Graph:
  num_nodes = max(1, int(np.random.normal(loc=mean_nodes, scale=2)))
  graph = Graph()
  for node_id in range(num_nodes):
    node_feature_b = random.uniform(-1, 1)
    if node_feature_b >= 0 and random.uniform(0, 1) <= .25:
      # Make it a bit more likely for node_feature_b to be negative.
      node_feature_b = -1.0
    graph.add_node(
        Node(node_id=node_id,
             node_feature_a=random.randint(0, node_feature_a_size-1),
             node_feature_b=node_feature_b,
             distractor_feature=random.randint(-1, 1))
    )
  mean_edges = 4 * mean_nodes
  for _ in range(mean_edges):
    src = random.randint(0, num_nodes-1)
    dst = random.randint(0, num_nodes-1)
    if src == dst:
      continue
    graph.add_edge(src, dst)

  return graph

def get_node_labels_for_graph(graph: Graph) -> list[int]:
  labels = []
  num_nodes = len(graph.id_to_node)
  for node_id in range(num_nodes):
    label = -1

    all_1_hop_neighbors_not_including_self = [e for e in graph.adjacency_list[node_id] if e != node_id]
    num_neighbors_of_node_not_including_self = len(all_1_hop_neighbors_not_including_self)

    all_2_hop_have_negative_node_feature_b = graph.id_to_node[node_id].node_feature_b <= 0
    for one_hop_node_id in all_1_hop_neighbors_not_including_self:
      one_hop_node = graph.id_to_node[one_hop_node_id]
      if one_hop_node.node_feature_b > 0:
        all_2_hop_have_negative_node_feature_b = False
        break
      for two_hop_node_id in graph.adjacency_list[one_hop_node_id]:
        two_hop_node = graph.id_to_node[two_hop_node_id]
        if two_hop_node.node_feature_b > 0:
          all_2_hop_have_negative_node_feature_b = False
          break

    all_1_hop_neighbors_including_self = set(graph.adjacency_list[node_id])
    all_1_hop_neighbors_including_self.add(node_id)
    even_value_for_feature_a_neighbors = [e for e in all_1_hop_neighbors_including_self if graph.id_to_node[e].node_feature_a % 2 == 0]
    odd_value_for_feature_a_neighbors = [e for e in all_1_hop_neighbors_including_self if graph.id_to_node[e].node_feature_a % 2 != 0]
    positive_feature_b_neighbors_not_including_self = [e for e in all_1_hop_neighbors_not_including_self if graph.id_to_node[e].node_feature_b > 0]

    if num_neighbors_of_node_not_including_self == 3:
      label = 0
    elif all_2_hop_have_negative_node_feature_b:
      label = 1
    elif float(len(even_value_for_feature_a_neighbors)) / len(all_1_hop_neighbors_including_self) >= .75:
      label = 2
    elif float(len(odd_value_for_feature_a_neighbors)) / len(all_1_hop_neighbors_including_self) >= .75:
      label = 3
    elif len(all_1_hop_neighbors_not_including_self) > 0 and (float(len(positive_feature_b_neighbors_not_including_self)) / len(all_1_hop_neighbors_not_including_self)) >= .5:
      label = 4
    elif graph.id_to_node[node_id].node_feature_a % 2 == 0 and graph.id_to_node[node_id].node_feature_b <= 0:
      label = 5
    elif graph.id_to_node[node_id].node_feature_a % 2 != 0 and graph.id_to_node[node_id].node_feature_b > 0:
      label = 6
    else:
      label = 7

    assert label >= 0 and label <= 7
    labels.append(label)

  assert len(labels) == num_nodes
  return labels

def get_train_and_test_set(*, train_set_size: int, test_set_size: int) -> tuple[list[tuple[Graph, list[int]]], list[tuple[Graph, list[int]]]]:
  # Returns tuple of (train_set, test_set)
  # train_set and test_set are both a list of pairs of Graphs and labels for all nodes in that graph (in order).
  train_set = []
  for _ in range(train_set_size):
    graph = generate_random_graph()
    labels = get_node_labels_for_graph(graph)
    train_set.append((graph, labels))

  test_set = []
  for _ in range(test_set_size):
    graph = generate_random_graph()
    labels = get_node_labels_for_graph(graph)
    test_set.append((graph, labels))

  return train_set, test_set


train_set, test_set = get_train_and_test_set(train_set_size=10000, test_set_size=1000)
example_node = Node(node_id=0, node_feature_a=0, node_feature_b=0.0, distractor_feature=0)

def print_statistics(set_to_check: list[tuple[Graph, list[int]]]) -> None:
  all_graph_sizes = []
  all_labels = collections.defaultdict(int)
  for graph, labels in set_to_check:
    all_graph_sizes.append(len(graph.id_to_node))
    for label in labels:
      all_labels[label] += 1

  all_labels_count = sum(all_labels.values())
  print(f"Average graph size: {np.mean(all_graph_sizes)}")
  print(f"Total node count: {all_labels_count}")
  for i in range(num_classes):
    print(f"Label {i}. Count: {all_labels[i]}. Percentage: {all_labels[i] / all_labels_count:.2f}.")

# Compute some statistics on the train and test sets.
print("Stats for train set:")
print_statistics(train_set)
print()
print("Stats for test set:")
print_statistics(test_set)

Stats for train set:
Average graph size: 7.459
Total node count: 74590
Label 0. Count: 13973. Percentage: 0.19.
Label 1. Count: 2120. Percentage: 0.03.
Label 2. Count: 7907. Percentage: 0.11.
Label 3. Count: 8231. Percentage: 0.11.
Label 4. Count: 17112. Percentage: 0.23.
Label 5. Count: 7622. Percentage: 0.10.
Label 6. Count: 5056. Percentage: 0.07.
Label 7. Count: 12569. Percentage: 0.17.

Stats for test set:
Average graph size: 7.538
Total node count: 7538
Label 0. Count: 1487. Percentage: 0.20.
Label 1. Count: 166. Percentage: 0.02.
Label 2. Count: 802. Percentage: 0.11.
Label 3. Count: 749. Percentage: 0.10.
Label 4. Count: 1736. Percentage: 0.23.
Label 5. Count: 776. Percentage: 0.10.
Label 6. Count: 531. Percentage: 0.07.
Label 7. Count: 1291. Percentage: 0.17.


### Solve using only node features

In [None]:
def convert_dataset_to_vectors(*, dataset: list[tuple[Graph, list[int]]], batch_size: int | None = None) -> tuple[jnp.ndarray, jnp.ndarray]:
  all_node_vectors = []
  all_labels = []
  for graph, labels in dataset:
    for node_id in range(len(graph.id_to_node)):
      all_node_vectors.append(graph.id_to_node[node_id].get_feature_vector_jnp(node_feature_a_size=node_feature_a_size))
    targets = jnp.array(labels).reshape(-1)
    one_hot_targets = jnp.eye(num_classes)[targets]
    all_labels.append(one_hot_targets)
  stacked_node_features = jnp.vstack(all_node_vectors)
  stacked_labels = jnp.vstack(all_labels)
  if batch_size is None:
    return stacked_node_features, stacked_labels

  if stacked_node_features.shape[0] % batch_size != 0:
    number_to_add = batch_size - (stacked_node_features.shape[0] % batch_size)
    stacked_node_features = jnp.vstack([stacked_node_features, jnp.zeros((number_to_add, stacked_node_features.shape[1]))])
  if stacked_labels.shape[0] % batch_size != 0:
    number_to_add = batch_size - (stacked_labels.shape[0] % batch_size)
    stacked_labels = jnp.vstack([stacked_labels, jnp.zeros((number_to_add, stacked_labels.shape[1]))])
  batched_node_features = stacked_node_features.reshape((-1, batch_size, stacked_node_features.shape[1]))
  batched_labels = stacked_labels.reshape((-1, batch_size, stacked_labels.shape[1]))
  return batched_node_features, batched_labels


batched_train_set, batched_train_labels = convert_dataset_to_vectors(dataset=train_set, batch_size=64)
vectorized_test_set, vectorized_test_labels = convert_dataset_to_vectors(dataset=test_set, batch_size=None)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batched_train_set_torch = torch.Tensor(np.array(batched_train_set)).to(device)
batched_train_labels_torch = torch.Tensor(np.array(batched_train_labels)).to(device)
vectorized_test_set_torch = torch.Tensor(np.array(vectorized_test_set)).to(device)
vectorized_test_labels_torch = torch.Tensor(np.array(vectorized_test_labels)).to(device)

In [None]:
# Solve it with raw Jax using just node features. Using only node features (no graph structure)
# is not expected to work well since we specifically designed the toy problem to need information
# about the local neighborhood in order to make correct predictions.
set_seeds(2)

def init_mlp(layer_dimensions: list[int], parent_random_key: jax.random.PRNGKey, scale: float=.1) -> list[list[jnp.ndarray]]:
  params = []
  keys = jax.random.split(parent_random_key, num=len(layer_dimensions)-1)
  for index, _ in enumerate(layer_dimensions[:-1]):
    in_width = layer_dimensions[index]
    out_width = layer_dimensions[index+1]
    weight_key, bias_key = jax.random.split(keys[index])
    params.append([
      scale * jax.random.normal(weight_key, shape=(in_width, out_width)),
      scale * jax.random.normal(bias_key, shape=(out_width,)),
    ])
  return params

@jax.jit
def forward(params: list[list[jnp.ndarray]], inputs: jnp.ndarray) -> jnp.ndarray:
  starting_vector = inputs
  for index, layer in enumerate(params):
    starting_vector = jnp.dot(starting_vector, layer[0]) + layer[1]
    if index < len(params)-1:
      # Don't RELU the final values.
      starting_vector = jax.nn.relu(starting_vector)
  return starting_vector

forward_batched = jax.jit(jax.vmap(forward, in_axes=(None, 0)))

@jax.jit
def loss_function_single(params: list[list[jnp.ndarray]], inputs: jnp.ndarray, correct_labels: jnp.ndarray) -> float:
  # inputs is shape (node_feature_vector_len,)
  # correct_labels shape is (num_classes,)
  predicted = forward(params, inputs)  # (num_classes,)
  return optax.softmax_cross_entropy(predicted, correct_labels)

@jax.jit
def loss_function_batched(params: list[list[jnp.ndarray]], inputs: jnp.ndarray, correct_labels: jnp.ndarray) -> float:
  # inputs is now shape (batch_size, node_feature_vector_len)
  # correct_labels is now shape (batch_size, num_classes)
  predicted = forward_batched(params, inputs)  # (batch_size, num_classes)
  return jnp.sum(optax.softmax_cross_entropy(predicted, correct_labels))

def calculate_accuracy_over_test_set_jax(params: list[list[jnp.ndarray]], test_set: jnp.ndarray, test_labels: jnp.ndarray) -> float:
  # Returns the percentage of test examples classified correctly.
  prediction_logits = forward(params, test_set)
  softmaxed = jax.nn.softmax(prediction_logits)
  predictions_argmax = jnp.argmax(softmaxed, axis=1)
  ground_truth_argmax = jnp.argmax(test_labels, axis=1)
  correct_count = jnp.sum(predictions_argmax == ground_truth_argmax)
  return correct_count / test_labels.shape[0]


seed = 12
num_epochs = 50
learning_rate = .001
network_params = init_mlp(
    [example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], 128, 128, num_classes],
    jax.random.PRNGKey(seed),
)
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  for i in range(batched_train_set.shape[0]):
    train_batch = batched_train_set[i]  # (batch_size, node_feature_vector_len)
    train_labels = batched_train_labels[i]  # (batch_size, num_classes)
    shuffled_indices = jax.random.permutation(jax.random.PRNGKey(epoch), jnp.arange(train_batch.shape[0]))
    train_batch = train_batch[shuffled_indices]
    train_labels = train_labels[shuffled_indices]
    loss_value, loss_gradient = jax.value_and_grad(loss_function_batched)(network_params, train_batch, train_labels)
    network_params = jax.tree.map(lambda p, g: p - learning_rate*g, network_params, loss_gradient)
  accuracy = calculate_accuracy_over_test_set_jax(network_params, vectorized_test_set, vectorized_test_labels)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss_value:.4f}.")

Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.32. Final batch loss was 43.4771.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.33. Final batch loss was 43.1715.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.33. Final batch loss was 43.1604.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.33. Final batch loss was 43.2434.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.33. Final batch loss was 43.2694.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.34. Final batch loss was 43.2728.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.34. Final batch loss was 43.3746.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.34. Final batch loss was 43.4371.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.34. Final batch loss was 43.5352.
Beginning 

In [None]:
# Solve it in PyTorch using just node features.
set_seeds(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleMlp(nn.Module):
  """Basic MLP to attempt to classify node features."""

  def __init__(self, *, node_feature_vector_len: int, num_classes: int):
    super().__init__()
    self.network = nn.Sequential(
        nn.Linear(node_feature_vector_len, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes),
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.network(x)

def calculate_accuracy_over_test_set_pytorch(params: SimpleMlp, test_set: torch.Tensor, test_labels: torch.Tensor) -> float:
  # Returns the percentage of test examples classified correctly.
  prediction_logits = params.forward(test_set)
  softmaxed = nn.functional.softmax(prediction_logits, dim=1)
  predictions_argmax = torch.argmax(softmaxed, dim=1)
  ground_truth_argmax = torch.argmax(test_labels, dim=1)
  correct_count = torch.sum(predictions_argmax == ground_truth_argmax)
  return correct_count / test_labels.shape[0]


mlp = SimpleMlp(node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], num_classes=num_classes).to(device)
print(f"Param count {sum(param.numel() for param in mlp.parameters())}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp.parameters(), lr=.001)
num_epochs = 50
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  for i in range(batched_train_set_torch.shape[0]):
    train_batch = batched_train_set_torch[i]  # (batch_size, node_feature_vector_len)
    train_labels = torch.argmax(batched_train_labels_torch[i], dim=1)  # (batch_size,)
    num_rows = train_batch.shape[0]
    shuffled_indices = torch.randperm(num_rows)
    train_batch = train_batch[shuffled_indices]
    train_labels = train_labels[shuffled_indices]

    optimizer.zero_grad()
    outputs = mlp.forward(train_batch)  # (batch_size, num_classes)
    loss = criterion(outputs, train_labels)
    loss.backward()
    optimizer.step()

  accuracy = calculate_accuracy_over_test_set_pytorch(mlp, vectorized_test_set_torch, vectorized_test_labels_torch)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss:.4f}.")

Param count 19208
Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.33. Final batch loss was 1.5793.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.34. Final batch loss was 1.5649.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.34. Final batch loss was 1.5515.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.34. Final batch loss was 1.5382.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.34. Final batch loss was 1.5351.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.34. Final batch loss was 1.5322.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.34. Final batch loss was 1.5225.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.34. Final batch loss was 1.5143.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.34. Final batch loss was 1.5087.
B

### Solve using GCN

In [None]:
# GCN in Jax
set_seeds(2)

def init_gnn(*, node_feature_vector_len: int, num_classes: int, parent_random_key: jax.random.PRNGKey, scale: float=.1) -> dict[str, jnp.ndarray]:
  params = {}
  keys = jax.random.split(parent_random_key, num=12)

  # Could do this in a loop but making everything super explicit since there are only 2 layers.
  params["gnn_layer_1_w"] = scale * jax.random.normal(keys[0], shape=(node_feature_vector_len, node_feature_vector_len))
  params["gnn_layer_1_b"] = scale * jax.random.normal(keys[1], shape=(node_feature_vector_len,))
  params["self_update_1_w"] = scale * jax.random.normal(keys[2], shape=(node_feature_vector_len, node_feature_vector_len))
  params["self_update_1_b"] = scale * jax.random.normal(keys[3], shape=(node_feature_vector_len,))
  params["gnn_layer_2_w"] = scale * jax.random.normal(keys[4], shape=(node_feature_vector_len, node_feature_vector_len))
  params["gnn_layer_2_b"] = scale * jax.random.normal(keys[5], shape=(node_feature_vector_len,))
  params["self_update_2_w"] = scale * jax.random.normal(keys[6], shape=(node_feature_vector_len, node_feature_vector_len))
  params["self_update_2_b"] = scale * jax.random.normal(keys[7], shape=(node_feature_vector_len,))
  params["classification_head_mlp_1_w"] = scale * jax.random.normal(keys[8], shape=(node_feature_vector_len, 128))
  params["classification_head_mlp_1_b"] = scale * jax.random.normal(keys[9], shape=(128,))
  params["classification_head_mlp_2_w"] = scale * jax.random.normal(keys[10], shape=(128, num_classes))
  params["classification_head_mlp_2_b"] = scale * jax.random.normal(keys[11], shape=(num_classes,))

  return params

@jax.jit
def forward_batched(params: dict[str, jnp.ndarray], h0: jnp.ndarray, D_inv: jnp.ndarray, A: jnp.ndarray) -> jnp.ndarray:
  # h0 is (V, d)
  # D_inv is (V, V)
  # A is (V, V)
  d_inv_a = D_inv @ A  # (V, V)
  gnn_layer_1 = (d_inv_a @ h0) @ params["gnn_layer_1_w"] + params["gnn_layer_1_b"]  # (V, d)
  self_update_1 = (h0 @ params["self_update_1_w"]) + params["self_update_1_b"]  # (V, d)
  h1 = jax.nn.relu(gnn_layer_1 + self_update_1)  # (V, d)
  gnn_layer_2 = (d_inv_a @ h1) @ params["gnn_layer_2_w"] + params["gnn_layer_2_b"]  # (V, d)
  self_update_2 = (h1 @ params["self_update_2_w"]) + params["self_update_2_b"]  # (V, d)
  h2 = jax.nn.relu(gnn_layer_2 + self_update_2)  # (V, d)

  mlp1 = (h2 @ params["classification_head_mlp_1_w"]) + params["classification_head_mlp_1_b"]  # (V, 128)
  mlp1 = jax.nn.relu(mlp1)
  return (mlp1 @ params["classification_head_mlp_2_w"]) + params["classification_head_mlp_2_b"]  # (V, num_classes)

@jax.jit
def loss_function_batched(
    params: dict[str, jnp.ndarray],
    h0: jnp.ndarray,
    D_inv: jnp.ndarray,
    A: jnp.ndarray,
    correct_labels: jnp.ndarray
) -> float:
  # h0 is (V, d)
  # D_inv is (V, V)
  # A is (V, V)
  # correct_labels is now shape (V, num_classes)
  predicted = forward_batched(params, h0, D_inv, A)  # (V, num_classes)
  return jnp.sum(optax.softmax_cross_entropy(predicted, correct_labels))

def calculate_accuracy_over_test_set_jax(params: dict[str, jnp.ndarray], test_set: list[tuple[Graph, list[int]]]) -> float:
  # Returns the percentage of test examples classified correctly.
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    A = graph.get_adjacency_matrix_jnp(add_self_loops=False)
    summed = jnp.sum(A, axis=1)
    D_inv = jnp.diag(jnp.where(summed > 0, 1 / summed, 0))
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector_jnp(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = jnp.vstack(all_node_feature_vectors)
    prediction_logits = forward_batched(params, h0, D_inv, A)
    softmaxed = jax.nn.softmax(prediction_logits)
    predictions_argmax = jnp.argmax(softmaxed, axis=1)
    correct_count += jnp.sum(predictions_argmax == jnp.array(labels, dtype=jnp.int32))
    total_count += len(labels)
  return float(correct_count) / total_count


num_epochs = 50
learning_rate = .001
network_params = init_gnn(
    node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0],
    num_classes=num_classes,
    parent_random_key=jax.random.PRNGKey(12),
)
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(train_set)
  random.shuffle(train_set_copy)
  for graph, labels in train_set_copy:
    A = graph.get_adjacency_matrix_jnp(add_self_loops=False)
    summed = jnp.sum(A, axis=1)
    D_inv = jnp.diag(jnp.where(summed > 0, 1 / summed, 0))
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector_jnp(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = jnp.vstack(all_node_feature_vectors)
    targets = jnp.array(labels, dtype=jnp.int32)
    one_hot_targets = jnp.eye(num_classes)[targets]
    loss_value, loss_gradient = jax.value_and_grad(loss_function_batched)(network_params, h0, D_inv, A, one_hot_targets)
    network_params = jax.tree.map(lambda p, g: p - learning_rate*g, network_params, loss_gradient)
  accuracy = calculate_accuracy_over_test_set_jax(network_params, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss_value:.4f}.")

Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.60. Final batch loss was 9.1806.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.67. Final batch loss was 3.7298.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.69. Final batch loss was 6.0752.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.67. Final batch loss was 0.7907.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.73. Final batch loss was 3.8979.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.76. Final batch loss was 3.5668.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.77. Final batch loss was 3.7210.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.71. Final batch loss was 3.7890.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.72. Final batch loss was 6.0502.
Beginning training 

In [None]:
# GCN in PyTorch
set_seeds(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GCN(nn.Module):
  """Basic GCN to take graph structure into account when predicting."""

  def __init__(self, *, node_feature_vector_len: int, num_classes: int):
    super().__init__()
    self.gnn_layer_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.self_update_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.gnn_layer_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.self_update_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.classification_head = nn.Sequential(
        nn.Linear(node_feature_vector_len, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes),
    )

  def forward(self, h0: torch.Tensor, D_inv: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
    # h0 is (V, d)
    # D_inv is (V, V)
    # A is (V, V)
    h1 = nn.functional.relu(self.gnn_layer_1((D_inv @ A) @ h0) + self.self_update_1(h0))
    h2 = nn.functional.relu(self.gnn_layer_2((D_inv @ A) @ h1) + self.self_update_2(h1))
    return self.classification_head(h2)

def calculate_accuracy_over_test_set_pytorch(params: GCN, test_set: list[tuple[Graph, list[int]]]) -> float:
  # Returns the percentage of test examples classified correctly.
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    A = torch.Tensor(graph.get_adjacency_matrix(add_self_loops=False)).to(device)
    summed = torch.sum(A, dim=1).to(device)
    D_inv = torch.diag(torch.where(summed > 0, 1 / summed, 0)).to(device)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = torch.Tensor(np.vstack(all_node_feature_vectors)).to(device)
    prediction_logits = params.forward(h0, D_inv, A)
    softmaxed = nn.functional.softmax(prediction_logits, dim=1)
    predictions_argmax = torch.argmax(softmaxed, dim=1)
    correct_count += torch.sum(predictions_argmax == torch.IntTensor(labels).to(device))
    total_count += len(labels)
  return float(correct_count) / total_count

gcn = GCN(node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], num_classes=num_classes).to(device)
print(f"Param count {sum(param.numel() for param in gcn.parameters())}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gcn.parameters(), lr=.001)
num_epochs = 50
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(train_set)
  random.shuffle(train_set_copy)
  for graph, labels in train_set_copy:
    A = torch.Tensor(graph.get_adjacency_matrix(add_self_loops=False)).to(device)
    summed = torch.sum(A, dim=1).to(device)
    D_inv = torch.diag(torch.where(summed > 0, 1 / summed, 0)).to(device)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = torch.Tensor(np.vstack(all_node_feature_vectors)).to(device)

    optimizer.zero_grad()
    outputs = gcn.forward(h0, D_inv, A)  # (batch_size, num_classes)
    loss = criterion(outputs, torch.LongTensor(labels).to(device))
    loss.backward()
    optimizer.step()

  accuracy = calculate_accuracy_over_test_set_pytorch(gcn, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss:.4f}.")

Param count 3320
Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.67. Final batch loss was 0.7736.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.72. Final batch loss was 0.2663.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.76. Final batch loss was 0.4433.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.75. Final batch loss was 0.1447.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.79. Final batch loss was 0.3416.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.78. Final batch loss was 0.3632.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.76. Final batch loss was 0.3755.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.79. Final batch loss was 0.7155.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.79. Final batch loss was 0.4404.
Be

### Solve using GAT

In [None]:
# GAT in Jax
set_seeds(2)

def init_gnn(*, node_feature_vector_len: int, num_classes: int, parent_random_key: jax.random.PRNGKey, scale: float=.1) -> dict[str, jnp.ndarray]:
  params = {}
  keys = jax.random.split(parent_random_key, num=20)

  # Could do this in a loop but making everything super explicit since there are only 2 layers.
  params["gnn_layer_1_w"] = scale * jax.random.normal(keys[0], shape=(node_feature_vector_len, node_feature_vector_len))
  params["gnn_layer_1_b"] = scale * jax.random.normal(keys[1], shape=(node_feature_vector_len,))
  params["self_update_1_w"] = scale * jax.random.normal(keys[2], shape=(node_feature_vector_len, node_feature_vector_len))
  params["self_update_1_b"] = scale * jax.random.normal(keys[3], shape=(node_feature_vector_len,))
  params["layer_1_mha_head_1_w"] = scale * jax.random.normal(keys[4], shape=(2 * node_feature_vector_len, 1))
  params["layer_1_mha_head_1_b"] = scale * jax.random.normal(keys[5], shape=(1,))
  params["layer_1_mha_head_2_w"] = scale * jax.random.normal(keys[6], shape=(2 * node_feature_vector_len, 1))
  params["layer_1_mha_head_2_b"] = scale * jax.random.normal(keys[7], shape=(1,))
  params["gnn_layer_2_w"] = scale * jax.random.normal(keys[8], shape=(node_feature_vector_len, node_feature_vector_len))
  params["gnn_layer_2_b"] = scale * jax.random.normal(keys[9], shape=(node_feature_vector_len,))
  params["self_update_2_w"] = scale * jax.random.normal(keys[10], shape=(node_feature_vector_len, node_feature_vector_len))
  params["self_update_2_b"] = scale * jax.random.normal(keys[11], shape=(node_feature_vector_len,))
  params["layer_2_mha_head_1_w"] = scale * jax.random.normal(keys[12], shape=(2 * node_feature_vector_len, 1))
  params["layer_2_mha_head_1_b"] = scale * jax.random.normal(keys[13], shape=(1,))
  params["layer_2_mha_head_2_w"] = scale * jax.random.normal(keys[14], shape=(2 * node_feature_vector_len, 1))
  params["layer_2_mha_head_2_b"] = scale * jax.random.normal(keys[15], shape=(1,))
  params["classification_head_mlp_1_w"] = scale * jax.random.normal(keys[16], shape=(node_feature_vector_len, 128))
  params["classification_head_mlp_1_b"] = scale * jax.random.normal(keys[17], shape=(128,))
  params["classification_head_mlp_2_w"] = scale * jax.random.normal(keys[18], shape=(128, num_classes))
  params["classification_head_mlp_2_b"] = scale * jax.random.normal(keys[19], shape=(num_classes,))

  return params

@jax.jit
def forward_batched(params: dict[str, jnp.ndarray], h0: jnp.ndarray, A: jnp.ndarray) -> jnp.ndarray:
  # h0 is (V, d)
  # A is (V, V)
  starting_vals = (h0 @ params["gnn_layer_1_w"]) + params["gnn_layer_1_b"]  # (V, d)
  lhs = jnp.repeat(starting_vals, starting_vals.shape[0], axis=0)  # (V^2, d)
  rhs = jnp.tile(starting_vals, (starting_vals.shape[0], 1))  # (V^2, d)
  ans = jnp.hstack([lhs, rhs])  # (V^2, 2d)
  attn_logits_head_1 = (ans @ params["layer_1_mha_head_1_w"]) + params["layer_1_mha_head_1_b"]  # (V^2, 1)
  attn_logits_head_1 = jax.nn.relu(attn_logits_head_1).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
  attn_weights_layer_1_head_1 = jax.nn.softmax(attn_logits_head_1 * A, axis=1)  # (V, V)
  attn_logits_head_2 = (ans @ params["layer_1_mha_head_2_w"]) + params["layer_1_mha_head_2_b"]  # (V^2, 1)
  attn_logits_head_2 = jax.nn.relu(attn_logits_head_2).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
  attn_weights_layer_1_head_2 = jax.nn.softmax(attn_logits_head_2 * A, axis=1)  # (V, V)
  self_update_1 = (h0 @ params["self_update_1_w"]) + params["self_update_1_b"]  # (V, d)
  h1_head1 = jax.nn.relu((attn_weights_layer_1_head_1 @ starting_vals) + self_update_1)
  h1_head2 = jax.nn.relu((attn_weights_layer_1_head_2 @ starting_vals) + self_update_1)
  h1 = h1_head1 + h1_head2  # Sum pool

  starting_vals = (h1 @ params["gnn_layer_2_w"]) + params["gnn_layer_2_b"]  # (V, d)
  lhs = jnp.repeat(starting_vals, starting_vals.shape[0], axis=0)  # (V^2, d)
  rhs = jnp.tile(starting_vals, (starting_vals.shape[0], 1))  # (V^2, d)
  ans = jnp.hstack([lhs, rhs])  # (V^2, 2d)
  attn_logits_head_1 = (ans @ params["layer_2_mha_head_1_w"]) + params["layer_2_mha_head_1_b"]  # (V^2, 1)
  attn_logits_head_1 = jax.nn.relu(attn_logits_head_1).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
  attn_weights_layer_2_head_1 = jax.nn.softmax(attn_logits_head_1 * A, axis=1)  # (V, V)
  attn_logits_head_2 = (ans @ params["layer_2_mha_head_2_w"]) + params["layer_2_mha_head_2_b"]  # (V^2, 1)
  attn_logits_head_2 = jax.nn.relu(attn_logits_head_2).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
  attn_weights_layer_2_head_2 = jax.nn.softmax(attn_logits_head_2 * A, axis=1)  # (V, V)
  self_update_2 = (h1 @ params["self_update_2_w"]) + params["self_update_2_b"]  # (V, d)
  h2_head1 = jax.nn.relu((attn_weights_layer_2_head_1 @ starting_vals) + self_update_2)
  h2_head2 = jax.nn.relu((attn_weights_layer_2_head_2 @ starting_vals) + self_update_2)
  h2 = h2_head1 + h2_head2  # Sum pool

  mlp1 = (h2 @ params["classification_head_mlp_1_w"]) + params["classification_head_mlp_1_b"]  # (V, 128)
  mlp1 = jax.nn.relu(mlp1)
  return (mlp1 @ params["classification_head_mlp_2_w"]) + params["classification_head_mlp_2_b"]  # (V, num_classes)

@jax.jit
def loss_function_batched(
    params: dict[str, jnp.ndarray],
    h0: jnp.ndarray,
    A: jnp.ndarray,
    correct_labels: jnp.ndarray
) -> float:
  # h0 is (V, d)
  # A is (V, V)
  # correct_labels is now shape (V, num_classes)
  predicted = forward_batched(params, h0, A)  # (V, num_classes)
  return jnp.sum(optax.softmax_cross_entropy(predicted, correct_labels))

def calculate_accuracy_over_test_set_jax(params: dict[str, jnp.ndarray], test_set: list[tuple[Graph, list[int]]]) -> float:
  # Returns the percentage of test examples classified correctly.
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    A = graph.get_adjacency_matrix_jnp(add_self_loops=False)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector_jnp(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = jnp.vstack(all_node_feature_vectors)
    prediction_logits = forward_batched(params, h0, A)
    softmaxed = jax.nn.softmax(prediction_logits)
    predictions_argmax = jnp.argmax(softmaxed, axis=1)
    correct_count += jnp.sum(predictions_argmax == jnp.array(labels, dtype=jnp.int32))
    total_count += len(labels)
  return float(correct_count) / total_count


num_epochs = 50
learning_rate = .001
network_params = init_gnn(
    node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0],
    num_classes=num_classes,
    parent_random_key=jax.random.PRNGKey(12),
)
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(train_set)
  random.shuffle(train_set_copy)
  for graph, labels in train_set_copy:
    A = graph.get_adjacency_matrix_jnp(add_self_loops=False)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector_jnp(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = jnp.vstack(all_node_feature_vectors)
    targets = jnp.array(labels, dtype=jnp.int32)
    one_hot_targets = jnp.eye(num_classes)[targets]
    loss_value, loss_gradient = jax.value_and_grad(loss_function_batched)(network_params, h0, A, one_hot_targets)
    network_params = jax.tree.map(lambda p, g: p - learning_rate*g, network_params, loss_gradient)
  accuracy = calculate_accuracy_over_test_set_jax(network_params, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss_value:.4f}.")

Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.63. Final batch loss was 7.3880.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.68. Final batch loss was 3.8438.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.62. Final batch loss was 5.5485.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.64. Final batch loss was 0.7052.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.74. Final batch loss was 2.7534.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.76. Final batch loss was 6.0981.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.76. Final batch loss was 3.5568.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.74. Final batch loss was 4.3175.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.70. Final batch loss was 4.6210.
Beginning training 

In [None]:
# GAT in PyTorch
set_seeds(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GAT(nn.Module):
  """Use attention to attend to different neighbors differently."""

  def __init__(self, *, node_feature_vector_len: int, num_classes: int):
    super().__init__()
    self.gnn_layer_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.self_update_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.layer_1_mha_head_1 = nn.Linear(2 * node_feature_vector_len, 1)
    self.layer_1_mha_head_2 = nn.Linear(2 * node_feature_vector_len, 1)
    self.gnn_layer_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.self_update_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.layer_2_mha_head_1 = nn.Linear(2 * node_feature_vector_len, 1)
    self.layer_2_mha_head_2 = nn.Linear(2 * node_feature_vector_len, 1)
    self.classification_head = nn.Sequential(
        nn.Linear(node_feature_vector_len, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes),
    )

  def forward(self, h0: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
    # h0 is (V, d)
    # A is (V, V)
    # Realistically the MHA should be extracted to a module and use loops instead of copy/paste,
    # but this is meant to be quick, in-line, and explicit.
    starting_vals = self.gnn_layer_1(h0)  # (V, d)
    lhs = torch.repeat_interleave(starting_vals, starting_vals.shape[0], dim=0)  # (V^2, d)
    rhs = starting_vals.repeat(starting_vals.shape[0], 1)  # (V^2, d)
    ans = torch.hstack([lhs, rhs])  # (V^2, 2d)
    attn_logits_head_1 = self.layer_1_mha_head_1(ans)  # (V^2, 1)
    attn_logits_head_1 = torch.nn.functional.relu(attn_logits_head_1).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
    attn_weights_layer_1_head_1 = torch.nn.functional.softmax(attn_logits_head_1 * A, dim=1)  # (V, V)
    attn_logits_head_2 = self.layer_1_mha_head_2(ans)  # (V^2, 1)
    attn_logits_head_2 = torch.nn.functional.relu(attn_logits_head_2).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
    attn_weights_layer_1_head_2 = torch.nn.functional.softmax(attn_logits_head_2 * A, dim=1)  # (V, V)
    self_update = self.self_update_1(h0)
    h1_head1 = nn.functional.relu((attn_weights_layer_1_head_1 @ starting_vals) + self_update)
    h1_head2 = nn.functional.relu((attn_weights_layer_1_head_2 @ starting_vals) + self_update)
    h1 = h1_head1 + h1_head2  # Sum pool

    starting_vals = self.gnn_layer_2(h1)  # (V, d)
    lhs = torch.repeat_interleave(starting_vals, starting_vals.shape[0], dim=0)  # (V^2, d)
    rhs = starting_vals.repeat(starting_vals.shape[0], 1)  # (V^2, d)
    ans = torch.hstack([lhs, rhs])  # (V^2, 2d)
    attn_logits_head_1 = self.layer_2_mha_head_1(ans)  # (V^2, 1)
    attn_logits_head_1 = torch.nn.functional.relu(attn_logits_head_1).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
    attn_weights_layer_2_head_1 = torch.nn.functional.softmax(attn_logits_head_1 * A, dim=1)  # (V, V)
    attn_logits_head_2 = self.layer_2_mha_head_2(ans)  # (V^2, 1)
    attn_logits_head_2 = torch.nn.functional.relu(attn_logits_head_2).reshape((starting_vals.shape[0], starting_vals.shape[0]))  # (V, V)
    attn_weights_layer_2_head_2 = torch.nn.functional.softmax(attn_logits_head_2 * A, dim=1)  # (V, V)
    self_update = self.self_update_2(h1)
    h2_head1 = nn.functional.relu((attn_weights_layer_2_head_1 @ starting_vals) + self_update)
    h2_head2 = nn.functional.relu((attn_weights_layer_2_head_2 @ starting_vals) + self_update)
    h2 = h2_head1 + h2_head2  # Sum pool

    return self.classification_head(h2)

def calculate_accuracy_over_test_set_pytorch(params: GAT, test_set: list[tuple[Graph, list[int]]]) -> float:
  # Returns the percentage of test examples classified correctly.
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    A = torch.Tensor(graph.get_adjacency_matrix(add_self_loops=False)).to(device)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = torch.Tensor(np.vstack(all_node_feature_vectors)).to(device)
    prediction_logits = params.forward(h0, A)
    softmaxed = nn.functional.softmax(prediction_logits, dim=1)
    predictions_argmax = torch.argmax(softmaxed, dim=1)
    correct_count += torch.sum(predictions_argmax == torch.IntTensor(labels).to(device))
    total_count += len(labels)
  return float(correct_count) / total_count

gat = GAT(node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], num_classes=num_classes).to(device)
print(f"Param count {sum(param.numel() for param in gat.parameters())}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gat.parameters(), lr=.001)
num_epochs = 50
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(train_set)
  random.shuffle(train_set_copy)
  for graph, labels in train_set_copy:
    A = torch.Tensor(graph.get_adjacency_matrix(add_self_loops=False)).to(device)
    all_node_feature_vectors = [
        graph.id_to_node[i].get_feature_vector(node_feature_a_size=node_feature_a_size)
        for i in range(len(graph.id_to_node))]
    h0 = torch.Tensor(np.vstack(all_node_feature_vectors)).to(device)

    optimizer.zero_grad()
    outputs = gat.forward(h0, A)  # (batch_size, num_classes)
    loss = criterion(outputs, torch.LongTensor(labels).to(device))
    loss.backward()
    optimizer.step()

  accuracy = calculate_accuracy_over_test_set_pytorch(gat, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss:.4f}.")

Param count 3420
Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.64. Final batch loss was 0.5924.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.65. Final batch loss was 0.4348.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.71. Final batch loss was 0.4477.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.71. Final batch loss was 0.1075.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.73. Final batch loss was 0.2920.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.75. Final batch loss was 0.3675.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.77. Final batch loss was 0.4094.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.75. Final batch loss was 0.6922.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.79. Final batch loss was 0.4744.
Be

### PyTorch Geometric

In [None]:
# https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html
! pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m41.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
from torch_geometric.data import Data as PygData
from torch_geometric.loader import DataLoader as PygDataLoader
from torch_geometric.nn import GCNConv, GATConv

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def convert_graph_to_pyg_data(graph: Graph, labels: list[int]) -> PygData:
  # data.x: Node feature matrix with shape [num_nodes, num_node_features]
  # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
  # data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
  all_node_feature_vectors = [
      graph.id_to_node[i].get_feature_vector(node_feature_a_size=node_feature_a_size)
      for i in range(len(graph.id_to_node))]
  h0 = torch.Tensor(np.vstack(all_node_feature_vectors)).to(device)
  data = PygData(x=h0, edge_index=graph.get_edge_connections_coo().to(device), y=torch.tensor(labels, dtype=torch.int64).to(device))
  data.validate(raise_on_error=True)
  return data

pyg_dataset = []
for graph, labels in train_set:
  pyg_dataset.append(convert_graph_to_pyg_data(graph, labels))

print(f"{len(pyg_dataset)} training examples converted.")
print(f"First graph num nodes: {pyg_dataset[0].num_nodes}")
print(f"First graph num edges: {pyg_dataset[0].num_edges}")
print(f"First graph num node features: {pyg_dataset[0].num_node_features}")

10000 training examples converted.
First graph num nodes: 7
First graph num edges: 36
First graph num node features: 12


In [None]:
set_seeds(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PygGCN(nn.Module):

  def __init__(self, *, node_feature_vector_len: int, num_classes: int):
    super().__init__()
    # self.dropout_rate = .01  # No dropout worked better than any dropout so disabling here.
    self.gcn_conv_1 = GCNConv(node_feature_vector_len, node_feature_vector_len, add_self_loops=False, normalize=False, bias=True)
    # This is weird. For some reason adding this self-update like I did manually above
    # works significantly (as in, over 30 points better) than just using add_self_loops=True.
    # That is likely because information from the node itself has a significant impact on
    # its classification (but I feel that would likely often be the case for node classification
    # tasks).
    self.self_update_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.gcn_conv_2 = GCNConv(node_feature_vector_len, node_feature_vector_len, add_self_loops=False, normalize=False, bias=True)
    self.self_update_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.classification_head = nn.Sequential(
        nn.Linear(node_feature_vector_len, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes),
    )

  def forward(self, data: PygData) -> torch.Tensor:
    x, edge_index = data.x, data.edge_index
    x = self.gcn_conv_1(x, edge_index) + self.self_update_1(x)
    x = nn.functional.relu(x)
    # x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
    x = self.gcn_conv_2(x, edge_index) + self.self_update_2(x)
    x = nn.functional.relu(x)
    # x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
    return self.classification_head(x)

def calculate_accuracy_over_test_set_pytorch_geometric(
    params: PygGCN, test_set: list[tuple[Graph, list[int]]]
) -> float:
  # Returns the percentage of test examples classified correctly.
  original_training = params.training
  params.eval()
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    pyg_data = convert_graph_to_pyg_data(graph, labels)
    prediction_logits = params.forward(pyg_data)
    softmaxed = nn.functional.softmax(prediction_logits, dim=1)
    predictions_argmax = torch.argmax(softmaxed, dim=1)
    correct_count += torch.sum(predictions_argmax == torch.IntTensor(labels).to(device))
    total_count += len(labels)
  params.train()
  return float(correct_count) / total_count


pyg_gcn = PygGCN(node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], num_classes=num_classes).to(device)
print(f"Param count {sum(param.numel() for param in pyg_gcn.parameters())}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pyg_gcn.parameters(), lr=.001)
num_epochs = 50
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(pyg_dataset)
  random.shuffle(train_set_copy)
  # Don't need a Dataset/InMemoryDataset object here since we already have everything
  # in memory and don't need to download the dataset or import it from files.
  # pytorch_geometric batching works by taking batch_size individual graphs and combining them
  # into a large, disconnected graph.
  loader = PygDataLoader(train_set_copy, batch_size=32)
  for batch in loader:
    # https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Batch.html#torch_geometric.data.Batch
    # "torch_geometric.data.Batch inherits from torch_geometric.data.Data and contains an additional attribute called batch"
    optimizer.zero_grad()
    outputs = pyg_gcn.forward(batch.to(device))  # (batch_size, num_classes)
    loss = criterion(outputs, batch.y.to(device))
    loss.backward()
    optimizer.step()

  accuracy = calculate_accuracy_over_test_set_pytorch_geometric(pyg_gcn, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss:.4f}.")

Param count 3320
Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.53. Final batch loss was 1.2818.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.67. Final batch loss was 0.8820.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.73. Final batch loss was 0.6417.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.84. Final batch loss was 0.4031.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.87. Final batch loss was 0.2279.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.88. Final batch loss was 0.3545.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.88. Final batch loss was 0.2972.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.89. Final batch loss was 0.4528.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.89. Final batch loss was 0.1584.
Be

In [None]:
set_seeds(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PygGAT(nn.Module):

  def __init__(self, *, node_feature_vector_len: int, num_classes: int):
    super().__init__()
    # self.dropout_rate = .01  # No dropout worked better than any dropout so disabling here.
    num_heads = 2
    self.gat_conv_1 = GATConv(node_feature_vector_len, node_feature_vector_len // num_heads, heads=num_heads, concat=True, add_self_loops=False, bias=True)
    self.self_update_1 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.gat_conv_2 = GATConv(node_feature_vector_len, node_feature_vector_len // num_heads, heads=num_heads, concat=True, add_self_loops=False, bias=True)
    self.self_update_2 = nn.Linear(node_feature_vector_len, node_feature_vector_len)
    self.classification_head = nn.Sequential(
        nn.Linear(node_feature_vector_len, 128),
        nn.ReLU(),
        nn.Linear(128, num_classes),
    )

  def forward(self, data: PygData) -> torch.Tensor:
    x, edge_index = data.x, data.edge_index
    x = self.gat_conv_1(x, edge_index) + self.self_update_1(x)
    x = nn.functional.relu(x)
    # x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
    x = self.gat_conv_2(x, edge_index) + self.self_update_2(x)
    x = nn.functional.relu(x)
    # x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training)
    return self.classification_head(x)

def calculate_accuracy_over_test_set_pytorch_geometric(
    params: PygGAT, test_set: list[tuple[Graph, list[int]]]
) -> float:
  # Returns the percentage of test examples classified correctly.
  original_training = params.training
  params.eval()
  correct_count = 0
  total_count = 0
  for graph, labels in test_set:
    pyg_data = convert_graph_to_pyg_data(graph, labels)
    prediction_logits = params.forward(pyg_data)
    softmaxed = nn.functional.softmax(prediction_logits, dim=1)
    predictions_argmax = torch.argmax(softmaxed, dim=1)
    correct_count += torch.sum(predictions_argmax == torch.IntTensor(labels).to(device))
    total_count += len(labels)
  params.train()
  return float(correct_count) / total_count


pyg_gat = PygGAT(node_feature_vector_len=example_node.get_feature_vector(node_feature_a_size=node_feature_a_size).shape[0], num_classes=num_classes).to(device)
print(f"Param count {sum(param.numel() for param in pyg_gat.parameters())}")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pyg_gat.parameters(), lr=.001)
num_epochs = 50
for epoch in range(num_epochs):
  print(f"Beginning training for epoch {epoch+1} of {num_epochs}...")
  train_set_copy = list(pyg_dataset)
  random.shuffle(train_set_copy)
  # Don't need a Dataset/InMemoryDataset object here since we already have everything
  # in memory and don't need to download the dataset or import it from files.
  # pytorch_geometric batching works by taking batch_size individual graphs and combining them
  # into a large, disconnected graph.
  loader = PygDataLoader(train_set_copy, batch_size=32)
  for batch in loader:
    # https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Batch.html#torch_geometric.data.Batch
    # "torch_geometric.data.Batch inherits from torch_geometric.data.Data and contains an additional attribute called batch"
    optimizer.zero_grad()
    outputs = pyg_gat.forward(batch.to(device))  # (batch_size, num_classes)
    loss = criterion(outputs, batch.y.to(device))
    loss.backward()
    optimizer.step()

  accuracy = calculate_accuracy_over_test_set_pytorch_geometric(pyg_gat, test_set)
  print(f"After {epoch+1} epochs, test accuracy was {accuracy:.2f}. Final batch loss was {loss:.4f}.")

Param count 3368
Beginning training for epoch 1 of 50...
After 1 epochs, test accuracy was 0.57. Final batch loss was 1.1272.
Beginning training for epoch 2 of 50...
After 2 epochs, test accuracy was 0.64. Final batch loss was 0.9342.
Beginning training for epoch 3 of 50...
After 3 epochs, test accuracy was 0.66. Final batch loss was 0.8695.
Beginning training for epoch 4 of 50...
After 4 epochs, test accuracy was 0.68. Final batch loss was 0.6681.
Beginning training for epoch 5 of 50...
After 5 epochs, test accuracy was 0.69. Final batch loss was 0.5843.
Beginning training for epoch 6 of 50...
After 6 epochs, test accuracy was 0.71. Final batch loss was 0.7093.
Beginning training for epoch 7 of 50...
After 7 epochs, test accuracy was 0.72. Final batch loss was 0.7248.
Beginning training for epoch 8 of 50...
After 8 epochs, test accuracy was 0.74. Final batch loss was 0.7649.
Beginning training for epoch 9 of 50...
After 9 epochs, test accuracy was 0.74. Final batch loss was 0.5400.
Be