In [35]:
import os
from tqdm import tqdm 
from rdkit.Chem import QED
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import pandas as pd
import time
from IPython.display import clear_output
import itertools
import importlib
import random
from utils import *

In [37]:
with open("./data/zinc15_selfies_tokens.txt", "r") as file:
    token_in_dataset = [line.strip() for line in file]

print("Loaded tokens:", token_in_dataset)
word2index = {"<pad>": 0, "<unk>": 1, "<sos>": 2, "<eos>": 3}
index2word = {0: "<pad>", 1: "<unk>", 2: "<sos>", 3: "<eos>"}

start_index = max(index2word.keys()) + 1

for i, token in enumerate(token_in_dataset, start=start_index):
    word2index[token] = i
    index2word[i] = token
    
print("word2index:", word2index)
print("index2word:", index2word)

data_path = './data/zinc15_sample.csv'
df = pd.read_csv(data_path)
smiles = df['smiles'].to_numpy()
selfies = df['selfies'].to_numpy()

prop_num = 3
prop_data = [df['bcl2'].to_numpy(),
             df['bclxl'].to_numpy(),
             df['bclw'].to_numpy()]

embed_dim = 256                   #Embedding Vector Dim 
hidden_dim = 512                   #Latent Vector Dim
latent_dim = 256
en_n_l = 3                         #Encoder GRU Number of Layers 
de_n_l = 2                         #Decoder GRU Number of Layers
base_batch_size = 128                   # Batch Size of training data
learning_rate = 1e-4

from VAE import *
model = VAE(voca_dim=len(word2index),
            embed_dim=embed_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            en_num_layers=en_n_l,
            de_num_layers=de_n_l,
            prop_num = prop_num,
            run_predictor = True,
            value_range = None).to(device)
pt_path = f'./model/finetuned_Model_A.pt'
state_dict = torch.load(pt_path)['model_state_dict']
model.load_state_dict(state_dict)
model.eval()

