In [10]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch

from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem.rdMolDescriptors import CalcNumHBD    
from rdkit.Chem.rdMolDescriptors import CalcNumHBA
from rdkit.Chem.rdMolDescriptors import CalcTPSA
from rdkit import Chem
from rdkit.Chem.QED import qed
from utils import decode_smiles_from_indexes, load_dataset

from model import MolecularICVAE
from rdkit import RDLogger   
RDLogger.DisableLog('rdApp.*')

In [12]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MolecularICVAE().to(device)

In [13]:
model.load_state_dict(torch.load('./result/model/MW_model.pth'))

<All keys matched successfully>

**please note because the input value should be the normalized property value ranging from 0 to 500. So you need to transfer the true property value to normalized property firstly. In the following, we give you the two value y.min() and y_range = (y.max()-y.min()) for esily transfer:**

    
MW: [150.025169192,          201.12958056818]

SAS: [1.0574495523703487,  6.351893861381157]

TPSA: [0.,                149.48999999999998]

logP: [-4.9017,           10.970920000000003]

QED: [0.2335469098782191, 0.7145762918458708]

HBA: [0,                                  10]

HBD: [0,                                   5]

In [14]:
 def reconstructed(autoencoder, charset, p1):
    valid_smile = []
    
    p1 = (p1 - 150.025169192)/201.12958056818*500
    ## the two value 150.025169192 and 201.12958056818 need to change if you use want to generate other property.

    nums = 1000
    x = np.linspace(0, 10, 33)

    prop_np  = np.zeros((nums, 128))
    prop_np[:,0:2] = np.ones((nums,2))*p1
    
    p_ar = np.array([p1])
    p = np.repeat(p_ar[:], nums,0)
    for i in range(100):
        
        lat =  np.random.normal(0, 1, size=(nums, 128)).astype ('float32')    
        lat = lat + prop_np
                
        lat_torch =  torch.Tensor(lat).to(device)
        cond_torch =  torch.Tensor(p).to(device)
        
        output = autoencoder.decode(lat_torch, cond_torch)
        outp = output.cpu().detach().numpy()
        
        for j in range(nums):
            decode_smi = outp[j].reshape(1, 120, len(charset)).argmax(axis=2)[0]
            smi = decode_smiles_from_indexes(decode_smi, charset)
            m = Chem.MolFromSmiles(smi)
            #if (m != None) and (' ' not in smi) and (abs(ExactMolWt(m)-p1)<10*2)  and (abs(MolLogP(m)-p2)<0.548546*2)  and (abs(calculateScore(m)-p3)<0.3176*2) and (abs(qed(m)-p4)<0.0357*2) and (abs(CalcNumHBA(m)-p5)<1*2) and (abs(CalcNumHBD(m)-p6)<1*2)and (abs(CalcTPSA(m)-p7)<7.47*2):
            if (m != None) and (' ' not in smi):
                valid_smile.append(smi)
                
    valid_smile = list(set(valid_smile))
 
    return valid_smile

In [17]:
lines = []
valid_smile_all = []

condition = 200
## set the condition value (MW value) 
X_train, X_test, charset = load_dataset('./data/processed.h5')

valid_smile = reconstructed(model, charset, condition)

valid_smile_all.append(valid_smile)

In [18]:
valid_smile_all

[['CCCCCCCCCCCCCCCCCCOOOOO',
  'CCOC(OO)CCCCCCCCCCCC',
  'CC1cOcccccccccc1BCCNN',
  'Cc1cnnnnn1CCCNNN',
  'c1ccccccccccccccc1',
  'c1ccc(cc1)CCCCCCCCCCCC',
  'CCOC1=cccccccccccccc1',
  'c1ccccccc1BBOO',
  'CCCCOCCCCCCCCCCCOOOO',
  'CCCCCCCCCCCCCCCCCCCCC',
  'c1ccccccc1CCCCCCCCCOOO',
  'CCOC(OO)CC(CCCCCCCCCCCC)',
  'CC(C)COCCCCCCCCCCCCOOOOO',
  'c1ccccccc1CCCCCCCCC',
  'c1ccccccccccnnn1',
  'CCOC(OO)CCCCCCCCCCCN',
  'CCCCCCCCCCCCCCCCCCOOOO',
  'CC(CCCCCCCCCCCCCOOO)NNN',
  'CCCCCCCCCCCC1CCCCCC1',
  'c1ccccccc1CCCCCCNN',
  'c1ccc(cc1)COOOOONCCCCCCCCCCC',
  'CCCCOCCCCCOOOCCOOOOOO',
  'C[C@@H]CCCCCCCCCCCCCCCCCO',
  'CCCCOCCCOOOOSSCOOOO',
  'c1ccc(cc1)CCCOOCOO',
  'Cc1cccnnn1BBCCOO',
  'Cc1ccccccccccc1COOOO',
  'Cc1cccccccccCC1CCCCCCCCCCF',
  'c1cc(ccc1CCCN)CCCCCCCC',
  'Cc1cccnnccccccn1CCCO',
  'c1ccc2ccc12OOCOOOOOO',
  'C[C@@H]CCCCCCC=CCCCCCCC',
  'c1ccc(cc1)COOOONNNN',
  'CC1C2CCCC1cCCCCCcccccc2',
  'c1ccc(cc1)CCCC=CCCCCCCCOOO',
  'CC1cccccccc1CCCCOOOOOOO',
  'C1cccccc1NCCCCCCCCCCCCCCCC(C