diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index 5bf55108b09..b107b49fe04 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -482,6 +482,7 @@ paths: description: The MMDS data store JSON. schema: type: object + additionalProperties: true 404: description: The MMDS data store content can not be found. schema: @@ -898,21 +899,119 @@ definitions: The CPU configuration template defines a set of bit maps as modifiers of flags accessed by register to be disabled/enabled for the microvm. properties: + kvm_capabilities: + type: array + description: A collection of KVM capabilities to be added or removed (both x86_64 and aarch64) + items: + type: string + description: KVM capability as a numeric string. Prefix with '!' to remove capability. Example "121" (add) or "!121" (remove) cpuid_modifiers: - type: object - description: A collection of CPUIDs to be modified. (x86_64) + type: array + description: A collection of CPUID leaf modifiers (x86_64 only) + items: + $ref: "#/definitions/CpuidLeafModifier" msr_modifiers: - type: object - description: A collection of model specific registers to be modified. (x86_64) + type: array + description: A collection of model specific register modifiers (x86_64 only) + items: + $ref: "#/definitions/MsrModifier" reg_modifiers: - type: object - description: A collection of registers to be modified. (aarch64) + type: array + description: A collection of register modifiers (aarch64 only) + items: + $ref: "#/definitions/ArmRegisterModifier" vcpu_features: - type: object - description: A collection of vcpu features to be modified. (aarch64) - kvm_capabilities: - type: object - description: A collection of kvm capabilities to be modified. (aarch64) + type: array + description: A collection of vCPU features to be modified (aarch64 only) + items: + $ref: "#/definitions/VcpuFeatures" + + CpuidLeafModifier: + type: object + description: Modifier for a CPUID leaf and subleaf (x86_64) + required: + - leaf + - subleaf + - flags + - modifiers + properties: + leaf: + type: string + description: CPUID leaf index as hex, binary, or decimal string (e.g., "0x0", "0b0", "0")) + subleaf: + type: string + description: CPUID subleaf index as hex, binary, or decimal string (e.g., "0x0", "0b0", "0") + flags: + type: integer + format: int32 + description: KVM feature flags for this leaf-subleaf + modifiers: + type: array + description: Register modifiers for this CPUID leaf + items: + $ref: "#/definitions/CpuidRegisterModifier" + + CpuidRegisterModifier: + type: object + description: Modifier for a specific CPUID register within a leaf (x86_64) + required: + - register + - bitmap + properties: + register: + type: string + description: Target CPUID register name + enum: + - eax + - ebx + - ecx + - edx + bitmap: + type: string + description: 32-bit bitmap string defining which bits to modify. Format is "0b" followed by 32 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Example "0b00000000000000000000000000000001" or "0bxxxxxxxxxxxxxxxxxxxxxxxxxxxx0001" + + MsrModifier: + type: object + description: Modifier for a model specific register (x86_64) + required: + - addr + - bitmap + properties: + addr: + type: string + description: 32-bit MSR address as hex, binary, or decimal string (e.g., "0x10a", "0b100001010", "266") + bitmap: + type: string + description: 64-bit bitmap string defining which bits to modify. Format is "0b" followed by 64 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Underscores can be used for readability. Example "0b0000000000000000000000000000000000000000000000000000000000000001" + + ArmRegisterModifier: + type: object + description: Modifier for an ARM register (aarch64) + required: + - addr + - bitmap + properties: + addr: + type: string + description: 64-bit register address as hex, binary, or decimal string (e.g., "0x0", "0b0", "0") + bitmap: + type: string + description: 128-bit bitmap string defining which bits to modify. Format is "0b" followed by up to 128 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Underscores can be used for readability. Example "0b0000000000000000000000000000000000000000000000000000000000000001" + + VcpuFeatures: + type: object + description: vCPU feature modifier (aarch64) + required: + - index + - bitmap + properties: + index: + type: integer + format: int32 + description: Index in the kvm_vcpu_init.features array + bitmap: + type: string + description: 32-bit bitmap string defining which bits to modify. Format is "0b" followed by 32 characters where '0' = clear bit, '1' = set bit, 'x' = don't modify. Example "0b00000000000000000000000001100000" Drive: type: object @@ -1025,6 +1124,11 @@ definitions: description: Configurations for all net devices. items: $ref: "#/definitions/NetworkInterface" + pmem: + type: array + description: Configurations for all pmem devices. + items: + $ref: "#/definitions/Pmem" vsock: $ref: "#/definitions/Vsock" entropy: @@ -1213,6 +1317,7 @@ definitions: type: object description: Describes the contents of MMDS in JSON format. + additionalProperties: true NetworkInterface: type: object @@ -1416,7 +1521,7 @@ definitions: description: The configuration of the serial device properties: - output_path: + serial_out_path: type: string description: Path to a file or named pipe on the host to which serial output should be written. diff --git a/tests/framework/http_api.py b/tests/framework/http_api.py index 0ae2e279571..e5887aed0b0 100644 --- a/tests/framework/http_api.py +++ b/tests/framework/http_api.py @@ -9,6 +9,8 @@ import requests from requests_unixsocket import DEFAULT_SCHEME, UnixAdapter +from framework.swagger_validator import SwaggerValidator, ValidationError + class Session(requests.Session): """An HTTP over UNIX sockets Session @@ -65,6 +67,21 @@ def get(self): self._api.error_callback("GET", self.resource, str(e)) raise assert res.status_code == HTTPStatus.OK, res.json() + + # Validate response against Swagger specification + # only validate successful requests + if self._api.validator and res.status_code == HTTPStatus.OK: + try: + response_body = res.json() + self._api.validator.validate_response( + "GET", self.resource, 200, response_body + ) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Response validation failed for GET {self.resource}: {e.message}" + ) from e + return res def request(self, method, path, **kwargs): @@ -85,6 +102,32 @@ def request(self, method, path, **kwargs): elif "error" in json: msg = json["error"] raise RuntimeError(msg, json, res) + + # Validate request against Swagger specification + # do this after the actual request as we only want to validate successful + # requests as the tests may be trying to pass bad requests and assert an + # error is raised. + if self._api.validator: + if kwargs: + try: + self._api.validator.validate_request(method, path, kwargs) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Request validation failed for {method} {path}: {e.message}" + ) from e + + if res.status_code == HTTPStatus.OK: + try: + response_body = res.json() + self._api.validator.validate_response( + method, path, 200, response_body + ) + except ValidationError as e: + # Re-raise with more context + raise ValidationError( + f"Response validation failed for {method} {path}: {e.message}" + ) from e return res def put(self, **kwargs): @@ -105,13 +148,16 @@ def patch(self, **kwargs): class Api: """A simple HTTP client for the Firecracker API""" - def __init__(self, api_usocket_full_name, *, on_error=None): + def __init__(self, api_usocket_full_name, *, validate=True, on_error=None): self.error_callback = on_error self.socket = api_usocket_full_name url_encoded_path = urllib.parse.quote_plus(api_usocket_full_name) self.endpoint = DEFAULT_SCHEME + url_encoded_path self.session = Session() + # Initialize the swagger validator + self.validator = SwaggerValidator() if validate else None + self.describe = Resource(self, "/") self.vm = Resource(self, "/vm") self.vm_config = Resource(self, "/vm/config") diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index 74ae180950c..9331ed51653 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -634,6 +634,7 @@ def spawn( log_show_origin=False, metrics_path="fc.ndjson", emit_metrics: bool = False, + validate_api: bool = True, ): """Start a microVM as a daemon or in a screen session.""" # pylint: disable=subprocess-run-check @@ -641,6 +642,7 @@ def spawn( self.jailer.setup() self.api = Api( self.jailer.api_socket_path(), + validate=validate_api, on_error=lambda verb, uri, err_msg: self._dump_debug_information( f"Error during {verb} {uri}: {err_msg}" ), diff --git a/tests/framework/swagger_validator.py b/tests/framework/swagger_validator.py new file mode 100644 index 00000000000..0ad9e310268 --- /dev/null +++ b/tests/framework/swagger_validator.py @@ -0,0 +1,186 @@ +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""A validator for Firecracker API Swagger schema""" + +from pathlib import Path + +import yaml +from jsonschema import Draft4Validator, ValidationError + + +def _filter_none_recursive(data): + if isinstance(data, dict): + return {k: _filter_none_recursive(v) for k, v in data.items() if v is not None} + if isinstance(data, list): + return [_filter_none_recursive(item) for item in data if item is not None] + return data + + +class SwaggerValidator: + """Validator for API requests against the Swagger/OpenAPI specification""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize the validator with the Swagger specification.""" + if self._initialized: + return + self._initialized = True + + swagger_path = ( + Path(__file__).parent.parent.parent + / "src" + / "firecracker" + / "swagger" + / "firecracker.yaml" + ) + + with open(swagger_path, "r", encoding="utf-8") as f: + self.swagger_spec = yaml.safe_load(f) + + # Cache validators for each endpoint + self._validators = {} + self._build_validators() + + def _build_validators(self): + """Build JSON schema validators for each endpoint.""" + paths = self.swagger_spec.get("paths", {}) + definitions = self.swagger_spec.get("definitions", {}) + + for path, methods in paths.items(): + for method, spec in methods.items(): + if method.upper() not in ["GET", "PUT", "PATCH", "POST", "DELETE"]: + continue + + # Build request body validators + parameters = spec.get("parameters", []) + for param in parameters: + if param.get("in") == "body" and "schema" in param: + schema = self._resolve_schema(param["schema"], definitions) + if method.upper() == "PATCH": + # do not validate required fields on PATCH requests + schema["required"] = [] + key = ("request", method.upper(), path) + self._validators[key] = Draft4Validator(schema) + + # Build response validators for 200/204 responses + responses = spec.get("responses", {}) + for status_code, response_spec in responses.items(): + if str(status_code) in ["200", "204"] and "schema" in response_spec: + schema = self._resolve_schema( + response_spec["schema"], definitions + ) + key = ("response", method.upper(), path, str(status_code)) + self._validators[key] = Draft4Validator(schema) + + def _resolve_schema(self, schema, definitions): + """Resolve $ref references in schema.""" + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/definitions/"): + def_name = ref_path.split("/")[-1] + if def_name in definitions: + return self._resolve_schema(definitions[def_name], definitions) + + # Recursively resolve nested schemas + resolved = schema.copy() + if "properties" in resolved: + resolved["properties"] = { + k: self._resolve_schema(v, definitions) + for k, v in resolved["properties"].items() + } + if "items" in resolved and isinstance(resolved["items"], dict): + resolved["items"] = self._resolve_schema(resolved["items"], definitions) + + if not "additionalProperties" in resolved: + resolved["additionalProperties"] = False + + return resolved + + def validate_request(self, method, path, body): + """ + Validate a request body against the Swagger specification. + + Args: + method: HTTP method (GET, PUT, PATCH, etc.) + path: API path (e.g., "/drives/{drive_id}") + body: Request body as a dictionary + + Raises: + ValidationError: If the request body doesn't match the schema + """ + # Normalize path - replace specific IDs with parameter placeholders + normalized_path = self._normalize_path(path) + key = ("request", method.upper(), normalized_path) + + if key in self._validators: + validator = self._validators[key] + # Remove None values from body before validation + cleaned_body = _filter_none_recursive(body) + validator.validate(cleaned_body) + else: + raise ValidationError(f"{key} is not in the schema") + + def validate_response(self, method, path, status_code, body): + """ + Validate a response body against the Swagger specification. + + Args: + method: HTTP method (GET, PUT, PATCH, etc.) + path: API path (e.g., "/drives/{drive_id}") + status_code: HTTP status code (e.g., 200, 204) + body: Response body as a dictionary + + Raises: + ValidationError: If the response body doesn't match the schema + """ + # Normalize path - replace specific IDs with parameter placeholders + normalized_path = self._normalize_path(path) + key = ("response", method.upper(), normalized_path, str(status_code)) + + if key in self._validators: + validator = self._validators[key] + # Remove None values from body before validation + cleaned_body = _filter_none_recursive(body) + validator.validate(cleaned_body) + else: + raise ValidationError(f"{key} is not in the schema") + + def _normalize_path(self, path): + """ + Normalize a path by replacing specific IDs with parameter placeholders. + + E.g., "/drives/rootfs" -> "/drives/{drive_id}" + """ + # Match against known patterns in the swagger spec + paths = self.swagger_spec.get("paths", {}) + + # Direct match + if path in paths: + return path + + # Try to match parameterized paths + parts = path.split("/") + for swagger_path in paths.keys(): + swagger_parts = swagger_path.split("/") + if len(parts) == len(swagger_parts): + match = True + for _, (part, swagger_part) in enumerate(zip(parts, swagger_parts)): + # Check if it's a parameter placeholder or exact match + if swagger_part.startswith("{") and swagger_part.endswith("}"): + continue # This is a parameter, any value matches + if part != swagger_part: + match = False + break + + if match: + return swagger_path + + return path