In [None]:
#!/usr/bin/env python
import os,sys
import numpy as np

import torch
from torch.nn.parameter import Parameter

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

def accu(pred,val,batch_l):

    correct=0
    total=0
    cor_seq=0
    for i in range(0,batch_l.shape[0]):
        mm=(pred[i,0:batch_l[i]].cpu().data.numpy() == val[i,0:batch_l[i]].cpu().data.numpy())
        correct+=mm.sum()
        total+=batch_l[i].sum()
        cor_seq+=mm.all()
    acc=correct/float(total)
    acc2=cor_seq/batch_l.shape[0]
    return acc,acc2

def vec_to_char(out_num):
    stri=""
    for cha in out_num:
        stri+=char_list[cha]
    return stri

def cal_prec_rec(Ypred,Ydata,conf):

    small=0.0000000001
    Ypred0=Ypred.cpu().data.numpy()
    Ydata0=Ydata.cpu().data.numpy()
    Ypred00=Ypred0>conf
    mm=Ypred00*Ydata0
    TP=mm.sum()
    A=Ydata0.sum()
    P=Ypred00.sum()
    precision=(TP+small)/(P+small)
    recall=(TP+small)/A

    return precision, recall

class Encoder(nn.Module):

    def __init__(self,para,bias=True):
        super(Encoder,self).__init__()

        self.Nseq=para['Nseq']
        self.Nfea=para['Nfea']

        self.hidden_dim=para['hidden_dim']
        self.NLSTM_layer=para['NLSTM_layer']

        self.embedd = nn.Embedding(self.Nfea, self.Nfea)
        self.encoder_rnn = nn.LSTM(input_size=self.Nfea,hidden_size=self.hidden_dim,
                num_layers=self.NLSTM_layer,bias=True,
                batch_first=True,bidirectional=False)

        for param in self.encoder_rnn.parameters():
            if len(param.shape)>=2:
                nn.init.orthogonal_(param.data)
            else:
                nn.init.normal_(param.data)

    def forward(self,X0,L0):

        batch_size=X0.shape[0]
        device=X0.device
        enc_h0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)
        enc_c0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)

        X = self.embedd(X0)
        out,(encoder_hn,encoder_cn)=self.encoder_rnn(X,(enc_h0,enc_c0))
        last_step_index_list = (L0 - 1).view(-1, 1).expand(out.size(0), out.size(2)).unsqueeze(1)
        Z=out.gather(1,last_step_index_list).squeeze()
#        Z=torch.sigmoid(Z)
        Z=F.normalize(Z,p=2,dim=1)

        return Z

class Decoder(nn.Module):

    def __init__(self,para,bias=True):
        super(Decoder,self).__init__()

        self.Nseq=para['Nseq']
        self.Nfea=para['Nfea']

        self.hidden_dim=para['hidden_dim']
        self.NLSTM_layer=para['NLSTM_layer']

        self.embedd = nn.Embedding(self.Nfea, self.Nfea)

#        self.decoder_rnn = nn.LSTM(input_size=self.Nfea,
        self.decoder_rnn = nn.LSTM(input_size=self.Nfea+self.hidden_dim,
            hidden_size=self.hidden_dim, num_layers=self.NLSTM_layer,
            bias=True, batch_first=True,bidirectional=False)

        for param in self.decoder_rnn.parameters():
            if len(param.shape)>=2:
                nn.init.orthogonal_(param.data)
            else:
                nn.init.normal_(param.data)

        self.decoder_fc1=nn.Linear(self.hidden_dim,self.Nfea)
        nn.init.xavier_normal_(self.decoder_fc1.weight.data)
        nn.init.normal_(self.decoder_fc1.bias.data)

    def forward(self, Z, X0, L0):

        batch_size=Z.shape[0]
        device=Z.device
        dec_h0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)
        dec_c0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)

        X = self.embedd(X0)
        Zm=Z.view(-1,1,self.hidden_dim).expand(-1,self.Nseq,self.hidden_dim)
        ZX=torch.cat((Zm,X),2)

