In [1]:
from rdkit import Chem

import numpy
from rdkit.Chem import rdMolDescriptors, rdmolops, Mol, AllChem

from chainerchem.dataset.preprocessor.mol_preprocessor import MolPreprocessor
from chainerchem.dataset.preprocessor.mol_preprocessor import MolFeatureExtractFailure


def get_atom_array(mol, zero_padding=False ,n_max_atoms=-1):
        atom_list = [a.GetAtomicNum() for a in mol.GetAtoms()]
        if zero_padding:
            # 'empty' padding for atom_list
            # 0 represents empty place for atom
            n_atom = len(atom_list)
            atom_list = atom_list + [0] * (n_max_atoms - n_atom)
            # atom_array = numpy.zeros((max_num_atom,), dtype=numpy.int32)
            # atom_array[:n_atom] = numpy.array(atom_list)
        return numpy.array(atom_list, dtype=numpy.int32)

def constructBond(mol, i, j):
    bondFeature = numpy.zeros((4,), dtype = 'f')
    k = mol.GetBondBetweenAtoms(i,j)
    if k is not None:
        bondType = str(k.GetBondType())
        if bondType == 'SINGLE':
            bondFeature[0] = 1.0
        elif bondType =='DOUBLE':
            bondFeature[1] = 1.0            
        elif bondType =='TRIPLE':
            bondFeature[2] = 1.0
        elif bondType =='AROMATIC':
            bondFeature[3] = 1.0
        else:
            print("Unknown bond type", bondType)
            assert(False)
    return bondFeature

MAX_DISTANCE = 7

def constructDistance(distanceMatrix, i, j):
    distance = min(MAX_DISTANCE, int(distanceMatrix[i][j]))
    distanceFeature = numpy.zeros((MAX_DISTANCE,), dtype = 'f')
    distanceFeature[:distance] = 1.0
    return distanceFeature

def constructRingFeature(mol,n_max_atoms=-1):
    MAX_NUMBER_ATOM = n_max_atoms
    nAtom = mol.GetNumAtoms()    
    rinfo = mol.GetRingInfo()
    sssr = Chem.GetSymmSSSR(mol)
    ringFeature = numpy.zeros((MAX_NUMBER_ATOM**2, 1,), dtype = 'f')
    for ring in sssr:
        ring = list(ring)
        nAtomInRing = len(ring)
        for i in range(nAtomInRing):
            for j in range(nAtomInRing):
                a0 = ring[i]
                a1 = ring[j]
                ringFeature[a0*nAtom + a1] = 1
    return ringFeature

def constructPairFeature(mol=Chem.Mol, zero_padding=False, n_max_atoms=-1):
    MAX_NUMBER_ATOM = n_max_atoms
    nAtom = mol.GetNumAtoms()  
    distanceMatrix = Chem.GetDistanceMatrix(mol)
    distanceFeature = numpy.zeros((MAX_NUMBER_ATOM**2, MAX_DISTANCE,), dtype = 'f')
    print(distanceMatrix)
    for i in range(nAtom):
        for j in range(nAtom):
            distanceFeature[i*nAtom + j] = constructDistance(distanceMatrix, i, j)
    bondFeature = numpy.zeros((MAX_NUMBER_ATOM**2,4,), dtype = 'f')
    for i in range(nAtom):
        for j in range(nAtom):
            bondFeature[i*nAtom + j] = constructBond(mol, i, j)
    ringFeature = constructRingFeature(mol,n_max_atoms=MAX_NUMBER_ATOM)
    feature = numpy.hstack((distanceFeature, bondFeature, ringFeature))
    return feature
    
class WeaveNetPreprocessor(MolPreprocessor):
    def __init__(self, labels=None, max_atoms=-1, zero_padding=False):
        """

        Args:
            labels (str or list): label names to extract
            max_atoms (int): Max number of atoms for each molecule, if the number 
            of atoms is more than this value, this data is simply ignored.
            Setting negative value indicates no limit for max atoms.
            zero_padding (bool): True

        """
        super(WeaveNetPreprocessor, self).__init__(labels=labels)
        if zero_padding and max_atoms <= 0:
            raise ValueError('max_atoms must be set to positive value when '
                             'zero_padding is True')

        self.max_atoms = max_atoms
        self.zero_padding = zero_padding
        if zero_padding:
            self.feature_array_dtype_list = [numpy.float32, numpy.float32,
                                             numpy.int32]
        else:
            self.feature_array_dtype_list = [None, None, numpy.int32]
            
    def get_descriptor(self, mol):
        """

        Args:
            mol (Mol): 

        Returns:

        """
        mol =Chem.AddHs(mol)
        num_atoms = mol.GetNumAtoms()
        if self.max_atoms >= 0 and num_atoms > self.max_atoms:
            # Skip extracting feature. ignore this case.
            raise MolFeatureExtractFailure
        atom_array = get_atom_array(mol,self.zero_padding,
                                                   n_max_atoms=self.max_atoms)
        print(atom_array)
        pair_feature = constructPairFeature(mol, self.zero_padding,
                                                   n_max_atoms=self.max_atoms)
        print(pair_feature)
        return atom_array, pair_feature

In [2]:
smi ="CN1CCN(CC2=CC=C(C=C2)C(=O)NC2=CC(NC3=NC=CC(=N3)C3=CN=CC=C3)=C(C)C=C2)CC1"
mol = Chem.MolFromSmiles(smi)

In [7]:
wp = WeaveNetPreprocessor(max_atoms=100,zero_padding=True)

In [8]:
wp.get_descriptor(mol)

[6 7 6 6 7 6 6 6 6 6 6 6 6 8 7 6 6 6 7 6 7 6 6 6 7 6 6 7 6 6 6 6 6 6 6 6 6
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[[ 0.  1.  2. ...,  4.  3.  3.]
 [ 1.  0.  1. ...,  3.  2.  2.]
 [ 2.  1.  0. ...,  4.  3.  3.]
 ..., 
 [ 4.  3.  4. ...,  0.  3.  3.]
 [ 3.  2.  3. ...,  3.  0.  2.]
 [ 3.  2.  3. ...,  3.  2.  0.]]
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 1.  0.  0. ...,  0.  0.  0.]
 [ 1.  1.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]


(array([6, 7, 6, 6, 7, 6, 6, 6, 6, 6, 6, 6, 6, 8, 7, 6, 6, 6, 7, 6, 7, 6, 6,
        6, 7, 6, 6, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
 array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 1.,  0.,  0., ...,  0.,  0.,  0.],
        [ 1.,  1.,  0., ...,  0.,  0.,  0.],
        ..., 
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32))