Skip to content

Commit

Permalink
[ENH] Rework of base series annotator API (sktime#6265)
Browse files Browse the repository at this point in the history
See sktime#3214 .


#### What does this implement/fix? Explain your changes.
<!--
A clear and concise description of what you have implemented.
-->

Reworks and refactors the `BaseSeriesAnnotator` class:

* Removes the `fmt` and `label` attributes.
* Adds the `learning_type` and `task` attributes.
* Adds default `predict` and `transform` method.
* Adds default `predict_points` and `predict_segments` methods.
* Adds methods for converting  between dense and sparse output formats.
* Adds default methods for `predict` and `transform`.

Also carries out a move of all existing annotation estimators to the new interface.
  • Loading branch information
Alex-JG3 authored and geetu040 committed Jun 4, 2024
1 parent ba1dcb0 commit 224af14
Show file tree
Hide file tree
Showing 12 changed files with 681 additions and 79 deletions.
21 changes: 7 additions & 14 deletions extension_templates/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ class MySeriesAnnotator(BaseSeriesAnnotator):
Parameters
----------
fmt : str {"dense", "sparse"}, optional (default="dense")
Annotation output format:
* If "sparse", a sub-series of labels for only the outliers in X is returned,
* If "dense", a series of labels for all values in X is returned.
labels : str {"indicator", "score"}, optional (default="indicator")
Annotation output labels:
* If "indicator", returned values are boolean, indicating whether a value is an
outlier,
* If "score", returned values are floats, giving the outlier score.
parama : int
descriptive explanation of parama
paramb : string, optional (default='default')
Expand All @@ -69,6 +59,12 @@ class MySeriesAnnotator(BaseSeriesAnnotator):
and so on
"""

# Change the `task` and `learning_type` as needed
_tags = {
"task": "segmentation",
"learning_type": "unsupervised",
}

# todo: add any hyper-parameters and components to constructor
def __init__(
self,
Expand All @@ -77,8 +73,6 @@ def __init__(
est2=None,
paramb="default",
paramc=None,
fmt="dense",
labels="indicator",
):
# estimators should precede parameters
# if estimators have default values, set None and initialize below
Expand All @@ -89,8 +83,7 @@ def __init__(
self.paramb = paramb
self.paramc = paramc

# leave this as is
super().__init__(fmt=fmt, labels=labels)
super().__init__()

# todo: optional, parameter checking logic (if applicable) should happen here
# if writes derived values to self, should *not* overwrite self.parama etc
Expand Down
40 changes: 26 additions & 14 deletions sktime/annotation/adapters/_pyod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sktime.annotation.base._base import BaseSeriesAnnotator
from sktime.utils.validation._dependencies import _check_soft_dependencies
from sktime.utils.warnings import warn

__author__ = ["mloning", "satya-pattnaik", "fkiraly"]

Expand All @@ -21,22 +22,34 @@ class PyODAnnotator(BaseSeriesAnnotator):
estimator : PyOD estimator
See ``https://pyod.readthedocs.io/en/latest/`` documentation for a detailed
description of all options.
fmt : str {"dense", "sparse"}, optional (default="dense")
Annotation output format:
* If "sparse", a sub-series of labels for only the outliers in X is returned,
* If "dense", a series of labels for all values in X is returned.
labels : str {"indicator", "score"}, optional (default="indicator")
Annotation output labels:
* If "indicator", returned values are boolean, indicating whether a value is an
outlier,
* If "score", returned values are floats, giving the outlier score.
"""

_tags = {"python_dependencies": "pyod"}
_tags = {
"python_dependencies": "pyod",
"task": "anomaly_detection",
"learning_type": "unsupervised",
}

def __init__(self, estimator, fmt="dense", labels="indicator"):
# todo 0.31.0: remove fmt argument and warning
def __init__(self, estimator, fmt="deprecated", labels="indicator"):
self.estimator = estimator # pyod estimator
super().__init__(fmt=fmt, labels=labels)
self.fmt = fmt
self.labels = labels

super().__init__()

if fmt == "deprecated":
self._fmt = "sparse"
warn(
f"Warning from {type(self).__name__}: fmt argument will be removed in"
" 0.31.0. For behaviour equivalent to fmt=dense, use transform instead "
"of predict. In 0.31.0 the behaviour of predict will equivalent to the"
" current behaviour of predict when fmt=sparse.",
DeprecationWarning,
obj=self,
)
else:
self._fmt = fmt

def _fit(self, X, Y=None):
"""Fit to training data.
Expand Down Expand Up @@ -77,9 +90,8 @@ def _predict(self, X):
Returns
-------
Y : pd.Series - annotations for sequence X
exact format depends on annotation type
"""
fmt = self.fmt
fmt = self._fmt
labels = self.labels

X_np = X.to_numpy()
Expand Down
Loading

0 comments on commit 224af14

Please sign in to comment.