Skip to content

Commit

Permalink
Merge branch 'atom_features'
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 10, 2020
2 parents 93e6420 + 17a6ba3 commit b4312e6
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 26 deletions.
12 changes: 12 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class CommonArgs(Tap):
"""Number of workers for the parallel data loading (0 means sequential)."""
batch_size: int = 50
"""Batch size."""
atom_descriptors: Literal['feature', 'descriptor'] = None
"""
Custom extra atom descriptors.
:code:`feature`: used as atom features to featurize a given molecule.
:code:`descriptor`: used as descriptor and concatenated to the machine learned atomic representation.
"""
atom_descriptors_path: str = None
"""Path to the extra atom descriptors."""
no_cache_mol: bool = False
"""
Whether to not cache the RDKit molecule for each SMILES string to reduce memory usage (cached by default).
Expand Down Expand Up @@ -133,6 +141,10 @@ def process_args(self) -> None:
if self.features_generator is not None and 'rdkit_2d_normalized' in self.features_generator and self.features_scaling:
raise ValueError('When using rdkit_2d_normalized features, --no_features_scaling must be specified.')

# Validate atom descriptors
if self.atom_descriptors is not None and self.atom_descriptors_path is None:
raise ValueError('When using atom_descriptors, --atom_descriptors_path must be specified')

set_cache_mol(not self.no_cache_mol)


Expand Down
48 changes: 45 additions & 3 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self,
targets: List[Optional[float]] = None,
row: OrderedDict = None,
features: np.ndarray = None,
features_generator: List[str] = None):
features_generator: List[str] = None,
atom_features: np.ndarray = None,
atom_descriptors: np.ndarray = None):
"""
:param smiles: The SMILES string for the molecule.
:param targets: A list of targets for the molecule (contains None for unknown target values).
Expand All @@ -68,6 +70,8 @@ def __init__(self,
self.row = row
self.features = features
self.features_generator = features_generator
self.atom_descriptors = atom_descriptors
self.atom_features = atom_features

# Generate additional features if given a generator
if self.features_generator is not None:
Expand All @@ -81,10 +85,18 @@ def __init__(self,
self.features = np.array(self.features)

# Fix nans in features
replace_token = 0
if self.features is not None:
replace_token = 0
self.features = np.where(np.isnan(self.features), replace_token, self.features)

# Fix nans in atom_descriptors
if self.atom_descriptors is not None:
self.atom_descriptors = np.where(np.isnan(self.atom_descriptors), replace_token, self.atom_descriptors)

# Fix nans in atom_features
if self.atom_features is not None:
self.atom_features = np.where(np.isnan(self.atom_features), replace_token, self.atom_features)

# Save a copy of the raw features and targets to enable different scaling later on
self.raw_features, self.raw_targets = features, targets

Expand Down Expand Up @@ -173,7 +185,7 @@ def batch_graph(self) -> BatchMolGraph:
if d.smiles in SMILES_TO_GRAPH:
mol_graph = SMILES_TO_GRAPH[d.smiles]
else:
mol_graph = MolGraph(d.mol)
mol_graph = MolGraph(d.mol, d.atom_features)
if cache_graph():
SMILES_TO_GRAPH[d.smiles] = mol_graph
mol_graphs.append(mol_graph)
Expand All @@ -193,6 +205,18 @@ def features(self) -> List[np.ndarray]:

return [d.features for d in self._data]

def atom_descriptors(self) -> List[np.ndarray]:
"""
Returns the atom descriptors associated with each molecule (if they exit).
:return: A list of 2D numpy arrays containing the atom descriptors
for each molecule or None if there are no features.
"""
if len(self._data) == 0 or self._data[0].atom_descriptors is None:
return None

return [d.atom_descriptors for d in self._data]

def targets(self) -> List[List[Optional[float]]]:
"""
Returns the targets associated with each molecule.
Expand All @@ -217,6 +241,24 @@ def features_size(self) -> int:
"""
return len(self._data[0].features) if len(self._data) > 0 and self._data[0].features is not None else None

def atom_descriptors_size(self) -> int:
"""
Returns the size of custom additional atom descriptors vector associated with the molecules.
:return: The size of the additional atom descriptor vector.
"""
return len(self._data[0].atom_descriptors[0]) \
if len(self._data) > 0 and self._data[0].atom_descriptors is not None else None

def atom_features_size(self) -> int:
"""
Returns the size of custom additional atom features vector associated with the molecules.
:return: The size of the additional atom feature vector.
"""
return len(self._data[0].atom_features[0]) \
if len(self._data) > 0 and self._data[0].atom_features is not None else None

def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler:
"""
Normalizes the features of the dataset using a :class:`~chemprop.data.StandardScaler`.
Expand Down
19 changes: 17 additions & 2 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .data import MoleculeDatapoint, MoleculeDataset
from .scaffold import log_scaffold_stats, scaffold_split
from chemprop.args import PredictArgs, TrainArgs
from chemprop.features import load_features
from chemprop.features import load_features, load_atom_features


def get_task_names(path: str,
Expand Down Expand Up @@ -109,6 +109,7 @@ def get_data(path: str,
args: Union[TrainArgs, PredictArgs] = None,
features_path: List[str] = None,
features_generator: List[str] = None,
atom_descriptors_path: str = None,
max_data_size: int = None,
store_row: bool = False,
logger: Logger = None,
Expand All @@ -127,6 +128,7 @@ def get_data(path: str,
in place of :code:`args.features_path`.
:param features_generator: A list of features generators to use. If provided, it is used
in place of :code:`args.features_generator`.
:param atom_descriptors_path: The path to the file containing the custom atom descriptors.
:param max_data_size: The maximum number of data points to load.
:param logger: A logger for recording output.
:param store_row: Whether to store the raw CSV row in each :class:`~chemprop.data.data.MoleculeDatapoint`.
Expand All @@ -137,15 +139,26 @@ def get_data(path: str,
"""
debug = logger.debug if logger is not None else print

# Load atomic descriptors
atom_features = None
atom_descriptors = None

if args is not None:
# Prefer explicit function arguments but default to args if not provided
smiles_column = smiles_column if smiles_column is not None else args.smiles_column
target_columns = target_columns if target_columns is not None else args.target_columns
ignore_columns = ignore_columns if ignore_columns is not None else args.ignore_columns
features_path = features_path if features_path is not None else args.features_path
features_generator = features_generator if features_generator is not None else args.features_generator
atom_descriptors_path = atom_descriptors_path if atom_descriptors_path is not None \
else args.atom_descriptors_path
max_data_size = max_data_size if max_data_size is not None else args.max_data_size

if args.atom_descriptors == 'feature':
atom_features = load_atom_features(atom_descriptors_path)
elif args.atom_descriptors == 'descriptor':
atom_descriptors = load_atom_features(atom_descriptors_path)

max_data_size = max_data_size or float('inf')

# Load features
Expand Down Expand Up @@ -204,7 +217,9 @@ def get_data(path: str,
targets=targets,
row=all_rows[i] if store_row else None,
features_generator=features_generator,
features=all_features[i] if features_data is not None else None
features=all_features[i] if features_data is not None else None,
atom_features=atom_features[i] if atom_features is not None else None,
atom_descriptors=atom_descriptors[i] if atom_descriptors is not None else None,
) for i, (smiles, targets) in tqdm(enumerate(zip(all_smiles, all_targets)),
total=len(all_smiles))
])
Expand Down
8 changes: 5 additions & 3 deletions chemprop/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
morgan_binary_features_generator, morgan_counts_features_generator, rdkit_2d_features_generator, \
rdkit_2d_normalized_features_generator, register_features_generator
from .featurization import atom_features, bond_features, BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph, \
MolGraph, onek_encoding_unk
from .utils import load_features, save_features
MolGraph, onek_encoding_unk, set_extra_atom_fdim
from .utils import load_features, save_features, load_atom_features

