Skip to content

Commit

Permalink
ttrpc: add a message length check before send
Browse files Browse the repository at this point in the history
Signed-off-by: Qian Zhang <cosmoer@qq.com>
  • Loading branch information
cosmoer committed May 26, 2022
1 parent 74421d1 commit a627d10
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 38 deletions.
11 changes: 3 additions & 8 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

if mh.Length > uint32(messageLengthMax) {
if _, err := ch.br.Discard(int(mh.Length)); err != nil {
return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err)
}

return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax)
}

Expand All @@ -143,10 +139,9 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
}

func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error {
// TODO: Error on send rather than on recv
//if len(p) > messageLengthMax {
// return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
//}
if len(p) > messageLengthMax {
return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax)
}
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil {
return err
}
Expand Down
38 changes: 8 additions & 30 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ import (
"reflect"
"testing"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"google.golang.org/grpc/codes"
)

func TestReadWriteMessage(t *testing.T) {
Expand Down Expand Up @@ -89,37 +90,14 @@ func TestReadWriteMessage(t *testing.T) {

func TestMessageOversize(t *testing.T) {
var (
w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
w, _ = net.Pipe()
wch = newChannel(w)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
)

go func() {
if err := wch.send(1, 1, 0, msg); err != nil {
errs <- err
}
}()

_, _, err := rch.recv()
if err == nil {
t.Fatalf("error expected reading with small buffer")
}

status, ok := status.FromError(err)
if !ok {
t.Fatalf("expected grpc status error: %v", err)
}
err := wch.send(1, 1, 0, msg)

if status.Code() != codes.ResourceExhausted {
t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted)
}

select {
case err := <-errs:
if err != nil {
t.Fatal(err)
}
default:
if status.Convert(err).Code() != codes.InvalidArgument {
t.Fatalf("error expected while send a message of massive length")
}
}

0 comments on commit a627d10

Please sign in to comment.