diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index ebe8bfe330a..e6626bf96e7 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s contentSubtype, validContentType := grpcutil.ContentSubtype(contentType) if !validContentType { msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType) - http.Error(w, msg, http.StatusBadRequest) + http.Error(w, msg, http.StatusUnsupportedMediaType) return nil, errors.New(msg) } if _, ok := w.(http.Flusher); !ok { @@ -87,7 +87,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s if v := r.Header.Get("grpc-timeout"); v != "" { to, err := decodeTimeout(v) if err != nil { - msg := fmt.Sprintf("malformed time-out: %v", err) + msg := fmt.Sprintf("malformed grpc-timeout: %v", err) http.Error(w, msg, http.StatusBadRequest) return nil, status.Error(codes.Internal, msg) } diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 82b4baca58b..fbd8058b79f 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -41,11 +41,12 @@ import ( func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { type testCase struct { - name string - req *http.Request - wantErr string - modrw func(http.ResponseWriter) http.ResponseWriter - check func(*serverHandlerTransport, *testCase) error + name string + req *http.Request + wantErr string + wantErrCode int + modrw func(http.ResponseWriter) http.ResponseWriter + check func(*serverHandlerTransport, *testCase) error } tests := []testCase{ { @@ -54,7 +55,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { ProtoMajor: 1, ProtoMinor: 1, }, - wantErr: "gRPC requires HTTP/2", + wantErr: "gRPC requires HTTP/2", + wantErrCode: http.StatusBadRequest, }, { name: "bad method", @@ -63,7 +65,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Method: "GET", Header: http.Header{}, }, - wantErr: `invalid gRPC request method "GET"`, + wantErr: `invalid gRPC request method "GET"`, + wantErrCode: http.StatusBadRequest, }, { name: "bad content type", @@ -74,7 +77,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { "Content-Type": {"application/foo"}, }, }, - wantErr: `invalid gRPC request content-type "application/foo"`, + wantErr: `invalid gRPC request content-type "application/foo"`, + wantErrCode: http.StatusUnsupportedMediaType, }, { name: "not flusher", @@ -93,7 +97,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { } return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)} }, - wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", + wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", + wantErrCode: http.StatusInternalServerError, }, { name: "valid", @@ -153,7 +158,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { Path: "/service/foo.bar", }, }, - wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`, + wantErr: `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`, + wantErrCode: http.StatusBadRequest, }, { name: "with metadata", @@ -187,7 +193,12 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { } for _, tt := range tests { - rw := newTestHandlerResponseWriter() + rrec := httptest.NewRecorder() + rw := http.ResponseWriter(testHandlerResponseWriter{ + ResponseRecorder: rrec, + closeNotify: make(chan bool, 1), + }) + if tt.modrw != nil { rw = tt.modrw(rw) } @@ -196,6 +207,13 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr) continue } + if tt.wantErrCode == 0 { + tt.wantErrCode = http.StatusOK + } + if rrec.Code != tt.wantErrCode { + t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode) + continue + } if gotErr != nil { continue } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 37e089bc843..bc3da706726 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -380,13 +380,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( fc: &inFlow{limit: uint32(t.initialWindowSize)}, } var ( - // If a gRPC Response-Headers has already been received, then it means - // that the peer is speaking gRPC and we are in gRPC mode. - isGRPC = false - mdata = make(map[string][]string) - httpMethod string - // headerError is set if an error is encountered while parsing the headers - headerError bool + // if false, content-type was missing or invalid + isGRPC = false + contentType = "" + mdata = make(map[string][]string) + httpMethod string + // these are set if an error is encountered while parsing the headers + protocolError bool + headerError *status.Status timeoutSet bool timeout time.Duration @@ -397,6 +398,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( case "content-type": contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value) if !validContentType { + contentType = hf.Value break } mdata[hf.Name] = append(mdata[hf.Name], hf.Value) @@ -412,7 +414,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( timeoutSet = true var err error if timeout, err = decodeTimeout(hf.Value); err != nil { - headerError = true + headerError = status.Newf(codes.Internal, "malformed grpc-timeout: %v", err) } // "Transports must consider requests containing the Connection header // as malformed." - A41 @@ -420,14 +422,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( if logger.V(logLevel) { logger.Errorf("transport: http2Server.operateHeaders parsed a :connection header which makes a request malformed as per the HTTP/2 spec") } - headerError = true + protocolError = true default: if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { break } v, err := decodeMetadataHeader(hf.Name, hf.Value) if err != nil { - headerError = true + headerError = status.Newf(codes.Internal, "malformed binary metadata %q in header %q: %v", hf.Value, hf.Name, err) logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) break } @@ -446,7 +448,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( logger.Errorf("transport: %v", errMsg) } t.controlBuf.put(&earlyAbortStream{ - httpStatus: 400, + httpStatus: http.StatusBadRequest, streamID: streamID, contentSubtype: s.contentSubtype, status: status.New(codes.Internal, errMsg), @@ -455,7 +457,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return nil } - if !isGRPC || headerError { + if protocolError { t.controlBuf.put(&cleanupStream{ streamID: streamID, rst: true, @@ -464,6 +466,26 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( }) return nil } + if !isGRPC { + t.controlBuf.put(&earlyAbortStream{ + httpStatus: http.StatusUnsupportedMediaType, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), + rst: !frame.StreamEnded(), + }) + return nil + } + if headerError != nil { + t.controlBuf.put(&earlyAbortStream{ + httpStatus: http.StatusBadRequest, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: headerError, + rst: !frame.StreamEnded(), + }) + return nil + } // "If :authority is missing, Host must be renamed to :authority." - A41 if len(mdata[":authority"]) == 0 { diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index b41378b0024..f61a1ed6972 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1952,105 +1952,154 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) { grpcStatusWant: "13", grpcMessageWant: "which should be POST", }, + { + name: "Client Sending Wrong Content-Type", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/json"}}, + }, + httpStatusWant: "415", + grpcStatusWant: "3", + grpcMessageWant: `invalid gRPC request content-type "application/json"`, + }, + { + name: "Client Sending Bad Timeout", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "grpc-timeout", values: []string{"18f6n"}}, + }, + httpStatusWant: "400", + grpcStatusWant: "13", + grpcMessageWant: "malformed grpc-timeout", + }, + { + name: "Client Sending Bad Binary Header", + headers: []struct { + name string + values []string + }{ + {name: ":method", values: []string{"POST"}}, + {name: ":path", values: []string{"foo"}}, + {name: ":authority", values: []string{"localhost"}}, + {name: "content-type", values: []string{"application/grpc"}}, + {name: "foobar-bin", values: []string{"X()3e@#$-"}}, + }, + httpStatusWant: "400", + grpcStatusWant: "13", + grpcMessageWant: `header "foobar-bin": illegal base64 data`, + }, } for _, test := range tests { - server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) - defer server.stop() - // Create a client directly to not tie what you can send to API of - // http2_client.go (i.e. control headers being sent). - mconn, err := net.Dial("tcp", server.lis.Addr().String()) - if err != nil { - t.Fatalf("Client failed to dial: %v", err) - } - defer mconn.Close() + t.Run(test.name, func(t *testing.T) { + server := setUpServerOnly(t, 0, &ServerConfig{}, suspended) + defer server.stop() + // Create a client directly to not tie what you can send to API of + // http2_client.go (i.e. control headers being sent). + mconn, err := net.Dial("tcp", server.lis.Addr().String()) + if err != nil { + t.Fatalf("Client failed to dial: %v", err) + } + defer mconn.Close() - if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { - t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) - } + if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) { + t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, ", n, err, len(clientPreface)) + } - framer := http2.NewFramer(mconn, mconn) - framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) - if err := framer.WriteSettings(); err != nil { - t.Fatalf("Error while writing settings: %v", err) - } + framer := http2.NewFramer(mconn, mconn) + framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil) + if err := framer.WriteSettings(); err != nil { + t.Fatalf("Error while writing settings: %v", err) + } - // result chan indicates that reader received a Headers Frame with - // desired grpc status and message from server. An error will be passed - // on it if any other frame is received. - result := testutils.NewChannel() + // result chan indicates that reader received a Headers Frame with + // desired grpc status and message from server. An error will be passed + // on it if any other frame is received. + result := testutils.NewChannel() - // Launch a reader goroutine. - go func() { - for { - frame, err := framer.ReadFrame() - if err != nil { - return - } - switch frame := frame.(type) { - case *http2.SettingsFrame: - // Do nothing. A settings frame is expected from server preface. - case *http2.MetaHeadersFrame: - var httpStatus, grpcStatus, grpcMessage string - for _, header := range frame.Fields { - if header.Name == ":status" { - httpStatus = header.Value + // Launch a reader goroutine. + go func() { + for { + frame, err := framer.ReadFrame() + if err != nil { + return + } + switch frame := frame.(type) { + case *http2.SettingsFrame: + // Do nothing. A settings frame is expected from server preface. + case *http2.MetaHeadersFrame: + var httpStatus, grpcStatus, grpcMessage string + for _, header := range frame.Fields { + if header.Name == ":status" { + httpStatus = header.Value + } + if header.Name == "grpc-status" { + grpcStatus = header.Value + } + if header.Name == "grpc-message" { + grpcMessage = header.Value + } } - if header.Name == "grpc-status" { - grpcStatus = header.Value + if httpStatus != test.httpStatusWant { + result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant)) + return } - if header.Name == "grpc-message" { - grpcMessage = header.Value + if grpcStatus != test.grpcStatusWant { // grpc status code internal + result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant)) + return } - } - if httpStatus != test.httpStatusWant { - result.Send(fmt.Errorf("incorrect HTTP Status got %v, want %v", httpStatus, test.httpStatusWant)) - return - } - if grpcStatus != test.grpcStatusWant { // grpc status code internal - result.Send(fmt.Errorf("incorrect gRPC Status got %v, want %v", grpcStatus, test.grpcStatusWant)) - return - } - if !strings.Contains(grpcMessage, test.grpcMessageWant) { - result.Send(fmt.Errorf("incorrect gRPC message")) + if !strings.Contains(grpcMessage, test.grpcMessageWant) { + result.Send(fmt.Errorf("incorrect gRPC message, want %q got %q", test.grpcMessageWant, grpcMessage)) + return + } + + // Records that client successfully received a HeadersFrame + // with expected Trailers-Only response. + result.Send(nil) return + default: + // The server should send nothing but a single Settings and Headers frame. + result.Send(errors.New("the client received a frame other than Settings or Headers")) } - - // Records that client successfully received a HeadersFrame - // with expected Trailers-Only response. - result.Send(nil) - return - default: - // The server should send nothing but a single Settings and Headers frame. - result.Send(errors.New("the client received a frame other than Settings or Headers")) } - } - }() + }() - var buf bytes.Buffer - henc := hpack.NewEncoder(&buf) + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) - // Needs to build headers deterministically to conform to gRPC over - // HTTP/2 spec. - for _, header := range test.headers { - for _, value := range header.values { - if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { - t.Fatalf("Error while encoding header: %v", err) + // Needs to build headers deterministically to conform to gRPC over + // HTTP/2 spec. + for _, header := range test.headers { + for _, value := range header.values { + if err := henc.WriteField(hpack.HeaderField{Name: header.name, Value: value}); err != nil { + t.Fatalf("Error while encoding header: %v", err) + } } } - } - if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { - t.Fatalf("Error while writing headers: %v", err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - r, err := result.Receive(ctx) - if err != nil { - t.Fatalf("Error receiving from channel: %v", err) - } - if r != nil { - t.Fatalf("want nil, got %v", r) - } + if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil { + t.Fatalf("Error while writing headers: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + r, err := result.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from channel: %v", err) + } + if r != nil { + t.Fatalf("want nil, got %v", r) + } + }) } }