diff --git a/ctx.go b/ctx.go index ec3a11eee6..29224e270a 100644 --- a/ctx.go +++ b/ctx.go @@ -288,6 +288,8 @@ func (c *Ctx) BodyParser(out interface{}) error { // Get content-type ctype := utils.ToLower(utils.UnsafeString(c.fasthttp.Request.Header.ContentType())) + ctype = utils.ParseVendorSpecificContentType(ctype) + // Parse body accordingly if strings.HasPrefix(ctype, MIMEApplicationJSON) { schemaDecoder.SetAliasTag("json") diff --git a/utils/http.go b/utils/http.go index 4584f3c7a7..da7d6ec893 100644 --- a/utils/http.go +++ b/utils/http.go @@ -4,6 +4,8 @@ package utils +import "strings" + const MIMEOctetStream = "application/octet-stream" // GetMIME returns the content-type of a file extension @@ -22,6 +24,32 @@ func GetMIME(extension string) (mime string) { return mime } +// ParseVendorSpecificContentType check if content type is vendor specific and +// if it is parsable to any known types. If its not vendor specific then returns +// the original content type. +func ParseVendorSpecificContentType(cType string) string { + plusIndex := strings.Index(cType, "+") + + if plusIndex == -1 { + return cType + } + + var parsableType string + if semiColonIndex := strings.Index(cType, ";"); semiColonIndex == -1 { + parsableType = cType[plusIndex+1:] + } else { + parsableType = cType[plusIndex+1 : semiColonIndex] + } + + slashIndex := strings.Index(cType, "/") + + if slashIndex == -1 { + return cType + } + + return cType[0:slashIndex+1] + parsableType +} + // limits for HTTP statuscodes const ( statusMessageMin = 100 diff --git a/utils/http_test.go b/utils/http_test.go index 65bfa7a792..baddd94f86 100644 --- a/utils/http_test.go +++ b/utils/http_test.go @@ -53,6 +53,42 @@ func Benchmark_GetMIME(b *testing.B) { }) } +func Test_ParseVendorSpecificContentType(t *testing.T) { + t.Parallel() + + cType := ParseVendorSpecificContentType("application/json") + AssertEqual(t, "application/json", cType) + + cType = ParseVendorSpecificContentType("application/vnd.api+json; version=1") + AssertEqual(t, "application/json", cType) + + cType = ParseVendorSpecificContentType("application/vnd.api+json") + AssertEqual(t, "application/json", cType) + + cType = ParseVendorSpecificContentType("application/vnd.dummy+x-www-form-urlencoded") + AssertEqual(t, "application/x-www-form-urlencoded", cType) + + cType = ParseVendorSpecificContentType("something invalid") + AssertEqual(t, "something invalid", cType) +} + +func Benchmark_ParseVendorSpecificContentType(b *testing.B) { + var cType string + b.Run("vendorContentType", func(b *testing.B) { + for n := 0; n < b.N; n++ { + cType = ParseVendorSpecificContentType("application/vnd.api+json; version=1") + } + AssertEqual(b, "application/json", cType) + }) + + b.Run("defaultContentType", func(b *testing.B) { + for n := 0; n < b.N; n++ { + cType = ParseVendorSpecificContentType("application/json") + } + AssertEqual(b, "application/json", cType) + }) +} + func Test_StatusMessage(t *testing.T) { t.Parallel() res := StatusMessage(204)