Skip to content

Commit

Permalink
Merge pull request #206 from chemprop/global_featurization_params
Browse files Browse the repository at this point in the history
Global featurization params
  • Loading branch information
cjmcgill committed Sep 14, 2021
2 parents 491085f + 0ce2bf2 commit 9c8ff40
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 68 deletions.
5 changes: 3 additions & 2 deletions chemprop/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
rdkit_2d_normalized_features_generator, register_features_generator
from .featurization import atom_features, bond_features, BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph, \
MolGraph, onek_encoding_unk, set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, \
is_reaction, is_explicit_h
is_reaction, is_explicit_h, reset_featurization_parameters
from .utils import load_features, save_features, load_valid_atom_or_bond_features

__all__ = [
Expand All @@ -29,5 +29,6 @@
'onek_encoding_unk',
'load_features',
'save_features',
'load_valid_atom_or_bond_features'
'load_valid_atom_or_bond_features',
'reset_featurization_parameters'
]
141 changes: 79 additions & 62 deletions chemprop/features/featurization.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,66 @@
from typing import List, Tuple, Union
from itertools import zip_longest
import logging

from rdkit import Chem
import torch
import numpy as np

from chemprop.rdkit import make_mol

# Atom feature sizes
MAX_ATOMIC_NUM = 100
ATOM_FEATURES = {
'atomic_num': list(range(MAX_ATOMIC_NUM)),
'degree': [0, 1, 2, 3, 4, 5],
'formal_charge': [-1, -2, 1, 2, 0],
'chiral_tag': [0, 1, 2, 3],
'num_Hs': [0, 1, 2, 3, 4],
'hybridization': [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2
],
}

# Distance feature sizes
PATH_DISTANCE_BINS = list(range(10))
THREE_D_DISTANCE_MAX = 20
THREE_D_DISTANCE_STEP = 1
THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP))

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
EXTRA_ATOM_FDIM = 0
BOND_FDIM = 14
EXTRA_BOND_FDIM = 0
REACTION_MODE = None
EXPLICIT_H = False
REACTION = False
class Featurization_parameters:
"""
A class holding molecule featurization parameters as attributes.
"""
def __init__(self) -> None:

# Atom feature sizes
self.MAX_ATOMIC_NUM = 100
self.ATOM_FEATURES = {
'atomic_num': list(range(self.MAX_ATOMIC_NUM)),
'degree': [0, 1, 2, 3, 4, 5],
'formal_charge': [-1, -2, 1, 2, 0],
'chiral_tag': [0, 1, 2, 3],
'num_Hs': [0, 1, 2, 3, 4],
'hybridization': [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2
],
}

# Distance feature sizes
self.PATH_DISTANCE_BINS = list(range(10))
self.THREE_D_DISTANCE_MAX = 20
self.THREE_D_DISTANCE_STEP = 1
self.THREE_D_DISTANCE_BINS = list(range(0, self.THREE_D_DISTANCE_MAX + 1, self.THREE_D_DISTANCE_STEP))

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
self.ATOM_FDIM = sum(len(choices) + 1 for choices in self.ATOM_FEATURES.values()) + 2
self.EXTRA_ATOM_FDIM = 0
self.BOND_FDIM = 14
self.EXTRA_BOND_FDIM = 0
self.REACTION_MODE = None
self.EXPLICIT_H = False
self.REACTION = False

# Create a global parameter object for reference throughout this module
PARAMS = Featurization_parameters()


def reset_featurization_parameters(logger: logging.Logger = None) -> None:
"""
Function resets feature parameter values to defaults by replacing the parameters instance.
"""
if logger is not None:
debug = logger.debug
else:
debug = print
debug('Setting molecule featurization parameters to default.')
global PARAMS
PARAMS = Featurization_parameters()


