In [1]:
import numpy as np
import torch
import os
import pickle
import time
from Model import *
from itertools import chain
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
from synthetic import simulate_lorenz_96, simulate_var
from utils import build_flags, time_split, save_result, evaluate_result, count_accuracy, loss_sparsity, loss_divergence, loss_mmd, save_result

In [11]:
parser = build_flags()
args = parser.parse_args(args=[])
args.seed = 2
args.num_nodes = 10
args.dims = 1
args.threshold = 0.5
args.time_length = 500
args.time_step = 10
args.epochs = 3000
args.batch_size = 128
args.lr = 1e-3
args.weight_decay = 1e-3
args.encoder_alpha = 0.02
args.decoder_alpha = 0.04
args.beta_sparsity = 0.30 #0.25   #log_sum
args.beta_kl = 0.1        #JS散度
args.beta_mmd = 0.5  #1      #MMD
args.encoder_hidden = 20
args.decoder_hidden = 15 #20
args.encoder_dropout = 0.1
args.decoder_dropout = 0.3

In [12]:
X_np, GC = simulate_lorenz_96(p=10, F=10, T=500, seed=2)
X_np_ori = X_np
X_np = X_np.transpose(1, 0)
X_np = X_np[:, :, np.newaxis]
X_np = np.array(time_split(X_np, step=10))
X_np = torch.FloatTensor(X_np)
data = X_np
data_loader = DataLoader(data, batch_size=128)

