In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import gzip
import pandas
import h5py
import numpy as np
import argparse
import os
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn import model_selection
from scipy import stats

from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.rdMolDescriptors import CalcNumHBD    
from rdkit.Chem.rdMolDescriptors import CalcNumHBA
from rdkit.Chem.rdMolDescriptors import CalcTPSA
from rdkit import Chem
from rdkit.Chem.QED import qed

from utils import decode_smiles_from_indexes, load_dataset
import matplotlib.pyplot as plt
from scipy.special import softmax

In [None]:
X_train, X_test, charset = load_dataset('./data/processed.h5')

mw_train = np.load("./prop_np/weight/y_train_norm.npy")
mw_test = np.load("./prop_np/weight/y_test_norm.npy")

mw_pdf_train = np.load("./prop_np/weight/pdf_train.npy")
mw_pdf_test = np.load("./prop_np/weight/pdf_test.npy")

tpsa_train = np.load("./prop_np/tpsa/y_train_norm.npy")
tpsa_test = np.load("./prop_np/tpsa/y_test_norm.npy")

tpsa_pdf_train = np.load("./prop_np/tpsa/pdf_train.npy")
tpsa_pdf_test = np.load("./prop_np/tpsa/pdf_test.npy")

In [None]:
lat_dim = 128

In [None]:
y_train = np.zeros((len(mw_train), lat_dim))
y_test = np.zeros((len(mw_test), lat_dim))

y_train[:,0] = mw_train
y_train[:,1] = tpsa_train

y_test[:,0] = mw_test
y_test[:,1] = tpsa_test

y_train[:,2:] = 0.
y_test[:,2:] = 0.

In [None]:
pdf_train = np.concatenate((mw_pdf_train, tpsa_pdf_train),axis=-1)
pdf_test = np.concatenate((mw_pdf_test, tpsa_pdf_test),axis=-1)

In [None]:
torch_X_train = torch.from_numpy(X_train).type(torch.FloatTensor)
torch_X_test = torch.from_numpy(X_test).type(torch.FloatTensor)

torch_pdf_train = torch.from_numpy(pdf_train).type(torch.FloatTensor) 
torch_pdf_test = torch.from_numpy(pdf_test).type(torch.FloatTensor)

torch_y_train = torch.from_numpy(y_train).type(torch.FloatTensor) 
torch_y_test = torch.from_numpy(y_test).type(torch.FloatTensor)

train = torch.utils.data.TensorDataset(torch_X_train, torch_pdf_train, torch_y_train)
test = torch.utils.data.TensorDataset(torch_X_test, torch_pdf_test, torch_y_test)

In [None]:
train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=250)
test_loader = torch.utils.data.DataLoader(test, shuffle=True, batch_size=250)

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.linear_3 = nn.Linear(194, 256) ##161
        self.linear_4 = nn.Linear(256, 1344)
        self.conv_4 = nn.ConvTranspose2d(64,32, kernel_size=(11,3), stride=(2,2), padding=0,output_padding=(0,0))
        self.conv_5 = nn.ConvTranspose2d(32,16, kernel_size=(11,3), stride=(2,2), padding=0,output_padding=(0,1))
        self.conv_6 = nn.ConvTranspose2d(16,1, kernel_size=(11,3), stride=(2,2),padding=0,output_padding=(1,0))
        self.relu = nn.ReLU()
    
    def forward(self, z, y):
        z_cond = torch.cat((z,y), dim=1)
        z_cond = F.selu(self.linear_3(z_cond))
        z_cond = F.selu(self.linear_4(z_cond))
        z_cond = z_cond.view(z_cond.size(0), 64, 7, 3)# (N,C,H)\n",
        z_cond = self.relu(self.conv_4(z_cond))
        z_cond = self.relu(self.conv_5(z_cond))
        z_cond = self.relu(self.conv_6(z_cond))
        y0 = z_cond.contiguous().view(z_cond.size(0), -1) # (N,C,H)\n",
        y1 = F.softmax(y0, dim=1)
        y = y1.contiguous().view(z_cond.size(0), z_cond.size(2), z_cond.size(3))
        
        return y

In [None]:
class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv2d(1, 16, (11,3), stride=(2,2))
        self.conv_2 =nn.Conv2d(16, 32, (11,3), stride=(2,2))
        self.conv_3 = nn.Conv2d(32, 64, (11,3), stride=(2,2))
        self.linear_0 = nn.Linear(1344, 256)
        self.linear_1 = nn.Linear(256, lat_dim)
        self.linear_2 = nn.Linear(256, lat_dim)
        self.relu = nn.ReLU()
        self.decode = Decoder()
        
    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)
    
    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def forward(self, x, y):
        x_cond = torch.cat((x,y.view(y.size(0), 1, 2, -1)), dim=2)
        z_mean, z_logvar = self.encode(x_cond)
        z = self.sampling(z_mean, z_logvar)
        decoder = self.decode(z, y)
        
        return decoder, z_mean, z_logvar,z

In [None]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar, y_arg):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - (z_mean-y_arg).pow(2) - z_logvar.exp())
    return 0.1*xent_loss + 0.1*kl_loss, 0.1*kl_loss

In [None]:
torch.manual_seed(42)
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MolecularVAE().to(device)
optimizer = optim.Adam(model.parameters())

