From 3aa6048db1a1552ef93bfe47dc56bedd60b7dabd Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Tue, 29 Jul 2025 19:24:59 +0530 Subject: [PATCH 1/5] fix: Obtain server response even with malformed json input (#179) - graciously handled json lazy loading validation --- mcp/transport.go | 41 +++++++++++++++++++++++++++++++++++++++ mcp/transport_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/mcp/transport.go b/mcp/transport.go index dccc920b..86bc144d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -420,6 +420,17 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { case <-t.closed: return nil, io.EOF } + // Read remaining data in the buffer. + tr := &trail{} + buf := in.Buffered() + err := tr.load(buf) + if err != nil { + return nil, err + } + // If trailing data exists, it is an error. + if err := tr.validate(); err != nil { + return nil, err + } msgs, batch, err := readBatch(raw) if err != nil { @@ -453,6 +464,36 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { return msgs[0], err } +// trail is a helper type to store and validate remaining data in decoder buffer. +type trail struct { + data []byte +} + +// load reads remaining data from the buffer. +func (t *trail) load(buf io.Reader) error { + data, err := io.ReadAll(buf) + if err != nil { + return err + } + t.data = data + log.Println("trail", string(t.data)) + return nil +} + +// validate checks if the trailing data exists. +// if it does, it returns an error. +func (t *trail) validate() error { + // Ignore newline to be deemed as trailing data. + // It is usual for stdio transport. + if t.data[len(t.data)-1] == '\n' { + t.data = t.data[:len(t.data)-1] + } + if len(t.data) > 0 { + return fmt.Errorf("invalid trailing data '%s' at the end of stream", string(t.data)) + } + return nil +} + // readBatch reads batch data, which may be either a single JSON-RPC message, // or an array of JSON-RPC messages. func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { diff --git a/mcp/transport_test.go b/mcp/transport_test.go index c63b84ee..58b9626b 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,6 +7,8 @@ package mcp import ( "context" "io" + "reflect" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" @@ -51,3 +53,46 @@ func TestBatchFraming(t *testing.T) { } } } + +func Test_ioConn_Read_BadTrailingData(t *testing.T) { + type fields struct { + rwc io.ReadWriteCloser + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + { + name: "bad data at the end of first valid json", + fields: fields{ + rwc: rwc{ + rc: io.NopCloser(strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`)), + }, + }, + args: args{ + ctx: context.Background(), + }, + want: "invalid trailing data ',' at the end of stream", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := newIOConn(tt.fields.rwc) + _, err := tr.Read(tt.args.ctx) + if (err != nil) != tt.wantErr { + t.Errorf("ioConn.Read() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(err.Error(), tt.want) { + t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) + } + }) + } +} From d22b2b0f0b7ec3bedcc567f0a4c6b570a3975a85 Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Tue, 29 Jul 2025 21:50:48 +0530 Subject: [PATCH 2/5] fix: simplified check by validating next buffered byte to newline --- mcp/transport.go | 45 ++++++++----------------------------------- mcp/transport_test.go | 2 +- 2 files changed, 9 insertions(+), 38 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 86bc144d..461e6515 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -420,16 +420,17 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { case <-t.closed: return nil, io.EOF } - // Read remaining data in the buffer. - tr := &trail{} - buf := in.Buffered() - err := tr.load(buf) + + // Read the next byte to check if there is trailing data. + tr := make([]byte, 1) + _, err := in.Buffered().Read(tr) if err != nil { return nil, err } - // If trailing data exists, it is an error. - if err := tr.validate(); err != nil { - return nil, err + + // If the next byte is not a newline, it is an error. + if tr[0] != '\n' { + return nil, fmt.Errorf("invalid trailing data at the end of stream") } msgs, batch, err := readBatch(raw) @@ -464,36 +465,6 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { return msgs[0], err } -// trail is a helper type to store and validate remaining data in decoder buffer. -type trail struct { - data []byte -} - -// load reads remaining data from the buffer. -func (t *trail) load(buf io.Reader) error { - data, err := io.ReadAll(buf) - if err != nil { - return err - } - t.data = data - log.Println("trail", string(t.data)) - return nil -} - -// validate checks if the trailing data exists. -// if it does, it returns an error. -func (t *trail) validate() error { - // Ignore newline to be deemed as trailing data. - // It is usual for stdio transport. - if t.data[len(t.data)-1] == '\n' { - t.data = t.data[:len(t.data)-1] - } - if len(t.data) > 0 { - return fmt.Errorf("invalid trailing data '%s' at the end of stream", string(t.data)) - } - return nil -} - // readBatch reads batch data, which may be either a single JSON-RPC message, // or an array of JSON-RPC messages. func readBatch(data []byte) (msgs []jsonrpc.Message, isBatch bool, _ error) { diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 58b9626b..ee39231b 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -78,7 +78,7 @@ func Test_ioConn_Read_BadTrailingData(t *testing.T) { args: args{ ctx: context.Background(), }, - want: "invalid trailing data ',' at the end of stream", + want: "invalid trailing data at the end of stream", wantErr: true, }, } From 0657788859c8207ff8ce7eb8918e1fccdc34880e Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Thu, 31 Jul 2025 18:20:53 +0530 Subject: [PATCH 3/5] refactor: code review changes --- mcp/transport.go | 18 +++++++------- mcp/transport_test.go | 56 +++++++++++++++++++------------------------ 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 461e6515..92f2ba3b 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -422,15 +422,17 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { } // Read the next byte to check if there is trailing data. - tr := make([]byte, 1) - _, err := in.Buffered().Read(tr) - if err != nil { - return nil, err + var tr [1]byte + n, err := in.Buffered().Read(tr[:]) + if n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + return nil, fmt.Errorf("invalid trailing data at the end of stream") + } } - - // If the next byte is not a newline, it is an error. - if tr[0] != '\n' { - return nil, fmt.Errorf("invalid trailing data at the end of stream") + // Return error except for EOF + if err != nil && err != io.EOF { + return nil, err } msgs, batch, err := readBatch(raw) diff --git a/mcp/transport_test.go b/mcp/transport_test.go index ee39231b..18a326e8 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -7,7 +7,6 @@ package mcp import ( "context" "io" - "reflect" "strings" "testing" @@ -54,43 +53,38 @@ func TestBatchFraming(t *testing.T) { } } -func Test_ioConn_Read_BadTrailingData(t *testing.T) { - type fields struct { - rwc io.ReadWriteCloser - } - type args struct { - ctx context.Context - } +func TestIOConnRead(t *testing.T) { tests := []struct { - name string - fields fields - args args - want string - wantErr bool + name string + input string + want string }{ + { - name: "bad data at the end of first valid json", - fields: fields{ - rwc: rwc{ - rc: io.NopCloser(strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`)), - }, - }, - args: args{ - ctx: context.Background(), - }, - want: "invalid trailing data at the end of stream", - wantErr: true, + name: "valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`, + want: "", + }, + + { + name: "newline at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}} + `, + want: "", + }, + { + name: "bad data at the end of first valid json input", + input: `{"jsonrpc":"2.0","id":1,"method":"test","params":{}},`, + want: "invalid trailing data at the end of stream", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tr := newIOConn(tt.fields.rwc) - _, err := tr.Read(tt.args.ctx) - if (err != nil) != tt.wantErr { - t.Errorf("ioConn.Read() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(err.Error(), tt.want) { + tr := newIOConn(rwc{ + rc: io.NopCloser(strings.NewReader(tt.input)), + }) + _, err := tr.Read(context.Background()) + if err != nil && err.Error() != tt.want { t.Errorf("ioConn.Read() = %v, want %v", err.Error(), tt.want) } }) From e6cf5be96d61b6b082edfdc96878098805571918 Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Thu, 31 Jul 2025 18:28:07 +0530 Subject: [PATCH 4/5] refactor: inline condition check --- mcp/transport.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 92f2ba3b..06073ce4 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -423,15 +423,12 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { // Read the next byte to check if there is trailing data. var tr [1]byte - n, err := in.Buffered().Read(tr[:]) - if n > 0 { + if n, err := in.Buffered().Read(tr[:]); n > 0 { // If read byte is not a newline, it is an error. if tr[0] != '\n' { return nil, fmt.Errorf("invalid trailing data at the end of stream") } - } - // Return error except for EOF - if err != nil && err != io.EOF { + } else if err != nil && err != io.EOF { return nil, err } From 3c1003e35cad35c9784b5e30ed45c6e6ee9c601d Mon Sep 17 00:00:00 2001 From: Rishabh Nrupnarayan Date: Wed, 6 Aug 2025 17:56:46 +0530 Subject: [PATCH 5/5] refactor: reloacted the fix after rebasing --- mcp/transport.go | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mcp/transport.go b/mcp/transport.go index 06073ce4..f2d5c72d 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -308,6 +308,19 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { for { var raw json.RawMessage err := dec.Decode(&raw) + // If decoding was successful, check for trailing data at the end of the stream. + if err == nil { + // Read the next byte to check if there is trailing data. + var tr [1]byte + if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { + // If read byte is not a newline, it is an error. + if tr[0] != '\n' { + err = fmt.Errorf("invalid trailing data at the end of stream") + } + } else if readErr != nil && readErr != io.EOF { + err = readErr + } + } select { case incoming <- msgOrErr{msg: raw, err: err}: case <-closed: @@ -421,17 +434,6 @@ func (t *ioConn) Read(ctx context.Context) (jsonrpc.Message, error) { return nil, io.EOF } - // Read the next byte to check if there is trailing data. - var tr [1]byte - if n, err := in.Buffered().Read(tr[:]); n > 0 { - // If read byte is not a newline, it is an error. - if tr[0] != '\n' { - return nil, fmt.Errorf("invalid trailing data at the end of stream") - } - } else if err != nil && err != io.EOF { - return nil, err - } - msgs, batch, err := readBatch(raw) if err != nil { return nil, err