In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
sys.path.append('../release')

In [None]:
import numpy as np
import pandas as pd
from tqdm import trange
import torch
use_cuda = torch.cuda.is_available()

In [None]:
from data import GeneratorData, PredictorData
from stackRNN import StackAugmentedRNN
from utils import get_fp
from reinforcement import Reinforcement

In [None]:
from sklearn.ensemble import RandomForestClassifier as RFC
from predictor import VanillaQSAR

# Training the predictor

In [None]:
np.random.seed(42)

pred_data = PredictorData('../data/egfr_with_pubchem.csv', get_features=get_fp)
model_instance = RFC
model_params = {'n_estimators': 250,
                'n_jobs': 10}
my_predictor = VanillaQSAR(model_instance=model_instance,
                           model_params=model_params,
                           ensemble_size=10)

In [None]:
# uncomment to train predictor model...
# my_predictor.fit_model(pred_data, cv_split='random')
# my_predictor.save_model('../project/checkpoints/predictor/egfr_rfc')

In [None]:
# ...or use pretrained model
my_predictor.load_model('../checkpoints/predictor/egfr_rfc')

# Pre-train the generative model

In [None]:
np.random.seed(42)
torch.manual_seed(42)

In [None]:
tokens = [' ', '<', '>', '#', '%', ')', '(', '+', '-', '/', '.', '1', '0', '3', '2', '5', '4', '7',
          '6', '9', '8', '=', 'A', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'P', 'S', '[', ']',
          '\\', 'c', 'e', 'i', 'l', 'o', 'n', 'p', 's', 'r']
gen_data_path = '../data/chembl_22_clean_1576904_sorted_std_final.smi'
gen_data = GeneratorData(training_data_path=gen_data_path, delimiter='\t',
                         cols_to_read=[0], keep_header=True, tokens=tokens)

In [None]:
hidden_size = 1500
stack_width = 1500
stack_depth = 200
layer_type = 'GRU'
lr = 0.0002
optimizer = torch.optim.Adadelta

my_generator = StackAugmentedRNN(input_size=gen_data.n_characters, hidden_size=hidden_size,
                                     output_size=gen_data.n_characters,
                                     layer_type=layer_type, n_layers=1, is_bidirectional=False,
                                     has_stack=True, stack_width=stack_width, stack_depth=stack_depth,
                                     use_cuda=use_cuda,
                                     lr=lr, optimizer_instance=optimizer)

In [None]:
model_path = '../checkpoints/generator/checkpoint_biggest_rnn1'
batch_size = 16
n_iterations = 1500000

# uncomment to pretrain generator...
# losses = my_generator.fit(gen_data, batch_size, n_iterations)
# my_generator.save_model(model_path)
# with open('losses.txt','wt') as f:
#     for val in losses:
#         print(val, file=f)

In [None]:
# ... or load pre-trained model
my_generator.load(model_path)

# Fine-tune generative model on molecules with predicted activity against EGFR

In [None]:
def get_reward_max(smiles, predictor, threshold, invalid_reward=1.0, get_features=get_fp):
    mol, prop, nan_smiles = predictor.predict([smiles], get_features=get_features)
    if len(nan_smiles) == 1:
        return invalid_reward
    if prop[0] >= threshold:
        return 10.0
    else:
        return invalid_reward

In [None]:
RL_max = Reinforcement(my_generator, my_predictor, get_reward_max)

In [None]:
data_path = ['../data/egfr_actives.smi',
             '../data/egfr_enamine.smi',
             '../data/egfr_mixed.smi']
save_path = ['../checkpoints/generator/egfr_clf_rnn_primed',
             '../checkpoints/generator/egfr_clf_rnn_enamine_primed',
             '../checkpoints/generator/egfr_clf_rnn_mixed_primed']
n_iterations = 250

for dpath, mpath in zip(data_path, save_path):
    print('Pretraining on %s' % dpath)
    np.random.seed(42)
    torch.manual_seed(42)
    
    actives_data = GeneratorData(dpath,
                                 tokens=tokens,
                                 cols_to_read=[0],
                                 keep_header=True)
    RL_max.generator.load_model(model_path)
    for i in range(n_iterations):
        print(i)
        RL_max.fine_tune(data=actives_data, n_steps=n_iterations, batch_size=16)
    RL_max.generator.save_model(mpath)