Skip to content

Commit

Permalink
add support for the google.api.HttpBody proto as a response
Browse files Browse the repository at this point in the history
  • Loading branch information
theRealWardo committed Sep 18, 2017
1 parent 8bec008 commit 136b143
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 17 deletions.
6 changes: 3 additions & 3 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ func (*errorBody) ProtoMessage() {}
func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) {
const fallback = `{"error": "failed to marshal error message"}`

w.Header().Del("Trailer")
w.Header().Set("Content-Type", marshaler.ContentType())

s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())
Expand All @@ -94,6 +91,9 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w
Code: int32(s.Code()),
}

w.Header().Del("Trailer")
w.Header().Set("Content-Type", marshaler.ContentType(body))

buf, merr := marshaler.Marshal(body)
if merr != nil {
grpclog.Printf("Failed to marshal error message %q: %v", body, merr)
Expand Down
14 changes: 10 additions & 4 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"net/textproto"

"github.com/golang/protobuf/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime/internal"
"github.com/therealwardo/grpc-gateway/runtime/internal"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
Expand All @@ -32,13 +32,14 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
handleForwardResponseServerMetadata(w, mux, md)

w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Type", marshaler.ContentType())
w.Header().Set("Content-Type", marshaler.ContentType(nil))
if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
f.Flush()
ctSet := false
for {
resp, err := recv()
if err == io.EOF {
Expand All @@ -53,7 +54,12 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}

buf, err := marshaler.Marshal(streamChunk(resp, nil))
chunk := streamChunk(resp, nil)
if !ctSet {
w.Header().Set("Content-Type", marshaler.ContentType(chunk))
ctSet = true
}
buf, err := marshaler.Marshal(chunk)
if err != nil {
grpclog.Printf("Failed to marshal response chunk: %v", err)
return
Expand Down Expand Up @@ -101,7 +107,7 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha

handleForwardResponseServerMetadata(w, mux, md)
handleForwardResponseTrailerHeader(w, md)
w.Header().Set("Content-Type", marshaler.ContentType())
w.Header().Set("Content-Type", marshaler.ContentType(resp))
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, mux, marshaler, w, req, err)
return
Expand Down
74 changes: 74 additions & 0 deletions runtime/marshal_httpbody.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package runtime

import (
"io"
"reflect"

"github.com/golang/protobuf/proto"
hb "google.golang.org/genproto/googleapis/api/httpbody"
)

var (
backupMarshaler = &JSONPb{OrigName: true}
)

// HttpBodyMarshaler is a Marshaler which supports marshaling of a
// google.api.HttpBody message as the full response body if it is
// the actual message used as the response. If not, then this will
// simply fallback to the JSONPb marshaler.
type HttpBodyMarshaler struct{}

// ContentType returns the type specified in the google.api.HttpBody
// proto if "v" is a google.api.HttpBody proto, otherwise returns
// "application/json".
func (*HttpBodyMarshaler) ContentType(v interface{}) string {
if h := tryHttpBody(v); h != nil {
return h.GetContentType()
}
return "application/json"
}

// Marshal marshals "v" by returning the body bytes if v is a
// google.api.HttpBody message, or it marshals to JSON.
func (*HttpBodyMarshaler) Marshal(v interface{}) ([]byte, error) {
if h := tryHttpBody(v); h != nil {
return h.GetData(), nil
}
return backupMarshaler.Marshal(v)
}

// Unmarshal unmarshals JSON data into "v".
// google.api.HttpBody messages are not supported on the request.
func (*HttpBodyMarshaler) Unmarshal(data []byte, v interface{}) error {
return backupMarshaler.Unmarshal(data, v)
}

// NewDecoder returns a Decoder which reads JSON stream from "r".
func (*HttpBodyMarshaler) NewDecoder(r io.Reader) Decoder {
return backupMarshaler.NewDecoder(r)
}

// NewEncoder returns an Encoder which writes JSON stream into "w".
func (*HttpBodyMarshaler) NewEncoder(w io.Writer) Encoder {
return backupMarshaler.NewEncoder(w)
}

func tryHttpBody(v interface{}) *hb.HttpBody {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return nil
}
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
if rv.Type().ConvertibleTo(typeProtoMessage) {
pb := rv.Interface().(proto.Message)
if proto.MessageName(pb) == "google.api.HttpBody" {
return v.(*hb.HttpBody)
}
}
rv = rv.Elem()
}
return nil
}
2 changes: 1 addition & 1 deletion runtime/marshal_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type JSONBuiltin struct{}

// ContentType always Returns "application/json".
func (*JSONBuiltin) ContentType() string {
func (*JSONBuiltin) ContentType(v interface{}) string {
return "application/json"
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/marshal_jsonpb.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
type JSONPb jsonpb.Marshaler

// ContentType always returns "application/json".
func (*JSONPb) ContentType() string {
func (*JSONPb) ContentType(v interface{}) string {
return "application/json"
}

Expand Down
4 changes: 2 additions & 2 deletions runtime/marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type Marshaler interface {
NewDecoder(r io.Reader) Decoder
// NewEncoder returns an Encoder which writes bytes sequence into "w".
NewEncoder(w io.Writer) Encoder
// ContentType returns the Content-Type which this marshaler is responsible for.
ContentType() string
// ContentType returns the response Content-Type for "v".
ContentType(v interface{}) string
}

// Decoder decodes a byte sequence
Expand Down
2 changes: 1 addition & 1 deletion runtime/marshaler_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var (
acceptHeader = http.CanonicalHeaderKey("Accept")
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")

defaultMarshaler = &JSONPb{OrigName: true}
defaultMarshaler = &HttpBodyMarshaler{}
)

// MarshalerForRequest returns the inbound/outbound marshalers for this request.
Expand Down
2 changes: 1 addition & 1 deletion runtime/marshaler_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func TestMarshalerForRequest(t *testing.T) {

type dummyMarshaler struct{}

func (dummyMarshaler) ContentType() string { return "" }
func (dummyMarshaler) ContentType(v interface{}) string { return "" }
func (dummyMarshaler) Marshal(interface{}) ([]byte, error) {
return nil, errors.New("not implemented")
}
Expand Down
9 changes: 5 additions & 4 deletions runtime/proto_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`

w.Header().Del("Trailer")
w.Header().Set("Content-Type", marshaler.ContentType())

s, ok := status.FromError(err)
if !ok {
s = status.New(codes.Unknown, err.Error())
}

buf, merr := marshaler.Marshal(s.Proto())
pb := s.Proto()
w.Header().Del("Trailer")
w.Header().Set("Content-Type", marshaler.ContentType(pb))

buf, merr := marshaler.Marshal(pb)
if merr != nil {
grpclog.Printf("Failed to marshal error message %q: %v", s.Proto(), merr)
w.WriteHeader(http.StatusInternalServerError)
Expand Down

0 comments on commit 136b143

Please sign in to comment.