Skip to content

Commit

Permalink
fix basePath
Browse files Browse the repository at this point in the history
  • Loading branch information
bentsku committed Jun 4, 2023
1 parent bb80d52 commit 061ed81
Show file tree
Hide file tree
Showing 10 changed files with 4,651 additions and 1,554 deletions.
79 changes: 64 additions & 15 deletions localstack/services/apigateway/helpers.py
Expand Up @@ -552,13 +552,34 @@ def import_api_from_openapi_spec(
store = get_apigateway_store(account_id=account_id, region=region)
rest_api_container = store.rest_apis[rest_api.id]

def is_api_key_required(path_payload: dict) -> bool:
# TODO: consolidate and refactor with `create_authorizer`, duplicate logic for now
if "security" not in path_payload:
return False

security_schemes = path_payload.get("security")
for security_scheme in security_schemes:
for security_scheme_name, _ in security_scheme.items():
# $.securityDefinitions is Swagger 2.0
# $.components.SecuritySchemes is OpenAPI 3.0
security_definitions = body.get("securityDefinitions") or body.get(
"components", {}
).get("securitySchemes", {})
if security_scheme_name in security_definitions:
security_config = security_definitions.get(security_scheme_name)
if security_config.get("type") == "apiKey":
return True
return False

def create_authorizer(path_payload: dict) -> Optional[Authorizer]:
if "security" not in path_payload:
return None

security_schemes = path_payload.get("security")
for security_scheme in security_schemes:
for security_scheme_name, _ in security_scheme.items():
# $.securityDefinitions is Swagger 2.0
# $.components.SecuritySchemes is OpenAPI 3.0
security_definitions = body.get("securityDefinitions") or body.get(
"components", {}
).get("securitySchemes", {})
Expand All @@ -570,8 +591,8 @@ def create_authorizer(path_payload: dict) -> Optional[Authorizer]:
if not aws_apigateway_authorizer:
continue

if authorizers.get(security_scheme_name):
return authorizers.get(security_scheme_name)
if authorizer := authorizers.get(security_scheme_name):
return authorizer

authorizer_type = aws_apigateway_authorizer.get("type", "").upper()
# TODO: do we need validation of resources here?
Expand Down Expand Up @@ -664,8 +685,8 @@ def add_path_methods(rel_path: str, parts: List[str], parent_id=""):

# Get the `Method` requestParameters and requestModels
request_parameters_schema = field_schema.get("parameters", [])
request_parameters = {}
if request_parameters_schema:
request_parameters = {}
request_models = {}
for req_param_data in request_parameters_schema:
# TODO: does `required` attribute maps to a RequestValidator? check with AWS
Expand Down Expand Up @@ -693,10 +714,16 @@ def add_path_methods(rel_path: str, parts: List[str], parent_id=""):
param_location,
)
continue

method_resource.request_parameters = request_parameters or None
method_resource.request_models = request_models or None

# we check if there's a path parameter, AWS adds the requestParameter automatically
resource_name = parts[-1].strip("/")
if resource_name.startswith("{") and not resource_name.endswith("+}"):
path_parameter = resource_name[1:-1] # remove the curly braces
request_parameters[f"method.request.path.{path_parameter}"] = True

method_resource.request_parameters = request_parameters or None

# Create the `MethodResponse` for the previously created `Method`
method_responses = field_schema.get("responses", {})
for method_status_code, method_response in method_responses.items():
Expand Down Expand Up @@ -790,21 +817,23 @@ def add_path_methods(rel_path: str, parts: List[str], parent_id=""):
return resource

def create_method_resource(child, method, method_schema):
kwargs = {
"api_key_required": None,
"authorization_type": "NONE",
}
authorization_type = "NONE"
api_key_required = is_api_key_required(method_schema)
kwargs = {}

if authorizer := create_authorizer(method_schema):
# override the authorizer_type if it's a TOKEN or REQUEST to CUSTOM
if (authorizer_type := authorizer["type"]) in ("TOKEN", "REQUEST"):
authorizer_type = "CUSTOM"
authorization_type = "CUSTOM"
else:
authorization_type = authorizer_type

kwargs["authorization_type"] = authorizer_type
kwargs["authorizer_id"] = authorizer["id"]

return child.add_method(
method,
api_key_required=api_key_required,
authorization_type=authorization_type,
operation_name=method_schema.get("operationId"),
**kwargs,
)
Expand All @@ -828,12 +857,32 @@ def create_method_resource(child, method, method_schema):
store.rest_apis[rest_api.id].models[name] = model

# determine base path
basepath_mode = query_params.get("basepath") or "prepend"
# default basepath mode is "ignore"
# see https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-import-api-basePath.html
basepath_mode = query_params.get("basepath") or "ignore"
base_path = ""
if basepath_mode == "prepend":
base_path = resolved_schema.get("basePath") or ""

if basepath_mode != "ignore":
# in Swagger 2.0, the basePath is a top-level property
if "basePath" in resolved_schema:
base_path = resolved_schema["basePath"]

# in OpenAPI 3.0, the basePath is contained in the server object
elif "servers" in resolved_schema:
servers_property = resolved_schema.get("servers", [])
for server in servers_property:
# first, we check if there are a basePath variable (1st choice)
if "basePath" in server.get("variables", {}):
base_path = server["variables"]["basePath"].get("default", "")
break
# TODO: this allows both absolute and relative part, but AWS might not manage relative
url_path = urlparse.urlparse(server.get("url", "")).path
if url_path:
base_path = url_path if url_path != "/" else ""
break

if basepath_mode == "split":
base_path = (resolved_schema.get("basePath") or "").strip("/").split("/")[0]
base_path = base_path.strip("/").partition("/")[-1]
base_path = f"/{base_path}" if base_path else ""

for path in resolved_schema.get("paths", {}):
Expand Down
9 changes: 7 additions & 2 deletions localstack/services/apigateway/provider.py
Expand Up @@ -256,7 +256,9 @@ def put_rest_api(self, context: RequestContext, request: PutRestApiRequest) -> R
store = get_apigateway_store(account_id=context.account_id, region=context.region)
store.rest_apis[request["restApiId"]].rest_api = response
# TODO: verify this
return to_rest_api_response_json(response)
response = to_rest_api_response_json(response)
response.setdefault("tags", {})
return response

def delete_rest_api(self, context: RequestContext, rest_api_id: String) -> None:
try:
Expand Down Expand Up @@ -1313,7 +1315,10 @@ def import_rest_api(
"PutRestApi",
put_api_request,
)
return self.put_rest_api(put_api_context, put_api_request)
put_api_response = self.put_rest_api(put_api_context, put_api_request)
if not put_api_response.get("tags"):
put_api_response.pop("tags", None)
return put_api_response

# integrations

Expand Down

0 comments on commit 061ed81

Please sign in to comment.