Skip to content

Commit

Permalink
TST: Simplify test_bad_ensemble_params and fix guard code in Ensemble…
Browse files Browse the repository at this point in the history
…Optimizer
  • Loading branch information
Jacob-Stevens-Haas committed Jun 30, 2023
1 parent 59f8e6c commit b36ccf6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
16 changes: 8 additions & 8 deletions pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,19 @@ def __init__(
raise ValueError(
"If not ensembling data or library terms, use another optimizer"
)
if n_subset is not None and n_subset <= 0:
if bagging and (n_subset is None or n_subset < 1):
raise ValueError("n_subset must be a positive integer if bagging")
if n_candidates_to_drop is not None and n_candidates_to_drop <= 0:
if library_ensemble and (
n_candidates_to_drop is None or n_candidates_to_drop < 1
):
raise ValueError(
"n_candidates_to_drop must be a positive integer if ensembling library"
)
self.opt = opt
if n_models is None or n_models == 0:
warnings.warn(
"n_models must be a positive integer. Explicitly initialized to zero"
" or None, defaulting to 20."
if n_models < 1:
raise ValueError(
"n_candidates_to_drop must be a positive integer if ensembling library"
)
n_models = 20
self.opt = opt
self.n_models = n_models
self.n_subset = n_subset
self.bagging = bagging
Expand Down
13 changes: 4 additions & 9 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,15 +1046,10 @@ def test_ensemble_optimizer(data_lorenz, optimizer_params):
@pytest.mark.parametrize(
"params",
[
dict(ensemble=False, n_models=-1, n_subset=1),
dict(ensemble=False, n_models=0, n_subset=1),
dict(ensemble=False, n_models=1, n_subset=0),
dict(ensemble=False, n_models=1, n_subset=-1),
dict(ensemble=True, n_models=-1, n_subset=1),
dict(ensemble=True, n_models=0, n_subset=1),
dict(ensemble=True, n_models=1, n_subset=0),
dict(ensemble=True, n_models=1, n_subset=-1),
dict(ensemble=True, n_models=1, n_subset=0),
dict(),
dict(bagging=True, n_models=0),
dict(bagging=True, n_subset=0),
dict(library_ensemble=True, n_candidates_to_drop=0),
],
)
def test_bad_ensemble_params(data_lorenz, params):
Expand Down

0 comments on commit b36ccf6

Please sign in to comment.