In [1]:
from layers import two_wl_aggregation
import neural_tangents as nt
from neural_tangents import stax
from jax import numpy as jnp
from jax import random
import jax
import shutil
import os
from jax.example_libraries import optimizers
from jax import jit, grad, vmap
from prepare_for_dataloader import twl_sparse_pattern, twl_sparse_edge_features
from dataloader import TWL_Dataloader


  from .autonotebook import tqdm as notebook_tqdm


# Example (two graphs) to test if the full network behaves as expected

In [2]:
# define a toy 3x3 graph
# 0---2
# | /     
# 1

# define a toy 1x1 graph
# 3---4

edges_1 = jnp.array([[0,0], [1,1], [2,2], [0,1], [0,2], [1,2]])
edges_2 = jnp.array([[0,0], [0,1], [1,1]])

node_features_1 = jnp.array([[10], [20], [30]])
node_features_2 = jnp.array([[40], [50],])

pattern_1, edge_list_1 = twl_sparse_pattern(edges_1, 3)
pattern_2, edge_list_2 = twl_sparse_pattern(edges_2, 2)

edge_features_1 = twl_sparse_edge_features(node_features_1, edge_list_1.shape[0], 3)
edge_features_2 = twl_sparse_edge_features(node_features_2, edge_list_2.shape[0], 2)

jnp.save(f"Test_Data/TWL/twl_id_0/ref_matrix.npy", pattern_1)
jnp.save(f"Test_Data/TWL/twl_id_0/edge_features.npy", edge_features_1)
jnp.save(f"Test_Data/TWL/twl_id_0/y.npy", jnp.array([1]))

jnp.save(f"Test_Data/TWL/twl_id_1/ref_matrix.npy", pattern_2)
jnp.save(f"Test_Data/TWL/twl_id_1/edge_features.npy", edge_features_2)
jnp.save(f"Test_Data/TWL/twl_id_1/y.npy", jnp.array([1]))


dataloader = TWL_Dataloader(f"Test_Data/TWL", 2)
arrays = next(dataloader.batch_iterator(2))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Reference output befor the index aggregation layer

In [4]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    # index_aggregation(),
)

key = random.PRNGKey(0)
key, subkey = jax.random.split(key)
_, params = init_fn(subkey, arrays["edge_features"].shape)
params

# set all weights to 1 for the dense layer
params[0][1][0][0]  = (jnp.zeros(params[0][1][0][0][0].shape), None) 
params[0][1][1][0]  = (jnp.ones(params[0][1][1][0][0].shape), None)

# set the L_branche dense weights 0 and the Gamma_branche weights 1
# => output is the sum of features
out = apply_fn(params, arrays["edge_features"], pattern=arrays["ref_matrix"], graph_indx=arrays["edge_features_graph_indx"], nb_graphs=2)

print(jnp.squeeze(arrays["edge_features"]))
print(jnp.squeeze(out))

[[10.  1.]
 [20.  1.]
 [30.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [ 0.  1.]
 [40.  1.]
 [50.  1.]
 [ 0.  1.]]
[ 26.  46.  66.  36.  46.  56.  84. 104.  94.]


The first 6 edges are from the first graph and the last 3 edges are from the second graph
=> the output after the aggregation musst be the sum 

In [6]:
print(jnp.sum(jnp.squeeze(out)[:-3]))
print(jnp.sum(jnp.squeeze(out)[-3:]))

276.0
282.0


In [28]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    index_aggregation(),
)

key = random.PRNGKey(0)
key, subkey = jax.random.split(key)
_, params = init_fn(subkey, arrays["edge_features"].shape)
params

# set all weights to 1 for the dense layer
params[0][1][0][0]  = (jnp.zeros(params[0][1][0][0][0].shape), None) 
params[0][1][1][0]  = (jnp.ones(params[0][1][1][0][0].shape), None)

# set the L_branche dense weights 0 and the Gamma_branche weights 1
# => output is the sum of features
out = apply_fn(params, arrays["edge_features"], pattern=arrays["ref_matrix"], graph_indx=arrays["edge_features_graph_indx"], nb_graphs=2)

print(jnp.squeeze(out)[:2])

[276. 282.]


# Example to test if the pattern in the kernel function behaves as expected!

check the shape of the kernel input for the index aggregation => input shape is (nb_edges, nb_edges)

In [32]:
from layers import index_aggregation
# two_wl_aggregation_layer = get_two_wl_aggregation_layer("standard", 10)
# define new becaus need to not use relu to check results

L_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    )

Gamma_branche = stax.serial(
    stax.Dense(1, parameterization="standard"),
    two_wl_aggregation(),
)

two_wl_aggregation_layer = stax.serial(
    stax.FanOut(2), 
    stax.parallel(L_branche, Gamma_branche),
    stax.FanInSum(),
    # stax.Relu(),
    )

init_fn, apply_fn, kernel_fn = stax.serial(
    two_wl_aggregation_layer,
    # index_aggregation(),
)

