Skip to content

Commit

Permalink
Merge pull request #200 from chemprop/preload_model
Browse files Browse the repository at this point in the history
Enable preloading of a model
  • Loading branch information
hesther committed Sep 3, 2021
2 parents ef3a26c + ed8667c commit 241e671
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 34 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Please see [aicures.mit.edu](https://aicures.mit.edu) and the associated [data G
* [Option 2: Installing from source](#option-2-installing-from-source)
* [Docker](#docker)
- [Web Interface](#web-interface)
- [Within Python](#within-python)
- [Data](#data)
- [Training](#training)
* [Train/Validation/Test Splits](#trainvalidationtest-splits)
Expand Down Expand Up @@ -122,6 +123,12 @@ Next, navigate to `chemprop/web` and run `gunicorn --bind {host}:{port} 'wsgi:bu
* Arguments including `init_db` and `demo` can be passed with this pattern: `'wsgi:build_app(init_db=True, demo=True)'`
* Gunicorn documentation can be found [here](http://docs.gunicorn.org/en/stable/index.html).

## Within Python

For information on the use of Chemprop within a python script, refer to the [Within a python script](https://chemprop.readthedocs.io/en/latest/tutorial.html#within-a-python-script)
section of the documentation.


## Data

In order to train a model, you must provide training data containing molecules (as SMILES strings) and known target values. Targets can either be real numbers, if performing regression, or binary (i.e. 0s and 1s), if performing classification. Target values which are unknown can be left as blanks.
Expand Down Expand Up @@ -249,6 +256,7 @@ Similar to molecule-, and atom-level features, the bond-level features are scale
### Reaction

As an alternative to molecule SMILES, Chemprop can also process atom-mapped reaction SMILES (see [Daylight manual](https://www.daylight.com/meetings/summerschool01/course/basics/smirks.html) for details on reaction SMILES), which consist of three parts denoting reactants, agents and products, separated by ">". Use the option `--reaction` to enable the input of reactions, which transforms the reactants and products of each reaction to the corresponding condensed graph of reaction and changes the initial atom and bond features to hold information from both the reactant and product (option `--reaction_mode reac_prod`), or from the reactant and the difference upon reaction (option `--reaction_mode reac_diff`, default) or from the product and the difference upon reaction (option `--reaction_mode prod_diff`). In reaction mode, Chemprop thus concatenates information to each atomic and bond feature vector, for example, with option `--reaction_mode reac_prod`, each atomic feature vector holds information on the state of the atom in the reactant (similar to default Chemprop), and concatenates information on the state of the atom in the product, so that the size of the D-MPNN increases slightly. Agents are discarded. Functions incompatible with a reaction as input (scaffold splitting and feature generation) are carried out on the reactants only. If the atom-mapped reaction SMILES contain mapped hydrogens, enable explicit hydrogens via `--explicit_h`. Example of an atom-mapped reaction SMILES denoting the reaction of methanol to formaldehyde without hydrogens: `[CH3:1][OH:2]>>[CH2:1]=[O:2]` and with hydrogens: `[C:1]([H:3])([H:4])([H:5])[O:2][H:6]>>[C:1]([H:3])([H:4])=[O:2].[H:5][H:6]`. The reactions do not need to be balanced and can thus contain unmapped parts, for example leaving groups, if necessary.
For further details and benchmarking, as well as a citable reference, please see [DOI 10.33774/chemrxiv-2021-frfhz](https://doi.org/10.33774/chemrxiv-2021-frfhz).

### Pretraining, With and Without Frozen Parameters

Expand Down
3 changes: 2 additions & 1 deletion chemprop/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .cross_validate import chemprop_train, cross_validate, TRAIN_LOGGER_NAME
from .evaluate import evaluate, evaluate_predictions
from .make_predictions import chemprop_predict, make_predictions
from .make_predictions import chemprop_predict, make_predictions, load_model
from .molecule_fingerprint import chemprop_fingerprint, model_fingerprint
from .predict import predict
from .run_training import run_training
Expand All @@ -15,6 +15,7 @@
'chemprop_predict',
'chemprop_fingerprint',
'make_predictions',
'load_model',
'predict',
'run_training',
'train'
Expand Down
143 changes: 110 additions & 33 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
from collections import OrderedDict
import csv
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

import numpy as np
from tqdm import tqdm

from .predict import predict
from chemprop.args import PredictArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset, StandardScaler
from chemprop.utils import load_args, load_checkpoint, load_scalers, makedirs, timeit, update_prediction_args
from chemprop.features import set_extra_atom_fdim, set_extra_bond_fdim, set_reaction, set_explicit_h
from chemprop.models import MoleculeModel

@timeit()
def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
def load_model(args: PredictArgs, generator: bool = False):
"""
Loads data and a trained model and uses the model to make predictions on the data.
If SMILES are provided, then makes predictions on smiles.
Otherwise makes predictions on :code:`args.test_data`.
Function to load a model or ensemble of models from file. If generator is True, a generator of the respective model and scaler
objects is returned (memory efficient), else the full list (holding all models in memory, necessary for preloading).
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param smiles: List of list of SMILES to make predictions on.
:return: A list of lists of target predictions.
:param generator: A boolean to return a generator instead of a list of models and scalers.
:return: A tuple of updated prediction arguments, training arguments, a list or generator object of models, a list or
generator object of scalers, the number of tasks and their respective names.
"""
print('Loading training args')
train_args = load_args(args.checkpoint_paths[0])
Expand All @@ -31,16 +30,25 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
update_prediction_args(predict_args=args, train_args=train_args)
args: Union[PredictArgs, TrainArgs]

if args.atom_descriptors == 'feature':
set_extra_atom_fdim(train_args.atom_features_size)

if args.bond_features_path is not None:
set_extra_bond_fdim(train_args.bond_features_size)
# Load model and scalers
models = (load_checkpoint(checkpoint_path, device=args.device) for checkpoint_path in args.checkpoint_paths)
scalers = (load_scalers(checkpoint_path) for checkpoint_path in args.checkpoint_paths)
if not generator:
models = list(models)
scalers = list(scalers)

#set explicit H option and reaction option
set_explicit_h(train_args.explicit_h)
set_reaction(train_args.reaction, train_args.reaction_mode)
return args, train_args, models, scalers, num_tasks, task_names

def load_data(args: PredictArgs, smiles: List[List[str]]):
"""
Function to load data from a list of smiles or a file.
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param smiles: A list of list of smiles, or None if data is to be read from file
:return: A tuple of a :class:`~chemprop.data.MoleculeDataset` containing all datapoints, a :class:`~chemprop.data.MoleculeDataset` containing only valid datapoints,
a :class:`~chemprop.data.MoleculeDataLoader` and a dictionary mapping full to valid indices.
"""
print('Loading data')
if smiles is not None:
full_data = get_data_from_smiles(
Expand All @@ -62,12 +70,56 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[

test_data = MoleculeDataset([full_data[i] for i in sorted(full_to_valid_indices.keys())])

# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)

print(f'Test size = {len(test_data):,}')

# Create data loader
test_data_loader = MoleculeDataLoader(
dataset=test_data,
batch_size=args.batch_size,
num_workers=args.num_workers
)

return full_data, test_data, test_data_loader, full_to_valid_indices


def set_features(args: PredictArgs, train_args: TrainArgs):
"""
Function to set extra options.
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
"""
if args.atom_descriptors == 'feature':
set_extra_atom_fdim(train_args.atom_features_size)

if args.bond_features_path is not None:
set_extra_bond_fdim(train_args.bond_features_size)

#set explicit H option and reaction option
set_explicit_h(train_args.explicit_h)
set_reaction(train_args.reaction, train_args.reaction_mode)


def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: MoleculeDataset,
task_names: List[str], num_tasks: int, test_data_loader: MoleculeDataLoader, full_data: MoleculeDataset,
full_to_valid_indices: dict, models: List[MoleculeModel], scalers: List[List[StandardScaler]]):
"""
Function to predict with a model and save the predictions to file.
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param train_args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
:param test_data: A :class:`~chemprop.data.MoleculeDataset` containing valid datapoints.
:param task_names: A list of task names.
:param num_tasks: Number of tasks.
:param test_data_loader: A :class:`~chemprop.data.MoleculeDataLoader` to load the test data.
:param full_data: A :class:`~chemprop.data.MoleculeDataset` containing all (valid and invalid) datapoints.
:param full_to_valid_indices: A dictionary dictionary mapping full to valid indices.
:param models: A list or generator object of :class:`~chemprop.models.MoleculeModel`\ s.
:param scalers: A list or generator object of :class:`~chemprop.features.scaler.StandardScaler` objects.
:return: A list of lists of target predictions.
"""
# Predict with each model individually and sum predictions
if args.dataset_type == 'multiclass':
sum_preds = np.zeros((len(test_data), num_tasks, args.multiclass_num_classes))
Expand All @@ -79,19 +131,10 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
else:
all_preds = np.zeros((len(test_data), num_tasks, len(args.checkpoint_paths)))

# Create data loader
test_data_loader = MoleculeDataLoader(
dataset=test_data,
batch_size=args.batch_size,
num_workers=args.num_workers
)

# Partial results for variance robust calculation.
print(f'Predicting with an ensemble of {len(args.checkpoint_paths)} models')
for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
# Load model and scalers
model = load_checkpoint(checkpoint_path, device=args.device)
scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(checkpoint_path)
for index, (model, scaler_list) in enumerate(tqdm(zip(models, scalers), total=len(args.checkpoint_paths))):
scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = scaler_list

# Normalize features
if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
Expand Down Expand Up @@ -179,6 +222,40 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
writer.writerow(datapoint.row)

avg_preds = avg_preds.tolist()

return avg_preds


@timeit()
def make_predictions(args: PredictArgs, smiles: List[List[str]] = None,
model_objects: Tuple[PredictArgs, TrainArgs, List[MoleculeModel], List[StandardScaler], int, List[str]] = None) -> List[List[Optional[float]]]:
"""
Loads data and a trained model and uses the model to make predictions on the data.
If SMILES are provided, then makes predictions on smiles.
Otherwise makes predictions on :code:`args.test_data`.
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param smiles: List of list of SMILES to make predictions on.
:param model_objects: Tuple of output of load_model function which can be called separately.
:return: A list of lists of target predictions.
"""
if model_objects:
args, train_args, models, scalers, num_tasks, task_names = model_objects
else:
args, train_args, models, scalers, num_tasks, task_names = load_model(args, generator=True)

set_features(args, train_args)

full_data, test_data, test_data_loader, full_to_valid_indices = load_data(args, smiles)

# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)

avg_preds = predict_and_save(args, train_args, test_data, task_names, num_tasks, test_data_loader, full_data, full_to_valid_indices, models, scalers)

return avg_preds


Expand Down
66 changes: 66 additions & 0 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ Reaction
^^^^^^^^

As an alternative to molecule SMILES, Chemprop can also process atom-mapped reaction SMILES (see `Daylight manual <https://www.daylight.com/meetings/summerschool01/course/basics/smirks.html>`_ for details on reaction SMILES), which consist of three parts denoting reactants, agents and products, separated by ">". Use the option :code:`--reaction` to enable the input of reactions, which transforms the reactants and products of each reaction to the corresponding condensed graph of reaction and changes the initial atom and bond features to hold information from both the reactant and product (option :code:`--reaction_mode reac_prod`), or from the reactant and the difference upon reaction (option :code:`--reaction_mode reac_diff`, default) or from the product and the difference upon reaction (option :code:`--reaction_mode prod_diff`). In reaction mode, Chemprop thus concatenates information to each atomic and bond feature vector, for example, with option :code:`--reaction_mode reac_prod`, each atomic feature vector holds information on the state of the atom in the reactant (similar to default Chemprop), and concatenates information on the state of the atom in the product, so that the size of the D-MPNN increases slightly. Agents are discarded. Functions incompatible with a reaction as input (scaffold splitting and feature generation) are carried out on the reactants only. If the atom-mapped reaction SMILES contain mapped hydrogens, enable explicit hydrogens via :code:`--explicit_h`. Example of an atom-mapped reaction SMILES denoting the reaction of methanol to formaldehyde without hydrogens: :code:`[CH3:1][OH:2]>>[CH2:1]=[O:2]` and with hydrogens: :code:`[C:1]([H:3])([H:4])([H:5])[O:2][H:6]>>[C:1]([H:3])([H:4])=[O:2].[H:5][H:6]`. The reactions do not need to be balanced and can thus contain unmapped parts, for example leaving groups, if necessary.
For further details and benchmarking, as well as a citable reference, please see `DOI 10.33774/chemrxiv-2021-frfhz <https://doi.org/10.33774/chemrxiv-2021-frfhz>`_.

Pretraining
^^^^^^^^^^^
Expand Down Expand Up @@ -242,3 +243,68 @@ Web Interface
-------------

For those less familiar with the command line, Chemprop also includes a web interface which allows for basic training and predicting. See :ref:`web` for more details.

Within a python script
----------------------

Model training and predicting can also be embedded within a python script. To train a model, provide arguments as a list of strings (arguments are identical to command line mode),
parse the arguments, and then call :code:`chemprop.train.cross_validate()`::

import chemprop

arguments = [
'--data_path', 'data/tox21.csv',
'--dataset_type', 'classification',
'--save_dir', 'tox21_checkpoints'
]

args = chemprop.args.TrainArgs().parse_args(arguments)
mean_score, std_score = chemprop.train.cross_validate(args=args, train_func=chemprop.train.run_training)

For predicting with a given model, either a list of smiles or a csv file can be used as input. To use a csv file ::

import chemprop

arguments = [
'--test_path', 'data/tox21.csv',
'--preds_path', 'tox21_preds.csv',
'--checkpoint_dir', 'tox21_checkpoints'
]
args = chemprop.args.PredictArgs().parse_args(arguments)
preds = chemprop.train.make_predictions(args=args)

If you only want to use the predictions :code:`preds` within the script, and not save the file, set :code:`preds_path` to :code:`/dev/null`. To predict on a list of smiles, run::

import chemprop

smiles = [['CCC', 'CCCC', 'OCC']]
arguments = [
'--test_path', '/dev/null',
'--preds_path', '/dev/null',
'--checkpoint_dir', 'tox21_checkpoints'
]

args = chemprop.args.PredictArgs().parse_args(arguments)
preds = chemprop.train.make_predictions(args=args, smiles=smiles)

where the given :code:`test_path` will be discarded if a list of smiles is provided. If you want to predict multiple sets of molecules consecutively, it is more efficient to
only load the chemprop model once, and then predict with the preloaded model (instead of loading the model for every prediction)::

import chemprop

arguments = [
'--test_path', '/dev/null',
'--preds_path', '/dev/null',
'--checkpoint_dir', 'tox21_checkpoints'
]

args = chemprop.args.PredictArgs().parse_args(arguments)

model_objects = chemprop.train.load_model(args=args)
smiles = [['CCC', 'CCCC', 'OCC']]
preds = chemprop.train.make_predictions(args=args, smiles=smiles, model_objects=model_objects)

smiles = [['CCCC', 'CCCCC', 'COCC']]
preds = chemprop.train.make_predictions(args=args, smiles=smiles, model_objects=model_objects)

0 comments on commit 241e671

Please sign in to comment.