def get_atom_fdim(overwrite_default_atom: bool = False) -> int:
Expand All @@ -45,7 +70,7 @@ def get_atom_fdim(overwrite_default_atom: bool = False) -> int:
:param overwrite_default_atom: Whether to overwrite the default atom descriptors
:return: The dimensionality of the atom feature vector.
"""
return (not overwrite_default_atom) * ATOM_FDIM + EXTRA_ATOM_FDIM
return (not overwrite_default_atom) * PARAMS.ATOM_FDIM + PARAMS.EXTRA_ATOM_FDIM


def set_explicit_h(explicit_h: bool) -> None:
Expand All @@ -54,8 +79,7 @@ def set_explicit_h(explicit_h: bool) -> None:
:param explicit_h: Boolean whether to keep explicit Hs from input.
"""
global EXPLICIT_H
EXPLICIT_H = explicit_h
PARAMS.EXPLICIT_H = explicit_h


def set_reaction(reaction: bool, mode: str) -> None:
Expand All @@ -66,37 +90,31 @@ def set_reaction(reaction: bool, mode: str) -> None:
:param mode: Reaction mode to construct atom and bond feature vectors.
"""
global REACTION
REACTION = reaction
PARAMS.REACTION = reaction
if reaction:
global REACTION_MODE
global EXTRA_BOND_FDIM
global EXTRA_ATOM_FDIM

EXTRA_ATOM_FDIM = ATOM_FDIM - MAX_ATOMIC_NUM -1
EXTRA_BOND_FDIM = BOND_FDIM
REACTION_MODE = mode
PARAMS.EXTRA_ATOM_FDIM = PARAMS.ATOM_FDIM - PARAMS.MAX_ATOMIC_NUM -1
PARAMS.EXTRA_BOND_FDIM = PARAMS.BOND_FDIM
PARAMS.REACTION_MODE = mode


def is_explicit_h() -> bool:
r"""Returns whether to use retain explicit Hs"""
return EXPLICIT_H
return PARAMS.EXPLICIT_H


def is_reaction() -> bool:
r"""Returns whether to use reactions as input"""
return REACTION
return PARAMS.REACTION


def reaction_mode() -> str:
r"""Returns the reaction mode"""
return REACTION_MODE
return PARAMS.REACTION_MODE


def set_extra_atom_fdim(extra):
"""Change the dimensionality of the atom feature vector."""
global EXTRA_ATOM_FDIM
EXTRA_ATOM_FDIM = extra
PARAMS.EXTRA_ATOM_FDIM = extra


def get_bond_fdim(atom_messages: bool = False,
Expand All @@ -113,14 +131,13 @@ def get_bond_fdim(atom_messages: bool = False,
:return: The dimensionality of the bond feature vector.
"""

return (not overwrite_default_bond) * BOND_FDIM + EXTRA_BOND_FDIM + \
return (not overwrite_default_bond) * PARAMS.BOND_FDIM + PARAMS.EXTRA_BOND_FDIM + \
(not atom_messages) * get_atom_fdim(overwrite_default_atom=overwrite_default_atom)


def set_extra_bond_fdim(extra):
"""Change the dimensionality of the bond feature vector."""
global EXTRA_BOND_FDIM
EXTRA_BOND_FDIM = extra
PARAMS.EXTRA_BOND_FDIM = extra


