/
conformer_featurizer.py
256 lines (226 loc) · 9.62 KB
/
conformer_featurizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from deepchem.feat.graph_data import GraphData
from deepchem.feat import MolecularFeaturizer
# similar to SNAP featurizer. both taken from Open Graph Benchmark (OGB) github.com/snap-stanford/ogb
# The difference between this and the SNAP features is the lack of masking tokens, possible_implicit_valence_list, possible_bond_dirs
# and the prescence of possible_bond_stereo_list, possible_is_conjugated_list, possible_is_in_ring_list,
allowable_features = {
'possible_atomic_num_list': list(range(1, 119)) + ['misc'], # type: ignore
'possible_chirality_list': [
'CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW',
'CHI_OTHER', 'misc'
],
'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
'possible_formal_charge_list': [
-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'
],
'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
'possible_hybridization_list': [
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
],
'possible_is_aromatic_list': [False, True],
'possible_is_in_ring_list': [False, True],
'possible_bond_type_list': [
'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC', 'misc'
],
'possible_bond_stereo_list': [
'STEREONONE',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
'STEREOANY',
],
'possible_is_conjugated_list': [False, True],
}
full_atom_feature_dims = list(
map(
len, # type: ignore
[
allowable_features['possible_atomic_num_list'],
allowable_features['possible_chirality_list'],
allowable_features['possible_degree_list'],
allowable_features['possible_formal_charge_list'],
allowable_features['possible_numH_list'],
allowable_features['possible_number_radical_e_list'],
allowable_features['possible_hybridization_list'],
allowable_features['possible_is_aromatic_list'],
allowable_features['possible_is_in_ring_list']
]))
full_bond_feature_dims = list(
map(
len, # type: ignore
[
allowable_features['possible_bond_type_list'],
allowable_features['possible_bond_stereo_list'],
allowable_features['possible_is_conjugated_list']
]))
def safe_index(feature_list, e):
"""
Return index of element e in list l. If e is not present, return the last index
Parameters
----------
feature_list : list
Feature vector
e : int
Element index to find in feature vector
"""
try:
return feature_list.index(e)
except ValueError:
return len(feature_list) - 1
class RDKitConformerFeaturizer(MolecularFeaturizer):
"""
A featurizer that featurizes an RDKit mol object as a GraphData object with 3D coordinates. The 3D coordinates are represented in the node_pos_features attribute of the GraphData object of shape [num_atoms * num_conformers, 3].
The ETKDGv2 algorithm is used to generate 3D coordinates for the molecule.
The RDKit source for this algorithm can be found in RDkit/Code/GraphMol/DistGeomHelpers/Embedder.cpp
The documentation can be found here:
https://rdkit.org/docs/source/rdkit.Chem.rdDistGeom.html#rdkit.Chem.rdDistGeom.ETKDGv2
This featurization requires RDKit.
Examples
--------
>>> from deepchem.feat.molecule_featurizers.conformer_featurizer import RDKitConformerFeaturizer
>>> featurizer = RDKitConformerFeaturizer()
>>> molecule = "CCO"
>>> conformer = featurizer.featurize(molecule)
>>> print (type(conformer[0]))
<class 'deepchem.feat.graph_data.GraphData'>
"""
# FIXME Add support for multiple conformers (wip)
# def __init__(self, num_conformers: int = 1, rmsd_cutoff: float = 2):
# """
# Initialize the RDKitConformerFeaturizer with the given parameters.
# Parameters
# ----------
# num_conformers : int, optional, default=1
# The number of conformers to generate for each molecule.
# rmsd_cutoff : float, optional, default=2
# The root-mean-square deviation (RMSD) cutoff value. Conformers with an RMSD
# greater than this value will be discarded.
# """
# self.num_conformers = num_conformers
# self.rmsd_cutoff = rmsd_cutoff
def atom_to_feature_vector(self, atom):
"""
Converts an RDKit atom object to a feature list of indices.
Parameters
----------
atom : Chem.rdchem.Atom
RDKit atom object.
Returns
-------
List[int]
List of feature indices for the given atom.
"""
atom_feature = [
safe_index(allowable_features['possible_atomic_num_list'],
atom.GetAtomicNum()),
safe_index(allowable_features['possible_chirality_list'],
str(atom.GetChiralTag())),
safe_index(allowable_features['possible_degree_list'],
atom.GetTotalDegree()),
safe_index(allowable_features['possible_formal_charge_list'],
atom.GetFormalCharge()),
safe_index(allowable_features['possible_numH_list'],
atom.GetTotalNumHs()),
safe_index(allowable_features['possible_number_radical_e_list'],
atom.GetNumRadicalElectrons()),
safe_index(allowable_features['possible_hybridization_list'],
str(atom.GetHybridization())),
allowable_features['possible_is_aromatic_list'].index(
atom.GetIsAromatic()),
allowable_features['possible_is_in_ring_list'].index(
atom.IsInRing()),
]
return atom_feature
def bond_to_feature_vector(self, bond):
"""
Converts an RDKit bond object to a feature list of indices.
Parameters
----------
bond : Chem.rdchem.Bond
RDKit bond object.
Returns
-------
List[int]
List of feature indices for the given bond.
"""
bond_feature = [
safe_index(allowable_features['possible_bond_type_list'],
str(bond.GetBondType())),
allowable_features['possible_bond_stereo_list'].index(
str(bond.GetStereo())),
allowable_features['possible_is_conjugated_list'].index(
bond.GetIsConjugated()),
]
return bond_feature
def _featurize(self, datapoint):
"""
Featurizes a molecule into a graph representation with 3D coordinates.
Parameters
----------
datapoint : RdkitMol
RDKit molecule object
Returns
-------
graph: List[GraphData]
list of GraphData objects of the molecule conformers with 3D coordinates.
"""
# Derived from https://github.com/HannesStark/3DInfomax/blob/5cd32629c690e119bcae8726acedefdb0aa037fc/datasets/qm9_dataset_rdkit_conformers.py#L377
# add hydrogen bonds to molecule because they are not in the smiles representation
mol = Chem.AddHs(datapoint)
ps = AllChem.ETKDGv2()
ps.useRandomCoords = True
AllChem.EmbedMolecule(mol, ps)
# FIXME Add support for multiple conformers (wip)
# AllChem.EmbedMultipleConfs(mol, self.num_conformers)
# AllChem.MMFFOptimizeMolecule(mol)
# rmsd_list = []
# rdMolAlign.AlignMolConformers(mol, RMSlist=rmsd_list)
# # insert 0 RMSD for first conformer
# rmsd_list.insert(0, 0)
# conformers = [
# mol.GetConformer(i)
# for i in range(self.num_conformers)
# if rmsd_list[i] < self.rmsd_cutoff
# ]
# # if conformer list is less than num_conformers, pad by repeating conformers
# conf_idx = 0
# while len(conformers) < self.num_conformers:
# conformers.append(conformers[conf_idx])
# conf_idx += 1
# coordinates = [conf.GetPositions() for conf in conformers]
AllChem.MMFFOptimizeMolecule(mol, confId=0)
conf = mol.GetConformer()
coordinates = conf.GetPositions()
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(self.atom_to_feature_vector(atom))
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = self.bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# Graph connectivity in COO format with shape [2, num_edges]
# FIXME Add support for multiple conformers (wip)
# graph_list = []
# for i in range(self.num_conformers):
# graph_list.append(
# GraphData(node_pos_features=np.array(coordinates[i]),
# node_features=np.array(atom_features_list),
# edge_features=np.array(edge_features_list),
# edge_index=np.array(edges_list).T))
# return graph_list
return GraphData(node_pos_features=coordinates,
node_features=np.array(atom_features_list),
edge_features=np.array(edge_features_list),
edge_index=np.array(edges_list).T)