Skip to content

Commit

Permalink
Merge pull request #2109 from nd-02110114/gat-pyg-2
Browse files Browse the repository at this point in the history
Implement sample GAT model for working PyG with DeepChem
  • Loading branch information
nissy-dev committed Sep 2, 2020
2 parents 96de1d1 + b7b56fa commit 3d257a0
Show file tree
Hide file tree
Showing 16 changed files with 1,327 additions and 30 deletions.
9 changes: 9 additions & 0 deletions deepchem/feat/__init__.py
@@ -1,12 +1,16 @@
"""
Making it easy to import in classes.
"""
# flake8: noqa

# base classes for featurizers
from deepchem.feat.base_classes import Featurizer
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.feat.base_classes import MaterialStructureFeaturizer
from deepchem.feat.base_classes import MaterialCompositionFeaturizer
from deepchem.feat.base_classes import ComplexFeaturizer
from deepchem.feat.base_classes import UserDefinedFeaturizer

from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.graph_features import WeaveFeaturizer
from deepchem.feat.fingerprints import CircularFingerprint
Expand All @@ -22,6 +26,11 @@
from deepchem.feat.atomic_coordinates import NeighborListComplexAtomicCoordinates
from deepchem.feat.adjacency_fingerprints import AdjacencyFingerprint
from deepchem.feat.smiles_featurizers import SmilesToSeq, SmilesToImage

# molecule featurizers
from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer

# material featurizers
from deepchem.feat.material_featurizers import ElementPropertyFingerprint
from deepchem.feat.material_featurizers import SineCoulombMatrix
from deepchem.feat.material_featurizers import CGCNNFeaturizer
55 changes: 55 additions & 0 deletions deepchem/feat/base_classes.py
@@ -1,6 +1,7 @@
"""
Feature calculations.
"""
import inspect
import logging
import numpy as np
import multiprocessing
Expand Down Expand Up @@ -75,6 +76,60 @@ def _featurize(self, datapoint: Any):
"""
raise NotImplementedError('Featurizer is not defined.')

def __repr__(self) -> str:
"""Convert self to repr representation.
Returns
-------
str
The string represents the class.
Examples
--------
>>> import deepchem as dc
>>> dc.feat.CircularFingerprint(size=1024, radius=4)
CircularFingerprint[radius=4, size=1024, chiral=False, bonds=True, features=False, sparse=False, smiles=False]
>>> dc.feat.CGCNNFeaturizer()
CGCNNFeaturizer[radius=8.0, max_neighbors=8, step=0.2]
"""
args_spec = inspect.getfullargspec(self.__init__) # type: ignore
args_names = [arg for arg in args_spec.args if arg != 'self']
args_info = ''
for arg_name in args_names:
args_info += arg_name + '=' + str(self.__dict__[arg_name]) + ', '
return self.__class__.__name__ + '[' + args_info[:-2] + ']'

def __str__(self) -> str:
"""Convert self to str representation.
Returns
-------
str
The string represents the class.
Examples
--------
>>> import deepchem as dc
>>> str(dc.feat.CircularFingerprint(size=1024, radius=4))
'CircularFingerprint_radius_4_size_1024'
>>> str(dc.feat.CGCNNFeaturizer())
'CGCNNFeaturizer'
"""
args_spec = inspect.getfullargspec(self.__init__) # type: ignore
args_names = [arg for arg in args_spec.args if arg != 'self']
args_num = len(args_names)
args_default_values = [None for _ in range(args_num)]
if args_spec.defaults is not None:
defaults = list(args_spec.defaults)
args_default_values[-len(defaults):] = defaults

override_args_info = ''
for arg_name, default in zip(args_names, args_default_values):
arg_value = self.__dict__[arg_name]
if default != arg_value:
override_args_info += '_' + arg_name + '_' + str(arg_value)
return self.__class__.__name__ + override_args_info


class ComplexFeaturizer(object):
""""
Expand Down
2 changes: 2 additions & 0 deletions deepchem/feat/molecule_featurizers/__init__.py
@@ -0,0 +1,2 @@
# flake8: noqa
from deepchem.feat.molecule_featurizers.mol_graph_conv_featurizer import MolGraphConvFeaturizer
196 changes: 196 additions & 0 deletions deepchem/feat/molecule_featurizers/mol_graph_conv_featurizer.py
@@ -0,0 +1,196 @@
from typing import List, Sequence, Tuple
import numpy as np

from deepchem.utils.typing import RDKitAtom, RDKitBond, RDKitMol
from deepchem.feat.graph_data import GraphData
from deepchem.feat.base_classes import MolecularFeaturizer
from deepchem.utils.molecule_feature_utils import get_atom_type_one_hot, \
construct_hydrogen_bonding_info, get_atom_hydrogen_bonding_one_hot, \
get_atom_is_in_aromatic_one_hot, get_atom_hybridization_one_hot, \
get_atom_total_num_Hs_one_hot, get_atom_chirality_one_hot, get_atom_formal_charge, \
get_atom_partial_charge, get_atom_ring_size_one_hot, get_atom_total_degree_one_hot, \
get_bond_type_one_hot, get_bond_is_in_same_ring_one_hot, get_bond_is_conjugated_one_hot, \
get_bond_stereo_one_hot


