diff --git a/mcp/event.go b/mcp/event.go index f4f4eeea..0dd8734b 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -153,6 +153,11 @@ func scanEvents(r io.Reader) iter.Seq2[Event, error] { // // All of an EventStore's methods must be safe for use by multiple goroutines. type EventStore interface { + // Open prepares the event store for a given stream. It ensures that the + // underlying data structure for the stream is initialized, making it + // ready to store event streams. + Open(_ context.Context, sessionID string, streamID StreamID) error + // Append appends data for an outgoing event to given stream, which is part of the // given session. Append(_ context.Context, sessionID string, _ StreamID, data []byte) error @@ -162,6 +167,7 @@ type EventStore interface { // Once the iterator yields a non-nil error, it will stop. // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. + // The stream must have been opened previously (see [EventStore.Open]). After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] // SessionClosed informs the store that the given session is finished, along @@ -256,11 +262,20 @@ func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { } } -// Append implements [EventStore.Append] by recording data in memory. -func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { +// Open implements [EventStore.Open]. It ensures that the underlying data +// structures for the given session are initialized and ready for use. +func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error { s.mu.Lock() defer s.mu.Unlock() + s.init(sessionID, streamID) + return nil +} +// init is an internal helper function that ensures the nested map structure for a +// given sessionID and streamID exists, creating it if necessary. It returns the +// dataList associated with the specified IDs. +// Requires s.mu. +func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { streamMap, ok := s.store[sessionID] if !ok { streamMap = make(map[StreamID]*dataList) @@ -271,6 +286,14 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID dl = &dataList{} streamMap[streamID] = dl } + return dl +} + +// Append implements [EventStore.Append] by recording data in memory. +func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + dl := s.init(sessionID, streamID) // Purge before adding, so at least the current data item will be present. // (That could result in nBytes > maxBytes, but we'll live with that.) s.purge() diff --git a/mcp/streamable.go b/mcp/streamable.go index f56b7084..92f20206 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -401,7 +401,7 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp } // Connect implements the [Transport] interface. -func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) { +func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, error) { if t.connection != nil { return nil, fmt.Errorf("transport already connected") } @@ -415,13 +415,17 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error) streams: make(map[StreamID]*stream), requestStreams: make(map[jsonrpc.ID]StreamID), } + if t.connection.eventStore == nil { + t.connection.eventStore = NewMemoryEventStore(nil) + } // Stream 0 corresponds to the hanging 'GET'. // // It is always text/event-stream, since it must carry arbitrarily many // messages. - t.connection.streams[""] = newStream("", false) - if t.connection.eventStore == nil { - t.connection.eventStore = NewMemoryEventStore(nil) + var err error + t.connection.streams[""], err = t.connection.newStream(ctx, "", false) + if err != nil { + return nil, err } return t.connection, nil } @@ -490,7 +494,7 @@ type stream struct { // that there are messages available to write into the HTTP response. // In addition, the presence of a channel guarantees that at most one HTTP response // can receive messages for a logical stream. After claiming the stream, incoming - // requests should read from outgoing, to ensure that no new messages are missed. + // requests should read from the event store, to ensure that no new messages are missed. // // To simplify locking, signal is an atomic. We need an atomic.Pointer, because // you can't set an atomic.Value to nil. @@ -502,22 +506,21 @@ type stream struct { // The following mutable fields are protected by the mutex of the containing // StreamableServerTransport. - // outgoing is the list of outgoing messages, enqueued by server methods that - // write notifications and responses, and dequeued by streamResponse. - outgoing [][]byte - // streamRequests is the set of unanswered incoming RPCs for the stream. // - // Requests persist until their response data has been added to outgoing. + // Requests persist until their response data has been added to the event store. requests map[jsonrpc.ID]struct{} } -func newStream(id StreamID, jsonResponse bool) *stream { +func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, jsonResponse bool) (*stream, error) { + if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { + return nil, err + } return &stream{ id: id, jsonResponse: jsonResponse, requests: make(map[jsonrpc.ID]struct{}), - } + }, nil } func signalChanPtr() *chan struct{} { @@ -668,7 +671,11 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream = newStream(StreamID(randText()), c.jsonResponse) + stream, err = c.newStream(req.Context(), StreamID(randText()), c.jsonResponse) + if err != nil { + http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) + return + } c.mu.Lock() c.streams[stream.id] = stream stream.requests = requests @@ -706,13 +713,13 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter var msgs []json.RawMessage ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, false) { - if !ok { + for msg, err := range c.messages(ctx, stream, false, -1) { + if err != nil { if ctx.Err() != nil { w.WriteHeader(http.StatusNoContent) return } else { - http.Error(w, http.StatusText(http.StatusGone), http.StatusGone) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } @@ -770,44 +777,18 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, } } - if lastIndex >= 0 { - // Resume. - for data, err := range c.eventStore.After(req.Context(), c.SessionID(), stream.id, lastIndex) { - if err != nil { - // TODO: reevaluate these status codes. - // Maybe distinguish between storage errors, which are 500s, and missing - // session or stream ID--can these arise from bad input? - status := http.StatusInternalServerError - if errors.Is(err, ErrEventsPurged) { - status = http.StatusInsufficientStorage - } - errorf(status, "failed to read events: %v", err) - return - } - // The iterator yields events beginning just after lastIndex, or it would have - // yielded an error. - if !write(data) { - return - } - } - } - // Repeatedly collect pending outgoing events and send them. ctx := req.Context() - for msg, ok := range c.messages(ctx, stream, persistent) { - if !ok { + for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { + if err != nil { if ctx.Err() != nil && writes == 0 { // This probably doesn't matter, but respond with NoContent if the client disconnected. w.WriteHeader(http.StatusNoContent) } else { - errorf(http.StatusGone, "stream terminated") + errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } return } - if err := c.eventStore.Append(req.Context(), c.SessionID(), stream.id, msg); err != nil { - errorf(http.StatusInternalServerError, "storing event: %v", err.Error()) - return - } if !write(msg) { return } @@ -816,27 +797,33 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, // messages iterates over messages sent to the current stream. // +// persistent indicates if it is the main GET listener, which should never be +// terminated. +// lastIndex is the index of the last seen event, iteration begins at lastIndex+1. +// // The first iterated value is the received JSON message. The second iterated -// value is an OK value indicating whether the stream terminated normally. +// value is an error value indicating whether the stream terminated normally. +// Iteration ends at the first non-nil error. // // If the stream did not terminate normally, it is either because ctx was // cancelled, or the connection is closed: check the ctx.Err() to differentiate // these cases. -func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool) iter.Seq2[json.RawMessage, bool] { - return func(yield func(json.RawMessage, bool) bool) { +func (c *streamableServerConn) messages(ctx context.Context, stream *stream, persistent bool, lastIndex int) iter.Seq2[json.RawMessage, error] { + return func(yield func(json.RawMessage, error) bool) { for { c.mu.Lock() - outgoing := stream.outgoing - stream.outgoing = nil nOutstanding := len(stream.requests) c.mu.Unlock() - - for _, data := range outgoing { - if !yield(data, true) { + for data, err := range c.eventStore.After(ctx, c.SessionID(), stream.id, lastIndex) { + if err != nil { + yield(nil, err) return } + if !yield(data, nil) { + return + } + lastIndex++ } - // If all requests have been handled and replied to, we should terminate this connection. // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream." // ยง6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server @@ -850,13 +837,14 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per case <-*stream.signal.Load(): // there are new outgoing messages // return to top of loop case <-c.done: // session is closed - yield(nil, false) + yield(nil, errors.New("session is closed")) return case <-ctx.Done(): - yield(nil, false) + yield(nil, ctx.Err()) return } } + } } @@ -963,9 +951,9 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e stream = c.streams[""] } - // TODO: if there is nothing to send these messages to (as would happen, for example, if forConn == "" - // and the client never did a GET), then memory will grow without bound. Consider a mitigation. - stream.outgoing = append(stream.outgoing, data) + if err := c.eventStore.Append(ctx, c.SessionID(), stream.id, data); err != nil { + return fmt.Errorf("error storing event: %w", err) + } if isResponse { // Once we've put the reply on the queue, it's no longer outstanding. delete(stream.requests, forRequest)