In [1]:
import copy
from dgl import batch, unbatch
import rdkit.Chem as Chem
from rdkit import RDLogger
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
np.random.seed(0)

import os,sys,inspect
sys.path.insert(0,'/home/icarus/app/src') 

from models.jtnn_vae import DGLJTNNVAE
from models.modules import *

#from .nnutils import cuda, move_dgl_to_cuda

#from .jtnn_enc import DGLJTNNEncoder
#from .jtnn_dec import DGLJTNNDecoder
#from .mpn import DGLMPN
#from .mpn import mol2dgl_single as mol2dgl_enc
#from .jtmpn import DGLJTMPN
#from .jtmpn import mol2dgl_single as mol2dgl_dec

import os,sys,inspect
from tqdm import tqdm
sys.path.insert(0,'/home/icarus/T-CVAE-MolGen/src')

from utils.chem import set_atommap, copy_edit_mol, enum_assemble_nx, \
                            attach_mols_nx, decode_stereo

lg = RDLogger.logger()

lg.setLevel(RDLogger.CRITICAL)



In [2]:
class TCVAE(nn.Module):
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(TCVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth
        
        self.embedding = nn.Embedding(vocab.size(), hidden_size).cuda()
        self.mpn = MPN(hidden_size, depth)
        self.encoder = None  #
        self.decoder = None  #
        
        self.T_mean, self.T_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
        self.G_mean, self.G_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
            
        self.n_nodes_total = 0
        self.n_edges_total = 0
        self.n_passes = 0

## Posterior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$q(z|x,y) \sim N(\mu,\sigma^2I)$, where

$\quad\quad h = \text{MultiHead}(c,E_\text{out}^L(x;y),E_\text{out}^L(x;y))$

$\quad\quad \begin{bmatrix}\mu\\\log(\sigma^2)\end{bmatrix} = hW_q+b_q$

In [3]:
    def sample_posterior(self, prob_decode=False):
        return

### Prior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$p_\theta (z|x) \sim N(\mu', \sigma'^2 I)$, where:

$\quad\quad h' = \text{MultiHead}(c, E_\text{out}^L(x), E_\text{out}^L(x))$

$\quad\quad \begin{bmatrix}\mu'\\\log(\sigma'^2)\end{bmatrix} = MLP_p(h')$

In [4]:
    def sample_prior(self, prob_decode=False):
        return

In [5]:
from utils.data import JTNNDataset, JTNNCollator
#torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

class ArgsTemp():
    def __init__(self, hidden_size, depth, device):
        self.hidden_size = hidden_size
        self.batch_size = 350
        self.latent_size = 56
        self.depth = depth
        self.device = device
        self.lr = 1e-3
        self.beta = 1.0
        self.use_cuda = torch.cuda.is_available()
        
args = ArgsTemp(200,3, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(args.depth)

dataset = JTNNDataset(data='train', vocab='vocab', training=True, intermediates=False)
vocab = dataset.vocab
"""
dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=1,
        collate_fn=JTNNCollator(vocab, True, intermediates=False),
        drop_last=True,
        worker_init_fn=None)
        """

model = DGLJTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth, args).cuda()
model.share_memory()
#if torch.cuda.device_count() > 1:
  #print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  #model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)


3
Loading data...
Loading finished.
	Num samples: 220011
	Vocab file: /home/icarus/.dgl/jtnn/vocab.txt
Initializing Embedding
Initializing MPN
Initializing JTNN
Initializing Decoder
Initializing JTMPN


In [6]:
MAX_EPOCH = 50
PRINT_ITER = 20

from tqdm import tqdm
from os import access, R_OK
from os.path import isdir
import sys

save_path = '/home/icarus/app/data/05_model_output'
assert isdir(save_path) and access(save_path, R_OK), \
       "File {} doesn't exist or isn't readable".format(save_path)

def train():
    dataset.training = True
    print("Loading data...")
    dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=12,
            collate_fn=JTNNCollator(vocab, True),
            drop_last=True,
            worker_init_fn=None)
    dataloader._use_shared_memory = False
    last_loss = sys.maxsize
    print("Beginning Training...")
    for epoch in range(MAX_EPOCH):
        word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
        print("Epoch %d: " % epoch)

        for it, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
            model.zero_grad()
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta)
            except:
                print([t.smiles for t in batch['mol_trees']])
                raise
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            cur_loss = loss.item()
            
            if (it + 1) % PRINT_ITER == 0:
                word_acc = word_acc / PRINT_ITER * 100
                topo_acc = topo_acc / PRINT_ITER * 100
                assm_acc = assm_acc / PRINT_ITER * 100
                steo_acc = steo_acc / PRINT_ITER * 100

                print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f, Delta: %.6f" % (
                    kl_div, word_acc, topo_acc, assm_acc, steo_acc, cur_loss, last_loss - cur_loss))
                word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
                sys.stdout.flush()

            if (it + 1) % 1500 == 0: #Fast annealing
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
                
            if (it + 1) % 100 == 0:
                torch.save(model.state_dict(),
                            save_path + "/model.iter-%d-%d" % (epoch, it + 1))
                      
            #if last_loss - cur_loss < 1e-5:
            #    break
            last_loss = cur_loss

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
        torch.save(model.state_dict(), save_path + "/model.iter-" + str(epoch))

