diff --git a/mcp/server.go b/mcp/server.go index c8878da3..e69a872e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -567,11 +567,25 @@ func (s *Server) Connect(ctx context.Context, t Transport) (*ServerSession, erro return connect(ctx, t, s) } -func (s *Server) callInitializedHandler(ctx context.Context, ss *ServerSession, params *InitializedParams) (Result, error) { - if s.opts.KeepAlive > 0 { - ss.startKeepalive(s.opts.KeepAlive) +func (ss *ServerSession) initialized(ctx context.Context, params *InitializedParams) (Result, error) { + if ss.server.opts.KeepAlive > 0 { + ss.startKeepalive(ss.server.opts.KeepAlive) } - return callNotificationHandler(ctx, s.opts.InitializedHandler, ss, params) + ss.mu.Lock() + hasParams := ss.initializeParams != nil + wasInitialized := ss._initialized + if hasParams { + ss._initialized = true + } + ss.mu.Unlock() + + if !hasParams { + return nil, fmt.Errorf("%q before %q", notificationInitialized, methodInitialize) + } + if wasInitialized { + return nil, fmt.Errorf("duplicate %q received", notificationInitialized) + } + return callNotificationHandler(ctx, ss.server.opts.InitializedHandler, ss, params) } func (s *Server) callRootsListChangedHandler(ctx context.Context, ss *ServerSession, params *RootsListChangedParams) (Result, error) { @@ -603,7 +617,7 @@ type ServerSession struct { mu sync.Mutex logLevel LoggingLevel initializeParams *InitializeParams - initialized bool + _initialized bool keepaliveCancel context.CancelFunc } @@ -702,7 +716,7 @@ var serverMethodInfos = map[string]methodInfo{ methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), 0), methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), 0), - notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), notification|missingParamsOK), + notificationInitialized: newMethodInfo(sessionMethod((*ServerSession).initialized), notification|missingParamsOK), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), notification|missingParamsOK), notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), notification), } @@ -729,13 +743,13 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn } // handle invokes the method described by the given JSON RPC request. func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) { ss.mu.Lock() - initialized := ss.initialized + initialized := ss._initialized ss.mu.Unlock() // From the spec: // "The client SHOULD NOT send requests other than pings before the server // has responded to the initialize request." switch req.Method { - case "initialize", "ping": + case methodInitialize, methodPing, notificationInitialized: default: if !initialized { return nil, fmt.Errorf("method %q is invalid during session initialization", req.Method) @@ -756,17 +770,6 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam ss.initializeParams = params ss.mu.Unlock() - // Mark the connection as initialized when this method exits. - // TODO: Technically, the server should not be considered initialized until it has - // *responded*, but we don't have adequate visibility into the jsonrpc2 - // connection to implement that easily. In any case, once we've initialized - // here, we can handle requests. - defer func() { - ss.mu.Lock() - ss.initialized = true - ss.mu.Unlock() - }() - // If we support the client's version, reply with it. Otherwise, reply with our // latest version. version := params.ProtocolVersion diff --git a/mcp/testdata/conformance/server/bad_requests.txtar b/mcp/testdata/conformance/server/bad_requests.txtar index e9e9d483..44816189 100644 --- a/mcp/testdata/conformance/server/bad_requests.txtar +++ b/mcp/testdata/conformance/server/bad_requests.txtar @@ -38,11 +38,12 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } -{"jsonrpc":"2.0", "id": 3, "method": "notifications/initialized"} -{"jsonrpc":"2.0", "method":"ping"} -{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel"} -{"jsonrpc":"2.0", "id": 5, "method": "completion/complete"} -{"jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null} +{ "jsonrpc":"2.0", "id": 3, "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc":"2.0", "method":"ping" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel" } +{ "jsonrpc":"2.0", "id": 5, "method": "completion/complete" } +{ "jsonrpc":"2.0", "id": 4, "method": "logging/setLevel", "params": null } -- server -- { diff --git a/mcp/testdata/conformance/server/lifecycle.txtar b/mcp/testdata/conformance/server/lifecycle.txtar new file mode 100644 index 00000000..eba287e0 --- /dev/null +++ b/mcp/testdata/conformance/server/lifecycle.txtar @@ -0,0 +1,57 @@ +This test checks that the server obeys the rules for initialization lifecycle, +and rejects non-ping requests until 'initialized' is received. + +See also modelcontextprotocol/go-sdk#225. + +-- client -- +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } + } +} +{ "jsonrpc":"2.0", "id": 1, "method":"ping" } +{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } +{ "jsonrpc": "2.0", "id": 3, "method": "tools/list" } + +-- server -- +{ + "jsonrpc": "2.0", + "id": 1, + "result": { + "capabilities": { + "logging": {} + }, + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "testServer", + "version": "v1.0.0" + } + } +} +{ + "jsonrpc": "2.0", + "id": 1, + "result": {} +} +{ + "jsonrpc": "2.0", + "id": 2, + "error": { + "code": 0, + "message": "method \"tools/list\" is invalid during session initialization" + } +} +{ + "jsonrpc": "2.0", + "id": 3, + "result": { + "tools": [] + } +} diff --git a/mcp/testdata/conformance/server/prompts.txtar b/mcp/testdata/conformance/server/prompts.txtar index 3fd036e6..fdaf7932 100644 --- a/mcp/testdata/conformance/server/prompts.txtar +++ b/mcp/testdata/conformance/server/prompts.txtar @@ -18,9 +18,11 @@ code_review "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } { "jsonrpc": "2.0", "id": 5, "method": "prompts/get" } + -- server -- { "jsonrpc": "2.0", diff --git a/mcp/testdata/conformance/server/resources.txtar b/mcp/testdata/conformance/server/resources.txtar index ae2e23cb..314817b8 100644 --- a/mcp/testdata/conformance/server/resources.txtar +++ b/mcp/testdata/conformance/server/resources.txtar @@ -21,6 +21,7 @@ info.txt "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{ "jsonrpc":"2.0", "method": "notifications/initialized" } { "jsonrpc": "2.0", "id": 2, "method": "resources/list" } { "jsonrpc": "2.0", "id": 3, diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index 01fdb266..29dfdc18 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -20,6 +20,7 @@ greet "clientInfo": { "name": "ExampleClient", "version": "1.0.0" } } } +{"jsonrpc":"2.0", "method": "notifications/initialized"} { "jsonrpc": "2.0", "id": 2, "method": "tools/list" } { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" }