#        dec_out,(decoder_hn,decoder_cn)=self.decoder_rnn(X0,(Z.view(1,-1,self.hidden_dim),dec_c0))
        dec_out,(decoder_hn,decoder_cn)=self.decoder_rnn(ZX,(dec_h0,dec_c0))
        dec=self.decoder_fc1(dec_out)
        return dec

    def decoding(self, Z):
        batch_size=Z.shape[0]
        device=Z.device
        dec_h0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)
        dec_c0 = torch.zeros(self.NLSTM_layer*1,batch_size,self.hidden_dim).to(device)

        seq=torch.zeros([batch_size,1],dtype=torch.long).to(device)
        seq[:,0]=self.Nfea-2

#        Xdata_onehot=torch.zeros([batch_size,1,self.Nfea],dtype=torch.float32).to(device)
#        Xdata_onehot[:,0,self.Nfea-2]=1
        Y = seq
        Zm=Z.view(-1,1,self.hidden_dim).expand(-1,1,self.hidden_dim)

        decoder_hn=dec_h0
        decoder_cn=dec_c0
#        seq2=Xdata_onehot
        for i in range(self.Nseq):
            dec_h0=decoder_hn
            dec_c0=decoder_cn

            X = self.embedd(Y)
            ZX=torch.cat((Zm,X),2)
            dec_out,(decoder_hn,decoder_cn)=self.decoder_rnn(ZX,(dec_h0,dec_c0))
            dec=self.decoder_fc1(dec_out)
            Y= torch.argmax(dec,dim=2)
#            Xdata_onehot=torch.zeros([batch_size,self.Nfea],dtype=torch.float32).to(device)
#            Xdata_onehot=Xdata_onehot.scatter_(1,Y,1).view(-1,1,self.Nfea)
            seq=torch.cat((seq,Y),dim=1)
#            seq2=torch.cat((seq2,dec),dim=1)

        return seq #, seq2[:,1:]

class Generator(nn.Module):
    def __init__(self,para,bias=True):
        super(Generator,self).__init__()

        self.seed_dim=para['seed_dim']
        self.hidden_dim=para['hidden_dim']

        self.generator_fc1=nn.Linear(self.seed_dim,self.hidden_dim)
        nn.init.xavier_normal_(self.generator_fc1.weight.data)
        nn.init.normal_(self.generator_fc1.bias.data)

        self.generator_fc2=nn.Linear(self.hidden_dim,self.hidden_dim)
        nn.init.xavier_normal_(self.generator_fc2.weight.data)
        nn.init.normal_(self.generator_fc2.bias.data)

        self.generator_fc3=nn.Linear(self.hidden_dim,self.hidden_dim)
        nn.init.xavier_normal_(self.generator_fc3.weight.data)
        nn.init.normal_(self.generator_fc3.bias.data)

    def forward(self,S0):

        S1=self.generator_fc1(S0)
        S1=torch.relu(S1)
        S2=self.generator_fc2(S1)
        S2=torch.relu(S2)
        Zgen=self.generator_fc3(S2)
#        Zgen=torch.sigmoid(Zgen)
        Zgen=F.normalize(Zgen,p=2,dim=1)

        return Zgen

class Critic(nn.Module):
    def __init__(self,para,bias=True):
        super(Critic,self).__init__()

        self.hidden_dim=para['hidden_dim']

        self.critic_fc1=nn.Linear(self.hidden_dim,self.hidden_dim)
        nn.init.xavier_normal_(self.critic_fc1.weight.data)
        nn.init.normal_(self.critic_fc1.bias.data)

        self.critic_fc2=nn.Linear(self.hidden_dim,self.hidden_dim)
        nn.init.xavier_normal_(self.critic_fc2.weight.data)
        nn.init.normal_(self.critic_fc2.bias.data)

        self.critic_fc3=nn.Linear(self.hidden_dim,1)
        nn.init.xavier_normal_(self.critic_fc3.weight.data)
        nn.init.normal_(self.critic_fc3.bias.data)

    def forward(self,Z0):

        D1=self.critic_fc1(Z0)
        D1=torch.relu(D1)
        D2=self.critic_fc2(D1)
        D2=torch.relu(D2)
        Dout=self.critic_fc3(D2)

        return Dout

    def clip(self,epsi=0.01):
        torch.clamp_(self.critic_fc1.weight.data,min=-epsi,max=epsi)
        torch.clamp_(self.critic_fc1.bias.data,min=-epsi,max=epsi)
        torch.clamp_(self.critic_fc2.weight.data,min=-epsi,max=epsi)
        torch.clamp_(self.critic_fc2.bias.data,min=-epsi,max=epsi)
        torch.clamp_(self.critic_fc3.weight.data,min=-epsi,max=epsi)
        torch.clamp_(self.critic_fc3.bias.data,min=-epsi,max=epsi)


