In [2]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable
import networkx as nx

import os.path as osp
import pickle
from scipy.linalg import sqrtm
import argparse
from scipy.stats import wasserstein_distance
from torch.distributions import normal, kl


import argparse
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE, InnerProductDecoder, ARGVA
from torch_geometric.utils import train_test_split_edges
import matplotlib.pyplot as plt
import warnings
from sklearn.model_selection import KFold


In [3]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor()

In [11]:
torch.set_printoptions(precision=8)
print(lr_train.shape)
print(hr_train.shape)

torch.Size([167, 160, 160])
torch.Size([167, 268, 268])


In [None]:
# Number of subjects in simulated data 
N_SUBJECTS = 50

# Number of ROIs in source brain graph for simulated data 
N_SOURCE_NODES = 160

# Number of ROIs in target brain graph for simulated data
N_TARGET_NODES = 268

# Number of traning epochs
N_EPOCHS = 100


####** DO NOT MODIFY BELOW **####
N_SOURCE_NODES_F =int((N_SOURCE_NODES*(N_SOURCE_NODES-1))/2)
N_TARGET_NODES_F =int((N_TARGET_NODES*(N_TARGET_NODES-1))/2)
###**************************####

In [None]:
class Aligner(torch.nn.Module):
    def __init__(self):
        
        super(Aligner, self).__init__()

        nn = Sequential(Linear(1, N_SOURCE_NODES*N_SOURCE_NODES), ReLU())
        self.conv1 = NNConv(N_SOURCE_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv11 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
        self.conv2 = NNConv(N_SOURCE_NODES, 1, nn, aggr='mean', root_weight=True, bias=True)
        self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
        self.conv3 = NNConv(1, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv33 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)


    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr

        x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
        x1 = F.dropout(x1, training=self.training)

        x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
        x2 = F.dropout(x2, training=self.training)

        x3 = torch.cat([F.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
        x4 = x3[:, 0:N_SOURCE_NODES]
        x5 = x3[:, N_SOURCE_NODES:2*N_SOURCE_NODES]

        x6 = (x4 + x5) / 2
        return x6








class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        nn = Sequential(Linear(1, N_SOURCE_NODES*N_SOURCE_NODES),ReLU())
        self.conv1 = NNConv(N_SOURCE_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv11 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_TARGET_NODES*N_SOURCE_NODES), ReLU())
        self.conv2 = NNConv(N_TARGET_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv22 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_TARGET_NODES*N_SOURCE_NODES), ReLU())
        self.conv3 = NNConv(N_SOURCE_NODES, N_TARGET_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv33 = BatchNorm(N_TARGET_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)


        # self.layer= torch.nn.ConvTranspose2d(N_TARGET_NODES, N_TARGET_NODES,5)


    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
        # x = torch.squeeze(x)

        x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
        x1 = F.dropout(x1, training=self.training)

        # x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
        # x2 = F.dropout(x2, training=self.training)

        x3 = F.sigmoid(self.conv33(self.conv3(x1, edge_index, edge_attr)))
        x3 = F.dropout(x3, training=self.training)



        x4  = torch.matmul(x3.t(), x3)

        return x4

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = GCNConv(N_TARGET_NODES, N_TARGET_NODES, cached=True)
        self.conv2 = GCNConv(N_TARGET_NODES, 1, cached=True)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.pos_edge_index, data.edge_attr
        x = torch.squeeze(x)
        x1 = F.sigmoid(self.conv1(x, edge_index))
        x1 = F.dropout(x1, training=self.training)
        x2 = F.sigmoid(self.conv2(x1, edge_index))
        #         # x2 = F.dropout(x2, training=self.training)


        return x2


In [None]:
# put it back into a 2D symmetric array


def topological_measures(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_matrix(np.absolute(data))
    U = G.to_undirected()

    # Centrality #

    # compute closeness centrality and transform the output to vector
    cc = nx.closeness_centrality(U, distance="weight")
    closeness_centrality = np.array([cc[g] for g in U])
    # compute betweeness centrality and transform the output to vector
    # bc = nx.betweenness_centrality(U, weight='weight')
    # bc = (nx.betweenness_centrality(U))
    betweenness_centrality = np.array([cc[g] for g in U])
    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])


    topology.append(closeness_centrality)  # 0
    topology.append(betweenness_centrality)  # 1
    topology.append(eigenvector_centrality)  # 2

    return topology
# put it back into a 2D symmetric array

def eigen_centrality(data):
    # ROI is the number of brain regions (i.e.,35 in our case)
    ROI = 160

    topology_eigen = []



    # A = to_2d(data)
    np.fill_diagonal(data, 0)

    # create a graph from similarity matrix
    G = nx.from_numpy_matrix(np.absolute(data))
    U = G.to_undirected()

    # Centrality #


    # # compute egeinvector centrality and transform the output to vector
    ec = nx.eigenvector_centrality_numpy(U)
    eigenvector_centrality = np.array([ec[g] for g in U])



    topology_eigen.append(eigenvector_centrality)  # 2

    return topology_eigen

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("running on GPU")
else:
    device = torch.device("cpu")
    print("running on CPU")

l1_loss = torch.nn.L1Loss()
adversarial_loss = torch.nn.BCELoss()
adversarial_loss.to(device)
l1_loss.to(device)


def pearson_coor(input, target):
    vx = input - torch.mean(input)
    vy = target - torch.mean(target)
    cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
    return cost


def GT_loss(target, predicted):

    # l1_loss
    loss_pix2pix = l1_loss(target, predicted)

    # topological_loss
    target_n = target.detach().cpu().clone().numpy()
    predicted_n = predicted.detach().cpu().clone().numpy()
    torch.cuda.empty_cache()

    target_t = eigen_centrality(target_n)
    real_topology = torch.tensor(target_t)
    predicted_t = eigen_centrality(predicted_n)
    fake_topology = torch.tensor(predicted_t)
    topo_loss = l1_loss(fake_topology, real_topology)

    pc_loss = pearson_coor(target, predicted).to(device)
    torch.cuda.empty_cache()

    G_loss = loss_pix2pix + (1 - pc_loss) + topo_loss

    return G_loss


def Alignment_loss(target, predicted):
    # l_loss1 = torch.abs(nn.KLDivLoss()(F.softmax(zt1), F.softmax(z_s1.t())))

    kl_loss = torch.abs(F.kl_div(F.softmax(target), F.softmax(predicted), None, None, 'sum'))
    kl_loss = (1/350) * kl_loss
    return kl_loss

In [None]:
warnings.filterwarnings("ignore")
#  GAN
aligner = Aligner()
generator = Generator()
discriminator = Discriminator()
# Losses
adversarial_loss1 = torch.nn.BCELoss()
l1_loss = torch.nn.L1Loss()

# send 1st GAN to GPU
aligner.to(device)
generator.to(device)
discriminator.to(device)
adversarial_loss1.to(device)
l1_loss.to(device)

Aligner_optimizer = torch.optim.AdamW(aligner.parameters(), lr=0.025, betas=(0.5, 0.999))
generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=0.025, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=0.025, betas=(0.5, 0.999))
def IMANGraphNet (X_train_source, X_test_source, X_train_target, X_test_target):

    X_casted_train_source = cast_data_vector_RH(X_train_source)
    X_casted_test_source = cast_data_vector_RH(X_test_source)
    X_casted_train_target = cast_data_vector_FC(X_train_target)
    X_casted_test_target = cast_data_vector_FC(X_test_target)

    aligner.train()
    generator.train()
    discriminator.train()

    nbre_epochs = N_EPOCHS
    for epochs in range(nbre_epochs):
        # Train Generator
        with torch.autograd.set_detect_anomaly(True):
            Al_losses = []


            Ge_losses = []
            losses_discriminator = []

            i = 0
            for data_source, data_target in zip(X_casted_train_source, X_casted_train_target):
                # print(i)
                targett = data_target.edge_attr.view(N_TARGET_NODES, N_TARGET_NODES)
                # ************    Domain alignment    ************
                A_output = aligner(data_source)
                A_casted = convert_generated_to_graph_Al(A_output)
                A_casted = A_casted[0]

                target = data_target.edge_attr.view(N_TARGET_NODES, N_TARGET_NODES).detach().cpu().clone().numpy()
                target_mean = np.mean(target)
                target_std = np.std(target)

                d_target = torch.normal(target_mean, target_std, size=(1, N_SOURCE_NODES_F))
                dd_target = cast_data_vector_RH(d_target)
                dd_target = dd_target[0]
                target_d = dd_target.edge_attr.view(N_SOURCE_NODES, N_SOURCE_NODES)

                kl_loss = Alignment_loss(target_d, A_output)

                Al_losses.append(kl_loss)

                # ************     Super-resolution    ************
                G_output = generator(A_casted)  # 35 x 35
                # print("G_output: ", G_output.shape)
                G_output_reshaped = (G_output.view(1, N_TARGET_NODES, N_TARGET_NODES, 1).type(torch.FloatTensor)).detach()
                G_output_casted = convert_generated_to_graph(G_output_reshaped)
                G_output_casted = G_output_casted[0]
                torch.cuda.empty_cache()

                Gg_loss = GT_loss(targett, G_output)
                torch.cuda.empty_cache()
                D_real = discriminator(data_target)
                D_fake = discriminator(G_output_casted)
                torch.cuda.empty_cache()
                G_adversarial = adversarial_loss(D_fake, (torch.ones_like(D_fake, requires_grad=False)))
                G_loss = G_adversarial + Gg_loss
                Ge_losses.append(G_loss)

                D_real_loss = adversarial_loss(D_real, (torch.ones_like(D_real, requires_grad=False)))
                # torch.cuda.empty_cache()
                D_fake_loss = adversarial_loss(D_fake.detach(), torch.zeros_like(D_fake))
                D_loss = (D_real_loss + D_fake_loss) / 2
                # torch.cuda.empty_cache()
                losses_discriminator.append(D_loss)
                i += 1

            # torch.cuda.empty_cache()

            generator_optimizer.zero_grad()
            Ge_losses = torch.mean(torch.stack(Ge_losses))
            Ge_losses.backward(retain_graph=True)
            generator_optimizer.step()

            Aligner_optimizer.zero_grad()
            Al_losses = torch.mean(torch.stack(Al_losses))
            Al_losses.backward(retain_graph=True)
            Aligner_optimizer.step()


            discriminator_optimizer.zero_grad()
            losses_discriminator = torch.mean(torch.stack(losses_discriminator))
            losses_discriminator.backward(retain_graph=True)
            discriminator_optimizer.step()

        print("[Epoch: %d]| [Al loss: %f]| [Ge loss: %f]| [D loss: %f]" % (epochs, Al_losses, Ge_losses, losses_discriminator))

    torch.save(aligner.state_dict(), "./weight" + "aligner_fold" + "_" + ".model")
    torch.save(generator.state_dict(), "./weight" + "generator_fold" + "_" + ".model")

    torch.cuda.empty_cache()
    torch.cuda.empty_cache()

    # #     ######################################### TESTING PART #########################################
    restore_aligner = "./weight" + "aligner_fold" + "_" + ".model"
    restore_generator = "./weight" + "generator_fold" + "_" + ".model"

    aligner.load_state_dict(torch.load(restore_aligner))
    generator.load_state_dict(torch.load(restore_generator))

    aligner.eval()
    generator.eval()

    i = 0
    predicted_test_graphs = []
    losses_test = []
    eigenvector_losses_test = []
    l1_tests = []
    Closeness_test = []
    Eigenvector_test = []
    for data_source, data_target in zip(X_casted_test_source, X_casted_test_target):
        # print(i)
        data_source_test = data_source.x.view(N_SOURCE_NODES, N_SOURCE_NODES)
        data_target_test = data_target.x.view(N_TARGET_NODES, N_TARGET_NODES)


        A_test = aligner(data_source)
        A_test_casted = convert_generated_to_graph_Al(A_test)
        A_test_casted = A_test_casted[0]
        data_target = data_target_test.detach().cpu().clone().numpy()
        # ************     Super-resolution    ************
        G_output_test = generator(A_test_casted)  # 35 x35
        G_output_test_casted = convert_generated_to_graph(G_output_test)
        G_output_test_casted = G_output_test_casted[0]
        torch.cuda.empty_cache()

        L1_test = l1_loss(data_target_test, G_output_test)
        # fold= 1
        target_test = data_target_test.detach().cpu().clone().numpy()
        predicted_test = G_output_test.detach().cpu().clone().numpy()
        source_test = data_source_test.detach().cpu().clone().numpy()

        torch.cuda.empty_cache()
        fake_topology_test = torch.tensor(topological_measures(predicted_test))
        real_topology_test = torch.tensor(topological_measures(target_test))

        eigenvector_test = (l1_loss(fake_topology_test[2], real_topology_test[2]))


        l1_tests.append(L1_test.detach().cpu().numpy())
        Eigenvector_test.append(eigenvector_test.detach().cpu().numpy())



    mean_l1 = np.mean(l1_tests)
    mean_eigenvector = np.mean(Eigenvector_test)

    # print("Mean L1 Loss Test: ", fold_mean_l1_loss)
    # print()

    losses_test.append(mean_l1)
    eigenvector_losses_test.append(mean_eigenvector)

    # fold += 1
    return (predicted_test, data_target, source_test, losses_test, eigenvector_losses_test)


