In [1]:
%load_ext autoreload
%autoreload 2
import copy
from dgl import batch, unbatch
import rdkit.Chem as Chem
from rdkit import RDLogger
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

import numpy as np
np.random.seed(0)

import os,sys,inspect
sys.path.insert(0,'/home/icarus/app/src') 

from models.jtnn_vae import DGLJTNNVAE
from models.modules import *

#from .nnutils import cuda, move_dgl_to_cuda

#from .jtnn_enc import DGLJTNNEncoder
#from .jtnn_dec import DGLJTNNDecoder
#from .mpn import DGLMPN
#from .mpn import mol2dgl_single as mol2dgl_enc
#from .jtmpn import DGLJTMPN
#from .jtmpn import mol2dgl_single as mol2dgl_dec

import os,sys,inspect
from tqdm import tqdm
sys.path.insert(0,'/home/icarus/T-CVAE-MolGen/src')


from utils.chem import set_atommap, copy_edit_mol, enum_assemble_nx, \
                            attach_mols_nx, decode_stereo

lg = RDLogger.logger()

lg.setLevel(RDLogger.CRITICAL)



In [2]:
class TCVAE(nn.Module):
    def __init__(self, vocab, hidden_size, latent_size, depth):
        super(TCVAE, self).__init__()
        self.vocab = vocab
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.depth = depth
        
        self.embedding = nn.Embedding(vocab.size(), hidden_size).cuda()
        self.mpn = MPN(hidden_size, depth)
        self.encoder = None  #
        self.decoder = None  #
        
        self.T_mean, self.T_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
        self.G_mean, self.G_var = nn.Linear(hidden_size, latent_size // 2), \
                                  nn.Linear(hidden_size, latent_size // 2)
            
        self.n_nodes_total = 0
        self.n_edges_total = 0
        self.n_passes = 0

## Posterior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$q(z|x,y) \sim N(\mu,\sigma^2I)$, where

$\quad\quad h = \text{MultiHead}(c,E_\text{out}^L(x;y),E_\text{out}^L(x;y))$

$\quad\quad \begin{bmatrix}\mu\\\log(\sigma^2)\end{bmatrix} = hW_q+b_q$

In [3]:
    def sample_posterior(self, prob_decode=False):
        return

### Prior

As in <cite data-cite="7333468/6Y976JUQ"></cite>

$p_\theta (z|x) \sim N(\mu', \sigma'^2 I)$, where:

$\quad\quad h' = \text{MultiHead}(c, E_\text{out}^L(x), E_\text{out}^L(x))$

$\quad\quad \begin{bmatrix}\mu'\\\log(\sigma'^2)\end{bmatrix} = MLP_p(h')$

In [4]:
    def sample_prior(self, prob_decode=False):
        return

In [5]:
from utils.data import JTNNDataset, JTNNCollator
#torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

class ArgsTemp():
    def __init__(self, hidden_size, depth, device):
        self.hidden_size = hidden_size
        self.batch_size = 200
        self.latent_size = 56
        self.depth = depth
        self.device = device
        self.lr = 1e-3
        self.beta = 1.0
        self.use_cuda = torch.cuda.is_available()
        
args = ArgsTemp(200,3, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(args.depth)

from pathlib import Path
data_path = Path('/home/icarus/app/data/01_raw/moses/tinytrain.csv')
vocab_path = Path('/home/icarus/app/data/03_processed/tinytrainvocab.txt')

dataset = JTNNDataset(data=data_path, 
                      vocab=vocab_path, training=True, intermediates=False)
vocab = dataset.vocab
"""
dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=1,
        collate_fn=JTNNCollator(vocab, True, intermediates=False),
        drop_last=True,
        worker_init_fn=None)
        """

model = DGLJTNNVAE(vocab, args.hidden_size, args.latent_size, args.depth, args).cuda()
model.share_memory()

model_path = None
if model_path is not None:
    model.load_state_dict(torch.load(model_path))
else:
    for param in model.parameters():
        if param.dim() == 1:
            nn.init.constant_(param, 0)
        else:
            nn.init.xavier_normal_(param)
#if torch.cuda.device_count() > 1:
  #print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  #model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)


3
Loading data...
<class 'pathlib.PosixPath'>
Loading finished.
	Num samples: 1000
	Vocab file: /home/icarus/app/data/03_processed/tinytrainvocab.txt
Initializing Embedding
Initializing MPN
Initializing JTNN
Initializing Decoder
Initializing JTMPN


In [6]:
MAX_EPOCH = 5
PRINT_ITER = 20

from tqdm import tqdm
from os import access, R_OK
from os.path import isdir
import sys

save_path = '/home/icarus/app/data/04_models'
assert isdir(save_path) and access(save_path, R_OK), \
       "File {} doesn't exist or isn't readable".format(save_path)

def train():
    dataset.training = True
    print("Loading data...")
    dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=10,
            collate_fn=JTNNCollator(vocab, True),
            drop_last=True,
            worker_init_fn=None)
    dataloader._use_shared_memory = False
    last_loss = sys.maxsize
    print("Beginning Training...")
    for epoch in range(MAX_EPOCH):
        word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
        print("Epoch %d: " % epoch)

        for it, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
            model.zero_grad()
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta)
            except:
                print([t.smiles for t in batch['mol_trees']])
                raise
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            cur_loss = loss.item()
            
            if (it + 1) % PRINT_ITER == 0:
                word_acc = word_acc / PRINT_ITER * 100
                topo_acc = topo_acc / PRINT_ITER * 100
                assm_acc = assm_acc / PRINT_ITER * 100
                steo_acc = steo_acc / PRINT_ITER * 100

                print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f, Delta: %.6f" % (
                    kl_div, word_acc, topo_acc, assm_acc, steo_acc, cur_loss, last_loss - cur_loss))
                word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
                sys.stdout.flush()

            if (it + 1) % 1500 == 0: #Fast annealing
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
                
            if (it + 1) % 100 == 0:
                torch.save(model.state_dict(),
                            save_path + "/model.iter-%d-%d" % (epoch, it + 1))
                      
            #if last_loss - cur_loss < 1e-5:
            #    break
            last_loss = cur_loss

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
        torch.save(model.state_dict(), save_path + "/model.iter-" + str(epoch))

