Skip to content

Commit

Permalink
Merge pull request #673 from shihchengli/fix_keep_atom_maps
Browse files Browse the repository at this point in the history
Fix atom/bond property prediction with atom-mapped SMILES and target classification
  • Loading branch information
oscarwumit committed Feb 29, 2024
2 parents 36f92ac + 14ff7d4 commit 1785627
Show file tree
Hide file tree
Showing 4 changed files with 541 additions and 8 deletions.
16 changes: 15 additions & 1 deletion chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,21 @@ def __init__(self,
@property
def mol(self) -> List[Union[Chem.Mol, Tuple[Chem.Mol, Chem.Mol]]]:
"""Gets the corresponding list of RDKit molecules for the corresponding SMILES list."""
mol = make_mols(self.smiles, self.is_reaction_list, self.is_explicit_h_list, self.is_adding_hs_list, self.is_keeping_atom_map_list)
if self.atom_targets is not None or self.bond_targets is not None:
# When the original atom mapping is used, the explicit hydrogens specified in the input SMILES should be used
# However, the explicit Hs can only be added for reactions with `--explicit_h` flag
# To fix this, the attribute of `keep_h_list` in make_mols() is set to match the `keep_atom_map_list`
mol = make_mols(smiles=self.smiles,
reaction_list=self.is_reaction_list,
keep_h_list=self.is_keeping_atom_map_list,
add_h_list=self.is_adding_hs_list,
keep_atom_map_list=self.is_keeping_atom_map_list)
else:
mol = make_mols(smiles=self.smiles,
reaction_list=self.is_reaction_list,
keep_h_list=self.is_explicit_h_list,
add_h_list=self.is_adding_hs_list,
keep_atom_map_list=self.is_keeping_atom_map_list)
if cache_mol():
for s, m in zip(self.smiles, mol):
SMILES_TO_MOL[s] = m
Expand Down
22 changes: 17 additions & 5 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,17 @@ def get_mixed_task_names(path: str,
for row in reader:
atom_target_names, bond_target_names, molecule_target_names = [], [], []
smiles = [row[c] for c in smiles_columns]
mol = make_mol(smiles[0], keep_h, add_h, keep_atom_map)
for s in smiles:
if keep_atom_map:
# When the original atom mapping is used, the explicit hydrogens specified in the input SMILES should be used
# However, the explicit Hs can only be added for reactions with `--explicit_h` flag
# To fix this, `keep_h` is set to True when `keep_atom_map` is also True
mol = make_mol(s, keep_h=True, add_h=add_h, keep_atom_map=True)
else:
mol = make_mol(s, keep_h=keep_h, add_h=add_h, keep_atom_map=False)
if len(mol.GetAtoms()) != len(mol.GetBonds()):
break

for column in target_names:
value = row[column]
value = value.replace('None', 'null')
Expand All @@ -160,16 +170,18 @@ def get_mixed_task_names(path: str,
if len(target.shape) == 0:
is_molecule_target = True
elif len(target.shape) == 1:
if len(mol.GetAtoms()) == len(mol.GetBonds()):
break
elif len(target) == len(mol.GetAtoms()): # Atom targets saved as 1D list
if len(target) == len(mol.GetAtoms()): # Atom targets saved as 1D list
is_atom_target = True
elif len(target) == len(mol.GetBonds()): # Bond targets saved as 1D list
is_bond_target = True
else:
raise RuntimeError(f'Unrecognized targets of column {column} in {path}. '
'Expected targets should be either atomic or bond targets. '
'Please ensure the content is correct.')
elif len(target.shape) == 2: # Bond targets saved as 2D list
is_bond_target = True
else:
raise ValueError('Unrecognized targets of column {column} in {path}.')
raise ValueError(f'Unrecognized targets of column {column} in {path}.')

if is_atom_target:
atom_target_names.append(column)
Expand Down

0 comments on commit 1785627

Please sign in to comment.