def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
Expand Down Expand Up @@ -148,14 +165,14 @@ def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -
:return: A list containing the atom features.
"""
if atom is None:
features = [0] * ATOM_FDIM
features = [0] * PARAMS.ATOM_FDIM
else:
features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
features = onek_encoding_unk(atom.GetAtomicNum() - 1, PARAMS.ATOM_FEATURES['atomic_num']) + \
onek_encoding_unk(atom.GetTotalDegree(), PARAMS.ATOM_FEATURES['degree']) + \
onek_encoding_unk(atom.GetFormalCharge(), PARAMS.ATOM_FEATURES['formal_charge']) + \
onek_encoding_unk(int(atom.GetChiralTag()), PARAMS.ATOM_FEATURES['chiral_tag']) + \
onek_encoding_unk(int(atom.GetTotalNumHs()), PARAMS.ATOM_FEATURES['num_Hs']) + \
onek_encoding_unk(int(atom.GetHybridization()), PARAMS.ATOM_FEATURES['hybridization']) + \
[1 if atom.GetIsAromatic() else 0] + \
[atom.GetMass() * 0.01] # scaled to about the same range as other features
if functional_groups is not None:
Expand All @@ -171,7 +188,7 @@ def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]:
:return: A list containing the bond features.
"""
if bond is None:
fbond = [1] + [0] * (BOND_FDIM - 1)
fbond = [1] + [0] * (PARAMS.BOND_FDIM - 1)
else:
bt = bond.GetBondType()
fbond = [
Expand Down Expand Up @@ -340,11 +357,11 @@ def __init__(self, mol: Union[str, Chem.Mol, Tuple[Chem.Mol, Chem.Mol]],
if self.reaction_mode in ['reac_diff','prod_diff']:
f_atoms_diff = [list(map(lambda x, y: x - y, ii, jj)) for ii, jj in zip(f_atoms_prod, f_atoms_reac)]
if self.reaction_mode == 'reac_prod':
self.f_atoms = [x+y[MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_prod)]
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_prod)]
elif self.reaction_mode == 'reac_diff':
self.f_atoms = [x+y[MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_diff)]
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_reac, f_atoms_diff)]
elif self.reaction_mode == 'prod_diff':
self.f_atoms = [x+y[MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_prod, f_atoms_diff)]
self.f_atoms = [x+y[PARAMS.MAX_ATOMIC_NUM+1:] for x,y in zip(f_atoms_prod, f_atoms_diff)]
self.n_atoms = len(self.f_atoms)
n_atoms_reac = mol_reac.GetNumAtoms()

Expand Down
3 changes: 2 additions & 1 deletion chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from chemprop.constants import TEST_SCORES_FILE_NAME, TRAIN_LOGGER_NAME
from chemprop.data import get_data, get_task_names, MoleculeDataset, validate_dataset_type
from chemprop.utils import create_logger, makedirs, timeit
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_explicit_h, set_reaction
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_explicit_h, set_reaction, reset_featurization_parameters


@timeit(logger_name=TRAIN_LOGGER_NAME)
Expand Down Expand Up @@ -61,6 +61,7 @@ def cross_validate(args: TrainArgs,
args.save(os.path.join(args.save_dir, 'args.json'), with_reproducibility=False)

#set explicit H option and reaction option
reset_featurization_parameters(logger=logger)
set_explicit_h(args.explicit_h)
set_reaction(args.reaction, args.reaction_mode)

Expand Down
4 changes: 3 additions & 1 deletion chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from chemprop.args import PredictArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset, StandardScaler
from chemprop.utils import load_args, load_checkpoint, load_scalers, makedirs, timeit, update_prediction_args
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h, reset_featurization_parameters
from chemprop.models import MoleculeModel

def load_model(args: PredictArgs, generator: bool = False):
Expand Down Expand Up @@ -91,6 +91,8 @@ def set_features(args: PredictArgs, train_args: TrainArgs):
loading data and a model and making predictions.
:param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
"""
reset_featurization_parameters()

if args.atom_descriptors == 'feature':
set_extra_atom_fdim(train_args.atom_features_size)

Expand Down
3 changes: 2 additions & 1 deletion chemprop/train/molecule_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.utils import load_args, load_checkpoint, makedirs, timeit, load_scalers, update_prediction_args
from chemprop.data import MoleculeDataLoader, MoleculeDataset
from chemprop.features import set_reaction, set_explicit_h
from chemprop.features import set_reaction, set_explicit_h, reset_featurization_parameters
from chemprop.models import MoleculeModel

@timeit()
Expand All @@ -35,6 +35,7 @@ def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None)
args: Union[FingerprintArgs, TrainArgs]

#set explicit H option and reaction option
reset_featurization_parameters()
set_explicit_h(train_args.explicit_h)
set_reaction(train_args.reaction, train_args.reaction_mode)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def test_predict_spectra(self,
['--reaction', '--data_path', os.path.join(TEST_DATA_DIR, 'reaction_regression.csv'), '--explicit_h']
)
])
def test_z_train_single_task_regression_reaction(self,
def test_train_single_task_regression_reaction(self,
name: str,
model_type: str,
expected_score: float,
Expand Down

0 comments on commit 9c8ff40

Please sign in to comment.