From aba3ff50ee01d47f772cd63eaf6d3d8f687a2abb Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Fri, 31 May 2024 16:51:26 -0400 Subject: [PATCH] Restrict handler metadata headers (#748) This PR fixes the setting of protocol headers to avoid multiple value headers when providing metadata to a handler. The metadata headers are further restricted to avoid setting protocol headers like "Content-Type". This restriction allows the user to pass the response of a proxy call to a handler without having to filter the response headers themselves. This enforces the protocol headers are set by the handler and that they are unaffected from any user provided metadata. Previously, returning the response of a client request to a handler would merge the headers together leading to protocol errors from invalid headers such as "Content-Type" having multiple values. --------- Signed-off-by: Edward McFarlane --- .golangci.yml | 3 +++ connect_ext_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++ error_writer.go | 2 +- handler.go | 4 +-- header.go | 59 +++++++++++++++++++++-------------------- protocol_connect.go | 4 +-- protocol_grpc.go | 2 +- 7 files changed, 104 insertions(+), 34 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index fc627292..d94bf062 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -132,3 +132,6 @@ issues: # We want to show examples with http.Get - linters: [noctx] path: internal/memhttp/memhttp_test.go + # We need to initialize a map of all protocol headers + - linters: [gochecknoglobals] + path: header.go diff --git a/connect_ext_test.go b/connect_ext_test.go index 42294086..81ea030a 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2631,6 +2631,70 @@ func TestBlankImportCodeGeneration(t *testing.T) { assert.NotNil(t, desc) } +// TestSetProtocolHeaders tests that headers required by the protocols are set +// overriding user provided headers. +func TestSetProtocolHeaders(t *testing.T) { + t.Parallel() + tests := []struct { + name string + clientOption connect.ClientOption + expectContentType string + }{{ + name: "connect", + expectContentType: "application/proto", + }, { + name: "grpc", + clientOption: connect.WithGRPC(), + expectContentType: "application/grpc", + }, { + name: "grpcweb", + clientOption: connect.WithGRPCWeb(), + expectContentType: "application/grpc-web+proto", + }} + for _, tt := range tests { + testcase := tt + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + pingServer := &pingServer{} + mux := http.NewServeMux() + mux.Handle(pingv1connect.NewPingServiceHandler(pingServer)) + server := memhttptest.NewServer(t, mux) + + clientOpts := []connect.ClientOption{} + if testcase.clientOption == nil { + // Use a different protocol to test the override. + clientOpts = append(clientOpts, connect.WithGRPC()) + } + client := pingv1connect.NewPingServiceClient(server.Client(), server.URL(), clientOpts...) + + pingProxyServer := &pluggablePingServer{ + ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + return client.Ping(ctx, request) + }, + } + proxyMux := http.NewServeMux() + proxyMux.Handle(pingv1connect.NewPingServiceHandler(pingProxyServer)) + proxyServer := memhttptest.NewServer(t, proxyMux) + + proxyClientOpts := []connect.ClientOption{} + if testcase.clientOption != nil { + proxyClientOpts = append(proxyClientOpts, testcase.clientOption) + } + proxyClient := pingv1connect.NewPingServiceClient(proxyServer.Client(), proxyServer.URL(), proxyClientOpts...) + + request := connect.NewRequest(&pingv1.PingRequest{Number: 42}) + request.Header().Set("X-Test", t.Name()) + response, err := proxyClient.Ping(context.Background(), request) + if !assert.Nil(t, err) { + return + } + // Assert the Content-Type is set for the proxy clients protocol and not the client's. + assert.Equal(t, response.Header().Get("Content-Type"), testcase.expectContentType) + assert.Equal(t, len(response.Header().Values("Content-Type")), 1) + }) + } +} + type unflushableWriter struct { w http.ResponseWriter } diff --git a/error_writer.go b/error_writer.go index 58ce3c42..f05d19ec 100644 --- a/error_writer.go +++ b/error_writer.go @@ -128,7 +128,7 @@ func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error { if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(response.Header(), connectErr.meta) + mergeNonProtocolHeaders(response.Header(), connectErr.meta) } response.WriteHeader(connectCodeToHTTP(CodeOf(err))) data, marshalErr := json.Marshal(newConnectWireError(err)) diff --git a/handler.go b/handler.go index 1d573291..5eab6c71 100644 --- a/handler.go +++ b/handler.go @@ -71,8 +71,8 @@ func NewUnaryHandler[Req, Res any]( if err != nil { return err } - mergeHeaders(conn.ResponseHeader(), response.Header()) - mergeHeaders(conn.ResponseTrailer(), response.Trailer()) + mergeNonProtocolHeaders(conn.ResponseHeader(), response.Header()) + mergeNonProtocolHeaders(conn.ResponseTrailer(), response.Trailer()) return conn.Send(response.Any()) } diff --git a/header.go b/header.go index f3c7cacd..b3f05432 100644 --- a/header.go +++ b/header.go @@ -19,6 +19,33 @@ import ( "net/http" ) +var ( + protocolHeaders = map[string]struct{}{ + // HTTP headers. + headerContentType: {}, + headerContentLength: {}, + headerContentEncoding: {}, + headerHost: {}, + headerUserAgent: {}, + headerTrailer: {}, + headerDate: {}, + // Connect headers. + connectUnaryHeaderAcceptCompression: {}, + connectUnaryTrailerPrefix: {}, + connectStreamingHeaderCompression: {}, + connectStreamingHeaderAcceptCompression: {}, + connectHeaderTimeout: {}, + connectHeaderProtocolVersion: {}, + // gRPC headers. + grpcHeaderCompression: {}, + grpcHeaderAcceptCompression: {}, + grpcHeaderTimeout: {}, + grpcHeaderStatus: {}, + grpcHeaderMessage: {}, + grpcHeaderDetails: {}, + } +) + // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values. // // In the Connect, gRPC, and gRPC-Web protocols, binary headers must have keys @@ -57,10 +84,9 @@ func mergeHeaders(into, from http.Header) { } } -// mergeMetdataHeaders merges the metadata headers from the "from" header into -// the "into" header. It skips over non metadata headers that should not be -// propagated from the server to the client. -func mergeMetadataHeaders(into, from http.Header) { +// mergeNonProtocolHeaders merges headers excluding protocol headers defined in +// protocolHeaders. +func mergeNonProtocolHeaders(into, from http.Header) { for key, vals := range from { if len(vals) == 0 { // For response trailers, net/http will pre-populate entries @@ -68,30 +94,7 @@ func mergeMetadataHeaders(into, from http.Header) { // are no actual values for those keys, we skip them. continue } - switch http.CanonicalHeaderKey(key) { - case headerContentType, - headerContentLength, - headerContentEncoding, - headerHost, - headerUserAgent, - headerTrailer, - headerDate: - // HTTP headers. - case connectUnaryHeaderAcceptCompression, - connectUnaryTrailerPrefix, - connectStreamingHeaderCompression, - connectStreamingHeaderAcceptCompression, - connectHeaderTimeout, - connectHeaderProtocolVersion: - // Connect headers. - case grpcHeaderCompression, - grpcHeaderAcceptCompression, - grpcHeaderTimeout, - grpcHeaderStatus, - grpcHeaderMessage, - grpcHeaderDetails: - // gRPC headers. - default: + if _, isProtocolHeader := protocolHeaders[key]; !isProtocolHeader { into[key] = append(into[key], vals...) } } diff --git a/protocol_connect.go b/protocol_connect.go index e3c5e4a5..bf26f1aa 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -765,7 +765,7 @@ func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) { } if err != nil { if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(header, connectErr.meta) + mergeNonProtocolHeaders(header, connectErr.meta) } } for k, v := range hc.responseTrailer { @@ -850,7 +850,7 @@ func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Hea if err != nil { end.Error = newConnectWireError(err) if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(end.Trailer, connectErr.meta) + mergeNonProtocolHeaders(end.Trailer, connectErr.meta) } } data, marshalErr := json.Marshal(end) diff --git a/protocol_grpc.go b/protocol_grpc.go index 32b116f4..5addf31e 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -841,7 +841,7 @@ func grpcErrorToTrailer(trailer http.Header, protobuf Codec, err error) { return } if connectErr, ok := asError(err); ok && !connectErr.wireErr { - mergeMetadataHeaders(trailer, connectErr.meta) + mergeNonProtocolHeaders(trailer, connectErr.meta) } var ( status = grpcStatusFromError(err)