Skip to content

Commit

Permalink
Handle unsupported request content type
Browse files Browse the repository at this point in the history
Return 415 Unsupported Media Type when the request decoder does not support the request content type.
  • Loading branch information
raphael committed Apr 24, 2024
1 parent 026621f commit a3f64cc
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 50 deletions.
29 changes: 28 additions & 1 deletion http/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"mime"
"net/http"
"strings"

goa "goa.design/goa/v3/pkg"
)

const (
Expand All @@ -25,6 +27,13 @@ const (
ContentTypeKey
)

const (
// ErrUnsupportedMediaTypeName is the name of the Goa service error returned
// by the built-in decoder when the content type of the request body is not
// supported.
ErrUnsupportedMediaTypeName = "ErrUnsupportedMediaType"
)

type (
// Decoder provides the actual decoding algorithm used to load HTTP
// request and response bodies.
Expand Down Expand Up @@ -80,7 +89,7 @@ func RequestDecoder(r *http.Request) Decoder {
case "text/html", "text/plain":
return newTextDecoder(r.Body, contentType)
default:
return json.NewDecoder(r.Body)
return newUnsupportedDecoder(contentType)
}
}

Expand Down Expand Up @@ -306,3 +315,21 @@ func (e *textDecoder) Decode(v any) error {
}
return nil
}

// newUnsupportedDecoder returns a decoder that returns an error indicating that
// the content type is not supported.
func newUnsupportedDecoder(ct string) Decoder {
return &unsupportedDecoder{ct}
}

type unsupportedDecoder struct {
ct string
}

func (e *unsupportedDecoder) Decode(v any) error {
return &goa.ServiceError{
Name: ErrUnsupportedMediaTypeName,
ID: goa.NewErrorID(),
Message: fmt.Sprintf("unsupported media type %s", e.ct),
}
}
130 changes: 81 additions & 49 deletions http/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

goa "goa.design/goa/v3/pkg"
)

Expand All @@ -36,25 +39,70 @@ func TestRequestEncoder(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := &http.Request{
Header: http.Header{},
}
r := &http.Request{Header: http.Header{}}
if c.requestCT != "" {
r.Header.Set(ct, c.requestCT)
}

encoder := RequestEncoder(r)

if gotT := fmt.Sprintf("%T", encoder); gotT != wantT {
t.Errorf("got encoder type %s, want %s", gotT, wantT)
}
if gotCT := r.Header.Get(ct); gotCT != c.wantCT {
t.Errorf("got Content-Type %q, want %q", gotCT, c.wantCT)
assert.Equal(t, wantT, fmt.Sprintf("%T", encoder))
assert.Equal(t, c.wantCT, r.Header.Get(ct))
})
}
}

func TestRequestDecoder(t *testing.T) {
const (
ct = "Content-Type"
ctJSON = "application/json"
ctXML = "application/xml"
ctGob = "application/gob"
unsupportedT = "*http.unsupportedDecoder"
jsonT = "*json.Decoder"
xmlT = "*xml.Decoder"
gobT = "*gob.Decoder"
)
cases := []struct {
name string
requestCT string
wantCT string
}{
{"no ct", "", jsonT},
{"unsupported ct", "application/foo", unsupportedT},
{"json ct", ctJSON, jsonT},
{"xml ct", ctXML, xmlT},
{"gob ct", ctGob, gobT},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := &http.Request{Header: http.Header{}}
if c.requestCT != "" {
r.Header.Set(ct, c.requestCT)
}

decoder := RequestDecoder(r)

assert.Equal(t, c.wantCT, fmt.Sprintf("%T", decoder))
})
}
}

func TestUnsupportedDecoder(t *testing.T) {
// Write the response produced when writing the error returned the
// unsupported decoder to validate the response status code.
w := httptest.NewRecorder()
decoder := &unsupportedDecoder{"application/foo"}
err := decoder.Decode(nil)
require.Error(t, err)
encoder := ErrorEncoder(ResponseEncoder, nil)

err = encoder(context.Background(), w, err)

require.NoError(t, err)
assert.Equal(t, http.StatusUnsupportedMediaType, w.Code)
}

