In [1]:
import torch
from torch.utils.data import DataLoader
import gentrl
import gentrl.lp
from gentrl.tokenizer import encode, get_vocab_size
from gentrl.new_dataloader import NewMolecularDataset
import time
import torch.multiprocessing as mp

## Set GPU

In [2]:
torch.cuda.set_device(3)

In [3]:
RANDOM_SEED = 42

In [4]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, latent_descr=50 * [('c', 20)], feature_descr=[('c', 20)], beta=0.001)
# this moves the model to GPU
model.to('cuda');

In [5]:
device = torch.device('cuda')

In [6]:
md = NewMolecularDataset(device=device,
    sources=[
    {'path':'train_plogp_plogpm.csv',
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

In [7]:
gpu_dataset = md.create_tensor_dataset()

In [8]:
BATCH_SIZE = 512
LR = 1e-4
NUM_EPOCHS = 1
NUM_WORKERS = 0
PIN_MEMORY= False

In [9]:
train_loader = DataLoader(gpu_dataset,
                      batch_size=BATCH_SIZE,
                      shuffle=True,
                      num_workers=NUM_WORKERS,
                      pin_memory=PIN_MEMORY,
                      drop_last=True)    

In [10]:
start_time = time.time()
model.train_as_vaelp(train_loader, lr=LR, num_epochs=NUM_EPOCHS)
end_time = time.time()
duration = end_time - start_time
print(f"Time Taken is {duration/60} min")

Epoch 0 :
loss: 2.619;rec: -2.532;kl: -65.83;log_p_y_by_z: -1.527;log_p_z_by_y: -69.82;
loss: 1.843;rec: -1.741;kl: -49.16;log_p_y_by_z: -1.518;log_p_z_by_y: -52.89;
loss: 1.511;rec: -1.396;kl: -34.41;log_p_y_by_z: -1.492;log_p_z_by_y: -40.24;
loss: 1.314;rec: -1.212;kl: -27.66;log_p_y_by_z: -1.304;log_p_z_by_y: -38.73;
loss: 1.175;rec: -1.1;kl: -23.55;log_p_y_by_z: -0.9842;log_p_z_by_y: -37.69;
loss: 1.084;rec: -1.023;kl: -20.61;log_p_y_by_z: -0.8156;log_p_z_by_y: -37.03;
loss: 1.025;rec: -0.969;kl: -18.4;log_p_y_by_z: -0.7396;log_p_z_by_y: -36.43;
loss: 0.9764;rec: -0.9257;kl: -16.74;log_p_y_by_z: -0.6749;log_p_z_by_y: -36.16;
loss: 0.9379;rec: -0.8927;kl: -15.29;log_p_y_by_z: -0.6057;log_p_z_by_y: -36.07;
loss: 0.9057;rec: -0.8644;kl: -14.08;log_p_y_by_z: -0.5544;log_p_z_by_y: -36.31;
loss: 0.8817;rec: -0.8415;kl: -13.41;log_p_y_by_z: -0.5366;log_p_z_by_y: -36.54;
loss: 0.8533;rec: -0.8194;kl: -12.62;log_p_y_by_z: -0.465;log_p_z_by_y: -36.58;
loss: 0.8351;rec: -0.8007;kl: -12.4;log_

In [None]:
# from moses.metrics import mol_passes_filters, QED, SA, logP
# from moses.metrics.utils import get_n_rings, get_mol


# def get_num_rings_6(mol):
#     r = mol.GetRingInfo()
#     return len([x for x in r.AtomRings() if len(x) > 6])


# def penalized_logP(mol_or_smiles, masked=False, default=-5):
#     mol = get_mol(mol_or_smiles)
#     if mol is None:
#         return default
#     reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
#     if masked and not mol_passes_filters(mol):
#         return default
#     return reward

In [None]:
# df = pd.read_csv('dataset_v1.csv')
# df = df[df['SPLIT'] == 'train']
# df['plogP'] = df['SMILES'].apply(penalized_logP)
# df.to_csv('train_plogp_plogpm.csv', index=None)

In [None]:
# ! mkdir -p saved_gentrl

In [None]:
# model.save('./saved_gentrl/')