In [2]:
import numpy as np
import pandas as pd
import networkx as nx
import torch
import sys
util_path = '../codes/'
sys.path.append(util_path)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from networkx import average_clustering, clustering, average_node_connectivity

In [4]:
from networkx import Graph

In [147]:
from model_graph import RGDiscrimininator, RGULGenerator, GANTrainerComb
from evaluate import convert_sample_to_space, graph_set_info

In [7]:
all_g = torch.load('rand_data.pt')

In [11]:
from torch_geometric.loader import DataLoader

device = 'cuda:5'
generator = RGULGenerator(in_dim=128, initial_dim=128, \
            hidden_dims=[[64], [128], [512]], \
            final_layer_dims=[64], hidden_act='relu', \
            edge_hidden_dim=8, \
            leaky_relu_coef=0.05, device=device, \
            skip_z=True, skip_z_dims=None, \
            unpool_bn=True, link_bn=True, \
            link_act='leaky', unpool_para=dict(add_additional_link=True, add_perference=True, \
                                roll_bn=True, roll_simple=True, add_bn=False), \
            attr_bn=True, edge_out_dim=None, \
            fix_points=[0, 0], roll_ones=[2, 5], node_nums = [3, 6, 12], \
            )


discriminator = RGDiscrimininator(in_dim=2, hidden_dim=[64, 128], lin_hidden_dim=128, 
                out_hidden_dim=[128, 258], device=device, \
                useBN=True, droprate=None, outBN=True, out_drop=0.3, \
                final_layers=2, \
                conv_layers=2,
                last_act='sigmoid', relu_coef=0.05)

generator = generator.to(device)
discriminator = discriminator.to(device)
batch_size = 64
data_loader = DataLoader(all_g, batch_size=batch_size, shuffle=True, follow_batch=['edge_index'])


In [18]:
trainer = GANTrainerComb(discriminator, generator, rand_dim=128, train_folder='random_graph_0905_gpu', \
                   tot_epoch_num=500, eval_iter_num=100, batch_size=64, \
                   device=device, d_add=None, learning_rate_g=2e-5, learning_rate_d=1e-4, \
                   lambda_g=0.0, max_train_G=5, tresh_add_trainG=0.2, \
                   use_loss='bce', \
                   g_out_prob=True, lambda_rl=1.0, \
                   lambda_nonodes = 0.,
                   lambda_noedges = 0.,
                   trainD=True, \
                   initial_weight=True
            )

In [None]:
trainer.train(data_loader, verbose=False, \
        NN=200, evaluate_num=1000, mol_data=None, \
        alter_trainer=False, reinforce_acclerate=True) #, only_train=None)

In [19]:
torch.save(test_data, 'UL_GAN_random_graph.pt')

In [148]:
node_cnts, edge_cnts, node_features, degrees, dense_edge, connectivities, clustering = graph_set_info(all_g, return_raw=True)

In [39]:
import os

In [149]:
distributions = node_cnts, edge_cnts, node_features, degrees, dense_edge, connectivities, clustering

In [150]:
from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch
from evaluate import kld_evaluation
kld_evaluation(all_g, test_data, distributions)

{'kl_connectivity': 0.024769900943519815,
 'wd_connectivity': 0.048033344245805215,
 'kl_edge_density': 0.05628528797056805,
 'wd_edge_density': 0.010992660790975398,
 'kl_clustering_coef': 0.030925141305384186,
 'wd_clustering_coef': 0.038220395449757374,
 'kl_avg_degrees': 0.1077621244276535,
 'wd_avg_degrees': 0.05911188757021568,
 'kl_node_features': [0.16375693393765967, 0.19472734861483174],
 'wd_node_features': [0.07884708782282389, 0.09150449185723417]}

In [158]:
_ = generator.eval()
test_data = []
with torch.no_grad():
    for attemp in range(160):
        z = (torch.rand(64, 128) * 2 - 1).to(device)
        data, x = generator(z)

        from model_random_graph import convert_Batch_to_datalist
        data_list = convert_Batch_to_datalist(data.x, data.edge_index, batch=data.batch, edge_batch = data.edge_index_batch)
        test_data.extend(data_list)



In [159]:
from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch
from evaluate import kld_evaluation
kld_evaluation(all_g, test_data, distributions)

{'kl_connectivity': 0.010564814866864699,
 'wd_connectivity': 0.03374130897432577,
 'kl_edge_density': 0.003672168791364585,
 'wd_edge_density': 0.0009746384686460247,
 'kl_clustering_coef': 0.02703565696837029,
 'wd_clustering_coef': 0.022620992844200254,
 'kl_avg_degrees': 0.021279956097155616,
 'wd_avg_degrees': 0.043432854367453325,
 'kl_node_features': [0.23175848875031635, 0.05813050802825191],
 'wd_node_features': [0.04016307662315018, 0.04367399380382586]}

In [164]:
np.mean( [0.23175848875031635, 0.05813050802825191]), np.mean([0.04016307662315018, 0.04367399380382586])

(0.14494449838928414, 0.041918535213488026)

In [162]:
node_cnts_test, edge_cnts_test, node_features_test, degrees_test, dense_edge_test, connectivities_test, clustering_test = graph_set_info(test_data, return_raw=True)
distributions_test = node_cnts_test, edge_cnts_test, node_features_test, degrees_test, dense_edge_test, connectivities_test, clustering_test
torch.save(distributions_test, 'UL_GAN_distributions.pt')