In [None]:
def train(epoch):
    model.train()
    train_loss = 0
    KL_loss = 0
    latent_arr = []
    label_arr = []
    for batch_idx, data in enumerate(train_loader):
        oh, label, arg_label = data
        oh, label, arg_label = oh.unsqueeze(1).to(device), label.to(device), arg_label.to(device)
        optimizer.zero_grad()
        output, mean, logvar,latent = model(oh, label)
        loss, kl_loss = vae_loss(output, oh.squeeze(1), mean, logvar, arg_label)
        loss.backward()
        train_loss += loss
        KL_loss+=kl_loss
        optimizer.step()
        latent_arr.append(latent.cpu().detach().numpy())
        label_arr.append(arg_label.cpu().detach().numpy())
        
    print('train', train_loss / len(train_loader.dataset))
    print('train KL', KL_loss / len(train_loader.dataset))
    
    return train_loss / len(train_loader.dataset),latent_arr, label_arr

In [None]:
def test(epoch):
    model.eval()
    test_loss = 0
    KL_loss = 0
    for batch_idx, data in enumerate(test_loader):
        oh, label, arg_label = data
        oh, label, arg_label = oh.unsqueeze(1).to(device), label.to(device), arg_label.to(device)
        output, mu, logvar,latent = model(oh, label)

        loss, kl_loss = vae_loss(output, oh.squeeze(1), mu, logvar, arg_label)
        KL_loss+=kl_loss
        test_loss += loss
    print('test', test_loss / len(test_loader.dataset))
    print('test KL', KL_loss / len(test_loader.dataset))
    
    return test_loss / len(test_loader.dataset)

In [None]:
for epoch in range(1, epochs + 1):
    train_loss,latent_arr,label_arr = train(epoch)
    test_loss = test(epoch)

In [None]:
latent_np = np.array(latent_arr)

In [None]:
latent_np = latent_np[:,:,:2].reshape((-1,2))

In [None]:
label_np = np.array(label_arr)

In [None]:
label_np = label_np[:,:,:2].reshape((-1,2))

In [None]:
label_np.shape

In [None]:
np.max(label_np[:,0])

In [None]:
np.max(latent_np[:,0])

In [None]:
np.min()

In [None]:
np.max(latent_np[:,1])

In [None]:
latent_np.shape

In [None]:
cor_ord = np.sqrt(label_np[:,0]**2+(label_np[:,1])**2)

np.save("./prop_np/hba_hbd/latent.npy", latent_np)
np.save("./prop_np/hba_hbd/label.npy", label_np)

In [None]:
cor_ord.shape

In [None]:
import pandas as pd

latten_pd = pd.DataFrame({
        'x': latent_np[:, 0]/10,
        'y': latent_np[:, 1]/10,
        'label': cor_ord})

In [None]:
im = latten_pd.plot(kind="scatter", x='x', y='y', alpha=0.7, figsize=(10,7),
    c='label', cmap=plt.get_cmap("jet"), colorbar=True,
    sharex=False)

#im.axes.set_title("CVAE Latent Distribution",y=1.05, fontsize=25)
im.set_xlabel("normalized HBA",fontsize=20)
im.set_ylabel("normalized HBD",fontsize=20)
im.tick_params(labelsize=20)


#plt.savefig("./image/HBA_HBD.jpg", format='jpg',edgecolor='none', dpi=300)
plt.show()

In [None]:
 def reconstructed(autoencoder, charset):
    valid_smile = []
    true_value = 3.5
    prop1 = 5
    prop2 = 250
    nums = 1000
    x = np.linspace(0, 10, 33)

    for i in range(1000):
        
        re_smile = []
        lat =  np.random.normal(0, 1., size=(nums, 128)).astype ('float32')    
        lat[:,0] = lat[:,0] + (np.ones((nums))*prop1)
        lat[:,1] = lat[:,1] + (np.ones((nums))*prop2)

        cond1 = stats.norm(prop1, 1).pdf(x)
        
        cond2 = stats.norm(prop2, 1).pdf(x)
        
        cond = np.concatenate((cond1,cond2),axis=-1)

        cond_repeat = np.repeat(cond[np.newaxis, :], nums, axis=0)
        
        
        lat_torch =  torch.Tensor(lat).to(device)
        cond_torch =  torch.Tensor(cond_repeat).to(device)
        
        output = autoencoder.decode(lat_torch,cond_torch)
        outp = output.cpu().detach().numpy()
        
        for j in range(nums):
            decode_smi = outp[j].reshape(1, 120, len(charset)).argmax(axis=2)[0]
            re_smile.append(decode_smiles_from_indexes(decode_smi, charset))
        
        for k in range(len(re_smile)):
            m = Chem.MolFromSmiles(re_smile[k])
            if (m != None) and (' ' not in re_smile[k]):
                valid_smile.append(re_smile[k])
                
    valid_smile = list(set(valid_smile))
        
    return valid_smile

In [None]:
from rdkit import RDLogger   
RDLogger.DisableLog('rdApp.*')

In [None]:
valid_smile = reconstructed(model, charset)

In [None]:
valid_smile

In [None]:
mw_tpsa = np.zeros((len(valid_smile),2))
for i in range(len(valid_smile)):
    s = valid_smile[i]
    m = Chem.MolFromSmiles(s)
    mw_tpsa[i,0] = ExactMolWt(m)
    mw_tpsa[i,1] = CalcTPSA(m)

In [None]:
mw_tpsa