__all__ = [
'get_available_features_generators',
Expand All @@ -16,10 +16,12 @@
'bond_features',
'BatchMolGraph',
'get_atom_fdim',
'set_extra_atom_fdim',
'get_bond_fdim',
'mol2graph',
'MolGraph',
'onek_encoding_unk',
'load_features',
'save_features'
'save_features',
'load_atom_features'
]
23 changes: 19 additions & 4 deletions chemprop/features/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from rdkit import Chem
import torch
import numpy as np

# Atom feature sizes
MAX_ATOMIC_NUM = 100
Expand All @@ -28,12 +29,19 @@

# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
EXTRA_ATOM_FDIM = 0
BOND_FDIM = 14


def get_atom_fdim() -> int:
"""Gets the dimensionality of the atom feature vector."""
return ATOM_FDIM
return ATOM_FDIM + EXTRA_ATOM_FDIM


def set_extra_atom_fdim(extra) -> int:
"""Change the dimensionality of the atom feature vector."""
global EXTRA_ATOM_FDIM
EXTRA_ATOM_FDIM = extra


def get_bond_fdim(atom_messages: bool = False) -> int:
Expand Down Expand Up @@ -124,7 +132,7 @@ class MolGraph:
* :code:`b2revb`: A mapping from a bond index to the index of the reverse bond.
"""

def __init__(self, mol: Union[str, Chem.Mol]):
def __init__(self, mol: Union[str, Chem.Mol], atom_descriptors: np.ndarray = None):
"""
:param mol: A SMILES or an RDKit molecule.
"""
Expand All @@ -142,6 +150,9 @@ def __init__(self, mol: Union[str, Chem.Mol]):

# Get atom features
self.f_atoms = [atom_features(atom) for atom in mol.GetAtoms()]
if atom_descriptors is not None:
self.f_atoms = [f_atoms + descs.tolist() for f_atoms, descs in zip(self.f_atoms, atom_descriptors)]

self.n_atoms = len(self.f_atoms)

# Initialize atom to bond mapping for each atom
Expand Down Expand Up @@ -290,11 +301,15 @@ def get_a2a(self) -> torch.LongTensor:
return self.a2a


def mol2graph(mols: Union[List[str], List[Chem.Mol]]) -> BatchMolGraph:
def mol2graph(mols: Union[List[str], List[Chem.Mol]], atom_descriptors_batch: List[np.array] = None) -> BatchMolGraph:
"""
Converts a list of SMILES or RDKit molecules to a :class:`BatchMolGraph` containing the batch of molecular graphs.
:param mols: A list of SMILES or a list of RDKit molecules.
:param atom_descriptors_batch: A list of 2D numpy array containing additional atom descriptors to featurize the molecule
:return: A :class:`BatchMolGraph` containing the combined molecular graph for the molecules.
"""
return BatchMolGraph([MolGraph(mol) for mol in mols])
if atom_descriptors_batch is not None:
return BatchMolGraph([MolGraph(mol, atom_descriptors) for mol, atom_descriptors in zip(mols, atom_descriptors_batch)])
else:
return BatchMolGraph([MolGraph(mol) for mol in mols])
20 changes: 20 additions & 0 deletions chemprop/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

import numpy as np
import pandas as pd


def save_features(path: str, features: List[np.ndarray]) -> None:
Expand Down Expand Up @@ -53,3 +54,22 @@ def load_features(path: str) -> np.ndarray:
raise ValueError(f'Features path extension {extension} not supported.')

return features


def load_atom_features(path: str) -> List[np.ndarray]:
"""
Loads features saved in a .pkl file.
:param path: Path to file containing atomwise features.
:return: A list of 2D array.
"""

features_df = pd.read_pickle(path)
if features_df.iloc[0, 0].ndim == 1:
features = features_df.apply(lambda x: np.stack(x.tolist(), axis=1), axis=1).tolist()
elif features_df.iloc[0, 0].ndim == 2:
features = features_df.apply(lambda x: np.concatenate(x.tolist(), axis=1), axis=1).tolist()
else:
raise ValueError(f'Atom descriptors input {path} format not supported')

return features
17 changes: 12 additions & 5 deletions chemprop/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def create_ffn(self, args: TrainArgs) -> None:
if args.use_input_features:
first_linear_dim += args.features_size

if args.atom_descriptors == 'descriptor':
first_linear_dim += args.atom_descriptors_size

dropout = nn.Dropout(args.dropout)
activation = get_activation_function(args.activation)

Expand Down Expand Up @@ -97,33 +100,37 @@ def create_ffn(self, args: TrainArgs) -> None:

def featurize(self,
batch: Union[List[str], List[Chem.Mol], BatchMolGraph],
features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None) -> torch.FloatTensor:
"""
Computes feature vectors of the input by running the model except for the last layer.
:param batch: A list of SMILES, a list of RDKit molecules, or a
:class:`~chemprop.features.featurization.BatchMolGraph`.
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:return: The feature vectors computed by the :class:`MoleculeModel`.
"""
return self.ffn[:-1](self.encoder(batch, features_batch))
return self.ffn[:-1](self.encoder(batch, features_batch, atom_descriptors_batch))

def forward(self,
batch: Union[List[str], List[Chem.Mol], BatchMolGraph],
features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None) -> torch.FloatTensor:
"""
Runs the :class:`MoleculeModel` on input.
:param batch: A list of SMILES, a list of RDKit molecules, or a
:class:`~chemprop.features.featurization.BatchMolGraph`.
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:return: The output of the :class:`MoleculeModel`, which is either property predictions
or molecule features if :code:`self.featurizer=True`.
"""
if self.featurizer:
return self.featurize(batch, features_batch)
return self.featurize(batch, features_batch, atom_descriptors_batch)

output = self.ffn(self.encoder(batch, features_batch))
output = self.ffn(self.encoder(batch, features_batch, atom_descriptors_batch))

# Don't apply sigmoid during training b/c using BCEWithLogitsLoss
if self.classification and not self.training:
Expand Down

0 comments on commit b4312e6

Please sign in to comment.