Skip to content

Commit

Permalink
remove asserts (#257)
Browse files Browse the repository at this point in the history
* comment out inappropriate asserts and replace the appropriate ones with Exceptions

* use `np.isclose` to check sum value

* Remove already addressed TODO
  • Loading branch information
davidegraff committed Mar 21, 2022
1 parent 0f04bfa commit 7aa766b
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 13 deletions.
6 changes: 5 additions & 1 deletion chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,11 @@ def set_targets(self, targets: List[List[Optional[float]]]) -> None:
:param targets: A list of lists of floats (or None) containing targets for each molecule. This must be the
same length as the underlying dataset.
"""
assert len(self._data) == len(targets)
if not len(self._data) == len(targets):
raise ValueError(
"number of molecules and targets must be of same length! "
f"num molecules: {len(self._data)}, num targets: {len(targets)}"
)
for i in range(len(self._data)):
self._data[i].set_targets(targets[i])

Expand Down
3 changes: 2 additions & 1 deletion chemprop/data/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def scaffold_split(data: MoleculeDataset,
:return: A tuple of :class:`~chemprop.data.MoleculeDataset`\ s containing the train,
validation, and test splits of the data.
"""
assert sum(sizes) == 1
if not (len(sizes) == 3 and np.isclose(sum(sizes), 1)):
raise ValueError(f"Invalid train/val/test splits! got: {sizes}")

# Split
train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data)
Expand Down
10 changes: 6 additions & 4 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ def split_data(data: MoleculeDataset,
:return: A tuple of :class:`~chemprop.data.MoleculeDataset`\ s containing the train,
validation, and test splits of the data.
"""
if not (len(sizes) == 3 and sum(sizes) == 1):
raise ValueError('Valid split sizes must sum to 1 and must have three sizes: train, validation, and test.')
if not (len(sizes) == 3 and np.isclose(sum(sizes), 1)):
raise ValueError(f"Invalid train/val/test splits! got: {sizes}")

random = Random(seed)

Expand Down Expand Up @@ -571,8 +571,10 @@ def split_data(data: MoleculeDataset,
raise ValueError('Test size must be zero since test set is created separately '
'and we want to put all other data in train and validation')

assert folds_file is not None
assert test_fold_index is not None
if folds_file is None:
raise ValueError('arg "folds_file" can not be None!')
if test_fold_index is None:
raise ValueError('arg "test_fold_index" can not be None!')

try:
with open(folds_file, 'rb') as f:
Expand Down
15 changes: 13 additions & 2 deletions chemprop/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,19 @@ def __init__(self,
:param max_lr: The maximum learning rate (achieved after :code:`warmup_epochs`).
:param final_lr: The final learning rate (achieved after :code:`total_epochs`).
"""
assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
len(max_lr) == len(final_lr)
if not (
len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) \
== len(init_lr) == len(max_lr) == len(final_lr)
):
raise ValueError(
"Number of param groups must match the number of epochs and learning rates! "
f"got: len(optimizer.param_groups)= {len(optimizer.param_groups)}, "
f"len(warmup_epochs)= {len(warmup_epochs)}, "
f"len(total_epochs)= {len(total_epochs)}, "
f"len(init_lr)= {len(init_lr)}, "
f"len(max_lr)= {len(max_lr)}, "
f"len(final_lr)= {len(final_lr)}"
)

self.num_lrs = len(optimizer.param_groups)

Expand Down
2 changes: 1 addition & 1 deletion chemprop/sklearn_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def predict_sklearn(args: SklearnPredictArgs) -> None:
avg_preds = avg_preds.tolist()

print(f'Saving predictions to {args.preds_path}')
assert len(data) == len(avg_preds)
# assert len(data) == len(avg_preds) #TODO: address with unit test later
makedirs(args.preds_path, isfile=True)

# Copy predictions over to data
Expand Down
7 changes: 4 additions & 3 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ def predict_and_save(args: PredictArgs, train_args: TrainArgs, test_data: Molecu

# Save predictions
print(f'Saving predictions to {args.preds_path}')
assert len(test_data) == len(avg_preds)
if args.ensemble_variance:
assert len(test_data) == len(all_epi_uncs)
#TODO: add unit tests for this
# assert len(test_data) == len(avg_preds)
# if args.ensemble_variance:
# assert len(test_data) == len(all_epi_uncs)
makedirs(args.preds_path, isfile=True)

# Set multiclass column names, update num_tasks definition for multiclass
Expand Down
2 changes: 1 addition & 1 deletion chemprop/train/molecule_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None)

# Save predictions
print(f'Saving predictions to {args.preds_path}')
assert len(test_data) == len(all_fingerprints)
# assert len(test_data) == len(all_fingerprints) #TODO: add unit test for this
makedirs(args.preds_path, isfile=True)

# Set column names
Expand Down

0 comments on commit 7aa766b

Please sign in to comment.