In [1]:
import sys
util_path = '../codes'
sys.path.append(util_path)
from model_graph import RGDiscrimininator, RGULGenerator, RGVAE


In [5]:
device = 'cuda:5'

In [6]:
generator = RGULGenerator(in_dim=128, initial_dim=128, \
            hidden_dims=[[64, 64], [64, 128], [256]], \
            final_layer_dims=[128, 256], hidden_act='relu', \
            edge_hidden_dim=8, \
            leaky_relu_coef=0.05, device=device, \
            skip_z=True, skip_z_dims=None, \
            unpool_bn=False, link_bn=False, \
            link_act='leaky', unpool_para=dict(add_additional_link=True, add_perference=True, \
                                roll_bn=False, roll_simple=True, add_bn=False), \
            attr_bn=False, edge_out_dim=None, \
            fix_points=[2, 0], roll_ones=[0, 0], node_nums = [3, 4, 8], \
            use_bn=False, final_bn=False, node_feature=3
            , skipz_bn=False
            )
discriminator = RGDiscrimininator(in_dim=3, hidden_dim=[128, 256], lin_hidden_dim=128, 
                out_hidden_dim=[128, 256], device=device, \
                useBN=False, droprate=None, outBN=False, out_drop=None, \
                final_layers=2, \
                conv_layers=2,
                last_act='linear', relu_coef=0.05, outdim=128)
vae = RGVAE(generator, discriminator, 128, 128, 
                lr=2e-4, beta=(0.5, 0.999), g_prob=True, 
                permutation=False, max_num_nodes=8, folder='vae_protein_1002', 
           device=device, lambda_rl=1e-1,
                       beta_node=2.0, beta_edge=3.0, beta_edge_total=1, beta_node_degree=1, beta_node_feature=1, 
           batch_size=64)
vae = vae.to(device)

In [10]:
from torch import nn
for param in vae.parameters():
    if param.dim() == 1:
        nn.init.constant(param, 0)
    else:
        nn.init.xavier_normal(param)

def weight_initiate(m):
    # Make inititiate with some variance.
    if(type(m) == nn.BatchNorm2d) or (type(m) == nn.modules.batchnorm.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif (type(m) == nn.BatchNorm1d) or (type(m) == nn.modules.batchnorm.BatchNorm1d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    else:
        if hasattr(m, "weight"):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.normal_(m.bias.data, 0.0, 0.02)
        if hasattr(m, "root") and m.root is not None:
            nn.init.normal_(m.bias.data, 0.0, 0.02)


_ = vae.decoder.apply(weight_initiate)

  nn.init.xavier_normal(param)
  nn.init.constant(param, 0)


In [11]:
import torch
from torch_geometric.loader import DataLoader
real_data = torch.load('protein_train.pt')
batch_size = 128
data_loader = DataLoader(real_data, batch_size=64, shuffle=True, follow_batch=['edge_index'])

In [None]:
vae.train(data_loader, epoch=range(200), beta=0, verbose_step=50, save_step=1000)

In [20]:
torch.save(vae, 'ULVAE_protein.pt')

In [21]:
_ = vae.encoder.eval()
_ = vae.decoder.eval()
_ = vae.z_mu.eval()
_ = vae.z_sigma.eval()
from model_random_graph import convert_Batch_to_datalist
test_data = []
with torch.no_grad():
    for attemp in range(160):
        z_rand = torch.randn(64, vae.decoder.in_dim)#*tot_h_data.std(axis=0).unsqueeze(0) + tot_h_data.mean(axis=0).unsqueeze(0)
        z_mu = vae.z_mu(z_rand)
        z_lsgms = vae.z_sigma(z_rand)
        z_sgm = z_lsgms.mul(0.5).exp_()
        eps = torch.randn(z_sgm.size()).to(vae.device)
        z = eps*z_sgm + z_mu
        data, _ = vae.decoder(z)
        data_list = convert_Batch_to_datalist(data.x, data.edge_index, batch=data.batch, edge_batch = data.edge_index_batch)
        test_data.extend(data_list)


In [23]:
from torch_geometric.utils import dense_to_sparse, to_dense_adj, to_dense_batch
from evaluate import kld_evaluation
kld_evaluation(real_data, test_data, distributions)

{'kl_connectivity': 0.7909808086944171,
 'wd_connectivity': 0.37333691406249997,
 'kl_edge_density': 0.4917165139725243,
 'wd_edge_density': 0.040913831208881576,
 'kl_clustering_coef': 0.8893497330796853,
 'wd_clustering_coef': 0.07206787060424498,
 'kl_avg_degrees': 0.4917165139725243,
 'wd_avg_degrees': 0.3273106496710526,
 'kl_node_features': [0.1054675920273398,
  0.1684474793322888,
  0.26862284392368063],
 'wd_node_features': [2.7561263463979624,
  7.0281805451755375,
  3.2462799598981746]}