Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport/http2: use HTTP 400 for bad requests instead of 500 #5804

Merged
merged 1 commit into from Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 21 additions & 9 deletions internal/transport/handler_server.go
Expand Up @@ -46,24 +46,32 @@ import (
"google.golang.org/grpc/status"
)

// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
msg := "gRPC requires HTTP/2"
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method")
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
msg := "gRPC requires a ResponseWriter supporting http.Flusher"
http.Error(w, msg, http.StatusInternalServerError)
return nil, errors.New(msg)
}

st := &serverHandlerTransport{
Expand All @@ -79,7 +87,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
msg := fmt.Sprintf("malformed time-out: %v", err)
easwars marked this conversation as resolved.
Show resolved Hide resolved
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
st.timeoutSet = true
st.timeout = to
Expand All @@ -97,7 +107,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
for _, v := range vv {
v, err := decodeMetadataHeader(k, v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err)
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
metakv = append(metakv, k, v)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/transport/handler_server_test.go
Expand Up @@ -63,7 +63,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Method: "GET",
Header: http.Header{},
},
wantErr: "invalid gRPC request method",
wantErr: `invalid gRPC request method "GET"`,
},
{
name: "bad content type",
Expand All @@ -74,7 +74,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
"Content-Type": {"application/foo"},
},
},
wantErr: "invalid gRPC request content-type",
wantErr: `invalid gRPC request content-type "application/foo"`,
},
{
name: "not flusher",
Expand Down
3 changes: 2 additions & 1 deletion server.go
Expand Up @@ -1008,7 +1008,8 @@ var _ http.Handler = (*Server)(nil)
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Errors returned from transport.NewServerHandlerTransport have
// already been written to w.
return
}
if !s.addConn(listenerAddressForServeHTTP, st) {
Expand Down