Skip to content

Commit

Permalink
Merge pull request #230 from chemprop/split_key_molecule
Browse files Browse the repository at this point in the history
Molecule Index for Splitting Functions and Assorted Other Splitting Changes
  • Loading branch information
cjmcgill committed Jan 10, 2022
2 parents 651fa43 + 5b1bdad commit 02433d2
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 35 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,17 @@ Our code supports several methods of splitting data into train, validation, and

**Separate val/test:** If you have separate data files you would like to use as the validation or test set, you can specify them with `--separate_val_path <val_path>` and/or `--separate_test_path <test_path>`. If both are provided, then the data specified by `--data_path` is used entirely as the training data. If only one separate path is provided, the `--data_path` data is split between train data and either val or test data, whichever is not provided separately.

Note: By default, both random and scaffold split the data into 80% train, 10% validation, and 10% test. This can be changed with `--split_sizes <train_frac> <val_frac> <test_frac>`. For example, the default setting is `--split_sizes 0.8 0.1 0.1`. Both also involve a random component and can be seeded with `--seed <seed>`. The default setting is `--seed 0`.
When data contains multiple molecules per datapoint, scaffold and repeated SMILES splitting will only constrain splitting based on one of the molecules. The key molecule can be chosen with the argument `--split_key_molecule <int>`, with the default setting using an index of 0 indicating the first molecule.

By default, both random and scaffold split the data into 80% train, 10% validation, and 10% test. This can be changed with `--split_sizes <train_frac> <val_frac> <test_frac>`. The default setting is `--split_sizes 0.8 0.1 0.1`. If a separate validation set or test set is provided, the split defaults to 80%-20%. Splitting involves a random component and can be seeded with `--seed <seed>`. The default setting is `--seed 0`.

### Cross validation

k-fold cross-validation can be run by specifying `--num_folds <k>`. The default is `--num_folds 1`.
k-fold cross-validation can be run by specifying `--num_folds <k>`. The default is `--num_folds 1`. Each trained model will have different data splits. The reported test score will be the average of the metrics from each fold.

### Ensembling

To train an ensemble, specify the number of models in the ensemble with `--ensemble_size <n>`. The default is `--ensemble_size 1`.
To train an ensemble, specify the number of models in the ensemble with `--ensemble_size <n>`. The default is `--ensemble_size 1`. Each trained model within the ensemble will share data splits. The reported test score for one ensemble is the metric applied to the averaged prediction across the models. Ensembling and cros-validation can be used at the same time.

### Hyperparameter Optimization

Expand Down
84 changes: 66 additions & 18 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from tempfile import TemporaryDirectory
import pickle
from typing import List, Optional, Tuple
from typing import List, Optional
from typing_extensions import Literal

import torch
Expand Down Expand Up @@ -243,8 +243,10 @@ class TrainArgs(CommonArgs):
"""Weights associated with each target, affecting the relative weight of targets in the loss function. Must match the number of target columns."""
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined', 'random_with_repeated_smiles'] = 'random'
"""Method of splitting the data into train/val/test."""
split_sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1)
split_sizes: List[float] = None
"""Split proportions for train/validation/test sets."""
split_key_molecule: int = 0
"""The index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used, like scaffold_balanced or random_with_repeated_smiles."""
num_folds: int = 1
"""Number of folds when performing cross validation."""
folds_file: str = None
Expand Down Expand Up @@ -586,27 +588,73 @@ def process_args(self) -> None:
self._crossval_index_sets = pickle.load(rf)
self.num_folds = len(self.crossval_index_sets)
self.seed = 0

# Validate split size entry and set default values
if self.split_sizes is None:
if self.separate_val_path is None and self.separate_test_path is None: # separate data paths are not provided
self.split_sizes = (0.8, 0.1, 0.1)
elif self.separate_val_path is not None and self.separate_test_path is None: # separate val path only
self.split_sizes = (0.8, 0., 0.2)
elif self.separate_val_path is None and self.separate_test_path is not None: # separate test path only
self.split_sizes = (0.8, 0.2, 0.)
else: # both separate data paths are provided
self.split_sizes = (1., 0., 0.)

else:
if sum(self.split_sizes) != 1.:
raise ValueError(f'Provided split sizes of {self.split_sizes} do not sum to 1.')

if len(self.split_sizes) not in [2,3]:
raise ValueError(f'Three values should be provided for train/val/test split sizes. Instead received {len(self.split_sizes)} value(s).')

if self.separate_val_path is None and self.separate_test_path is None: # separate data paths are not provided
if len(self.split_sizes) != 3:
raise ValueError(f'Three values should be provided for train/val/test split sizes. Instead received {len(self.split_sizes)} value(s).')
if 0. in self.split_sizes:
raise ValueError(f'Provided split sizes must be nonzero if no separate data files are provided. Received split sizes of {self.split_sizes}.')