In [7]:
import warnings; warnings.simplefilter('ignore')
train() 

Loading data...
Beginning Training...
Epoch 0: 


100%|██████████| 5/5 [00:51<00:00, 10.27s/it]

learning rate: 0.000900
Epoch 1: 



100%|██████████| 5/5 [00:57<00:00, 11.50s/it]


learning rate: 0.000810
Epoch 2: 


100%|██████████| 5/5 [00:56<00:00, 11.35s/it]


learning rate: 0.000729
Epoch 3: 


100%|██████████| 5/5 [00:55<00:00, 11.17s/it]

learning rate: 0.000656
Epoch 4: 



100%|██████████| 5/5 [00:53<00:00, 10.62s/it]


learning rate: 0.000590


## References

<div class="cite2c-biblio"></div>

In [24]:
from utils.data import JTNNDataset
test_path=Path('/home/icarus/app/data/01_raw/moses/test.csv')
testset = JTNNDataset(data=test_path, 
                      vocab=vocab_path, training=False, intermediates=False)
def test():
    testset.training = False
    dataloader = DataLoader(
            testset,
            batch_size=1,
            shuffle=False,
            num_workers=0,
            collate_fn=JTNNCollator(vocab, False),
            drop_last=True,
            worker_init_fn=None)#worker_init_fn)

    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    samples_path = '/home/icarus/app/data/05_model_output/tinytrain_sampled.csv'
    with open(samples_path, 'w+') as f:
        f.write('SMILES\n')
        
    n = 0
    for it, batch in enumerate(dataloader):
        if n > 1000:
            break
        try:
            #print(batch['mol_trees'])
            gt_smiles = batch['mol_trees'][0].smiles
            #print(gt_smiles)
            model.move_to_cuda(batch)
            _, tree_vec, mol_vec = model.encode(batch)
            tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
            smiles = model.decode(tree_vec, mol_vec)
            if(smiles == None):
                continue
            n+=1
            with open(samples_path, 'a') as f:
                f.write(f'{smiles}\n')
            print(f'{n}: {smiles}')
        except:
            continue

Loading data...
<class 'pathlib.PosixPath'>
Data[0]: SMILES
Loading finished.
	Num samples: 176074
	Vocab file: /home/icarus/app/data/03_processed/tinytrainvocab.txt


In [25]:
torch.cuda.empty_cache()
test()

1: CSCSC
2: SC(S)(S)Sc1ccc2c(ccc3ccccc32)c1
3: CS[C@H](S)NCNc1ccccc1
4: c1ccc(OCOc2cccc3ccccc23)cc1
5: OCO[C@H](Cc1cccc2cc3cc4cc5cc6ccccc6cc5cc4cc3cc12)c1ccc(O)c(O)c1
6: SCc1cc(-c2ccc3ccccc3c2)cc2ccccc12
7: OC(O)(O)Oc1cccc2ccccc12
8: CNCNCNc1ccccc1
9: COCOCOCO
10: CCCCCc1cc(C)c2ccccc2c1
11: NSCOCOCO
12: COCO[C@H](O)OC
13: COCO[C@H](O)OC
14: O=S
15: CC(C)(C)CC(C)(C)C
16: Sc1ccc2cc3ccccc3cc2c1
17: OCOc1ccc2cc3ccccc3cc2c1
18: COCOCOCO
19: CN(c1ccccc1)[C@H](N)S
20: SCNc1ccc2c(ccc3ccccc32)c1
21: SCOCOc1ccc2ccccc2c1
22: Cc1ccc2cc3ccccc3cc2c1
23: CC(C)c1cccc2c1ccc1ccccc12
24: CSCS
25: OC(O)OC(O)(O)O
26: CO[C@@H](O)Oc1ccc2ccccc2c1
27: CSCSc1ccc2cc3ccccc3cc2c1
28: CCCCc1ccc2cc3ccccc3cc2c1
29: C[SH](C)c1c2ccccc2cc2c1ccc1ccccc12
30: BrC(Br)c1ccc2cc3ccccc3cc2c1
31: COCOc1cccc2ccccc12
32: SCSc1cccc2c1ccc1c3ccccc3ccc21
33: COCOCOCO
34: CC(C)c1cccc2ccccc12
35: O=[SH]CNc1cccc2ccccc12
36: CSCS
37: c1ccc2c(c1)ccc1c2ccc2c1ccc1c3ccc4ccccc4c3ccc12
38: c1ccc2cc3cc4cc5cc6ccccc6cc5cc4cc3cc2c1
39: C[SH](C)Cc1c