Skip to content

Commit

Permalink
implement SNS Filter/operators $or, suffix, equals-ignore-case, anyth…
Browse files Browse the repository at this point in the history
…ing-but (#10691)

Co-authored-by: Mathieu Cloutier <cloutier.mat0@gmail.com>
  • Loading branch information
bentsku and cloutierMat committed May 16, 2024
1 parent 98dbcbc commit 21bdd43
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 72 deletions.
176 changes: 113 additions & 63 deletions localstack/services/sns/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@ class SubscriptionFilter:
def check_filter_policy_on_message_attributes(
self, filter_policy: dict, message_attributes: dict
):
for criteria, conditions in filter_policy.items():
if not self._evaluate_filter_policy_conditions_on_attribute(
conditions,
message_attributes.get(criteria),
field_exists=criteria in message_attributes,
):
return False
if not filter_policy:
return True

return True
flat_policy_conditions = self.flatten_policy(filter_policy)

return any(
all(
self._evaluate_filter_policy_conditions_on_attribute(
conditions,
message_attributes.get(criteria),
field_exists=criteria in message_attributes,
)
for criteria, conditions in flat_policy.items()
)
for flat_policy in flat_policy_conditions
)

def check_filter_policy_on_message_body(self, filter_policy: dict, message_body: str):
try:
Expand Down Expand Up @@ -45,18 +52,26 @@ def _evaluate_nested_filter_policy_on_dict(self, filter_policy, payload: dict) -
:param payload: a dict, starting at the MessageBody
:return: True if the payload respect the filter policy, otherwise False
"""
flat_policy = self._flatten_dict(filter_policy)
flat_payloads = self._flatten_dict_with_list(payload)
for key, values in flat_policy.items():
if not any(
self._evaluate_condition(
flat_payload.get(key), condition, field_exists=key in flat_payload
if not filter_policy:
return True

# TODO: maybe save/cache the flattened/expanded policy?
flat_policy_conditions = self.flatten_policy(filter_policy)
flat_payloads = self.flatten_payload(payload)

return any(
all(
any(
self._evaluate_condition(
flat_payload.get(key), condition, field_exists=key in flat_payload
)
for condition in values
for flat_payload in flat_payloads
)
for condition in values
for flat_payload in flat_payloads
):
return False
return True
for key, values in flat_policy.items()
)
for flat_policy in flat_policy_conditions
)

def _evaluate_filter_policy_conditions_on_attribute(
self, conditions, attribute, field_exists: bool
Expand Down Expand Up @@ -94,11 +109,19 @@ def _evaluate_condition(self, value, condition, field_exists: bool):
# the remaining conditions require the value to not be None
return False
elif anything_but := condition.get("anything-but"):
# TODO: support with `prefix`
# https://docs.aws.amazon.com/sns/latest/dg/string-value-matching.html#string-anything-but-matching-prefix
return value not in anything_but
elif prefix := (condition.get("prefix")):
if isinstance(anything_but, dict):
not_prefix = anything_but.get("prefix")
return not value.startswith(not_prefix)
elif isinstance(anything_but, list):
return value not in anything_but
else:
return value != anything_but
elif prefix := condition.get("prefix"):
return value.startswith(prefix)
elif suffix := condition.get("suffix"):
return value.endswith(suffix)
elif equal_ignore_case := condition.get("equals-ignore-case"):
return equal_ignore_case.lower() == value.lower()
elif numeric_condition := condition.get("numeric"):
return self._evaluate_numeric_condition(numeric_condition, value)
return False
Expand Down Expand Up @@ -135,35 +158,59 @@ def _evaluate_numeric_condition(conditions, value):
return True

@staticmethod
def _flatten_dict(nested_dict: dict):
def flatten_policy(nested_dict: dict) -> list[dict]:
"""
Takes a dictionary as input and will output the dictionary on a single level.
Input:
`{"field1": {"field2: {"field3: "val1", "field4": "val2"}}}`
`{"field1": {"field2": {"field3": "val1", "field4": "val2"}}}`
Output:
`{
"field1.field2.field3": "val1",
"field1.field2.field4": "val1"
}`
`[
{
"field1.field2.field3": "val1",
"field1.field2.field4": "val2"
}
]`
Input with $or will create multiple outputs:
`{"$or": [{"field1": "val1"}, {"field2": "val2"}], "field3": "val3"}`
Output:
`[
{"field1": "val1", "field3": "val3"},
{"field2": "val2", "field3": "val3"}
]`
:param nested_dict: a (nested) dictionary
:return: a list of flattened dictionaries with no nested dict or list inside, flattened to a
single level, one list item for every list item encountered
"""
flatten = {}

def _traverse(_policy: dict, parent_key=None):
for key, values in _policy.items():
flattened_parent_key = key if not parent_key else f"{parent_key}.{key}"
if not isinstance(values, dict):
flatten[flattened_parent_key] = values
def _traverse_policy(obj, array=None, parent_key=None) -> list:
if array is None:
array = [{}]

for key, values in obj.items():
if key == "$or" and isinstance(values, list) and len(values) > 1:
# $or will create multiple new branches in the array.
# Each current branch will traverse with each choice in $or
array = [
i for value in values for i in _traverse_policy(value, array, parent_key)
]
else:
_traverse(values, parent_key=flattened_parent_key)
# We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
_parent_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(values, dict):
# If the current key has child dict -- key: "key1", child: {"key2": ["val1", val2"]}
# We only update the parent_key and traverse its children with the current branches
array = _traverse_policy(values, array, _parent_key)
else:
# If the current key has no child, this means we found the values to match -- child: ["val1", val2"]
# we update the branches with the parent chain and the values -- {"key1.key2": ["val1, val2"]}
array = [{**item, _parent_key: values} for item in array]

_traverse(nested_dict)
return flatten
return array

return _traverse_policy(nested_dict)

@staticmethod
def _flatten_dict_with_list(nested_dict: dict) -> list[dict]:
def flatten_payload(nested_dict: dict) -> list[dict]:
"""
Takes a dictionary as input and will output the dictionary on a single level.
The dictionary can have lists containing other dictionaries, and one root level entry will be created for every
Expand All @@ -189,37 +236,22 @@ def _flatten_dict_with_list(nested_dict: dict) -> list[dict]:
:param nested_dict: a (nested) dictionary
:return: flatten_dict: a dictionary with no nested dict inside, flattened to a single level
"""
flattened = []
current_object = {}

def _traverse(_object, parent_key=None):
def _traverse(_object: dict, array=None, parent_key=None) -> list:
if isinstance(_object, dict):
for key, values in _object.items():
flattened_parent_key = key if not parent_key else f"{parent_key}.{key}"
_traverse(values, flattened_parent_key)
# We update the parent key do that {"key1": {"key2": ""}} becomes "key1.key2"
_parent_key = f"{parent_key}.{key}" if parent_key else key
array = _traverse(values, array, _parent_key)

# we don't have to worry about `parent_key` being None for list or any other type, because we have a check
# that the first object is always a dict, thus setting a parent key on first iteration
elif isinstance(_object, list):
for value in _object:
if isinstance(value, (dict, list)):
_traverse(value, parent_key=parent_key)
else:
current_object[parent_key] = value

if current_object:
flattened.append({**current_object})
current_object.clear()
array = [i for value in _object for i in _traverse(value, array, parent_key)]
else:
current_object[parent_key] = _object

_traverse(nested_dict)
array = [{**item, parent_key: _object} for item in array]

# if the payload did not have any list, we manually append the current object
if not flattened:
flattened.append(current_object)
return array

return flattened
return _traverse(nested_dict, array=[{}], parent_key=None)


class FilterPolicyValidator:
Expand Down Expand Up @@ -340,7 +372,6 @@ def _validate_rule(self, rule: t.Any) -> None:
operator, value = k, v

if operator in (
"anything-but",
"equals-ignore-case",
"prefix",
"suffix",
Expand All @@ -351,6 +382,25 @@ def _validate_rule(self, rule: t.Any) -> None:
)
return

elif operator == "anything-but":
# anything-but can actually contain any kind of simple rule (str, number, and list)
if isinstance(value, list):
for v in value:
self._validate_rule(v)

return

# or have a nested `prefix` pattern
elif isinstance(value, dict):
for inner_operator in value.keys():
if inner_operator != "prefix":
raise InvalidParameterException(
f"{self.error_prefix}FilterPolicy: Unsupported anything-but pattern: {inner_operator}"
)

self._validate_rule(value)
return

elif operator == "exists":
if not isinstance(value, bool):
raise InvalidParameterException(
Expand Down
81 changes: 79 additions & 2 deletions tests/aws/services/sns/test_sns_filter_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,6 @@ def get_messages(_queue_url: str, _received_messages: list):
snapshot.match("messages", {"Messages": received_messages})

@markers.aws.validated
@pytest.mark.skip("Not yet supported by LocalStack")
def test_filter_policy_on_message_body_or_attribute(
self,
sqs_create_queue,
Expand Down Expand Up @@ -1237,7 +1236,7 @@ def test_validate_policy_string_operators(
topic_arn = sns_create_topic()["TopicArn"]

def _subscribe(policy: dict):
sns_subscription(
return sns_subscription(
TopicArn=topic_arn,
Protocol="sms",
Endpoint=phone_number,
Expand All @@ -1262,6 +1261,18 @@ def _subscribe(policy: dict):
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-is-not-list-and-operator", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"suffix": []}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-empty-list", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"suffix": ["test", "test2"]}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-list-wrong-type", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": {"suffix": "value", "prefix": "value"}}
_subscribe(filter_policy)
Expand Down Expand Up @@ -1413,6 +1424,72 @@ def _subscribe(policy: dict):
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-string", e.value.response)

@markers.aws.validated
@markers.snapshot.skip_snapshot_verify(paths=["$..Error.Message"])
def test_validate_policy_nested_anything_but_operator(
self,
sns_create_topic,
sns_subscription,
snapshot,
aws_client,
):
phone_number = "+123123123"
topic_arn = sns_create_topic()["TopicArn"]

def _subscribe(policy: dict):
return sns_subscription(
TopicArn=topic_arn,
Protocol="sms",
Endpoint=phone_number,
Attributes={"FilterPolicy": json.dumps(policy)},
)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"wrong-operator": None}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-wrong-operator", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"suffix": "test"}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-suffix", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"exists": False}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-exists", e.value.response)

with pytest.raises(ClientError) as e:
filter_policy = {"key": [{"anything-but": {"prefix": False}}]}
_subscribe(filter_policy)
self._add_normalized_field_to_snapshot(e.value.response)
snapshot.match("error-condition-anything-but-prefix-wrong-type", e.value.response)

# positive testing
filter_policy = {"key": [{"anything-but": {"prefix": "test-"}}]}
response = _subscribe(filter_policy)
assert "SubscriptionArn" in response
subscription_arn = response["SubscriptionArn"]

filter_policy = {"key": [{"anything-but": ["test", "test2"]}]}
response = aws_client.sns.set_subscription_attributes(
SubscriptionArn=subscription_arn,
AttributeName="FilterPolicy",
AttributeValue=json.dumps(filter_policy),
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

filter_policy = {"key": [{"anything-but": "test"}]}
response = aws_client.sns.set_subscription_attributes(
SubscriptionArn=subscription_arn,
AttributeName="FilterPolicy",
AttributeValue=json.dumps(filter_policy),
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

@markers.aws.validated
def test_policy_complexity(
self,
Expand Down
Loading

0 comments on commit 21bdd43

Please sign in to comment.