Skip to content

Commit

Permalink
BUG: Update parametrized library calc_trajectory API
Browse files Browse the repository at this point in the history
calc_trajectory now returns tuple of (x, x_lhs)

Also reverted a error in merging master around inheritance of
calc_trajectory
  • Loading branch information
Jacob-Stevens-Haas committed Apr 11, 2023
1 parent c384bf6 commit 24877a2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 3 additions & 1 deletion pysindy/feature_library/generalized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def __init__(
self.inputs_per_library_ = inputs_per_library
self.libraries_full_ = self.libraries_
self.exclude_libs_ = exclude_libraries
self.calc_trajectory = self.libraries_[0].calc_trajectory

@x_sequence_or_item
def fit(self, x_full, y=None):
Expand Down Expand Up @@ -315,6 +314,9 @@ def get_feature_names(self, input_features=None):
feature_names += lib.get_feature_names(input_features_i)
return feature_names

def calc_trajectory(self, diff_method, x, t):
return self.libraries_[0].calc_trajectory(diff_method, x, t)

def get_spatial_grid(self):
for lib_k in self.libraries_:
spatial_grid = lib_k.get_spatial_grid()
Expand Down
6 changes: 2 additions & 4 deletions pysindy/feature_library/parameterized_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def calc_trajectory(self, diff_method, x, t):
constants_final = np.ones(self.libraries_[0].K)
for k in range(self.libraries_[0].K):
constants_final[k] = np.sum(self.libraries_[0].fullweights0[k])
return (
self.libraries_[0].calc_trajectory(diff_method, x, t)
* constants_final[:, np.newaxis]
)
x, x_int = self.libraries_[0].calc_trajectory(diff_method, x, t)
return x, x_int * constants_final[:, np.newaxis]
else:
return self.libraries_[0].calc_trajectory(diff_method, x, t)

0 comments on commit 24877a2

Please sign in to comment.