In [1]:
from fundl.layers.rnn import gru

In [2]:
import numpy.random as npr
import jax.numpy as np

x = np.arange(0, 100).reshape(10, 10)
x

array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
       [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
       [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
       [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
       [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
       [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])

In [3]:
from fundl.weights import add_gru_params

params = dict()
params = add_gru_params(params, "gru", 10, 2)



In [4]:
gru(params["gru"], x)

DeviceArray([[ 3.83266062e-01, -0.00000000e+00],
             [ 3.85106236e-01, -0.00000000e+00],
             [ 3.16568166e-01, -0.00000000e+00],
             [-1.19605124e-01, -0.00000000e+00],
             [-1.69583106e+00, -0.00000000e+00],
             [-7.06654167e+00, -0.00000000e+00],
             [-2.64628792e+01, -0.00000000e+00],
             [-1.03442741e+02, -0.00000000e+00],
             [-4.42269745e+02, -0.00000000e+00],
             [-2.10259766e+03, -0.00000000e+00]], dtype=float32)

In [5]:
from jax.lax import batch_matmul
from collections import defaultdict

In [6]:
node_feats = defaultdict(list)
adjacency_matrices = defaultdict(list)
for i in range(1000):
    num_nodes = npr.randint(3, 20)
    node_feat = npr.normal(size=(num_nodes, 10))
    amat = npr.binomial(n=1, p=0.3, size=(num_nodes, num_nodes)).astype(float)
    
    node_feats[num_nodes].append(node_feat)
    adjacency_matrices[num_nodes].append(amat)
    # node_feats.append(node_feat)
    # adjacency_matrices.append(amat)

# Prepare data batched by graph size

In [7]:
node_feats_batched = dict()
adj_mats_batched = dict()
for k, v in node_feats.items():
    node_feats_batched[k] = np.stack(v)
    adj_mats_batched[k] = np.stack(adjacency_matrices[k])

In [8]:
def message_passing(feats, adjs):
    outputs = dict()
    for k in feats.keys():
        outputs[k] = batch_matmul(adjs[k], feats[k])
    return outputs

output = message_passing(node_feats_batched, adj_mats_batched)

In [9]:
output[19].shape

(72, 19, 10)

In [12]:
num_nodes = 13
node_feats_batched[num_nodes].shape, adj_mats_batched[num_nodes].shape

((56, 13, 10), (56, 13, 13))

# Prepare data padded with zeros to largest graph size

In [13]:
largest_graph_size = 20

for size, feats in node_feats.items():
    pad_size = largest_graph_size - size
    node_feats[size] = np.pad(
        feats, 
        pad_width=(
            # syntax is n_before, n_after
            (0, 0),  # sample dimension, do not touch
            (0, pad_size),  # node dimension, pad to largest size
            (0, 0),  # feats simension, do not touch
        )
    )
    
for size, adj in adjacency_matrices.items():
    pad_size = largest_graph_size - size
    adjacency_matrices[size] = np.pad(
        adj,
        pad_width=[
            # syntax is n_before, n_after
            (0, 0),  # sample dimension, do not touch
            (0, pad_size),  # node dimension, pad to largest size
            (0, pad_size),  # node dimension, pad to largest size
        ]
    )

In [16]:
node_feats[16].shape

(55, 20, 10)

Now, stack them along the same axis.

In [22]:
node_feats_stacked = []
for size, feats in node_feats.items():
    node_feats_stacked.append(feats)
node_feats_stacked = np.concatenate(node_feats_stacked, axis=0)

adj_mats_stacked = []
for size, adj in adjacency_matrices.items():
    adj_mats_stacked.append(adj)
adj_mats_stacked = np.concatenate(adj_mats_stacked, axis=0)

In [29]:
mp1 = batch_matmul(adj_mats_stacked, node_feats_stacked)
summation = np.sum(mp1, axis=1)
summation.shape  # n_samples by n_features

(1000, 10)

In [27]:
batch_matmul??

[0;31mSignature:[0m [0mbatch_matmul[0m[0;34m([0m[0mlhs[0m[0;34m,[0m [0mrhs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mbatch_matmul[0m[0;34m([0m[0mlhs[0m[0;34m,[0m [0mrhs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Batch matrix multiplication."""[0m[0;34m[0m
[0;34m[0m  [0;32mif[0m [0m_min[0m[0;34m([0m[0mlhs[0m[0;34m.[0m[0mndim[0m[0;34m,[0m [0mrhs[0m[0;34m.[0m[0mndim[0m[0;34m)[0m [0;34m<[0m [0;36m2[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34m'Arguments to batch_matmul must be at least 2D, got {}, {}'[0m[0;34m[0m
[0;34m[0m                     [0;34m.[0m[0mformat[0m[0;34m([0m[0mlhs[0m[0;34m.[0m[0mndim[0m[0;34m,[0m [0mrhs[0m[0;34m.[0m[0mndim[0m[0;34m)[0m[0;34m)[0m[0;34m[0m
[0;34m[0m  [0;32mif[0m [0mlhs[0m[0;34m.[0m[0mndim[0m [0;34m!=[0m [0mrhs[0m[0;34m.[0m[0mndim[0m[0;34m:[0m[0;34m[

# Test Equivalence

Make sure that message passing with padding is equivalent to message passing without padding.

In [45]:
F = np.array([[1, 0], [1, 1]])
A = np.array([[1, 0], [0, 1]])

M = np.dot(A, F)

In [46]:
M

DeviceArray([[1, 0],
             [1, 1]], dtype=int32)

In [41]:
pad_size = 2
F_pad = np.pad(
    F, 
    pad_width=[
        (0, pad_size),
        (0, 0),
    ]
)
F_pad

A_pad = np.pad(
    A,
    pad_width=[
        (0, pad_size),
        (0, pad_size),
    ]
)

A_pad

DeviceArray([[1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]], dtype=int32)

In [50]:
A_pad

DeviceArray([[1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]], dtype=int32)

In [42]:
M_pad = np.dot(A_pad, F_pad)

In [51]:
M_pad

DeviceArray([[1, 0],
             [1, 1],
             [0, 0],
             [0, 0]], dtype=int32)

In [44]:
# HERE ARE THE TESTS OF EQUIVALENCE!
assert np.all(np.pad(M, [(0, pad_size), (0, 0)]) == M_pad)
assert np.all(M_pad[:-pad_size, :] == M)