Skip to content

Commit aa0c473

Browse files
authored
[ENH] sktime integration - time series classification (#173)
Adds integration for another learning task in `sktime` - time series classification. * `SktimeClassificationExperiment` experiment * `TSCOptCV` TSC tuning algorithm that takes any `hyperactive` tuner Since time series classification uses ordinary sklearn metrics, the `experiments.integration` contents are refactored, moving utilities concerned with `sklearn` metrics into a separate, private file `_skl_metrics`.
1 parent 13537d1 commit aa0c473

File tree

7 files changed

+807
-96
lines changed

7 files changed

+807
-96
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ jobs:
9999
exclude:
100100
- os: "windows-latest"
101101
python-version: "3.13"
102-
102+
103103
fail-fast: false
104104

105105
runs-on: ${{ matrix.os }}

src/hyperactive/experiment/integrations/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,15 @@
22
# copyright: hyperactive developers, MIT License (see LICENSE file)
33

44
from hyperactive.experiment.integrations.sklearn_cv import SklearnCvExperiment
5+
from hyperactive.experiment.integrations.sktime_classification import (
6+
SktimeClassificationExperiment,
7+
)
58
from hyperactive.experiment.integrations.sktime_forecasting import (
69
SktimeForecastingExperiment,
710
)
811

9-
__all__ = ["SklearnCvExperiment", "SktimeForecastingExperiment"]
12+
__all__ = [
13+
"SklearnCvExperiment",
14+
"SktimeClassificationExperiment",
15+
"SktimeForecastingExperiment",
16+
]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Integration utilities for sklearn metrics with Hyperactive."""
2+
3+
__all__ = ["_coerce_to_scorer", "_guess_sign_of_sklmetric"]
4+
5+
6+
def _coerce_to_scorer(scoring, estimator):
7+
"""Coerce scoring argument into a sklearn scorer.
8+
9+
Parameters
10+
----------
11+
scoring : str, callable, or None
12+
The scoring strategy to use.
13+
estimator : estimator object or str
14+
The estimator to use for default scoring if scoring is None.
15+
16+
If str, indicates estimator type, should be one of {"classifier", "regressor"}.
17+
18+
Returns
19+
-------
20+
scorer : callable
21+
A sklearn scorer callable.
22+
Follows the unified sklearn scorer interface
23+
"""
24+
from sklearn.metrics import check_scoring
25+
26+
# check if scoring is a scorer by checking for "estimator" in signature
27+
if scoring is None:
28+
if isinstance(estimator, str):
29+
if estimator == "classifier":
30+
from sklearn.metrics import accuracy_score
31+
32+
scoring = accuracy_score
33+
elif estimator == "regressor":
34+
from sklearn.metrics import r2_score
35+
36+
scoring = r2_score
37+
else:
38+
return check_scoring(estimator)
39+
40+
# check using inspect.signature for "estimator" in signature
41+
if callable(scoring):
42+
from inspect import signature
43+
44+
if "estimator" in signature(scoring).parameters:
45+
return scoring
46+
else:
47+
from sklearn.metrics import make_scorer
48+
49+
return make_scorer(scoring)
50+
else:
51+
# scoring is a string (scorer name)
52+
return check_scoring(estimator, scoring=scoring)
53+
54+
55+
def _guess_sign_of_sklmetric(scorer):
56+
"""Guess the sign of a sklearn metric scorer.
57+
58+
Parameters
59+
----------
60+
scorer : callable
61+
The sklearn metric scorer to guess the sign for.
62+
63+
Returns
64+
-------
65+
int
66+
1 if higher scores are better, -1 if lower scores are better.
67+
"""
68+
HIGHER_IS_BETTER = {
69+
# Classification
70+
"accuracy_score": True,
71+
"auc": True,
72+
"average_precision_score": True,
73+
"balanced_accuracy_score": True,
74+
"brier_score_loss": False,
75+
"class_likelihood_ratios": False,
76+
"cohen_kappa_score": True,
77+
"d2_log_loss_score": True,
78+
"dcg_score": True,
79+
"f1_score": True,
80+
"fbeta_score": True,
81+
"hamming_loss": False,
82+
"hinge_loss": False,
83+
"jaccard_score": True,
84+
"log_loss": False,
85+
"matthews_corrcoef": True,
86+
"ndcg_score": True,
87+
"precision_score": True,
88+
"recall_score": True,
89+
"roc_auc_score": True,
90+
"top_k_accuracy_score": True,
91+
"zero_one_loss": False,
92+
# Regression
93+
"d2_absolute_error_score": True,
94+
"d2_pinball_score": True,
95+
"d2_tweedie_score": True,
96+
"explained_variance_score": True,
97+
"max_error": False,
98+
"mean_absolute_error": False,
99+
"mean_absolute_percentage_error": False,
100+
"mean_gamma_deviance": False,
101+
"mean_pinball_loss": False,
102+
"mean_poisson_deviance": False,
103+
"mean_squared_error": False,
104+
"mean_squared_log_error": False,
105+
"mean_tweedie_deviance": False,
106+
"median_absolute_error": False,
107+
"r2_score": True,
108+
"root_mean_squared_error": False,
109+
"root_mean_squared_log_error": False,
110+
}
111+
112+
scorer_name = getattr(scorer, "__name__", None)
113+
114+
if hasattr(scorer, "greater_is_better"):
115+
return 1 if scorer.greater_is_better else -1
116+
elif scorer_name in HIGHER_IS_BETTER:
117+
return 1 if HIGHER_IS_BETTER[scorer_name] else -1
118+
elif scorer_name.endswith("_score"):
119+
# If the scorer name ends with "_score", we assume higher is better
120+
return 1
121+
elif scorer_name.endswith("_loss") or scorer_name.endswith("_deviance"):
122+
# If the scorer name ends with "_loss", we assume lower is better
123+
return -1
124+
elif scorer_name.endswith("_error"):
125+
return -1
126+
else:
127+
# If we cannot determine the sign, we assume lower is better
128+
return -1

src/hyperactive/experiment/integrations/sklearn_cv.py

Lines changed: 5 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
# copyright: hyperactive developers, MIT License (see LICENSE file)
44

55
from sklearn import clone
6-
from sklearn.metrics import check_scoring
76
from sklearn.model_selection import cross_validate
87
from sklearn.utils.validation import _num_samples
98

109
from hyperactive.base import BaseExperiment
10+
from hyperactive.experiment.integrations._skl_metrics import (
11+
_coerce_to_scorer,
12+
_guess_sign_of_sklmetric,
13+
)
1114

1215

1316
class SklearnCvExperiment(BaseExperiment):
@@ -97,22 +100,7 @@ def __init__(self, estimator, X, y, scoring=None, cv=None):
97100
else:
98101
self._cv = cv
99102

100-
# check if scoring is a scorer by checking for "estimator" in signature
101-
if scoring is None:
102-
self._scoring = check_scoring(self.estimator)
103-
# check using inspect.signature for "estimator" in signature
104-
elif callable(scoring):
105-
from inspect import signature
106-
107-
if "estimator" in signature(scoring).parameters:
108-
self._scoring = scoring
109-
else:
110-
from sklearn.metrics import make_scorer
111-
112-
self._scoring = make_scorer(scoring)
113-
else:
114-
# scoring is a string (scorer name)
115-
self._scoring = check_scoring(self.estimator, scoring=scoring)
103+
self._scoring = _coerce_to_scorer(scoring, self.estimator)
116104
self.scorer_ = self._scoring
117105

118106
# Set the sign of the scoring function
@@ -281,79 +269,3 @@ def _get_score_params(self):
281269
score_params_defaults,
282270
]
283271
return params
284-
285-
286-
def _guess_sign_of_sklmetric(scorer):
287-
"""Guess the sign of a sklearn metric scorer.
288-
289-
Parameters
290-
----------
291-
scorer : callable
292-
The sklearn metric scorer to guess the sign for.
293-
294-
Returns
295-
-------
296-
int
297-
1 if higher scores are better, -1 if lower scores are better.
298-
"""
299-
HIGHER_IS_BETTER = {
300-
# Classification
301-
"accuracy_score": True,
302-
"auc": True,
303-
"average_precision_score": True,
304-
"balanced_accuracy_score": True,
305-
"brier_score_loss": False,
306-
"class_likelihood_ratios": False,
307-
"cohen_kappa_score": True,
308-
"d2_log_loss_score": True,
309-
"dcg_score": True,
310-
"f1_score": True,
311-
"fbeta_score": True,
312-
"hamming_loss": False,
313-
"hinge_loss": False,
314-
"jaccard_score": True,
315-
"log_loss": False,
316-
"matthews_corrcoef": True,
317-
"ndcg_score": True,
318-
"precision_score": True,
319-
"recall_score": True,
320-
"roc_auc_score": True,
321-
"top_k_accuracy_score": True,
322-
"zero_one_loss": False,
323-
# Regression
324-
"d2_absolute_error_score": True,
325-
"d2_pinball_score": True,
326-
"d2_tweedie_score": True,
327-
"explained_variance_score": True,
328-
"max_error": False,
329-
"mean_absolute_error": False,
330-
"mean_absolute_percentage_error": False,
331-
"mean_gamma_deviance": False,
332-
"mean_pinball_loss": False,
333-
"mean_poisson_deviance": False,
334-
"mean_squared_error": False,
335-
"mean_squared_log_error": False,
336-
"mean_tweedie_deviance": False,
337-
"median_absolute_error": False,
338-
"r2_score": True,
339-
"root_mean_squared_error": False,
340-
"root_mean_squared_log_error": False,
341-
}
342-
343-
scorer_name = getattr(scorer, "__name__", None)
344-
345-
if hasattr(scorer, "greater_is_better"):
346-
return 1 if scorer.greater_is_better else -1
347-
elif scorer_name in HIGHER_IS_BETTER:
348-
return 1 if HIGHER_IS_BETTER[scorer_name] else -1
349-
elif scorer_name.endswith("_score"):
350-
# If the scorer name ends with "_score", we assume higher is better
351-
return 1
352-
elif scorer_name.endswith("_loss") or scorer_name.endswith("_deviance"):
353-
# If the scorer name ends with "_loss", we assume lower is better
354-
return -1
355-
elif scorer_name.endswith("_error"):
356-
return -1
357-
else:
358-
# If we cannot determine the sign, we assume lower is better
359-
return -1

0 commit comments

Comments
 (0)