Skip to content

Commit

Permalink
Merge pull request #52 from dynamicslab/check-optimizer-attributes
Browse files Browse the repository at this point in the history
Check optimizer attributes
  • Loading branch information
briandesilva committed Feb 20, 2020
2 parents 9069ea3 + b5fa7f9 commit 2bb361d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
29 changes: 22 additions & 7 deletions pysindy/optimizers/sindy_optimizer.py
Expand Up @@ -32,8 +32,13 @@ class SINDyOptimizer(BaseEstimator):
"""

def __init__(self, optimizer, unbias=True):
# TODO: add a check that optimizer has the necessary attributes
# and methods
if not hasattr(optimizer, "fit") or not callable(getattr(optimizer, "fit")):
raise AttributeError("optimizer does not have a callable fit method")
if not hasattr(optimizer, "predict") or not callable(
getattr(optimizer, "predict")
):
raise AttributeError("optimizer does not have a callable predict method")

self.optimizer = optimizer
self.unbias = unbias

Expand All @@ -42,6 +47,8 @@ def fit(self, x, y):
if not supports_multiple_targets(self.optimizer):
self.optimizer = _MultiTargetLinearRegressor(self.optimizer)
self.optimizer.fit(x, y)
if not hasattr(self.optimizer, "coef_"):
raise AttributeError("optimizer has no attribute coef_")
self.ind_ = np.abs(self.coef_) > 1e-14

if self.unbias:
Expand All @@ -51,13 +58,18 @@ def fit(self, x, y):

def _unbias(self, x, y):
coef = np.zeros((y.shape[1], x.shape[1]))
if hasattr(self.optimizer, "fit_intercept"):
fit_intercept = self.optimizer.fit_intercept
else:
fit_intercept = False
if hasattr(self.optimizer, "normalize"):
normalize = self.optimizer.normalize
else:
normalize = False
for i in range(self.ind_.shape[0]):
if np.any(self.ind_[i]):
coef[i, self.ind_[i]] = (
LinearRegression(
fit_intercept=self.optimizer.fit_intercept,
normalize=self.optimizer.normalize,
)
LinearRegression(fit_intercept=fit_intercept, normalize=normalize)
.fit(x[:, self.ind_[i]], y[:, i])
.coef_
)
Expand All @@ -82,7 +94,10 @@ def coef_(self):

@property
def intercept_(self):
return self.optimizer.intercept_
if hasattr(self.optimizer, "intercept_"):
return self.optimizer.intercept_
else:
return 0.0

@property
def complexity(self):
Expand Down
32 changes: 32 additions & 0 deletions test/optimizers/test_optimizers.py
Expand Up @@ -22,6 +22,26 @@ def fit(self, x, y):
self.intercept_ = 0
return self

def predict(self, x):
return x


class DummyEmptyModel(BaseEstimator):
# Does not have fit or predict methods
def __init__(self):
self.fit_intercept = False
self.normalize = False


class DummyModelNoCoef(BaseEstimator):
# Does not set the coef_ attribute
def fit(self, x, y):
self.intercept_ = 0
return self

def predict(self, x):
return x


@pytest.mark.parametrize(
"cls, support",
Expand Down Expand Up @@ -103,6 +123,18 @@ def test_bad_parameters(data_derivative_1d):
SR3(max_iter=0)


def test_bad_optimizers(data_derivative_1d):
x, x_dot = data_derivative_1d
x = x.reshape(-1, 1)

with pytest.raises(AttributeError):
opt = SINDyOptimizer(DummyEmptyModel())

with pytest.raises(AttributeError):
opt = SINDyOptimizer(DummyModelNoCoef())
opt.fit(x, x_dot)


# The different capitalizations are intentional;
# I want to make sure different versions are recognized
@pytest.mark.parametrize("thresholder", ["L0", "l1", "CAD"])
Expand Down

0 comments on commit 2bb361d

Please sign in to comment.