elif self.separate_val_path is not None and self.separate_test_path is None: # separate val path only
if len(self.split_sizes) == 2: # allow input of just 2 values
self.split_sizes = (self.split_sizes[0], 0., self.split_sizes[1])
if self.split_sizes[0] == 0.:
raise ValueError('Provided split size for train split must be nonzero.')
if self.split_sizes[1] != 0.:
raise ValueError('Provided split size for validation split must be 0 because validation set is provided separately.')
if self.split_sizes[2] == 0.:
raise ValueError('Provided split size for test split must be nonzero.')

elif self.separate_val_path is None and self.separate_test_path is not None: # separate test path only
if len(self.split_sizes) == 2: # allow input of just 2 values
self.split_sizes = (self.split_sizes[0], self.split_sizes[1], 0.)
if self.split_sizes[0] == 0.:
raise ValueError('Provided split size for train split must be nonzero.')
if self.split_sizes[1] == 0.:
raise ValueError('Provided split size for validation split must be nonzero.')
if self.split_sizes[2] != 0.:
raise ValueError('Provided split size for test split must be 0 because test set is provided separately.')


else: # both separate data paths are provided
if self.split_sizes != (1., 0., 0.):
raise ValueError(f'Separate data paths were provided for val and test splits. Split sizes should not also be provided.')

# Test settings
if self.test:
self.epochs = 0

# Validate extra atom or bond features for separate validation or test set
if self.separate_val_path is not None and self.atom_descriptors is not None \
and self.separate_val_atom_descriptors_path is None:
raise ValueError('Atom descriptors are required for the separate validation set.')

if self.separate_test_path is not None and self.atom_descriptors is not None \
and self.separate_test_atom_descriptors_path is None:
raise ValueError('Atom descriptors are required for the separate test set.')

if self.separate_val_path is not None and self.bond_features_path is not None \
and self.separate_val_bond_features_path is None:
raise ValueError('Bond descriptors are required for the separate validation set.')

if self.separate_test_path is not None and self.bond_features_path is not None \
and self.separate_test_bond_features_path is None:
raise ValueError('Bond descriptors are required for the separate test set.')
# Validate features are provided for separate validation or test set for each of the kinds of additional features
for (features_argument, base_features_path, val_features_path, test_features_path) in [
('`--features_path`', self.features_path, self.separate_val_features_path, self.separate_test_features_path),
('`--phase_features_path`', self.phase_features_path, self.separate_val_phase_features_path, self.separate_test_phase_features_path),
('`--atom_descriptors_path`', self.atom_descriptors_path, self.separate_val_atom_descriptors_path, self.separate_test_atom_descriptors_path),
('`--bond_features_path`', self.bond_features_path, self.separate_val_bond_features_path, self.separate_test_bond_features_path)
]:
if base_features_path is not None:
if self.separate_val_path is not None and val_features_path is None:
raise ValueError(f'Additional features were provided using the argument {features_argument}. The same kinds of features must be provided for the separate validation set.')
if self.separate_test_path is not None and test_features_path is None:
raise ValueError(f'Additional features were provided using the argument {features_argument}. The same kinds of features must be provided for the separate test set.')


