In [1]:
code_folder = '../codes'
import sys
sys.path.append(code_folder)

In [2]:
### Import packages
import numpy as np
from rdkit import RDLogger    
from torch_geometric.data import Data
RDLogger.DisableLog('rdApp.*')
from util_gnn import draw_graph


In [3]:
from torch_geometric.data import DataLoader
import torch

In [4]:
### Load data
batch_size = 64
mol_data = torch.load('zinc_smiles_noar.pt')
data_list = torch.load('zinc_data_noar.pt')



In [6]:
# Just use 11-36 nodes graphs for training.
data_list = [Data(x=j.x, edge_index=j.edge_index, edge_attr=j.edge_attr, y = j.y) \
             for i, j in enumerate(data_list) if len(j.x) >= 11 and len(j.x) <= 36]
data_loader = DataLoader(data_list, batch_size=batch_size, shuffle=True, follow_batch=['edge_index', 'y'])

In [7]:
len(data_list)/249456

0.9982842665640433

In [8]:
from gcn_model_sim_summ_z2 import *
from ugcn_model_summ_2 import pre_GCNModel_edge_3eos


In [9]:
device = torch.device("cuda:2" if(torch.cuda.is_available()) else "cpu")


In [11]:
# Build discriminator
gcn_model = pre_GCNModel_edge_3eosZv4(in_dim=15, \
                                   hidden_dim=128, \
                                   edge_dim=3, \
                                   edge_hidden_dim=32, \
                                   lin_hidden_dim=128, \
                                   out_hidden_dim=256, \
                                   device=device, \
                                    check_batch=None, 
                                    useBN=True, \
                                    droprate=0.3, 
                                    pool_method='trivial', \
                                    add_edge_link=False, 
                                    add_conv_link=True, \
                                    outBN=True, out_drop=0.3, 
                                    out_divide=4.0, 
                                    add_edge_agg=False, 
                                    real_trivial=False, 
                                    final_layers=2, 
                                    add_trivial_feature=False, ln_comp=False, without_ar=True).to(device)

# This is a trivial discriminator...
d_add = pre_GCNModel_edge_3eosZ(in_dim=15, \
                                   hidden_dim=64, \
                                   edge_dim=3, \
                                   edge_hidden_dim=32, \
                                   lin_hidden_dim=64, \
                                   out_hidden_dim=128, \
                                   device=device, \
                                    check_batch=None, 
                                    useBN=False, \
                                    droprate=0.3, 
                                    pool_method='trivial', \
                                    add_edge_link=False, 
                                    add_conv_link=False, \
                                    outBN=False, out_drop=0.3, 
                                    out_divide=4.0, 
                                    add_edge_agg=False, 
                                    real_trivial=True, 
                                    final_layers=2, 
                                    add_trivial_feature=True, ln_comp=False, without_ar=True).to(device)
d_add.e1_ind = 0.0 # no means, just sums.

In [12]:
# Build generator
generator = UnpoolGeneratorZ(in_dim=128, \
                            edge_dim=3, 
                            node_dim=15, 
                            node_hidden_dim=[32, 32, 64, 64, 64, 128, 128, 128, 256], 
                            edge_hidden_dim=32, \
                            use_x_bn=True, 
                            use_e_bn=True, 
                            unpool_bn=True, 
                            link_bn=True, 
                            attr_bn=True, 
                            skip_z=True, 
                            skip_zdim=None, 
                            conv_type='nn', 
                            device=device, 
                            last_act='leaky', 
                            link_act='leaky', \
                            unpool_type='edge', unpool_para=dict(add_perference=True, roll_bn=False, roll_simple=True), 
                             without_ar=True).to(device)

In [17]:
from trainer import GANTrainer


In [19]:
train = GANTrainer(d=gcn_model, g=generator, \
                   rand_dim=128, train_folder='ULGAN_ZINC', \
                   tot_epoch_num=100, eval_iter_num=1000, \
                   batch_size=64, \
                   device=device, d_add=d_add, \
                   learning_rate_g=1e-3, learning_rate_d=2e-4, \
                   lambda_g=10.0, \
                   max_train_G=2, \
                   tresh_add_trainG=0.2, \
                   use_loss='wgan', \
                   g_out_prob=True, \
                   lambda_rl=2e-3, lambda_nonodes = 0., 
                   lambda_noedges = 0., zinc=True, without_ar=True
                  )

In [None]:
train.train(data_loader, verbose=False, use_data_x = 15, use_data_edgeattr=3, \
        evaluate_num=1000, mol_data=mol_data, alter_trainer=True, NN=200, \
            reinforce_acclerate=True
           )

[1/200][0/3892]	G Loss: 0.5876;D Loss: -1.3593; GP: 0.2405
now, we train G 2 times with (prob fake = -0.588, prob real = 0.775)
Mean x/edge attr:  tensor([0.7179, 0.0983, 0.1393, 0.0273, 0.0172, 0.0000, 0.0000, 0.0000, 0.0000,
        0.9828, 0.0172, 0.0000, 0.9777, 0.0036, 0.0187], device='cuda:2',
       grad_fn=<MeanBackward1>) tensor([0.7482, 0.2518, 0.0000], device='cuda:2', grad_fn=<MeanBackward1>) tensor(43.8750, device='cuda:2', grad_fn=<DivBackward0>)
size x/ some distribution: 21.765625 tensor([25., 23., 24., 24., 24., 21., 27., 16., 22., 24., 23., 26., 26., 23.,
        21., 24., 17., 23., 23., 23., 24., 20., 28., 14., 23., 23., 26., 28.,
        20., 19., 23., 22., 22., 18., 23., 24., 20., 25., 20., 26., 24., 15.,
        16., 23., 21., 19., 20., 24., 23., 20., 24., 20., 15., 20., 18., 23.,
        20., 17., 16., 13., 25., 21., 22., 27.], device='cuda:2')
Sample prob: -24.980266571044922
Validation, uniqueness, novelty:  (0.625, 1.0, 1.0)
[1/200][100/3892]	G Loss: 0.3311;D 

In [25]:
import os
candidate_evals = [(i,j ) for i,j in enumerate(train.evals) if j[0] > 0.9]
candidate_metrics = [np.prod(j) for i,j in enumerate(train.evals) if j[0] > 0.9]
choose_model = candidate_evals[np.argmax(candidate_metrics)][0]
generator = torch.load(os.path.join(train.folder, f'generator_{choose_model}.pt'))
torch.save(generator, 'ul_gan_zinc.pt')