Skip to content

Commit

Permalink
Changing default values for sys.initialize parameters `secret_share…
Browse files Browse the repository at this point in the history
…s` and `secret_threshold` (#1063)

* move mock adapter and fixture to conftest

* update sys.initialize with conditional defaults and warnings

* add integration test for sys.is_initialized

* add unit tests for sys.init module

* add init value passing test
  • Loading branch information
briantist committed Oct 13, 2023
1 parent 9befaa4 commit 42736b3
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 34 deletions.
64 changes: 47 additions & 17 deletions hvac/api/system_backend/init.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from hvac.api.system_backend.system_backend_mixin import SystemBackendMixin
from hvac.exceptions import ParamValidationError

Expand Down Expand Up @@ -28,8 +29,8 @@ def is_initialized(self):

def initialize(
self,
secret_shares=5,
secret_threshold=3,
secret_shares=None,
secret_threshold=None,
pgp_keys=None,
root_token_pgp_key=None,
stored_shares=None,
Expand All @@ -49,7 +50,7 @@ def initialize(
:type secret_shares: int
:param secret_threshold: Specifies the number of shares required to reconstruct the master key. This must be
less than or equal secret_shares. If using Vault HSM with auto-unsealing, this value must be the same as
secret_shares.
secret_shares, or ommitted, depending on the version of Vault and the seal type.
:type secret_threshold: int
:param pgp_keys: List of PGP public keys used to encrypt the output unseal keys.
Ordering is preserved. The keys must be base64-encoded from their original binary representation.
Expand All @@ -73,20 +74,49 @@ def initialize(
:return: The JSON response of the request.
:rtype: dict
"""

# TODO(v3.0.0): remove this
if recovery_shares is None and secret_shares is None:
msg = (
"The secret_shares parameter will default to None in hvac v3.0.0. "
"To use the old default with no warning, explicitly set this value to 5. "
"See https://github.com/hvac/hvac/issues/1030"
)
warnings.warn(
message=msg,
category=DeprecationWarning,
stacklevel=2,
)
secret_shares = 5

# TODO(v3.0.0): remove this
if recovery_threshold is None and secret_threshold is None:
msg = (
"The secret_threshold parameter will default to None in hvac v3.0.0. "
"To use the old default with no warning, explicitly set this value to 3. "
"See https://github.com/hvac/hvac/issues/1030"
)
warnings.warn(
message=msg,
category=DeprecationWarning,
stacklevel=2,
)
secret_threshold = 3

params = {
"secret_shares": secret_shares,
"secret_threshold": secret_threshold,
"root_token_pgp_key": root_token_pgp_key,
}

if pgp_keys is not None:
if pgp_keys is not None and secret_shares is not None:
if len(pgp_keys) != secret_shares:
raise ParamValidationError(
"length of pgp_keys list argument must equal secret_shares value"
)
params["pgp_keys"] = pgp_keys

if stored_shares is not None:
if stored_shares is not None and secret_shares is not None:
if stored_shares != secret_shares:
raise ParamValidationError(
"value for stored_shares argument must equal secret_shares argument"
Expand All @@ -96,18 +126,18 @@ def initialize(
if recovery_shares is not None:
params["recovery_shares"] = recovery_shares

if recovery_threshold is not None:
if recovery_threshold > recovery_shares:
error_msg = "value for recovery_threshold argument be less than or equal to recovery_shares argument"
raise ParamValidationError(error_msg)
params["recovery_threshold"] = recovery_threshold

if recovery_pgp_keys is not None:
if len(recovery_pgp_keys) != recovery_shares:
raise ParamValidationError(
"length of recovery_pgp_keys list argument must equal recovery_shares value"
)
params["recovery_pgp_keys"] = recovery_pgp_keys
if recovery_threshold is not None:
if recovery_threshold > recovery_shares:
error_msg = "value for recovery_threshold argument must be less than or equal to recovery_shares argument"
raise ParamValidationError(error_msg)
params["recovery_threshold"] = recovery_threshold

if recovery_pgp_keys is not None:
if len(recovery_pgp_keys) != recovery_shares:
raise ParamValidationError(
"length of recovery_pgp_keys list argument must equal recovery_shares value"
)
params["recovery_pgp_keys"] = recovery_pgp_keys

api_path = "/v1/sys/init"
return self._adapter.put(
Expand Down
5 changes: 5 additions & 0 deletions tests/integration_tests/api/system_backend/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ def test_read_init_status(self):
read_response = self.client.sys.read_init_status()
logging.debug("read_response: %s" % read_response)
self.assertTrue(expr=read_response["initialized"])

def test_is_initialized(self):
is_initialized_response = self.client.sys.is_initialized()
logging.debug("is_initialized_response: %s" % is_initialized_response)
self.assertTrue(expr=is_initialized_response)
19 changes: 2 additions & 17 deletions tests/unit_tests/api/auth_methods/test_token.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
import pytest

from unittest import mock
from hvac.api.auth_methods.token import Token
from hvac.adapters import Adapter


# TODO: move this to a conftest.py somewhere
class MockAdapter(Adapter):
def __init__(self, *args, **kwargs):
kwargs["session"] = mock.MagicMock()
super().__init__(*args, **kwargs)

def request(self, *args, **kwargs):
return (args, kwargs)

def get_login_token(self, response):
raise NotImplementedError()


@pytest.fixture
def token_auth():
return Token(MockAdapter())
def token_auth(mock_adapter):
return Token(mock_adapter)


class TestToken:
Expand Down
175 changes: 175 additions & 0 deletions tests/unit_tests/api/system_backend/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import pytest

from hvac.exceptions import ParamValidationError
from hvac.api.system_backend.init import Init


@pytest.fixture
def sys_init(mock_adapter):
return Init(mock_adapter)


INIT_SECRET_PGP_ERROR_MSG = (
r"length of pgp_keys list argument must equal secret_shares value"
)
INIT_RECOVERY_PGP_ERROR_MSG = (
r"length of recovery_pgp_keys list argument must equal recovery_shares value"
)
INIT_RECOVERY_SHARES_ERROR_MSG = r"value for recovery_threshold argument must be less than or equal to recovery_shares argument"
INIT_STORED_SHARES_ERROR_MSG = (
r"value for stored_shares argument must equal secret_shares argument"
)


class TestInit:
@pytest.mark.parametrize(
["secret_shares", "recovery_shares", "expected_value", "expected_warn"],
[
(None, None, 5, True),
(3, None, 3, False),
(5, 7, 5, False),
(None, 9, None, False),
],
)
def test_initialize_default_secret_shares(
self,
sys_init,
mock_warn,
secret_shares,
recovery_shares,
expected_value,
expected_warn,
):
(r_args, r_kwargs) = sys_init.initialize(
secret_shares=secret_shares,
recovery_shares=recovery_shares,
recovery_threshold=0,
)
params = r_kwargs["json"]
assert params["secret_shares"] == expected_value

if expected_warn:
mock_warn.assert_called_once()
else:
mock_warn.assert_not_called()

@pytest.mark.parametrize(
["secret_threshold", "recovery_threshold", "expected_value", "expected_warn"],
[
(None, None, 3, True),
(3, None, 3, False),
(5, 7, 5, False),
(None, 9, None, False),
],
)
def test_initialize_default_secret_threshold(
self,
sys_init,
mock_warn,
secret_threshold,
recovery_threshold,
expected_value,
expected_warn,
):
(r_args, r_kwargs) = sys_init.initialize(
secret_threshold=secret_threshold,
recovery_threshold=recovery_threshold,
secret_shares=0,
)
params = r_kwargs["json"]
assert params["secret_threshold"] == expected_value

if expected_warn:
mock_warn.assert_called_once()
else:
mock_warn.assert_not_called()

@pytest.mark.parametrize(
[
"secret_shares",
"pgp_keys",
"stored_shares",
"recovery_shares",
"recovery_pgp_keys",
"recovery_threshold",
"exc_msg",
],
[
(
2,
[1, 2, 3],
2,
None,
None,
None,
INIT_SECRET_PGP_ERROR_MSG,
),
(
2,
[1, 2, 3],
3,
None,
None,
None,
INIT_SECRET_PGP_ERROR_MSG,
),
(
2,
[1, 2],
3,
None,
None,
None,
INIT_STORED_SHARES_ERROR_MSG,
),
(2, [1, 2], 2, 3, [1, 2], None, INIT_RECOVERY_PGP_ERROR_MSG),
(2, [1, 2], 2, 3, [1, 2], 1, INIT_RECOVERY_PGP_ERROR_MSG),
(2, [1, 2], 2, 3, [1, 2], 9, INIT_RECOVERY_SHARES_ERROR_MSG),
],
)
def test_initialize_errors(
self,
sys_init,
mock_adapter,
secret_shares,
pgp_keys,
stored_shares,
recovery_shares,
recovery_pgp_keys,
recovery_threshold,
exc_msg,
):
with pytest.raises(ParamValidationError, match=exc_msg):
sys_init.initialize(
secret_threshold=0, # TODO(v3.0.0): remove this, only set to suppress warning
secret_shares=secret_shares,
pgp_keys=pgp_keys,
stored_shares=stored_shares,
recovery_shares=recovery_shares,
recovery_pgp_keys=recovery_pgp_keys,
recovery_threshold=recovery_threshold,
)

mock_adapter.request.assert_not_called()

def test_initialize_value_pass(self, sys_init):
(r_args, r_kwargs) = sys_init.initialize(
secret_threshold=0,
secret_shares=2,
root_token_pgp_key="abc",
pgp_keys=[1, 2],
stored_shares=2,
recovery_shares=3,
recovery_pgp_keys=[1, 2, 3],
recovery_threshold=3,
)
params = r_kwargs["json"]

assert params["secret_threshold"] == 0
assert params["secret_shares"] == 2
assert params["root_token_pgp_key"] == "abc"
assert params["pgp_keys"] == [1, 2]
assert params["stored_shares"] == 2
assert params["recovery_shares"] == 3
assert params["recovery_pgp_keys"] == [1, 2, 3]
assert params["recovery_threshold"] == 3
31 changes: 31 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from unittest import mock

from hvac.adapters import Adapter


class MockAdapter(Adapter):
def __init__(self, *args, **kwargs):
if "session" not in kwargs:
kwargs["session"] = mock.MagicMock()
super().__init__(*args, **kwargs)

def request(self, *args, **kwargs):
return (args, kwargs)

def get_login_token(self, response):
raise NotImplementedError()


@pytest.fixture
def mock_adapter():
adapter = MockAdapter()
with mock.patch.object(adapter, "request", mock.Mock(wraps=MockAdapter.request)):
yield adapter


@pytest.fixture
def mock_warn():
with mock.patch("warnings.warn") as warn:
yield warn

0 comments on commit 42736b3

Please sign in to comment.