In [7]:
import warnings; warnings.simplefilter('ignore')
train() 

Loading data...
Beginning Training...
Epoch 0: 


  3%|▎         | 19/628 [03:48<1:09:03,  6.80s/it]

KL: 0.1, Word: 20.00, Topo: 73.63, Assm: 31.93, Steo: 27.29, Loss: 66.690994, Delta: -1.021431


  6%|▌         | 39/628 [07:12<3:45:28, 22.97s/it]

KL: 0.8, Word: 30.11, Topo: 84.75, Assm: 41.35, Steo: 31.33, Loss: 56.042809, Delta: 0.124046


  9%|▉         | 59/628 [09:03<36:45,  3.88s/it]  

KL: 0.8, Word: 39.18, Topo: 88.87, Assm: 50.81, Steo: 31.33, Loss: 50.798950, Delta: -1.014355


 13%|█▎        | 79/628 [12:19<1:07:33,  7.38s/it]

KL: 0.8, Word: 46.40, Topo: 90.58, Assm: 55.31, Steo: 29.39, Loss: 46.382809, Delta: -0.085739


 16%|█▌        | 99/628 [15:15<1:41:25, 11.50s/it]

KL: 1.0, Word: 49.99, Topo: 91.54, Assm: 56.65, Steo: 30.85, Loss: 44.333611, Delta: -1.550499


 19%|█▉        | 119/628 [17:46<51:19,  6.05s/it]  

KL: 1.0, Word: 52.39, Topo: 92.13, Assm: 59.03, Steo: 31.44, Loss: 41.728252, Delta: -0.834881


 22%|██▏       | 139/628 [20:44<1:45:41, 12.97s/it]

KL: 1.3, Word: 53.94, Topo: 92.60, Assm: 60.64, Steo: 30.58, Loss: 40.135433, Delta: -1.060310


 25%|██▌       | 159/628 [23:16<1:11:19,  9.12s/it]

KL: 1.6, Word: 54.54, Topo: 92.69, Assm: 62.43, Steo: 35.52, Loss: 38.698708, Delta: -0.634472


 29%|██▊       | 179/628 [26:52<1:15:14, 10.05s/it]

KL: 1.2, Word: 55.20, Topo: 92.90, Assm: 64.17, Steo: 32.22, Loss: 37.291447, Delta: -0.133656


 32%|███▏      | 199/628 [29:37<2:20:24, 19.64s/it]

KL: 1.2, Word: 55.59, Topo: 93.08, Assm: 65.56, Steo: 33.08, Loss: 36.709793, Delta: 0.476963


 35%|███▍      | 219/628 [31:48<35:15,  5.17s/it]  

KL: 1.2, Word: 55.87, Topo: 93.15, Assm: 66.34, Steo: 28.65, Loss: 35.316864, Delta: 1.063980


 38%|███▊      | 239/628 [34:53<51:43,  7.98s/it]  

KL: 1.1, Word: 56.34, Topo: 93.12, Assm: 67.42, Steo: 30.27, Loss: 36.079334, Delta: -0.106758


 41%|████      | 259/628 [37:54<1:19:06, 12.86s/it]

KL: 0.9, Word: 56.42, Topo: 93.33, Assm: 67.24, Steo: 28.87, Loss: 34.097462, Delta: 0.845039


 44%|████▍     | 279/628 [40:15<29:04,  5.00s/it]  

KL: 1.0, Word: 56.57, Topo: 93.38, Assm: 68.27, Steo: 28.33, Loss: 34.405842, Delta: 0.366703


 48%|████▊     | 299/628 [44:08<43:45,  7.98s/it]  

KL: 1.0, Word: 56.64, Topo: 93.46, Assm: 69.54, Steo: 29.77, Loss: 35.039536, Delta: -0.842480


 51%|█████     | 319/628 [47:14<1:07:50, 13.17s/it]

KL: 1.0, Word: 56.79, Topo: 93.55, Assm: 70.03, Steo: 30.50, Loss: 33.892181, Delta: 0.339996


 54%|█████▍    | 339/628 [49:54<22:52,  4.75s/it]  

KL: 0.9, Word: 57.18, Topo: 93.58, Assm: 69.48, Steo: 27.05, Loss: 34.099480, Delta: -1.080658


 57%|█████▋    | 359/628 [52:54<35:50,  8.00s/it]  

KL: 0.8, Word: 57.01, Topo: 93.70, Assm: 70.01, Steo: 27.54, Loss: 33.953522, Delta: -0.518978


 60%|██████    | 379/628 [55:52<1:11:58, 17.34s/it]

KL: 0.9, Word: 57.33, Topo: 93.49, Assm: 70.54, Steo: 28.90, Loss: 33.075787, Delta: 0.598927


 64%|██████▎   | 399/628 [57:54<17:29,  4.58s/it]  

