In [1]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
import networkx as nx

import os.path as osp
from scipy.linalg import sqrtm
from scipy.stats import wasserstein_distance
from torch.distributions import normal, kl


import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from torch_geometric.utils import train_test_split_edges
from sklearn.model_selection import KFold
import pandas as pd


In [2]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

EPOCHS = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)


In [3]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d))).to(DEVICE)
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out))).to(DEVICE)
        self.edge_weights = nn.Parameter(torch.randn((d*n_nodes,2*d*n_nodes))).to(DEVICE)


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.relu(L @ kron_prod @ X @ self.weight2) 
        else:
            return F.relu(L @ kron_prod @ X @ self.weight2) 


    def sheaf_laplacian(self, X, adj):
        laplacian_ls = []
        for v in range(self.n_nodes):
            L_v = torch.zeros((self.d, self.d)).to(DEVICE)
            for u in range(self.n_nodes):
                edge_weight = self.edge_weights[v*self.d:(v+1)*self.d, u*2*self.d:(u+1)*2*self.d]
                stacked_features = torch.concat((X[v*self.d:(v+1)*self.d], X[u*self.d:(u+1)*self.d]))
                lin_trans = F.relu(edge_weight @ stacked_features).to(DEVICE)
                L_v += adj[v, u] * lin_trans @ lin_trans.T
            laplacian_ls.append(L_v / torch.sum(adj[v]))
        return torch.block_diag(*laplacian_ls)


In [4]:
class SheafAligner(nn.Module):
    
    def __init__(self, d):
        super().__init__()

        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, 16)
        self.batchnorm1 = BatchNorm(16)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, 16)
        self.batchnorm2 = BatchNorm(16)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, 16)
        self.batchnorm3 = BatchNorm(16)

    def forward(self, X, adj):

        x1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, training=self.training)

        x2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, training=self.training)

        x3 = self.sheafconv3(x2, adj)
        x3 = F.sigmoid(self.batchnorm3(x3))
        # x3 = torch.cat([x3, x1], dim=1)

        # x4 = x3[:, 0:16]
        # x5 = x3[:, 16:2*16]


        return x3

        

In [5]:
class SheafGenerator(nn.Module):
    def __init__(self, d):
        super().__init__()
        
        self.sheafconv1 = SheafConvLayer(N_LR_NODES, d, 16)
        self.batchnorm1 = BatchNorm(16)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, d, 16)
        self.batchnorm2 = BatchNorm(16)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, d, 16, N_HR_NODES)
        self.batchnorm3 = BatchNorm(N_HR_NODES)

        self.out_mat = nn.Parameter(torch.randn((N_LR_NODES, 2*N_LR_NODES))).to(DEVICE)
        
        


    def forward(self, X, adj):
        x1 = self.sheafconv1(X, adj) # returns (d*lr_n) * 16
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, p=0.1, training=self.training)

        x2 = self.sheafconv2(x1, adj) # returns (d*lr_n) * 16
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, p=0.1, training=self.training)

        x3 = self.sheafconv3(x2, adj) # returns (d*lr_n) * hr_n
        x3 = F.sigmoid(self.batchnorm3(x3))
        x3 = F.sigmoid(x3.T @  self.out_mat.T @ adj @ self.out_mat @ x3)

        return (x3 + x3.T) / 2 # to ensure the matrix is symmetric
 

In [6]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(N_HR_NODES, N_HR_NODES)
        self.conv2 = GCNConv(N_HR_NODES, 1)
        self.linear = torch.nn.Linear(N_HR_NODES, 1)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
        x1 = F.sigmoid(self.conv1(x, edge_index))
        x1 = F.dropout(x1, p=0.1)
        x2 = F.sigmoid(self.conv2(x1, edge_index))
        return x2


In [7]:
# NEEDS TO CHANGE AND ADAPT
class SheafDiscriminator(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.sheafconv1 = SheafConvLayer(N_HR_NODES, d, 16)

        self.sheafconv2 = SheafConvLayer(N_HR_NODES, d, 16, 1)
        self.out = torch.nn.Linear(2*N_HR_NODES, 1)

    def forward(self, X, adj):
        x1 = F.sigmoid(self.sheafconv1(X, adj))
        x1 = F.dropout(x1, p=0.1, training=self.training)
        x2 = F.sigmoid(self.sheafconv2(X, adj))
        x3 = F.sigmoid(self.out(x2.flatten()))
        return x3

In [8]:
k = Data(x=torch.randn((3, 3)), pos_edge_index=torch.randint(0, 3, (2, 9)), edge_attr=torch.randn((9, 1)))

In [9]:
X = torch.randn((2*N_LR_NODES, 16))
adj = lr_train[0]

In [10]:
aligner = SheafAligner(2).to('cuda')
aligned = aligner(X.to(DEVICE), adj.to(DEVICE)) # should return (n_lr * d) * f matrix

In [11]:
aligned.shape

torch.Size([320, 16])

In [12]:
generator = SheafGenerator(2).to('cuda')
generated = generator(aligned, adj.to(DEVICE))

In [13]:
generated.shape

torch.Size([268, 268])

In [14]:
Y = torch.randn((N_HR_NODES*2, 16))
discriminator = SheafDiscriminator(2).to('cuda')
dis_decision = discriminator(Y.to(DEVICE), generated.to(DEVICE))

In [15]:
dis_decision

tensor([0.4299], device='cuda:0', grad_fn=<SigmoidBackward0>)