diff --git a/channel.go b/channel.go index feafd9a6b..a7497c322 100644 --- a/channel.go +++ b/channel.go @@ -124,10 +124,6 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } if mh.Length > uint32(messageLengthMax) { - if _, err := ch.br.Discard(int(mh.Length)); err != nil { - return mh, nil, fmt.Errorf("failed to discard after receiving oversized message: %w", err) - } - return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) } @@ -143,10 +139,9 @@ func (ch *channel) recv() (messageHeader, []byte, error) { } func (ch *channel) send(streamID uint32, t messageType, flags uint8, p []byte) error { - // TODO: Error on send rather than on recv - //if len(p) > messageLengthMax { - // return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) - //} + if len(p) > messageLengthMax { + return status.Errorf(codes.InvalidArgument, "refusing to send, message length %v exceed maximum message size of %v", len(p), messageLengthMax) + } if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t, Flags: flags}); err != nil { return err } diff --git a/channel_test.go b/channel_test.go index de8b66d38..3de0823a7 100644 --- a/channel_test.go +++ b/channel_test.go @@ -24,8 +24,9 @@ import ( "reflect" "testing" - "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "google.golang.org/grpc/codes" ) func TestReadWriteMessage(t *testing.T) { @@ -89,37 +90,14 @@ func TestReadWriteMessage(t *testing.T) { func TestMessageOversize(t *testing.T) { var ( - w, r = net.Pipe() - wch, rch = newChannel(w), newChannel(r) - msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) - errs = make(chan error, 1) + w, _ = net.Pipe() + wch = newChannel(w) + msg = bytes.Repeat([]byte("a message of massive length"), 512<<10) ) - go func() { - if err := wch.send(1, 1, 0, msg); err != nil { - errs <- err - } - }() - - _, _, err := rch.recv() - if err == nil { - t.Fatalf("error expected reading with small buffer") - } - - status, ok := status.FromError(err) - if !ok { - t.Fatalf("expected grpc status error: %v", err) - } + err := wch.send(1, 1, 0, msg) - if status.Code() != codes.ResourceExhausted { - t.Fatalf("expected grpc status code: %v != %v", status.Code(), codes.ResourceExhausted) - } - - select { - case err := <-errs: - if err != nil { - t.Fatal(err) - } - default: + if status.Convert(err).Code() != codes.InvalidArgument { + t.Fatalf("error expected while send a message of massive length") } }