KL: 1.0, Word: 57.65, Topo: 93.57, Assm: 71.30, Steo: 27.04, Loss: 33.569759, Delta: -0.515316


 67%|██████▋   | 419/628 [1:00:55<27:12,  7.81s/it]  

KL: 0.9, Word: 57.70, Topo: 93.63, Assm: 71.35, Steo: 28.55, Loss: 33.741669, Delta: -0.092270


 70%|██████▉   | 439/628 [1:04:12<1:13:04, 23.20s/it]

KL: 1.0, Word: 57.97, Topo: 93.71, Assm: 70.95, Steo: 27.92, Loss: 33.383160, Delta: -0.161720


 73%|███████▎  | 459/628 [1:06:21<13:22,  4.75s/it]  

KL: 0.9, Word: 57.83, Topo: 93.81, Assm: 71.45, Steo: 26.80, Loss: 32.170933, Delta: 0.711266


 76%|███████▋  | 479/628 [1:09:24<18:31,  7.46s/it]

KL: 0.8, Word: 57.84, Topo: 93.86, Assm: 71.87, Steo: 26.21, Loss: 32.917568, Delta: -0.474979


 79%|███████▉  | 499/628 [1:12:16<44:46, 20.82s/it]

KL: 1.0, Word: 58.21, Topo: 93.83, Assm: 72.05, Steo: 27.97, Loss: 32.481724, Delta: -0.016621


 83%|████████▎ | 519/628 [1:14:58<09:30,  5.23s/it]

KL: 0.9, Word: 58.13, Topo: 93.89, Assm: 72.05, Steo: 28.90, Loss: 32.583691, Delta: 0.064186


 86%|████████▌ | 539/628 [1:17:44<10:45,  7.25s/it]

KL: 1.0, Word: 57.99, Topo: 93.86, Assm: 72.69, Steo: 27.35, Loss: 32.085445, Delta: 0.615997


 89%|████████▉ | 559/628 [1:21:03<18:37, 16.19s/it]

KL: 1.1, Word: 58.29, Topo: 93.93, Assm: 72.94, Steo: 29.02, Loss: 32.583992, Delta: -0.151691


 92%|█████████▏| 579/628 [1:23:07<03:50,  4.71s/it]

KL: 0.9, Word: 58.22, Topo: 93.92, Assm: 73.01, Steo: 30.30, Loss: 32.009804, Delta: 1.207069


 95%|█████████▌| 599/628 [1:25:55<03:26,  7.12s/it]

KL: 0.8, Word: 58.52, Topo: 94.02, Assm: 73.07, Steo: 28.29, Loss: 32.172657, Delta: 0.923794


 99%|█████████▊| 619/628 [1:29:17<00:21,  2.44s/it]

KL: 1.0, Word: 58.56, Topo: 94.10, Assm: 73.15, Steo: 23.38, Loss: 31.393181, Delta: 1.087078


100%|██████████| 628/628 [1:29:32<00:00,  8.56s/it]

learning rate: 0.000900
Epoch 1: 



  3%|▎         | 19/628 [04:07<1:06:38,  6.57s/it]

KL: 1.0, Word: 58.58, Topo: 94.11, Assm: 74.28, Steo: 26.69, Loss: 31.110888, Delta: 0.696451


  6%|▌         | 39/628 [07:10<2:16:46, 13.93s/it]

KL: 0.9, Word: 58.71, Topo: 94.08, Assm: 73.73, Steo: 25.27, Loss: 31.473078, Delta: 0.733130


  9%|▉         | 59/628 [09:27<47:58,  5.06s/it]  

KL: 1.0, Word: 58.78, Topo: 94.06, Assm: 74.18, Steo: 26.60, Loss: 31.829498, Delta: -0.807379


 13%|█▎        | 79/628 [12:33<55:29,  6.06s/it]  

KL: 0.9, Word: 58.58, Topo: 94.06, Assm: 74.25, Steo: 30.52, Loss: 31.033461, Delta: 0.226658


 16%|█▌        | 99/628 [15:41<1:45:39, 11.98s/it]

KL: 0.9, Word: 58.65, Topo: 94.16, Assm: 74.24, Steo: 28.00, Loss: 31.649208, Delta: 0.070528


 19%|█▉        | 119/628 [18:08<38:30,  4.54s/it]  

KL: 1.0, Word: 58.70, Topo: 94.26, Assm: 74.69, Steo: 23.40, Loss: 31.303322, Delta: -0.409109


 22%|██▏       | 139/628 [21:03<41:31,  5.09s/it]  

KL: 0.9, Word: 58.70, Topo: 94.19, Assm: 74.41, Steo: 28.72, Loss: 31.232355, Delta: 0.142033


 25%|██▌       | 159/628 [24:16<1:37:55, 12.53s/it]

