Skip to content

Commit

Permalink
fix: issue with failing check_array if dataframe contains objects and…
Browse files Browse the repository at this point in the history
… nan
  • Loading branch information
chrislemke committed Jan 12, 2023
1 parent 91f9712 commit b99a734
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
12 changes: 10 additions & 2 deletions src/sk_transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def check_ready_to_transform(
X: pd.DataFrame,
features: Optional[Union[str, List[str]]] = None,
force_all_finite: Union[bool, str] = True,
dtype: Optional[Union[str, List[str]]] = None,
) -> pd.DataFrame:
"""
Args:
Expand All @@ -20,6 +21,10 @@ def check_ready_to_transform(
- True: Force all values of array to be finite.
- False: accepts np.inf, np.nan, pd.NA in array.
- "allow-nan": accepts only np.nan and pd.NA values in array. Values cannot be infinite.
dtype (Optional[Union[str, List[str]]]): Data type of result. If None, the `dtype` of the input is preserved.
If "numeric", `dtype` is preserved unless `array.dtype` is object.
If dtype is a list of types, conversion on the first type is only performed if the dtype of the input
is not in the list.
Raises:
TypeError: If the input `transformer` is not a subclass of `BaseEstimator`.
Expand Down Expand Up @@ -60,10 +65,13 @@ def check_ready_to_transform(
See https://github.com/scikit-learn-contrib/project-template/blob/master/skltemplate/_template.py#L146 for an example/template.
"""
)

check_is_fitted(transformer, "fitted_")

X_tmp = check_array(
X, dtype=None, accept_sparse=True, force_all_finite=force_all_finite
X.to_numpy(),
dtype=dtype,
accept_large_sparse=False,
force_all_finite=force_all_finite,
)
X_tmp = pd.DataFrame(X_tmp, columns=X.columns, index=X.index)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_transformer/test_generic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_column_dropper_transformer_in_pipeline(X) -> None:
assert pipeline.steps[0][0] == "columndroppertransformer"


def test_nan_transform_in_pipeline(X_nan_values) -> None:
def test_nan_transformer_in_pipeline(X_nan_values) -> None:
pipeline = make_pipeline(NaNTransformer([("a", -1), ("b", -1), ("c", "missing")]))
X = pipeline.fit_transform(X_nan_values)

Expand Down

0 comments on commit b99a734

Please sign in to comment.