-
Notifications
You must be signed in to change notification settings - Fork 161
/
base.py
70 lines (54 loc) · 1.91 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
r"""All estimators and cost functions are subclasses of.
[`BaseEstimator`][ruptures.base.BaseEstimator] and
[`BaseCost`][ruptures.base.BaseCost] respectively.
"""
import abc
from ruptures.utils import pairwise
class BaseEstimator(metaclass=abc.ABCMeta):
"""Base class for all change point detection estimators.
Notes:
All estimators should specify all the parameters that can be set
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
"""
@abc.abstractmethod
def fit(self, *args, **kwargs):
"""To call the segmentation algorithm."""
pass
@abc.abstractmethod
def predict(self, *args, **kwargs):
"""To call the segmentation algorithm."""
pass
@abc.abstractmethod
def fit_predict(self, *args, **kwargs):
"""To call the segmentation algorithm."""
pass
class BaseCost(object, metaclass=abc.ABCMeta):
"""Base class for all segment cost classes.
Notes:
All classes should specify all the parameters that can be set
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
"""
@abc.abstractmethod
def fit(self, *args, **kwargs):
"""Set the parameters of the cost function, for instance the Gram
matrix, etc."""
pass
@abc.abstractmethod
def error(self, start, end):
"""Returns the cost on segment [start:end]."""
pass
def sum_of_costs(self, bkps):
"""Returns the sum of segments cost for the given segmentation.
Args:
bkps (list): list of change points. By convention, bkps[-1]==n_samples.
Returns:
float: sum of costs
"""
soc = sum(self.error(start, end) for start, end in pairwise([0] + bkps))
return soc
@property
@abc.abstractmethod
def model(self):
pass