-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
353 lines (287 loc) · 10.8 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""Collection of methods used by various hvac test cases."""
import base64
import json
import logging
import operator
import os
import re
import socket
import subprocess
import typing as t
from shutil import which
from unittest import SkipTest, mock
from hvac import Client
from packaging.version import Version
logger = logging.getLogger(__name__)
VERSION_REGEX = re.compile(r"Vault v([0-9.]+)")
LATEST_VAULT_VERSION = "1.1.3"
def get_vault_version_string():
if "cache" in get_vault_version_string.__dict__:
return get_vault_version_string.cache
if not which("vault"):
raise SkipTest("Vault executable not found")
command = ["vault", "-version"]
process = subprocess.Popen(**get_popen_kwargs(args=command, stdout=subprocess.PIPE))
output, _ = process.communicate()
version_string = output.strip().split()[1].lstrip("v")
get_vault_version_string.cache = version_string
return version_string
def get_installed_vault_version():
version_string = get_vault_version_string()
# replace any '-beta1' type substrings with a StrictVersion parsable version. E.g., 1.0.0-beta1 => 1.0.0b1
version = version_string.replace("-", "").replace("beta", "b")
version = version.replace("+ent", "")
return version
def is_enterprise():
version_string = get_vault_version_string()
if re.search(r"\+ent$", version_string) is not None:
return True
return False
def if_vault_version(supported_version, comparison=operator.lt):
current_version = get_installed_vault_version()
return comparison(Version(current_version), Version(supported_version))
def vault_version_lt(supported_version):
return if_vault_version(supported_version, comparison=operator.lt)
def vault_version_ge(supported_version):
return if_vault_version(supported_version, comparison=operator.ge)
def vault_version_eq(supported_version):
return if_vault_version(supported_version, comparison=operator.eq)
def get_generate_root_otp():
"""Get a appropriate OTP for the current Vault version under test.
:return: OTP to use in generate root operations
:rtype: str
"""
if vault_version_ge("1.10.0"):
test_otp = "BMjzW3wAsEzINXCM05Wbas3u9zSl"
elif vault_version_ge("1.0.0"):
test_otp = "ygs0vL8GIxu0AjRVEmJ5jLCVq8"
else:
test_otp = "RSMGkAqBH5WnVLrDTbZ+UQ=="
return test_otp
def create_client(url, use_env=False, **kwargs):
"""Small helper to instantiate a :py:class:`hvac.v1.Client` class with the appropriate parameters for the test env.
:param url: Vault address to configure the client with.
:type url: str
:param use_env: configure vault using environment variable
:type use_env: bool
:param kwargs: Dictionary of additional keyword arguments to pass through to the Client instance being created.
:type kwargs: dict
:return: Instantiated :py:class:`hvac.v1.Client` class.
:rtype: hvac.v1.Client
"""
client_cert_path = get_config_file_path("client-cert.pem")
client_key_path = get_config_file_path("client-key.pem")
server_cert_path = get_config_file_path("server-cert.pem")
if use_env:
with mock.patch("hvac.v1.VAULT_CAPATH", server_cert_path):
with mock.patch("hvac.v1.VAULT_CLIENT_CERT", client_cert_path):
with mock.patch("hvac.v1.VAULT_CLIENT_KEY", client_key_path):
client = Client(
url=url,
**kwargs,
)
else:
client = Client(
url=url,
cert=(client_cert_path, client_key_path),
verify=server_cert_path,
**kwargs,
)
return client
class PortGetter:
_entered: bool = False
_sockets: t.List[socket.socket] = []
def __init__(self, default_address: str = "localhost"):
self._default_addr = default_address
class PortGetterProtocol(t.Protocol):
def __call__(
self,
*,
address: t.Optional[str] = None,
port: t.Optional[int] = None,
proto: socket.SocketKind = socket.SOCK_STREAM,
) -> t.Tuple[str, int]:
pass
def get_port(
self,
*,
address: t.Optional[str] = None,
port: t.Optional[int] = None,
proto: socket.SocketKind = socket.SOCK_STREAM,
) -> t.Tuple[str, int]:
if not self._entered:
raise RuntimeError("Enter the context manager before calling get_port.")
if address is None:
address = self._default_addr
s = socket.socket(socket.AF_INET, type=proto)
if port is not None:
try:
s.bind((address, port))
except OSError:
s.bind((address, 0))
else:
s.bind((address, 0))
self._sockets.append(s)
return s.getsockname()
def __enter__(self):
if self._entered:
raise RuntimeError(
"This context manager can only be entered once at a time. Exit first or use a new instance."
)
self._entered = True
self._sockets.clear()
return self
def __exit__(self, exc_type, exc_value, exc_tb):
for sock in self._sockets:
try:
sock.close()
except Exception:
pass
self._sockets.clear()
self._entered = False
def get_free_port():
# TODO: deprecated: remove in favor of port getter class once LDAP mock is refactored
"""Small helper method used to discover an open port to use by mock API HTTP servers.
:return: An available port number.
:rtype: int
"""
s = socket.socket(socket.AF_INET, type=socket.SOCK_STREAM)
s.bind(("localhost", 0))
address, port = s.getsockname()
s.close()
return port
def load_config_file(filename):
"""Load test config file data for use by various test cases.
:param filename: Name of the test data file.
:type filename: str | unicode
:return: Test data contents
:rtype: str | unicode
"""
test_data_path = get_config_file_path(filename)
with open(test_data_path) as f:
test_data = f.read()
return test_data
def get_config_file_path(*components):
"""Get the path to a config file under the "tests/config_files" directory.
I.e., the directory containing self-signed certificates, configuration files, etc. that are used for various tests.
:param components: One or more path components, the last of which is usually the name of the test data file.
:type components: str | unicode
:return: The absolute path to the test data directory.
:rtype: str | unicode
"""
# Use __file__ to derive a path relative to this module's location which points to the tests data directory.
relative_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "..", "config_files"
)
return os.path.join(os.path.abspath(relative_path), *components)
def decode_generated_root_token(encoded_token, otp, url):
"""Decode a newly generated root token via Vault CLI.
:param encoded_token: The token to decode.
:type encoded_token: str | unicode
:param otp: OTP code to use when decoding the token.
:type otp: str | unicode
:return: The decoded root token.
:rtype: str | unicode
"""
command = ["vault"]
if vault_version_ge("0.9.6"):
# before Vault ~0.9.6, the generate-root command was the first positional argument
# afterwards, it was moved under the "operator" category
command.append("operator")
command.extend(
[
"generate-root",
"-address",
url,
"-tls-skip-verify",
"-decode",
encoded_token,
"-otp",
otp,
]
)
process = subprocess.Popen(
**get_popen_kwargs(args=command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
)
stdout, stderr = process.communicate()
logging.debug('decode_generated_root_token stdout: "%s"' % str(stdout))
if stderr != "":
logging.error("decode_generated_root_token stderr: %s" % stderr)
try:
# On the off chance VAULT_FORMAT=json or such is set in the test environment:
new_token = json.loads(stdout)["token"]
except ValueError:
new_token = stdout.replace("Root token:", "")
new_token = new_token.strip()
return new_token
def get_popen_kwargs(**popen_kwargs):
"""Helper method to add `encoding='utf-8'` to subprocess.Popen.
:param popen_kwargs: List of keyword arguments to conditionally mutate
:type popen_kwargs: **kwargs
:return: Conditionally updated list of keyword arguments
:rtype: dict
"""
popen_kwargs["encoding"] = "utf-8"
return popen_kwargs
def base64ify(bytes_or_str):
"""Helper method to perform base64 encoding
:param bytes_or_str:
:type bytes_or_str:
:return:
:rtype:
"""
if isinstance(bytes_or_str, str):
input_bytes = bytes_or_str.encode("utf8")
else:
input_bytes = bytes_or_str
output_bytes = base64.urlsafe_b64encode(input_bytes)
return output_bytes.decode("ascii")
def configure_pki(
client, common_name="hvac.com", role_name="my-role", mount_point="pki"
):
"""Helper function to configure a pki backend for integration tests that need to work with lease IDs.
:param client: Authenticated hvac Client instance.
:typeclient: hvac.Client
:param common_name: Common name to configure in the pki backend
:type common_name: str
:param role_name: Name of the test role to configure.
:type role_name: str
:param mount_point: The path the pki backend is mounted under.
:type mount_point: str
:return: Nothing.
:rtype: None.
"""
if f"{mount_point}/" in client.sys.list_mounted_secrets_engines():
client.sys.disable_secrets_engine(mount_point)
client.sys.enable_secrets_engine(backend_type="pki", path=mount_point)
client.write_data(
path=f"{mount_point}/root/generate/internal",
data=dict(
common_name=common_name,
ttl="8760h",
),
)
client.write_data(
path=f"{mount_point}/config/urls",
data=dict(
issuing_certificates="http://127.0.0.1:8200/v1/pki/ca",
crl_distribution_points="http://127.0.0.1:8200/v1/pki/crl",
),
)
client.write_data(
path=f"{mount_point}/roles/{role_name}",
data=dict(
allowed_domains=common_name,
allow_subdomains=True,
generate_lease=True,
max_ttl="72h",
),
)
def disable_pki(client, mount_point="pki"):
"""Disable a previously configured pki backend.
:param client: Authenticated hvac Client instance.
:typeclient: hvac.Client
:param mount_point: The path the pki backend is mounted under.
:type mount_point: str
"""
client.sys.disable_secrets_engine(mount_point)