In [16]:
from torch_geometric.datasets import Planetoid
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
dataset = Planetoid("../datasets/Cora", "Cora")
data = dataset[0]

In [2]:
from gnn_180b.layer.sc_layer import SCLayer

In [3]:
sc = SCLayer([14, 14], "ReLU", 1433, 3, [14, 14])
clusters = sc(data.x, data.edge_index, data.edge_weight)[0]

In [24]:
class SparseAttention(nn.Module):
    def __init__(self, in_channels, out_channels, num_clusters, num_heads, concat=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_clusters = num_clusters
        self.num_heads = num_heads
        self.concat = concat

        self.weight = nn.Parameter(torch.Tensor(num_heads, in_channels, out_channels))
        self.bias = nn.Parameter(torch.Tensor(num_heads, out_channels))
        self.reset_parameters()

        self.cluster_weight = nn.Parameter(torch.Tensor(num_clusters, num_heads, out_channels))
        self.cluster_bias = nn.Parameter(torch.Tensor(num_heads, out_channels))
        self.reset_cluster_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def reset_cluster_parameters(self):
        nn.init.xavier_uniform_(self.cluster_weight)
        nn.init.zeros_(self.cluster_bias)

    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        print(x.shape, self.cluster_weight.shape)
        attn_coeffs = torch.einsum("ij,hjk->hik", x, self.cluster_weight) + self.cluster_bias
        attn_coeffs = F.softmax(attn_coeffs, dim=1)
        x = torch.einsum("ij,hjk->hik", x, attn_coeffs * self.weight) * self.bias
        x = F.relu(x)

        adj = to_dense_adj(edge_index)[0]

        deg = torch.sum(adj, dim=1, keepdim=True)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0
        norm_adj = adj * deg_inv_sqrt * deg_inv_sqrt.t()

        output = []

        for i in range(self.num_heads):
            h = torch.mm(norm_adj, x[i])
            output.append(h)

        if self.concat:
            x = torch.cat(output, dim=-1)
        else:
            x = torch.stack(output, dim=0)

        return x

In [26]:
sp = SparseAttention(3, 16, 3, 3)
sp(clusters, data.edge_index)

torch.Size([2708, 3]) torch.Size([3, 3, 16])


RuntimeError: The size of tensor a (2708) must match the size of tensor b (3) at non-singleton dimension 1

In [21]:
clusters.shape

torch.Size([2708, 3])

In [23]:
data.edge_index.shape

torch.Size([2, 10556])