In [1]:
import sys
import torch
from model_graph import RGDiscrimininator, RGULGenerator, RGVAE
util_path = '../codes'
sys.path.append(util_path)


In [5]:
device = 'cuda:5'
generator = RGULGenerator(in_dim=128, initial_dim=128, \
            hidden_dims=[[128], [128], [512]], \
            final_layer_dims=[64], hidden_act='relu', \
            edge_hidden_dim=16, \
            leaky_relu_coef=0.05, device=device, \
            skip_z=True, skip_z_dims=None, \
            unpool_bn=False, link_bn=False, \
            link_act='leaky', unpool_para=dict(add_additional_link=True, add_perference=True, \
                                roll_bn=False, roll_simple=True, add_bn=False), \
            attr_bn=False, edge_out_dim=None, \
            fix_points=[0, 0], roll_ones=[2, 5], node_nums = [3, 6, 12], \
            use_bn = False, skipz_bn=False, final_bn=False
            )
discriminator = RGDiscrimininator(in_dim=2, hidden_dim=[128, 256], lin_hidden_dim=128, 
                out_hidden_dim=[128, 256], device=device, \
                useBN=False, droprate=None, outBN=False, out_drop=None, \
                final_layers=2, \
                conv_layers=2,
                last_act='linear', relu_coef=0.05, outdim=128)

vae = RGVAE(generator, discriminator, 128, 128, 
                lr=1e-4, beta=(0.5, 0.999), g_prob=True, 
                permutation=False, max_num_nodes=12, folder='ulvae_rg', 
           device=device, lambda_rl=5e-2,
                       beta_node=1.0, beta_edge=1.0, beta_edge_total=1, beta_node_degree=1, beta_node_feature=1, 
           batch_size=64)
vae = vae.to(device)

In [9]:
from torch import nn
for param in vae.parameters():
    if param.dim() == 1:
        nn.init.constant(param, 0)
    else:
        nn.init.xavier_normal(param)

def weight_initiate(m):
    # Make inititiate with some variance.
    if(type(m) == nn.BatchNorm2d) or (type(m) == nn.modules.batchnorm.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif (type(m) == nn.BatchNorm1d) or (type(m) == nn.modules.batchnorm.BatchNorm1d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    else:
        if hasattr(m, "weight"):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.normal_(m.bias.data, 0.0, 0.02)
        if hasattr(m, "root") and m.root is not None:
            nn.init.normal_(m.bias.data, 0.0, 0.02)


_ = vae.decoder.apply(weight_initiate)

In [11]:
real_data = torch.load('rand_data.pt')

In [12]:
from torch_geometric.loader import DataLoader
batch_size = 64
data_loader = DataLoader(real_data, batch_size=batch_size, shuffle=True, follow_batch=['edge_index'])


In [None]:
vae.train(data_loader, epoch=range(200), beta=0.1, verbose_step=50, save_step=1000)

In [17]:
torch.save(vae, os.path.join(vae.folder, 'ul_vae.pt'))

In [18]:
from evaluate import convert_sample_to_space, graph_set_info
node_cnts, edge_cnts, node_features, degrees, dense_edge, connectivities, clustering = graph_set_info(real_data, return_raw=True)
distributions = node_cnts, edge_cnts, node_features, degrees, dense_edge, connectivities, clustering

In [21]:
# Generate data by vae.
_ = vae.encoder.eval()
_ = vae.decoder.eval()
_ = vae.z_mu.eval()
_ = vae.z_sigma.eval()
from model_random_graph import convert_Batch_to_datalist
test_data = []
with torch.no_grad():
    for attemp in range(160):
        z_rand = torch.randn(64, vae.decoder.in_dim)
        z_mu = vae.z_mu(z_rand)
        z_lsgms = vae.z_sigma(z_rand)
        z_sgm = z_lsgms.mul(0.5).exp_()
        eps = torch.randn(z_sgm.size()).to(vae.device)
        z = eps*z_sgm + z_mu
        data, _ = vae.decoder(z)
        data_list = convert_Batch_to_datalist(data.x, data.edge_index, batch=data.batch, edge_batch = data.edge_index_batch)
        test_data.extend(data_list)

In [23]:
from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch
from evaluate import kld_evaluation
kld_evaluation(real_data, test_data, distributions)

{'kl_connectivity': 0.10795093097872549,
 'wd_connectivity': 0.10466167716999164,
 'kl_edge_density': 0.09283738668703742,
 'wd_edge_density': 0.009848578279281248,
 'kl_clustering_coef': 0.4078172542221576,
 'wd_clustering_coef': 0.09955565111308703,
 'kl_avg_degrees': 0.1881074304649374,
 'wd_avg_degrees': 0.18302501576936875,
 'kl_node_features': [0.5568519278479547, 0.7129989260420687],
 'wd_node_features': [0.12277281242492655, 0.14029801737649358]}

In [24]:
node_cnts_test, edge_cnts_test, node_features_test, degrees_test, dense_edge_test, connectivities_test, clustering_test = graph_set_info(test_data, return_raw=True)
distributions_test = node_cnts_test, edge_cnts_test, node_features_test, degrees_test, dense_edge_test, connectivities_test, clustering_test
torch.save(distributions_test, 'UL_VAE_distributions.pt')