-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
parameterized.py
160 lines (123 loc) · 4.86 KB
/
parameterized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import functools
import itertools
import types
import unittest
import six
from chainer.testing import _bundle
from chainer import utils
def _param_to_str(obj):
if isinstance(obj, type):
return obj.__name__
return repr(obj)
def _shorten(s, maxlen):
# Shortens the string down to maxlen, by replacing the middle part with
# a 3-dots string '...'.
ellipsis = '...'
if len(s) <= maxlen:
return s
n1 = (maxlen - len(ellipsis)) // 2
n2 = maxlen - len(ellipsis) - n1
s = s[:n1] + ellipsis + s[-n2:]
assert len(s) == maxlen
return s
def _make_class_name(base_class_name, i_param, param):
# Creates a class name for a single combination of parameters.
SINGLE_PARAM_MAXLEN = 100 # Length limit of a single parameter value
PARAMS_MAXLEN = 5000 # Length limit of the whole parameters part
param_strs = [
'{}={}'.format(k, _shorten(_param_to_str(v), SINGLE_PARAM_MAXLEN))
for k, v in sorted(param.items())]
param_strs = _shorten(', '.join(param_strs), PARAMS_MAXLEN)
cls_name = '{}_param_{}_{{{}}}'.format(
base_class_name, i_param, param_strs)
return cls_name
def _parameterize_test_case_generator(base, params):
# Defines the logic to generate parameterized test case classes.
for i, param in enumerate(params):
cls_name = _make_class_name(base.__name__, i, param)
def __str__(self):
name = base.__str__(self)
return '%s parameter: %s' % (name, param)
mb = {'__str__': __str__}
for k, v in sorted(param.items()):
if isinstance(v, types.FunctionType):
def create_new_v():
f = v
def new_v(self, *args, **kwargs):
return f(*args, **kwargs)
return new_v
mb[k] = create_new_v()
else:
mb[k] = v
def method_generator(base_method):
# Generates a wrapped test method
# Bind to a new variable.
param2 = param
@functools.wraps(base_method)
def new_method(self, *args, **kwargs):
try:
return base_method(self, *args, **kwargs)
except unittest.SkipTest:
raise
except Exception as e:
s = six.StringIO()
s.write('Parameterized test failed.\n\n')
s.write('Base test method: {}.{}\n'.format(
base.__name__, base_method.__name__))
s.write('Test parameters:\n')
for k, v in sorted(param2.items()):
s.write(' {}: {}\n'.format(k, v))
utils._raise_from(e.__class__, s.getvalue(), e)
return new_method
yield (cls_name, mb, method_generator)
def parameterize(*params):
# TODO(niboshi): Add documentation
return _bundle.make_decorator(
lambda base: _parameterize_test_case_generator(base, params))
def _values_to_dicts(names, values):
assert isinstance(names, six.string_types)
assert isinstance(values, (tuple, list))
def safe_zip(ns, vs):
if len(ns) == 1:
return [(ns[0], vs)]
assert isinstance(vs, (tuple, list)) and len(ns) == len(vs)
return zip(ns, vs)
names = names.split(',')
params = [dict(safe_zip(names, value_list)) for value_list in values]
return params
def from_pytest_parameterize(names, values):
# Pytest-style parameterization.
# TODO(niboshi): Add documentation
return _values_to_dicts(names, values)
def parameterize_pytest(names, values):
# Pytest-style parameterization.
# TODO(niboshi): Add documentation
return parameterize(*from_pytest_parameterize(names, values))
def product(parameter):
# TODO(niboshi): Add documentation
if isinstance(parameter, dict):
return product([
_values_to_dicts(names, values)
for names, values in sorted(parameter.items())])
elif isinstance(parameter, list):
# list of lists of dicts
if not all(isinstance(_, list) for _ in parameter):
raise TypeError('parameter must be list of lists of dicts')
if not all(isinstance(_, dict) for l in parameter for _ in l):
raise TypeError('parameter must be list of lists of dicts')
lst = []
for dict_lst in itertools.product(*parameter):
a = {}
for d in dict_lst:
a.update(d)
lst.append(a)
return lst
else:
raise TypeError(
'parameter must be either dict or list. Actual: {}'.format(
type(parameter)))
def product_dict(*parameters):
# TODO(niboshi): Add documentation
return [
{k: v for dic in dicts for k, v in six.iteritems(dic)}
for dicts in itertools.product(*parameters)]