Skip to content

Commit

Permalink
Save test scores as CSV and validation classification/regression data…
Browse files Browse the repository at this point in the history
…set types
  • Loading branch information
swansonk14 committed Jul 16, 2020
1 parent 8c61787 commit a225ef0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
20 changes: 20 additions & 0 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,26 @@ def get_class_sizes(data: MoleculeDataset) -> List[List[float]]:
return class_sizes


def validate_dataset_type(data: MoleculeDataset, dataset_type: str) -> None:
"""
Validates the dataset type to ensure the data matches the provided type.
TODO: Validate multiclass dataset type.
:param data: A MoleculeDataset.
:param dataset_type: The dataset type to check.
"""
target_set = {target for targets in data.targets() for target in targets} - {None}
classification_target_set = {0, 1}

if dataset_type == 'classification' and not (target_set <= classification_target_set):
raise ValueError('Classification data targets must only be 0 or 1 (or None). '
'Please switch to regression.')
elif dataset_type == 'regression' and target_set <= classification_target_set:
raise ValueError('Regression data targets must be more than just 0 or 1 (or None). '
'Please switch to classification.')


def validate_data(data_path: str) -> Set[str]:
"""
Validates a data CSV file, returning a set of errors.
Expand Down
12 changes: 12 additions & 0 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
from logging import Logger
import os
from typing import Tuple
Expand Down Expand Up @@ -51,6 +52,17 @@ def cross_validate(args: TrainArgs, logger: Logger = None) -> Tuple[float, float
info(f'Overall test {task_name} {args.metric} = '
f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}')

# Save scores
with open(os.path.join(save_dir, 'test_scores.csv'), 'w') as f:
writer = csv.writer(f)
writer.writerow(['Task', f'Mean {args.metric}', f'Standard deviation {args.metric}'] +
[f'Fold {i} {args.metric}' for i in range(args.num_folds)])

for task_num, task_name in enumerate(task_names):
task_scores = all_scores[:, task_num]
mean, std = np.nanmean(task_scores), np.nanstd(task_scores)
writer.writerow([task_name, mean, std] + task_scores.tolist())

return mean_score, std_score


Expand Down
5 changes: 2 additions & 3 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
from logging import Logger
import os
import sys
Expand All @@ -8,15 +7,14 @@
from tensorboardX import SummaryWriter
import torch
from tqdm import trange
import pickle
from torch.optim.lr_scheduler import ExponentialLR

from .evaluate import evaluate, evaluate_predictions
from .predict import predict
from .train import train
from chemprop.args import TrainArgs
from chemprop.data import StandardScaler, MoleculeDataLoader
from chemprop.data.utils import get_class_sizes, get_data, get_task_names, split_data
from chemprop.data.utils import get_class_sizes, get_data, get_task_names, split_data, validate_dataset_type
from chemprop.models import MoleculeModel
from chemprop.nn_utils import param_count
from chemprop.utils import build_optimizer, build_lr_scheduler, get_loss_func, get_metric_func, load_checkpoint,\
Expand Down Expand Up @@ -54,6 +52,7 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:
debug('Loading data')
args.task_names = args.target_columns or get_task_names(args.data_path)
data = get_data(path=args.data_path, args=args, logger=logger)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.num_tasks = data.num_tasks()
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')
Expand Down

0 comments on commit a225ef0

Please sign in to comment.