In [13]:
for idx in range(10):
    print('Begin training feature: {:04d}'.format(idx + 1))
    decoder_file = 'decoder' + str(idx) + '.pt'
    decoder_file = os.path.join('/home/omnisky/Public/ChenRongfa/Intrer_VAE_result/Lorenz96.10.250/help', decoder_file)
    Inter_decoder = decoder(args.dims, args.decoder_hidden, args.time_step - 1, args.num_nodes, args.decoder_dropout, args.decoder_alpha)
    Inter_decoder = Inter_decoder.cuda()
    optimizer = optim.Adam(params = Inter_decoder.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
    loss_val = nn.MSELoss()
    best_loss = np.Inf
    for epoch in range(3000):
        scheduler.step()
        t = time.time()
        Loss = []
        mse_loss = []
        for batch_idx, data in enumerate(data_loader):
            data = data.cuda()
            target = data[:, idx, 1:, :]
            optimizer.zero_grad()
            inputs = data[:, :, :-1, :]
            pred = Inter_decoder(inputs, idx)
            mse = loss_val(pred, target)
            loss = mse

            loss.backward()
            optimizer.step()


            Loss.append(loss.item())
            mse_loss.append(mse.item())

        if epoch % 100 == 0:
            print('Feature: {:04d}'.format(idx + 1),
                'Epoch: {:04d}'.format(epoch),
                'Loss: {:.10f}'.format(np.mean(Loss)),
                'MSE_Loss: {:.10f}'.format(np.mean(mse_loss)),
                'time: {:.4f}s'.format(time.time() - t))

            
        if np.mean(mse_loss) < best_loss:
            best_loss = np.mean(mse_loss)
            torch.save(Inter_decoder.state_dict(), decoder_file)
            # print('Feature: {:04d}'.format(idx + 1),
            #       'Epoch: {:04d}'.format(epoch),
            #       'Loss: {:.10f}'.format(np.mean(Loss)),
            #       'mse_loss: {:.10f}'.format(np.mean(mse_loss)),
            #       'mmd_loss: {:.10f}'.format(np.mean(mmd_loss)),
            #       'time: {:.4f}s'.format(time.time() - t), file=log)


            

Begin training feature: 0001
Feature: 0001 Epoch: 0000 Loss: 22.3030369282 MSE_Loss: 22.3030369282 time: 0.0261s


  nn.init.xavier_normal(m.weight.data, gain=1.414)


Feature: 0001 Epoch: 0100 Loss: 13.9059495926 MSE_Loss: 13.9059495926 time: 0.0233s
Feature: 0001 Epoch: 0200 Loss: 9.4894406796 MSE_Loss: 9.4894406796 time: 0.0202s
Feature: 0001 Epoch: 0300 Loss: 7.0740630627 MSE_Loss: 7.0740630627 time: 0.0164s
Feature: 0001 Epoch: 0400 Loss: 6.1979525089 MSE_Loss: 6.1979525089 time: 0.0163s
Feature: 0001 Epoch: 0500 Loss: 4.7015128732 MSE_Loss: 4.7015128732 time: 0.0230s
Feature: 0001 Epoch: 0600 Loss: 4.8551487327 MSE_Loss: 4.8551487327 time: 0.0232s
Feature: 0001 Epoch: 0700 Loss: 3.7697503567 MSE_Loss: 3.7697503567 time: 0.0232s
Feature: 0001 Epoch: 0800 Loss: 4.3773422837 MSE_Loss: 4.3773422837 time: 0.0224s
Feature: 0001 Epoch: 0900 Loss: 4.0970312953 MSE_Loss: 4.0970312953 time: 0.0227s
Feature: 0001 Epoch: 1000 Loss: 3.9361172915 MSE_Loss: 3.9361172915 time: 0.0228s
Feature: 0001 Epoch: 1100 Loss: 4.5730556250 MSE_Loss: 4.5730556250 time: 0.0233s
Feature: 0001 Epoch: 1200 Loss: 4.4327391386 MSE_Loss: 4.4327391386 time: 0.0230s
Feature: 0001 

In [14]:
adj = []
for idx in range(10):
    decoder_file = 'decoder' + str(idx) + '.pt'
    decoder_file = os.path.join('/home/omnisky/Public/ChenRongfa/Intrer_VAE_result/Lorenz96.10.250/help', decoder_file)
    decoder_net = decoder(args.dims, args.decoder_hidden, args.time_step - 1, args.num_nodes, args.decoder_dropout, args.decoder_alpha)
    decoder_net.load_state_dict(torch.load(decoder_file))
    adj.append(decoder_net.adj[idx, :])

In [15]:
init_adj = torch.cat([temp.unsqueeze(0) for temp in adj], dim=0)
init_adj = init_adj.clone().detach()
print(init_adj.shape)

torch.Size([10, 10])


In [16]:
result, _ = evaluate_result(GC, init_adj.detach().numpy(), 0.5)
print(result)

{'accuracy': 0.69, 'precision': 0.9090909090909091, 'recall': 0.25, 'F1': 0.392156862745098, 'ROC_AUC': 0.9470833333333333, 'PR_AUC': 0.9215491348619778}


In [17]:
for mmd in [0.5]:
    for sparsity in [0.25]:
        for beta_kl in [0.01, 0.05]:

            args.beta_mmd = mmd
            args.beta_sparsity = sparsity
            args.beta_kl = beta_kl
            root_fodler = r'/home/omnisky/Public/ChenRongfa/Intrer_VAE_result/11Lorenz96.' + str(args.num_nodes) + '.' + str(args.time_length) + '/mmd' + str(args.beta_mmd) + '/sparsity' + str(args.beta_sparsity) + '/beta_kl' + str(args.beta_kl)
            if not os.path.exists(root_fodler):
                os.makedirs(root_fodler)
            for idx in range(10):
                print('Begin training feature: {:04d}'.format(idx + 1))
                decoder_file = 'decoder' + str(idx) + '.pt'
                decoder_file = os.path.join('/home/omnisky/Public/ChenRongfa/Intrer_VAE_result/Lorenz96.10.250/help', decoder_file)
                encoder_file = 'encoder' + str(idx) + '.pt'
                encoder_file = os.path.join(root_fodler, encoder_file)

                Inter_decoder = decoder(args.dims, args.decoder_hidden, args.time_step - 1, args.num_nodes, args.decoder_dropout, args.decoder_alpha)
                Inter_decoder.load_state_dict(torch.load(decoder_file))
                Inter_decoder = Inter_decoder.cuda()
                Inter_decoder.eval()

                Inter_encoder = encoder(init_adj, args.dims, args.encoder_hidden, args.dims, args.time_step - 1, args.encoder_dropout, args.encoder_alpha)
                Inter_encoder = Inter_encoder.cuda()

                optimizer = optim.Adam(Inter_encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay)
                loss_mse = nn.MSELoss()
                best_loss = np.inf
                    
                for epoch in range(2000):
                    t = time.time()
                    Loss = []
                    MSE_loss = []
                    SPA_loss = []
                    KL_loss = []
                    MMD_loss = []
                    for batch_idx, data in enumerate(data_loader):
                        optimizer.zero_grad()
                        data = data.cuda()
                        target = data[:, idx, 1:, :]
                        inputs = data[:, :, :-1, :]

                        mu, log_var = Inter_encoder(inputs)  #Inter_encoder(inputs, adj)
                        sigma = torch.exp(log_var / 2)
                        # sigma2 = torch.exp(log_var2 / 2)
                        gamma = torch.randn(size = mu.size()).cuda()
                        # theta = torch.randn(size = mu2.size()).cuda()
                        gamma = mu + sigma * gamma
                        # theta = mu2 + sigma2 * theta
                        mask = torch.sigmoid(gamma) #* torch.sigmoid(theta) #* torch.sigmoid(theta + gamma)
                        # gamma = torch.sigmoid(gamma)
                        # theta = torch.sigmoid(theta)

                        inputs = mask_inputs(mask, inputs)
                        pred = Inter_decoder(inputs, idx)   #Inter_decoder(inputs, adj, idx)



                        mse_loss = loss_mse(pred, target)
                        spa_loss = loss_sparsity(mask, 'log_sum')
                        kl_loss = loss_divergence(mask, 'JS')
                        mmd_loss = loss_mmd(data[:, :, 1:, :], pred, idx)

                        loss = mse_loss + args.beta_sparsity * spa_loss + args.beta_kl * kl_loss + args.beta_mmd * mmd_loss

                        loss.backward()
                        optimizer.step()

                        Loss.append(loss.item())
                        MSE_loss.append(mse_loss.item())
                        SPA_loss.append(spa_loss.item())
                        KL_loss.append(kl_loss.item())
                        MMD_loss.append(mmd_loss.item())
                        
                    # if epoch == 500:
                    #     optimizer.param_groups[0]['lr'] = args.lr/10

                    if epoch % 100 == 0:
                        print(  'Feature: {:04d}'.format(idx + 1),
                                'Epoch: {:04d}'.format(epoch),
                                'Loss: {:.10f}'.format(np.mean(Loss)),
                                'MSE_Loss: {:.10f}'.format(np.mean(MSE_loss)),
                                'Sparsity_loss: {:.10f}'.format(np.mean(SPA_loss)),
                                'KL_loss: {:.10f}'.format(np.mean(KL_loss)),
                                'MMD_loss: {:.10f}'.format(np.mean(MMD_loss)),
                                'time: {:.4f}s'.format(time.time() - t))
                                
                    if np.mean(Loss) < best_loss:
                        best_loss = np.mean(Loss)
                        #M[idx, :] = 
                        # gamma_matrix[idx, :] = gamma.squeeze().mean(dim=2).mean(dim=1)
                        # theta_matrix[idx, :] = theta.squeeze().mean(dim=2).mean(dim=1)
                        # torch.save({
                        #             'encoder_state_dict': Inter_encoder.state_dict(),
                        #             'decoder_state_dict': Inter_decoder.state_dict(),
                        #                 # 'adj' : adj

                        #             }, encoder_file)
                        torch.save(Inter_encoder.state_dict(), encoder_file)
                            # np.save(save_file + str(idx) + '.npy', mask.cpu().detach().numpy())

                        # print('Feature: {:04d}'.format(idx + 1),
                        #       'Epoch: {:04d}'.format(epoch),
                        #       'Loss: {:.10f}'.format(np.mean(Loss)),
                        #       'mse_loss: {:.10f}'.format(np.mean(MSE_loss)),
                        #       'Sparsity_loss: {:.10f}'.format(np.mean(SPA_loss)),
                        #       'KL_loss: {:.10f}'.format(np.mean(KL_loss)),
                        #       #'mmd_loss: {:.10f}'.format(np.mean(mmd_loss)),
                        #       # 'time: {:.4f}s'.format(time.time() - t), file=log)
            causality_matrix = []
            for idx in range(10):
                encoder_file = 'encoder' + str(idx) + '.pt'
                encoder_file = os.path.join(root_fodler, encoder_file)
                est_net = encoder(init_adj, args.dims, args.encoder_hidden, args.dims, args.time_step - 1, args.encoder_dropout, args.encoder_alpha)
                est_net.load_state_dict(torch.load(encoder_file))
                est_net.eval()
                inputs = X_np[:, :, :-1, :]#.cuda()   #:-1和1:有什么区别
                mu, log_var = est_net(inputs)
                sigma = torch.exp(log_var / 2)
                gamma = torch.randn(size = mu.size())
                gamma = mu + sigma * gamma
                mask_matrix = torch.sigmoid(gamma) #* torch.sigmoid(theta)
                mask_matrix = mask_matrix.squeeze()
                causality_matrix.append(mask_matrix)
            causality_matrix = torch.stack(causality_matrix, dim=1)
            adj_gca = causality_matrix.mean(dim=3).mean(dim=0)
            result, _ = evaluate_result(GC, adj_gca.detach().numpy(), 0.5)
            np.save(root_fodler, 'adj.npy', adj_gca.detach().numpy())
            save_result(result, 'encoder', root_fodler)
            




Begin training feature: 0001
Feature: 0001 Epoch: 0000 Loss: 17.9841623306 MSE_Loss: 15.8561654091 Sparsity_loss: 8.3888163567 KL_loss: 0.0736840013 MMD_loss: 0.0601114146 time: 0.0517s


  nn.init.xavier_normal(m.weight.data, gain=1.414)


Feature: 0001 Epoch: 0100 Loss: 10.0013122559 MSE_Loss: 8.3927799463 Sparsity_loss: 6.3209041357 KL_loss: 0.0789477415 MMD_loss: 0.0550340116 time: 0.0464s
Feature: 0001 Epoch: 0200 Loss: 9.2066121101 MSE_Loss: 8.4056054354 Sparsity_loss: 3.0913723111 KL_loss: 0.0507447869 MMD_loss: 0.0553122722 time: 0.0476s
Feature: 0001 Epoch: 0300 Loss: 8.9428900480 MSE_Loss: 8.5026884079 Sparsity_loss: 1.6474025846 KL_loss: 0.0489704264 MMD_loss: 0.0557225198 time: 0.0451s
Feature: 0001 Epoch: 0400 Loss: 8.9056705236 MSE_Loss: 8.5447691679 Sparsity_loss: 1.3295034170 KL_loss: 0.0458138222 MMD_loss: 0.0561346970 time: 0.0450s
Feature: 0001 Epoch: 0500 Loss: 8.8814654350 MSE_Loss: 8.5410394669 Sparsity_loss: 1.2473345995 KL_loss: 0.0452468842 MMD_loss: 0.0562800094 time: 0.0447s
Feature: 0001 Epoch: 0600 Loss: 8.8964592218 MSE_Loss: 8.5669318438 Sparsity_loss: 1.2042451799 KL_loss: 0.0447562085 MMD_loss: 0.0560366437 time: 0.0446s
Feature: 0001 Epoch: 0700 Loss: 8.8769940138 MSE_Loss: 8.5451002121 S

  result.to_excel(filename)
  nn.init.xavier_normal(m.weight.data, gain=1.414)


Begin training feature: 0001
Feature: 0001 Epoch: 0000 Loss: 18.0316145420 MSE_Loss: 15.9031710625 Sparsity_loss: 8.3801233768 KL_loss: 0.0732744671 MMD_loss: 0.0594986603 time: 0.0406s
Feature: 0001 Epoch: 0100 Loss: 10.0753629208 MSE_Loss: 8.3936897516 Sparsity_loss: 6.6009343863 KL_loss: 0.0775318500 MMD_loss: 0.0551260784 time: 0.0491s
Feature: 0001 Epoch: 0200 Loss: 9.3602738380 MSE_Loss: 8.3984992504 Sparsity_loss: 3.7265461087 KL_loss: 0.0513949450 MMD_loss: 0.0551370606 time: 0.0492s
Feature: 0001 Epoch: 0300 Loss: 9.0289897919 MSE_Loss: 8.4877834320 Sparsity_loss: 2.0450813174 KL_loss: 0.0430744560 MMD_loss: 0.0555649363 time: 0.0460s
Feature: 0001 Epoch: 0400 Loss: 8.9482084513 MSE_Loss: 8.5342736244 Sparsity_loss: 1.5337091684 KL_loss: 0.0498676449 MMD_loss: 0.0560285598 time: 0.0469s
Feature: 0001 Epoch: 0500 Loss: 8.9118061066 MSE_Loss: 8.5466321707 Sparsity_loss: 1.3388533592 KL_loss: 0.0464617414 MMD_loss: 0.0562752821 time: 0.0452s
Feature: 0001 Epoch: 0600 Loss: 8.9040

In [8]:
for idx in range(20):
    encoder_file = 'encoder' + str(idx) + '.pt'
    encoder_file = os.path.join('/home/jing_xuzijian/crf/Intrer_VAE_result/Lorenz96.20.500/help', encoder_file)
    est_net = encoder(init_adj, args.dims, args.encoder_hidden, args.dims, args.time_step - 1, args.encoder_dropout, args.encoder_alpha)
    est_net.load_state_dict(torch.load(encoder_file))
    est_adj = est_net.adj.cpu().detach().numpy()
    result, _ = evaluate_result(GC, est_adj, 0.5)
    print(result)

{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accuracy': 0.855, 'precision': 1.0, 'recall': 0.275, 'F1': 0.4313725490196079, 'ROC_AUC': 0.9642578125000001, 'PR_AUC': 0.9109595841158191}
{'accu

  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)
  nn.init.xavier_normal(m.weight.data, gain=1.414)


In [9]:
est_adj = est_net.adj.clone().detach()
result, _ = evaluate_result(GC, est_adj, 0.5)
print(result)

{'accuracy': 0.8025, 'precision': 1.0, 'recall': 0.0125, 'F1': 0.02469135802469136, 'ROC_AUC': 0.467734375, 'PR_AUC': 0.2090953304192029}


In [30]:
causality_matrix = []
total_gamma_matrix = []
total_theta_matrix = []
#init_adj = torch.eye(20)
for idx in range(10):
    encoder_file = 'encoder' + str(idx) + '.pt'
    encoder_file = os.path.join('/home/omnisky/Public/ChenRongfa/Intrer_VAE_result/Lorenz96.10.250/help', encoder_file)
    est_net = encoder(init_adj, args.dims, args.encoder_hidden, args.dims, args.time_step - 1, args.encoder_dropout, args.encoder_alpha)
    est_net.load_state_dict(torch.load(encoder_file))
    # est_net = est_net.cuda()
    est_net.eval()
    inputs = X_np[:, :, :-1, :]#.cuda()   #:-1和1:有什么区别
    mu, log_var = est_net(inputs)
    # mu = mu.cpu().detach()
    # log_var = log_var.cpu().detach()
    sigma = torch.exp(log_var / 2)
    # sigma2 = torch.exp(log_var2 / 2)
    gamma = torch.randn(size = mu.size())
    # theta = torch.randn(size = mu1.size())
    gamma = mu + sigma * gamma
    # theta = mu2 + sigma2* theta
    mask_matrix = torch.sigmoid(gamma) #* torch.sigmoid(theta)
    mask_matrix = mask_matrix.squeeze()
    causality_matrix.append(mask_matrix)
    # gamma_matrix = torch.sigmoid(gamma)
    # gamma_matrix = gamma_matrix.squeeze()
    # total_gamma_matrix.append(gamma_matrix)
    # theta_matrix = torch.sigmoid(theta)
    # theta_matrix = theta_matrix.squeeze()
    # total_theta_matrix.append(theta_matrix)

causality_matrix = torch.stack(causality_matrix, dim=1)
# total_gamma_matrix = torch.stack(total_gamma_matrix, dim=1)
# total_theta_matrix = torch.stack(total_theta_matrix, dim=1)

In [31]:
adj_gca = causality_matrix.mean(dim=3).mean(dim=0)
# gamma_adj = total_gamma_matrix.mean(dim=3).mean(dim=0)
# theta_adj = total_theta_matrix.mean(dim=3).mean(dim=0)
print(adj_gca.shape)

torch.Size([10, 10])


In [52]:
print(theta_adj)

tensor([[7.6896e-01, 7.8296e-02, 5.1112e-05, 5.2073e-05, 5.6759e-05, 4.9417e-02,
         5.4774e-05, 5.5814e-05, 5.4371e-05, 5.4734e-05, 5.6489e-05, 5.5717e-05,
         5.4487e-05, 5.5353e-05, 7.1983e-02, 5.1452e-05, 2.6985e-02, 1.0124e-01,
         3.2742e-01, 4.2244e-01],
        [1.2391e-01, 8.4173e-01, 5.9056e-03, 2.3101e-03, 2.3056e-03, 2.3278e-03,
         2.3600e-03, 2.2972e-03, 2.3250e-03, 2.2922e-03, 2.2934e-03, 2.3774e-03,
         2.3698e-03, 2.2588e-03, 2.3671e-03, 2.3379e-03, 2.3117e-03, 2.3549e-03,
         2.3715e-03, 3.0842e-01],
        [2.4178e-01, 3.8542e-01, 9.8017e-01, 1.8064e-01, 3.7534e-05, 3.9863e-05,
         3.8117e-05, 2.2170e-03, 7.2371e-02, 3.7368e-05, 3.9130e-05, 1.9649e-03,
         3.9108e-05, 3.8433e-05, 3.7392e-05, 4.0173e-05, 3.8965e-05, 3.8006e-05,
         3.9025e-05, 2.4334e-03],
        [7.9220e-05, 3.2196e-01, 3.1193e-01, 8.7320e-01, 2.5348e-03, 1.0796e-01,
         8.2627e-05, 7.7709e-04, 8.1066e-05, 8.2243e-05, 8.0188e-05, 8.1518e-05,
       

In [51]:
print(adj_gca)
print(GC[1, :])

tensor([[7.4909e-01, 7.4239e-02, 1.0852e-07, 1.0922e-07, 1.2867e-07, 4.6770e-02,
         1.2158e-07, 1.2443e-07, 1.2163e-07, 1.1989e-07, 1.2625e-07, 1.2637e-07,
         1.2013e-07, 1.2420e-07, 6.7696e-02, 1.1373e-07, 2.5288e-02, 9.6654e-02,
         3.1617e-01, 4.0582e-01],
        [1.1625e-01, 8.3866e-01, 4.9221e-03, 2.5222e-07, 2.6218e-07, 2.7182e-07,
         2.5677e-07, 2.6314e-07, 2.6949e-07, 2.5992e-07, 2.6494e-07, 2.7781e-07,
         2.6403e-07, 2.6299e-07, 2.7177e-07, 2.6159e-07, 2.5859e-07, 2.7005e-07,
         2.7816e-07, 2.9851e-01],
        [2.3491e-01, 3.7936e-01, 9.7944e-01, 1.7477e-01, 1.1422e-08, 1.2310e-08,
         1.1873e-08, 2.0021e-03, 6.9232e-02, 1.0887e-08, 1.1675e-08, 1.7427e-03,
         1.2874e-08, 1.2228e-08, 1.1453e-08, 1.2806e-08, 1.1806e-08, 1.1823e-08,
         1.1806e-08, 2.1349e-03],
        [1.5731e-08, 3.1262e-01, 2.9733e-01, 8.4128e-01, 2.0902e-03, 1.0322e-01,
         1.7735e-08, 6.6305e-04, 1.6666e-08, 1.7288e-08, 1.6466e-08, 1.7279e-08,
       

In [33]:
result, _ = evaluate_result(GC, adj_gca.detach().numpy(), 0.5)
print(result)

{'accuracy': 0.76, 'precision': 1.0, 'recall': 0.4, 'F1': 0.5714285714285715, 'ROC_AUC': 0.8783333333333333, 'PR_AUC': 0.8890493942929301}


In [31]:
print(GC)

[[1 1 0 0 0 0 0 0 1 1]
 [1 1 1 0 0 0 0 0 0 1]
 [1 1 1 1 0 0 0 0 0 0]
 [0 1 1 1 1 0 0 0 0 0]
 [0 0 1 1 1 1 0 0 0 0]
 [0 0 0 1 1 1 1 0 0 0]
 [0 0 0 0 1 1 1 1 0 0]
 [0 0 0 0 0 1 1 1 1 0]
 [0 0 0 0 0 0 1 1 1 1]
 [1 0 0 0 0 0 0 1 1 1]]


In [26]:
print(adj_gca)

tensor([[9.8090e-01, 2.4526e-01, 3.0937e-08, 2.7574e-08, 3.4536e-08, 8.7615e-02,
         3.8875e-08, 4.2935e-02, 1.9347e-02, 4.1903e-08],
        [2.2222e-01, 9.8181e-01, 2.2222e-01, 2.2222e-01, 2.2222e-01, 2.2222e-01,
         2.2222e-01, 2.2222e-01, 2.2222e-01, 2.2222e-01],
        [1.0000e+00, 1.6369e-01, 1.0000e+00, 8.2401e-02, 1.7285e-07, 1.9243e-07,
         9.6406e-02, 1.8654e-07, 1.3874e-07, 2.0195e-07],
        [4.3069e-07, 7.3208e-02, 4.6705e-01, 9.9998e-01, 3.7583e-07, 3.4952e-02,
         3.5127e-07, 5.4465e-02, 3.3362e-07, 3.4672e-07],
        [9.0340e-07, 9.2341e-07, 2.4287e-02, 9.2389e-07, 9.9945e-01, 9.5101e-07,
         9.4541e-07, 9.4353e-07, 9.4964e-07, 9.3344e-07],
        [9.6907e-04, 5.4347e-08, 6.6765e-08, 4.4504e-02, 1.0755e-01, 9.9631e-01,
         1.0056e-01, 1.0439e-02, 4.9131e-08, 6.4329e-08],
        [2.5810e-07, 2.6611e-07, 3.2856e-07, 1.9583e-07, 1.8355e-01, 9.9956e-01,
         1.0000e+00, 4.2288e-07, 2.5699e-07, 3.1616e-07],
        [2.2524e-07, 4.7017