Skip to content

Commit

Permalink
New get_invalid_smiles functions and output options for make_predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
cjmcgill committed Jan 23, 2022
1 parent 0672527 commit ec45082
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 32 deletions.
32 changes: 8 additions & 24 deletions chemprop/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,10 @@
from .data import (
cache_graph,
cache_mol,
MoleculeDatapoint,
MoleculeDataset,
MoleculeDataLoader,
MoleculeSampler,
set_cache_graph,
empty_cache,
set_cache_mol
)
from .data import cache_graph, cache_mol, MoleculeDatapoint, MoleculeDataset, MoleculeDataLoader, \
MoleculeSampler, set_cache_graph, empty_cache, set_cache_mol
from .scaffold import generate_scaffold, log_scaffold_stats, scaffold_split, scaffold_to_smiles
from .scaler import StandardScaler
from .utils import (
filter_invalid_smiles,
get_class_sizes,
get_data,
get_data_from_smiles,
get_header,
get_smiles,
get_task_names,
preprocess_smiles_columns,
split_data,
validate_data,
validate_dataset_type,
)
from .utils import filter_invalid_smiles, get_class_sizes, get_data, get_data_from_smiles, \
get_header, get_smiles, get_task_names, get_data_weights, preprocess_smiles_columns, split_data, \
validate_data, validate_dataset_type, get_invalid_smiles_from_file, get_invalid_smiles_from_list

__all__ = [
'cache_graph',
Expand All @@ -44,6 +25,9 @@
'get_class_sizes',
'get_data',
'get_data_from_smiles',
'get_data_weights',
'get_invalid_smiles_from_file',
'get_invalid_smiles_from_list',
'get_header',
'get_smiles',
'get_task_names',
Expand Down
46 changes: 45 additions & 1 deletion chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from tqdm import tqdm

from .data import MoleculeDatapoint, MoleculeDataset
from .data import MoleculeDatapoint, MoleculeDataset, make_mols
from .scaffold import log_scaffold_stats, scaffold_split
from chemprop.args import PredictArgs, TrainArgs
from chemprop.features import load_features, load_valid_atom_or_bond_features
Expand Down Expand Up @@ -168,6 +168,50 @@ def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
and all(m[0].GetNumHeavyAtoms() + m[1].GetNumHeavyAtoms() > 0 for m in datapoint.mol if isinstance(m, tuple))])


def get_invalid_smiles_from_file(path: str = None,
smiles_columns: Union[str, List[str]] = None,
header: bool = True,
reaction: bool = False,
) -> Union[List[str], List[List[str]]]:
"""
Returns the invalid SMILES from a data CSV file.
:param path: Path to a CSV file.
:param smiles_columns: A list of the names of the columns containing SMILES.
By default, uses the first :code:`number_of_molecules` columns.
:param header: Whether the CSV file contains a header.
:param reaction: Boolean whether the SMILES strings are to be treated as a reaction.
:return: A list of lists of SMILES, for the invalid SMILES in the file.
"""
smiles = get_smiles(path=path, smiles_columns=smiles_columns, header=header)

invalid_smiles = get_invalid_smiles_from_list(smiles=smiles, reaction=reaction)

return invalid_smiles


def get_invalid_smiles_from_list(smiles: List[List[str]], reaction: bool = False) -> List[List[str]]:
"""
Returns the invalid SMILES from a list of lists of SMILES strings.
:param smiles: A list of list of SMILES.
:param reaction: Boolean whether the SMILES strings are to be treated as a reaction.
:return: A list of lists of SMILES, for the invalid SMILES among the lists provided.
"""
invalid_smiles = []

for mol_smiles in smiles:
mols = make_mols(smiles=mol_smiles, reaction=reaction, keep_h=False, add_h=False)
if any(s == '' for s in mol_smiles) or \
any(m is None for m in mols) or \
any(m.GetNumHeavyAtoms() == 0 for m in mols if not isinstance(m, tuple)) or \
any(m[0].GetNumHeavyAtoms() + m[1].GetNumHeavyAtoms() == 0 for m in mols if isinstance(m, tuple)):

