diff --git a/mcp/transport.go b/mcp/transport.go index dccc920b..f2d5c72d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -308,6 +308,19 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { for { var raw json.RawMessage err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } select { case incoming <- msgOrErr{msg: raw, err: err}: case <-closed: diff --git a/mcp/transport_test.go b/mcp/transport_test.go index c63b84ee..18a326e8 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,6 +7,7 @@ package mcp import ( "context" "io" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -51,3 +52,41 @@ func TestBatchFraming(t *testing.T) { } } } + +func TestIOConnRead(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + + { + name: "valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, + want: "", + }, + + { + name: "newline at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} + `, + want: "", + }, + { + name: "bad data at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, + want: "invalid trailing data at the end of stream", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := newIOConn(rwc{ + rc: io.NopCloser(strings.NewReader(tt.input)), + }) + _, err := tr.Read(context.Background()) + if err != nil && err.Error() != tt.want { + t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) + } + }) + } +}