Skip to content

Commit

Permalink
TST: Simplify another legacy ensemble test
Browse files Browse the repository at this point in the history
see commit 3445cb9
  • Loading branch information
Jacob-Stevens-Haas committed May 12, 2023
1 parent 5843574 commit c9286d8
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,20 +1053,7 @@ def test_ensemble_optimizer(data_lorenz, optimizer_params):
assert model.coefficients().shape == (3, 10)


@pytest.mark.parametrize(
"optimizer",
[
STLSQ,
SSR,
FROLS,
SR3,
ConstrainedSR3,
StableLinearSR3,
TrappingSR3,
MIOSR,
],
)
def test_legacy_ensemble_pdes(optimizer):
def test_legacy_ensemble_pdes():
u = np.random.randn(10, 10, 2)
t = np.linspace(1, 10, 10)
x = np.linspace(1, 10, 10)
Expand All @@ -1084,11 +1071,11 @@ def test_legacy_ensemble_pdes(optimizer):
spatial_grid=x,
include_bias=True,
)
opt = optimizer(normalize_columns=True)
opt = STLSQ(normalize_columns=True)
model = SINDy(optimizer=opt, feature_library=pde_lib)
model.fit(u, x_dot=u_dot, ensemble=True, n_models=10, n_subset=20)
model.fit(u, x_dot=u_dot, ensemble=True, n_models=2, n_subset=2)
n_features = len(model.get_feature_names())
assert np.shape(model.coef_list) == (10, 2, n_features)
assert np.shape(model.coef_list) == (2, 2, n_features)


def test_ssr_criteria(data_lorenz):
Expand Down

0 comments on commit c9286d8

Please sign in to comment.