-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
391 lines (338 loc) · 11.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import random
import numpy as np
import warnings
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import logging
from joblib import Parallel, delayed
from rdkit import Chem,DataStructs
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit.Chem import SaltRemover
from rdkit.Chem import rdmolops
rdBase.DisableLog('rdApp.error')
################ For data process ################
def _initialiseNeutralisationReactions():
patts = (
# Imidazoles
('[n+;H]', 'n'),
# Amines
('[N+;!H0]', 'N'),
# Carboxylic acids and alcohols
('[$([O-]);!$([O-][#7])]', 'O'),
# Thiols
('[S-;X1]', 'S'),
# Sulfonamides
('[$([N-;X2]S(=O)=O)]', 'N'),
# Enamines
('[$([N-;X2][C,N]=C)]', 'N'),
# Tetrazoles
('[n-]', '[nH]'),
# Sulfoxides
('[$([S-]=O)]', 'S'),
# Amides
('[$([N-]C=O)]', 'N'),
)
return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]
_reactions = _initialiseNeutralisationReactions()
def _neutraliseCharges(mol, reactions=None):
global _reactions
if reactions is None:
reactions = _reactions
replaced = False
for i, (reactant, product) in enumerate(reactions):
while mol.HasSubstructMatch(reactant):
replaced = True
rms = AllChem.ReplaceSubstructs(mol, reactant, product)
mol = rms[0]
if replaced:
return mol, True
else:
return mol, False
def valid_size(mol, min_heavy_atoms, max_heavy_atoms, remove_long_side_chains):
"""Filters molecules on number of heavy atoms and atom types"""
if mol:
correct_size = min_heavy_atoms < mol.GetNumHeavyAtoms() < max_heavy_atoms
if not correct_size:
return
has_long_sidechains = False
if remove_long_side_chains:
# remove aliphatic side chains with at least 4 carbons not in a ring
sma = '[CR0]-[CR0]-[CR0]-[CR0]'
has_long_sidechains = mol.HasSubstructMatch(Chem.MolFromSmarts(sma))
return correct_size and not has_long_sidechains
def standardize_smiles(smiles, min_heavy_atoms=10, max_heavy_atoms=50,
remove_long_side_chains=False, neutralise_charges=True):
mol = Chem.MolFromSmiles(smiles)
if mol and neutralise_charges:
mol, _ = _neutraliseCharges(mol)
if mol:
rdmolops.Cleanup(mol)
rdmolops.SanitizeMol(mol)
mol = rdmolops.RemoveHs(mol, implicitOnly=False, updateExplicitCount=False, sanitize=True)
if mol and valid_size(mol, min_heavy_atoms, max_heavy_atoms, remove_long_side_chains):
return Chem.MolToSmiles(mol, isomericSmiles=False)
return None
def standardize_smiles_list(smiles_list):
"""Reads a SMILES list and returns a list of RDKIT SMILES"""
smiles_list = Parallel(n_jobs=-1, verbose=0)(delayed(standardize_smiles)(line) for line in smiles_list)
smiles_list = [smiles for smiles in set(smiles_list) if smiles is not None]
logging.debug("{} unique SMILES retrieved".format(len(smiles_list)))
return smiles_list
def canonical_smiles(smiles):
"""
Takes a SMILES string and returns its canonical SMILES.
Parameters
----------
smiles:str
SMILES strings to convert into canonical format
Returns
-------
new_smiles: str
canonical SMILES and NaNs if SMILES string is invalid or
unsanitized (when sanitize is True)
"""
try:
return Chem.MolToSmiles(Chem.MolFromSmiles(smiles),isomericSmiles=False)
except:
return None
def tokenize(smiles, tokens=None):
"""
Returns list of unique tokens, token-2-index dictionary and number of
unique tokens from the list of SMILES
Parameters
----------
smiles: list
list of SMILES strings to tokenize.
tokens: list, str (default None)
list of unique tokens
Returns
-------
tokens: list
list of unique tokens/SMILES alphabet.
token2idx: dict
dictionary mapping token to its index.
num_tokens: int
number of unique tokens.
"""
if tokens is None:
tokens = list(set(''.join(smiles)))
tokens = list(np.sort(tokens))
tokens = ''.join(tokens)
token2idx = dict((token, i) for i, token in enumerate(tokens))
num_tokens = len(tokens)
return tokens, token2idx, num_tokens
def randomSmiles(mol):
mol.SetProp("_canonicalRankingNumbers", "True")
idxs = list(range(0,mol.GetNumAtoms()))
random.shuffle(idxs)
for i,v in enumerate(idxs):
mol.GetAtomWithIdx(i).SetProp("_canonicalRankingNumber", str(v))
return Chem.MolToSmiles(mol,isomericSmiles=False)
def smile_augmentation(smile, augmentation, max_len):
mol = Chem.MolFromSmiles(smile)
s = set()
for _ in range(10000):
smiles = randomSmiles(mol)
if len(smiles)<=max_len:
s.add(smiles)
if len(s)==augmentation:
break
return list(s)
def save_smiles_to_file(filename, smiles, unique=True):
"""
Takes path to file and list of SMILES strings and writes SMILES to the specified file.
Args:
filename (str): path to the file
smiles (list): list of SMILES strings
unique (bool): parameter specifying whether to write only unique copies or not.
Output:
success (bool): defines whether operation was successfully completed or not.
"""
if unique:
smiles = list(set(smiles))
else:
smiles = list(smiles)
f = open(filename, 'w')
for mol in smiles:
f.writelines([mol, '\n'])
f.close()
return f.closed
def read_smiles_from_file(filename, unique=True, add_start_end_tokens=False):
"""
Reads SMILES from file. File must contain one SMILES string per line
with \n token in the end of the line.
Args:
filename (str): path to the file
unique (bool): return only unique SMILES
Returns:
smiles (list): list of SMILES strings from specified file.
success (bool): defines whether operation was successfully completed or not.
If 'unique=True' this list contains only unique copies.
"""
f = open(filename, 'r')
molecules = []
for line in f:
if add_start_end_tokens:
molecules.append('<' + line[:-1] + '>')
else:
molecules.append(line[:-1])
if unique:
molecules = list(set(molecules))
else:
molecules = list(molecules)
f.close()
return molecules, f.closed
################ For experiment ################
def fp2arr(fp):
arr = np.zeros((1,))
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
def fp_array_from_smiles_list(smiles,radius=2,nbits=2048):
mols = []
fps = []
for smile in smiles:
try:
mol = Chem.MolFromSmiles(smile)
mols.append(mol)
except:
pass
for mol in mols:
fp = AllChem.GetMorganFingerprintAsBitVect(mol=mol, radius=radius,nBits = nbits)
fp = fp2arr(fp)
fps.append(fp)
return fps
def fingerprint(smiles,radius=2,nbits=2048):
"""
Generates fingerprint for SMILES
If smiles is invalid, returns None
Returns fingerprint bits
Parameters:
smiles: SMILES string
"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol = mol, radius = radius, nBits = nbits)
return fingerprint
def scaffold(mol):
"""
Extracts a scafold from a molecule in a form of a canonic SMILES
"""
try:
scaffold = Chem.Scaffolds.MurckoScaffold.GetScaffoldForMol(mol)
except (ValueError, RuntimeError):
return None
scaffold_smiles = Chem.MolToSmiles(scaffold)
if scaffold_smiles == '' :
return None
return scaffold_smiles
def scaffolds(smiles_list):
mol_list = [Chem.MolFromSmiles(smile) for smile in smiles_list]
mol_list = [mol for mol in mol_list if mol is not None]
scaffold_list = [scaffold(mol) for mol in mol_list]
scaffolds = Counter(scaffold_list)
if None in scaffolds:
scaffolds.pop(None)
return scaffolds
def fragment(mol):
"""
fragment mol using BRICS and return smiles list
"""
fgs = Chem.AllChem.FragmentOnBRICSBonds(mol)
fgs_smi = Chem.MolToSmiles(fgs).split(".")
return fgs_smi
def fragments(smiles_list):
"""
fragment list of smiles using BRICS and return smiles list
"""
mol_list = [Chem.MolFromSmiles(smile) for smile in smiles_list]
mol_list = [mol for mol in mol_list if mol is not None]
fragments = Counter()
for mol in mol_list:
frags = fragment(mol)
fragments.update(frags)
return fragments
def get_structures(smiles_list):
fps = []
frags = []
scaffs = []
for smile in smiles_list:
mol = Chem.MolFromSmiles(smile)
fps.append(fingerprint(smile))
frags.append(fragment(mol))
scaffs.append(scaffold(mol))
return fps,frags,scaffs
def get_TanimotoSimilarity(sources_fps,target_fps,option = "max"):
maxs = []
means = []
for s_fp in sources_fps:
maximum = 0
total = 0
for t_fp in target_fps:
similarity = DataStructs.FingerprintSimilarity(s_fp,t_fp)
if similarity > maximum:
maximum = similarity
total = total + similarity
maxs.append(maximum)
means.append(total/len(target_fps))
if option == 'max':
return maxs
elif option == 'mean' :
return means
else:
return None
################ For train ################
def valid_score(smiles):
"""
score a smiles , if it is valid, score = 1 ; else score = 0
Parameters
----------
smiles: str
SMILES strings
Returns
-------
score: int 0 or 1
"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return 0
else :
return 1
def get_reward(sample,dis1,dis2,gen_loader):
if len(sample) == 2:
return 0,0,0
elif sample[1:].find('<') != -1:
return 0,0,0
else :
if sample.find(">") == -1:
x_temp = sample[1:]
else :
x_temp = sample[1:-1]
return dis1.classify(gen_loader.char_tensor(x_temp)), dis2.classify(gen_loader.char_tensor(x_temp)), valid_score(x_temp)
class GANLoss(nn.Module):
"""Reward-Refined NLLLoss Function for adversial training of Gnerator"""
def __init__(self):
super(GANLoss, self).__init__()
def forward(self, prob ,reward):
"""
Args:
prob: torch tensor
reward : torch tensor
"""
loss = prob * reward
loss = - torch.sum(loss)
return loss
class NLLLoss(nn.Module):
""" NLLLoss Function for Gnerator"""
def __init__(self):
super(NLLLoss, self).__init__()
def forward(self, prob):
"""
Args:
prob: torch tensor
"""
loss = - torch.sum(prob)
return loss