## Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os
import sys
import time
import numpy as np
import pandas as pd
import selfies as sf
import torch
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [3]:
class GPUCONFIGS:
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        if self.use_cuda: torch.cuda.set_device(self.device)
        
gpuconfigs = GPUCONFIGS()
print(torch.cuda.current_device())

0


In [4]:
REBADD_LIB_PATH = os.path.abspath(os.pardir)
if REBADD_LIB_PATH not in sys.path:
    sys.path = [REBADD_LIB_PATH] + sys.path

from rebadd.stackVAE import StackAugmentedVAE
from rebadd.datautils import GeneratorData

## Setting up the generator

### Loading data for the generator

In [5]:
class DATACONFIGS:
    def __init__(self):
        ## input
        self.input_dir = 'outputs_0_preprocess_data'
        self.train_data_path = os.path.join(self.input_dir, 'fragments_list.pkl')
        self.vocab_data_path = os.path.join(self.input_dir, 'vocabulary.csv')
        ## output - please manually create an output directory
        self.output_dir = 'outputs_1_pretraining'
        assert os.path.exists(self.output_dir)

dataconfigs = DATACONFIGS()

In [6]:
gen_data = GeneratorData(pickle_data_path=dataconfigs.train_data_path,
                         vocabulary_path=dataconfigs.vocab_data_path,
                         use_cuda=gpuconfigs.use_cuda)

print(f"Number of training samples: {len(gen_data.data)}")
print(f"Number of vocabulary: {len(gen_data.vocabs)}")
print(f"Maximum of seqlen: {gen_data.max_seqlen}")

Number of training samples: 295601
Number of vocabulary: 34620
Maximum of seqlen: 34


## Initializing and training the generator

We will used stack augmented generative GRU as a generator. The model was trained to predict the next symbol from SMILES alphabet using the already generated prefix. Model was trained to minimize the cross-entropy loss between predicted symbol and ground truth symbol. Scheme of the generator when inferring new SMILES is shown below:

Initialize stack-augmented generative RNN:

In [7]:
kwargs_generator = {"input_size"         : gen_data.n_characters,
                    "output_size"        : gen_data.n_characters,
                    "max_seqlen"         : 40,
                    "hidden_size"        : 256,
                    "latent_size"        : 64,
                    "n_layers"           : 4,
                    "has_stack"          : True,
                    "stack_width"        : 256,
                    "stack_depth"        : 20,
                    "lr"                 : 1e-3,
                    "use_cuda"           : gpuconfigs.use_cuda,
                    "device"             : gpuconfigs.device,
                    "optimizer_instance" : torch.optim.RMSprop}

In [8]:
my_generator = StackAugmentedVAE(**kwargs_generator)

If you want train the model from scratch, uncomment the lines below:

In [9]:
model_path = os.path.join(dataconfigs.output_dir, 'checkpoint.pth')
losses_path = os.path.join(dataconfigs.output_dir, 'losses.txt')

In [None]:
losses = my_generator.fit(gen_data, n_iterations=20000,
                          batch_size=50,
                          print_every=1000,
                          ckpt_every=1000,
                          model_path=model_path,
                          losses_path=losses_path)

Training in progress...:   5%|███████▏                                                                                                                                        | 1000/20000 [20:28<6:37:17,  1.25s/it]

