Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions mcp/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,21 @@ type EventStore interface {
// Open prepares the event store for a given stream. It ensures that the
// underlying data structure for the stream is initialized, making it
// ready to store event streams.
Open(_ context.Context, sessionID string, streamID StreamID) error
//
// streamIDs must be globally unique.
Open(_ context.Context, sessionID, streamID string) error

// Append appends data for an outgoing event to given stream, which is part of the
// given session.
Append(_ context.Context, sessionID string, _ StreamID, data []byte) error
Append(_ context.Context, sessionID, streamID string, data []byte) error

// After returns an iterator over the data for the given session and stream, beginning
// just after the given index.
// Once the iterator yields a non-nil error, it will stop.
// After's iterator must return an error immediately if any data after index was
// dropped; it must not return partial results.
// The stream must have been opened previously (see [EventStore.Open]).
After(_ context.Context, sessionID string, _ StreamID, index int) iter.Seq2[[]byte, error]
After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error]

// SessionClosed informs the store that the given session is finished, along
// with all of its streams.
Expand Down Expand Up @@ -217,9 +219,9 @@ func (dl *dataList) removeFirst() int {
// A MemoryEventStore is an [EventStore] backed by memory.
type MemoryEventStore struct {
mu sync.Mutex
maxBytes int // max total size of all data
nBytes int // current total size of all data
store map[string]map[StreamID]*dataList // session ID -> stream ID -> *dataList
maxBytes int // max total size of all data
nBytes int // current total size of all data
store map[string]map[string]*dataList // session ID -> stream ID -> *dataList
}

// MemoryEventStoreOptions are options for a [MemoryEventStore].
Expand Down Expand Up @@ -258,13 +260,13 @@ const defaultMaxBytes = 10 << 20 // 10 MiB
func NewMemoryEventStore(opts *MemoryEventStoreOptions) *MemoryEventStore {
return &MemoryEventStore{
maxBytes: defaultMaxBytes,
store: make(map[string]map[StreamID]*dataList),
store: make(map[string]map[string]*dataList),
}
}

// Open implements [EventStore.Open]. It ensures that the underlying data
// structures for the given session are initialized and ready for use.
func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID StreamID) error {
func (s *MemoryEventStore) Open(_ context.Context, sessionID, streamID string) error {
s.mu.Lock()
defer s.mu.Unlock()
s.init(sessionID, streamID)
Expand All @@ -275,10 +277,10 @@ func (s *MemoryEventStore) Open(_ context.Context, sessionID string, streamID St
// given sessionID and streamID exists, creating it if necessary. It returns the
// dataList associated with the specified IDs.
// Requires s.mu.
func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList {
func (s *MemoryEventStore) init(sessionID, streamID string) *dataList {
streamMap, ok := s.store[sessionID]
if !ok {
streamMap = make(map[StreamID]*dataList)
streamMap = make(map[string]*dataList)
s.store[sessionID] = streamMap
}
dl, ok := streamMap[streamID]
Expand All @@ -290,7 +292,7 @@ func (s *MemoryEventStore) init(sessionID string, streamID StreamID) *dataList {
}

// Append implements [EventStore.Append] by recording data in memory.
func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID StreamID, data []byte) error {
func (s *MemoryEventStore) Append(_ context.Context, sessionID, streamID string, data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()
dl := s.init(sessionID, streamID)
Expand All @@ -307,7 +309,7 @@ func (s *MemoryEventStore) Append(_ context.Context, sessionID string, streamID
var ErrEventsPurged = errors.New("data purged")

// After implements [EventStore.After].
func (s *MemoryEventStore) After(_ context.Context, sessionID string, streamID StreamID, index int) iter.Seq2[[]byte, error] {
func (s *MemoryEventStore) After(_ context.Context, sessionID, streamID string, index int) iter.Seq2[[]byte, error] {
// Return the data items to yield.
// We must copy, because dataList.removeFirst nils out slice elements.
copyData := func() ([][]byte, error) {
Expand Down
10 changes: 5 additions & 5 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func TestScanEvents(t *testing.T) {
func TestMemoryEventStoreState(t *testing.T) {
ctx := context.Background()

appendEvent := func(s *MemoryEventStore, sess string, str StreamID, data string) {
if err := s.Append(ctx, sess, str, []byte(data)); err != nil {
appendEvent := func(s *MemoryEventStore, sess, stream string, data string) {
if err := s.Append(ctx, sess, stream, []byte(data)); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestMemoryEventStoreAfter(t *testing.T) {

for _, tt := range []struct {
sessionID string
streamID StreamID
streamID string
index int
want []string
wantErr string // if non-empty, error should contain this string
Expand Down Expand Up @@ -277,11 +277,11 @@ func BenchmarkMemoryEventStore(b *testing.B) {
store.SetMaxBytes(test.limit)
ctx := context.Background()
sessionIDs := make([]string, test.sessions)
streamIDs := make([][3]StreamID, test.sessions)
streamIDs := make([][3]string, test.sessions)
for i := range sessionIDs {
sessionIDs[i] = fmt.Sprint(i)
for j := range 3 {
streamIDs[i][j] = StreamID(randText())
streamIDs[i][j] = randText()
}
}
payload := make([]byte, test.datasize)
Expand Down
30 changes: 13 additions & 17 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er
jsonResponse: t.jsonResponse,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
streams: make(map[StreamID]*stream),
requestStreams: make(map[jsonrpc.ID]StreamID),
streams: make(map[string]*stream),
requestStreams: make(map[jsonrpc.ID]string),
}
if t.connection.eventStore == nil {
t.connection.eventStore = NewMemoryEventStore(nil)
Expand Down Expand Up @@ -442,14 +442,14 @@ type streamableServerConn struct {
// bound. If we deleted a stream when the response is sent, we would lose the ability
// to replay if there was a cut just before the response was transmitted.
// Perhaps we could have a TTL for streams that starts just after the response.
streams map[StreamID]*stream
streams map[string]*stream

// requestStreams maps incoming requests to their logical stream ID.
//
// Lifecycle: requestStreams persist for the duration of the session.
//
// TODO: clean up once requests are handled. See the TODO for streams above.
requestStreams map[jsonrpc.ID]StreamID
requestStreams map[jsonrpc.ID]string
}

func (c *streamableServerConn) SessionID() string {
Expand All @@ -466,7 +466,7 @@ func (c *streamableServerConn) SessionID() string {
type stream struct {
// id is the logical ID for the stream, unique within a session.
// an empty string is used for messages that don't correlate with an incoming request.
id StreamID
id string

// If isInitialize is set, the stream is in response to an initialize request,
// and therefore should include the session ID header.
Expand Down Expand Up @@ -500,7 +500,7 @@ type stream struct {
requests map[jsonrpc.ID]struct{}
}

func (c *streamableServerConn) newStream(ctx context.Context, id StreamID, isInitialize, jsonResponse bool) (*stream, error) {
func (c *streamableServerConn) newStream(ctx context.Context, id string, isInitialize, jsonResponse bool) (*stream, error) {
if err := c.eventStore.Open(ctx, c.sessionID, id); err != nil {
return nil, err
}
Expand All @@ -517,10 +517,6 @@ func signalChanPtr() *chan struct{} {
return &c
}

// A StreamID identifies a stream of SSE events. It is globally unique.
// [ServerSession].
type StreamID string

// We track the incoming request ID inside the handler context using
// idContextValue, so that notifications and server->client calls that occur in
// the course of handling incoming requests are correlated with the incoming
Expand Down Expand Up @@ -569,7 +565,7 @@ func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.R
// It returns an HTTP status code and error message.
func (c *streamableServerConn) serveGET(w http.ResponseWriter, req *http.Request) {
// connID 0 corresponds to the default GET request.
id := StreamID("")
id := ""
// By default, we haven't seen a last index. Since indices start at 0, we represent
// that by -1. This is incremented just before each event is written, in streamResponse
// around L407.
Expand Down Expand Up @@ -669,7 +665,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
// notifications or server->client requests made in the course of handling.
// Update accounting for this incoming payload.
if len(requests) > 0 {
stream, err = c.newStream(req.Context(), StreamID(randText()), isInitialize, c.jsonResponse)
stream, err = c.newStream(req.Context(), randText(), isInitialize, c.jsonResponse)
if err != nil {
http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -860,25 +856,25 @@ func (c *streamableServerConn) messages(ctx context.Context, stream *stream, per
// streamID and message index idx.
//
// See also [parseEventID].
func formatEventID(sid StreamID, idx int) string {
func formatEventID(sid string, idx int) string {
return fmt.Sprintf("%s_%d", sid, idx)
}

// parseEventID parses a Last-Event-ID value into a logical stream id and
// index.
//
// See also [formatEventID].
func parseEventID(eventID string) (sid StreamID, idx int, ok bool) {
func parseEventID(eventID string) (streamID string, idx int, ok bool) {
parts := strings.Split(eventID, "_")
if len(parts) != 2 {
return "", 0, false
}
stream := StreamID(parts[0])
streamID = parts[0]
idx, err := strconv.Atoi(parts[1])
if err != nil || idx < 0 {
return "", 0, false
}
return StreamID(stream), idx, true
return streamID, idx, true
}

// Read implements the [Connection] interface.
Expand Down Expand Up @@ -922,7 +918,7 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e
//
// For messages sent outside of a request context, this is the default
// connection "".
var forStream StreamID
var forStream string
if forRequest.IsValid() {
c.mu.Lock()
forStream = c.requestStreams[forRequest]
Expand Down
2 changes: 1 addition & 1 deletion mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ func mustMarshal(v any) json.RawMessage {

func TestEventID(t *testing.T) {
tests := []struct {
sid StreamID
sid string
idx int
}{
{"0", 0},
Expand Down