diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go index b2980f8ac44..c82e608e077 100644 --- a/internal/metadata/metadata.go +++ b/internal/metadata/metadata.go @@ -76,33 +76,11 @@ func Set(addr resolver.Address, md metadata.MD) resolver.Address { return addr } -// Validate returns an error if the input md contains invalid keys or values. -// -// If the header is not a pseudo-header, the following items are checked: -// - header names must contain one or more characters from this set [0-9 a-z _ - .]. -// - if the header-name ends with a "-bin" suffix, no validation of the header value is performed. -// - otherwise, the header value must contain one or more characters from the set [%x20-%x7E]. +// Validate validates every pair in md with ValidatePair. func Validate(md metadata.MD) error { for k, vals := range md { - // pseudo-header will be ignored - if k[0] == ':' { - continue - } - // check key, for i that saving a conversion if not using for range - for i := 0; i < len(k); i++ { - r := k[i] - if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' { - return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", k) - } - } - if strings.HasSuffix(k, "-bin") { - continue - } - // check value - for _, val := range vals { - if hasNotPrintable(val) { - return fmt.Errorf("header key %q contains value with non-printable ASCII characters", k) - } + if err := ValidatePair(k, vals...); err != nil { + return err } } return nil @@ -118,3 +96,37 @@ func hasNotPrintable(msg string) bool { } return false } + +// ValidatePair validate a key-value pair with the following rules (the pseudo-header will be skipped) : +// +// - key must contain one or more characters. +// - the characters in the key must be contained in [0-9 a-z _ - .]. +// - if the key ends with a "-bin" suffix, no validation of the corresponding value is performed. +// - the characters in the every value must be printable (in [%x20-%x7E]). +func ValidatePair(key string, vals ...string) error { + // key should not be empty + if key == "" { + return fmt.Errorf("there is an empty key in the header") + } + // pseudo-header will be ignored + if key[0] == ':' { + return nil + } + // check key, for i that saving a conversion if not using for range + for i := 0; i < len(key); i++ { + r := key[i] + if !(r >= 'a' && r <= 'z') && !(r >= '0' && r <= '9') && r != '.' && r != '-' && r != '_' { + return fmt.Errorf("header key %q contains illegal characters not in [0-9a-z-_.]", key) + } + } + if strings.HasSuffix(key, "-bin") { + return nil + } + // check value + for _, val := range vals { + if hasNotPrintable(val) { + return fmt.Errorf("header key %q contains value with non-printable ASCII characters", key) + } + } + return nil +} diff --git a/internal/metadata/metadata_test.go b/internal/metadata/metadata_test.go index 80f1a44bb6a..8f0e430e5ed 100644 --- a/internal/metadata/metadata_test.go +++ b/internal/metadata/metadata_test.go @@ -100,6 +100,10 @@ func TestValidate(t *testing.T) { md: map[string][]string{"test": {string(rune(0x19))}}, want: errors.New("header key \"test\" contains value with non-printable ASCII characters"), }, + { + md: map[string][]string{"": {"valid"}}, + want: errors.New("there is an empty key in the header"), + }, { md: map[string][]string{"test-bin": {string(rune(0x19))}}, want: nil, diff --git a/stream.go b/stream.go index 89936a4f166..34b0cb4593e 100644 --- a/stream.go +++ b/stream.go @@ -168,10 +168,19 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { - if md, _, ok := metadata.FromOutgoingContextRaw(ctx); ok { + if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok { + // validate md if err := imetadata.Validate(md); err != nil { return nil, status.Error(codes.Internal, err.Error()) } + // validate added + for _, kvs := range added { + for i := 0; i < len(kvs); i += 2 { + if err := imetadata.ValidatePair(kvs[i], kvs[i+1]); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + } + } } if channelz.IsOn() { cc.incrCallsStarted() diff --git a/test/metadata_test.go b/test/metadata_test.go index ad2b12cfc77..a15e5cb1c6e 100644 --- a/test/metadata_test.go +++ b/test/metadata_test.go @@ -36,29 +36,55 @@ import ( ) func (s) TestInvalidMetadata(t *testing.T) { - grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 2) + grpctest.TLogger.ExpectErrorN("stream: failed to validate md when setting trailer", 5) tests := []struct { - md metadata.MD - want error - recv error + name string + md metadata.MD + appendMD []string + want error + recv error }{ { + name: "invalid key", md: map[string][]string{string(rune(0x19)): {"testVal"}}, want: status.Error(codes.Internal, "header key \"\\x19\" contains illegal characters not in [0-9a-z-_.]"), recv: status.Error(codes.Internal, "invalid header field"), }, { + name: "invalid value", md: map[string][]string{"test": {string(rune(0x19))}}, want: status.Error(codes.Internal, "header key \"test\" contains value with non-printable ASCII characters"), recv: status.Error(codes.Internal, "invalid header field"), }, { + name: "invalid appended value", + md: map[string][]string{"test": {"test"}}, + appendMD: []string{"/", "value"}, + want: status.Error(codes.Internal, "header key \"/\" contains illegal characters not in [0-9a-z-_.]"), + recv: status.Error(codes.Internal, "invalid header field"), + }, + { + name: "empty appended key", + md: map[string][]string{"test": {"test"}}, + appendMD: []string{"", "value"}, + want: status.Error(codes.Internal, "there is an empty key in the header"), + recv: status.Error(codes.Internal, "invalid header field"), + }, + { + name: "empty key", + md: map[string][]string{"": {"test"}}, + want: status.Error(codes.Internal, "there is an empty key in the header"), + recv: status.Error(codes.Internal, "invalid header field"), + }, + { + name: "-bin key with arbitrary value", md: map[string][]string{"test-bin": {string(rune(0x19))}}, want: nil, recv: io.EOF, }, { + name: "valid key and value", md: map[string][]string{"test": {"value"}}, want: nil, recv: io.EOF, @@ -77,13 +103,16 @@ func (s) TestInvalidMetadata(t *testing.T) { } test := tests[testNum] testNum++ - if err := stream.SetHeader(test.md); !reflect.DeepEqual(test.want, err) { - return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + // merge original md and added md. + md := metadata.Join(test.md, metadata.Pairs(test.appendMD...)) + + if err := stream.SetHeader(md); !reflect.DeepEqual(test.want, err) { + return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want) } - if err := stream.SendHeader(test.md); !reflect.DeepEqual(test.want, err) { - return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + if err := stream.SendHeader(md); !reflect.DeepEqual(test.want, err) { + return fmt.Errorf("call stream.SendHeader(md) validate metadata which is %v got err :%v, want err :%v", md, err, test.want) } - stream.SetTrailer(test.md) + stream.SetTrailer(md) return nil }, } @@ -93,29 +122,33 @@ func (s) TestInvalidMetadata(t *testing.T) { defer ss.Stop() for _, test := range tests { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - ctx = metadata.NewOutgoingContext(ctx, test.md) - if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) { - t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) - } + t.Run("unary "+test.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, test.md) + ctx = metadata.AppendToOutgoingContext(ctx, test.appendMD...) + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(test.want, err) { + t.Errorf("call ss.Client.EmptyCall() validate metadata which is %v got err :%v, want err :%v", test.md, err, test.want) + } + }) } // call the stream server's api to drive the server-side unit testing for _, test := range tests { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - stream, err := ss.Client.FullDuplexCall(ctx) - defer cancel() - if err != nil { - t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err) - continue - } - if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { - t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err) - } - if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) { - t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv) - } + t.Run("streaming "+test.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Errorf("call ss.Client.FullDuplexCall(context.Background()) will success but got err :%v", err) + return + } + if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Errorf("call ss.Client stream Send(nil) will success but got err :%v", err) + } + if _, err := stream.Recv(); status.Code(err) != status.Code(test.recv) || !strings.Contains(err.Error(), test.recv.Error()) { + t.Errorf("stream.Recv() = _, get err :%v, want err :%v", err, test.recv) + } + }) } }