In [24]:
import pickle
import rdkit 
from rdkit.Chem import MolFromSmiles as get_mol
import torch 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import math
import model 
import utils
import numpy as np 
from model import Model , Encoder
from utils import parallel_f, get_dic, MyDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   

def tokenize(smi) :
    return [0] + [smi_dic[char] for char in smi] + [1]

def pad(smi) :
    return smi + [2] * (max_len - len(smi))

def frange_cycle_cosine(start, stop, n_epoch, n_cycle=4, ratio=0.5):
    L = np.ones(n_epoch)
    period = n_epoch/n_cycle
    step = (stop-start)/(period*ratio) # step is in [0,1]
    
    # transform into [0, pi] for plots: 

    for c in range(n_cycle):

        v , i = start , 0
        while v <= stop:
            L[int(i+c*period)] = 0.5-.5*math.cos(v*math.pi)
            v += step
            i += 1
    return L    
def loss_fn(pred, tgt, mu, sigma, beta) :
    reconstruction_loss = F.cross_entropy(pred.reshape(-1, len(smi_dic)), tgt[:, 1:].reshape(-1), reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp()).mean()
    return  reconstruction_loss + kl_loss * beta 

In [25]:
with open('../data/chembl24_canon_train.pickle','rb') as file :
    smi_list = pickle.load(file) 
    smi_list = [smi for smi in smi_list if len(smi) <= 40]
    print(f'Number of data: {len(smi_list)}')

smi_dic = get_dic(smi_list)
inv_dic = {v:k for k, v in smi_dic.items()}
token_list = parallel_f(tokenize, smi_list)
max_len = len(max(token_list, key=len))
token_list = parallel_f(pad, token_list)

Number of data: 411544


In [26]:
BATCH_SIZE = 128
DIM_MODEL = 512
NUM_HEAD = 4
DROPOUT = 0.5
NORM = True 
LEARNING_RATE = 0.003

dataset = MyDataset(token_list)
train_set, val_set = random_split(dataset, [0.9, 0.1])
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)

model = Model(dim_model=DIM_MODEL,
              num_head=NUM_HEAD,
              dropout=DROPOUT,
              smi_dic=smi_dic,
              norm = NORM).to(device)

optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE) 


In [27]:
NUM_EPOCHS = 100 
NUM_GENERATE = 1

beta_np_cyc = frange_cycle_cosine(0.0, 1.0, len(train_loader), 4)

rdkit.rdBase.DisableLog('rdApp.*')

for epoch in range(1, NUM_EPOCHS + 1) :
    train_loss, val_loss = 0, 0 
    beta = 0 if epoch < 20 else 0.00001
    for i, input in enumerate(train_loader) :
        model.train() 
        # beta = 0 if i < len(train_loader) * 0.95 else 0.00001
        input = input.to(device)
        pred, mu, sigma = model(input, input[:, :-1])

        loss = loss_fn(pred, input, mu, sigma, beta)
        train_loss += loss.item() 
        loss.backward()
        optim.step()
        optim.zero_grad() 



        model.eval()

        z = torch.randn(NUM_GENERATE, max_len, DIM_MODEL // 4).to(device)
        target = torch.zeros(NUM_GENERATE, 1, dtype = torch.long).to(device)

        for _ in range(max_len - 1) :
            out = model.inference(z, target)
            _, idx = torch.topk(out, 1, dim = -1)
            idx = idx[:, -1, :]
            target = torch.cat([target, idx], dim = 1)
        target = target.squeeze(0).tolist()
        smiles = ''.join([inv_dic[i] for i in target])
        smiles = smiles.replace("<START>", "").replace("<PAD>", "").replace("<END>","")
        valid = "Valid" if get_mol(smiles) else "Not"
        print(f'{smiles} - {valid}')
    print('\n\n\n\n\n')
    
    print(f'epoch : {epoch}, train loss : {train_loss / len(train_loader)} val loss : {val_loss / len(val_loader)}')


CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC - Valid
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(C(CC(CC(=O)cccccccccccccccccccccccccccc - Not
C(C(CC(CC(C(CC(C(CCC(C(C(CCC(C(C(C(CCC(C( - Not
C(C(CC(=O)ccccccccccccccccccccccccccccccc - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)c1 - Not
C(=O)c1 - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)c1 - Not
C(=O)c1 - Not
C(=O)c1 - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(=O)cccccccccccccccccccccccccccccccccccc - Not
C(

In [20]:
CCCCC(C)CC(=O)C(=O)CCc2c(=O)c1c2c(C#N)cc1 - Valid
OCCCC(=O)CC1C(=O)CC1=O - Valid
OCCCC(C)CC(=S)CC(=S)CC(=O)NCCc1c(=S)c1CCC - Valid


In [22]:
mol = rdkit.Chem.MolFromSmiles('"((O)(O)(O)OC2)cc1O)NCC1c2"')

In [23]:
mol