In [1]:
from layers import two_wl_aggregation
import neural_tangents as nt
from neural_tangents import stax
from jax import numpy as jnp
from jax import random
import jax
import shutil
import os
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
from prepare_for_dataloader import twl_sparse_pattern, twl_sparse_edge_features
from dataloader import TWL_Dataloader


  from .autonotebook import tqdm as notebook_tqdm


# Example (two graphs) to test if the full network behaves as expected

In [2]:
# define a toy 3x3 graph
# 0---2
# | /     
# 1

# define a toy 1x1 graph
# 3---4

edges_1 = jnp.array([[0,0], [1,1], [2,2], [0,1], [0,2], [1,2]])
edges_2 = jnp.array([[0,0], [0,1], [1,1]])

node_features_1 = jnp.array([[10], [20], [30]])
node_features_2 = jnp.array([[40], [50],])

pattern_1, edge_list_1 = twl_sparse_pattern(edges_1, 3)
pattern_2, edge_list_2 = twl_sparse_pattern(edges_2, 2)

edge_features_1 = twl_sparse_edge_features(node_features_1, edge_list_1.shape[0], 3)
edge_features_2 = twl_sparse_edge_features(node_features_2, edge_list_2.shape[0], 2)

jnp.save(f"Test_Data/TWL/twl_id_0/ref_matrix.npy", pattern_1)
jnp.save(f"Test_Data/TWL/twl_id_0/edge_features.npy", edge_features_1)
jnp.save(f"Test_Data/TWL/twl_id_0/y.npy", jnp.array([1]))

jnp.save(f"Test_Data/TWL/twl_id_1/ref_matrix.npy", pattern_2)
jnp.save(f"Test_Data/TWL/twl_id_1/edge_features.npy", edge_features_2)
jnp.save(f"Test_Data/TWL/twl_id_1/y.npy", jnp.array([1]))


dataloader = TWL_Dataloader(f"Test_Data/TWL", 2)
arrays = next(dataloader.batch_iterator(2))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
list(arrays.keys())

['edge_features',
 'ref_matrix',
 'edge_features_graph_indx',
 'ref_matrix_graph_indx',
 'nb_edges',
 'ys',
 'nb_graphs',
 'id_high']

In [4]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    # two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    # index_aggregation(),
)

key = random.PRNGKey(0)
key, subkey = jax.random.split(key)
_, params = init_fn(subkey, arrays["edge_features"].shape)
params

# set all weights to 1 for the dense layer
params[0][1][0][0]  = (jnp.zeros(params[0][1][0][0][0].shape), None) 
params[0][1][1][0]  = (jnp.ones(params[0][1][1][0][0].shape), None)

# set the L_branche dense weights 0 and the Gamma_branche weights 1
# => output is the sum of features
out = apply_fn(params, arrays["edge_features"], pattern=arrays["ref_matrix"], graph_indx=arrays["edge_features_graph_indx"], nb_graphs=2)

print(jnp.squeeze(arrays["edge_features"]))
print(jnp.squeeze(out))

[[10.  1.]
 [20.  1.]
 [30.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [40.  1.]
 [50.  1.]
 [ 0.  1.]]
[11. 21. 31.  1.  1.  1. 41. 51.  1.]


In [5]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    # index_aggregation(),
)

key = random.PRNGKey(0)
key, subkey = jax.random.split(key)
_, params = init_fn(subkey, arrays["edge_features"].shape)
params

# set all weights to 1 for the dense layer
params[0][1][0][0]  = (jnp.zeros(params[0][1][0][0][0].shape), None) 
params[0][1][1][0]  = (jnp.ones(params[0][1][1][0][0].shape), None)

