Skip to content

Commit

Permalink
Option to save test predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Aug 24, 2020
1 parent df32e06 commit 8d5d0c6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 2 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class TrainArgs(CommonArgs):
Below this number, caching is used and data loading is sequential.
Above this number, caching is not used and data loading is parallel.
"""
save_preds: bool = False
"""Whether to save test split predictions during training."""

# Model arguments
bias: bool = False
Expand Down
9 changes: 5 additions & 4 deletions chemprop/train/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ def cross_validate(args: TrainArgs) -> Tuple[float, float]:
mean, std = np.nanmean(task_scores), np.nanstd(task_scores)
writer.writerow([task_name, mean, std] + task_scores.tolist())

# Merge and save test preds if doing cross-validation
all_preds = pd.concat([pd.read_csv(os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv'))
for fold_num in range(args.num_folds)])
all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False)
# Optionally merge and save test preds
if args.save_preds:
all_preds = pd.concat([pd.read_csv(os.path.join(save_dir, f'fold_{fold_num}', 'test_preds.csv'))
for fold_num in range(args.num_folds)])
all_preds.to_csv(os.path.join(save_dir, 'test_preds.csv'), index=False)

return mean_score, std_score

Expand Down
4 changes: 2 additions & 2 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:
for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')

# Save test preds if doing cross-validation
if args.split_type == 'cv':
# Optionally save test preds
if args.save_preds:
test_preds_dataframe = pd.DataFrame(data={'smiles': test_data.smiles()})

for i, task_name in enumerate(args.task_names):
Expand Down

0 comments on commit 8d5d0c6

Please sign in to comment.