From 6e468d25f1fed4b59ebaad264c07d4c5d60c4078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Goran=20Mari=C4=87?= <45515666+GocaMaric@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:16:59 +0100 Subject: [PATCH] Update content_type.go --- middleware/content_type.go | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/middleware/content_type.go b/middleware/content_type.go index 023978fa..c3d76ff6 100644 --- a/middleware/content_type.go +++ b/middleware/content_type.go @@ -6,44 +6,46 @@ import ( ) // SetHeader is a convenience handler to set a response header key/value -func SetHeader(key, value string) func(next http.Handler) http.Handler { +func SetHeader(key, value string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(key, value) next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) + }) } } // AllowContentType enforces a whitelist of request Content-Types otherwise responds // with a 415 Unsupported Media Type status. -func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handler { +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { allowedContentTypes := make(map[string]struct{}, len(contentTypes)) for _, ctype := range contentTypes { allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} } return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength == 0 { - // skip check for empty content body + // Skip check for empty content body next.ServeHTTP(w, r) return } - s := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) - if i := strings.Index(s, ";"); i > -1 { - s = s[0:i] + contentType := r.Header.Get("Content-Type") + if contentType == "" { + // Handle case where Content-Type is empty + w.WriteHeader(http.StatusUnsupportedMediaType) + return } + s := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + if _, ok := allowedContentTypes[s]; ok { next.ServeHTTP(w, r) return } w.WriteHeader(http.StatusUnsupportedMediaType) - } - return http.HandlerFunc(fn) + }) } }