Skip to content

Commit

Permalink
Merge pull request #498 from noqdev/fix-detect-interpreter
Browse files Browse the repository at this point in the history
Fix detect interpreter
  • Loading branch information
smoy committed Jul 18, 2023
2 parents 9d94091 + 78e1ec3 commit eae94c6
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 174 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@

# Change Log

## 0.10.11 (July 17, 2023)

BUG FIXES:
* Fixed race condition on iambic detect not using templated resource id grouping resources.
* Fixed issue where a resource could show as excluded on a resource it was never evaluated on.

ENHANCEMENTS:
* Improved ordering of template attributes.
* `base_group_dict_attribute` is now more deterministic in its grouping.
* `iambic detect` performance optimizations.
* Now only evaluates on the account a resource id change is detected on as opposed to all accounts.
* Example if `engineering` is on all accounts and detect is ran for account a, only `engineering` on account a is evaluated.
* Removed remaining AWS provider references from core.

## 0.10.1 (July 3, 2023)

DOCS:
Expand Down
46 changes: 46 additions & 0 deletions iambic/core/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from collections import defaultdict
from typing import Optional, Type

from iambic.core.models import BaseTemplate, ProviderChild
from iambic.core.utils import evaluate_on_provider


def group_detect_messages(group_by: str, messages: list) -> dict:
"""Group messages by a key in the message dict.
Args:
group_by (str): The key to group by.
messages (list): The messages to group.
Returns:
dict: The grouped messages.
"""
grouped_messages = defaultdict(list)
for message in messages:
grouped_messages[getattr(message, group_by)].append(message)

return grouped_messages


def generate_template_output(
excluded_provider_ids: list[str],
provider_child_map: dict[str, ProviderChild],
template: Optional[Type[BaseTemplate]],
) -> dict[str, dict]:
provider_children_value_map = dict()
if not template:
return provider_children_value_map
for provider_child_id, provider_child in provider_child_map.items():
if provider_child_id in excluded_provider_ids:
continue
elif not evaluate_on_provider(
template, provider_child, exclude_import_only=False
):
continue

if provider_child_value := template.apply_resource_dict(provider_child):
provider_children_value_map[provider_child_id] = provider_child_value

return provider_children_value_map
37 changes: 12 additions & 25 deletions iambic/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import dateparser
from deepdiff.model import PrettyOrderedSet
from git import Repo
from jinja2 import BaseLoader, Environment
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Extra, Field, root_validator, schema, validate_model, validator
from pydantic.fields import ModelField
Expand All @@ -40,8 +39,8 @@
LiteralScalarString,
apply_to_provider,
create_commented_map,
get_rendered_template_str_value,
get_writable_directory,
sanitize_string,
simplify_dt,
snake_to_camelcap,
sort_dict,
Expand All @@ -51,7 +50,6 @@

if TYPE_CHECKING:
from iambic.config.dynamic_config import Config
from iambic.plugins.v0_1_0.aws.models import AWSAccount

MappingIntStrAny = typing.Mapping[int | str, Any]
AbstractSetIntStr = typing.AbstractSet[int | str]
Expand Down Expand Up @@ -159,14 +157,14 @@ def get_field_type(field: Any) -> Any:

