In [58]:
import os
from tqdm import tqdm 
from rdkit.Chem import QED
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import pandas as pd
import time
from IPython.display import clear_output
import itertools
import importlib
import random
from utils import *

In [60]:
with open("./data/zinc15_selfies_tokens.txt", "r") as file:
    token_in_dataset = [line.strip() for line in file]

print("Loaded tokens:", token_in_dataset)
word2index = {"<pad>": 0, "<unk>": 1, "<sos>": 2, "<eos>": 3}
index2word = {0: "<pad>", 1: "<unk>", 2: "<sos>", 3: "<eos>"}

start_index = max(index2word.keys()) + 1

for i, token in enumerate(token_in_dataset, start=start_index):
    word2index[token] = i
    index2word[i] = token

print("word2index:", word2index)
print("index2word:", index2word)

data_path = './data/zinc15_sample.csv'
df = pd.read_csv(data_path)
smiles = df['smiles'].to_numpy()
selfies = df['selfies'].to_numpy()

GPU_NUM = 0
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)

dataset = selfiesDataset(selfies, None, word2index, device, num_samples=None)
data_loader = DataLoader(dataset,
                     batch_size=128,
                     shuffle=True,
                     collate_fn=lambda x: collate_fn_pre(x, word2index, dataset.pattern, device))
data_len = len(dataset)
print(data_len)

