Skip to content

Commit

Permalink
Update for function input validation (#1347)
Browse files Browse the repository at this point in the history
* Update for function input validation

* Add sklearn 1.3 version to CI matrix
  • Loading branch information
Alexsandruss committed Jun 30, 2023
1 parent 40ccd1c commit c40bdda
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
9 changes: 9 additions & 0 deletions .ci/pipeline/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ jobs:
Python3.11_Sklearn1.2:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.2'
Python3.11_Sklearn1.3:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.3'
pool:
vmImage: 'ubuntu-22.04'
steps:
Expand All @@ -101,6 +104,9 @@ jobs:
Python3.11_Sklearn1.2:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.2'
Python3.11_Sklearn1.3:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.3'
pool:
vmImage: 'macos-12'
steps:
Expand All @@ -124,6 +130,9 @@ jobs:
Python3.11_Sklearn1.2:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.2'
Python3.11_Sklearn1.3:
PYTHON_VERSION: '3.11'
SKLEARN_VERSION: '1.3'
pool:
vmImage: 'windows-latest'
steps:
Expand Down
4 changes: 1 addition & 3 deletions daal4py/sklearn/metrics/_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,12 @@ def pairwise_distances(X, Y=None, metric="euclidean", n_jobs=None,


if sklearn_check_version('1.3'):
validation_kwargs = {'prefer_skip_nested_validation': True} \
if sklearn_check_version('1.4') else {}
pairwise_distances = validate_params(
{
"X": ["array-like", "sparse matrix"],
"Y": ["array-like", "sparse matrix", None],
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
"n_jobs": [Integral, None],
"force_all_finite": ["boolean", StrOptions({"allow-nan"})],
}, **validation_kwargs
}, prefer_skip_nested_validation=True
)(pairwise_distances)
4 changes: 1 addition & 3 deletions daal4py/sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def roc_auc_score(


if sklearn_check_version('1.3'):
validation_kwargs = {'prefer_skip_nested_validation': True} \
if sklearn_check_version('1.4') else {}
roc_auc_score = validate_params(
{
"y_true": ["array-like"],
Expand All @@ -194,5 +192,5 @@ def roc_auc_score(
"max_fpr": [Interval(Real, 0.0, 1, closed="right"), None],
"multi_class": [StrOptions({"raise", "ovr", "ovo"})],
"labels": ["array-like", None],
}, **validation_kwargs
}, prefer_skip_nested_validation=True
)(roc_auc_score)
4 changes: 1 addition & 3 deletions daal4py/sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,6 @@ def train_test_split(*arrays, **options):


if sklearn_check_version('1.3'):
validation_kwargs = {'prefer_skip_nested_validation': True} \
if sklearn_check_version('1.4') else {}
train_test_split = validate_params({
"test_size": [
Interval(RealNotInt, 0, 1, closed="neither"),
Expand All @@ -274,4 +272,4 @@ def train_test_split(*arrays, **options):
"random_state": ["random_state"],
"shuffle": ["boolean"],
"stratify": ["array-like", None],
}, **validation_kwargs)(train_test_split)
}, prefer_skip_nested_validation=True)(train_test_split)

0 comments on commit c40bdda

Please sign in to comment.