class Net(nn.Module):

    def __init__(self,para,bias=True):
        super(Net,self).__init__()

        self.Nseq=para['Nseq']
        self.Nfea=para['Nfea']

        self.hidden_dim=para['hidden_dim']
        self.NLSTM_layer=para['NLSTM_layer']

        self.Enc=Encoder(para)
        self.Dec=Decoder(para)
        self.Gen=Generator(para)
        self.Cri=Critic(para)


    def AE(self, X0, L0, noise):

        Z = self.Enc(X0, L0)
#        print(Z.shape, noise.shape)
        Zn = Z+noise
        decoded = self.Dec(Zn, X0, L0)

        return decoded


def main():

    print("main")

if __name__=="__main__":
    main()

valid.py

In [None]:
#!/usr/bin/env python
import sys
import os
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem as AllChem
from rdkit.Chem.QED import qed
from rdkit.Chem.Descriptors import MolWt, MolLogP, NumHDonors, NumHAcceptors, TPSA
from rdkit.Chem.rdMolDescriptors import CalcNumRotatableBonds
from rdkit.Chem import MolStandardize
#from molvs import tautomer
from rdkit import DataStructs

from multiprocessing import Manager
from multiprocessing import Process
from multiprocessing import Queue

import sascorer

USAGE = """
valid.py data_dir
"""


def creator(q, data, Nproc):
    Ndata = len(data)
    for d in data:
        idx = d[0]
        smiles = d[1]
        q.put((idx, smiles))

    for i in range(0, Nproc):
        q.put('DONE')


def check_validity(q, return_dict_valid):

    while True:
        qqq = q.get()
        if qqq == 'DONE':
            #            print('proc =', os.getpid())
            break
        idx, smi0 = qqq

        index = smi0.find('>')
        smi = smi0[0:index].strip('<')

        if idx % 10000 == 0:
            print(idx)

        m = Chem.MolFromSmiles(smi)
        if m is None:
            continue
        if Chem.SanitizeMol(m, catchErrors=True):
            continue
        smi2 = Chem.MolToSmiles(m)
#        smi2=MolStandardize.canonicalize_tautomer_smiles(smi)

        return_dict_valid[idx] = [smi2]


def cal_fp(q, return_dict_fp):

    nbits = 1024
    while True:
        qqq = q.get()
        if qqq == 'DONE':
            #            print('proc =', os.getpid())
            break
        idx, smi = qqq

        if idx % 10000 == 0:
            print(idx)
        Nsmi = len(smi)
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        if Chem.SanitizeMol(mol, catchErrors=True):
            continue

        com_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
        return_dict_fp[idx] = [com_fp]


def cal_sim(q, ref_data, return_dict_sim):

    Nref = len(ref_data)
    nbits = 1024
    while True:
        qqq = q.get()
        if qqq == 'DONE':
            #            print('proc =', os.getpid())
            break
        idx, smi = qqq

        if idx % 10000 == 0:
            print(idx)
        Nsmi = len(smi)
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue
        if Chem.SanitizeMol(mol, catchErrors=True):
            continue

        com_fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
        sim_data = []
        for j in range(Nref):
            ref_fp = ref_data[j][1]
            sim = DataStructs.TanimotoSimilarity(com_fp, ref_fp)
            sim_data += [sim]
        similarity = np.array(sim_data)
        j_max = similarity.argmax()
        sim_max = similarity[j_max]
        return_dict_sim[idx] = [sim_max, j_max]