def _construct_atom_feature(atom: RDKitAtom,
h_bond_infos: List[Tuple[int, str]],
sssr: List[Sequence]) -> List[float]:
"""Construct an atom feature from a RDKit atom object.
Parameters
----------
atom: rdkit.Chem.rdchem.Atom
RDKit atom object
h_bond_infos: List[Tuple[int, str]]
A list of tuple `(atom_index, hydrogen_bonding_type)`.
Basically, it is expected that this value is the return value of
`construct_hydrogen_bonding_info`. The `hydrogen_bonding_type`
value is "Acceptor" or "Donor".
sssr: List[Sequence]
The return value of `Chem.GetSymmSSSR(mol)`.
The value is a sequence of rings.
Returns
-------
List[float]
A one-hot vector of the atom feature.
"""
atom_type = get_atom_type_one_hot(atom)
chirality = get_atom_chirality_one_hot(atom)
formal_charge = get_atom_formal_charge(atom)
partial_charge = get_atom_partial_charge(atom)
ring_size = get_atom_ring_size_one_hot(atom, sssr)
hybridization = get_atom_hybridization_one_hot(atom)
acceptor_donor = get_atom_hydrogen_bonding_one_hot(atom, h_bond_infos)
aromatic = get_atom_is_in_aromatic_one_hot(atom)
degree = get_atom_total_degree_one_hot(atom)
total_num = get_atom_total_num_Hs_one_hot(atom)
return atom_type + chirality + formal_charge + partial_charge + \
ring_size + hybridization + acceptor_donor + aromatic + degree + total_num


def _construct_bond_feature(bond: RDKitBond) -> List[float]:
"""Construct a bond feature from a RDKit bond object.
Parameters
---------
bond: rdkit.Chem.rdchem.Bond
RDKit bond object
Returns
-------
List[float]
A one-hot vector of the bond feature.
"""
bond_type = get_bond_type_one_hot(bond)
same_ring = get_bond_is_in_same_ring_one_hot(bond)
conjugated = get_bond_is_conjugated_one_hot(bond)
stereo = get_bond_stereo_one_hot(bond)
return bond_type + same_ring + conjugated + stereo


class MolGraphConvFeaturizer(MolecularFeaturizer):
"""This class is a featurizer of general graph convolution networks for molecules.
The default node(atom) and edge(bond) representations are based on
`WeaveNet paper <https://arxiv.org/abs/1603.00856>`_. If you want to use your own representations,
you could use this class as a guide to define your original Featurizer. In many cases, it's enough
to modify return values of `construct_atom_feature` or `construct_bond_feature`.
The default node representation are constructed by concatenating the following values,
and the feature length is 39.
- Atom type: A one-hot vector of this atom, "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "other atoms".
- Chirality: A one-hot vector of the chirality, "R" or "S".
- Formal charge: Integer electronic charge.
- Partial charge: Calculated partial charge.
- Ring sizes: A one-hot vector of the size (3-8) of rings that include this atom.
- Hybridization: A one-hot vector of "sp", "sp2", "sp3".
- Hydrogen bonding: A one-hot vector of whether this atom is a hydrogen bond donor or acceptor.
- Aromatic: A one-hot vector of whether the atom belongs to an aromatic ring.
- Degree: A one-hot vector of the degree (0-5) of this atom.
- Number of Hydrogens: A one-hot vector of the number of hydrogens (0-4) that this atom connected.
The default edge representation are constructed by concatenating the following values,
and the feature length is 11.
- Bond type: A one-hot vector of the bond type, "single", "double", "triple", or "aromatic".
- Same ring: A one-hot vector of whether the atoms in the pair are in the same ring.
- Conjugated: A one-hot vector of whether this bond is conjugated or not.
- Stereo: A one-hot vector of the stereo configuration of a bond.
If you want to know more details about features, please check the paper [1]_ and
utilities in deepchem.utils.molecule_feature_utils.py.
Examples
--------
>>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
>>> featurizer = MolGraphConvFeaturizer()
>>> out = featurizer.featurize(smiles)
>>> type(out[0])
<class 'deepchem.feat.graph_data.GraphData'>
>>> out[0].num_node_features
39
>>> out[0].num_edge_features
11
References
----------
.. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints."
Journal of computer-aided molecular design 30.8 (2016):595-608.
Notes
-----
This class requires RDKit to be installed.
"""

def __init__(self, add_self_edges: bool = False):
"""
Parameters
----------
add_self_edges: bool, default False
Whether to add self-connected edges or not. If you want to use DGL,
you sometimes need to add explict self-connected edges.
"""
self.add_self_edges = add_self_edges

def _featurize(self, mol: RDKitMol) -> GraphData:
"""Calculate molecule graph features from RDKit mol object.
Parameters
----------
mol: rdkit.Chem.rdchem.Mol
RDKit mol object.
Returns
-------
graph: GraphData
A molecule graph with some features.
"""
try:
from rdkit import Chem
from rdkit.Chem import AllChem
except ModuleNotFoundError:
raise ValueError("This method requires RDKit to be installed.")

