Skip to content

Commit

Permalink
Merge pull request #197 from chemprop/spectra
Browse files Browse the repository at this point in the history
Spectra Training
  • Loading branch information
cjmcgill committed Sep 7, 2021
2 parents 241e671 + 03f816a commit 491085f
Show file tree
Hide file tree
Showing 18 changed files with 1,483 additions and 79 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Please see [aicures.mit.edu](https://aicures.mit.edu) and the associated [data G
* [RDKit 2D Features](#rdkit-2d-features)
* [Custom Features](#custom-features)
* [Atomic Features](#atomic-features)
* [Spectra](#spectra)
* [Reaction](#reaction)
* [Pretraining](#pretraining)
* [Missing target values](#missing-target-values)
Expand Down Expand Up @@ -153,7 +154,7 @@ To train a model, run:
```
chemprop_train --data_path <path> --dataset_type <type> --save_dir <dir>
```
where `<path>` is the path to a CSV file containing a dataset, `<type>` is either "classification" or "regression" depending on the type of the dataset, and `<dir>` is the directory where model checkpoints will be saved.
where `<path>` is the path to a CSV file containing a dataset, `<type>` is one of [classification, regression, multiclass, spectra] depending on the type of the dataset, and `<dir>` is the directory where model checkpoints will be saved.

For example:
```
Expand All @@ -177,6 +178,8 @@ Our code supports several methods of splitting data into train, validation, and

**Scaffold:** Alternatively, the data can be split by molecular scaffold so that the same scaffold never appears in more than one split. This can be specified by adding `--split_type scaffold_balanced`.

**Random With Repeated SMILES** Some datasets have multiple entries with the same SMILES. To constrain splitting so the repeated SMILES are in the same split, use the argument `--split_type random_with_repeated_smiles`.

**Separate val/test:** If you have separate data files you would like to use as the validation or test set, you can specify them with `--separate_val_path <val_path>` and/or `--separate_test_path <test_path>`. If both are provided, then the data specified by `--data_path` is used entirely as the training data. If only one separate path is provided, the `--data_path` data is split between train data and either val or test data, whichever is not provided separately.

Note: By default, both random and scaffold split the data into 80% train, 10% validation, and 10% test. This can be changed with `--split_sizes <train_frac> <val_frac> <test_frac>`. For example, the default setting is `--split_sizes 0.8 0.1 0.1`. Both also involve a random component and can be seeded with `--seed <seed>`. The default setting is `--seed 0`.
Expand Down Expand Up @@ -253,12 +256,20 @@ The bond-level features are concatenated with the bond feature vectors before th

Similar to molecule-, and atom-level features, the bond-level features are scaled by default. This can be disabled with the option `--no_bond_features_scaling`.

### Spectra

One of the data types that can be trained with Chemprop is "spectra". Spectra training is different than other datatypes because it considers the predictions of all targets together. Targets for spectra should be provided as the values for the spectrum at a specific position in the spectrum. The loss function for spectra is SID, spectral information divergence. Alternatively, Wasserstein distance (earthmover's distance) can be used for both loss function and metric with input arguments `--metric wasserstein --alternative_loss_function wasserstein`.

Spectra predictions are configured to return only positive values and normalize them to sum each spectrum to 1. Activation to enforce positivity is an exponential function by default but can also be set as a Softplus function, according to the argument `--spectra_activation <exp or softplus>`. Value positivity is enforced on input targets as well using a floor value that replaces negative or smaller target values with the floor value (default 1e-8), customizable with the argument `--spectra_target_floor <float>`.

In absorption spectra, sometimes the phase of collection will create regions in the spectrum where data collection or prediction would be unreliable. To exclude these regions, include paths to phase features for your data (`--phase_features_path <path>`) and a mask indicating the spectrum regions that are supported (`--spectra_phase_mask_path <path>`). The format for the mask file is a `.csv` file with columns for the spectrum positions and rows for the phases, with column and row labels in the same order as they appear in the targets and features files.

### Reaction

As an alternative to molecule SMILES, Chemprop can also process atom-mapped reaction SMILES (see [Daylight manual](https://www.daylight.com/meetings/summerschool01/course/basics/smirks.html) for details on reaction SMILES), which consist of three parts denoting reactants, agents and products, separated by ">". Use the option `--reaction` to enable the input of reactions, which transforms the reactants and products of each reaction to the corresponding condensed graph of reaction and changes the initial atom and bond features to hold information from both the reactant and product (option `--reaction_mode reac_prod`), or from the reactant and the difference upon reaction (option `--reaction_mode reac_diff`, default) or from the product and the difference upon reaction (option `--reaction_mode prod_diff`). In reaction mode, Chemprop thus concatenates information to each atomic and bond feature vector, for example, with option `--reaction_mode reac_prod`, each atomic feature vector holds information on the state of the atom in the reactant (similar to default Chemprop), and concatenates information on the state of the atom in the product, so that the size of the D-MPNN increases slightly. Agents are discarded. Functions incompatible with a reaction as input (scaffold splitting and feature generation) are carried out on the reactants only. If the atom-mapped reaction SMILES contain mapped hydrogens, enable explicit hydrogens via `--explicit_h`. Example of an atom-mapped reaction SMILES denoting the reaction of methanol to formaldehyde without hydrogens: `[CH3:1][OH:2]>>[CH2:1]=[O:2]` and with hydrogens: `[C:1]([H:3])([H:4])([H:5])[O:2][H:6]>>[C:1]([H:3])([H:4])=[O:2].[H:5][H:6]`. The reactions do not need to be balanced and can thus contain unmapped parts, for example leaving groups, if necessary.
For further details and benchmarking, as well as a citable reference, please see [DOI 10.33774/chemrxiv-2021-frfhz](https://doi.org/10.33774/chemrxiv-2021-frfhz).

### Pretraining, With and Without Frozen Parameters
### Pretraining

Pretraining can be carried out using previously trained checkpoint files to set some or all of the initial values of a model for training. Additionally, some model parameters from the previous model can be frozen in place, so that they will not be updated during training.

Expand Down Expand Up @@ -312,7 +323,7 @@ If installed from source, `chemprop_predict` can be replaced with `python predic

### Epistemic Uncertainty

One method of obtaining the epistemic uncertainty of a prediction is to calculate the variance of an ensemble of models. To calculate these variances and write them as an additional column in the `--preds_path` file, use `--ensemble_variance`.
One method of obtaining the epistemic uncertainty of a prediction is to calculate the variance of an ensemble of models. To calculate these variances and write them as an additional column in the `--preds_path` file, use `--ensemble_variance`. If this flag is used with a spectra prediction, it will instead return the average pairwise SID comparison of the different ensemble predictions.

## Encode Fingerprint Latent Representation

Expand Down
36 changes: 26 additions & 10 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chemprop.features import get_available_features_generators


Metric = Literal['auc', 'prc-auc', 'rmse', 'mae', 'mse', 'r2', 'accuracy', 'cross_entropy', 'binary_cross_entropy']
Metric = Literal['auc', 'prc-auc', 'rmse', 'mae', 'mse', 'r2', 'accuracy', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein']


def get_checkpoint_paths(checkpoint_path: Optional[str] = None,
Expand Down Expand Up @@ -82,6 +82,8 @@ class CommonArgs(Tap):
"""Method(s) of generating additional features."""
features_path: List[str] = None
"""Path(s) to features to use in FNN (instead of features_generator)."""
phase_features_path: str = None
"""Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype."""
no_features_scaling: bool = False
"""Turn off scaling of features."""
max_data_size: int = None
Expand Down Expand Up @@ -225,19 +227,21 @@ class TrainArgs(CommonArgs):
"""
ignore_columns: List[str] = None
"""Name of the columns to ignore when :code:`target_columns` is not provided."""
dataset_type: Literal['regression', 'classification', 'multiclass']
dataset_type: Literal['regression', 'classification', 'multiclass', 'spectra']
"""Type of dataset. This determines the loss function used during training."""
multiclass_num_classes: int = 3
"""Number of classes when running multiclass classification."""
separate_val_path: str = None
"""Path to separate val set, optional."""
separate_test_path: str = None
"""Path to separate test set, optional."""
spectra_phase_mask_path: str = None
"""Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions."""
data_weights_path: str = None
"""Path to weights for each molecule in the training data, affecting the relative weight of molecules in the loss function"""
target_weights: List[float] = None
"""Weights associated with each target, affecting the relative weight of targets in the loss function. Must match the number of target columns."""
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined'] = 'random'
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined', 'random_with_repeated_smiles'] = 'random'
"""Method of splitting the data into train/val/test."""
split_sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1)
"""Split proportions for train/validation/test sets."""
Expand All @@ -263,7 +267,7 @@ class TrainArgs(CommonArgs):
metric: Metric = None
"""
Metric to use during evaluation. It is also used with the validation set for early stopping.
Defaults to "auc" for classification and "rmse" for regression.
Defaults to "auc" for classification, "rmse" for regression, and "sid" for spectra.
"""
extra_metrics: List[Metric] = []
"""Additional metrics to use to evaluate the model. Not used for early stopping."""
Expand Down Expand Up @@ -324,6 +328,10 @@ class TrainArgs(CommonArgs):
"""Path to file with features for separate val set."""
separate_test_features_path: List[str] = None
"""Path to file with features for separate test set."""
separate_val_phase_features_path: str = None
"""Path to file with phase features for separate val set."""
separate_test_phase_features_path: str = None
"""Path to file with phase features for separate test set."""
separate_val_atom_descriptors_path: str = None
"""Path to file with extra atom descriptors for separate val set."""
separate_test_atom_descriptors_path: str = None
Expand Down Expand Up @@ -377,7 +385,12 @@ class TrainArgs(CommonArgs):
"""Maximum magnitude of gradient during training."""
class_balance: bool = False
"""Trains with an equal number of positives and negatives in each batch."""

spectra_activation: Literal['exp', 'softplus'] = 'exp'
"""Indicates which function to use in dataset_type spectra training to constrain outputs to be positive."""
spectra_target_floor: float = 1e-8
"""Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values."""
alternative_loss_function: Literal['wasserstein'] = None
"""Option to replace the default loss function, with an alternative. Only currently applied for spectra data type and wasserstein loss."""
overwrite_default_atom_features: bool = False
"""
Overwrites the default atom descriptors with the new ones instead of concatenating them.
Expand Down Expand Up @@ -419,12 +432,12 @@ def metrics(self) -> List[str]:
@property
def minimize_score(self) -> bool:
"""Whether the model should try to minimize the score metric or maximize it."""
return self.metric in {'rmse', 'mae', 'mse', 'cross_entropy', 'binary_cross_entropy'}
return self.metric in {'rmse', 'mae', 'mse', 'cross_entropy', 'binary_cross_entropy', 'sid', 'wasserstein'}

@property
def use_input_features(self) -> bool:
"""Whether the model is using additional molecule-level features."""
return self.features_generator is not None or self.features_path is not None
return self.features_generator is not None or self.features_path is not None or self.phase_features_path is not None

@property
def num_lrs(self) -> int:
Expand Down Expand Up @@ -518,6 +531,8 @@ def process_args(self) -> None:
self.metric = 'auc'
elif self.dataset_type == 'multiclass':
self.metric = 'cross_entropy'
elif self.dataset_type == 'spectra':
self.metric = 'sid'
else:
self.metric = 'rmse'

Expand All @@ -526,9 +541,10 @@ def process_args(self) -> None:
f'Please only include it once.')

for metric in self.metrics:
if not ((self.dataset_type == 'classification' and metric in ['auc', 'prc-auc', 'accuracy', 'binary_cross_entropy']) or
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2']) or
(self.dataset_type == 'multiclass' and metric in ['cross_entropy', 'accuracy'])):
if not any([(self.dataset_type == 'classification' and metric in ['auc', 'prc-auc', 'accuracy', 'binary_cross_entropy']),
(self.dataset_type == 'regression' and metric in ['rmse', 'mae', 'mse', 'r2']),
(self.dataset_type == 'multiclass' and metric in ['cross_entropy', 'accuracy']),
(self.dataset_type == 'spectra' and metric in ['sid','wasserstein'])]):
raise ValueError(f'Metric "{metric}" invalid for dataset type "{self.dataset_type}".')

# Validate class balance
Expand Down
14 changes: 14 additions & 0 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self,
data_weight: float = 1,
features: np.ndarray = None,
features_generator: List[str] = None,
phase_features: List[float] = None,
atom_features: np.ndarray = None,
atom_descriptors: np.ndarray = None,
bond_features: np.ndarray = None,
Expand All @@ -73,6 +74,7 @@ def __init__(self,
:param data_weight: Weighting of the datapoint for the loss function.
:param features: A numpy array containing additional features (e.g., Morgan fingerprint).
:param features_generator: A list of features generators to use.
:param phase_features: A one-hot vector indicating the phase of the data, as used in spectra data.
:param atom_descriptors: A numpy array containing additional atom descriptors to featurize the molecule
:param bond_features: A numpy array containing additional bond features to featurize the molecule
:param overwrite_default_atom_features: Boolean to overwrite default atom features by atom_features
Expand All @@ -88,6 +90,7 @@ def __init__(self,
self.data_weight = data_weight
self.features = features
self.features_generator = features_generator
self.phase_features = phase_features
self.atom_descriptors = atom_descriptors
self.atom_features = atom_features
self.bond_features = bond_features
Expand Down Expand Up @@ -320,6 +323,17 @@ def features(self) -> List[np.ndarray]:

return [d.features for d in self._data]

def phase_features(self) -> List[np.ndarray]:
"""
Returns the phase features associated with each molecule (if they exist).
:return: A list of 1D numpy arrays containing the phase features for each molecule or None if there are no features.
"""
if len(self._data) == 0 or self._data[0].phase_features is None:
return None

return [d.phase_features for d in self._data]

def atom_features(self) -> List[np.ndarray]:
"""
Returns the atom descriptors associated with each molecule (if they exit).
Expand Down

0 comments on commit 491085f

Please sign in to comment.