KL: 1.0, Word: 58.87, Topo: 94.17, Assm: 74.44, Steo: 29.54, Loss: 31.542109, Delta: -0.638308


 29%|██▊       | 179/628 [26:18<30:29,  4.08s/it]  

KL: 1.0, Word: 58.89, Topo: 94.19, Assm: 74.91, Steo: 25.80, Loss: 31.186951, Delta: 0.386005


 32%|███▏      | 199/628 [29:16<43:47,  6.12s/it]  

KL: 1.0, Word: 58.78, Topo: 94.08, Assm: 74.79, Steo: 27.81, Loss: 31.744499, Delta: -0.023743


 35%|███▍      | 219/628 [32:15<1:00:03,  8.81s/it]

KL: 1.1, Word: 58.84, Topo: 94.22, Assm: 75.04, Steo: 29.93, Loss: 31.766392, Delta: -0.097967


 38%|███▊      | 239/628 [35:05<28:20,  4.37s/it]  

KL: 1.0, Word: 59.12, Topo: 94.38, Assm: 75.04, Steo: 28.37, Loss: 31.015980, Delta: 0.288607


 41%|████      | 259/628 [37:29<29:03,  4.73s/it]  

KL: 1.0, Word: 59.04, Topo: 94.31, Assm: 75.19, Steo: 27.33, Loss: 30.854376, Delta: 0.071880


 44%|████▍     | 279/628 [40:30<48:02,  8.26s/it]  

KL: 1.0, Word: 59.01, Topo: 94.33, Assm: 75.27, Steo: 26.38, Loss: 31.430140, Delta: -0.112406


 48%|████▊     | 299/628 [43:20<1:03:13, 11.53s/it]

KL: 1.0, Word: 59.18, Topo: 94.28, Assm: 75.53, Steo: 29.25, Loss: 30.943493, Delta: 0.150566


 51%|█████     | 319/628 [45:56<36:41,  7.13s/it]  

KL: 1.0, Word: 59.43, Topo: 94.39, Assm: 75.66, Steo: 28.28, Loss: 31.032104, Delta: 0.212664


 54%|█████▍    | 339/628 [48:38<38:04,  7.90s/it]  

KL: 1.0, Word: 59.08, Topo: 94.28, Assm: 75.77, Steo: 28.39, Loss: 30.939226, Delta: 0.125141


 57%|█████▋    | 359/628 [51:40<43:17,  9.66s/it]  

KL: 0.9, Word: 59.07, Topo: 94.38, Assm: 75.81, Steo: 23.79, Loss: 31.376236, Delta: 0.205757


 60%|██████    | 379/628 [54:25<24:54,  6.00s/it]  

KL: 1.1, Word: 59.26, Topo: 94.35, Assm: 75.86, Steo: 27.45, Loss: 31.755093, Delta: -0.757301


 64%|██████▎   | 399/628 [57:25<51:43, 13.55s/it]  

KL: 1.0, Word: 59.13, Topo: 94.37, Assm: 75.69, Steo: 30.87, Loss: 30.930008, Delta: -0.043879


 67%|██████▋   | 419/628 [59:43<20:20,  5.84s/it]  

KL: 1.1, Word: 59.40, Topo: 94.38, Assm: 75.60, Steo: 29.38, Loss: 31.098349, Delta: -0.763536


 70%|██████▉   | 439/628 [1:02:45<17:47,  5.65s/it]  

KL: 0.9, Word: 59.58, Topo: 94.57, Assm: 76.38, Steo: 28.38, Loss: 30.512894, Delta: 0.447886


 73%|███████▎  | 459/628 [1:06:07<24:04,  8.54s/it]  

KL: 1.0, Word: 59.62, Topo: 94.47, Assm: 75.76, Steo: 31.28, Loss: 30.506632, Delta: 0.177830


 76%|███████▋  | 479/628 [1:08:21<11:11,  4.51s/it]

KL: 1.0, Word: 59.27, Topo: 94.61, Assm: 76.51, Steo: 31.19, Loss: 30.951612, Delta: -0.379509


 79%|███████▉  | 499/628 [1:11:24<14:48,  6.89s/it]

KL: 1.0, Word: 59.55, Topo: 94.50, Assm: 76.38, Steo: 27.25, Loss: 30.451078, Delta: 0.109652


 83%|████████▎ | 519/628 [1:15:24<50:33, 27.83s/it]  

KL: 1.1, Word: 59.42, Topo: 94.45, Assm: 75.89, Steo: 32.75, Loss: 30.701025, Delta: -0.209267


 86%|████████▌ | 539/628 [1:16:55<06:03,  4.09s/it]

KL: 1.0, Word: 59.09, Topo: 94.50, Assm: 76.34, Steo: 29.31, Loss: 31.522182, Delta: -0.604328


 89%|████████▉ | 559/628 [1:20:12<07:57,  6.92s/it]

KL: 1.0, Word: 59.26, Topo: 94.58, Assm: 76.05, Steo: 27.75, Loss: 30.411140, Delta: 0.526186


 92%|█████████▏| 579/628 [1:23:37<04:57,  6.06s/it]

