Skip to content

Commit

Permalink
refactor(secrets): error handling for Azure auth problems
Browse files Browse the repository at this point in the history
  • Loading branch information
thekaveman committed Feb 7, 2024
1 parent 0e55a8c commit e055d86
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
15 changes: 13 additions & 2 deletions benefits/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys

from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from django.conf import settings
Expand Down Expand Up @@ -34,8 +35,18 @@ def get_secret_by_name(secret_name, client=None):
credential = DefaultAzureCredential()
client = SecretClient(vault_url=vault_url, credential=credential)

secret = client.get_secret(secret_name)
return secret.value
secret_value = None

if client is not None:
try:
secret = client.get_secret(secret_name)
secret_value = secret.value
except ClientAuthenticationError:
logger.error("Could not authenticate to Azure KeyVault")
else:
logger.error("Azure KeyVault SecretClient was not configured")

return secret_value


if __name__ == "__main__":
Expand Down
60 changes: 45 additions & 15 deletions tests/pytest/test_secrets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from azure.core.exceptions import ClientAuthenticationError

from benefits.secrets import KEY_VAULT_URL, get_secret_by_name

Expand All @@ -11,14 +12,20 @@ def mock_DefaultAzureCredential(mocker):
return credential_cls


@pytest.fixture
def secret_name():
return "the secret name"


@pytest.fixture
def secret_value():
return "the secret value"


@pytest.mark.parametrize("runtime_env", ["dev", "test", "prod"])
def test_get_secret_by_name__with_client__returns_secret_value(mocker, runtime_env, settings):
def test_get_secret_by_name__with_client__returns_secret_value(mocker, runtime_env, settings, secret_name, secret_value):
settings.RUNTIME_ENVIRONMENT = lambda: runtime_env

# set up the mock client class and expected return values

secret_name = "the secret name"
secret_value = "the secret value"
client = mocker.patch("benefits.secrets.SecretClient")
client.get_secret.return_value = mocker.Mock(value=secret_value)

Expand All @@ -29,16 +36,13 @@ def test_get_secret_by_name__with_client__returns_secret_value(mocker, runtime_e


@pytest.mark.parametrize("runtime_env", ["dev", "test", "prod"])
def test_get_secret_by_name__None_client__returns_secret_value(mocker, runtime_env, settings, mock_DefaultAzureCredential):
def test_get_secret_by_name__None_client__returns_secret_value(
mocker, runtime_env, settings, mock_DefaultAzureCredential, secret_name, secret_value
):
settings.RUNTIME_ENVIRONMENT = lambda: runtime_env
expected_keyvault_url = KEY_VAULT_URL.format(env=runtime_env[0])

# set up the mock client class and expected return values
# this test does not pass in a known client, instead checking that a client is constructed as expected

secret_name = "the secret name"
secret_value = "the secret value"

mock_credential = mock_DefaultAzureCredential.return_value
client_cls = mocker.patch("benefits.secrets.SecretClient")
client = client_cls.return_value
Expand All @@ -51,11 +55,37 @@ def test_get_secret_by_name__None_client__returns_secret_value(mocker, runtime_e
assert actual_value == secret_value


def test_get_secret_by_name__local__returns_environment_variable(mocker, settings):
settings.RUNTIME_ENVIRONMENT = lambda: "local"
@pytest.mark.parametrize("runtime_env", ["dev", "test", "prod"])
def test_get_secret_by_name__None_client__returns_None(mocker, runtime_env, settings, secret_name):
settings.RUNTIME_ENVIRONMENT = lambda: runtime_env

# this test forces construction of a new client to None
client_cls = mocker.patch("benefits.secrets.SecretClient", return_value=None)

actual_value = get_secret_by_name(secret_name)

client_cls.assert_called_once()
assert actual_value is None

secret_name = "the secret name"
secret_value = "the secret value"

@pytest.mark.parametrize("runtime_env", ["dev", "test", "prod"])
def test_get_secret_by_name__unauthenticated_client__returns_None(mocker, runtime_env, settings, secret_name):
settings.RUNTIME_ENVIRONMENT = lambda: runtime_env

# this test forces client.get_secret to throw an exception
client_cls = mocker.patch("benefits.secrets.SecretClient")
client = client_cls.return_value
client.get_secret.side_effect = ClientAuthenticationError

actual_value = get_secret_by_name(secret_name)

client_cls.assert_called_once()
client.get_secret.assert_called_once_with(secret_name)
assert actual_value is None


def test_get_secret_by_name__local__returns_environment_variable(mocker, settings, secret_name, secret_value):
settings.RUNTIME_ENVIRONMENT = lambda: "local"

env_spy = mocker.patch("benefits.secrets.os.environ.get", return_value=secret_value)
client_cls = mocker.patch("benefits.secrets.SecretClient")
Expand Down

0 comments on commit e055d86

Please sign in to comment.