diff --git a/middleware/content_type.go b/middleware/content_type.go index 023978fa..c3d76ff6 100644 --- a/middleware/content_type.go +++ b/middleware/content_type.go @@ -6,44 +6,46 @@ import ( ) // SetHeader is a convenience handler to set a response header key/value -func SetHeader(key, value string) func(next http.Handler) http.Handler { +func SetHeader(key, value string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(key, value) next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) + }) } } // AllowContentType enforces a whitelist of request Content-Types otherwise responds // with a 415 Unsupported Media Type status. -func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handler { +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { allowedContentTypes := make(map[string]struct{}, len(contentTypes)) for _, ctype := range contentTypes { allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} } return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength == 0 { - // skip check for empty content body + // Skip check for empty content body next.ServeHTTP(w, r) return } - s := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) - if i := strings.Index(s, ";"); i > -1 { - s = s[0:i] + contentType := r.Header.Get("Content-Type") + if contentType == "" { + // Handle case where Content-Type is empty + w.WriteHeader(http.StatusUnsupportedMediaType) + return } + s := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + if _, ok := allowedContentTypes[s]; ok { next.ServeHTTP(w, r) return } w.WriteHeader(http.StatusUnsupportedMediaType) - } - return http.HandlerFunc(fn) + }) } }