Skip to content

Commit

Permalink
Only load data once during cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Aug 24, 2020
1 parent 46b9f64 commit befaabc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 41 deletions.
16 changes: 4 additions & 12 deletions chemprop/sklearn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,15 @@ def multi_task_sklearn(model: Union[RandomForestRegressor, RandomForestClassifie
return scores


def run_sklearn(args: SklearnTrainArgs, logger: Logger = None) -> Dict[str, List[float]]:
def run_sklearn(args: SklearnTrainArgs,
data: MoleculeDataset,
logger: Logger = None) -> Dict[str, List[float]]:
"""
Loads data, trains a scikit-learn model, and returns test scores for the model checkpoint with the highest validation score.
:param args: A :class:`~chemprop.args.SklearnTrainArgs` object containing arguments for
loading data and training the scikit-learn model.
:param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
:param logger: A logger to record output.
:return: A dictionary mapping each metric in :code:`metrics` to a list of values for each task.
"""
Expand All @@ -175,17 +178,6 @@ def run_sklearn(args: SklearnTrainArgs, logger: Logger = None) -> Dict[str, List
else:
debug = info = print

debug(pformat(vars(args)))

debug('Loading data')
data = get_data(path=args.data_path, smiles_column=args.smiles_column, target_columns=args.target_columns)
args.task_names = get_task_names(
path=args.data_path,
smiles_column=args.smiles_column,
target_columns=args.target_columns,
ignore_columns=args.ignore_columns
)

if args.model_type == 'svm' and data.num_tasks() != 1:
raise ValueError(f'SVM can only handle single-task data but found {data.num_tasks()} tasks')

Expand Down
38 changes: 32 additions & 6 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
import csv
from logging import Logger
import os
import sys
from typing import Callable, Dict, List, Tuple, Union

import numpy as np
import torch

from .run_training import run_training
from chemprop.args import SklearnTrainArgs, TrainArgs
from chemprop.args import TrainArgs
from chemprop.constants import TEST_SCORES_FILE_NAME, TRAIN_LOGGER_NAME
from chemprop.data import get_task_names
from chemprop.data import get_data, get_task_names, MoleculeDataset, validate_dataset_type
from chemprop.utils import create_logger, makedirs, timeit


@timeit(logger_name=TRAIN_LOGGER_NAME)
def cross_validate(args: Union[TrainArgs, SklearnTrainArgs],
train_func: Callable[[Union[TrainArgs, SklearnTrainArgs], Logger], Dict[str, List[float]]]
def cross_validate(args: TrainArgs,
train_func: Callable[[TrainArgs, MoleculeDataset, Logger], Dict[str, List[float]]]
) -> Tuple[float, float]:
"""
Runs k-fold cross-validation.
Expand All @@ -29,7 +31,10 @@ def cross_validate(args: Union[TrainArgs, SklearnTrainArgs],
:return: A tuple containing the mean and standard deviation performance across folds.
"""
logger = create_logger(name=TRAIN_LOGGER_NAME, save_dir=args.save_dir, quiet=args.quiet)
info = logger.info if logger is not None else print
if logger is not None:
debug, info = logger.debug, logger.info
else:
debug = info = print

# Initialize relevant variables
init_seed = args.seed
Expand All @@ -41,14 +46,35 @@ def cross_validate(args: Union[TrainArgs, SklearnTrainArgs],
ignore_columns=args.ignore_columns
)

# Print command line
debug('Command line')
debug(f'python {" ".join(sys.argv)}')

# Print args
debug('Args')
debug(args)

# Save args
args.save(os.path.join(args.save_dir, 'args.json'))

# Set pytorch seed for random initial weights
torch.manual_seed(args.pytorch_seed)

# Get data
debug('Loading data')
data = get_data(path=args.data_path, args=args, logger=logger, skip_none_targets=True)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')

# Run training on different random seeds for each fold
all_scores = defaultdict(list)
for fold_num in range(args.num_folds):
info(f'Fold {fold_num}')
args.seed = init_seed + fold_num
args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
makedirs(args.save_dir)
model_scores = train_func(args, logger)
model_scores = train_func(args, data, logger)
for metric, scores in model_scores.items():
all_scores[metric].append(scores)
all_scores = dict(all_scores)
Expand Down
28 changes: 5 additions & 23 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@
from .train import train
from chemprop.args import TrainArgs
from chemprop.constants import MODEL_FILE_NAME
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, split_data, StandardScaler, validate_dataset_type
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, split_data, StandardScaler
from chemprop.models import MoleculeModel
from chemprop.nn_utils import param_count
from chemprop.utils import build_optimizer, build_lr_scheduler, get_loss_func, load_checkpoint,makedirs, \
save_checkpoint, save_smiles_splits


def run_training(args: TrainArgs, logger: Logger = None) -> Dict[str, List[float]]:
def run_training(args: TrainArgs,
data: MoleculeDataset,
logger: Logger = None) -> Dict[str, List[float]]:
"""
Loads data, trains a Chemprop model, and returns test scores for the model checkpoint with the highest validation score.
:param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for
loading data and training the Chemprop model.
:param data: A :class:`~chemprop.data.MoleculeDataset` containing the data.
:param logger: A logger to record output.
:return: A dictionary mapping each metric in :code:`args.metrics` to a list of values for each task.
Expand All @@ -36,27 +39,6 @@ def run_training(args: TrainArgs, logger: Logger = None) -> Dict[str, List[float
else:
debug = info = print

# Print command line
debug('Command line')
debug(f'python {" ".join(sys.argv)}')

# Print args
debug('Args')
debug(args)

# Save args
args.save(os.path.join(args.save_dir, 'args.json'))

# Set pytorch seed for random initial weights
torch.manual_seed(args.pytorch_seed)

# Get data
debug('Loading data')
data = get_data(path=args.data_path, args=args, logger=logger, skip_none_targets=True)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')

# Split data
debug(f'Splitting data with seed {args.seed}')
if args.separate_test_path:
Expand Down

0 comments on commit befaabc

Please sign in to comment.