In [None]:
!pip install -q --upgrade pip
!pip install -q --upgrade jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q git+https://www.github.com/google/neural-tangents
!pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
!pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.13.0+cpu.html

In [None]:
!pip uninstall jaxlib
!pip install jaxlib==0.4.1

In [None]:
import neural_tangents as nt
from neural_tangents import stax
from jax import numpy as np
from jax import random
import jax


def diag(x, batched=True):
  """
  Arange a 2-dim arrary into a 3-dim array.
  Where the 3-dim array has in the channel
  dimension diagonal matricies filled
  with the values from the 2-dim input.
  e.g 
  diag(np.array[[1,3],[3,4], [5,6]])
  = [ [[1,0], [0,3]], [[3,0], [0,4]], [[5,0], [0,6]]]
  """
  if batched:
    out = np.zeros((x.shape[0], x.shape[1], x.shape[1]))
    for i in range(0, x.shape[1]):
      out = out.at[:,i,i].set(x[:,i])
  else:
    out = np.zeros((x.shape[0], x.shape[0]))
    out = out.at[np.diag_indices(out.shape[0])].set(x)
  return out


def grap_conv_pattern(A, batched=True):
  A_tilde = A + np.identity(A.shape[1])
  A_tilde = A_tilde.at[A_tilde == 2].set(1)
  if batched:
    D_tilde = np.sum(A_tilde, axis=2)
  else:
    D_tilde = np.sum(A_tilde, axis=1)
  D_tilde = 1/np.sqrt(D_tilde)
  D_tilde = diag(D_tilde, batched)
  return D_tilde @ A_tilde @ D_tilde

def expand_pattern_at_channels_dim(pattern_in, nr_channels):
  """
  Expand a (batched) two dimensional pattern 
  into a three dimensional pattern. The size of the added 
  dimension is determined by nr_channels.
  The channe
  """
  pattern_out = np.zeros((pattern_in.shape[0],
                          pattern_in.shape[1], nr_channels, 
                          pattern_in.shape[1], nr_channels))
  for k in range(pattern_in.shape[0]):
    for i in range(pattern_in.shape[1]):
      for j in range(pattern_in.shape[2]):
        pattern_out = pattern_out.at[k,i,:,j,:].set(np.full((nr_channels,nr_channels), pattern_in[k,i,j]))
  return pattern_out

def expand_pattern_at_channels_dim(pattern_in, nr_channels, batched=True):
  """
  Expand a (batched) two dimensional pattern 
  into a three dimensional pattern. The size of the added 
  dimension is determined by nr_channels.
  The channe
  """

  if batched:
      out = np.zeros((pattern_in.shape[0],
                          pattern_in.shape[1], nr_channels, 
                          pattern_in.shape[1], nr_channels))
      for k in range(pattern_in.shape[0]):
        for i in range(nr_channels):
          out = out.at[k,:,i,:,i].set(pattern_in[k,:])
  else:
    out = np.zeros((pattern_in.shape[1], nr_channels, 
                    pattern_in.shape[1], nr_channels))
    for i in range(nr_channels):
      out = out.at[:,i,:,i].set(pattern_in)
  return out

In [None]:
# Grap Convolution Toy Example

# Batched feature Tensor
# Batch x Height x With x Channel
# NHWC
x = random.normal(random.PRNGKey(1), (10, 5, 1, 1))

A = random.bernoulli(random.PRNGKey(2), 0.5, (10, 5, 5))

pattern_1 = grap_conv_pattern(A) 
pattern_2 = expand_pattern_at_channels_dim(pattern_1, 1)

# `A[n, h1, w1, h2, w2] == True`
# means an edge between pixels `(h1, w1)` and `(h2, w2)` in image `n`.

# Adjacency matrix for a graph with 3 verticies
#[[1,1,0],
# [1,0,1],
# [0,1,0],]
# assuming the graph features are two dimensional
# [h1, w1, h2, w2]
# [0, 0, 0, 0] = 1 a self edge between h1 and h1 for channel dimension 0
# [0, 1, 0, 1] = 1 a self edge between h1 and h1 for channel dimension 1
# [1, 0, 0, 0] = 1
# [1, 1, 0, 1] = 1
# [0, 0, 1, 0] = 1
# [0, 1, 1, 1] = 1
# [2, 0, 1, 0] = 1
# [2, 1, 1, 1] = 1
# [2, 0, 2, 0] = 0
# [2, 1, 2, 1] = 0

print(pattern_2.shape)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.Aggregate(aggregate_axis=(1, 2), batch_axis=0, channel_axis=3),
    stax.Conv(100, (1,1)), stax.Relu(),
    stax.GlobalSumPool(), 
)

out = apply_fn((), x, pattern=A)

kernel = kernel_fn(x, x, 'nngp', pattern=(pattern_2,pattern_2))