def get_attribute_val_for_account(
self,
aws_account: AWSAccount,
provider_child: Type[ProviderChild],
attr: str,
as_boto_dict: bool = True,
):
"""
Retrieve the value of an attribute for a specific AWS account.
:param aws_account: The AWSAccount object for which the attribute value should be retrieved.
:param provider_child: The ProviderChild object for which the attribute value should be retrieved.
:param attr: The attribute name (supports nested attributes via dot notation, e.g., properties.tags).
:param as_boto_dict: If True, the value will be transformed to a boto dictionary if applicable.
:return: The attribute value for the specified AWS account.
Expand All @@ -177,12 +175,12 @@ def get_attribute_val_for_account(
attr_val = getattr(attr_val, attr_key)

if as_boto_dict and hasattr(attr_val, "_apply_resource_dict"):
return attr_val._apply_resource_dict(aws_account)
return attr_val._apply_resource_dict(provider_child)
elif not isinstance(attr_val, list):
return attr_val

matching_definitions = [
val for val in attr_val if apply_to_provider(val, aws_account)
val for val in attr_val if apply_to_provider(val, provider_child)
]
if len(matching_definitions) == 0:
# Fallback to the default definition
Expand All @@ -194,15 +192,15 @@ def get_attribute_val_for_account(
return field.__fields__[split_key[-1]].default
elif as_boto_dict:
return [
match._apply_resource_dict(aws_account)
match._apply_resource_dict(provider_child)
if hasattr(match, "_apply_resource_dict")
else match
for match in matching_definitions
]
else:
return matching_definitions

def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:
def _apply_resource_dict(self, provider_child: Type[ProviderChild] = None) -> dict:
exclude_keys = {
"deleted",
"expires_at",
Expand All @@ -220,10 +218,10 @@ def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:
exclude_keys.update(self.exclude_keys)
has_properties = hasattr(self, "properties")
properties = getattr(self, "properties", self)
if aws_account:
if provider_child:
resource_dict = {
k: self.get_attribute_val_for_account(
aws_account,
provider_child,
f"properties.{k}" if has_properties else k,
)
for k in properties.__dict__.keys()
Expand All @@ -239,20 +237,9 @@ def _apply_resource_dict(self, aws_account: AWSAccount = None) -> dict:

return {self.case_convention(k): v for k, v in resource_dict.items()}

def apply_resource_dict(self, aws_account: AWSAccount) -> dict:
response = self._apply_resource_dict(aws_account)
variables = {var.key: var.value for var in aws_account.variables}
variables["account_id"] = aws_account.account_id
variables["account_name"] = aws_account.account_name
if hasattr(self, "owner") and (owner := getattr(self, "owner", None)):
variables["owner"] = owner

rtemplate = Environment(loader=BaseLoader()).from_string(json.dumps(response))
valid_characters_re = r"[\w_+=,.@-]"
variables = {
k: sanitize_string(v, valid_characters_re) for k, v in variables.items()
}
data = rtemplate.render(var=variables)
def apply_resource_dict(self, provider_child: Type[ProviderChild]) -> dict:
response = self._apply_resource_dict(provider_child)
data = get_rendered_template_str_value(json.dumps(response), provider_child)
return json.loads(data)

async def remove_expired_resources(self):
Expand Down
20 changes: 18 additions & 2 deletions iambic/core/template_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from collections import OrderedDict, defaultdict
from typing import Type, Union

import xxhash
Expand All @@ -20,6 +20,22 @@
)


def deep_sort(obj: Union[dict, list]) -> Union[dict, list]:
"""Recursively sorts a dict or list"""
if isinstance(obj, dict):
obj = OrderedDict(sorted(obj.items()))
for k, v in obj.items():
if isinstance(v, dict) or isinstance(v, list):
obj[k] = deep_sort(v)
elif isinstance(obj, list):
for i, v in enumerate(obj):
if isinstance(v, dict) or isinstance(v, list):
obj[i] = deep_sort(v)
obj = sorted(obj, key=lambda x: json.dumps(x))

return obj


async def get_existing_template_map(
repo_dir: str,
template_type: str,
Expand Down Expand Up @@ -275,7 +291,7 @@ async def base_group_dict_attribute(
] = provider_child_resource[provider_child_key_id]
# Set raw dict
resource_hash = xxhash.xxh32(
json.dumps(resource["resource_val"])
json.dumps(deep_sort(resource["resource_val"]))
).hexdigest()
hash_map[resource_hash] = resource["resource_val"]
# Set dict with attempted interpolation
Expand Down
28 changes: 27 additions & 1 deletion iambic/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import aiofiles
import jwt
from asgiref.sync import sync_to_async
from jinja2 import BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from ruamel.yaml import YAML, scalarstring

from iambic.core import noq_json as json
Expand All @@ -26,7 +28,7 @@
from iambic.core.logger import log

if TYPE_CHECKING:
from iambic.core.models import ProposedChange
from iambic.core.models import ProposedChange, ProviderChild


NOQ_TEMPLATE_REGEX = r".*template_type:\n?.*NOQ::"
Expand Down Expand Up @@ -892,3 +894,27 @@ def decode_with_reference_time(encoded_jwt, public_key, algorithms, reference_ti
)

return payload


def get_rendered_template_str_value(
template_value: str, provider_child: typing.Type[ProviderChild]
) -> str:
"""
Render a template string with the variables from the provider child.
"""
valid_characters_re = r"[\w_+=,.@-]"
variables = {var.key: var.value for var in getattr(provider_child, "variables", [])}
for extra_attr in {"account_id", "account_name", "owner"}:
if attr_val := getattr(provider_child, extra_attr, None):
variables[extra_attr] = attr_val

if not variables:
return template_value

variables = {
k: sanitize_string(v, valid_characters_re) for k, v in variables.items()
}
rtemplate = ImmutableSandboxedEnvironment(loader=BaseLoader()).from_string(
template_value
)
return rtemplate.render(var=variables)
36 changes: 28 additions & 8 deletions iambic/plugins/v0_1_0/aws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
Variable,
)
from iambic.core.parser import load_templates
from iambic.core.template_generation import get_existing_template_map
from iambic.core.template_generation import (
get_existing_template_map,
templatize_resource,
)
from iambic.core.utils import async_batch_processor, gather_templates, yaml
from iambic.plugins.v0_1_0.aws.event_bridge.models import (
GroupMessageDetails,
Expand Down Expand Up @@ -69,6 +72,7 @@
from iambic.plugins.v0_1_0.aws.organizations.scp.utils import (
service_control_policy_is_enabled,
)
from iambic.plugins.v0_1_0.aws.utils import get_aws_account_map

if TYPE_CHECKING:
from iambic.plugins.v0_1_0.aws.iambic_plugin import AWSConfig
Expand Down Expand Up @@ -514,6 +518,7 @@ async def detect_changes( # noqa: C901
log.debug("No cloudtrail changes queue arn found. Returning")
return

aws_account_map = await get_aws_account_map(config)
role_messages = []
user_messages = []
group_messages = []
Expand Down Expand Up @@ -584,6 +589,7 @@ async def detect_changes( # noqa: C901
)
if actor != identity_arn:
account_id = decoded_message.get("recipientAccountId")
aws_account = aws_account_map[account_id]
request_params = decoded_message["requestParameters"]
response_elements = decoded_message["responseElements"]
event = decoded_message["eventName"]
Expand All @@ -595,7 +601,9 @@ async def detect_changes( # noqa: C901
role_messages.append(
RoleMessageDetails(
account_id=account_id,
role_name=role_name,
role_name=templatize_resource(
aws_account, role_name
),
delete=bool(event == "DeleteRole"),
)
)
Expand All @@ -605,7 +613,9 @@ async def detect_changes( # noqa: C901
user_messages.append(
UserMessageDetails(
account_id=account_id,
role_name=user_name,
user_name=templatize_resource(
aws_account, user_name
),
delete=bool(event == "DeleteUser"),
)
)
Expand All @@ -615,7 +625,9 @@ async def detect_changes( # noqa: C901
group_messages.append(
GroupMessageDetails(
account_id=account_id,
group_name=group_name,
group_name=templatize_resource(
aws_account, group_name
),
delete=bool(event == "DeleteGroup"),
)
)
Expand All @@ -632,8 +644,12 @@ async def detect_changes( # noqa: C901
managed_policy_messages.append(
ManagedPolicyMessageDetails(
account_id=account_id,
policy_name=policy_name,
policy_path=policy_path,
policy_name=templatize_resource(
aws_account, policy_name
),
policy_path=templatize_resource(
aws_account, policy_path
),
delete=bool(
decoded_message["eventName"] == "DeletePolicy"
),
Expand All @@ -647,8 +663,12 @@ async def detect_changes( # noqa: C901
permission_set_messages.append(
PermissionSetMessageDetails(
account_id=account_id,
instance_arn=request_params.get("instanceArn"),
permission_set_arn=permission_set_arn,
instance_arn=templatize_resource(
aws_account, request_params.get("instanceArn")
),
permission_set_arn=templatize_resource(
aws_account, permission_set_arn
),
)
)
elif scp_policy_id := SCPMessageDetails.get_policy_id(
Expand Down
Loading

0 comments on commit eae94c6

Please sign in to comment.