In [106]:
import moses
import os
import numpy as np
import logging, sys
logging.disable(sys.maxsize)
import warnings
warnings.filterwarnings('ignore')
#import rdkit.Chem as Chem

In [107]:
def get_mol(smiles_or_mol):
    '''
    Loads SMILES/molecule into RDKit's object
    '''
    if isinstance(smiles_or_mol, str):
        if len(smiles_or_mol) == 0:
            return None
        mol = Chem.MolFromSmiles(smiles_or_mol)
        if mol is None:
            return None
        try:
            Chem.SanitizeMol(mol)
        except ValueError:
            return None
        return mol
    return smiles_or_mol

In [108]:
def process_sampled_folder(sampled_dir, epochs, cpus):
    valid_list = []
    unique_list = []
    novel_list = []
    div_list = []
    for epoch in epochs:
        can_smiles_list = []
        valid_count = 0
        novel_count = 0
        epoch=str(epoch)
        gen_smiles = os.path.join(sampled_dir,f'{str(epoch)}.smi')
        gen_list = smiles_to_list(gen_smiles)
        c_valid = moses.metrics.fraction_valid(gen_list, n_jobs=cpus)
        gen_valid = moses.metrics.metrics.remove_invalid(gen_list, canonize=True, n_jobs=cpus)
        c_unique = moses.metrics.fraction_unique(gen_valid, n_jobs = cpus, check_validity=False)
        c_novel = moses.metrics.metrics.novelty(gen_list,train_list, n_jobs= cpus)
        div = moses.metrics.internal_diversity(gen_list, n_jobs= cpus)
        valid_list.append(c_valid)
        o_unique = c_valid*c_unique
        unique_list.append(o_unique)
        novel_list.append(o_unique* c_novel)
        div_list.append(div)
    return np.array([epochs,valid_list,unique_list,novel_list,div_list])

In [109]:
def smiles_to_list(smiles_file):
    with open(smiles_file) as f:
        content = f.readlines()
    out_list = [x.strip() for x in content] 
    return out_list

In [110]:
train_smiles = os.path.expanduser("~/reinvent-2/data/model1_cano.smi")
#train_smiles = '/Volumes/mg4417/home/reinvent-2/data/model1_cano.smi'
train_list = smiles_to_list(train_smiles)

In [111]:
sampled_dir = os.path.expanduser("~/reinvent-2/outputs/REINVENT_RL_demo/report/pure_gds_alt/sampled")
#sampled_dir = '/Volumes/mg4417/home/reinvent-2/outputs/REINVENT_transfer_learning_demo/sampled'
#epochs = ['1','10','20','30','40','50','60','70','80','90','100','110','120','130','140','150','160','170','180','190','200']
epochs = [10,100,200,300,400,500,600,700,800,900,1000]

In [112]:
%%capture output
out_arr = process_sampled_folder(sampled_dir, epochs, cpus=2)

In [113]:
out_arr[0]

array([  10.,  100.,  200.,  300.,  400.,  500.,  600.,  700.,  800.,
        900., 1000.])

In [114]:
output_dir = os.path.expanduser("~/reinvent-2/outputs/REINVENT_RL_demo/report/metrics")
#output_dir = '/Volumes/mg4417/home/reinvent-2/outputs/REINVENT_transfer_learning_demo/report'
np.save(os.path.join(output_dir,'pure_gds_alt.npy'), out_arr)