In [None]:
warnings.filterwarnings("ignore")


"""#Training"""

torch.cuda.empty_cache()
torch.cuda.empty_cache()

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("running on GPU")
else:
    device = torch.device("cpu")
    print("running on CPU")

source_data = np.random.normal(0, 0.5, (N_SUBJECTS, N_SOURCE_NODES_F))
target_data = np.random.normal(0, 0.5, (N_SUBJECTS, N_TARGET_NODES_F))

kf = KFold(n_splits=3, shuffle=True, random_state=1773)

fold = 0
losses_test = []
closeness_losses_test = []
# betweenness_losses_test = []
eigenvector_losses_test = []

for train_index, test_index in kf.split(source_data):
    # print( * "#" + " FOLD " + str(fold) + " " +  * "#")
    X_train_source, X_test_source, X_train_target, X_test_target = source_data[train_index], source_data[test_index], target_data[train_index], target_data[test_index]

    predicted_test, data_target, source_test, l1_test, eigenvector_test = IMANGraphNet(X_train_source, X_test_source, X_train_target, X_test_target)




test_mean = np.mean(l1_test)
Eigenvector_test_mean = np.mean(eigenvector_test)
plot_source(source_test)
plot_target(data_target)
plot_target(predicted_test)

print("Mean L1 Test", test_mean)

print("Mean Eigenvector Test", Eigenvector_test_mean)
