Skip to content

Commit

Permalink
fix validator and model check during invocation (#7846)
Browse files Browse the repository at this point in the history
  • Loading branch information
bentsku committed Mar 17, 2023
1 parent 1dfc637 commit 92dca7d
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 15 deletions.
2 changes: 2 additions & 0 deletions localstack/services/apigateway/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
# special tag name to allow specifying a custom ID for new REST APIs
TAG_KEY_CUSTOM_ID = "_custom_id_"

EMPTY_MODEL = "Empty"

# TODO: make the CRUD operations in this file generic for the different model types (authorizes, validators, ...)


Expand Down
31 changes: 21 additions & 10 deletions localstack/services/apigateway/invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from localstack.services.apigateway import helpers
from localstack.services.apigateway.context import ApiInvocationContext
from localstack.services.apigateway.helpers import (
EMPTY_MODEL,
extract_path_params,
extract_query_string_params,
get_cors_response,
Expand Down Expand Up @@ -58,11 +59,15 @@ def is_request_valid(self) -> bool:
return True

# check if there is a validator for this request
validator = self.apigateway_client.get_request_validator(
restApiId=self.context.api_id, requestValidatorId=resource["requestValidatorId"]
)
if validator is None:
return True
try:
validator = self.apigateway_client.get_request_validator(
restApiId=self.context.api_id, requestValidatorId=resource["requestValidatorId"]
)
except ClientError as e:
if "NotFoundException" in e:
return True

raise

# are we validating the body?
if self.should_validate_body(validator):
Expand All @@ -78,11 +83,13 @@ def is_request_valid(self) -> bool:
return True

def validate_body(self, resource):
# we need a model to validate the body
if "requestModels" not in resource or not resource["requestModels"]:
return False
# if there's no model to validate the body, use the Empty model
# https://docs.aws.amazon.com/cdk/api/v1/docs/@aws-cdk_aws-apigateway.EmptyModel.html
if not (request_models := resource.get("requestModels")):
schema_name = EMPTY_MODEL
else:
schema_name = request_models.get(APPLICATION_JSON, EMPTY_MODEL)

schema_name = resource["requestModels"].get(APPLICATION_JSON)
try:
model = self.apigateway_client.get_model(
restApiId=self.context.api_id,
Expand All @@ -96,7 +103,11 @@ def validate_body(self, resource):
validate(instance=json.loads(self.context.data), schema=json.loads(model["schema"]))
return True
except ValidationError as e:
LOG.warning("failed to validate request body", e)
LOG.warning("failed to validate request body %s", e)
return False
except json.JSONDecodeError as e:
# TODO: for now, it could also be the loading of the schema failing but it will be validated at some point
LOG.warning("failed to validate request body, request data is not valid JSON %s", e)
return False

# TODO implement parameters and headers
Expand Down
14 changes: 10 additions & 4 deletions localstack/services/apigateway/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from localstack.aws.forwarder import NotImplementedAvoidFallbackError, create_aws_request_context
from localstack.constants import APPLICATION_JSON
from localstack.services.apigateway.helpers import (
EMPTY_MODEL,
OpenApiExporter,
apply_json_patch_safe,
get_apigateway_store,
Expand Down Expand Up @@ -137,7 +138,7 @@ def create_rest_api(self, context: RequestContext, request: CreateRestApiRequest
rest_api_container = RestApiContainer(rest_api=response)
store.rest_apis[result["id"]] = rest_api_container
# add the 2 default models
rest_api_container.models["Empty"] = DEFAULT_EMPTY_MODEL
rest_api_container.models[EMPTY_MODEL] = DEFAULT_EMPTY_MODEL
rest_api_container.models["Error"] = DEFAULT_ERROR_MODEL

return response
Expand Down Expand Up @@ -442,7 +443,7 @@ def put_method(
if request_models:
for content_type, model_name in request_models.items():
# FIXME: add Empty model to rest api at creation
if model_name == "Empty":
if model_name == EMPTY_MODEL:
continue
if model_name not in rest_api_container.models:
raise BadRequestException(f"Invalid model identifier specified: {model_name}")
Expand Down Expand Up @@ -513,10 +514,15 @@ def update_method(
continue

elif path == "/requestValidatorId" and value not in rest_api.validators:
if not value:
# you can remove a requestValidator by passing an empty string as a value
patch_op = {"op": "remove", "path": path, "value": value}
applicable_patch_operations.append(patch_op)
continue
raise BadRequestException("Invalid Request Validator identifier specified")

elif path.startswith("/requestModels/"):
if value != "Empty" and value not in rest_api.models:
if value != EMPTY_MODEL and value not in rest_api.models:
raise BadRequestException(f"Invalid model identifier specified: {value}")

applicable_patch_operations.append(patch_operation)
Expand Down Expand Up @@ -1499,7 +1505,7 @@ def to_response_json(model_type, data, api_id=None, self_link=None, id_attr=None

DEFAULT_EMPTY_MODEL = Model(
id=short_uid()[:6],
name="Empty",
name=EMPTY_MODEL,
contentType="application/json",
description="This is a default empty schema model",
schema=json.dumps(
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/apigateway/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,17 @@ def _factory(
return api_id

yield _factory


@pytest.fixture
def apigw_redeploy_api(apigateway_client):
def _factory(rest_api_id: str, stage_name: str):
deployment_id = apigateway_client.create_deployment(restApiId=rest_api_id)["id"]

apigateway_client.update_stage(
restApiId=rest_api_id,
stageName=stage_name,
patchOperations=[{"op": "replace", "path": "/deploymentId", "value": deployment_id}],
)

return _factory
234 changes: 234 additions & 0 deletions tests/integration/apigateway/test_apigateway_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import json

import pytest
import requests

from localstack.services.awslambda.lambda_utils import LAMBDA_RUNTIME_PYTHON39
from localstack.utils.aws.arns import parse_arn
from localstack.utils.strings import short_uid
from tests.integration.apigateway.apigateway_fixtures import api_invoke_url
from tests.integration.awslambda.test_lambda import TEST_LAMBDA_AWS_PROXY


class TestApiGatewayCommon:
"""
In this class we won't test individual CRUD API calls but how those will affect the integrations and
requests/responses from the API.
"""

@pytest.mark.aws_validated
@pytest.mark.skip_snapshot_verify(
paths=[
"$.invalid-request-body.Type",
"$..methodIntegration.cacheNamespace",
"$..methodIntegration.integrationResponses",
"$..methodIntegration.passthroughBehavior",
"$..methodIntegration.requestParameters",
"$..methodIntegration.timeoutInMillis",
]
)
def test_api_gateway_request_validator(
self,
apigateway_client,
create_lambda_function,
create_rest_apigw,
apigw_redeploy_api,
lambda_client,
snapshot,
):
# TODO: create fixture which will provide basic integrations where we can test behaviour
# see once we have more cases how we can regroup functionality into one or several fixtures
# example: create a basic echo lambda + integrations + deploy stage
snapshot.add_transformers_list(
[
snapshot.transform.key_value("requestValidatorId"),
snapshot.transform.key_value("id"), # deployment id
snapshot.transform.key_value("fn_name"), # lambda name
snapshot.transform.key_value("fn_arn"), # lambda arn
]
)

fn_name = f"test-{short_uid()}"
create_lambda_function(
func_name=fn_name,
handler_file=TEST_LAMBDA_AWS_PROXY,
runtime=LAMBDA_RUNTIME_PYTHON39,
)
lambda_arn = lambda_client.get_function(FunctionName=fn_name)["Configuration"][
"FunctionArn"
]
# matching on lambda id for reference replacement in snapshots
snapshot.match("register-lambda", {"fn_name": fn_name, "fn_arn": lambda_arn})

parsed_arn = parse_arn(lambda_arn)
region = parsed_arn["region"]
account_id = parsed_arn["account"]

api_id, _, root = create_rest_apigw(name="aws lambda api")

resource_1 = apigateway_client.create_resource(
restApiId=api_id, parentId=root, pathPart="test"
)["id"]

resource_id = apigateway_client.create_resource(
restApiId=api_id, parentId=resource_1, pathPart="{test}"
)["id"]

validator_id = apigateway_client.create_request_validator(
restApiId=api_id,
name="test-validator",
validateRequestParameters=True,
validateRequestBody=True,
)["id"]

apigateway_client.put_method(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
authorizationType="NONE",
requestValidatorId=validator_id,
requestParameters={"method.request.path.test": True},
)

apigateway_client.put_integration(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
integrationHttpMethod="POST",
type="AWS_PROXY",
uri=f"arn:aws:apigateway:{region}:lambda:path//2015-03-31/functions/"
f"{lambda_arn}/invocations",
)
apigateway_client.put_method_response(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
statusCode="200",
)
apigateway_client.put_integration_response(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
statusCode="200",
)

stage_name = "local"
deploy_1 = apigateway_client.create_deployment(restApiId=api_id, stageName=stage_name)
snapshot.match("deploy-1", deploy_1)

source_arn = f"arn:aws:execute-api:{region}:{account_id}:{api_id}/*/*/test/*"

lambda_client.add_permission(
FunctionName=lambda_arn,
StatementId=str(short_uid()),
Action="lambda:InvokeFunction",
Principal="apigateway.amazonaws.com",
SourceArn=source_arn,
)

url = api_invoke_url(api_id, stage=stage_name, path="/test/value")
response = requests.post(url, json={"test": "test"})
assert response.ok
assert json.loads(response.json()["body"]) == {"test": "test"}

response = apigateway_client.update_method(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
patchOperations=[
{
"op": "add",
"path": "/requestParameters/method.request.path.issuer",
"value": "true",
},
{
"op": "remove",
"path": "/requestParameters/method.request.path.test",
"value": "true",
},
],
)
snapshot.match("change-request-path-names", response)

apigw_redeploy_api(rest_api_id=api_id, stage_name=stage_name)

response = requests.post(url, json={"test": "test"})
# FIXME: for now, not implemented in LocalStack, we don't validate RequestParameters yet
# assert response.status_code == 400
if response.status_code == 400:
snapshot.match("missing-required-request-params", response.json())

# create Model schema to validate body
apigateway_client.create_model(
restApiId=api_id,
name="testSchema",
contentType="application/json",
schema=json.dumps(
{
"title": "testSchema",
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
"required": ["a", "b"],
}
),
)
# then attach the schema to the method
response = apigateway_client.update_method(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
patchOperations=[
{"op": "add", "path": "/requestModels/application~1json", "value": "testSchema"},
],
)
snapshot.match("add-schema", response)

response = apigateway_client.update_method(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
patchOperations=[
{
"op": "add",
"path": "/requestParameters/method.request.path.test",
"value": "true",
},
{
"op": "remove",
"path": "/requestParameters/method.request.path.issuer",
"value": "true",
},
],
)
snapshot.match("revert-request-path-names", response)

apigw_redeploy_api(rest_api_id=api_id, stage_name=stage_name)

# the validator should then check against this schema and fail
response = requests.post(url, json={"test": "test"})
assert response.status_code == 400
snapshot.match("invalid-request-body", response.json())

# remove the validator from the method
response = apigateway_client.update_method(
restApiId=api_id,
resourceId=resource_id,
httpMethod="POST",
patchOperations=[
{
"op": "replace",
"path": "/requestValidatorId",
"value": "",
},
],
)
snapshot.match("remove-validator", response)

apigw_redeploy_api(rest_api_id=api_id, stage_name=stage_name)

response = requests.post(url, json={"test": "test"})
assert response.ok
assert json.loads(response.json()["body"]) == {"test": "test"}

0 comments on commit 92dca7d

Please sign in to comment.