Skip to content

Commit

Permalink
Merge branch 'solvation'
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 24, 2020
2 parents fc35354 + 7f13a08 commit a387d18
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 118 deletions.
39 changes: 33 additions & 6 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def get_checkpoint_paths(checkpoint_path: Optional[str] = None,
class CommonArgs(Tap):
""":class:`CommonArgs` contains arguments that are used in both :class:`TrainArgs` and :class:`PredictArgs`."""

smiles_column: str = None
"""Name of the column containing SMILES strings. By default, uses the first column."""
smiles_columns: List[str] = None
"""List of names of the columns containing SMILES strings.
By default, uses the first :code:`number_of_molecules` columns."""
number_of_molecules: int = 1
"""Number of molecules in each input to the model.
This must equal the length of :code:`smiles_column` (if not :code:`None`)."""
checkpoint_dir: str = None
"""Directory from which to load model checkpoints (walks directory and ensembles all models that are found)."""
checkpoint_path: str = None
Expand Down Expand Up @@ -164,9 +168,19 @@ def process_args(self) -> None:
if self.features_generator is not None and 'rdkit_2d_normalized' in self.features_generator and self.features_scaling:
raise ValueError('When using rdkit_2d_normalized features, --no_features_scaling must be specified.')

if self.smiles_columns is None:
self.smiles_columns = [None] * self.number_of_molecules
elif len(self.smiles_columns) != self.number_of_molecules:
raise ValueError('Length of smiles_columns must match number_of_molecules.')

# Validate atom descriptors
if self.atom_descriptors is not None and self.atom_descriptors_path is None:
raise ValueError('When using atom_descriptors, --atom_descriptors_path must be specified')
if (self.atom_descriptors is None) != (self.atom_descriptors_path is None):
raise ValueError('If atom_descriptors is specified, then an atom_descriptors_path must be provided '
'and vice versa.')

if self.atom_descriptors is not None and self.number_of_molecules > 1:
raise NotImplementedError('Atom descriptors are currently only supported with one molecule '
'per input (i.e., number_of_molecules = 1).')

set_cache_mol(not self.no_cache_mol)

Expand Down Expand Up @@ -251,6 +265,9 @@ class TrainArgs(CommonArgs):
"""Dimensionality of hidden layers in MPN."""
depth: int = 3
"""Number of message passing steps."""
mpn_shared: bool = False
"""Whether to use the same message passing neural network for all input molecules
Only relevant if :code:`number_of_molecules > 1`"""
dropout: float = 0.0
"""Dropout probability."""
activation: Literal['ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'] = 'ReLU'
Expand Down Expand Up @@ -531,8 +548,12 @@ class SklearnPredictArgs(Tap):

test_path: str
"""Path to CSV file containing testing data for which predictions will be made."""
smiles_column: str = None
"""Name of the column containing SMILES strings. By default, uses the first column."""
smiles_columns: List[str] = None
"""List of names of the columns containing SMILES strings.
By default, uses the first :code:`number_of_molecules` columns."""
number_of_molecules: int = 1
"""Number of molecules in each input to the model.
This must equal the length of :code:`smiles_column` (if not :code:`None`)."""
preds_path: str
"""Path to CSV file where predictions will be saved."""
checkpoint_dir: str = None
Expand All @@ -543,6 +564,12 @@ class SklearnPredictArgs(Tap):
"""List of paths to model checkpoints (:code:`.pkl` files)"""

def process_args(self) -> None:

if self.smiles_columns is None:
self.smiles_columns = [None] * self.number_of_molecules
elif len(self.smiles_columns) != self.number_of_molecules:
raise ValueError('Length of smiles_columns must match number_of_molecules.')

# Load checkpoint paths
self.checkpoint_paths = get_checkpoint_paths(
checkpoint_path=self.checkpoint_path,
Expand Down
76 changes: 50 additions & 26 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class MoleculeDatapoint:
"""A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets."""

def __init__(self,
smiles: str,
smiles: List[str],
targets: List[Optional[float]] = None,
row: OrderedDict = None,
features: np.ndarray = None,
features_generator: List[str] = None,
atom_features: np.ndarray = None,
atom_descriptors: np.ndarray = None):
"""
:param smiles: The SMILES string for the molecule.
:param smiles: A list of the SMILES strings for the molecules.
:param targets: A list of targets for the molecule (contains None for unknown target values).
:param row: The raw CSV row containing the information for this molecule.
:param features: A numpy array containing additional features (e.g., Morgan fingerprint).
Expand All @@ -79,8 +79,13 @@ def __init__(self,

for fg in self.features_generator:
features_generator = get_features_generator(fg)
if self.mol is not None and self.mol.GetNumHeavyAtoms() > 0:
self.features.extend(features_generator(self.mol))
for m in self.mol:
if m is not None and m.GetNumHeavyAtoms() > 0:
self.features.extend(features_generator(m))
# for H2
elif m is not None and m.GetNumHeavyAtoms() == 0:
# not all features are equally long, so use methane as dummy molecule to determine length
self.features.extend(np.zeros(len(features_generator(Chem.MolFromSmiles('C')))))

self.features = np.array(self.features)

Expand All @@ -101,12 +106,13 @@ def __init__(self,
self.raw_features, self.raw_targets = self.features, self.targets

@property
def mol(self) -> Chem.Mol:
"""Gets the corresponding RDKit molecule for this molecule's SMILES."""
mol = SMILES_TO_MOL.get(self.smiles, Chem.MolFromSmiles(self.smiles))

def mol(self) -> List[Chem.Mol]:
"""Gets the corresponding list of RDKit molecules for the corresponding SMILES list."""
mol = [SMILES_TO_MOL.get(s, Chem.MolFromSmiles(s)) for s in self.smiles]
if cache_mol():
SMILES_TO_MOL[self.smiles] = mol
for s, m in zip(self.smiles, mol):
SMILES_TO_MOL[s] = m

return mol

Expand All @@ -118,6 +124,14 @@ def set_features(self, features: np.ndarray) -> None:
"""
self.features = features

def extend_features(self, features: np.ndarray) -> None:
"""
Extends the features of the molecule.
:param features: A 1D numpy array of extra features for the molecule.
"""
self.features = np.append(self.features, features) if self.features is not None else features

def num_tasks(self) -> int:
"""
Returns the number of prediction tasks.
Expand Down Expand Up @@ -151,23 +165,23 @@ def __init__(self, data: List[MoleculeDatapoint]):
self._batch_graph = None
self._random = Random()

def smiles(self) -> List[str]:
def smiles(self) -> List[List[str]]:
"""
Returns a list containing the SMILES associated with each molecule.
Returns a list containing the SMILES list associated with each :class:`MoleculeDatapoint`.
:return: A list of SMILES strings.
:return: A list of lists of SMILES strings.
"""
return [d.smiles for d in self._data]

def mols(self) -> List[Chem.Mol]:
def mols(self) -> List[List[Chem.Mol]]:
"""
Returns the RDKit molecules associated with each molecule.
Returns a list of the RDKit molecules associated with each :class:`MoleculeDatapoint`.
:return: A list of RDKit molecules.
:return: A list of lists of RDKit molecules.
"""
return [d.mol for d in self._data]

def batch_graph(self) -> BatchMolGraph:
def batch_graph(self) -> List[BatchMolGraph]:
r"""
Constructs a :class:`~chemprop.features.BatchMolGraph` with the graph featurization of all the molecules.
Expand All @@ -177,20 +191,30 @@ def batch_graph(self) -> BatchMolGraph:
set of :class:`MoleculeDatapoint`\ s changes, then the returned :class:`~chemprop.features.BatchMolGraph`
will be incorrect for the underlying data.
:return: A :class:`~chemprop.features.BatchMolGraph` containing the graph featurization of all the molecules.
:return: A list of :class:`~chemprop.features.BatchMolGraph` containing the graph featurization of all the
molecules in each :class:`MoleculeDatapoint`.
"""
if self._batch_graph is None:
self._batch_graph = []

mol_graphs = []
for d in self._data:
if d.smiles in SMILES_TO_GRAPH:
mol_graph = SMILES_TO_GRAPH[d.smiles]
else:
mol_graph = MolGraph(d.mol, d.atom_features)
if cache_graph():
SMILES_TO_GRAPH[d.smiles] = mol_graph
mol_graphs.append(mol_graph)

self._batch_graph = BatchMolGraph(mol_graphs)
mol_graphs_list = []
for s, m in zip(d.smiles, d.mol):
if s in SMILES_TO_GRAPH:
mol_graph = SMILES_TO_GRAPH[s]
else:
if len(d.smiles) > 1 and d.atom_features is not None:
raise NotImplementedError('Atom descriptors are currently only supported with one molecule '
'per input (i.e., number_of_molecules = 1).')

mol_graph = MolGraph(m, d.atom_features)
if cache_graph():
SMILES_TO_GRAPH[s] = mol_graph
mol_graphs_list.append(mol_graph)
mol_graphs.append(mol_graphs_list)

self._batch_graph = [BatchMolGraph([g[i] for g in mol_graphs]) for i in range(len(mol_graphs[0]))]

return self._batch_graph

Expand Down
57 changes: 33 additions & 24 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def get_task_names(path: str,
smiles_column: str = None,
smiles_columns: List[str] = None,
target_columns: List[str] = None,
ignore_columns: List[str] = None) -> List[str]:
"""
Expand All @@ -29,7 +29,8 @@ def get_task_names(path: str,
the :code:`ignore_columns`.
:param path: Path to a CSV file.
:param smiles_column: The name of the column containing SMILES. By default, uses the first column.
:param smiles_columns: The names of the columns containing SMILES.
By default, uses the first :code:`number_of_molecules` columns.
:param target_columns: Name of the columns containing target values. By default, uses all columns
except the :code:`smiles_column` and the :code:`ignore_columns`.
:param ignore_columns: Name of the columns to ignore when :code:`target_columns` is not provided.
Expand All @@ -40,10 +41,12 @@ def get_task_names(path: str,

columns = get_header(path)

if smiles_column is None:
smiles_column = columns[0]
smiles_columns = smiles_columns if smiles_columns is not None else [None]

ignore_columns = set([smiles_column] + ([] if ignore_columns is None else ignore_columns))
if None in smiles_columns:
smiles_columns = columns[:len(smiles_columns)]

ignore_columns = set(smiles_columns + ([] if ignore_columns is None else ignore_columns))

target_names = [column for column in columns if column not in ignore_columns]

Expand All @@ -63,28 +66,31 @@ def get_header(path: str) -> List[str]:
return header


def get_smiles(path: str, smiles_column: str = None, header: bool = True) -> List[str]:
def get_smiles(path: str, smiles_columns: List[str] = None, header: bool = True) -> List[str]:
"""
Returns the SMILES from a data CSV file.
:param path: Path to a CSV file.
:param smiles_column: The name of the column containing SMILES. By default, uses the first column.
:param smiles_columns: A list of the names of the columns containing SMILES.
By default, uses the first :code:`number_of_molecules` columns.
:param header: Whether the CSV file contains a header.
:return: A list of SMILES.
"""
if smiles_column is not None and not header:
if smiles_columns is not None and not header:
raise ValueError('If smiles_column is provided, the CSV file must have a header.')

smiles_columns = smiles_columns if smiles_columns is not None else [None]

with open(path) as f:
if header:
reader = csv.DictReader(f)
if smiles_column is None:
smiles_column = reader.fieldnames[0]
if None in smiles_columns:
smiles_columns = reader.fieldnames[:len(smiles_columns)]
else:
reader = csv.reader(f)
smiles_column = 0
smiles_columns = 0

smiles = [row[smiles_column] for row in reader]
smiles = [[row[c] for c in smiles_columns] for row in reader]

return smiles

Expand All @@ -97,12 +103,12 @@ def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
:return: A :class:`~chemprop.data.MoleculeDataset` with only the valid molecules.
"""
return MoleculeDataset([datapoint for datapoint in tqdm(data)
if datapoint.smiles != '' and datapoint.mol is not None
and datapoint.mol.GetNumHeavyAtoms() > 0])
if all(s != '' for s in datapoint.smiles) and all(m is not None for m in datapoint.mol)
and all(m.GetNumHeavyAtoms() > 0 for m in datapoint.mol)])


