Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding New metrics to ruptures #283

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python -m pip install --editable .[dev]
Note that `python -m` can be omitted most of the times, but within virtualenvs, it can prevent certain errors.
Also, in certain terminals (such as `zsh`), the square brackets must be escaped, e.g. replace `.[dev]` by `.\[dev\]`.

In addition to `numpy`, `scipy` and `ruptures`, this command will install all packages needed to develop `ruptures`.
In addition to `numpy`, `scipy`, `scikit-learn` and `ruptures`, this command will install all packages needed to develop `ruptures`.
The exact list of librairies can be found in the [`setup.cfg` file](https://github.com/deepcharles/ruptures/blob/master/setup.cfg) (section `[options.extras_require]`).

### Pre-commit hooks
Expand Down
2 changes: 1 addition & 1 deletion docs/install.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

This library requires Python >=3.6 and the following packages: `numpy`, `scipy` and `matplotlib` (the last one is optional and only for display purposes).
This library requires Python >=3.6 and the following packages: `numpy`, `scipy`, `scikit-learn` and `matplotlib` (the last one is optional and only for display purposes).
You can either install the latest stable release or the development version.

## Stable release
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ requires = [
"setuptools_scm[toml]>=3.4", # https://scikit-hep.org/developer/packaging#git-tags-official-pypa-method
"oldest-supported-numpy", # https://github.com/scipy/oldest-supported-numpy
"scipy>=0.19.1",
"scikit-learn>=1.0",
]
build-backend = "setuptools.build_meta"

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ python_requires = >= 3.6
install_requires =
numpy
scipy
scikit-learn
packages = find:
package_dir =
=src
Expand Down
1 change: 1 addition & 0 deletions src/ruptures/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .precisionrecall import precision_recall
from .hamming import hamming
from .randindex import randindex
from .adjusted_randindex import adjusted_randindex
41 changes: 41 additions & 0 deletions src/ruptures/metrics/adjusted_randindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
r"""Adjusted Rand index (`adjusted_randindex`)"""
import numpy as np
from ruptures.metrics.sanity_check import sanity_check
from sklearn.metrics import adjusted_rand_score


def chpt_to_label(bkps):
"""Return the segment index each sample belongs to.

Example:
-------
>>> chpt_to_label([4, 10])
array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
duration = np.diff([0] + bkps)
return np.repeat(np.arange(len(bkps)), duration)


def adjusted_randindex(bkps1, bkps2):
"""Compute the adjusted Rand index (between -0.5 and 1.) between two
segmentations.

The Rand index (RI) measures the similarity between two segmentations and
is equal to the proportion of aggreement between two partitions.

The metric implemented here is RI variant, adjusted for chance, and based
on [scikit-learn's implementation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html).

Args:
----
bkps1 (list): sorted list of the last index of each regime.
bkps2 (list): sorted list of the last index of each regime.

Return:
------
float: Adjusted Rand index
""" # noqa E501
sanity_check(bkps1, bkps2)
label1 = chpt_to_label(bkps1)
label2 = chpt_to_label(bkps2)
return adjusted_rand_score(label1, label2)
9 changes: 9 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
meantime,
precision_recall,
randindex,
adjusted_randindex,
)
from ruptures.metrics.sanity_check import BadPartitions

Expand All @@ -31,6 +32,14 @@ def test_randindex(b_mb):
assert m == 1


def test_adjusted_randindex(b_mb):
b, mb = b_mb
m = adjusted_randindex(b, mb)
assert 1 > m > -0.5
m = adjusted_randindex(b, b)
assert m == 1


def test_meantime(b_mb):
b, mb = b_mb
m = meantime(b, mb)
Expand Down