Skip to content

Commit

Permalink
Merge pull request #69 from billtubbs/master
Browse files Browse the repository at this point in the history
Change to SINDy init defaults for optimizer, feature_library, and differentiation_method
  • Loading branch information
briandesilva committed Apr 26, 2020
2 parents 5c45a36 + f9ac15b commit c8e7324
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
18 changes: 6 additions & 12 deletions pysindy/optimizers/stlsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@


class STLSQ(BaseOptimizer):
"""
Sequentially thresholded least squares algorithm.
"""Sequentially thresholded least squares algorithm.
Attempts to minimize the objective function
:math:`\\|y - Xw\\|^2_2 + alpha \\times \\|w\\|^2_2`
Expand Down Expand Up @@ -128,8 +127,7 @@ def _no_change(self):
return all(bool(i) == bool(j) for i, j in zip(this_coef, last_coef))

def _reduce(self, x, y):
"""
Iterates the thresholding. Assumes an initial guess is saved in
"""Iterates the thresholding. Assumes an initial guess is saved in
self.coef_ and self.ind_
"""
ind = self.ind_
Expand All @@ -140,10 +138,8 @@ def _reduce(self, x, y):
for _ in range(self.max_iter):
if np.count_nonzero(ind) == 0:
warnings.warn(
"""Sparsity parameter is too big ({}) and eliminated all
coefficients""".format(
self.threshold
)
"Sparsity parameter is too big ({}) and eliminated all "
"coefficients".format(self.threshold)
)
coef = np.zeros((n_targets, n_features))
break
Expand All @@ -152,10 +148,8 @@ def _reduce(self, x, y):
for i in range(n_targets):
if np.count_nonzero(ind[i]) == 0:
warnings.warn(
"""Sparsity parameter is too big ({}) and eliminated all
coefficients""".format(
self.threshold
)
"Sparsity parameter is too big ({}) and eliminated all "
"coefficients".format(self.threshold)
)
continue
coef_i = self._regress(x[:, ind[i]], y[:, i])
Expand Down
12 changes: 9 additions & 3 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,21 @@ class SINDy(BaseEstimator):

def __init__(
self,
optimizer=STLSQ(),
feature_library=PolynomialFeatures(),
differentiation_method=FiniteDifference(),
optimizer=None,
feature_library=None,
differentiation_method=None,
feature_names=None,
discrete_time=False,
n_jobs=1,
):
if optimizer is None:
optimizer = STLSQ()
self.optimizer = optimizer
if feature_library is None:
feature_library = PolynomialFeatures()
self.feature_library = feature_library
if differentiation_method is None:
differentiation_method = FiniteDifference()
self.differentiation_method = differentiation_method
self.feature_names = feature_names
self.discrete_time = discrete_time
Expand Down

0 comments on commit c8e7324

Please sign in to comment.