forked from scikit-learn-contrib/imbalanced-learn
/
test_edited_nearest_neighbours.py
115 lines (91 loc) · 4.48 KB
/
test_edited_nearest_neighbours.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Test the module edited nearest neighbour."""
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# Christos Aridas
# License: MIT
import pytest
import numpy as np
from sklearn.utils.testing import assert_array_equal
from sklearn.neighbors import NearestNeighbors
from imblearn.under_sampling import EditedNearestNeighbours
from imblearn.utils.testing import warns
X = np.array([[2.59928271, 0.93323465], [0.25738379, 0.95564169], [
1.42772181, 0.526027
], [1.92365863, 0.82718767], [-0.10903849,
-0.12085181], [-0.284881, -0.62730973],
[0.57062627, 1.19528323], [0.03394306,
0.03986753], [0.78318102, 2.59153329],
[0.35831463, 1.33483198], [-0.14313184, -1.0412815], [
0.01936241, 0.17799828
], [-1.25020462, -0.40402054], [-0.09816301, -0.74662486], [
-0.01252787, 0.34102657
], [0.52726792, -0.38735648], [0.2821046, -0.07862747], [
0.05230552, 0.09043907
], [0.15198585, 0.12512646], [0.70524765, 0.39816382]])
Y = np.array([1, 2, 1, 1, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 2, 1])
def test_enn_init():
enn = EditedNearestNeighbours()
assert enn.n_neighbors == 3
assert enn.kind_sel == 'all'
assert enn.n_jobs == 1
def test_enn_fit_resample():
enn = EditedNearestNeighbours()
X_resampled, y_resampled = enn.fit_resample(X, Y)
X_gt = np.array([[-0.10903849, -0.12085181], [0.01936241, 0.17799828], [
2.59928271, 0.93323465
], [1.92365863, 0.82718767], [0.25738379, 0.95564169],
[0.78318102, 2.59153329], [0.52726792, -0.38735648]])
y_gt = np.array([0, 0, 1, 1, 2, 2, 2])
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)
@pytest.mark.filterwarnings("ignore:'return_indices' is deprecated from 0.4")
def test_enn_fit_resample_with_indices():
enn = EditedNearestNeighbours(return_indices=True)
X_resampled, y_resampled, idx_under = enn.fit_resample(X, Y)
X_gt = np.array([[-0.10903849, -0.12085181], [0.01936241, 0.17799828], [
2.59928271, 0.93323465
], [1.92365863, 0.82718767], [0.25738379, 0.95564169],
[0.78318102, 2.59153329], [0.52726792, -0.38735648]])
y_gt = np.array([0, 0, 1, 1, 2, 2, 2])
idx_gt = np.array([4, 11, 0, 3, 1, 8, 15])
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)
assert_array_equal(idx_under, idx_gt)
def test_enn_fit_resample_mode():
enn = EditedNearestNeighbours(kind_sel='mode')
X_resampled, y_resampled = enn.fit_resample(X, Y)
X_gt = np.array([[-0.10903849, -0.12085181], [0.01936241, 0.17799828], [
2.59928271, 0.93323465
], [1.42772181, 0.526027], [1.92365863, 0.82718767], [
0.25738379, 0.95564169
], [-0.284881, -0.62730973], [0.57062627, 1.19528323],
[0.78318102, 2.59153329], [0.35831463, 1.33483198],
[-0.14313184, -1.0412815], [-0.09816301, -0.74662486],
[0.52726792, -0.38735648], [0.2821046, -0.07862747]])
y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)
def test_enn_fit_resample_with_nn_object():
nn = NearestNeighbors(n_neighbors=4)
enn = EditedNearestNeighbours(n_neighbors=nn, kind_sel='mode')
X_resampled, y_resampled = enn.fit_resample(X, Y)
X_gt = np.array([[-0.10903849, -0.12085181], [0.01936241, 0.17799828], [
2.59928271, 0.93323465
], [1.42772181, 0.526027], [1.92365863, 0.82718767], [
0.25738379, 0.95564169
], [-0.284881, -0.62730973], [0.57062627, 1.19528323],
[0.78318102, 2.59153329], [0.35831463, 1.33483198],
[-0.14313184, -1.0412815], [-0.09816301, -0.74662486],
[0.52726792, -0.38735648], [0.2821046, -0.07862747]])
y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2])
assert_array_equal(X_resampled, X_gt)
assert_array_equal(y_resampled, y_gt)
def test_enn_not_good_object():
nn = 'rnd'
enn = EditedNearestNeighbours(n_neighbors=nn, kind_sel='mode')
with pytest.raises(ValueError, match="has to be one of"):
enn.fit_resample(X, Y)
def test_deprecation_random_state():
enn = EditedNearestNeighbours(random_state=0)
with warns(
DeprecationWarning, match="'random_state' is deprecated from 0.4"):
enn.fit_resample(X, Y)