diff --git a/server.py b/server.py index a1834dc..bb572e3 100644 --- a/server.py +++ b/server.py @@ -4,4 +4,4 @@ from src.server import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/components/customizers.py b/src/components/customizers.py index 0d61936..dd2ccca 100644 --- a/src/components/customizers.py +++ b/src/components/customizers.py @@ -51,6 +51,18 @@ def customize_components( # Hide this fully for now. component.output_schema = None + # SCHEMA REFERENCE HANDLING: + # OpenAPI-generated MCP tool schemas contain $defs with nested $ref references + # (e.g., $ref -> $defs -> $ref chains). + # FastMCP cannot resolve these complex reference chains. + # + # The following code strips $defs from MCP tool input/output schemas generated by OpenAPI docs. + # However, stripping $defs still resulted in "PointerToNowhere" errors due to input schemas + # containing nested $ref references failing to resolve. + # + # Openapi_resolver script replaces ALL $ref with inline schema definitions, and eliminates the need for $defs entirely. + # Keep this code to prevent any remaining $defs from breaking MCP clients. + # if hasattr(component, 'parameters') and isinstance(component.parameters, dict): if "$defs" in component.parameters: logger.debug(f" Found $defs with {len(component.parameters['$defs'])} definitions") diff --git a/src/server.py b/src/server.py index 17c9b49..1310563 100644 --- a/src/server.py +++ b/src/server.py @@ -9,6 +9,7 @@ from .config import Config from .routes.mappers import custom_route_mapper from .utils.logging import setup_logging +from .utils.openapi_resolver import resolve_refs logger = setup_logging() @@ -38,6 +39,11 @@ def create_mcp_server() -> FastMCP: openapi_spec = load_openapi_spec() + # Workaround to resolve all $ref references since FastMCP cannot resolve complex reference chains + logger.info("Resolving OpenAPI $ref references...") + openapi_spec = resolve_refs(openapi_spec) + logger.info("OpenAPI $ref references resolved") + client = create_cortex_client() mcp_server = FastMCP.from_openapi( diff --git a/src/utils/openapi_resolver.py b/src/utils/openapi_resolver.py new file mode 100644 index 0000000..f213536 --- /dev/null +++ b/src/utils/openapi_resolver.py @@ -0,0 +1,136 @@ +""" +OpenAPI $ref resolver for FastMCP compatibility. + +FastMCP cannot resolve complex reference chains (e.g., $ref -> $defs -> $ref chains). +This script replaces ALL $ref in the entire OpenAPI spec with inline schema definitions. + +Since customizers.py currently hides output schemas, $ref resolutions are only visible in INPUT schemas. + +Performance note: Currently processes all ~800 endpoints but only ~20 become MCP tools. +The dereferencing work is only visible in PointInTimeMetrics tool, since it's the only +MCP-enabled endpoint with $ref chains in its input schema (output schemas are hidden). + +TODO: Optimize to only process MCP-enabled paths +""" + +from typing import Any + + +def resolve_refs(spec: dict[str, Any]) -> dict[str, Any]: + """ + Recursively resolve all $ref references in an OpenAPI specification. + + This is a workaround for FastMCP's issue with $ref handling where it + doesn't properly include schema definitions when creating tool input schemas. + + Args: + spec: OpenAPI specification dictionary + + Returns: + Modified spec with all $refs resolved inline + """ + # Create a copy to avoid modifying the original + spec = spec.copy() + + # Get the components/schemas section for reference resolution + schemas = spec.get("components", {}).get("schemas", {}) + + def resolve_schema(obj: Any, visited: set[str] | None = None) -> Any: + """Recursively resolve $ref in an object.""" + if visited is None: + visited = set() + + if isinstance(obj, dict): + # Check if this is a $ref + if "$ref" in obj and len(obj) == 1: + ref_path = obj["$ref"] + + # Prevent infinite recursion + if ref_path in visited: + # Return the ref as-is to avoid infinite loop + return obj + + visited.add(ref_path) + + # Extract schema name from reference + if ref_path.startswith("#/components/schemas/"): + schema_name = ref_path.split("/")[-1] + if schema_name in schemas: + # Recursively resolve the referenced schema + resolved = resolve_schema(schemas[schema_name].copy(), visited) + visited.remove(ref_path) + return resolved + + # If we can't resolve, return as-is + visited.remove(ref_path) + return obj + else: + # Recursively process all values in the dict + result = {} + for key, value in obj.items(): + result[key] = resolve_schema(value, visited) + return result + + elif isinstance(obj, list): + # Recursively process all items in the list + return [resolve_schema(item, visited) for item in obj] + else: + # Return primitive values as-is + return obj + + # Resolve refs in all paths + if "paths" in spec: + spec["paths"] = resolve_schema(spec["paths"]) + + return spec + + +# Use if context becomes too large for inline definitions +def resolve_refs_with_defs(spec: dict[str, Any]) -> dict[str, Any]: + """ + Alternative approach: Keep $refs but ensure $defs section is populated. + + This transforms OpenAPI $refs to JSON Schema format and includes + all referenced schemas in a $defs section at the root level. + + Args: + spec: OpenAPI specification dictionary + + Returns: + Modified spec with $refs pointing to $defs and all definitions included + """ + # Create a copy to avoid modifying the original + spec = spec.copy() + + # Get the components/schemas section + schemas = spec.get("components", {}).get("schemas", {}) + + # Create $defs section at root level + if schemas: + spec["$defs"] = schemas.copy() + + def transform_refs(obj: Any) -> Any: + """Transform OpenAPI $refs to JSON Schema $refs.""" + if isinstance(obj, dict): + result = {} + for key, value in obj.items(): + if key == "$ref" and isinstance(value, str): + # Transform the reference format + if value.startswith("#/components/schemas/"): + schema_name = value.split("/")[-1] + result[key] = f"#/$defs/{schema_name}" + else: + result[key] = value + else: + result[key] = transform_refs(value) + return result + elif isinstance(obj, list): + return [transform_refs(item) for item in obj] + else: + return obj + + # Transform all refs in paths + if "paths" in spec: + spec["paths"] = transform_refs(spec["paths"]) + + return spec diff --git a/tests/test_openapi_resolver.py b/tests/test_openapi_resolver.py new file mode 100644 index 0000000..f0873d8 --- /dev/null +++ b/tests/test_openapi_resolver.py @@ -0,0 +1,283 @@ +"""Tests for OpenAPI $ref resolver.""" +import json + +from src.utils.openapi_resolver import resolve_refs, resolve_refs_with_defs + + +class TestOpenAPIResolver: + """Test suite for OpenAPI $ref resolver.""" + + def test_resolve_simple_ref(self): + """Test resolving a simple $ref to a schema.""" + spec = { + "paths": { + "/api/test": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/TestSchema"} + } + } + } + } + } + }, + "components": { + "schemas": { + "TestSchema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "value": {"type": "integer"} + }, + "required": ["name"] + } + } + } + } + + resolved = resolve_refs(spec) + + # Check that the $ref was resolved + schema = resolved["paths"]["/api/test"]["post"]["requestBody"]["content"]["application/json"]["schema"] + assert "$ref" not in schema + assert schema["type"] == "object" + assert "name" in schema["properties"] + assert "value" in schema["properties"] + assert schema["required"] == ["name"] + + def test_resolve_nested_refs(self): + """Test resolving nested $refs.""" + spec = { + "paths": { + "/api/test": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/ParentSchema"} + } + } + } + } + } + }, + "components": { + "schemas": { + "ParentSchema": { + "type": "object", + "properties": { + "child": {"$ref": "#/components/schemas/ChildSchema"} + } + }, + "ChildSchema": { + "type": "object", + "properties": { + "data": {"type": "string"} + } + } + } + } + } + + resolved = resolve_refs(spec) + + # Check that all $refs were resolved + schema = resolved["paths"]["/api/test"]["post"]["requestBody"]["content"]["application/json"]["schema"] + assert "$ref" not in json.dumps(schema) + assert schema["properties"]["child"]["properties"]["data"]["type"] == "string" + + def test_resolve_circular_refs(self): + """Test handling of circular references.""" + spec = { + "paths": { + "/api/test": { + "get": { + "responses": { + "200": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Node"} + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Node": { + "type": "object", + "properties": { + "value": {"type": "string"}, + "parent": {"$ref": "#/components/schemas/Node"} + } + } + } + } + } + + # Should not raise an exception and should handle circular refs + resolved = resolve_refs(spec) + + # Should have resolved the top-level ref but left circular ref intact + schema = resolved["paths"]["/api/test"]["get"]["responses"]["200"]["content"]["application/json"]["schema"] + assert schema["type"] == "object" + assert "value" in schema["properties"] + # Circular ref should be preserved to prevent infinite recursion + assert schema["properties"]["parent"]["$ref"] == "#/components/schemas/Node" + + def test_resolve_refs_in_arrays(self): + """Test resolving $refs inside arrays.""" + spec = { + "paths": { + "/api/test": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"$ref": "#/components/schemas/Item"} + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Item": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"} + } + } + } + } + } + + resolved = resolve_refs(spec) + + schema = resolved["paths"]["/api/test"]["post"]["requestBody"]["content"]["application/json"]["schema"] + assert schema["type"] == "array" + assert "$ref" not in schema["items"] + assert schema["items"]["properties"]["id"]["type"] == "integer" + + def test_preserve_non_schema_refs(self): + """Test that non-schema $refs are preserved.""" + spec = { + "paths": { + "/api/test": { + "get": { + "parameters": [ + {"$ref": "#/components/parameters/CommonParam"} + ] + } + } + }, + "components": { + "parameters": { + "CommonParam": { + "name": "test", + "in": "query", + "schema": {"type": "string"} + } + } + } + } + + resolved = resolve_refs(spec) + + # Non-schema refs should be preserved + param_ref = resolved["paths"]["/api/test"]["get"]["parameters"][0] + assert param_ref == {"$ref": "#/components/parameters/CommonParam"} + + def test_resolve_refs_with_defs(self): + """Test the alternative approach using $defs.""" + spec = { + "paths": { + "/api/test": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/TestSchema"} + } + } + } + } + } + }, + "components": { + "schemas": { + "TestSchema": { + "type": "object", + "properties": {"name": {"type": "string"}} + } + } + } + } + + resolved = resolve_refs_with_defs(spec) + + # Should have $defs section + assert "$defs" in resolved + assert "TestSchema" in resolved["$defs"] + + # Refs should be transformed to JSON Schema format + schema_ref = resolved["paths"]["/api/test"]["post"]["requestBody"]["content"]["application/json"]["schema"] + assert schema_ref["$ref"] == "#/$defs/TestSchema" + + def test_no_components_section(self): + """Test handling when components section is missing.""" + spec = { + "paths": { + "/api/test": { + "get": { + "responses": { + "200": { + "description": "OK" + } + } + } + } + } + } + + # Should not raise an exception + resolved = resolve_refs(spec) + assert resolved == spec + + def test_original_spec_unchanged(self): + """Test that the original spec is not modified.""" + spec = { + "paths": { + "/api/test": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/TestSchema"} + } + } + } + } + } + }, + "components": { + "schemas": { + "TestSchema": {"type": "object"} + } + } + } + + original_json = json.dumps(spec, sort_keys=True) + resolve_refs(spec) + + # Original should be unchanged + assert json.dumps(spec, sort_keys=True) == original_json