Skip to content

Commit

Permalink
Change signatures of _constrained_ordination helper functions
Browse files Browse the repository at this point in the history
- Changed first argument to be positional, rather than required keyword
- Edit error message to return list rather than dict_keys
  • Loading branch information
grovduck committed Oct 17, 2023
1 parent 1735983 commit a5d0d84
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/sknnr/transformers/_constrained_ordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from numpy.typing import NDArray


def is_2d_numeric_array(*, arr: NDArray) -> bool:
def is_2d_numeric_array(arr: NDArray) -> bool:
"""Verify that arr is a 2D numeric array."""
return arr.ndim == 2 and np.issubdtype(arr.dtype, np.number)


def zero_sum_vectors(*, arr: NDArray, axis: int) -> NDArray:
def zero_sum_vectors(arr: NDArray, *, axis: int) -> NDArray:
"""Find any rows or columns in arr that sum to 0."""
return arr.sum(axis=axis) <= 0.0

Expand Down Expand Up @@ -44,19 +44,19 @@ def _set_initialization_attributes(self, X: NDArray, Y: NDArray) -> None:
@staticmethod
def _check_inputs(X: NDArray, Y: NDArray) -> tuple[NDArray, NDArray]:
"""Verify that X and Y are valid inputs in preparation for ordination."""
if not is_2d_numeric_array(arr=X):
if not is_2d_numeric_array(X):
raise ValueError("X must be a 2D numeric numpy array")

if not is_2d_numeric_array(arr=Y):
if not is_2d_numeric_array(Y):
raise ValueError("Y must be a 2D numeric numpy array")

if X.shape[0] != Y.shape[0]:
raise ValueError("X and Y must have the same number of rows")

if np.any(zero_sum_vectors(arr=X, axis=1)):
if np.any(zero_sum_vectors(X, axis=1)):
raise ValueError("All row sums in X must be greater than 0")

excluded_columns = zero_sum_vectors(arr=Y, axis=0)
excluded_columns = zero_sum_vectors(Y, axis=0)
return X, Y[:, ~excluded_columns]

@property
Expand Down
2 changes: 1 addition & 1 deletion src/sknnr/transformers/_constrained_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fit(self, X, y):
method_cls = self.CONSTRAINED_METHODS.get(self.constrained_method)
if method_cls is None:
raise ValueError(
f"`method` must be one of {self.CONSTRAINED_METHODS.keys()}, not"
f"`method` must be one of {list(self.CONSTRAINED_METHODS.keys())}, not"
f" {self.constrained_method}."
)
self.ordination_ = method_cls(X, y)
Expand Down

0 comments on commit a5d0d84

Please sign in to comment.