Loaded tokens: ['[=Branch1]', '[#Branch1]', '[=Branch2]', '[#Branch2]', '[Branch1]', '[Branch2]', '[=Ring1]', '[=Ring2]', '[Ring1]', '[Ring2]', '[NH1+1]', '[CH1-1]', '[=N+1]', '[=N-1]', '[=S+1]', '[=PH1]', '[#N+1]', '[N+1]', '[O-1]', '[NH1]', '[CH0]', '[N-1]', '[OH0]', '[PH1]', '[C-1]', '[S+1]', '[CH1]', '[NH0]', '[PH0]', '[SH1]', '[=C]', '[=O]', '[=N]', '[Cl]', '[#C]', '[Br]', '[=S]', '[=P]', '[#N]', '[C]', '[N]', '[O]', '[S]', '[P]', '[F]']
word2index: {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3, '[=Branch1]': 4, '[#Branch1]': 5, '[=Branch2]': 6, '[#Branch2]': 7, '[Branch1]': 8, '[Branch2]': 9, '[=Ring1]': 10, '[=Ring2]': 11, '[Ring1]': 12, '[Ring2]': 13, '[NH1+1]': 14, '[CH1-1]': 15, '[=N+1]': 16, '[=N-1]': 17, '[=S+1]': 18, '[=PH1]': 19, '[#N+1]': 20, '[N+1]': 21, '[O-1]': 22, '[NH1]': 23, '[CH0]': 24, '[N-1]': 25, '[OH0]': 26, '[PH1]': 27, '[C-1]': 28, '[S+1]': 29, '[CH1]': 30, '[NH0]': 31, '[PH0]': 32, '[SH1]': 33, '[=C]': 34, '[=O]': 35, '[=N]': 36, '[Cl]': 37, '[#C]': 38, '

VAE(
  (embed): Embedding(49, 256)
  (generator): Sequential(
    (0): Linear(in_features=512, out_features=49, bias=False)
    (1): LogSoftmax(dim=-1)
  )
  (encoder): Encoder(
    (gru): GRU(256, 512, num_layers=3, batch_first=True, bidirectional=True)
    (mu): Linear(in_features=1024, out_features=256, bias=True)
    (std): Sequential(
      (0): Linear(in_features=1024, out_features=256, bias=True)
      (1): Softplus(beta=1, threshold=20)
    )
  )
  (decoder): Decoder(
    (bridge): Linear(in_features=256, out_features=1024, bias=True)
    (decoder_gru): GRU(256, 512, num_layers=2, batch_first=True)
    (pre_output): Linear(in_features=768, out_features=512, bias=True)
  )
  (predictors): ModuleList(
    (0): Predictor(
      (latent_layer_1): Linear(in_features=256, out_features=512, bias=True)
      (latent_layer_2): Linear(in_features=512, out_features=512, bias=True)
      (latent_layer_3): Linear(in_features=512, out_features=512, bias=True)
      (latent_layer_4): Linear(i

In [38]:
target_score_dict = {'qed' : (0.8, 'up', {'min' : 0, 'max' : 1}),
                     'SAs' : (3.0, 'down', {'min' : 1, 'max' : 10}), 
                     'bcl2' : (10.34, 'up', {'min' : 2.15, 'max' : 10.34}),
                     'bclxl' : (9.85, 'up', {'min' : 1.91, 'max' : 9.85}),
                     'bclw' : (8.68, 'up', {'min' : 2.23, 'max' : 8.68})}
target_list = [ target_score_dict[str][0] for str in prop_name] 
target_up = [ True if target_score_dict[str][1] == 'up' else False for str in prop_name]  #Target up
target_tensor = torch.tensor(target_list).to(device)
target_max = [ target_score_dict[str][2]['max'] for str in prop_name ]
target_min = [ target_score_dict[str][2]['min'] for str in prop_name ]

final_df = df.sample(n=100, random_state=2025)
smiles = final_df['smiles'].to_numpy()
selfies = final_df['selfies'].to_numpy()

prop_data = []
for i in range(prop_num):
    prop_data.append(final_df[prop_name[i]].to_numpy())

dataset = selfiesDataset(selfies, prop_data, word2index, device, num_samples=None)
data_loader = DataLoader(dataset,
                         batch_size=1,
                         shuffle=False,
                         collate_fn=lambda x: collate_fn(x, word2index, dataset.pattern, device))
print(final_df[['smiles']+ prop_name])


#Setting odds
weight_vectors = [torch.ones(latent_dim) for _ in range(prop_num)]
updated_indices_list = [set() for _ in range(prop_num)]

thold_list = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3] # if odds < 0.3, don't care

for thold in thold_list:
    mask_output = model.mask_inference(thold)
    indices_sets = [set(torch.nonzero(row == 1, as_tuple=True)[0].cpu().numpy()) for row in mask_output]
    
    for i, indices_set in enumerate(indices_sets):
        new_indices = indices_set - updated_indices_list[i]  
        for idx in new_indices:
            weight_vectors[i][idx] = 3 + thold 
        updated_indices_list[i].update(new_indices)

                                                smiles      bcl2     bclxl  \
267  CN(CC(=O)NC(C)(C)C)C(=O)C1CCN(c2ccc(S(=O)(=O)N...  5.038860  4.875372   
987  O=C(NCCc1ccccc1)C1CN(C(=O)CN(C2CCCCC2)S(=O)(=O...  5.755924  5.237386   
874  Cc1sc2ncnc(N3CCC(C(=O)N4CCN(S(=O)(=O)c5ccc(C(C...  5.225762  4.832924   
153  COc1cc(C=c2sc(=CC(=O)C(C)(C)C)n(CC(=O)N3CCCC(C...  4.366474  4.462648   
802  CCN(CC)S(=O)(=O)c1cc(C(=O)Nc2ccc3c(c2)N(S(=O)(...  5.821074  6.577568   
..                                                 ...       ...       ...   
560  COc1ccccc1N(CC(=O)N(CCc1ccccc1)C(C)C(=O)NC(C)C...  5.398428  5.515744   
361  O=C(c1ccc(N2CCCC2)c([N+](=O)[O-])c1)N1CCN(S(=O...  5.626795  4.916760   
190  CCN(CC)S(=O)(=O)c1ccc(N2CCN(CC(O)COc3ccccc3F)C...  6.846522  7.228673   
66   COC1=C(OC)C(=O)C(=C2C(C)=C(C(=O)O)NC(c3ccc4c(n...  6.228081  6.555745   
88   CC(=O)CC(c1ccccc1)c1c(O)c2ccc(OC3OC(C(=O)O)C(O...  4.410439  3.976376   

         bclw  
267  5.214093  
987  6.090002  
874  5.098664  

In [39]:

#Log 저장 경로
model_name = f"Model_A_Optim"
save_optim_path = f'./model/{model_name}/'

parent_num = 10
stop_condi = 3000

t_weight = 15
prop_weight = [1, 1, 1]

MSE = torch.nn.MSELoss(reduction="sum")
GaussianNLL = torch.nn.GaussianNLLLoss(reduction="sum")

optim_step = 50
optim_iter = 20  

y_list = []
for p_i, batch in enumerate(data_loader):
    # Define Model
    from VAE import *
    model = VAE(voca_dim=len(word2index),
                embed_dim=embed_dim,
                hidden_dim=hidden_dim,
                latent_dim=latent_dim,
                en_num_layers=en_n_l,
                de_num_layers=de_n_l,
                prop_num = prop_num,
                run_predictor = True,
                value_range = None).to(device)
    pt_path = f'./model/finetuned_Model_A.pt'
    state_dict = torch.load(pt_path)['model_state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    #\eta parameter
    z_param = nn.Parameter(torch.randn(1, latent_dim), requires_grad=True)

    sorted_source, sorted_target, sorted_lengths, max_len, sorted_props, sorted_origin_indexs = batch
    batch_size = len(sorted_source)
    src = sorted_source.to(device)
    trg = sorted_target.to(device)
    x, y, _, sample_z, _ = model.inference(src,
                                           sorted_lengths,
                                           max_len)

    z_param.data = sample_z.clone().data  
    z_param.requires_grad = True

    optimizer = optim.Adam([z_param], lr=0.01)
    sos_tensor = torch.tensor([word2index['<sos>']]).to(device)
    smi_src, sf = tensor2smiles_v2(x, word2index, index2word)
    smi = smi_src[0]

    _, y = model.decode_z(z_param, max_len = 154, sos_tensor=sos_tensor, return_y = True) 
    y_values = [tensor_item.cpu().detach().numpy().flatten()[0] for tensor_item in y]
    y_list.append(y_values)
    
    for gen_try in range(optim_step * optim_iter):
        individual_losses = []
        ori_t_losses = []
        pred_list = []
        
        x = model.decode_z(z_param, max_len = 154, sos_tensor=sos_tensor) 
        smi_src, sf = tensor2smiles_sampling(x, word2index, index2word)  
        smi = normalize_SMILES(smi_src[0])
        optimizer.zero_grad() 
        total_grad_z = torch.zeros_like(z_param) 

        for idx, (predictor, target, is_up, p_max, p_min) in enumerate(zip(model.predictors, target_tensor, target_up, target_max, target_min)):
            pred_y = predictor(z_param)
            pred_list.append(pred_y)
            t_loss = MSE(pred_y.squeeze(0).view(target.shape), target)
            individual_losses.append(prop_weight[idx] * t_loss)
            ori_t_losses.append(t_loss)
    
            weight_vector = weight_vectors[idx]
            weight_vector = weight_vector.to(dtype=z_param.dtype, device=z_param.device)
    
            grad_z = torch.autograd.grad(prop_weight[idx] * t_loss, z_param, retain_graph=True)[0]
    
            weighted_grad_z = grad_z * weight_vector.unsqueeze(0)
    
            total_grad_z += weighted_grad_z

        target_loss = torch.mean(torch.stack(individual_losses))
        loss = (t_weight * target_loss)
        z_param.grad = total_grad_z
        optimizer.step()        
        #loss.backward()

        if (gen_try % optim_iter == 0):         
            _, y = model.decode_z(z_param, max_len = 154, sos_tensor=sos_tensor, return_y = True) 
            y_values = [tensor_item.cpu().detach().numpy().flatten()[0] for tensor_item in y]
            y_list.append(y_values)

            
            clear_output(wait=True)
            print(f'= =step : {int(gen_try / optim_iter)}, seed {p_i} Optimize Results === = = =')
            print("Current Step Z mean:", z_param.data.mean().item())
            print("#" * 20)
            print("smi : ", smi)
            print("Target Loss : ", t_weight * target_loss.item())

            for i in range(prop_num):
                print('#'* 20 )
                print(f'# Property Info {i}  - {prop_name[i]} -  # ')
                print(f'Property Range | {target_min[i]}(min) ~ {target_max[i]}(max) |')
                print(f'Target : {target_list[i]}')
                print(f'Predict : {pred_list[i].item()}')

            print(f'= = = === = = = === = = = === = = = === = = = === = = =')
    mask_y_list = list(map(list, zip(*y_list)))

print("End of Optimization")

= =step : 1, seed 0 Optimize Results === = = =
Current Step Z mean: 0.04734654352068901
####################
smi :  CN(CC(=O)NC(C)(C)C)C(=O)C1CCN(c2ccc(S(=O)(=O)N3CCCCC3)cc2[N+](=O)[O-])CC1
Target Loss :  281.3097381591797
####################
# Property Info 0  - bcl2 -  # 
Property Range | 2.15(min) ~ 10.34(max) |
Target : 10.34
Predict : 5.457488536834717
####################
# Property Info 1  - bclxl -  # 
Property Range | 1.91(min) ~ 9.85(max) |
Target : 9.85
Predict : 5.508421421051025
####################
# Property Info 2  - bclw -  # 
Property Range | 2.23(min) ~ 8.68(max) |
Target : 8.68
Predict : 4.995748043060303
= = = === = = = === = = = === = = = === = = = === = = =


KeyboardInterrupt: 