In [188]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
import networkx as nx

import os.path as osp
import pickle
from scipy.linalg import sqrtm
import argparse
from scipy.stats import wasserstein_distance
from torch.distributions import normal, kl


import argparse
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from torch_geometric.utils import train_test_split_edges
import matplotlib.pyplot as plt
import warnings
from sklearn.model_selection import KFold
import pandas as pd


In [189]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

# set global variables
N_SUBJECTS = 167

N_LR_NODES = 160

N_HR_NODES = 268

EPOCHS = 10

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

N_LR_NODES_F = int(N_LR_NODES * (N_LR_NODES-1) / 2)
N_HR_NODES_F = int(N_HR_NODES * (N_HR_NODES-1) / 2)


In [195]:
class SheafConvLayer(nn.Module):
    def __init__(self, n_nodes, d, f_in, f_out=None):
        super().__init__()
        self.d = d
        self.n_nodes = n_nodes
        self.f_out = f_out
        # random init weight matrices
        if f_out is None:
            f_out = f_in 
        self.weight1 = nn.Parameter(torch.randn((d, d))).to(DEVICE)
        self.weight2 = nn.Parameter(torch.randn((f_in, f_out))).to(DEVICE)
        self.edge_weights = nn.Parameter(torch.randn((d*n_nodes,2*d*n_nodes))).to(DEVICE)


    def forward(self, X, adj):
        kron_prod = torch.kron(torch.eye(self.n_nodes).to(DEVICE), self.weight1)
        L = self.sheaf_laplacian(X, adj)
        if self.f_out is None:
            return X - F.relu(L @ kron_prod @ X @ self.weight2) 
        else:
            return F.relu(L @ kron_prod @ X @ self.weight2) 


    def sheaf_laplacian(self, X, adj):
        laplacian_ls = []
        for v in range(self.n_nodes):
            L_v = torch.zeros((self.d, self.d)).to(DEVICE)
            for u in range(self.n_nodes):
                edge_weight = self.edge_weights[v*self.d:(v+1)*self.d, u*2*self.d:(u+1)*2*self.d]
                stacked_features = torch.concat((X[v*self.d:(v+1)*self.d], X[u*self.d:(u+1)*self.d]))
                lin_trans = F.relu(edge_weight @ stacked_features).to(DEVICE)
                L_v += adj[v, u] * lin_trans @ lin_trans.T
            laplacian_ls.append(L_v / torch.sum(adj[v]))
        return torch.block_diag(*laplacian_ls)


In [218]:
class SheafAligner(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.sheafconv1 = SheafConvLayer(N_LR_NODES, 2, 16)
        self.batchnorm1 = BatchNorm(16)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, 2, 16)
        self.batchnorm2 = BatchNorm(16)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, 2, 16)
        self.batchnorm3 = BatchNorm(16)

    def forward(self, X, adj):

        x1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, training=self.training)

        x2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, training=self.training)

        x3 = self.sheafconv3(x2, adj)
        x3 = F.sigmoid(self.batchnorm3(x3))
        x3 = torch.cat([x3, x1], dim=1)

        x4 = x3[:, 0:16]
        x5 = x3[:, 16:2*16]


        return (x4 + x5) / 2

        

In [219]:
class SheafGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.sheafconv1 = SheafConvLayer(N_LR_NODES, 2, 16)
        self.batchnorm1 = BatchNorm(16)

        self.sheafconv2 = SheafConvLayer(N_LR_NODES, 2, 16)
        self.batchnorm2 = BatchNorm(16)

        self.sheafconv3 = SheafConvLayer(N_LR_NODES, 2, 16, N_HR_NODES)
        self.batchnorm3 = BatchNorm(N_HR_NODES)

        self.out_mat = nn.Parameter(torch.randn((N_LR_NODES, 2*N_LR_NODES))).to(DEVICE)
        
        


    def forward(self, X, adj):
        x1 = self.sheafconv1(X, adj)
        x1 = F.sigmoid(self.batchnorm1(x1))
        x1 = F.dropout(x1, p=0.1, training=self.training)

        x2 = self.sheafconv2(x1, adj)
        x2 = F.sigmoid(self.batchnorm2(x2))
        x2 = F.dropout(x2, p=0.1, training=self.training)

        x3 = self.sheafconv3(x2, adj)
        x3 = F.sigmoid(self.batchnorm3(x3))
        
        return F.sigmoid(x3.T @  self.out_mat.T @ adj @ self.out_mat @ x3)
 

In [220]:
# NEEDS TO CHANGE AND ADAPT
class SheafDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.sheafconv1 = SheafConvLayer(N_HR_NODES, 2, N_HR_NODES)

        self.sheafconv2 = SheafConvLayer(N_HR_NODES, 2, N_HR_NODES, 1)
        self.out = torch.nn.Linear(N_HR_NODES, 1)

    def forward(self, X, adj):
        x1 = F.sigmoid(self.sheafconv1(X, adj))
        x1 = F.dropout(x1, p=0.1, training=self.training)
        x2 = F.sigmoid(self.sheafconv2(X, adj))
        x3 = F.sigmoid(self.out(x2))
        return x3

In [112]:
k = Data(x=torch.randn((3, 3)), pos_edge_index=torch.randint(0, 3, (2, 9)), edge_attr=torch.randn((9, 1)))

In [115]:
x = Their_Generator()
x(k).shape

torch.Size([1, 10, 10])

In [199]:
X = torch.randn((2*N_LR_NODES, 16))
adj = lr_train[0]

In [221]:
aligner = SheafAligner().to('cuda')
aligned = aligner(X.to(DEVICE), adj.to(DEVICE))

In [212]:
aligned.shape

torch.Size([160, 32])

In [222]:
generator = SheafGenerator().to('cuda')
generator(aligned, adj.to(DEVICE))

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)

In [184]:
hr_adj = hr_train[0]
Y = torch.randn((N_HR_NODES, N_HR_NODES))
discriminator = SheafDiscriminator().to('cuda')
discriminator(Y.to(DEVICE), hr_adj.to(DEVICE))

KeyboardInterrupt: 