In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import selfies as sf
import torch
import time
from tqdm import tqdm, trange
from rdkit import RDLogger, Chem
RDLogger.DisableLog('rdApp.*')

In [4]:
## Inter-op parallelism
torch.set_num_interop_threads(4)
torch.get_num_interop_threads()
## Intra-op parallelism
torch.set_num_threads(4)
torch.get_num_threads()

4

In [5]:
class GPUCONFIGS:
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        if self.use_cuda: torch.cuda.set_device(self.device)
        
gpuconfigs = GPUCONFIGS()
print(torch.cuda.current_device())

0


In [6]:
REBADD_LIB_PATH = os.path.abspath(os.pardir)
if REBADD_LIB_PATH not in sys.path:
    sys.path = [REBADD_LIB_PATH] + sys.path

from rebadd.stackVAE import StackAugmentedVAE
from rebadd.datautils import GeneratorData

In [7]:
from ReBADD_config import Reward_bcl2_bclxl_bclw as Reward

In [8]:
class DATACONFIGS:
    def __init__(self):
        ## input
        self.input_dir = os.path.join('processed_data', 'zinc15')
        self.train_data_path = os.path.join(self.input_dir, 'fragments_list.pkl')
        self.vocab_data_path = os.path.join(self.input_dir, 'vocabulary.csv')
        ## output
        self.output_dir = os.path.join('outputs_6_generate_molecules', 'zinc15')
        assert os.path.exists(self.output_dir)

dataconfigs = DATACONFIGS()

In [9]:
gen_data = GeneratorData(pickle_data_path=dataconfigs.train_data_path,
                         vocabulary_path=dataconfigs.vocab_data_path,
                         use_cuda=gpuconfigs.use_cuda)

print(f"Number of training samples: {len(gen_data.data)}")
print(f"Number of vocabulary: {len(gen_data.vocabs)}")
print(f"Maximum of seqlen: {gen_data.max_seqlen}")

Number of training samples: 599957
Number of vocabulary: 44766
Maximum of seqlen: 44


In [10]:
reward_ft = Reward(use_cuda=gpuconfigs.use_cuda, device=gpuconfigs.device)

[DEBUG] BA(navitoclax,P10415) = 9.746 (GT:9.745)
[DEBUG] BA(navitoclax,Q07817) = 7.525 (GT:7.524)
[DEBUG] BA(navitoclax,Q92843) = 6.598 (GT:6.597)
[DEBUG] SA(navitoclax) = 4.131 (GT:4.131)


In [11]:
kwargs_generator = {"input_size"         : gen_data.n_characters,
                    "output_size"        : gen_data.n_characters,
                    "max_seqlen"         : 44,
                    "hidden_size"        : 256,
                    "latent_size"        : 64,
                    "n_layers"           : 4,
                    "has_stack"          : True,
                    "stack_width"        : 256,
                    "stack_depth"        : 20,
                    "lr"                 : 1e-4,
                    "use_cuda"           : gpuconfigs.use_cuda,
                    "device"             : gpuconfigs.device,
                    "optimizer_instance" : torch.optim.RMSprop}

generator = StackAugmentedVAE(**kwargs_generator)

In [12]:
class CKPTCONFIGS:
    def __init__(self):
        self.input_dir = 'outputs_2_optimize_ReBADD'
        self.modelnames = ['zinc15']
        self.numbers = ['0150']
        
ckptconfigs = CKPTCONFIGS()

In [13]:
def normalize_SMILES(smi):
    mol = Chem.MolFromSmiles(smi)
    smi_rdkit = Chem.MolToSmiles(
        mol,
        isomericSmiles=False,   # modified because this option allows special tokens (e.g. [125I])
        kekuleSmiles=False,     # default
        rootedAtAtom=-1,        # default
        canonical=True,         # default
        allBondsExplicit=False, # default
        allHsExplicit=False     # default
    )
    return smi_rdkit


def generate_single_SMILES(data, generator, reward_ft, K, threshold):
    best_smi = 'C'
    best_rwd = threshold
    
    for _ in range(K):
        ## SELFIES
        z = generator.sample_latent_vectors()
        sel = generator.evaluate(data, z=z, return_z=False, greedy=False)
        sel = sel.replace(data.start_token, '').replace(data.end_token, '')
        
        ## SMILES
        smi = sf.decoder(sel)
    
        ## Reward
        try:
            smi = normalize_SMILES(smi)
            rwd = reward_ft(smi, return_min=True)
        except:
            rwd = threshold
            
        if rwd > best_rwd:
            best_smi = smi
            best_rwd = rwd
        
    return best_smi


def generate_SMILES(data, generator, reward_ft, sample_size, K, threshold):
    results = []
    for _ in trange(sample_size):
        best_smi = generate_single_SMILES(data, generator, reward_ft, K, threshold)
        results.append(best_smi)
    return results


def generate_novel_SMILES(data, generator, reward_ft, sample_size, K, threshold, calc_sim):
    results = []
    for _ in trange(sample_size):
        best_smi = 'C'
        best_sim = 1.
        for _ in range(K):
            smi = generate_single_SMILES(data, generator, reward_ft, K, threshold)
            sim = calc_sim(smi)
            if sim < best_sim:
                best_sim = sim
                best_smi = smi
        results.append(best_smi)
    return results


def SMILES_generate(data, generator, reward_ft, sample_size=5000, K=5, threshold=0., calc_sim=None):
    generator.eval()
    if calc_sim:
        return generate_novel_SMILES(data, generator, reward_ft, sample_size, K, threshold, calc_sim)
    else:
        return generate_SMILES(data, generator, reward_ft, sample_size, K, threshold)
    

def save_smiles(filepath, smiles):
    with open(filepath, 'w') as fout:
        for smi in smiles:
            fout.write(f"{smi}\n")
    print(f"[INFO] {len(smiles)} SMILES were saved in {filepath}")

In [14]:
n_sampling = 5000
K = 10

In [15]:
for modelname in ckptconfigs.modelnames:
    for num in ckptconfigs.numbers:
        
        filepath = os.path.join(ckptconfigs.input_dir, modelname, f'checkpoint.pth.{num}')
        generator.load_model(filepath)
        
        for k in range(K):
            generated = SMILES_generate(gen_data, generator, reward_ft, sample_size=n_sampling)

            save_smiles(os.path.join(dataconfigs.output_dir, f'smi_after.csv.{k}'), generated)
        

100%|██████████| 5000/5000 [18:41<00:00,  4.46it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.0


100%|██████████| 5000/5000 [19:02<00:00,  4.38it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.1


100%|██████████| 5000/5000 [19:41<00:00,  4.23it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.2


100%|██████████| 5000/5000 [29:08<00:00,  2.86it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.3


100%|██████████| 5000/5000 [34:04<00:00,  2.45it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.4


100%|██████████| 5000/5000 [24:07<00:00,  3.46it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.5


100%|██████████| 5000/5000 [22:18<00:00,  3.74it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.6


100%|██████████| 5000/5000 [22:20<00:00,  3.73it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.7


100%|██████████| 5000/5000 [22:18<00:00,  3.74it/s]


[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.8


100%|██████████| 5000/5000 [18:34<00:00,  4.49it/s]

[INFO] 5000 SMILES were saved in outputs_6_generate_molecules/zinc15/smi_after.csv.9