func TestResponseEncoder(t *testing.T) {
cases := []struct {
name string
Expand Down Expand Up @@ -101,10 +149,10 @@ func TestResponseEncoder(t *testing.T) {
ctx = context.WithValue(ctx, AcceptTypeKey, c.acceptType)
ctx = context.WithValue(ctx, ContentTypeKey, c.contentType)
w := httptest.NewRecorder()

encoder := ResponseEncoder(ctx, w)
if c.encoderType != fmt.Sprintf("%T", encoder) {
t.Errorf("got encoder type %s, expected %s", fmt.Sprintf("%T", encoder), c.encoderType)
}

assert.Equal(t, c.encoderType, fmt.Sprintf("%T", encoder))
})
}
}
Expand Down Expand Up @@ -138,13 +186,12 @@ func TestResponseEncoder_Encode_ErrorResponse(t *testing.T) {
w := httptest.NewRecorder()
ErrorResponseXMLName = c.xmlName
encoder := ResponseEncoder(ctx, w)
if err := encoder.Encode(NewErrorResponse(ctx, serviceError)); err != nil {
t.Error(err)
}

err := encoder.Encode(NewErrorResponse(ctx, serviceError))

assert.NoError(t, err)
body := strings.TrimSpace(w.Body.String())
if body != c.encoded {
t.Errorf("got %s, expected %s", body, c.encoded)
}
assert.Equal(t, c.encoded, body)
})
}
}
Expand Down Expand Up @@ -184,9 +231,8 @@ func TestResponseDecoder(t *testing.T) {
},
}
decoder := ResponseDecoder(r)
if c.decoderType != fmt.Sprintf("%T", decoder) {
t.Errorf("got decoder type %s, expected %s", fmt.Sprintf("%T", decoder), c.decoderType)
}

assert.Equal(t, c.decoderType, fmt.Sprintf("%T", decoder))
})
}
}
Expand All @@ -211,57 +257,43 @@ func TestTextEncoder_Encode(t *testing.T) {
buffer.Reset()
err := encoder.Encode(c.value)
if c.error != nil {
if err == nil || c.error.Error() != err.Error() {
t.Errorf("got error %q, expected %q", err, c.error)
}
} else {
if err != nil {
t.Errorf("got error %q, expected <nil>", err)
}
if buffer.String() != testString {
t.Errorf("got string %s, expected %s", buffer.String(), testString)
}
assert.Error(t, err, c.error)
return
}
require.NoError(t, err)
assert.Equal(t, testString, buffer.String())
})
}
}

func TestTextPlainDecoder_Decode_String(t *testing.T) {
decoder := makeTextDecoder()

var value string

err := decoder.Decode(&value)
if err != nil {
t.Errorf("got error %q, expected <nil>", err)
}
if testString != value {
t.Errorf("got string %s, expected %s", value, testString)
}

assert.NoError(t, err)
assert.Equal(t, testString, value)
}

func TestTextPlainDecoder_Decode_Bytes(t *testing.T) {
decoder := makeTextDecoder()

var value []byte

err := decoder.Decode(&value)
if err != nil {
t.Errorf("got error %q, expected <nil>", err)
}
if testString != string(value) {
t.Errorf("got string %s, expected %s", value, testString)
}

assert.NoError(t, err)
assert.Equal(t, testString, string(value))
}

func TestTextPlainDecoder_Decode_Other(t *testing.T) {
decoder := makeTextDecoder()

expectedErr := fmt.Errorf("can't decode content/type to *int")

var value int

err := decoder.Decode(&value)
if err == nil || err.Error() != expectedErr.Error() {
t.Errorf("got error %q, expectedErr %q", err, expectedErr)
}

assert.Error(t, err, expectedErr)
}

func makeTextDecoder() Decoder {
Expand Down
3 changes: 3 additions & 0 deletions http/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ func (resp *ErrorResponse) MarshalXML(e *xml.Encoder, _ xml.StartElement) error
// error. This method is used by the generated server code when the error is not
// described explicitly in the design.
func (resp *ErrorResponse) StatusCode() int {
if resp.Name == ErrUnsupportedMediaTypeName {
return http.StatusUnsupportedMediaType
}
if resp.Fault {
return http.StatusInternalServerError
}
Expand Down

0 comments on commit a3f64cc

Please sign in to comment.