KL: 0.9, Word: 59.64, Topo: 94.42, Assm: 75.68, Steo: 34.01, Loss: 30.988131, Delta: -0.103933


 95%|█████████▌| 599/628 [1:26:05<02:09,  4.48s/it]

KL: 1.0, Word: 59.66, Topo: 94.40, Assm: 76.61, Steo: 30.65, Loss: 30.958740, Delta: 0.171646


 99%|█████████▊| 619/628 [1:29:29<01:01,  6.84s/it]

KL: 1.0, Word: 59.33, Topo: 94.46, Assm: 76.18, Steo: 31.87, Loss: 30.692131, Delta: -0.709757


100%|██████████| 628/628 [1:29:49<00:00,  8.58s/it]

learning rate: 0.000810
Epoch 2: 



  3%|▎         | 19/628 [03:46<50:21,  4.96s/it]   

KL: 1.0, Word: 59.69, Topo: 94.59, Assm: 77.18, Steo: 27.55, Loss: 30.045231, Delta: -0.271477


  6%|▌         | 39/628 [06:58<2:13:42, 13.62s/it]

KL: 1.0, Word: 59.81, Topo: 94.48, Assm: 77.07, Steo: 31.52, Loss: 30.470297, Delta: -0.403978


  9%|▉         | 59/628 [09:12<41:27,  4.37s/it]  

KL: 1.0, Word: 59.56, Topo: 94.59, Assm: 76.60, Steo: 28.09, Loss: 30.651028, Delta: -0.787212


 13%|█▎        | 79/628 [12:30<57:56,  6.33s/it]  

KL: 1.0, Word: 59.60, Topo: 94.62, Assm: 77.41, Steo: 31.31, Loss: 30.323420, Delta: 0.011894


 16%|█▌        | 99/628 [15:36<2:11:52, 14.96s/it]

KL: 1.0, Word: 59.84, Topo: 94.70, Assm: 77.00, Steo: 29.01, Loss: 30.171867, Delta: 0.781565


 19%|█▉        | 119/628 [17:40<41:09,  4.85s/it]  

KL: 1.0, Word: 59.64, Topo: 94.65, Assm: 76.86, Steo: 30.95, Loss: 30.816250, Delta: -0.177923


 22%|██▏       | 139/628 [20:42<51:34,  6.33s/it]  

KL: 1.0, Word: 59.81, Topo: 94.65, Assm: 76.83, Steo: 30.40, Loss: 30.428251, Delta: 0.271229


 25%|██▌       | 159/628 [23:42<1:55:30, 14.78s/it]

KL: 1.0, Word: 59.82, Topo: 94.68, Assm: 77.18, Steo: 33.22, Loss: 30.440168, Delta: -0.401974


 29%|██▊       | 179/628 [26:43<40:39,  5.43s/it]  

KL: 1.0, Word: 59.82, Topo: 94.64, Assm: 77.51, Steo: 32.81, Loss: 30.652216, Delta: -0.061300


 32%|███▏      | 199/628 [29:37<1:01:27,  8.60s/it]

KL: 1.0, Word: 60.09, Topo: 94.67, Assm: 77.61, Steo: 32.47, Loss: 29.974176, Delta: -0.214119


 35%|███▍      | 219/628 [32:17<1:13:16, 10.75s/it]

KL: 1.1, Word: 60.01, Topo: 94.71, Assm: 77.43, Steo: 29.94, Loss: 30.075224, Delta: -0.465319


 38%|███▊      | 239/628 [35:12<31:02,  4.79s/it]  

KL: 1.0, Word: 59.72, Topo: 94.66, Assm: 77.63, Steo: 30.53, Loss: 29.971336, Delta: -0.026598


 41%|████      | 259/628 [38:06<49:40,  8.08s/it]  

KL: 1.0, Word: 59.76, Topo: 94.71, Assm: 77.33, Steo: 26.02, Loss: 30.132479, Delta: -0.169291


 44%|████▍     | 279/628 [40:51<24:19,  4.18s/it]  

KL: 1.0, Word: 59.85, Topo: 94.76, Assm: 76.72, Steo: 30.72, Loss: 30.067556, Delta: -0.369997


 48%|████▊     | 299/628 [43:39<29:00,  5.29s/it]  

KL: 1.0, Word: 59.93, Topo: 94.79, Assm: 77.99, Steo: 31.03, Loss: 29.756643, Delta: 0.331556


 51%|█████     | 319/628 [46:48<49:21,  9.58s/it]  

KL: 1.0, Word: 59.89, Topo: 94.69, Assm: 77.07, Steo: 31.85, Loss: 29.315945, Delta: 0.743355


 54%|█████▍    | 339/628 [49:04<22:15,  4.62s/it]  

KL: 1.0, Word: 59.91, Topo: 94.67, Assm: 77.72, Steo: 32.89, Loss: 29.819572, Delta: 0.629190


 57%|█████▋    | 359/628 [51:46<20:27,  4.56s/it]  