def cal_prop(q, return_dict_prop):

    nbits = 1024
    while True:
        qqq = q.get()
        if qqq == 'DONE':
            #            print('proc =', os.getpid())
            break
        idx, smi = qqq

#        if idx%10000==0:
#            print(idx)
        mol = Chem.MolFromSmiles(smi)
        logP = MolLogP(mol)
        SAS = sascorer.calculateScore(mol)
        QED = qed(mol)
        MW = MolWt(mol)
        TPSA0 = TPSA(mol)

        return_dict_prop[idx] = [logP, SAS, QED, MW, TPSA0]


def main():
    if len(sys.argv) < 1:
        print(USAGE)
        sys.exit()

    data_dir = sys.argv[1]

    Nproc = 30
    gen_file = data_dir+"/ARAE_smiles.txt"
    fp = open(gen_file)
    lines = fp.readlines()
    fp.close()
    k = -1
    gen_data = []
    for line in lines:
        if line.startswith("#"):
            continue
        k += 1
        smi = line.strip()
        gen_data += [[k, smi]]

    Ndata = len(gen_data)

    q = Queue()
    manager = Manager()
    return_dict_valid = manager.dict()
    proc_master = Process(target=creator, args=(q, gen_data, Nproc))
    proc_master.start()

    procs = []
    for k in range(0, Nproc):
        proc = Process(target=check_validity, args=(q, return_dict_valid))
        procs.append(proc)
        proc.start()

    q.close()
    q.join_thread()
    proc_master.join()
    for proc in procs:
        proc.join()

    keys = sorted(return_dict_valid.keys())
    num_valid = keys

    valid_smi_list = []
    for idx in keys:
        valid_smi = return_dict_valid[idx][0]
        valid_smi_list += [valid_smi]

    num_valid = len(valid_smi_list)

    line_out = "valid:  %6d %6d %6.4f" % (
        num_valid, Ndata, float(num_valid)/Ndata)
    print(line_out)

    unique_set = set(valid_smi_list)
    num_set = len(unique_set)
    unique_list = sorted(unique_set)

    line_out = "Unique:  %6d %6d %6.4f" % (
        num_set, num_valid, float(num_set)/float(num_valid))
    print(line_out)

    file_output2 = data_dir+"/smiles_unique.txt"
    fp_out2 = open(file_output2, "w")
    line_out = "#smi\n"
    fp_out2.write(line_out)

    for smi in unique_list:
        line_out = "%s\n" % (smi)
        fp_out2.write(line_out)
    fp_out2.close()

    ZINC_file = "ZINC/train_5.txt"
    ZINC_data = [x.strip().split()[0]
                 for x in open(ZINC_file) if not x.startswith("SMILES")]
    ZINC_set = set(ZINC_data)
    novel_list = list(unique_set-ZINC_set)

    novel_data = []
    for idx, smi in enumerate(novel_list):
        novel_data += [[idx, smi]]

    q2 = Queue()
    manager = Manager()
    return_dict_prop = manager.dict()
    proc_master = Process(target=creator, args=(q2, novel_data, Nproc))
    proc_master.start()

    procs = []
    for k in range(0, Nproc):
        proc = Process(target=cal_prop, args=(q2, return_dict_prop))
        procs.append(proc)
        proc.start()

    q2.close()
    q2.join_thread()
    proc_master.join()
    for proc in procs:
        proc.join()

    num_novel = len(novel_list)

    line_out = "Novel:  %6d %6d %6.4f" % (
        num_novel, num_set, float(num_novel)/float(num_set))
    print(line_out)

    file_output3 = data_dir+"/smiles_novel.txt"
    fp_out3 = open(file_output3, "w")
    line_out = '#SMILES logP SAS QED MW TPSA\n'
    fp_out3.write(line_out)
    keys = sorted(return_dict_prop.keys())

    for key in keys:
        smi = novel_data[key][1]
        prop = return_dict_prop[key]
        logP, SAS, QED, MW, TPSA = prop
        line_out = "%s %6.3f %6.3f %5.3f %7.3f %7.3f\n" % (
            smi, logP, SAS, QED, MW, TPSA)
        fp_out3.write(line_out)
    fp_out3.close()


if __name__ == "__main__":
    main()