Skip to content

Commit

Permalink
Cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Jul 30, 2020
1 parent 71e5e86 commit 8a2ad6c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 5 deletions.
2 changes: 1 addition & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TrainArgs(CommonArgs):
"""Path to separate val set, optional."""
separate_test_path: str = None
"""Path to separate test set, optional."""
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'index_predetermined'] = 'random'
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'index_predetermined'] = 'random'
"""Method of splitting the data into train/val/test."""
split_sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1)
"""Split proportions for train/validation/test sets."""
Expand Down
26 changes: 25 additions & 1 deletion chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def split_data(data: MoleculeDataset,
split_type: str = 'random',
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
seed: int = 0,
num_folds: int = 1,
args: TrainArgs = None,
logger: Logger = None) -> Tuple[MoleculeDataset,
MoleculeDataset,
Expand All @@ -259,6 +260,7 @@ def split_data(data: MoleculeDataset,
:param split_type: Split type.
:param sizes: A length-3 tuple with the proportions of data in the train, validation, and test sets.
:param seed: The random seed to use before shuffling data.
:param num_folds: Number of folds to create (only needed for "cv" split type).
:param args: A :class:`~chemprop.args.TrainArgs` object.
:param logger: A logger for recording output.
:return: A tuple of :class:`~chemprop.data.MoleculeDataset`\ s containing the train,
Expand Down Expand Up @@ -286,7 +288,29 @@ def split_data(data: MoleculeDataset,
data_split.append([data[i] for i in split_indices])
train, val, test = tuple(data_split)
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)


elif split_type == 'cv':
if num_folds <= 1 or num_folds > len(data):
raise ValueError('Number of folds for cross-validation must be between 2 and len(data), inclusive.')

random = Random(0)

indices = np.repeat(np.arange(num_folds), 1 + len(data) // num_folds)[:len(data)]
random.shuffle(indices)
test_index = seed % num_folds
val_index = (seed + 1) % num_folds

train, val, test = [], [], []
for d, index in zip(data, indices):
if index == test_index:
test.append(d)
elif index == val_index:
val.append(d)
else:
train.append(d)

return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)

elif split_type == 'index_predetermined':
split_indices = args.crossval_index_sets[args.seed]

Expand Down
1 change: 1 addition & 0 deletions chemprop/sklearn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def run_sklearn(args: SklearnTrainArgs, logger: Logger = None) -> List[float]:
split_type=args.split_type,
seed=args.seed,
sizes=args.split_sizes,
num_folds=args.num_folds,
args=args
)

Expand Down
6 changes: 6 additions & 0 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple

import numpy as np
import pandas as pd

from .run_training import run_training
from chemprop.args import TrainArgs
Expand Down Expand Up @@ -79,6 +80,11 @@ def cross_validate(args: TrainArgs) -> Tuple[float, float]:
mean, std = np.nanmean(task_scores), np.nanstd(task_scores)
writer.writerow([task_name, mean, std] + task_scores.tolist())

# Merge and save test preds if doing cross-validation
all_preds = pd.concat([pd.read_csv(os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv'))
for fold_num in range(args.num_folds)])
all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False)

return mean_score, std_score


Expand Down
16 changes: 13 additions & 3 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

import numpy as np
import pandas as pd
from tensorboardX import SummaryWriter
import torch
from tqdm import trange
Expand Down Expand Up @@ -66,11 +67,11 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:
if args.separate_val_path and args.separate_test_path:
train_data = data
elif args.separate_val_path:
train_data, _, test_data = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.0, 0.2), seed=args.seed, args=args, logger=logger)
train_data, _, test_data = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.0, 0.2), seed=args.seed, num_folds=args.num_folds, args=args, logger=logger)
elif args.separate_test_path:
train_data, val_data, _ = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
train_data, val_data, _ = split_data(data=data, split_type=args.split_type, sizes=(0.8, 0.2, 0.0), seed=args.seed, num_folds=args.num_folds, args=args, logger=logger)
else:
train_data, val_data, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, seed=args.seed, args=args, logger=logger)
train_data, val_data, test_data = split_data(data=data, split_type=args.split_type, sizes=args.split_sizes, seed=args.seed, num_folds=args.num_folds, args=args, logger=logger)

if args.dataset_type == 'classification':
class_sizes = get_class_sizes(data)
Expand Down Expand Up @@ -289,4 +290,13 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:
for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')

# Save test preds if doing cross-validation
if args.split_type == 'cv':
test_preds_dataframe = pd.DataFrame(data={'smiles': test_data.smiles()})

for i, task_name in enumerate(args.task_names):
test_preds_dataframe[task_name] = [pred[i] for pred in avg_test_preds]

test_preds_dataframe.to_csv(os.path.join(args.save_dir, 'test_preds.csv'), index=False)

return ensemble_scores

0 comments on commit 8a2ad6c

Please sign in to comment.