diff --git a/design/design.md b/design/design.md index bfabeac7..33dc3a63 100644 --- a/design/design.md +++ b/design/design.md @@ -776,9 +776,9 @@ If a server author wants to support resource subscriptions, they must provide ha type ServerOptions struct { ... // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *SubscribeParams) error + SubscribeHandler func(context.Context, ss *ServerSession, *SubscribeParams) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + UnsubscribeHandler func(context.Context, ss *ServerSession, *UnsubscribeParams) error } ``` diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 819edeb6..da53465c 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -78,11 +78,11 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { notificationChans["progress_server"] <- 0 }, - SubscribeHandler: func(context.Context, *SubscribeParams) error { + SubscribeHandler: func(context.Context, *ServerSession, *SubscribeParams) error { notificationChans["subscribe"] <- 0 return nil }, - UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error { + UnsubscribeHandler: func(context.Context, *ServerSession, *UnsubscribeParams) error { notificationChans["unsubscribe"] <- 0 return nil }, diff --git a/mcp/server.go b/mcp/server.go index ba2bd0a9..c8878da3 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -66,9 +66,9 @@ type ServerOptions struct { // the session is automatically closed. KeepAlive time.Duration // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *SubscribeParams) error + SubscribeHandler func(context.Context, *ServerSession, *SubscribeParams) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *UnsubscribeParams) error + UnsubscribeHandler func(context.Context, *ServerSession, *UnsubscribeParams) error // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. HasPrompts bool @@ -469,7 +469,7 @@ func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *Subsc if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } - if err := s.opts.SubscribeHandler(ctx, params); err != nil { + if err := s.opts.SubscribeHandler(ctx, ss, params); err != nil { return nil, err } @@ -488,7 +488,7 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns return nil, jsonrpc2.ErrMethodNotFound } - if err := s.opts.UnsubscribeHandler(ctx, params); err != nil { + if err := s.opts.UnsubscribeHandler(ctx, ss, params); err != nil { return nil, err } diff --git a/mcp/server_test.go b/mcp/server_test.go index 0b853a33..5a161b72 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -281,10 +281,10 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { return nil }, }, @@ -325,10 +325,10 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + SubscribeHandler: func(ctx context.Context, _ *ServerSession, sp *SubscribeParams) error { return nil }, - UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + UnsubscribeHandler: func(ctx context.Context, _ *ServerSession, up *UnsubscribeParams) error { return nil }, CompletionHandler: func(ctx context.Context, ss *ServerSession, params *CompleteParams) (*CompleteResult, error) {