# Reference code: JAK2 activity optimization with ReLeaSE algorithm

## Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import time
import numpy as np
import pandas as pd
import selfies as sf
import torch
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [None]:
class GPUCONFIGS:
    def __init__(self):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
        if self.use_cuda: torch.cuda.set_device(self.device)
        
gpuconfigs = GPUCONFIGS()
print(torch.cuda.current_device())

In [None]:
REBADD_LIB_PATH = os.path.abspath(os.pardir)
if REBADD_LIB_PATH not in sys.path:
    sys.path = [REBADD_LIB_PATH] + sys.path

In [None]:
from rebadd.stackVAE import StackAugmentedVAE
from rebadd.datautils import GeneratorData, normalize_SMILES

## Setting up the generator

### Loading data for the generator

In [None]:
class DATACONFIGS:
    def __init__(self):
        ## input
        self.input_dir = os.path.join('processed_data', 'zinc15')
        self.train_data_path = os.path.join(self.input_dir, 'fragments_list.pkl')
        self.vocab_data_path = os.path.join(self.input_dir, 'vocabulary.csv')
        ## output
        self.output_dir = os.path.join('outputs_1_pretraining_ReBADD', 'zinc15')
        assert os.path.exists(self.output_dir)

dataconfigs = DATACONFIGS()

In [None]:
gen_data = GeneratorData(pickle_data_path=dataconfigs.train_data_path,
                         vocabulary_path=dataconfigs.vocab_data_path,
                         use_cuda=gpuconfigs.use_cuda)

print(f"Number of training samples: {len(gen_data.data)}")
print(f"Number of vocabulary: {len(gen_data.vocabs)}")
print(f"Maximum of seqlen: {gen_data.max_seqlen}")

## Initializing and training the generator

We will used stack augmented generative GRU as a generator. The model was trained to predict the next symbol from SMILES alphabet using the already generated prefix. Model was trained to minimize the cross-entropy loss between predicted symbol and ground truth symbol. Scheme of the generator when inferring new SMILES is shown below:

Initialize stack-augmented generative RNN:

In [None]:
kwargs_generator = {"input_size"         : gen_data.n_characters,
                    "output_size"        : gen_data.n_characters,
                    "max_seqlen"         : 44,
                    "hidden_size"        : 256,
                    "latent_size"        : 64,
                    "n_layers"           : 4,
                    "has_stack"          : True,
                    "stack_width"        : 256,
                    "stack_depth"        : 20,
                    "lr"                 : 0.001,
                    "use_cuda"           : gpuconfigs.use_cuda,
                    "device"             : gpuconfigs.device,
                    "optimizer_instance" : torch.optim.RMSprop}

In [None]:
my_generator = StackAugmentedVAE(**kwargs_generator)

If you want train the model from scratch, uncomment the lines below:

In [None]:
model_path = os.path.join(dataconfigs.output_dir, 'checkpoint.pth')
losses_path = os.path.join(dataconfigs.output_dir, 'losses.txt')

In [None]:
losses = my_generator.fit(gen_data, n_iterations=15000,
                          batch_size=32,
                          print_every=1000,
                          ckpt_every=1000,
                          model_path=model_path,
                          losses_path=losses_path)

In [None]:
my_generator.save_model(model_path)

In [None]:
with open(losses_path, 'w') as fout:
    fout.write("LOSS_VAE\tLOSS_RECONSTRUCTION\tLOSS_KLDIVERGENCE\tBETA\n")
    for loss_vae, loss_rec, loss_kld, beta in zip(losses['LOSS_VAE'], losses['LOSS_RECONSTRUCTION'], losses['LOSS_KLDIVERGENCE'], losses["BETA"]):
        fout.write(f"{loss_vae:.6f}\t{loss_rec:.6f}\t{loss_kld:.6f}\t{beta:.3f}\n")

In [None]:
sns.set_theme(style='whitegrid')

fig, axes = plt.subplots(3,1,figsize=(6,9.9))

axes[0].plot(losses['LOSS_VAE'], label='ELBO Loss', linewidth=2)
axes[1].plot(losses['LOSS_RECONSTRUCTION'], label='Reconstruction Loss', linewidth=2)
axes[2].plot(losses['LOSS_KLDIVERGENCE'], label='KL divergence', linewidth=2)

#ax.set_ylabel('Loss', fontsize=16)
axes[2].set_xlabel('Iterations', fontsize=16)

axes[0].legend(loc='best')
axes[1].legend(loc='best')
axes[2].legend(loc='best')

axes[2].set_yscale('log')

plt.tight_layout()
plt.show()

In [None]:
def SMILES_generate(generator, n_to_generate, gen_data):
    generated = []
    for i in trange(n_to_generate):
        sel = generator.evaluate(gen_data)
        sel = sel.replace(gen_data.start_token, "").replace(gen_data.end_token, "")
        smi = normalize_SMILES(sf.decoder(sel))
        generated.append(smi)
    return generated

In [None]:
def save_smiles(filepath, smiles):
    with open(filepath, 'w') as fout:
        for smi in smiles:
            fout.write(f"{smi}\n")
    print(f"[INFO] {len(smiles)} SMILES were saved in {filepath}")

In [None]:
n_sampling = 30000

smi_after = SMILES_generate(my_generator, n_sampling, gen_data)

In [None]:
save_smiles(os.path.join(dataconfigs.output_dir, "smi_after.csv"), smi_after)