In [1]:
import gentrl
import torch
import pandas as pd
import time
from torch.utils.data import DataLoader
torch.cuda.set_device(3)

In [2]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
model.cuda();

In [3]:
md = gentrl.MolecularDataset(sources=[
    {'path':'train_plogp_plogpm.csv',
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

In [4]:
batch_size = 512
num_workers = 8
pin_memory = True
lr = 1e-4
num_epochs = 1

In [5]:
train_loader = DataLoader(md, batch_size=batch_size,
                          shuffle=True, num_workers=num_workers,
                          pin_memory=pin_memory, drop_last=True)

In [6]:
start_time = time.time()
model.train_as_vaelp(train_loader, num_epochs=num_epochs, lr=lr)
end_time = time.time()
duration = end_time - start_time
print(f"Time Taken is {duration/60} min")

Epoch 0 :
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 2.694;rec: -2.609;kl: -66.05;log_p_y_by_z: -1.51;log_p_z_by_y: -70.09;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.825;rec: -1.722;kl: -48.2;log_p_y_by_z: -1.508;log_p_z_by_y: -52.58;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.519;rec: -1.412;kl: -32.75;log_p_y_by_z: -1.406;log_p_z_by_y: -40.85;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.322;rec: -1.238;kl: -28.42;log_p_y_by_z: -1.127;log_p_z_by_y: -37.72;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.186;rec: -1.121;kl: -24.61;log_p_y_by_z: -0.8992;log_p_z_by_y: -36.6;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.103;rec: -1.046;kl: -22.72;log_p_y_by_z: -0.8059;log_p_z_by_y: -36.64;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.036;rec: -0.9846;kl: -20.62;log_p_y_by_z: -0.7214;log_p_z_by_y: -36.18;
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!loss: 0.9878;rec: -0.9406;kl: -18.58;log_p_

In [None]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol


def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [None]:
! mkdir -p saved_gentrl

In [None]:
model.save('./saved_gentrl/')