Skip to content

Commit

Permalink
Merge pull request #212 from chemprop/reaction_correction
Browse files Browse the repository at this point in the history
added reaction balancing and bugfix
  • Loading branch information
cjmcgill committed Sep 27, 2021
2 parents 91c1b1b + 5a5bc8a commit d0adea6
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ In absorption spectra, sometimes the phase of collection will create regions in

### Reaction

As an alternative to molecule SMILES, Chemprop can also process atom-mapped reaction SMILES (see [Daylight manual](https://www.daylight.com/meetings/summerschool01/course/basics/smirks.html) for details on reaction SMILES), which consist of three parts denoting reactants, agents and products, separated by ">". Use the option `--reaction` to enable the input of reactions, which transforms the reactants and products of each reaction to the corresponding condensed graph of reaction and changes the initial atom and bond features to hold information from both the reactant and product (option `--reaction_mode reac_prod`), or from the reactant and the difference upon reaction (option `--reaction_mode reac_diff`, default) or from the product and the difference upon reaction (option `--reaction_mode prod_diff`). In reaction mode, Chemprop thus concatenates information to each atomic and bond feature vector, for example, with option `--reaction_mode reac_prod`, each atomic feature vector holds information on the state of the atom in the reactant (similar to default Chemprop), and concatenates information on the state of the atom in the product, so that the size of the D-MPNN increases slightly. Agents are discarded. Functions incompatible with a reaction as input (scaffold splitting and feature generation) are carried out on the reactants only. If the atom-mapped reaction SMILES contain mapped hydrogens, enable explicit hydrogens via `--explicit_h`. Example of an atom-mapped reaction SMILES denoting the reaction of methanol to formaldehyde without hydrogens: `[CH3:1][OH:2]>>[CH2:1]=[O:2]` and with hydrogens: `[C:1]([H:3])([H:4])([H:5])[O:2][H:6]>>[C:1]([H:3])([H:4])=[O:2].[H:5][H:6]`. The reactions do not need to be balanced and can thus contain unmapped parts, for example leaving groups, if necessary.
As an alternative to molecule SMILES, Chemprop can also process atom-mapped reaction SMILES (see [Daylight manual](https://www.daylight.com/meetings/summerschool01/course/basics/smirks.html) for details on reaction SMILES), which consist of three parts denoting reactants, agents and products, separated by ">". Use the option `--reaction` to enable the input of reactions, which transforms the reactants and products of each reaction to the corresponding condensed graph of reaction and changes the initial atom and bond features to hold information from both the reactant and product (option `--reaction_mode reac_prod`), or from the reactant and the difference upon reaction (option `--reaction_mode reac_diff`, default) or from the product and the difference upon reaction (option `--reaction_mode prod_diff`). In reaction mode, Chemprop thus concatenates information to each atomic and bond feature vector, for example, with option `--reaction_mode reac_prod`, each atomic feature vector holds information on the state of the atom in the reactant (similar to default Chemprop), and concatenates information on the state of the atom in the product, so that the size of the D-MPNN increases slightly. Agents are discarded. Functions incompatible with a reaction as input (scaffold splitting and feature generation) are carried out on the reactants only. If the atom-mapped reaction SMILES contain mapped hydrogens, enable explicit hydrogens via `--explicit_h`. Example of an atom-mapped reaction SMILES denoting the reaction of methanol to formaldehyde without hydrogens: `[CH3:1][OH:2]>>[CH2:1]=[O:2]` and with hydrogens: `[C:1]([H:3])([H:4])([H:5])[O:2][H:6]>>[C:1]([H:3])([H:4])=[O:2].[H:5][H:6]`. The reactions do not need to be balanced and can thus contain unmapped parts, for example leaving groups, if necessary. With reaction modes `reac_prod`, `reac_diff` and `prod_diff`, the atom and bond features of unbalanced aroma are set to zero on the side of the reaction they are not specified. Alternatively, features can be set to the same values on the reactant and product side via the modes `reac_prod_balance`, `reac_diff_balance` and `prod_diff_balance`, which corresponds to a rough balancing of the reaction.
For further details and benchmarking, as well as a citable reference, please see [DOI 10.33774/chemrxiv-2021-frfhz](https://doi.org/10.33774/chemrxiv-2021-frfhz).

### Pretraining
Expand Down
5 changes: 4 additions & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,15 @@ class TrainArgs(CommonArgs):
"""
Whether to adjust MPNN layer to take reactions as input instead of molecules.
"""
reaction_mode: Literal['reac_prod', 'reac_diff', 'prod_diff'] = 'reac_diff'
reaction_mode: Literal['reac_prod', 'reac_diff', 'prod_diff', 'reac_prod_balance', 'reac_diff_balance', 'prod_diff_balance'] = 'reac_diff'
"""
Choices for construction of atom and bond features for reactions
:code:`reac_prod`: concatenates the reactants feature with the products feature.
:code:`reac_diff`: concatenates the reactants feature with the difference in features between reactants and products.
:code:`prod_diff`: concatenates the products feature with the difference in features between reactants and products.
:code:`reac_prod_balance`: concatenates the reactants feature with the products feature, balances imbalanced reactions.
:code:`reac_diff_balance`: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions.
:code:`prod_diff_balance`: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions.
"""
explicit_h: bool = False
"""
Expand Down
65 changes: 51 additions & 14 deletions chemprop/features/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,21 @@ def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -
return features


def atom_features_zeros(atom: Chem.rdchem.Atom) -> List[Union[bool, int, float]]:
"""
Builds a feature vector for an atom containing only the atom number information.
:param atom: An RDKit atom.
:return: A list containing the atom features.
"""
if atom is None:
features = [0] * PARAMS.ATOM_FDIM
else:
features = onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']) + \
[0] * (PARAMS.ATOM_FDIM - PARAMS.MAX_ATOMIC_NUM - 1) #set other features to zero
return features


def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]:
"""
Builds a feature vector for a bond.
Expand Down Expand Up @@ -350,17 +365,30 @@ def __init__(self, mol: Union[str, Chem.Mol, Tuple[Chem.Mol, Chem.Mol]],
ri2pi, pio, rio = map_reac_to_prod(mol_reac, mol_prod)

# Get atom features
f_atoms_reac = [atom_features(atom) for atom in mol_reac.GetAtoms()] + [atom_features(None) for index in pio]
f_atoms_prod = [atom_features(mol_prod.GetAtomWithIdx(ri2pi[atom.GetIdx()])) if atom.GetIdx() not in rio else
atom_features(None) for atom in mol_reac.GetAtoms()] + [atom_features(mol_prod.GetAtomWithIdx(index)) for index in pio]

if self.reaction_mode in ['reac_diff','prod_diff']:
if self.reaction_mode in ['reac_diff','prod_diff', 'reac_prod']:
#Reactant: regular atom features for each atom in the reactants, as well as zero features for atoms that are only in the products (indices in pio)
f_atoms_reac = [atom_features(atom) for atom in mol_reac.GetAtoms()] + [atom_features_zeros(mol_prod.GetAtomWithIdx(index)) for index in pio]

#Product: regular atom features for each atom that is in both reactants and products (not in rio), other atom features zero,
#regular features for atoms that are only in the products (indices in pio)
f_atoms_prod = [atom_features(mol_prod.GetAtomWithIdx(ri2pi[atom.GetIdx()])) if atom.GetIdx() not in rio else
atom_features_zeros(atom) for atom in mol_reac.GetAtoms()] + [atom_features(mol_prod.GetAtomWithIdx(index)) for index in pio]
else: #balance
#Reactant: regular atom features for each atom in the reactants, copy features from product side for atoms that are only in the products (indices in pio)
f_atoms_reac = [atom_features(atom) for atom in mol_reac.GetAtoms()] + [atom_features(mol_prod.GetAtomWithIdx(index)) for index in pio]

#Product: regular atom features for each atom that is in both reactants and products (not in rio), copy features from reactant side for
#other atoms, regular features for atoms that are only in the products (indices in pio)
f_atoms_prod = [atom_features(mol_prod.GetAtomWithIdx(ri2pi[atom.GetIdx()])) if atom.GetIdx() not in rio else
atom_features(atom) for atom in mol_reac.GetAtoms()] + [atom_features(mol_prod.GetAtomWithIdx(index)) for index in pio]

if self.reaction_mode in ['reac_diff', 'prod_diff', 'reac_diff_balance', 'prod_diff_balance']:
f_atoms_diff = [list(map(lambda x, y: x - y, ii, jj)) for ii, jj in zip(f_atoms_prod, f_atoms_reac)]
if self.reaction_mode == 'reac_prod':
if self.reaction_mode in ['reac_prod', 'reac_prod_balance']:
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_prod)]
elif self.reaction_mode == 'reac_diff':
elif self.reaction_mode in ['reac_diff', 'reac_diff_balance']:
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_diff)]
elif self.reaction_mode == 'prod_diff':
elif self.reaction_mode in ['prod_diff', 'prod_diff_balance']:
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_prod, f_atoms_diff)]
self.n_atoms = len(self.f_atoms)
n_atoms_reac = mol_reac.GetNumAtoms()
Expand All @@ -373,8 +401,11 @@ def __init__(self, mol: Union[str, Chem.Mol, Tuple[Chem.Mol, Chem.Mol]],
for a1 in range(self.n_atoms):
for a2 in range(a1 + 1, self.n_atoms):
if a1 >= n_atoms_reac and a2 >= n_atoms_reac: # Both atoms only in product
bond_reac = None
bond_prod = mol_prod.GetBondBetweenAtoms(pio[a1 - n_atoms_reac], pio[a2 - n_atoms_reac])
if self.reaction_mode in ['reac_prod_balance', 'reac_diff_balance', 'prod_diff_balance']:
bond_reac = bond_prod
else:
bond_reac = None
elif a1 < n_atoms_reac and a2 >= n_atoms_reac: # One atom only in product
bond_reac = None
if a1 in ri2pi.keys():
Expand All @@ -386,20 +417,26 @@ def __init__(self, mol: Union[str, Chem.Mol, Tuple[Chem.Mol, Chem.Mol]],
if a1 in ri2pi.keys() and a2 in ri2pi.keys():
bond_prod = mol_prod.GetBondBetweenAtoms(ri2pi[a1], ri2pi[a2]) #Both atoms in both reactant and product
else:
bond_prod = None # One or both atoms only in reactant
if self.reaction_mode in ['reac_prod_balance', 'reac_diff_balance', 'prod_diff_balance']:
if a1 in ri2pi.keys() or a2 in ri2pi.keys():
bond_prod = None # One atom only in reactant
else:
bond_prod = bond_reac # Both atoms only in reactant
else:
bond_prod = None # One or both atoms only in reactant

if bond_reac is None and bond_prod is None:
continue

f_bond_reac = bond_features(bond_reac)
f_bond_prod = bond_features(bond_prod)
if self.reaction_mode in ['reac_diff', 'prod_diff']:
if self.reaction_mode in ['reac_diff', 'prod_diff', 'reac_diff_balance', 'prod_diff_balance']:
f_bond_diff = [y - x for x, y in zip(f_bond_reac, f_bond_prod)]
if self.reaction_mode == 'reac_prod':
if self.reaction_mode in ['reac_prod', 'reac_prod_balance']:
f_bond = f_bond_reac + f_bond_prod
elif self.reaction_mode == 'reac_diff':
elif self.reaction_mode in ['reac_diff', 'reac_diff_balance']:
f_bond = f_bond_reac + f_bond_diff
elif self.reaction_mode == 'prod_diff':
elif self.reaction_mode in ['prod_diff', 'prod_diff_balance']:
f_bond = f_bond_prod + f_bond_diff
self.f_bonds.append(self.f_atoms[a1] + f_bond)
self.f_bonds.append(self.f_atoms[a2] + f_bond)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,19 +696,19 @@ def test_predict_spectra(self,
(
'chemprop_reaction',
'chemprop',
2.025709,
2.019870,
['--reaction', '--data_path', os.path.join(TEST_DATA_DIR, 'reaction_regression.csv')]
),
(
'chemprop_scaffold_split',
'chemprop',
1.890371,
1.907502,
['--reaction', '--data_path', os.path.join(TEST_DATA_DIR, 'reaction_regression.csv'),'--split_type', 'scaffold_balanced']
),
(
'chemprop_morgan_features_generator',
'chemprop',
2.848752,
2.846405,
['--reaction', '--data_path', os.path.join(TEST_DATA_DIR, 'reaction_regression.csv'),'--features_generator', 'morgan']
),
(
Expand Down

0 comments on commit d0adea6

Please sign in to comment.