In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random, sys
import numpy as np
import argparse
from argparse import Namespace
from collections import deque
import pickle as pickle

from jtnn import *
from auxiliaries import build_parser, set_random_seed
import rdkit
import json, os
from rdkit import RDLogger

In [2]:
lg = RDLogger.logger() 
lg.setLevel(RDLogger.CRITICAL)

In [20]:
root = '/home/huang651/port-to-botorch/rafa-pred-model'
json_path = (root+'/default_gen_args2.json')
cmd_args = {'beta': 0.002, 'max_beta': 1.0}
with open(json_path) as handle:
    args = json.loads(handle.read())
args.update(cmd_args)
if 'seed' in args:
    set_random_seed(args['seed'])
else:
    args['seed'] = set_random_seed()
if 'save_dir' not in args:
    args['save_dir'] = "gen-{}-h{}-l{}-n{}-e{}-s{}".format(
        '2dmodel', args['hidden_size'], args['latent_size'],
        args['num_layers'], args['epoch'], args['seed']
    )
args['cuda'] = torch.cuda.is_available()

In [21]:
# save model settings
if not os.path.exists(args['save_dir']):
    os.makedirs(args['save_dir'])
dump_json_path = os.path.join(args['save_dir'], 'model.json')
if not os.path.exists(dump_json_path):
    with open(dump_json_path, "w") as fp:
        json.dump(args, fp, sort_keys=True, indent=4)
args = Namespace(**args)
print(args)
train_path = os.path.join(f'rafa-processed')
device = 'cuda' if args.cuda else 'cpu'

Namespace(anneal_iter=40000, anneal_rate=0.9, batch_size=32, beta=0.002, clip_norm=50.0, cuda=True, depthG=3, depthT=20, epoch=150, hidden_size=450, kl_anneal_iter=2000, latent_size=4, load_epoch=0, lr=0.001, max_beta=1.0, n_out=1, num_layers=3, print_iter=100, save_dir='gen-2dmodel-h450-l4-n3-e150-s764396391', save_iter=5000, seed=764396391, target='homo', total_trials=50, use_activation=True, vocab='/home/huang651/port-to-botorch/rafa-pred-model/data/rafa/vocab.txt', warmup=20000)


In [22]:
vocab = [x.strip("\r\n ") for x in open(args.vocab)] 
vocab = Vocab(vocab)

model = JTNNVAE(vocab, args)
if args.cuda:
    model = model.cuda()
print(model)

JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(43, 450)
    (outputNN): Sequential(
      (0): Linear(in_features=900, out_features=450, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=900, out_features=450, bias=True)
      (W_r): Linear(in_features=450, out_features=450, bias=False)
      (U_r): Linear(in_features=450, out_features=450, bias=True)
      (W_h): Linear(in_features=900, out_features=450, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(43, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (U_r): Linear(in_features=450, out_features=450, bias=False)
    (W_r): Linear(in_features=450, out_features=450, bias=True)
    (W_h): Linear(in_features=900, out_features=450, bias=True)
    (W): Linear(in_features=452, out_features=450, bias=True)
    (U): Linear(in_features=452, out_features=450, bias=True)
    (U_i): Linear(in_features=900, out_features=450, bias=True)
    (W_o): Li

In [23]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)

if args.load_epoch > 0:
    model.load_state_dict(torch.load(args.save_dir + "/model.iter-" + str(args.load_epoch)))

print(("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)))

Model #Params: 4599K


In [24]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate)
scheduler.step()

param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

total_step = args.load_epoch
beta = args.beta
meters = np.zeros(4)

In [25]:
for epoch in range(args.epoch):
    print(f"Currently at epoch: {epoch+1}")
    loader = MolTreeFolder(train_path, vocab, args.batch_size, num_workers=4)
    for batch in loader:
        total_step += 1
        model.zero_grad()
        loss, kl_div, wacc, tacc, sacc = model(batch, beta)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
        optimizer.step()

        meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100])

        if total_step % args.print_iter == 0:
            meters /= args.print_iter
            print(("[%d] Beta: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], param_norm(model), grad_norm(model))))
            sys.stdout.flush()
            meters *= 0

        if total_step % args.save_iter == 0:
            torch.save(model.state_dict(), args.save_dir + "/model.iter-" + str(total_step))

        if total_step % args.anneal_iter == 0:
            scheduler.step()
            print(("learning rate: %.6f" % scheduler.get_lr()[0]))

        if total_step % args.kl_anneal_iter == 0 and total_step >= args.warmup:
            beta = min(args.max_beta, beta + args.step_beta)


Currently at epoch: 1
[100] Beta: 0.002, KL: 107.29, Word: 54.82, Topo: 90.38, Assm: 94.34, PNorm: 97.32, GNorm: 38.25
[200] Beta: 0.002, KL: 119.78, Word: 76.46, Topo: 97.70, Assm: 96.51, PNorm: 100.29, GNorm: 25.24
[300] Beta: 0.002, KL: 109.79, Word: 79.78, Topo: 98.19, Assm: 96.33, PNorm: 102.55, GNorm: 24.44
[400] Beta: 0.002, KL: 108.09, Word: 81.70, Topo: 98.67, Assm: 97.38, PNorm: 104.31, GNorm: 17.59
Currently at epoch: 2
[500] Beta: 0.002, KL: 102.16, Word: 82.79, Topo: 98.68, Assm: 97.02, PNorm: 106.26, GNorm: 22.44
[600] Beta: 0.002, KL: 99.65, Word: 83.68, Topo: 98.81, Assm: 97.33, PNorm: 107.89, GNorm: 23.38
[700] Beta: 0.002, KL: 97.66, Word: 84.18, Topo: 98.96, Assm: 96.77, PNorm: 109.25, GNorm: 12.73
[800] Beta: 0.002, KL: 97.25, Word: 84.75, Topo: 98.83, Assm: 97.19, PNorm: 110.43, GNorm: 31.14
Currently at epoch: 3
[900] Beta: 0.002, KL: 92.50, Word: 85.06, Topo: 98.94, Assm: 97.14, PNorm: 111.74, GNorm: 17.41
[1000] Beta: 0.002, KL: 97.20, Word: 85.84, Topo: 99.07, 

KeyboardInterrupt: 