# construct atom and bond features
try:
mol.GetAtomWithIdx(0).GetProp('_GasteigerCharge')
except:
# If partial charges were not computed
AllChem.ComputeGasteigerCharges(mol)

h_bond_infos = construct_hydrogen_bonding_info(mol)
sssr = Chem.GetSymmSSSR(mol)

# construct atom (node) feature
atom_features = np.array(
[
_construct_atom_feature(atom, h_bond_infos, sssr)
for atom in mol.GetAtoms()
],
dtype=np.float,
)

# construct edge (bond) information
src, dest, bond_features = [], [], []
for bond in mol.GetBonds():
# add edge list considering a directed graph
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
src += [start, end]
dest += [end, start]
bond_features += 2 * [_construct_bond_feature(bond)]

if self.add_self_edges:
num_atoms = mol.GetNumAtoms()
src += [i for i in range(num_atoms)]
dest += [i for i in range(num_atoms)]
# add dummy edge features
bond_fea_length = len(bond_features[0])
bond_features += num_atoms * [[0 for _ in range(bond_fea_length)]]

return GraphData(
node_features=atom_features,
edge_index=np.array([src, dest], dtype=np.int),
edge_features=np.array(bond_features, dtype=np.float))
9 changes: 4 additions & 5 deletions deepchem/feat/tests/test_graph_data.py
@@ -1,5 +1,4 @@
import unittest
import pytest
import numpy as np
from deepchem.feat.graph_data import GraphData, BatchGraphData

Expand Down Expand Up @@ -38,7 +37,7 @@ def test_graph_data(self):
assert isinstance(dgl_graph, DGLGraph)

def test_invalid_graph_data(self):
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
invalid_node_features_type = list(np.random.random_sample((5, 32)))
edge_index = np.array([
[0, 1, 2, 2, 3, 4],
Expand All @@ -49,7 +48,7 @@ def test_invalid_graph_data(self):
edge_index=edge_index,
)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
node_features = np.random.random_sample((5, 32))
invalid_edge_index_shape = np.array([
[0, 1, 2, 2, 3, 4],
Expand All @@ -60,7 +59,7 @@ def test_invalid_graph_data(self):
edge_index=invalid_edge_index_shape,
)

with pytest.raises(ValueError):
with self.assertRaises(ValueError):
node_features = np.random.random_sample((5, 5))
invalid_edge_index_shape = np.array([
[0, 1, 2, 2, 3, 4],
Expand All @@ -72,7 +71,7 @@ def test_invalid_graph_data(self):
edge_index=invalid_edge_index_shape,
)

with pytest.raises(TypeError):
with self.assertRaises(TypeError):
node_features = np.random.random_sample((5, 32))
_ = GraphData(node_features=node_features)

Expand Down
42 changes: 42 additions & 0 deletions deepchem/feat/tests/test_mol_graph_conv_featurizer.py
@@ -0,0 +1,42 @@
import unittest

from deepchem.feat import MolGraphConvFeaturizer


class TestMolGraphConvFeaturizer(unittest.TestCase):

def test_default_featurizer(self):
smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
featurizer = MolGraphConvFeaturizer()
graph_feat = featurizer.featurize(smiles)
assert len(graph_feat) == 2

# assert "C1=CC=CN=C1"
assert graph_feat[0].num_nodes == 6
assert graph_feat[0].num_node_features == 39
assert graph_feat[0].num_edges == 12
assert graph_feat[0].num_edge_features == 11

# assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
assert graph_feat[1].num_nodes == 22
assert graph_feat[1].num_node_features == 39
assert graph_feat[1].num_edges == 44
assert graph_feat[1].num_edge_features == 11

def test_featurizer_with_self_loop(self):
smiles = ["C1=CC=CN=C1", "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"]
featurizer = MolGraphConvFeaturizer(add_self_edges=True)
graph_feat = featurizer.featurize(smiles)
assert len(graph_feat) == 2

# assert "C1=CC=CN=C1"
assert graph_feat[0].num_nodes == 6
assert graph_feat[0].num_node_features == 39
assert graph_feat[0].num_edges == 12 + 6
assert graph_feat[0].num_edge_features == 11

# assert "O=C(NCc1cc(OC)c(O)cc1)CCCC/C=C/C(C)C"
assert graph_feat[1].num_nodes == 22
assert graph_feat[1].num_node_features == 39
assert graph_feat[1].num_edges == 44 + 22
assert graph_feat[1].num_edge_features == 11
1 change: 1 addition & 0 deletions deepchem/models/__init__.py
Expand Up @@ -37,6 +37,7 @@
try:
from deepchem.models.torch_models import TorchModel
from deepchem.models.torch_models import CGCNN, CGCNNModel
from deepchem.models.torch_models import GAT, GATModel
except ModuleNotFoundError:
pass

Expand Down

0 comments on commit 3d257a0

Please sign in to comment.