In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data_preparation import load_data_tensor


In [43]:
lr_train, lr_test, hr_train = load_data_tensor('dgl-icl')

# Sheaf Neural Diffusion

In [15]:
# ignore this for now
class Laplacian_MLP(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.layer = nn.Linear(2*d, d)

    def forward(self, own_node, other_node):
        x = torch.concat((own_node, other_node))
        return F.relu(self.layer(x))


In [71]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        # random init weight matrices
        self.weight1 = nn.Parameter(torch.randn((d, d)))
        self.weight2 = nn.Parameter(torch.randn((f, f)))
        self.edge_weights = nn.Parameter(torch.randn((d*n_nodes,2*d*n_nodes)))


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        return X - F.relu(L @ kron_prod @ X @ self.weight2) 


    def sheaf_laplacian(self, X, adj):
        laplacian_ls = []
        for v in range(self.n_nodes):
            L_v = torch.zeros((self.d, self.d))
            for u in range(self.n_nodes):
                edge_weight = self.edge_weights[v][u]
                stacked_features = torch.concat((X[v*self.d:(v+1)*self.d], X[u*self.d:(u+1)*self.d]))
                lin_trans = F.relu(edge_weight @ stacked_features)
                L_v += adj[v, u] * lin_trans @ lin_trans.T
            laplacian_ls.append(L_v / torch.sum(adj[v]))
        return torch.block_diag(*laplacian_ls)


In [82]:
sum(p.numel() for p in SheafConvLayer(168, 1, 5).parameters())


56474