Loaded tokens: ['[=Branch1]', '[#Branch1]', '[=Branch2]', '[#Branch2]', '[Branch1]', '[Branch2]', '[=Ring1]', '[=Ring2]', '[Ring1]', '[Ring2]', '[NH1+1]', '[CH1-1]', '[=N+1]', '[=N-1]', '[=S+1]', '[=PH1]', '[#N+1]', '[N+1]', '[O-1]', '[NH1]', '[CH0]', '[N-1]', '[OH0]', '[PH1]', '[C-1]', '[S+1]', '[CH1]', '[NH0]', '[PH0]', '[SH1]', '[=C]', '[=O]', '[=N]', '[Cl]', '[#C]', '[Br]', '[=S]', '[=P]', '[#N]', '[C]', '[N]', '[O]', '[S]', '[P]', '[F]']
word2index: {'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3, '[=Branch1]': 4, '[#Branch1]': 5, '[=Branch2]': 6, '[#Branch2]': 7, '[Branch1]': 8, '[Branch2]': 9, '[=Ring1]': 10, '[=Ring2]': 11, '[Ring1]': 12, '[Ring2]': 13, '[NH1+1]': 14, '[CH1-1]': 15, '[=N+1]': 16, '[=N-1]': 17, '[=S+1]': 18, '[=PH1]': 19, '[#N+1]': 20, '[N+1]': 21, '[O-1]': 22, '[NH1]': 23, '[CH0]': 24, '[N-1]': 25, '[OH0]': 26, '[PH1]': 27, '[C-1]': 28, '[S+1]': 29, '[CH1]': 30, '[NH0]': 31, '[PH0]': 32, '[SH1]': 33, '[=C]': 34, '[=O]': 35, '[=N]': 36, '[Cl]': 37, '[#C]': 38, '

In [61]:
class KLWeightScheduler:
    def __init__(self, initial_weight=0.0, final_weight=10.0, total_epochs=100, midpoint=None):
        self.initial_weight = initial_weight
        self.final_weight = final_weight
        self.total_epochs = total_epochs
        self.midpoint = midpoint if midpoint is not None else total_epochs / 2

    def get_weight(self, current_epoch):
        # Sigmoid function parameters
        growth_rate = 10 / self.total_epochs  # Adjust growth rate as needed
        
        # Sigmoid function to calculate KL weight
        weight = self.initial_weight + (self.final_weight - self.initial_weight) / (1 + np.exp(-growth_rate * (current_epoch - self.midpoint)))
        return weight

In [62]:
from tqdm.notebook import tqdm
from tqdm.auto import trange

num_epoch = 100

model_name = 'pretrained_vae_zinc15'
save_path = f'./model/{model_name}/'
if not os.path.exists(save_path):
    os.makedirs(save_path)

scheduler = KLWeightScheduler(initial_weight = 0.0,
                                   final_weight=1.0,
                                   total_epochs=num_epoch,
                                   midpoint=15)
weights = [kl_w_scheduler.get_weight(epoch) for epoch in range(num_epoch)]

embed_dim = 256                   #Embedding Vector Dim 
hidden_dim = 512                   #Latent Vector Dim
latent_dim = 256
en_n_l = 3                         #Encoder GRU Number of Layers 
de_n_l = 2                         #Decoder GRU Number of Layers
base_batch_size = 128                   # Batch Size of training data
learning_rate = 1e-4

from VAE import *
model = VAE(voca_dim=len(word2index),
            embed_dim=embed_dim,
            hidden_dim=hidden_dim,
            latent_dim=latent_dim,
            en_num_layers=en_n_l,
            de_num_layers=de_n_l,
            prop_num = prop_num,
            run_predictor = False, #because pre-training
            value_range = None).to(device)

NLL = nn.NLLLoss(reduction='none',
                 ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
epoch_log = []
iterate_log = []
loss_log = []
rec_loss_log = []
kl_loss_log = []
ori_kl_loss_log = []



total_loss = 0
is_print = True

#Teacher Forcing Annealing
initial_tf_ratio = 1.0
final_tf_ratio = 0.0
tf_decay_rate = (initial_tf_ratio - final_tf_ratio) / num_epoch


start_time = time.time()
non_zero = 1e-8  # add this when calculate log() e.g. log(std + non_zero)
i = -1    

start_epoch = 0

for epoch in trange(start_epoch, num_epoch, desc="Epochs", leave=False): 
    kl_w_rate = scheduler.get_weight(current_epoch = epoch)
    tf_ratio = max(initial_tf_ratio - (epoch * tf_decay_rate), final_tf_ratio)
    
    for batch in tqdm(data_loader, desc=f"Dataset({data_len})/batch({base_batch_size})", leave=False):
        
        
        model.train()
        i += 1
        one_time = time.time()
        sorted_source, sorted_target, sorted_lengths, max_len, sorted_origin_indexs = batch
        batch_size = len(sorted_source)
        src = sorted_source.to(device)# ['<sos>', O', '=', 'C'] 
        trg = sorted_target.to(device)# ['O', '=', 'C', '<eos>'] 

        x, y, z, sample_z = model.forward(src, trg, sorted_lengths, max_len, tf_ratio = tf_ratio)
        x_ = x.view(-1, x.shape[2]).contiguous()
        x_label = sorted_target.view(-1).contiguous().to(device) 

       # Reconstruct Loss
        rec_loss = NLL(x_, x_label)  # rec_loss = F.binary_cross_entropy(x,  )
        rec_loss = rec_loss.view(x.shape[0], x.shape[1])
        rec_loss = torch.sum(rec_loss, dim=-1)  #[batch, token_num ] -> [batch]
        rec_loss = torch.sum(rec_loss) / batch_size
    
        # KLD Loss 
        
        mu, std = z
        kl_loss = torch.mean(mu ** 2 + std ** 2 - 2 * torch.log(std + non_zero) - 1, dim=1)
        ori_kl_loss = torch.mean(kl_loss).item()
        kl_loss = kl_w_rate * torch.mean(kl_loss)
        loss = rec_loss + kl_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
        # Log 
        epoch_log.append(epoch)
        iterate_log.append(i)
        loss_log.append(loss.item())
        rec_loss_log.append(rec_loss.item())
        kl_loss_log.append(kl_loss.item())
        ori_kl_loss_log.append(ori_kl_loss)


        if i % 100 == 0:
            if is_print:
                print(f'=========================epoch : {epoch}==========================')
                print(f'# Model Name : {model_name}')
                print("Iteration ", i, ", Total loss ", loss.item(), "\nKL loss ", kl_loss.item(),
                      ", Rec loss ", rec_loss.item(), sep="")
                print("Origin_KL_loss : ", ori_kl_loss)
                print("Time: ", convert_time(time.time() - one_time))
                print("Teacher Forcing Ratio : ", tf_ratio)
                print("KLD weight : ", kl_w_rate)
                idx = print_token2sf(sorted_target, x, word2index, index2word, batch_size, num=1)
                idx = idx[0].item()
                print(" idx : ",idx)
                print("-----")
                _, tmp_a = torch.max(x[idx], dim=-1)
                tmp_a = tmp_a.reshape(-1).cpu().detach().numpy()
                in_sentence = [index2word[i] for i in tmp_a]
                in_smi = ''.join(in_sentence)
                print(in_smi)
                print("===============================================================")
            one_time = time.time()
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, f'{save_path}{model_name}_main.pt')
        df = pd.DataFrame({'epoch': epoch_log,
                           'i': iterate_log,
                           'loss': loss_log,
                           'rec_loss': rec_loss_log,
                           'kl_loss': kl_loss_log,
                           'ori_kl_loss' : ori_kl_loss_log,})
        df.to_csv(f'{save_path}Log_{model_name}.csv',
                  index=False)

        if (epoch + 1) % 5 == 0:
            torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, f'{save_path}{model_name}_epoch_{epoch}.pt')

mean_iteration_loss = total_loss / i
print("===============================================================")
print("Completed Epoch", epoch, ", Total loss Mean: ", mean_iteration_loss, ", Time: ",
      convert_time(time.time() - start_time))
print("===============================================================")