[01000 (5.0%) 20m 28s], Loss_vae:4.834, Loss_rec:4.583, Loss_kld:13.933, Beta:0.050
selfies: [O][=C][Branch1][#Branch2][C][C][=C][C][=C][C][=C][Ring1][=Branch1][C][C][C][C][C][N][Ring1][=Branch1][C][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1]
smiles: O=C(CC1=CC=CC=C1)C2CCCCN2CC3=CC=C(Cl)C=C3


Training in progress...:  10%|██████████████▍                                                                                                                                 | 2000/20000 [40:55<6:13:42,  1.25s/it]

[02000 (10.0%) 40m 55s], Loss_vae:4.218, Loss_rec:3.834, Loss_kld:5.316, Beta:0.100
selfies: [C][O][C][=C][C][=C][C][=Branch1][Ring2][=C][Ring1][=Branch1][C][=Branch1][C][=O][N][Branch1][Branch1][C][C][C][O][C][=C][C][=C][N][=C][Ring1][=Branch1][N][Ring1][=Branch2]
smiles: COC1=CC=CC(=C1)C(=O)N(CCC2O)C3=CC=CN=C3N2


Training in progress...:  15%|█████████████████████▎                                                                                                                        | 3000/20000 [1:01:24<5:59:34,  1.27s/it]

[03000 (15.0%) 61m 24s], Loss_vae:4.087, Loss_rec:3.794, Loss_kld:2.428, Beta:0.150
selfies: [C][N][Branch1][C][C][C][=Branch1][C][=O][C][C][C][C][C][S][C][N][Branch1][=C][C][N][C][=Branch1][C][=O][C][C][C][C][C][Ring1][Branch1][C][C][Ring2][Ring1][Branch2]
smiles: C1N(C)C(=O)CCCCCSCN(CNC(=O)C2CCCC2)CC1


Training in progress...:  20%|████████████████████████████▍                                                                                                                 | 4000/20000 [1:21:49<5:33:32,  1.25s/it]

[04000 (20.0%) 81m 48s], Loss_vae:3.959, Loss_rec:3.826, Loss_kld:0.782, Beta:0.200
selfies: [C][C][=Branch1][C][=O][N][=C][S][C][=C][N][Ring1][#C][Branch1][Ring2][C][Ring1][=Branch2][Branch1][Branch1][C][C][O][C][C][Ring1][#Branch1][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1]
smiles: C1C(=O)N=CSC2=CN12


Training in progress...:  25%|███████████████████████████████████▌                                                                                                          | 5000/20000 [1:42:12<5:13:25,  1.25s/it]

[05000 (25.0%) 102m 12s], Loss_vae:3.822, Loss_rec:3.766, Loss_kld:0.253, Beta:0.250
selfies: [N][=N][C][=C][C][N][Ring1][Ring2][C][=C][C][=C][Branch1][=Branch2][O][C][C][C][C][N][Ring1][Ring2][C][C][C][N][C][C][O][C][Ring1][#Branch2]
smiles: N=NC1=CCN1C=CC=C(OCC2CC3N2)CCCNCCOC3


Training in progress...:  30%|██████████████████████████████████████████▌                                                                                                   | 6000/20000 [2:02:39<5:02:01,  1.29s/it]

[06000 (30.0%) 122m 39s], Loss_vae:3.711, Loss_rec:3.682, Loss_kld:0.108, Beta:0.300
selfies: [C][O][C][=Branch1][C][=O][N][C][Branch1][C][C][N][C][=Branch1][C][=O][C][C][C][N][N][=C][Branch1][C][C][C][=C][C][=C][Ring1][#Branch1][C][Branch1][=N][C][=C][C][=C][Branch1][Ring1][C][#N][C][=C][Ring1][Branch2][C][Ring2][Ring1][Ring2]
smiles: COC(=O)NC(C)NC(=O)CC1CNN=C(C)C=CC=CC(C2=CC=C(C#N)C=C2)C1


Training in progress...:  35%|█████████████████████████████████████████████████▋                                                                                            | 7000/20000 [2:23:06<4:37:31,  1.28s/it]

[07000 (35.0%) 143m 6s], Loss_vae:3.628, Loss_rec:3.605, Loss_kld:0.073, Beta:0.350
selfies: [C][=C][C][Branch1][Branch1][C][C][C][C][=C][Branch1][C][C][C][=Branch1][C][=O][C][=C][Branch1][C][O][C][N][C][C][C][N][C][=C][C][=N][C][=C][C][Ring1][=Branch1]
smiles: C=CC(CCCC)=C(C)C(=O)C=C(O)CNCCCNC=C1C=NC=CC1


Training in progress...:  40%|████████████████████████████████████████████████████████▊                                                                                     | 8000/20000 [2:43:36<4:14:17,  1.27s/it]

[08000 (40.0%) 163m 36s], Loss_vae:3.555, Loss_rec:3.533, Loss_kld:0.060, Beta:0.400
selfies: [O][=C][S][N][=C][Ring1][N][C][=C][N][=C][C][Branch1][C][Br][=C][Ring1][#Branch1]
smiles: O=CSN=CC1=CN=CC(Br)=C1


Training in progress...:  42%|███████████████████████████████████████████████████████████▊                                                                                  | 8416/20000 [2:52:06<3:53:21,  1.21s/it]

In [None]:
my_generator.save_model(model_path)

In [None]:
with open(losses_path, 'w') as fout:
    fout.write("LOSS_VAE\tLOSS_RECONSTRUCTION\tLOSS_KLDIVERGENCE\tBETA\n")
    for loss_vae, loss_rec, loss_kld, beta in zip(losses['LOSS_VAE'], losses['LOSS_RECONSTRUCTION'], losses['LOSS_KLDIVERGENCE'], losses["BETA"]):
        fout.write(f"{loss_vae:.6f}\t{loss_rec:.6f}\t{loss_kld:.6f}\t{beta:.3f}\n")

In [None]:
sns.set_theme(style='whitegrid')

fig, axes = plt.subplots(3,1,figsize=(6,9.9))

axes[0].plot(losses['LOSS_VAE'][10:], label='ELBO Loss', linewidth=2)
axes[1].plot(losses['LOSS_RECONSTRUCTION'][10:], label='Reconstruction Loss', linewidth=2)
axes[2].plot(losses['LOSS_KLDIVERGENCE'][10:], label='KL divergence', linewidth=2)

#ax.set_ylabel('Loss', fontsize=16)
axes[2].set_xlabel('Iterations', fontsize=16)

axes[0].legend(loc='best')
axes[1].legend(loc='best')
axes[2].legend(loc='best')

axes[2].set_yscale('log')

plt.tight_layout()
plt.show()

In [None]:
def SMILES_generate(generator, n_to_generate, gen_data):
    generated = []
    for i in trange(n_to_generate):
        z = generator.sample_latent_vectors()
        sel = generator.evaluate(gen_data, z=z)
        sel = sel.replace(gen_data.start_token, "").replace(gen_data.end_token, "")
        smi = sf.decoder(sel)
        generated.append(smi)
    return generated

In [None]:
def save_smiles(filepath, smiles):
    with open(filepath, 'w') as fout:
        for smi in smiles:
            fout.write(f"{smi}\n")
    print(f"[INFO] {len(smiles)} SMILES were saved in {filepath}")

In [None]:
n_sampling = 30000

smi_after = SMILES_generate(my_generator, n_sampling, gen_data)

In [None]:
save_smiles(os.path.join(dataconfigs.output_dir, "smi_after.csv"), smi_after)