invalid_smiles.append(mol_smiles)

return invalid_smiles


def get_data(path: str,
smiles_columns: Union[str, List[str]] = None,
target_columns: List[str] = None,
Expand Down
53 changes: 46 additions & 7 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def set_features(args: PredictArgs, train_args: TrainArgs):

def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: MoleculeDataset,
task_names: List[str], num_tasks: int, test_data_loader: MoleculeDataLoader, full_data: MoleculeDataset,
full_to_valid_indices: dict, models: List[MoleculeModel], scalers: List[List[StandardScaler]]):
full_to_valid_indices: dict, models: List[MoleculeModel], scalers: List[List[StandardScaler]],
return_invalid_smiles: bool = False):
"""
Function to predict with a model and save the predictions to file.
Expand All @@ -124,6 +125,7 @@ def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: Molecu
:param full_to_valid_indices: A dictionary dictionary mapping full to valid indices.
:param models: A list or generator object of :class:`~chemprop.models.MoleculeModel`\ s.
:param scalers: A list or generator object of :class:`~chemprop.features.scaler.StandardScaler` objects.
:param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions.
:return: A list of lists of target predictions.
"""
# Predict with each model individually and sum predictions
Expand Down Expand Up @@ -244,14 +246,25 @@ def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: Molecu
for datapoint in full_data:
writer.writerow(datapoint.row)

# Return predicted values
avg_preds = avg_preds.tolist()

return avg_preds
if return_invalid_smiles:
full_preds = []
for full_index in range(len(full_data)):
valid_index = full_to_valid_indices.get(full_index, None)
preds = avg_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * num_tasks
full_preds.append(preds)
return full_preds
else:
return avg_preds


@timeit()
def make_predictions(args: PredictArgs, smiles: List[List[str]] = None,
model_objects: Tuple[PredictArgs, TrainArgs, List[MoleculeModel], List[StandardScaler], int, List[str]] = None) -> List[List[Optional[float]]]:
model_objects: Tuple[PredictArgs, TrainArgs, List[MoleculeModel], List[StandardScaler], int, List[str]] = None,
return_invalid_smiles: bool = True,
return_index_dict: bool = False) -> List[List[Optional[float]]]:
"""
Loads data and a trained model and uses the model to make predictions on the data.
Expand All @@ -262,6 +275,8 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None,
loading data and a model and making predictions.
:param smiles: List of list of SMILES to make predictions on.
:param model_objects: Tuple of output of load_model function which can be called separately.
:param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions.
:param return_index_dict: Whether to return the prediction results as a dictionary keyed from the initial data indexes.
:return: A list of lists of target predictions.
"""
if model_objects:
Expand All @@ -271,15 +286,39 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None,

set_features(args, train_args)

# Note: to get the invalid SMILES for your data, use the get_invalid_smiles_from_file or get_invalid_smiles_from_list functions from data/utils.py
full_data, test_data, test_data_loader, full_to_valid_indices = load_data(args, smiles)

# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)

avg_preds = predict_and_save(args, train_args, test_data, task_names, num_tasks, test_data_loader, full_data, full_to_valid_indices, models, scalers)
avg_preds = [None] * len(full_data)
else:
avg_preds = predict_and_save(
args=args,
train_args=train_args,
test_data=test_data,
task_names=task_names,
num_tasks=num_tasks,
test_data_loader=test_data_loader,
full_data=full_data,
full_to_valid_indices=full_to_valid_indices,
models=models,
scalers=scalers,
return_invalid_smiles=return_invalid_smiles,
)

return avg_preds
if return_index_dict:
preds_dict = {}
for i in range(len(full_data)):
if return_invalid_smiles:
preds_dict[i] = avg_preds[i]
else:
valid_index = full_to_valid_indices.get(i, None)
if valid_index is not None:
preds_dict[i] = avg_preds[valid_index]
return preds_dict
else:
return avg_preds


def chemprop_predict() -> None:
Expand Down

0 comments on commit ec45082

Please sign in to comment.