KL: 1.0, Word: 60.16, Topo: 94.81, Assm: 77.67, Steo: 29.17, Loss: 30.100533, Delta: 0.207323


 60%|██████    | 379/628 [55:10<33:15,  8.02s/it]  

KL: 1.0, Word: 60.06, Topo: 94.60, Assm: 78.15, Steo: 29.67, Loss: 30.325207, Delta: -0.158333


 64%|██████▎   | 399/628 [57:47<36:57,  9.68s/it]  

KL: 1.2, Word: 59.65, Topo: 94.80, Assm: 77.82, Steo: 34.33, Loss: 29.802151, Delta: 0.858004


 67%|██████▋   | 419/628 [1:01:33<20:27,  5.87s/it]  

KL: 1.0, Word: 59.72, Topo: 94.63, Assm: 77.75, Steo: 27.10, Loss: 30.468138, Delta: 0.067579


 70%|██████▉   | 439/628 [1:04:04<33:34, 10.66s/it]  

KL: 1.0, Word: 60.16, Topo: 94.71, Assm: 77.82, Steo: 29.57, Loss: 30.542118, Delta: -0.577440


 73%|███████▎  | 459/628 [1:06:02<11:23,  4.04s/it]

KL: 1.1, Word: 60.15, Topo: 94.79, Assm: 78.25, Steo: 31.38, Loss: 29.690935, Delta: 0.736540


 76%|███████▋  | 479/628 [1:09:20<13:25,  5.41s/it]  

KL: 1.0, Word: 59.97, Topo: 94.85, Assm: 78.46, Steo: 29.19, Loss: 29.601782, Delta: -0.148376


 79%|███████▉  | 499/628 [1:12:24<18:13,  8.48s/it]

KL: 1.0, Word: 60.01, Topo: 94.81, Assm: 77.96, Steo: 29.63, Loss: 29.500988, Delta: 0.463432


 83%|████████▎ | 519/628 [1:14:58<08:41,  4.78s/it]

KL: 1.0, Word: 59.92, Topo: 94.92, Assm: 78.32, Steo: 32.04, Loss: 29.824104, Delta: -0.298809


 86%|████████▌ | 539/628 [1:17:55<08:26,  5.69s/it]

KL: 1.0, Word: 60.15, Topo: 94.62, Assm: 77.99, Steo: 33.24, Loss: 29.812492, Delta: -0.109535


 89%|████████▉ | 559/628 [1:21:22<13:41, 11.91s/it]

KL: 1.0, Word: 60.10, Topo: 94.74, Assm: 78.05, Steo: 32.44, Loss: 29.478935, Delta: 0.328014


 92%|█████████▏| 579/628 [1:23:21<03:45,  4.60s/it]

KL: 1.0, Word: 59.81, Topo: 94.86, Assm: 77.77, Steo: 30.31, Loss: 30.095285, Delta: -0.618998


 95%|█████████▌| 599/628 [1:26:32<02:25,  5.03s/it]

KL: 1.0, Word: 60.03, Topo: 94.78, Assm: 78.08, Steo: 30.41, Loss: 29.574100, Delta: 0.001072


 99%|█████████▊| 619/628 [1:30:08<01:20,  8.94s/it]

KL: 1.0, Word: 60.26, Topo: 94.86, Assm: 78.05, Steo: 31.00, Loss: 29.725216, Delta: 0.173376


100%|██████████| 628/628 [1:30:49<00:00,  8.68s/it]


learning rate: 0.000729
Epoch 3: 


  3%|▎         | 19/628 [03:52<54:25,  5.36s/it]   

KL: 1.0, Word: 60.27, Topo: 94.86, Assm: 78.32, Steo: 30.34, Loss: 29.923973, Delta: -0.147261


  6%|▌         | 39/628 [07:05<2:11:24, 13.39s/it]

KL: 1.0, Word: 60.38, Topo: 94.82, Assm: 78.96, Steo: 32.66, Loss: 29.234980, Delta: 0.509592


  9%|▉         | 59/628 [09:39<39:47,  4.20s/it]  

KL: 1.1, Word: 60.33, Topo: 94.87, Assm: 78.39, Steo: 30.26, Loss: 29.692696, Delta: -0.398027


 13%|█▎        | 79/628 [12:35<50:36,  5.53s/it]  

KL: 1.0, Word: 60.24, Topo: 94.93, Assm: 78.15, Steo: 33.96, Loss: 29.349636, Delta: 0.328260


 16%|█▌        | 99/628 [15:47<2:08:53, 14.62s/it]

KL: 1.0, Word: 60.28, Topo: 94.90, Assm: 78.77, Steo: 34.93, Loss: 29.416794, Delta: -0.195383


 19%|█▉        | 119/628 [18:10<35:36,  4.20s/it]  

KL: 0.9, Word: 60.20, Topo: 94.72, Assm: 79.33, Steo: 36.13, Loss: 29.892704, Delta: -0.066954


 22%|██▏       | 139/628 [21:12<45:04,  5.53s/it]  

