-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
helper.py
131 lines (100 loc) · 3.53 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import contextlib
import sys
import unittest
import warnings
import pkg_resources
try:
import mock
_mock_error = None
except ImportError as e:
_mock_error = e
def _check_mock_available():
if _mock_error is not None:
raise RuntimeError(
'mock is not available: Reason: {}'.format(_mock_error))
def with_requires(*requirements):
"""Run a test case only when given requirements are satisfied.
.. admonition:: Example
This test case runs only when `numpy>=1.10` is installed.
>>> import unittest
>>> from chainer import testing
>>> class Test(unittest.TestCase):
... @testing.with_requires('numpy>=1.10')
... def test_for_numpy_1_10(self):
... pass
Args:
requirements: A list of string representing requirement condition to
run a given test case.
"""
ws = pkg_resources.WorkingSet()
try:
ws.require(*requirements)
skip = False
except pkg_resources.ResolutionError:
skip = True
msg = 'requires: {}'.format(','.join(requirements))
return unittest.skipIf(skip, msg)
def without_requires(*requirements):
"""Run a test case only when given requirements are not satisfied.
.. admonition:: Example
This test case runs only when `numpy>=1.10` is not installed.
>>> from chainer import testing
... class Test(unittest.TestCase):
... @testing.without_requires('numpy>=1.10')
... def test_without_numpy_1_10(self):
... pass
Args:
requirements: A list of string representing requirement condition to
run a given test case.
"""
ws = pkg_resources.WorkingSet()
try:
ws.require(*requirements)
skip = True
except pkg_resources.ResolutionError:
skip = False
msg = 'requires: {}'.format(','.join(requirements))
return unittest.skipIf(skip, msg)
@contextlib.contextmanager
def assert_warns(expected):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
yield
# Python 2 does not raise warnings multiple times from the same stack
# frame.
if sys.version_info >= (3, 0):
if not any(isinstance(m.message, expected) for m in w):
try:
exc_name = expected.__name__
except AttributeError:
exc_name = str(expected)
raise AssertionError('%s not triggerred' % exc_name)
def _import_object_from_name(fullname):
comps = fullname.split('.')
obj = sys.modules.get(comps[0])
if obj is None:
raise RuntimeError('Can\'t import {}'.format(comps[0]))
for i, comp in enumerate(comps[1:]):
obj = getattr(obj, comp)
if obj is None:
raise RuntimeError(
'Can\'t find object {}'.format('.'.join(comps[:i + 1])))
return obj
def patch(target, *args, **kwargs):
"""A wrapper of mock.patch which appends wraps argument.
.. note::
Unbound methods are not supported as ``wraps`` argument.
Args:
target(str): Full name of target object.
wraps: Wrapping object which will be passed to ``mock.patch`` as
``wraps`` argument.
If omitted, the object specified by ``target`` is used.
*args: Passed to ``mock.patch``.
**kwargs: Passed to ``mock.patch``.
"""
_check_mock_available()
try:
wraps = kwargs.pop('wraps')
except KeyError:
wraps = _import_object_from_name(target)
return mock.patch(target, *args, wraps=wraps, **kwargs)