key = random.PRNGKey(0)
key, subkey = jax.random.split(key)
_, params = init_fn(subkey, arrays["edge_features"].shape)
params

# set all weights to 1 for the dense layer
params[0][1][0][0]  = (jnp.zeros(params[0][1][0][0][0].shape), None) 
params[0][1][1][0]  = (jnp.ones(params[0][1][1][0][0].shape), None)

out_kernel = kernel_fn(arrays["edge_features"], arrays["edge_features"], pattern=(arrays["ref_matrix"], arrays["ref_matrix"]), graph_indx=(arrays["edge_features_graph_indx"], arrays["edge_features_graph_indx"]), nb_edges=(9,9), nb_graphs=(2,2))
print(out_kernel.ntk)
print(out_kernel.ntk.shape)

[[  537.  1037.  1537.   637.   837.  1037.  2025.  2525.  1825.]
 [ 1037.  2037.  3037.  1237.  1637.  2037.  4025.  5025.  3625.]
 [ 1537.  3037.  4537.  1837.  2437.  3037.  6025.  7525.  5425.]
 [  637.  1237.  1837.   937.  1237.  1537.  2425.  3025.  2725.]
 [  837.  1637.  2437.  1237.  1637.  2037.  3225.  4025.  3625.]
 [ 1037.  2037.  3037.  1537.  2037.  2537.  4025.  5025.  4525.]
 [ 2025.  4025.  6025.  2425.  3225.  4025.  8017. 10017.  7217.]
 [ 2525.  5025.  7525.  3025.  4025.  5025. 10017. 12517.  9017.]
 [ 1825.  3625.  5425.  2725.  3625.  4525.  7217.  9017.  8117.]]
(9, 9)


In [47]:
# copy the kernel function from the layers module

from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
from neural_tangents import Kernel
from neural_tangents._src.stax.requirements import (
    Bool,
    Diagonal,
    get_diagonal_outer_prods,
    layer,
    mean_and_var,
    requires,
    supports_masking,
)
from utils import row_wise_karthesian_prod
from jax import numpy as np


def kernel_fn(
    k: Kernel,
    *,
    graph_indx: Tuple[Optional[np.ndarray], Optional[np.ndarray]] = (None, None),
    nb_graphs: Tuple[Optional[int], Optional[int]] = (None, None),
    **kwargs
):
    num_segments = nb_graphs[0] * nb_graphs[1]

    def agg(x, kernel_graph_indx):
        agg_x = jax.ops.segment_sum(
            np.reshape(x, (-1)),
            kernel_graph_indx,
            num_segments,
        )
        agg_x = np.reshape(agg_x, nb_graphs)
        return agg_x

    k_prod_graph_indx = row_wise_karthesian_prod(
        np.expand_dims(graph_indx[0], 1), np.expand_dims(graph_indx[1], 1)
    )
    kernel_graph_indx = np.ravel_multi_index(
        [k_prod_graph_indx[:, 0], k_prod_graph_indx[:, 1]], nb_graphs
    )

    agg_ntk = agg(k.ntk, kernel_graph_indx)
    agg_nngp = agg(k.nngp, kernel_graph_indx)

    return k.replace(
        ntk=agg_ntk, nngp=agg_nngp, is_gaussian=True, is_input=False, channel_axis=1
    )

define a toy kernel matrix (the input to the kernel function of the GCN layer is the kernel of the previous layer)

In [45]:
toy_kernel_2_new = np.array(
    [
        [1, 2, 3, 4, 5],
        [2, 1, 3, 4, 5],
        [3, 3, 1, 4, 5],
        [4, 4, 4, 1, 5],
        [5, 5, 5, 5, 1],
    ]
)

toy_kernel_2_new_indx = jnp.array([0, 0, 0, 1, 1])

print(toy_kernel_2_new)
print(toy_kernel_2_new.shape)

[[1 2 3 4 5]
 [2 1 3 4 5]
 [3 3 1 4 5]
 [4 4 4 1 5]
 [5 5 5 5 1]]
(5, 5)


In [48]:
k = nt.Kernel(nngp=toy_kernel_2_new, 
          ntk=toy_kernel_2_new,
          cov1=None,
          cov2=None, 
          x1_is_x2=None, 
          is_gaussian=None, 
          is_reversed=None, 
          is_input=None, 
          diagonal_batch=None, 
          diagonal_spatial=None, 
          shape1=None, 
          shape2=None, 
          batch_axis=None, 
          channel_axis=None, 
          mask1=None, 
          mask2=None)

kernel_matrix = kernel_fn(k=k, graph_indx=(toy_kernel_2_new_indx, toy_kernel_2_new_indx), nb_graphs=(2,2))

kernel_matrix.ntk

Array([[19, 27],
       [27, 12]], dtype=int32)

Expected

In [49]:
print(jnp.sum(toy_kernel_2_new[:3, :3]))
print(jnp.sum(toy_kernel_2_new[3:, 3:]))
print(jnp.sum(toy_kernel_2_new[:3, 3:]))
print(jnp.sum(toy_kernel_2_new[3:, :3]))

19
12
27
27