KL: 1.0, Word: 60.30, Topo: 94.85, Assm: 79.15, Steo: 28.99, Loss: 28.847120, Delta: 0.506290


 25%|██▌       | 159/628 [24:20<1:35:35, 12.23s/it]

KL: 1.0, Word: 60.23, Topo: 94.94, Assm: 79.01, Steo: 30.99, Loss: 30.105276, Delta: -0.232508


 29%|██▊       | 179/628 [26:29<31:46,  4.25s/it]  

KL: 1.0, Word: 60.04, Topo: 94.98, Assm: 79.17, Steo: 34.81, Loss: 29.535795, Delta: -0.508995


 32%|███▏      | 199/628 [29:49<57:31,  8.05s/it]  

KL: 0.9, Word: 60.21, Topo: 94.93, Assm: 78.87, Steo: 31.23, Loss: 29.328407, Delta: 0.251625


 35%|███▍      | 219/628 [33:33<1:48:51, 15.97s/it]

KL: 1.0, Word: 60.17, Topo: 94.93, Assm: 78.99, Steo: 30.29, Loss: 30.045839, Delta: -0.605671


 38%|███▊      | 239/628 [35:06<30:12,  4.66s/it]  

KL: 1.0, Word: 60.33, Topo: 94.94, Assm: 79.23, Steo: 32.51, Loss: 28.824295, Delta: 0.524622


 41%|████      | 259/628 [38:16<35:03,  5.70s/it]  

KL: 1.1, Word: 60.28, Topo: 94.96, Assm: 79.04, Steo: 33.59, Loss: 29.207348, Delta: 0.540924


 44%|████▍     | 279/628 [41:46<1:31:29, 15.73s/it]

KL: 1.0, Word: 60.41, Topo: 94.97, Assm: 79.16, Steo: 33.99, Loss: 29.524364, Delta: 0.362511


 48%|████▊     | 299/628 [43:38<24:10,  4.41s/it]  

KL: 1.1, Word: 60.38, Topo: 94.88, Assm: 79.13, Steo: 33.85, Loss: 29.893641, Delta: -0.720242


 51%|█████     | 319/628 [47:11<33:44,  6.55s/it]  

KL: 1.0, Word: 60.50, Topo: 94.94, Assm: 79.43, Steo: 30.53, Loss: 29.449535, Delta: 0.158014


 54%|█████▍    | 339/628 [49:49<54:31, 11.32s/it]  

KL: 1.0, Word: 60.43, Topo: 94.92, Assm: 78.74, Steo: 32.27, Loss: 29.793627, Delta: -0.876883


 57%|█████▋    | 359/628 [52:14<21:03,  4.70s/it]  

KL: 1.0, Word: 60.29, Topo: 94.96, Assm: 79.31, Steo: 32.04, Loss: 29.033302, Delta: 0.724869


 60%|██████    | 379/628 [55:22<24:17,  5.85s/it]  

KL: 1.0, Word: 60.19, Topo: 94.93, Assm: 79.19, Steo: 29.94, Loss: 29.485163, Delta: -0.249985


 64%|██████▎   | 399/628 [58:26<30:13,  7.92s/it]  

KL: 1.1, Word: 60.53, Topo: 94.98, Assm: 79.10, Steo: 34.23, Loss: 29.265802, Delta: -0.006845


 67%|██████▋   | 419/628 [1:00:45<16:05,  4.62s/it]  

KL: 1.0, Word: 60.53, Topo: 94.93, Assm: 79.39, Steo: 32.84, Loss: 28.971266, Delta: -0.104469


 70%|██████▉   | 439/628 [1:03:44<19:12,  6.10s/it]  

KL: 1.0, Word: 60.64, Topo: 94.96, Assm: 79.32, Steo: 31.91, Loss: 29.127401, Delta: 0.211578


 73%|███████▎  | 459/628 [1:06:39<37:41, 13.38s/it]  

KL: 1.0, Word: 60.34, Topo: 95.02, Assm: 78.69, Steo: 35.37, Loss: 29.978994, Delta: -0.126566


 76%|███████▋  | 479/628 [1:09:00<11:24,  4.59s/it]  

KL: 1.0, Word: 60.07, Topo: 94.97, Assm: 79.08, Steo: 29.21, Loss: 28.637239, Delta: 1.114359


 79%|███████▉  | 499/628 [1:12:03<11:58,  5.57s/it]

KL: 1.0, Word: 60.46, Topo: 95.01, Assm: 79.57, Steo: 33.47, Loss: 29.360977, Delta: 0.064987


 83%|████████▎ | 519/628 [1:14:39<16:46,  9.24s/it]

KL: 1.0, Word: 60.33, Topo: 94.95, Assm: 79.01, Steo: 33.32, Loss: 29.384272, Delta: 0.274626


 86%|████████▌ | 539/628 [1:17:26<07:38,  5.16s/it]

KL: 1.0, Word: 60.36, Topo: 94.98, Assm: 79.25, Steo: 29.30, Loss: 28.986525, Delta: 0.641905


 89%|████████▉ | 559/628 [1:20:29<08:49,  7.68s/it]

