In [1]:
import jax.numpy as jnp
import torch_geometric as tg
from torch_geometric.datasets import Planetoid

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cora_dataset = Planetoid(root='data', name='Cora')
data = cora_dataset[0]

In [3]:
node_features = jnp.array(data.x)
edge_index = jnp.array(data.edge_index)
node_labels = jnp.array(data.y)

adjacency_matrix = jnp.array(tg.utils.to_dense_adj(data.edge_index).squeeze())

In [4]:
print(f'Num nodes: {node_features.shape[0]}')
print(f'Feature dimension: {node_features.shape[1]}')
print(f'Num edges: {edge_index.shape[1]}')
print(f'Shape adjacency matrix: {adjacency_matrix.shape}')


Num nodes: 2708
Feature dimension: 1433
Num edges: 10556
Shape adjacency matrix: (2708, 2708)


In [5]:
adjacency_matrix[0, 0]

DeviceArray(0., dtype=float32)

In [6]:
adj_matrix = jnp.array(tg.utils.to_dense_adj(data.edge_index).squeeze())
adj_matrix = adj_matrix + jnp.identity(len(adj_matrix))
connectivity_mask = (adj_matrix == 0) * -jnp.inf

connectivity_mask.shape

(2708, 2708)

In [24]:
nodes_features_sum = node_features.sum(-1)
nodes_features_sum_inv = jnp.power(nodes_features_sum, -1)
diagonal_sum_inv = jnp.diag(nodes_features_sum_inv)
nodes_features_normalized = diagonal_sum_inv.dot(node_features)

In [27]:
nodes_features_normalized.sum(1)

DeviceArray([1.       , 1.       , 1.0000001, ..., 1.0000001, 1.0000001,
             1.       ], dtype=float32)