def get_data(path: str,
smiles_column: str = None,
smiles_columns: List[str] = None,
target_columns: List[str] = None,
ignore_columns: List[str] = None,
skip_invalid_smiles: bool = True,
Expand All @@ -118,7 +124,8 @@ def get_data(path: str,
Gets SMILES and target values from a CSV file.
:param path: Path to a CSV file.
:param smiles_column: The name of the column containing SMILES. By default, uses the first column.
:param smiles_columns: The names of the columns containing SMILES.
By default, uses the first :code:`number_of_molecules` columns.
:param target_columns: Name of the columns containing target values. By default, uses all columns
except the :code:`smiles_column` and the :code:`ignore_columns`.
:param ignore_columns: Name of the columns to ignore when :code:`target_columns` is not provided.
Expand All @@ -145,7 +152,7 @@ def get_data(path: str,

if args is not None:
# Prefer explicit function arguments but default to args if not provided
smiles_column = smiles_column if smiles_column is not None else args.smiles_column
smiles_columns = smiles_columns if smiles_columns is not None else args.smiles_columns
target_columns = target_columns if target_columns is not None else args.target_columns
ignore_columns = ignore_columns if ignore_columns is not None else args.ignore_columns
features_path = features_path if features_path is not None else args.features_path
Expand All @@ -159,6 +166,8 @@ def get_data(path: str,
elif args.atom_descriptors == 'descriptor':
atom_descriptors = load_atom_features(atom_descriptors_path)

smiles_columns = smiles_columns if smiles_columns is not None else [None]

max_data_size = max_data_size or float('inf')

# Load features
Expand All @@ -170,25 +179,25 @@ def get_data(path: str,
else:
features_data = None

skip_smiles = set()
skip_smiles = [set() for _ in range(len(smiles_columns))]

# Load data
with open(path) as f:
reader = csv.DictReader(f)
columns = reader.fieldnames

# By default, the SMILES column is the first column
if smiles_column is None:
smiles_column = columns[0]
if None in smiles_columns:
smiles_columns = columns[:len(smiles_columns)]

# By default, the targets columns are all the columns except the SMILES column
if target_columns is None:
ignore_columns = set([smiles_column] + ([] if ignore_columns is None else ignore_columns))
ignore_columns = set(smiles_columns + ([] if ignore_columns is None else ignore_columns))
target_columns = [column for column in columns if column not in ignore_columns]

all_smiles, all_targets, all_rows, all_features = [], [], [], []
for i, row in tqdm(enumerate(reader)):
smiles = row[smiles_column]
smiles = [row[c] for c in smiles_columns]

if smiles in skip_smiles:
continue
Expand Down Expand Up @@ -235,14 +244,14 @@ def get_data(path: str,
return data


def get_data_from_smiles(smiles: List[str],
def get_data_from_smiles(smiles: List[List[str]],
skip_invalid_smiles: bool = True,
logger: Logger = None,
features_generator: List[str] = None) -> MoleculeDataset:
"""
Converts a list of SMILES to a :class:`~chemprop.data.MoleculeDataset`.
:param smiles: A list of SMILES.
:param smiles: A list of lists of SMILES with length depending on the number of molecules.
:param skip_invalid_smiles: Whether to skip and filter out invalid smiles using :func:`filter_invalid_smiles`
:param logger: A logger for recording output.
:param features_generator: List of features generators.
Expand Down
2 changes: 1 addition & 1 deletion chemprop/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def scoring_function(smiles: List[str]) -> List[float]:
C_PUCT = args.c_puct
MIN_ATOMS = args.min_atoms

all_smiles = get_smiles(path=args.data_path, smiles_column=args.smiles_column)
all_smiles = get_smiles(path=args.data_path, smiles_columns=args.smiles_columns)
header = get_header(path=args.data_path)

property_name = header[args.property_id] if len(header) > args.property_id else 'score'
Expand Down
2 changes: 1 addition & 1 deletion chemprop/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create_ffn(self, args: TrainArgs) -> None:
if args.features_only:
first_linear_dim = args.features_size
else:
first_linear_dim = args.hidden_size
first_linear_dim = args.hidden_size * args.number_of_molecules
if args.use_input_features:
first_linear_dim += args.features_size

Expand Down

0 comments on commit a387d18

Please sign in to comment.