diff --git a/mcp/event.go b/mcp/event.go index d309c4e0..bd78cdee 100644 --- a/mcp/event.go +++ b/mcp/event.go @@ -156,11 +156,13 @@ type EventStore interface { // Open prepares the event store for a given stream. It ensures that the // underlying data structure for the stream is initialized, making it // ready to store event streams. - Open(_ context.Context, sessionID string, streamID StreamID) error + // + // streamIDs must be globally unique. + Open(_ context.Context, sessionID, streamID string) error // Append appends data for an outgoing event to given stream, which is part of the // given session. - Append(_ context.Context, sessionID string, _ StreamID, data []byte) error + Append(_ context.Context, sessionID, streamID string, data []byte) error // After returns an iterator over the data for the given session and stream, beginning // just after the given index. @@ -168,7 +170,7 @@ type EventStore interface { // After's iterator must return an error immediately if any data after index was // dropped; it must not return partial results. // The stream must have been opened previously (see [EventStore.Open]). - After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error] + After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] // SessionClosed informs the store that the given session is finished, along // with all of its streams. @@ -217,9 +219,9 @@ func (dl *dataList) removeFirst() int { // A MemoryEventStore is an [EventStore] backed by memory. type MemoryEventStore struct { mu sync.Mutex - maxBytes int // max total size of all data - nBytes int // current total size of all data - store map[string]map[StreamID]*dataList // session ID -> stream ID -> *dataList + maxBytes int // max total size of all data + nBytes int // current total size of all data + store map[string]map[string]*dataList // session ID -> stream ID -> *dataList } // MemoryEventStoreOptions are options for a [MemoryEventStore]. @@ -258,13 +260,13 @@ const defaultMaxBytes = 10 << 20 // 10 MiB func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore { return &MemoryEventStore{ maxBytes: defaultMaxBytes, - store: make(map[string]map[StreamID]*dataList), + store: make(map[string]map[string]*dataList), } } // Open implements [EventStore.Open]. It ensures that the underlying data // structures for the given session are initialized and ready for use. -func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error { +func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error { s.mu.Lock() defer s.mu.Unlock() s.init(sessionID, streamID) @@ -275,10 +277,10 @@ func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID St // given sessionID and streamID exists, creating it if necessary. It returns the // dataList associated with the specified IDs. // Requires s.mu. -func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { +func (s *MemoryEventStore) init(sessionID, streamID string) *dataList { streamMap, ok := s.store[sessionID] if !ok { - streamMap = make(map[StreamID]*dataList) + streamMap = make(map[string]*dataList) s.store[sessionID] = streamMap } dl, ok := streamMap[streamID] @@ -290,7 +292,7 @@ func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList { } // Append implements [EventStore.Append] by recording data in memory. -func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error { +func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error { s.mu.Lock() defer s.mu.Unlock() dl := s.init(sessionID, streamID) @@ -307,7 +309,7 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID var ErrEventsPurged = errors.New("data purged") // After implements [EventStore.After]. -func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] { +func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] { // Return the data items to yield. // We must copy, because dataList.removeFirst nils out slice elements. copyData := func() ([][]byte, error) { diff --git a/mcp/event_test.go b/mcp/event_test.go index ef4e080b..dacf30e8 100644 --- a/mcp/event_test.go +++ b/mcp/event_test.go @@ -105,8 +105,8 @@ func TestScanEvents(t *testing.T) { func TestMemoryEventStoreState(t *testing.T) { ctx := context.Background() - appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) { - if err := s.Append(ctx, sess, str, []byte(data)); err != nil { + appendEvent := func(s *MemoryEventStore, sess, stream string, data string) { + if err := s.Append(ctx, sess, stream, []byte(data)); err != nil { t.Fatal(err) } } @@ -218,7 +218,7 @@ func TestMemoryEventStoreAfter(t *testing.T) { for _, tt := range []struct { sessionID string - streamID StreamID + streamID string index int want []string wantErr string // if non-empty, error should contain this string @@ -277,11 +277,11 @@ func BenchmarkMemoryEventStore(b *testing.B) { store.SetMaxBytes(test.limit) ctx := context.Background() sessionIDs := make([]string, test.sessions) - streamIDs := make([][3]StreamID, test.sessions) + streamIDs := make([][3]string, test.sessions) for i := range sessionIDs { sessionIDs[i] = fmt.Sprint(i) for j := range 3 { - streamIDs[i][j] = StreamID(randText()) + streamIDs[i][j] = randText() } } payload := make([]byte, test.datasize) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8ac6f59a..67f187ea 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -396,8 +396,8 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er jsonResponse: t.jsonResponse, incoming: make(chan jsonrpc.Message, 10), done: make(chan struct{}), - streams: make(map[StreamID]*stream), - requestStreams: make(map[jsonrpc.ID]StreamID), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), } if t.connection.eventStore == nil { t.connection.eventStore = NewMemoryEventStore(nil) @@ -442,14 +442,14 @@ type streamableServerConn struct { // bound. If we deleted a stream when the response is sent, we would lose the ability // to replay if there was a cut just before the response was transmitted. // Perhaps we could have a TTL for streams that starts just after the response. - streams map[StreamID]*stream + streams map[string]*stream // requestStreams maps incoming requests to their logical stream ID. // // Lifecycle: requestStreams persist for the duration of the session. // // TODO: clean up once requests are handled. See the TODO for streams above. - requestStreams map[jsonrpc.ID]StreamID + requestStreams map[jsonrpc.ID]string } func (c *streamableServerConn) SessionID() string { @@ -466,7 +466,7 @@ func (c *streamableServerConn) SessionID() string { type stream struct { // id is the logical ID for the stream, unique within a session. // an empty string is used for messages that don't correlate with an incoming request. - id StreamID + id string // If isInitialize is set, the stream is in response to an initialize request, // and therefore should include the session ID header. @@ -500,7 +500,7 @@ type stream struct { requests map[jsonrpc.ID]struct{} } -func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) { +func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) { if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil { return nil, err } @@ -517,10 +517,6 @@ func signalChanPtr() *chan struct{} { return &c } -// A StreamID identifies a stream of SSE events. It is globally unique. -// [ServerSession]. -type StreamID string - // We track the incoming request ID inside the handler context using // idContextValue, so that notifications and server->client calls that occur in // the course of handling incoming requests are correlated with the incoming @@ -569,7 +565,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R // It returns an HTTP status code and error message. func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) { // connID 0 corresponds to the default GET request. - id := StreamID("") + id := "" // By default, we haven't seen a last index. Since indices start at 0, we represent // that by -1. This is incremented just before each event is written, in streamResponse // around L407. @@ -669,7 +665,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // notifications or server->client requests made in the course of handling. // Update accounting for this incoming payload. if len(requests) > 0 { - stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse) + stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse) if err != nil { http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return @@ -860,7 +856,7 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per // streamID and message index idx. // // See also [parseEventID]. -func formatEventID(sid StreamID, idx int) string { +func formatEventID(sid string, idx int) string { return fmt.Sprintf("%s_%d", sid, idx) } @@ -868,17 +864,17 @@ func formatEventID(sid StreamID, idx int) string { // index. // // See also [formatEventID]. -func parseEventID(eventID string) (sid StreamID, idx int, ok bool) { +func parseEventID(eventID string) (streamID string, idx int, ok bool) { parts := strings.Split(eventID, "_") if len(parts) != 2 { return "", 0, false } - stream := StreamID(parts[0]) + streamID = parts[0] idx, err := strconv.Atoi(parts[1]) if err != nil || idx < 0 { return "", 0, false } - return StreamID(stream), idx, true + return streamID, idx, true } // Read implements the [Connection] interface. @@ -922,7 +918,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e // // For messages sent outside of a request context, this is the default // connection "". - var forStream StreamID + var forStream string if forRequest.IsValid() { c.mu.Lock() forStream = c.requestStreams[forRequest] diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 79e9645f..e24478be 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1139,7 +1139,7 @@ func mustMarshal(v any) json.RawMessage { func TestEventID(t *testing.T) { tests := []struct { - sid StreamID + sid string idx int }{ {"0", 0},