# set the L_branche dense weights 0 and the Gamma_branche weights 1
# => the input for the twl aggregation is [11. 21. 31.  1.  1.  1. 41. 51.  1.]
# 0,0 0,1 1,0 |  1  1 ->  2
# 0,0 0,2 2,0 |  1  1 ->  2
# 0,0 0,0 0,0 | 11 11 -> 22
#                        26
# 1,1 1,1 1,1 | 21 21 -> 42
# 1,1 1,0 0,1 |  1  1 ->  2
# 1,1 1,2 2,1 |  1  1 ->  2
#                        46
# etc.
out = apply_fn(params, arrays["edge_features"], pattern=arrays["ref_matrix"], graph_indx=arrays["edge_features_graph_indx"], nb_graphs=2)

print(jnp.squeeze(arrays["edge_features"]))
print(jnp.squeeze(out))


[[10.  1.]
 [20.  1.]
 [30.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [40.  1.]
 [50.  1.]
 [ 0.  1.]]
[ 26.  46.  66.  36.  46.  56.  84. 104.  94.]


In [6]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    # two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    # index_aggregation(),
)


# only checking the shape
out_kernel = kernel_fn(arrays["edge_features"], arrays["edge_features"], pattern=(arrays["ref_matrix"], arrays["ref_matrix"]), graph_indx=(arrays["edge_features_graph_indx"], arrays["edge_features_graph_indx"]), nb_edges=(9,9), nb_graphs=(2,2))
print(out_kernel.ntk)

[[2.020e+02 4.020e+02 6.020e+02 2.000e+00 2.000e+00 2.000e+00 8.020e+02
  1.002e+03 2.000e+00]
 [4.020e+02 8.020e+02 1.202e+03 2.000e+00 2.000e+00 2.000e+00 1.602e+03
  2.002e+03 2.000e+00]
 [6.020e+02 1.202e+03 1.802e+03 2.000e+00 2.000e+00 2.000e+00 2.402e+03
  3.002e+03 2.000e+00]
 [2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00
  2.000e+00 2.000e+00]
 [2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00
  2.000e+00 2.000e+00]
 [2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00
  2.000e+00 2.000e+00]
 [8.020e+02 1.602e+03 2.402e+03 2.000e+00 2.000e+00 2.000e+00 3.202e+03
  4.002e+03 2.000e+00]
 [1.002e+03 2.002e+03 3.002e+03 2.000e+00 2.000e+00 2.000e+00 4.002e+03
  5.002e+03 2.000e+00]
 [2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00
  2.000e+00 2.000e+00]]




# Example to test if the pattern in the kernel function behaves as expected!

In [7]:
# copy the kernel function from the layers module

from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
from neural_tangents import Kernel
from neural_tangents._src.stax.requirements import (
    Bool,
    Diagonal,
    get_diagonal_outer_prods,
    layer,
    mean_and_var,
    requires,
    supports_masking,
)
from utils import row_wise_karthesian_prod
from jax import numpy as np


