Skip to content

Commit

Permalink
Adding split type cv-no-test for cross-validation without test split
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Dec 5, 2020
1 parent 8b9d985 commit b56ca98
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class TrainArgs(CommonArgs):
"""Path to separate val set, optional."""
separate_test_path: str = None
"""Path to separate test set, optional."""
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'index_predetermined'] = 'random'
split_type: Literal['random', 'scaffold_balanced', 'predetermined', 'crossval', 'cv', 'cv-no-test', 'index_predetermined'] = 'random'
"""Method of splitting the data into train/val/test."""
split_sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1)
"""Split proportions for train/validation/test sets."""
Expand Down
4 changes: 2 additions & 2 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def split_data(data: MoleculeDataset,
train, val, test = tuple(data_split)
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)

elif split_type == 'cv':
elif split_type in {'cv', 'cv-no-test'}:
if num_folds <= 1 or num_folds > len(data):
raise ValueError('Number of folds for cross-validation must be between 2 and len(data), inclusive.')

Expand All @@ -362,7 +362,7 @@ def split_data(data: MoleculeDataset,

train, val, test = [], [], []
for d, index in zip(data, indices):
if index == test_index:
if index == test_index and split_type != 'cv-no-test':
test.append(d)
elif index == val_index:
val.append(d)
Expand Down

0 comments on commit b56ca98

Please sign in to comment.