Skip to content

Commit

Permalink
fix sklearn warning (#2442)
Browse files Browse the repository at this point in the history
* fix sklearn warning

* lint
  • Loading branch information
matanper committed Apr 10, 2023
1 parent 9449798 commit d246a34
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepchecks/tabular/metric_utils/scorers.py
Expand Up @@ -17,6 +17,8 @@

import numpy as np
import pandas as pd
from packaging import version
from sklearn import __version__ as scikit_version
from sklearn.base import ClassifierMixin
from sklearn.metrics import get_scorer, log_loss, make_scorer, mean_absolute_error, mean_squared_error
from sklearn.metrics._scorer import _BaseScorer, _ProbaScorer
Expand Down Expand Up @@ -442,7 +444,9 @@ def _transform_to_multi_label_format(y: np.ndarray, classes):
# Some classifiers like catboost might return shape like (n_rows, 1), therefore squeezing the array.
y = np.squeeze(y) if y.ndim > 1 else y
if y.ndim == 1:
ohe = OneHotEncoder(handle_unknown='ignore', sparse=False)
kwargs = {'sparse_output': False} if version.parse(scikit_version) >= version.parse('1.2') \
else {'sparse': False}
ohe = OneHotEncoder(handle_unknown='ignore', **kwargs) # pylint: disable=unexpected-keyword-arg
ohe.fit(np.array(classes).reshape(-1, 1))
return ohe.transform(y.reshape(-1, 1))
# If after squeeze there are still 2 dimensions, then it must have column for each model class.
Expand Down

0 comments on commit d246a34

Please sign in to comment.