In [1]:
import copy
from dgl import batch, unbatch
import rdkit.Chem as Chem
from rdkit import RDLogger
import torch
import torch.nn as nn
import torch.nn.functional as F

import os,sys,inspect
sys.path.insert(0,'/home/icarus/app/src') 

from models.jtnn_vae import DGLJTNNVAE
from models.modules import *

#from .nnutils import cuda, move_dgl_to_cuda

#from .jtnn_enc import DGLJTNNEncoder
#from .jtnn_dec import DGLJTNNDecoder
#from .mpn import DGLMPN
#from .mpn import mol2dgl_single as mol2dgl_enc
#from .jtmpn import DGLJTMPN
#from .jtmpn import mol2dgl_single as mol2dgl_dec

import os,sys,inspect
from tqdm import tqdm
sys.path.insert(0,'/home/icarus/T-CVAE-MolGen/src')

from utils.chem import set_atommap, copy_edit_mol, enum_assemble_nx, \
                            attach_mols_nx, decode_stereo

lg = RDLogger.logger()

lg.setLevel(RDLogger.CRITICAL)



In [2]:
class TCVAE(nn.Module):
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(TCVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth
        
        self.embedding = nn.Embedding(vocab.size(), hidden_size).cuda()
        self.mpn = MPN(hidden_size, depth)
        self.encoder = None  #
        self.decoder = None  #
        
        self.T_mean, self.T_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
        self.G_mean, self.G_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
            
        self.n_nodes_total = 0
        self.n_edges_total = 0
        self.n_passes = 0

## Posterior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$q(z|x,y) \sim N(\mu,\sigma^2I)$, where

$\quad\quad h = \text{MultiHead}(c,E_\text{out}^L(x;y),E_\text{out}^L(x;y))$

$\quad\quad \begin{bmatrix}\mu\\\log(\sigma^2)\end{bmatrix} = hW_q+b_q$

In [3]:
    def sample_posterior(self, prob_decode=False):
        return

### Prior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$p_\theta (z|x) \sim N(\mu', \sigma'^2 I)$, where:

$\quad\quad h' = \text{MultiHead}(c, E_\text{out}^L(x), E_\text{out}^L(x))$

$\quad\quad \begin{bmatrix}\mu'\\\log(\sigma'^2)\end{bmatrix} = MLP_p(h')$

In [4]:
    def sample_prior(self, prob_decode=False):
        return

In [5]:
from utils.data import JTNNDataset, JTNNCollator
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

class ArgsTemp():
    def __init__(self, hidden_size, depth, device):
        self.hidden_size = hidden_size
        self.batch_size = 300
        self.latent_size = 56
        self.depth = depth
        self.device = device
        self.lr = 1e-3
        self.beta = 1.0
        self.use_cuda = torch.cuda.is_available()
        
args = ArgsTemp(200,3, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(args.depth)

dataset = JTNNDataset(data='train', vocab='vocab', training=True, intermediates=False)
vocab = dataset.vocab
"""
dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=1,
        collate_fn=JTNNCollator(vocab, True, intermediates=False),
        drop_last=True,
        worker_init_fn=None)
        """

model = DGLJTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth, args).cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)


3
Loading data...
Loading finished.
	Num samples: 220011
	Vocab file: /home/icarus/.dgl/jtnn/vocab.txt


In [6]:
MAX_EPOCH = 100
PRINT_ITER = 20

from tqdm import tqdm
from os import access, R_OK
from os.path import isdir

save_path = '/home/icarus/app/data/05_model_output'
assert isdir(save_path) and access(save_path, R_OK), \
       "File {} doesn't exist or isn't readable".format(save_path)

def train():
    dataset.training = True
    dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=2,
            collate_fn=JTNNCollator(vocab, True),
            drop_last=True,
            worker_init_fn=None)

    for epoch in range(MAX_EPOCH):
        word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
        print("Epoch %d: " % epoch)

        for it, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
            model.zero_grad()
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta)
            except:
                print([t.smiles for t in batch['mol_trees']])
                raise
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            if (it + 1) % PRINT_ITER == 0:
                word_acc = word_acc / PRINT_ITER * 100
                topo_acc = topo_acc / PRINT_ITER * 100
                assm_acc = assm_acc / PRINT_ITER * 100
                steo_acc = steo_acc / PRINT_ITER * 100

                print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % (
                    kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item()))
                word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
                sys.stdout.flush()

            if (it + 1) % 1500 == 0: #Fast annealing
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
                
            if (it + 1) % 100 == 0:
                torch.save(model.state_dict(),
                            save_path + "/model.iter-%d-%d" % (epoch, it + 1))

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
        torch.save(model.state_dict(), save_path + "/model.iter-" + str(epoch))

def test():
    dataset.training = False
    dataloader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            collate_fn=JTNNCollator(vocab, False),
            drop_last=True,
            worker_init_fn=worker_init_fn)

    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    for it, batch in enumerate(dataloader):
        gt_smiles = batch['mol_trees'][0].smiles
        print(gt_smiles)
        model.move_to_cuda(batch)
        _, tree_vec, mol_vec = model.encode(batch)
        tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
        smiles = model.decode(tree_vec, mol_vec)
        print(smiles)

In [None]:
import warnings; warnings.simplefilter('ignore')
train() 

Epoch 0: 


  3%|▎         | 19/733 [12:56<9:00:05, 45.39s/it] 

KL: 0.1, Word: 19.91, Topo: 74.10, Assm: 33.09, Steo: 21.26, Loss: 66.228661


  5%|▌         | 39/733 [25:47<8:27:41, 43.89s/it]

KL: 0.6, Word: 30.18, Topo: 85.35, Assm: 41.57, Steo: 26.16, Loss: 55.948597


  8%|▊         | 59/733 [38:19<8:34:30, 45.80s/it]

KL: 0.9, Word: 38.24, Topo: 88.75, Assm: 50.85, Steo: 23.22, Loss: 50.182426


 11%|█         | 79/733 [51:33<7:58:11, 43.87s/it]

KL: 0.7, Word: 45.32, Topo: 90.50, Assm: 55.23, Steo: 28.24, Loss: 45.417885


 11%|█         | 80/733 [51:35<5:39:59, 31.24s/it]

## References

<div class="cite2c-biblio"></div>