In [None]:
import torch
import torch.nn as nn

import torch.nn.functional as F


class VanillaCGN(nn.Module):
    def __init__(self, input_dim, node_dim, n_layers) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.node_dim = node_dim
        self.n_layers = n_layers
        self.U0 = nn.init.kaiming_normal_(torch.empty(node_dim, input_dim))
        self.b0 = nn.Parameter(torch.randn(node_dim))
        self.convLayers = [ConvNetLayer(self.node_dim) for _ in range(self.n_layers)]

    @staticmethod
    def build_adj_mat(edge_index):
        """Build the adjacency matrix of a graph from edge_index that looks like this: [[1, 4, 5, 6], [2, 3, 3, 5]]
        cf. https://huggingface.co/datasets/graphs-datasets/ZINC for more details"""
        n_nodes = max(max(edge_index[0]), max(edge_index[1]))
        n_edges = len(edge_index[0])
        adj_mat = torch.zeros((n_nodes, n_nodes))
        for i in range(n_edges):
            adj_mat[edge_index[0][i], edge_index[1][i]] = 1
        return adj_mat

    def forward(self, x):
        x = x @ self.U0 + self.b0  # self.b0 is broadcasted properly?
        for i in range(self.n_layers):
            x = self.convLayers[i](x)
        return x


class ConvNetLayer(nn.Module):
    def __init__(self, node_dim) -> None:
        super().__init__()
        self.U = nn.Parameter(torch.randn(node_dim, node_dim))

    def forward(self, x):
        new_x = torch.empty_like(x)
        for i in range(x.shape[0]):
            new_x[i, :] = F.relu(
                self.U @ (x[mask_i, :].sum(dim=0)).to(torch.float32) / deg_i
            )
        return new_x

In [14]:
adj_mat = torch.tensor([[0, 1, 0], [1, 0, 1], [0, 1, 0]])
X = torch.tensor([[5, 2], [3, 4], [10, 20]], dtype=torch.float32)
layer = ConvNetLayer(node_dim=2, adj_mat=adj_mat)

X = layer(X)
print(X)

tensor([[1.9000, 0.0000],
        [4.2847, 0.0000],
        [1.9000, 0.0000]], grad_fn=<CopySlices>)


In [24]:
model = VanillaCGN(input_dim=2, node_dim=2, n_layers=2, adj_mat=adj_mat)
Y = torch.rand((3, 2))
Y = model(Y)
print(Y)

tensor([[0.8336, 4.5861],
        [0.7860, 4.3481],
        [0.8336, 4.5861]], grad_fn=<CopySlices>)


In [None]:
@staticmethod
def build_adj_mat(edge_index):
    """Build the adjacency matrix of a graph from edge_index that looks like this: [[1, 4, 5, 6], [2, 3, 3, 5]]
    cf. https://huggingface.co/datasets/graphs-datasets/ZINC for more details"""
    n_nodes = max(max(edge_index[0]), max(edge_index[1])) + 1 # +1 because the nodes are 0-indexed
    n_edges = len(edge_index[0])
    adj_mat = torch.zeros((n_nodes, n_nodes))
    for i in range(n_edges):
        adj_mat[edge_index[0][i], edge_index[1][i]] = 1
    return adj_mat

edge_index = [ [ 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, 13, 14, 15 ], [ 1, 0, 2, 10, 1, 3, 2, 4, 3, 5, 4, 6, 9, 5, 7, 6, 8, 7, 9, 5, 8, 1, 11, 10, 12, 11, 13, 14, 15, 12, 12, 12 ] ]
adj_mat = build_adj_mat(edge_index)
print(adj_mat)

IndexError: index 15 is out of bounds for dimension 1 with size 15