# validate extra atom descriptor options
if self.overwrite_default_atom_features and self.atom_descriptors != 'feature':
Expand Down
8 changes: 4 additions & 4 deletions chemprop/data/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def scaffold_to_smiles(mols: Union[List[str], List[Chem.Mol], List[Tuple[Chem.Mo
def scaffold_split(data: MoleculeDataset,
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
balanced: bool = False,
key_molecule_index: int = 0,
seed: int = 0,
logger: logging.Logger = None) -> Tuple[MoleculeDataset,
MoleculeDataset,
Expand All @@ -63,23 +64,22 @@ def scaffold_split(data: MoleculeDataset,
:param data: A :class:`MoleculeDataset`.
:param sizes: A length-3 tuple with the proportions of data in the train, validation, and test sets.
:param balanced: Whether to balance the sizes of scaffolds in each set rather than putting the smallest in test set.
:param key_molecule_index: For data with multiple molecules, this sets which molecule will be considered during splitting.
:param seed: Random seed for shuffling when doing balanced splitting.
:param logger: A logger for recording output.
:return: A tuple of :class:`~chemprop.data.MoleculeDataset`\ s containing the train,
validation, and test splits of the data.
"""
assert sum(sizes) == 1

if data.number_of_molecules > 1:
raise ValueError('Cannot perform a scaffold split with more than one molecule per datapoint.')

# Split
train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data)
train, val, test = [], [], []
train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0

# Map from scaffold to index in the data
scaffold_to_indices = scaffold_to_smiles(data.mols(flatten=True), use_indices=True)
key_mols = [m[key_molecule_index] for m in data.mols(flatten=False)]
scaffold_to_indices = scaffold_to_smiles(key_mols, use_indices=True)

# Seed randomness
random = Random(seed)
Expand Down
11 changes: 8 additions & 3 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def get_data_from_smiles(smiles: List[List[str]],
def split_data(data: MoleculeDataset,
split_type: str = 'random',
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
key_molecule_index: int = 0,
seed: int = 0,
num_folds: int = 1,
args: TrainArgs = None,
Expand All @@ -398,6 +399,7 @@ def split_data(data: MoleculeDataset,
:param data: A :class:`~chemprop.data.MoleculeDataset`.
:param split_type: Split type.
:param sizes: A length-3 tuple with the proportions of data in the train, validation, and test sets.
:param key_molecule_index: For data with multiple molecules, this sets which molecule will be considered during splitting.
:param seed: The random seed to use before shuffling data.
:param num_folds: Number of folds to create (only needed for "cv" split type).
:param args: A :class:`~chemprop.args.TrainArgs` object.
Expand All @@ -415,6 +417,9 @@ def split_data(data: MoleculeDataset,
args.folds_file, args.val_fold_index, args.test_fold_index
else:
folds_file = val_fold_index = test_fold_index = None

if key_molecule_index >= args.number_of_molecules:
raise ValueError('The index provided with the argument `--split_key_molecule` must be less than the number of molecules. Note that this index begins with 0 for the first molecule. ')

if split_type == 'crossval':
index_set = args.crossval_index_sets[args.seed]
Expand Down Expand Up @@ -501,12 +506,12 @@ def split_data(data: MoleculeDataset,
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)

elif split_type == 'scaffold_balanced':
return scaffold_split(data, sizes=sizes, balanced=True, seed=seed, logger=logger)
return scaffold_split(data, sizes=sizes, balanced=True, key_molecule_index=key_molecule_index, seed=seed, logger=logger)

elif split_type == 'random_with_repeated_smiles': # Use to constrain data with the same smiles go in the same split. Considers first molecule only.
elif split_type == 'random_with_repeated_smiles': # Use to constrain data with the same smiles go in the same split.
smiles_dict=defaultdict(set)
for i,smiles in enumerate(data.smiles()):
smiles_dict[smiles[0]].add(i)
smiles_dict[smiles[key_molecule_index]].add(i)
index_sets=list(smiles_dict.values())
random.seed(seed)
random.shuffle(index_sets)
Expand Down
7 changes: 5 additions & 2 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def run_training(args: TrainArgs,
elif args.separate_val_path:
train_data, _, test_data = split_data(data=data,
split_type=args.split_type,
sizes=(0.8, 0.0, 0.2),
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
logger=logger)
elif args.separate_test_path:
train_data, val_data, _ = split_data(data=data,
split_type=args.split_type,
sizes=(0.8, 0.2, 0.0),
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
Expand All @@ -87,6 +89,7 @@ def run_training(args: TrainArgs,
train_data, val_data, test_data = split_data(data=data,
split_type=args.split_type,
sizes=args.split_sizes,
key_molecule_index=args.split_key_molecule,
seed=args.seed,
num_folds=args.num_folds,
args=args,
Expand Down
11 changes: 6 additions & 5 deletions scripts/create_crossval_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ class Args(Tap):
val_folds_per_test: int = 3 # Number of val folds
time_folds_per_train_set: int = 3 # X:1:1 train:val:test for time split sliding window
smiles_columns: List[str] = None # columns in CSV dataset file containing SMILES
split_key_molecule: int = 0 # index of the molecule to use for splitting in muli-molecule data


def split_indices(all_indices: List[int],
num_folds: int,
scaffold: bool = False,
split_key_molecule: int = 0,
data: MoleculeDataset = None,
shuffle: bool = True) -> List[List[int]]:
num_data = len(all_indices)
if scaffold:
if data.number_of_molecules > 1:
raise ValueError('Cannot perform a scaffold split with more than one molecule per datapoint.')
scaffold_to_indices = scaffold_to_smiles(data.mols(flatten=True), use_indices=True)
key_mols = [m[split_key_molecule] for m in data.mols(flatten=False)]
scaffold_to_indices = scaffold_to_smiles(key_mols, use_indices=True)
index_sets = sorted(list(scaffold_to_indices.values()),
key=lambda index_set: len(index_set),
reverse=True)
Expand Down Expand Up @@ -68,7 +69,7 @@ def create_time_splits(args: Args):
subset_data = MoleculeDataset(data[begin:end])
fold_indices['random'].append(split_indices(deepcopy(subset_indices), args.time_folds_per_train_set + 2))
fold_indices['scaffold'].append(
split_indices(subset_indices, args.time_folds_per_train_set + 2, scaffold=True, data=subset_data))
split_indices(subset_indices, args.time_folds_per_train_set + 2, scaffold=True, split_key_molecule=args.split_key_molecule ,data=subset_data))
fold_indices['time'].append(split_indices(subset_indices, args.time_folds_per_train_set + 2, shuffle=False))
for split_type in ['random', 'scaffold', 'time']:
all_splits = []
Expand Down Expand Up @@ -96,7 +97,7 @@ def create_crossval_splits(args: Args):
fold_indices = split_indices(all_indices, args.num_folds, scaffold=False)
elif args.split_type == 'scaffold':
all_indices = list(range(num_data))
fold_indices = split_indices(all_indices, args.num_folds, scaffold=True, data=data)
fold_indices = split_indices(all_indices, args.num_folds, scaffold=True, split_key_molecule=args.split_key_molecule, data=data)
else:
raise ValueError
random.shuffle(fold_indices)
Expand Down

0 comments on commit 02433d2

Please sign in to comment.