In [1]:
import gentrl
import torch
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import random
from utilities.candiconfig import CandiConfig

torch.cuda.set_device(0)

In [2]:
from moses.metrics import mol_passes_filters, QED, SA, logP
from moses.metrics.utils import get_n_rings, get_mol

def get_num_rings_6(mol):
    r = mol.GetRingInfo()
    return len([x for x in r.AtomRings() if len(x) > 6])


def penalized_logP(mol_or_smiles, masked=False, default=-5):
    mol = get_mol(mol_or_smiles)
    if mol is None:
        return default
    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
    if masked and not mol_passes_filters(mol):
        return default
    return reward

In [3]:
config = CandiConfig(smiles_format=2, topn_fp_features=5, mode='threshold', max_fp_features=2048, threshold=0.3, morgan_radius=2)
with open(config.FpsSOM_model, 'rb') as infile:
    fps_som = pickle.load(infile)

In [7]:
# LA_smiles, _ = fps_som.get_LA_zinc_data(multiplier=10)
# LA_list = [s+',train' for s in LA_smiles]

# zinc_list = []
# with open(config.zinc_file, 'r', encoding='utf-8') as f:
#     for line in f:
#         zinc_list.append(line.strip())

# zinc_list.remove('SMILES,SPLIT')
        
# mixed_list = zinc_list + LA_list

# random.shuffle(mixed_list)

# with open(config.mixed_dataset, 'w', encoding='utf-8') as f:
#     f.write('SMILES,SPLIT\n')
#     for s in mixed_list:
#         f.write(s+'\n')

In [5]:
# df = pd.read_csv(config.mixed_dataset)
# df = df[df['SPLIT'] == 'train']
# # df['plogP'] = df['SMILES'].apply(penalized_logP)
# df['reward'] = df['SMILES'].apply(fps_som.som_reward)
# df.to_csv(config.mixed_train_dataset, index=None)

In [8]:
md = gentrl.MolecularDataset(sources=[
    {'path':config.mixed_train_dataset,
     'smiles': 'SMILES',
     'prob': 1,
#      'plogP' : 'plogP',
     'reward' : 'reward',
    }], 
#    props=['plogP', 'reward'])
    props=['reward'])

from torch.utils.data import DataLoader
train_loader = DataLoader(md, batch_size=50, shuffle=True, num_workers=1, drop_last=True)

In [9]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], 1* [('c', 20)], beta=0.001)
model.cuda()

GENTRL(
  (enc): RNNEncoder(
    (embs): Embedding(28, 256)
    (rnn): GRU(256, 256, num_layers=2)
    (final_mlp): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=100, bias=True)
    )
  )
  (dec): DilConvDecoder(
    (latent_fc): Linear(in_features=50, out_features=128, bias=True)
    (input_embeddings): Embedding(28, 128)
    (logits_1x1_layer): Conv1d(128, 28, kernel_size=(1,), stride=(1,))
    (parameters): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 28x128 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 28x128x1 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 28 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 128x50 (GPU 0)]
        (4): Parameter containing: [torch.cuda.FloatTensor of size 128 (GPU 0)]
        (5): Parameter containing: [torch

In [10]:
model.train_as_vaelp(train_loader, lr=1e-4)

Epoch 0 :
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

KeyboardInterrupt: 