diff --git a/mcp/schema_cache.go b/mcp/schema_cache.go new file mode 100644 index 00000000..f7176f8a --- /dev/null +++ b/mcp/schema_cache.go @@ -0,0 +1,74 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "reflect" + "sync" + + "github.com/google/jsonschema-go/jsonschema" +) + +// schemaCache provides concurrent-safe caching for JSON schemas. +// It caches both by reflect.Type (for auto-generated schemas) and +// by schema pointer (for pre-defined schemas). +// +// This cache significantly improves performance for stateless server deployments +// where tools are re-registered on every request. Without caching, each AddTool +// call would trigger expensive reflection-based schema generation and resolution. +// +// Create a cache using [NewSchemaCache] and pass it to [ServerOptions.SchemaCache]. +type schemaCache struct { + // byType caches schemas generated from Go types via jsonschema.ForType. + // Key: reflect.Type, Value: *cachedSchema + byType sync.Map + + // bySchema caches resolved schemas for pre-defined Schema objects. + // Key: *jsonschema.Schema (pointer identity), Value: *jsonschema.Resolved + // This uses pointer identity because integrators typically reuse the same + // Tool objects across requests, so the schema pointer remains stable. + bySchema sync.Map +} + +// cachedSchema holds both the generated schema and its resolved form. +type cachedSchema struct { + schema *jsonschema.Schema + resolved *jsonschema.Resolved +} + +// NewSchemaCache creates a new schema cache for use with [ServerOptions.SchemaCache]. +// Safe for concurrent use, unbounded. +func NewSchemaCache() *schemaCache { + return &schemaCache{} +} + +// getByType retrieves a cached schema by Go type. +// Returns the schema, resolved schema, and whether the cache hit. +func (c *schemaCache) getByType(t reflect.Type) (*jsonschema.Schema, *jsonschema.Resolved, bool) { + if v, ok := c.byType.Load(t); ok { + cs := v.(*cachedSchema) + return cs.schema, cs.resolved, true + } + return nil, nil, false +} + +// setByType caches a schema by Go type. +func (c *schemaCache) setByType(t reflect.Type, schema *jsonschema.Schema, resolved *jsonschema.Resolved) { + c.byType.Store(t, &cachedSchema{schema: schema, resolved: resolved}) +} + +// getBySchema retrieves a cached resolved schema by the original schema pointer. +// This is used when integrators provide pre-defined schemas (e.g., github-mcp-server pattern). +func (c *schemaCache) getBySchema(schema *jsonschema.Schema) (*jsonschema.Resolved, bool) { + if v, ok := c.bySchema.Load(schema); ok { + return v.(*jsonschema.Resolved), true + } + return nil, false +} + +// setBySchema caches a resolved schema by the original schema pointer. +func (c *schemaCache) setBySchema(schema *jsonschema.Schema, resolved *jsonschema.Resolved) { + c.bySchema.Store(schema, resolved) +} diff --git a/mcp/schema_cache_benchmark_test.go b/mcp/schema_cache_benchmark_test.go new file mode 100644 index 00000000..b2bb5f1b --- /dev/null +++ b/mcp/schema_cache_benchmark_test.go @@ -0,0 +1,160 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "testing" + + "github.com/google/jsonschema-go/jsonschema" +) + +// BenchmarkAddToolTypedHandler measures performance of AddTool with typed handlers. +// This simulates the stateless server pattern where new servers are created per request. +func BenchmarkAddToolTypedHandler(b *testing.B) { + type SearchInput struct { + Query string `json:"query" jsonschema:"required"` + Page int `json:"page"` + PerPage int `json:"per_page"` + } + + type SearchOutput struct { + Results []string `json:"results"` + Total int `json:"total"` + } + + handler := func(ctx context.Context, req *CallToolRequest, in SearchInput) (*CallToolResult, SearchOutput, error) { + return &CallToolResult{}, SearchOutput{}, nil + } + + tool := &Tool{ + Name: "search", + Description: "Search for items", + } + + // Create a shared cache for caching benefit + cache := NewSchemaCache() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{ + SchemaCache: cache, + }) + AddTool(s, tool, handler) + } +} + +// BenchmarkAddToolPreDefinedSchema measures performance with pre-defined schemas. +// This simulates how github-mcp-server registers tools with manual InputSchema. +func BenchmarkAddToolPreDefinedSchema(b *testing.B) { + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": {Type: "string", Description: "Search query"}, + "page": {Type: "integer", Description: "Page number"}, + "per_page": {Type: "integer", Description: "Results per page"}, + }, + Required: []string{"query"}, + } + + handler := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + } + + tool := &Tool{ + Name: "search", + Description: "Search for items", + InputSchema: schema, // Pre-defined schema like github-mcp-server + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil) + s.AddTool(tool, handler) + } +} + +// BenchmarkAddToolTypedHandlerNoCache measures performance without caching. +// Used to compare before/after performance. +func BenchmarkAddToolTypedHandlerNoCache(b *testing.B) { + type SearchInput struct { + Query string `json:"query" jsonschema:"required"` + Page int `json:"page"` + PerPage int `json:"per_page"` + } + + type SearchOutput struct { + Results []string `json:"results"` + Total int `json:"total"` + } + + handler := func(ctx context.Context, req *CallToolRequest, in SearchInput) (*CallToolResult, SearchOutput, error) { + return &CallToolResult{}, SearchOutput{}, nil + } + + tool := &Tool{ + Name: "search", + Description: "Search for items", + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // No cache - each iteration generates new schemas + s := NewServer(&Implementation{Name: "test", Version: "1.0"}, nil) + AddTool(s, tool, handler) + } +} + +// BenchmarkAddToolMultipleTools simulates registering multiple tools like github-mcp-server. +func BenchmarkAddToolMultipleTools(b *testing.B) { + type Input1 struct { + Query string `json:"query"` + } + type Input2 struct { + ID int `json:"id"` + } + type Input3 struct { + Name string `json:"name"` + Value string `json:"value"` + } + type Output struct { + Success bool `json:"success"` + } + + handler1 := func(ctx context.Context, req *CallToolRequest, in Input1) (*CallToolResult, Output, error) { + return &CallToolResult{}, Output{}, nil + } + handler2 := func(ctx context.Context, req *CallToolRequest, in Input2) (*CallToolResult, Output, error) { + return &CallToolResult{}, Output{}, nil + } + handler3 := func(ctx context.Context, req *CallToolRequest, in Input3) (*CallToolResult, Output, error) { + return &CallToolResult{}, Output{}, nil + } + + tool1 := &Tool{Name: "tool1", Description: "Tool 1"} + tool2 := &Tool{Name: "tool2", Description: "Tool 2"} + tool3 := &Tool{Name: "tool3", Description: "Tool 3"} + + // Create a shared cache for caching benefit + cache := NewSchemaCache() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{ + SchemaCache: cache, + }) + AddTool(s, tool1, handler1) + AddTool(s, tool2, handler2) + AddTool(s, tool3, handler3) + } +} diff --git a/mcp/schema_cache_test.go b/mcp/schema_cache_test.go new file mode 100644 index 00000000..8e015f87 --- /dev/null +++ b/mcp/schema_cache_test.go @@ -0,0 +1,238 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "reflect" + "testing" + + "github.com/google/jsonschema-go/jsonschema" +) + +func TestSchemaCacheByType(t *testing.T) { + cache := NewSchemaCache() + + type TestInput struct { + Name string `json:"name"` + } + + rt := reflect.TypeFor[TestInput]() + + // Initially not in cache + _, _, ok := cache.getByType(rt) + if ok { + t.Error("expected cache miss for new type") + } + + // Add to cache + schema := &jsonschema.Schema{Type: "object"} + resolved, err := schema.Resolve(nil) + if err != nil { + t.Fatalf("failed to resolve schema: %v", err) + } + cache.setByType(rt, schema, resolved) + + // Now should hit + gotSchema, gotResolved, ok := cache.getByType(rt) + if !ok { + t.Error("expected cache hit after set") + } + if gotSchema != schema { + t.Error("schema mismatch") + } + if gotResolved != resolved { + t.Error("resolved schema mismatch") + } +} + +func TestSchemaCacheBySchema(t *testing.T) { + cache := NewSchemaCache() + + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": {Type: "string"}, + }, + } + + // Initially not in cache + _, ok := cache.getBySchema(schema) + if ok { + t.Error("expected cache miss for new schema") + } + + // Add to cache + resolved, err := schema.Resolve(nil) + if err != nil { + t.Fatalf("failed to resolve schema: %v", err) + } + cache.setBySchema(schema, resolved) + + // Now should hit + gotResolved, ok := cache.getBySchema(schema) + if !ok { + t.Error("expected cache hit after set") + } + if gotResolved != resolved { + t.Error("resolved schema mismatch") + } + + // Different schema pointer should miss + schema2 := &jsonschema.Schema{Type: "object"} + _, ok = cache.getBySchema(schema2) + if ok { + t.Error("expected cache miss for different schema pointer") + } +} + +func TestSetSchemaCachesGeneratedSchemas(t *testing.T) { + cache := NewSchemaCache() + + type TestInput struct { + Query string `json:"query"` + } + + rt := reflect.TypeFor[TestInput]() + + // First call should generate and cache + var sfield1 any + var rfield1 *jsonschema.Resolved + _, err := setSchema[TestInput](&sfield1, &rfield1, cache) + if err != nil { + t.Fatalf("setSchema failed: %v", err) + } + + // Verify it's in cache + cachedSchema, cachedResolved, ok := cache.getByType(rt) + if !ok { + t.Fatal("schema not cached after first setSchema call") + } + + // Second call should hit cache + var sfield2 any + var rfield2 *jsonschema.Resolved + _, err = setSchema[TestInput](&sfield2, &rfield2, cache) + if err != nil { + t.Fatalf("setSchema failed on second call: %v", err) + } + + // Should return same cached objects + if sfield2.(*jsonschema.Schema) != cachedSchema { + t.Error("expected cached schema to be returned") + } + if rfield2 != cachedResolved { + t.Error("expected cached resolved schema to be returned") + } +} + +func TestSetSchemaCachesProvidedSchemas(t *testing.T) { + cache := NewSchemaCache() + + // This simulates the github-mcp-server pattern: + // schema is created once and reused across requests + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "query": {Type: "string"}, + }, + } + + // First call should resolve and cache + var sfield1 any = schema + var rfield1 *jsonschema.Resolved + _, err := setSchema[map[string]any](&sfield1, &rfield1, cache) + if err != nil { + t.Fatalf("setSchema failed: %v", err) + } + + // Verify it's in cache + cachedResolved, ok := cache.getBySchema(schema) + if !ok { + t.Fatal("resolved schema not cached after first setSchema call") + } + if rfield1 != cachedResolved { + t.Error("expected same resolved schema") + } + + // Second call with same schema pointer should hit cache + var sfield2 any = schema + var rfield2 *jsonschema.Resolved + _, err = setSchema[map[string]any](&sfield2, &rfield2, cache) + if err != nil { + t.Fatalf("setSchema failed on second call: %v", err) + } + + if rfield2 != cachedResolved { + t.Error("expected cached resolved schema to be returned") + } +} + +func TestSetSchemaNoCacheWhenNil(t *testing.T) { + type TestInput struct { + Query string `json:"query"` + } + + // First call without cache + var sfield1 any + var rfield1 *jsonschema.Resolved + _, err := setSchema[TestInput](&sfield1, &rfield1, nil) + if err != nil { + t.Fatalf("setSchema failed: %v", err) + } + + // Second call without cache - should still generate a new schema + var sfield2 any + var rfield2 *jsonschema.Resolved + _, err = setSchema[TestInput](&sfield2, &rfield2, nil) + if err != nil { + t.Fatalf("setSchema failed on second call: %v", err) + } + + // Both calls should succeed, schemas should be equivalent but not same pointer + // (since no caching is happening) + if sfield1 == nil || sfield2 == nil { + t.Error("expected schemas to be generated") + } + if rfield1 == nil || rfield2 == nil { + t.Error("expected resolved schemas to be generated") + } +} + +func TestAddToolCachesBetweenCalls(t *testing.T) { + cache := NewSchemaCache() + + type GreetInput struct { + Name string `json:"name" jsonschema:"the name to greet"` + } + + type GreetOutput struct { + Message string `json:"message"` + } + + handler := func(ctx context.Context, req *CallToolRequest, in GreetInput) (*CallToolResult, GreetOutput, error) { + return &CallToolResult{}, GreetOutput{Message: "Hello, " + in.Name}, nil + } + + tool := &Tool{ + Name: "greet", + Description: "Greet someone", + } + + // Simulate stateless server pattern: create new server each time, but share cache + for i := 0; i < 3; i++ { + s := NewServer(&Implementation{Name: "test", Version: "1.0"}, &ServerOptions{ + SchemaCache: cache, + }) + AddTool(s, tool, handler) + } + + // Verify schema was cached by type + rt := reflect.TypeFor[GreetInput]() + _, _, ok := cache.getByType(rt) + if !ok { + t.Error("expected schema to be cached by type after multiple AddTool calls") + } +} diff --git a/mcp/server.go b/mcp/server.go index d4317222..0d9e33f8 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -89,6 +89,10 @@ type ServerOptions struct { // If true, advertises the tools capability during initialization, // even if no tools have been registered. HasTools bool + // SchemaCache, if non-nil, enables caching of JSON schemas for tools. + // This can significantly improve performance for stateless server + // deployments where tools are re-registered on every request. + SchemaCache *schemaCache // GetSessionID provides the next session ID to use for an incoming request. // If nil, a default randomly generated ID will be used. @@ -239,7 +243,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { +func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], cache *schemaCache) (*Tool, ToolHandler, error) { tt := *t // Special handling for an "any" input: treat as an empty object. @@ -248,7 +252,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } var inputResolved *jsonschema.Resolved - if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + if _, err := setSchema[In](&tt.InputSchema, &inputResolved, cache); err != nil { return nil, nil, fmt.Errorf("input schema: %w", err) } @@ -263,7 +267,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan ) if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error - elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved) + elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved, cache) if err != nil { return nil, nil, fmt.Errorf("output schema: %v", err) } @@ -364,29 +368,81 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // pointer: if the user provided the schema, they may have intentionally // derived it from the pointer type, and handling of zero values is up to them. // +// If cache is non-nil, schemas are cached to avoid repeated reflection. +// // TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we // should have a jsonschema.Zero(schema) helper? -func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *schemaCache) (zero any, err error) { + rt := reflect.TypeFor[T]() + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + var internalSchema *jsonschema.Schema + if *sfield == nil { - rt := reflect.TypeFor[T]() - if rt.Kind() == reflect.Pointer { - rt = rt.Elem() - zero = reflect.Zero(rt).Interface() + // Case 1: No schema provided - check type cache first + if cache != nil { + if schema, resolved, ok := cache.getByType(rt); ok { + *sfield = schema + *rfield = resolved + return zero, nil + } } - // TODO: we should be able to pass nil opts here. + + // Generate schema via reflection (expensive, but cached for next time if cache is set) internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) - if err == nil { - *sfield = internalSchema + if err != nil { + return zero, err } - } else if err := remarshal(*sfield, &internalSchema); err != nil { - return zero, err + *sfield = internalSchema + + // Resolve and optionally cache + resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + if err != nil { + return zero, err + } + *rfield = resolved + if cache != nil { + cache.setByType(rt, internalSchema, resolved) + } + return zero, nil } + + // Case 2: Schema was provided + // Check if it's a *jsonschema.Schema we can cache by pointer + if providedSchema, ok := (*sfield).(*jsonschema.Schema); ok { + if cache != nil { + if resolved, ok := cache.getBySchema(providedSchema); ok { + *rfield = resolved + return zero, nil + } + } + // Need to resolve and optionally cache + internalSchema = providedSchema + } else { + // Schema provided as different type (e.g., map) - need to remarshal + if err := remarshal(*sfield, &internalSchema); err != nil { + return zero, err + } + } + + // Resolve the schema + resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) if err != nil { return zero, err } - *rfield, err = internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - return zero, err + *rfield = resolved + + // Cache by schema pointer if we got a direct *jsonschema.Schema + if cache != nil { + if providedSchema, ok := (*sfield).(*jsonschema.Schema); ok { + cache.setBySchema(providedSchema, resolved) + } + } + + return zero, nil } // AddTool adds a tool and typed tool handler to the server. @@ -409,7 +465,7 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err // tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed // description of this automatic behavior. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - tt, hh, err := toolForErr(t, h) + tt, hh, err := toolForErr(t, h, s.opts.SchemaCache) if err != nil { panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err)) } diff --git a/mcp/server_test.go b/mcp/server_test.go index d8c0df65..310eb694 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -562,7 +562,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { return nil, out, nil } - gott, goth, err := toolForErr(tool, th) + gott, goth, err := toolForErr(tool, th, nil) if err != nil { t.Fatal(err) }