In [1]:
import os
import gentrl
import torch
import pandas as pd
# torch.cuda.set_device(0)

In [2]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol


def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [3]:
# ! wget --no-clobber https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv

In [4]:
import os
target_path = 'osm_plogp_plogpm.csv'
input_datafile = 'Ion Regulation Data for OSM Competition - Malaria Molecules.csv'
if os.path.exists(target_path):
    print(f"Loading from {target_path}")
    df = pd.read_csv(target_path)
    # df = df[df['SPLIT'] == 'train']
else:
    print(f"Loading original data from {input_datafile} and calculating penalized_logP")
    df = pd.read_csv(input_datafile)
    # df = df[df['SPLIT'] == 'train']
    column_name = 'SMILES'
    df = df[(df[column_name].notnull()) & (df[column_name]!=u'')]
    df['plogP'] = df[column_name].apply(penalized_logP)
    df.to_csv(target_path, index=None)
print(f"Training data shape {df.shape}")

Loading from osm_plogp_plogpm.csv
Training data shape (1248, 18)


In [5]:
print(f"Training data shape {df.shape}")

Training data shape (1248, 18)


In [6]:
torch.cuda.is_available()
torch.cuda.device_count()
# torch.cuda.get_device_capability()

0

In [7]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)
#model.cuda();

In [8]:
md = gentrl.MolecularDataset(sources=[
    {'path': target_path,
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

from torch.utils.data import DataLoader
train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)

In [9]:
#sm_str = 'CCN(CC)CCN1C(=N)N(CC(O)c2ccc(Cl)c(Cl)c2)c3ccccc13'
#from gentrl.tokenizer import encode, smiles_tokenizer
#tenc = encode([sm_str])
#print(tenc)

In [10]:
for ind, sm_str in enumerate(df['SMILES']):
    try:
        test_enc = model.enc.encode([sm_str])
    except Exception as exc:
        print(f'Error encoding line {ind} {sm_str}: {exc}')

In [11]:
model.train_as_vaelp(train_loader, lr=1e-4, buf_size=500)

Epoch 0 :
!!!!!!!!!!!!!!!!!!!!!!!!loss:  3.6;rec: -3.463;kl: -61.76;log_p_y_by_z: -1.985;log_p_z_by_y: -77.77;
Epoch 2 :
!!!!!!!!!!!!!!!!!!!!!!!!loss: 2.74;rec: -2.621;kl: -62.47;log_p_y_by_z: -1.819;log_p_z_by_y: -74.17;
Epoch 4 :
!!!!!!!!!!!!!!!!!!!!!!!!loss: 2.351;rec: -2.226;kl: -58.07;log_p_y_by_z: -1.83;log_p_z_by_y: -70.13;
Epoch 6 :
!!!!!!!!!!!!!!!!!!!!!!!!loss: 2.137;rec: -2.006;kl: -51.25;log_p_y_by_z: -1.825;log_p_z_by_y: -65.24;
Epoch 8 :
!!!!!!!!!!!!!!!!!!!!!!!!loss: 1.978;rec: -1.84;kl: -43.65;log_p_y_by_z: -1.815;log_p_z_by_y: -60.59;


<gentrl.gentrl.TrainStats at 0x7faa51235d90>

In [12]:
! mkdir -p osm_saved_gentrl

In [13]:
model.save('./osm_saved_gentrl/')