In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import selfies as sf
import torch
import time
from tqdm import tqdm, trange
from rdkit import RDLogger, Chem
RDLogger.DisableLog('rdApp.*')

In [3]:
## Inter-op parallelism
torch.set_num_interop_threads(4)
torch.get_num_interop_threads()
## Intra-op parallelism
torch.set_num_threads(4)
torch.get_num_threads()

4

In [4]:
class GPUCONFIGS:
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        if self.use_cuda: torch.cuda.set_device(self.device)
        
gpuconfigs = GPUCONFIGS()
print(torch.cuda.current_device())

0


In [5]:
REBADD_LIB_PATH = os.path.abspath(os.pardir)
if REBADD_LIB_PATH not in sys.path:
    sys.path = [REBADD_LIB_PATH] + sys.path

from rebadd.stackVAE import StackAugmentedVAE
from rebadd.datautils import GeneratorData

In [6]:
class DATACONFIGS:
    def __init__(self):
        ## input
        self.input_dir = os.path.join(os.pardir, 'TASK3', 'outputs_0_preprocess_data')
        self.train_data_path = os.path.join(self.input_dir, 'fragments_list.pkl')
        self.vocab_data_path = os.path.join(self.input_dir, 'vocabulary.csv')
        ## output
        self.output_dir = 'outputs_1_generate_molecules+frag'
        assert os.path.exists(self.output_dir)

dataconfigs = DATACONFIGS()

In [7]:
gen_data = GeneratorData(pickle_data_path=dataconfigs.train_data_path,
                         vocabulary_path=dataconfigs.vocab_data_path,
                         use_cuda=gpuconfigs.use_cuda)

print(f"Number of training samples: {len(gen_data.data)}")
print(f"Number of vocabulary: {len(gen_data.vocabs)}")
print(f"Maximum of seqlen: {gen_data.max_seqlen}")

Number of training samples: 6291
Number of vocabulary: 7496
Maximum of seqlen: 38


In [8]:
kwargs_generator = {"input_size"         : gen_data.n_characters,
                    "output_size"        : gen_data.n_characters,
                    "max_seqlen"         : 44,
                    "hidden_size"        : 256,
                    "latent_size"        : 64,
                    "n_layers"           : 4,
                    "has_stack"          : True,
                    "stack_width"        : 256,
                    "stack_depth"        : 20,
                    "lr"                 : 1e-4,
                    "use_cuda"           : gpuconfigs.use_cuda,
                    "device"             : gpuconfigs.device,
                    "optimizer_instance" : torch.optim.RMSprop}

generator = StackAugmentedVAE(**kwargs_generator)

In [9]:
class CKPTCONFIGS:
    def __init__(self):
        ################################################
        ## Get a pretrained model from the TASK3
        ################################################
        self.input_dir = os.path.join(os.pardir, 'TASK3', 'outputs_1_pretraining')
        
ckptconfigs = CKPTCONFIGS()

In [10]:
def SELFIES_generate(data, generator, sample_size=5000):
    generator.eval()
    
    generated = []
    
    for _ in trange(sample_size):
        ## SELFIES
        z = generator.sample_latent_vectors()
        sel = generator.evaluate(data, z=z, return_z=False, greedy=False)
        sel = sel.replace(data.start_token, '').replace(data.end_token, '')
        
        generated.append(sel)
        
    return generated
        

def save_SELFIES(filepath, selfies):
    with open(filepath, 'w') as fout:
        for sel in selfies:
            fout.write(f"{sel}\n")
    print(f"[INFO] {len(selfies)} SELFIES were saved in {filepath}")

In [11]:
filepath = os.path.join(ckptconfigs.input_dir, 'checkpoint.pth')
generator.load_model(filepath)

In [12]:
n_sampling = 5000
generated = SELFIES_generate(gen_data, generator, sample_size=n_sampling)   

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:11<00:00, 446.30it/s]


In [13]:
save_SELFIES(os.path.join(dataconfigs.output_dir, f'smi_after.csv'), generated)

[INFO] 5000 SELFIES were saved in outputs_1_generate_molecules+frag/smi_after.csv
