forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_base.py
111 lines (82 loc) · 3.06 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Author: Gael Varoquaux
# License: BSD
from nose.tools import assert_true, assert_false, assert_equal, \
assert_raises
from ..base import BaseEstimator, clone, is_classifier
################################################################################
# A few test classes
class MyEstimator(BaseEstimator):
def __init__(self, l1=0):
self.l1 = l1
class K(BaseEstimator):
def __init__(self, c=None, d=None):
self.c = c
self.d = d
class T(BaseEstimator):
def __init__(self, a=None, b=None):
self.a = a
self.b = b
class Buggy(BaseEstimator):
" A buggy estimator that does not set its parameters right. "
def __init__(self, a=None):
self.a = 1
################################################################################
# The tests
def test_clone():
"""Tests that clone creates a correct deep copy.
We create an estimator, make a copy of its original state
(which, in this case, is the current state of the setimator),
and check that the obtained copy is a correct deep copy.
"""
from sklearn.feature_selection import SelectFpr, f_classif
selector = SelectFpr(f_classif, alpha=0.1)
new_selector = clone(selector)
assert_true(selector is not new_selector)
assert_equal(selector._get_params(), new_selector._get_params())
def test_clone_2():
"""Tests that clone doesn't copy everything.
We first create an estimator, give it an own attribute, and
make a copy of its original state. Then we check that the copy doesn't have
the specific attribute we manually added to the initial estimator.
"""
from sklearn.feature_selection import SelectFpr, f_classif
selector = SelectFpr(f_classif, alpha=0.1)
selector.own_attribute = "test"
new_selector = clone(selector)
assert_false(hasattr(new_selector, "own_attribute"))
def test_clone_buggy():
""" Check that clone raises an error on buggy estimators """
buggy = Buggy()
buggy.a = 2
assert_raises(AssertionError, clone, buggy)
def test_repr():
""" Smoke test the repr of the
"""
my_estimator = MyEstimator()
repr(my_estimator)
test = T(K(), K())
assert_equal(repr(test),
"T(a=K(c=None, d=None), b=K(c=None, d=None))"
)
def test_str():
""" Smoke test the str of the
"""
my_estimator = MyEstimator()
str(my_estimator)
def test_get_params():
test = T(K(), K())
assert_true('a__d' in test._get_params(deep=True))
assert_true('a__d' not in test._get_params(deep=False))
test.set_params(a__d=2)
assert test.a.d == 2
assert_raises(AssertionError, test.set_params, a__a=2)
def test_is_classifier():
from ..svm import SVC
from ..pipeline import Pipeline
from ..grid_search import GridSearchCV
svc = SVC()
assert_true(is_classifier(svc))
assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]})))
assert_true(is_classifier(Pipeline([('svc', svc)])))
assert_true(is_classifier(Pipeline([('svc_cv',
GridSearchCV(svc, {'C': [0.1, 1]}))])))