In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
!pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

In [None]:
!pip uninstall jaxlib
!pip install jaxlib==0.4.1

In [None]:
import neural_tangents as nt
from neural_tangents import stax
from jax import numpy as np
from jax import random
import jax


def diag(x, batched=True):
  """
  Arange a 2-dim arrary into a 3-dim array.
  Where the 3-dim array has in the channel
  dimension diagonal matricies filled
  with the values from the 2-dim input.
  e.g 
  diag(np.array[[1,3],[3,4], [5,6]])
  = [ [[1,0], [0,3]], [[3,0], [0,4]], [[5,0], [0,6]]]
  """
  if batched:
    out = np.zeros((x.shape[0], x.shape[1], x.shape[1]))
    for i in range(0, x.shape[1]):
      out = out.at[:,i,i].set(x[:,i])
  else:
    out = np.zeros((x.shape[0], x.shape[0]))
    out = out.at[np.diag_indices(out.shape[0])].set(x)
  return out


def grap_conv_pattern(A, batched=True):
  A_tilde = A + np.identity(A.shape[1])
  A_tilde = A_tilde.at[A_tilde == 2].set(1)
  if batched:
    D_tilde = np.sum(A_tilde, axis=2)
  else:
    D_tilde = np.sum(A_tilde, axis=1)
  D_tilde = 1/np.sqrt(D_tilde)
  D_tilde = diag(D_tilde, batched)
  return D_tilde @ A_tilde @ D_tilde

def expand_pattern_at_channels_dim(pattern_in, nr_channels):
  """
  Expand a (batched) two dimensional pattern 
  into a three dimensional pattern. The size of the added 
  dimension is determined by nr_channels.
  The channe
  """
  pattern_out = np.zeros((pattern_in.shape[0],
                          pattern_in.shape[1], nr_channels, 
                          pattern_in.shape[1], nr_channels))
  for k in range(pattern_in.shape[0]):
    for i in range(pattern_in.shape[1]):
      for j in range(pattern_in.shape[2]):
        pattern_out = pattern_out.at[k,i,:,j,:].set(np.full((nr_channels,nr_channels), pattern_in[k,i,j]))
  return pattern_out

def expand_pattern_at_channels_dim(pattern_in, nr_channels, batched=True):
  """
  Expand a (batched) two dimensional pattern 
  into a three dimensional pattern. The size of the added 
  dimension is determined by nr_channels.
  The channe
  """

  if batched:
      out = np.zeros((pattern_in.shape[0],
                          pattern_in.shape[1], nr_channels, 
                          pattern_in.shape[1], nr_channels))
      for k in range(pattern_in.shape[0]):
        for i in range(nr_channels):
          out = out.at[k,:,i,:,i].set(pattern_in[k,:])
  else:
    out = np.zeros((pattern_in.shape[1], nr_channels, 
                    pattern_in.shape[1], nr_channels))
    for i in range(nr_channels):
      out = out.at[:,i,:,i].set(pattern_in)
  return out

In [None]:
# Grap Convolution Mutag Datataset Example

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

def to_dense(node_list, size):
  """
  Naive implementation, to get a
  adjacency matrix from a node list.
  Node list 2xn -> adjacency matrix nxn
  """
  A = np.zeros((size, size))
  node_list = data.edge_index.tolist()
  for i,j in zip(node_list[0], node_list[1]):
    A = A.at[i,j].set(1)
  return A

def zero_append(a, shape):
  """
  Add zero columns and rows to the array 
  a, to make it of shape size x size.
  """
  out = np.zeros((shape[0],shape[1]))
  out = out.at[:a.shape[0],:a.shape[1]].set(a)
  return out

In [None]:
# make lists of dense adjacency matrix, node feature array and y

dataset = TUDataset(root="Masterarbeit", name="MUTAG")

As = list()
graps_edge_features = list()
ys = list()

for data in dataset:
  As.append(to_dense(data.edge_index, len(data.x)))
  graps_edge_features.append(np.array(data.x))
  ys.append(np.array(data.y))

# unify the sizes for the nn input
max_nodes = len(max(graps_edge_features, key=lambda x: len(x)))
graps_edge_features = [zero_append(ef, (max_nodes, ef.shape[1])) 
for ef in graps_edge_features]
As = [zero_append(a, (max_nodes, max_nodes)) for a in As]

# calcualte the graph convolution pattern for each graph
patterns = list()
for A in As:
  p = grap_conv_pattern(A, False)
  patterns.append(expand_pattern_at_channels_dim(p, 7, False))

graps_edge_features_2 = np.array(graps_edge_features)
graps_edge_features_2 = np.expand_dims(graps_edge_features_2, 3)
patterns = np.array(patterns)


# define a grap convolution network and calculate the kernel matrix for it
init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.GlobalSumPool(), 
)

size = 37
for i in range(5):
  for j in range(5):
    x1 = graps_edge_features_2[i*size:(1+i)*size,:]
    x2 = graps_edge_features_2[j*size:(1+j)*size,:]
    p1 = patterns[i*size:(1+i)*size,:]
    p2 = patterns[j*size:(1+j)*size,:]

    kernel_matrix = kernel_fn(x1, x2, 'nngp', pattern=(p1, p2))
    np.save(f"kernel_matrix_{(1+i)*size}_{(1+j)*size}", kernel_matrix)
np.save(f"ys", ys)