def kernel_fn(
        k: Kernel,
        *,
        pattern: Tuple[Optional[np.ndarray], Optional[np.ndarray]] = (None, None),
        nb_edges: Tuple[Optional[int], Optional[int]] = (None, None),
        **kwargs
    ):

        # nb_edges is equal to the corresponding shape[0] of the apply_fn inputs argument

        num_segments = int(np.prod(np.array(k.ntk.shape)))

        patterns = row_wise_karthesian_prod(pattern[0], pattern[1])

        e_i_j_ib_jb = np.ravel_multi_index(
            [patterns[:, 0], patterns[:, 3]], (nb_edges[0], nb_edges[1])
        )
        e_i_a_ib_ab = np.ravel_multi_index(
            [patterns[:, 1], patterns[:, 4]], (nb_edges[0], nb_edges[1])
        )
        e_i_a_ab_jb = np.ravel_multi_index(
            [patterns[:, 1], patterns[:, 5]], (nb_edges[0], nb_edges[1])
        )
        e_a_j_ib_ab = np.ravel_multi_index(
            [patterns[:, 2], patterns[:, 4]], (nb_edges[0], nb_edges[1])
        )
        e_a_j_ab_jb = np.ravel_multi_index(
            [patterns[:, 2], patterns[:, 5]], (nb_edges[0], nb_edges[1])
        )

        def agg(x):
            theta_i_a_ib_ab = jax.ops.segment_sum(
                np.take(x, e_i_a_ib_ab), e_i_j_ib_jb, num_segments
            )
            theta_i_a_ab_jb = jax.ops.segment_sum(
                np.take(x, e_i_a_ab_jb), e_i_j_ib_jb, num_segments
            )
            theta_a_j_ib_ab = jax.ops.segment_sum(
                np.take(x, e_a_j_ib_ab), e_i_j_ib_jb, num_segments
            )
            theta_a_j_ab_jb = jax.ops.segment_sum(
                np.take(x, e_a_j_ab_jb), e_i_j_ib_jb, num_segments
            )

            thetas_linear = np.array(
                [theta_i_a_ib_ab, theta_i_a_ab_jb, theta_a_j_ib_ab, theta_a_j_ab_jb]
            )
            theta_linear = np.sum(thetas_linear, 0)
            theta = np.reshape(theta_linear, x.shape)
            return theta

        ntk = agg(k.ntk)
        nngp = agg(k.nngp)

        return k.replace(ntk=ntk, nngp=nngp, is_gaussian=True, is_input=False)

define a toy kernel matrix (the input to the kernel function of the GCN layer is the kernel of the previous layer)

In [12]:
toy_kernel_2_new = np.array(range(3*3))
toy_kernel_2_new = np.reshape(toy_kernel_2_new, (3,3))

toy_kernel_2_new = np.expand_dims(toy_kernel_2_new, 0)
toy_kernel_2_new = np.expand_dims(toy_kernel_2_new, 0)
toy_kernel_2_new = np.expand_dims(toy_kernel_2_new, 0)
toy_kernel_2_new = np.expand_dims(toy_kernel_2_new, 0)

toy_kernel_2_new

edge_list_2 = np.array([[0,0],
                      [1,1],
                      [0,1],
                      [1,0]])

pattern, edge_list = twl_sparse_pattern(edge_list_2, 2)

print(toy_kernel_2_new)
print(pattern)

[[[[[[0 1 2]
     [3 4 5]
     [6 7 8]]]]]]
[[0 0 0]
 [1 1 1]
 [2 2 1]
 [2 0 2]
 [1 2 2]
 [0 2 2]]


The output of the kernel function for the toy kernel ($\Theta_{(i,j,i',j')}^{(l}$) must be:
$$
    \Theta_{(i,j,i',j')}^{(l)}
    =
    \sum_{v_a \in \Gamma_G^r(v_i) \cap \Gamma_G^r(v_j)}
    \sum_{v_a' \in \Gamma_G^r(v_i') \cap \Gamma_G^r(v_j')}
    \Theta_{(i,a,i',a')}^{(l-1)}
    +
    \Theta_{(a,j,a',j')}^{(l-1)}
    +
    \Theta_{(i,a,a',j')}^{(l-1)}
    +
    \Theta_{(a,j,i',a')}^{(l-1)}
$$
We can calculate this manually with the following table (see excel file "twl_kernel")

In [13]:
k = nt.Kernel(nngp=toy_kernel_2_new, 
          ntk=toy_kernel_2_new,
          cov1=None,
          cov2=None, 
          x1_is_x2=None, 
          is_gaussian=None, 
          is_reversed=None, 
          is_input=None, 
          diagonal_batch=None, 
          diagonal_spatial=None, 
          shape1=None, 
          shape2=None, 
          batch_axis=None, 
          channel_axis=None, 
          mask1=None, 
          mask2=None)

kernel_matrix = kernel_fn(k=k, pattern=(pattern, pattern), nb_edges=(3,3), nb_graphs=(1,1))

kernel_matrix.ntk

Array([[[[[[64, 72, 68],
           [88, 96, 92],
           [76, 84, 80]]]]]], dtype=int32)