KL: 1.0, Word: 60.54, Topo: 94.98, Assm: 79.32, Steo: 32.48, Loss: 28.914799, Delta: 0.986511


 92%|█████████▏| 579/628 [1:23:41<07:40,  9.40s/it]

KL: 1.1, Word: 60.37, Topo: 95.01, Assm: 79.37, Steo: 31.19, Loss: 29.334246, Delta: -0.449234


 95%|█████████▌| 599/628 [1:26:43<03:59,  8.27s/it]

KL: 1.0, Word: 60.62, Topo: 95.07, Assm: 79.54, Steo: 34.35, Loss: 28.823063, Delta: 1.029133


 99%|█████████▊| 619/628 [1:29:04<00:54,  6.07s/it]

KL: 1.0, Word: 60.71, Topo: 95.03, Assm: 79.20, Steo: 29.77, Loss: 29.314081, Delta: -0.550636


100%|██████████| 628/628 [1:29:19<00:00,  8.53s/it]

learning rate: 0.000656
Epoch 4: 



  3%|▎         | 19/628 [03:47<1:08:08,  6.71s/it]

KL: 1.0, Word: 60.88, Topo: 94.98, Assm: 79.77, Steo: 33.88, Loss: 29.309559, Delta: -0.732872


  6%|▌         | 39/628 [07:07<2:18:14, 14.08s/it]

KL: 1.0, Word: 60.53, Topo: 95.07, Assm: 80.45, Steo: 32.42, Loss: 29.310076, Delta: -0.486910


  9%|▉         | 59/628 [09:21<41:11,  4.34s/it]  

KL: 1.0, Word: 60.70, Topo: 95.11, Assm: 79.98, Steo: 33.74, Loss: 29.236412, Delta: -0.020956


 11%|█         | 67/628 [10:39<51:53,  5.55s/it]  

['O=C(COC(=O)C[C@@H]1C[C@H]2CC[C@@H]1C2)Nc1cccc(Cl)c1', 'COc1cc2[nH]c(O)c(C(=O)NCCC[NH+](C)C)c(=O)c2cc1OC', 'Nc1ccnc(CN(Cc2ccncc2)C2CC2)n1', 'NC(=O)c1ccc(/C=C/C(=O)Nc2cccc([N+](=O)[O-])c2)cc1', 'COc1ccc(-c2ocnc2C(=O)Nc2ccc3c(c2)CCC3)cc1', 'Cc1cc(NC(=O)CN2CCN(Cc3cc(-c4ccc(Cl)cc4)no3)CC2)no1', 'O=C([O-])c1ccc(CN2C(=O)S[C@@H](Nc3ccc(C(=O)[O-])cc3)C2=O)cc1', 'CC[C@H]([NH2+][C@@H](CO)c1ccccc1F)c1cccs1', 'Cc1ccc2ncc(C[NH+]3CCC(Nc4ccccc4)CC3)n2c1', 'C[C@@H]1CC[C@H](C(=O)O[C@H](C(F)(F)F)C(C)(C)c2ccccc2)O1', 'CN(CC(=O)NCC[NH+](C)C)S(=O)(=O)c1cc(Cl)ccc1Cl', 'COc1cc(C(=O)N(C2CCCC2)C2CC2)cc(OC)c1OC(F)F', 'COc1ccc(CNC(=O)CSc2nnc3c(=O)n(-c4ccc(F)cc4)ccn23)cc1', 'C[C@H](C(=O)NC(=O)NC12CC3CC(CC(C3)C1)C2)N1CC[NH+](CC2CC2)CC1', 'O=C(CSc1ccc(F)cc1)N[C@H]1CCS(=O)(=O)C1', 'CCCCC1([C@H]2CC(=O)N(c3ccccc3)C2=O)C(=O)N(c2ccccc2)N(c2ccccc2)C1=O', 'CC(C)CNC(=O)C(=O)Nc1ccccc1N(C)Cc1ccccc1', 'CCOc1ccccc1NC(=O)C1CCN(S(=O)(=O)c2cc(-c3noc(C)n3)cs2)CC1', 'O=C(Cn1cnc(-c2ccccc2)cc1=O)Nc1nc2ccccc2s1', 'Cc1cccc(CNS(=O)(=O)

KeyboardInterrupt: 

## References

<div class="cite2c-biblio"></div>

In [None]:
def test():
    dataset.training = False
    dataloader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            collate_fn=JTNNCollator(vocab, False),
            drop_last=True,
            worker_init_fn=None)#worker_init_fn)

    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    for it, batch in enumerate(dataloader):
        #print(batch['mol_trees'])
        gt_smiles = batch['mol_trees'][0].smiles
        #print(gt_smiles)
        model.move_to_cuda(batch)
        _, tree_vec, mol_vec = model.encode(batch)
        tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
        smiles = model.decode(tree_vec, mol_vec)
        print(smiles)

In [None]:
torch.cuda.empty_cache()
test()