Skip to content

Commit

Permalink
fix: handle extra SSM parameters (#11)
Browse files Browse the repository at this point in the history
* bugfix: handle extra SSM parameters

Avoid issues of `extra fields not permitted (type=value_error.extra)` when there are non-relevant
params stored in SSM.

* Fixes

* Fix

* Flake8 fix

* Flake8 fix again

* Cleanup docs
  • Loading branch information
alukach committed Sep 3, 2022
1 parent aee7145 commit 4b436de
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 20 deletions.
7 changes: 6 additions & 1 deletion pydantic_ssm_settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def customise_sources(
file_secret_settings: SecretsSettingsSource,
) -> Tuple[SettingsSourceCallable, ...]:

ssm_settings = AwsSsmSettingsSource(
ssm_prefix=file_secret_settings.secrets_dir,
env_nested_delimiter=env_settings.env_nested_delimiter,
)

return (
init_settings,
env_settings,
Expand All @@ -30,5 +35,5 @@ def customise_sources(
# about unexpected arguments. `secrets_dir` comes from `_secrets_dir`,
# one of the few special kwargs that Pydantic will allow:
# https://github.com/samuelcolvin/pydantic/blob/45db4ad3aa558879824a91dd3b011d0449eb2977/pydantic/env_settings.py#L33
AwsSsmSettingsSource(ssm_prefix=file_secret_settings.secrets_dir),
ssm_settings,
)
136 changes: 117 additions & 19 deletions pydantic_ssm_settings/source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple

from botocore.exceptions import ClientError
from botocore.client import Config
import boto3

from pydantic import BaseSettings, typing
from pydantic import BaseSettings
from pydantic.typing import StrPath, get_origin, is_union
from pydantic.utils import deep_update
from pydantic.fields import ModelField

if TYPE_CHECKING:
from mypy_boto3_ssm.client import SSMClient
Expand All @@ -16,11 +19,20 @@
logger = logging.getLogger(__name__)


class SettingsError(ValueError):
pass


class AwsSsmSettingsSource:
__slots__ = ("ssm_prefix",)
__slots__ = ("ssm_prefix", "env_nested_delimiter")

def __init__(self, ssm_prefix: Optional[typing.StrPath]):
self.ssm_prefix: Optional[typing.StrPath] = ssm_prefix
def __init__(
self,
ssm_prefix: Optional[StrPath],
env_nested_delimiter: Optional[str] = None,
):
self.ssm_prefix: Optional[StrPath] = ssm_prefix
self.env_nested_delimiter: Optional[str] = env_nested_delimiter

@property
def client(self) -> "SSMClient":
Expand All @@ -31,38 +43,124 @@ def client_config(self) -> Config:
timeout = float(os.environ.get("SSM_TIMEOUT", 0.5))
return Config(connect_timeout=timeout, read_timeout=timeout)

def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
"""
Returns lazy SSM values for all settings.
"""
secrets: Dict[str, Optional[Any]] = {}

if self.ssm_prefix is None:
return secrets

secrets_path = Path(self.ssm_prefix)
def load_from_ssm(self, secrets_path: Path, case_sensitive: bool):

if not secrets_path.is_absolute():
raise ValueError("SSM prefix must be absolute path")

logger.debug(f"Building SSM settings with prefix of {secrets_path=}")

output = {}
try:
paginator = self.client.get_paginator("get_parameters_by_path")
response_iterator = paginator.paginate(
Path=str(secrets_path), WithDecryption=True
)

output = {}
for page in response_iterator:
for parameter in page["Parameters"]:
key = Path(parameter["Name"]).relative_to(secrets_path).as_posix()
output[key] = parameter["Value"]
return output
output[key if case_sensitive else key.lower()] = parameter["Value"]

except ClientError:
logger.exception("Failed to get parameters from %s", secrets_path)
return {}

return output

def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
"""
Returns SSM values for all settings.
"""
d: Dict[str, Optional[Any]] = {}

if self.ssm_prefix is None:
return d

ssm_values = self.load_from_ssm(
secrets_path=Path(self.ssm_prefix),
case_sensitive=settings.__config__.case_sensitive,
)

# The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa
for field in settings.__fields__.values():
env_val: Optional[str] = None
for env_name in field.field_info.extra["env_names"]:
env_val = ssm_values.get(env_name)
if env_val is not None:
break

is_complex, allow_json_failure = self._field_is_complex(field)
if is_complex:
if env_val is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self._explode_ssm_values(field, ssm_values)
if env_val_built:
d[field.alias] = env_val_built
else:
# field is complex and there's a value, decode that as JSON, then
# add explode_env_vars
try:
env_val = settings.__config__.json_loads(env_val)
except ValueError as e:
if not allow_json_failure:
raise SettingsError(
f'error parsing JSON for "{env_name}"'
) from e

if isinstance(env_val, dict):
d[field.alias] = deep_update(
env_val, self._explode_ssm_values(field, ssm_values)
)
else:
d[field.alias] = env_val
elif env_val is not None:
# simplest case, field is not complex, we only need to add the
# value if it was found
d[field.alias] = env_val

return d

def _field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if field.is_complex():
allow_json_failure = False
elif (
is_union(get_origin(field.type_))
and field.sub_fields
and any(f.is_complex() for f in field.sub_fields)
):
allow_json_failure = True
else:
return False, False

return True, allow_json_failure

def _explode_ssm_values(
self, field: ModelField, env_vars: Mapping[str, Optional[str]]
) -> Dict[str, Any]:
"""
Process env_vars and extract the values of keys containing
env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
"""
prefixes = [
f"{env_name}{self.env_nested_delimiter}"
for env_name in field.field_info.extra["env_names"]
]
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
if not any(env_name.startswith(prefix) for prefix in prefixes):
continue
_, *keys, last_key = env_name.split(self.env_nested_delimiter)
env_var = result
for key in keys:
env_var = env_var.setdefault(key, {})
env_var[last_key] = env_val

return result

def __repr__(self) -> str:
return f"AwsSsmSettingsSource(ssm_prefix={self.ssm_prefix!r})"

0 comments on commit 4b436de

Please sign in to comment.