diff --git a/server/middlewares.go b/server/middlewares.go index 8948b70616a..17f63e41c30 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -119,10 +119,14 @@ func compressMiddleware() func(http.Handler) http.Handler { ) } +// clientUniqueIDMiddleware is a middleware that sets a unique client ID as a cookie if it's provided in the request header. +// If the unique client ID is not in the header but present as a cookie, it adds the ID to the request context. func clientUniqueIDMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() clientUniqueId := r.Header.Get(consts.UIClientUniqueIDHeader) + + // If clientUniqueId is found in the header, set it as a cookie if clientUniqueId != "" { c := &http.Cookie{ Name: consts.UIClientUniqueIDHeader, @@ -135,45 +139,69 @@ func clientUniqueIDMiddleware(next http.Handler) http.Handler { } http.SetCookie(w, c) } else { + // If clientUniqueId is not found in the header, check if it's present as a cookie c, err := r.Cookie(consts.UIClientUniqueIDHeader) if !errors.Is(err, http.ErrNoCookie) { clientUniqueId = c.Value } } + // If a valid clientUniqueId is found, add it to the request context if clientUniqueId != "" { ctx = request.WithClientUniqueId(ctx, clientUniqueId) r = r.WithContext(ctx) } + // Call the next middleware or handler in the chain next.ServeHTTP(w, r) }) } +// serverAddressMiddleware is a middleware function that modifies the request object +// to reflect the address of the server handling the request, as determined by the +// presence of X-Forwarded-* headers or the scheme and host of the request URL. func serverAddressMiddleware(h http.Handler) http.Handler { + // Define a new handler function that will be returned by this middleware function. fn := func(w http.ResponseWriter, r *http.Request) { + // Call the serverAddress function to get the scheme and host of the server + // handling the request. If a host is found, modify the request object to use + // that host and scheme instead of the original ones. if rScheme, rHost := serverAddress(r); rHost != "" { r.Host = rHost r.URL.Scheme = rScheme } + + // Call the next handler in the chain with the modified request and response. h.ServeHTTP(w, r) } + // Return the new handler function as an http.Handler object. return http.HandlerFunc(fn) } +// Define constants for the X-Forwarded-* header keys. var ( xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") ) +// serverAddress is a helper function that returns the scheme and host of the server +// handling the given request, as determined by the presence of X-Forwarded-* headers +// or the scheme and host of the request URL. func serverAddress(r *http.Request) (scheme, host string) { + // Save the original request host for later comparison. origHost := r.Host + + // Determine the protocol of the request based on the presence of a TLS connection. protocol := "http" if r.TLS != nil { protocol = "https" } + + // Get the X-Forwarded-Host header and extract the first host name if there are + // multiple hosts listed. If there is no X-Forwarded-Host header, use the original + // request host as the default. xfh := r.Header.Get(xForwardedHost) if xfh != "" { i := strings.Index(xfh, ",") @@ -182,19 +210,29 @@ func serverAddress(r *http.Request) (scheme, host string) { } xfh = xfh[:i] } + host = firstOr(r.Host, xfh) + + // Determine the protocol and scheme of the request based on the presence of + // X-Forwarded-* headers or the scheme of the request URL. scheme = firstOr( protocol, r.Header.Get(xForwardedProto), r.Header.Get(xForwardedScheme), r.URL.Scheme, ) - host = firstOr(r.Host, xfh) + + // If the request host has changed due to the X-Forwarded-Host header, log a trace + // message with the original and new host values, as well as the scheme and URL. if host != origHost { log.Trace(r.Context(), "Request host has changed", "origHost", origHost, "host", host, "scheme", scheme, "url", r.URL) } + + // Return the scheme and host of the server handling the request. return scheme, host } +// firstOr is a helper function that returns the first non-empty string from a list +// of strings, or a default value if all the strings are empty. func firstOr(or string, strings ...string) string { for _, s := range strings { if s != "" { @@ -204,25 +242,33 @@ func firstOr(or string, strings ...string) string { return or } -// URLParamsMiddleware convert Chi URL params (from Context) to query params, as expected by our REST package +// URLParamsMiddleware is a middleware function that decodes the query string of +// the incoming HTTP request, adds the URL parameters from the routing context, +// and re-encodes the modified query string. func URLParamsMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Retrieve the routing context from the request context. ctx := chi.RouteContext(r.Context()) - parts := make([]string, 0) + + // Parse the existing query string into a URL values map. + params, _ := url.ParseQuery(r.URL.RawQuery) + + // Loop through each URL parameter in the routing context. for i, key := range ctx.URLParams.Keys { - value := ctx.URLParams.Values[i] - if key == "*" { + // Skip any wildcard URL parameter keys. + if strings.Contains(key, "*") { continue } - parts = append(parts, url.QueryEscape(":"+key)+"="+url.QueryEscape(value)) - } - q := strings.Join(parts, "&") - if r.URL.RawQuery == "" { - r.URL.RawQuery = q - } else { - r.URL.RawQuery += "&" + q + + // Add the URL parameter key-value pair to the URL values map. + params.Add(":"+key, ctx.URLParams.Values[i]) } + // Re-encode the URL values map as a query string and replace the + // existing query string in the request. + r.URL.RawQuery = params.Encode() + + // Call the next handler in the chain with the modified request and response. next.ServeHTTP(w, r) }) } diff --git a/server/middlewares_test.go b/server/middlewares_test.go index 946ac613f2d..7823b1ef8f0 100644 --- a/server/middlewares_test.go +++ b/server/middlewares_test.go @@ -3,18 +3,27 @@ package server import ( "net/http" "net/http/httptest" + "net/url" "os" + "github.com/go-chi/chi/v5" + "github.com/navidrome/navidrome/conf" + "github.com/navidrome/navidrome/conf/configtest" + "github.com/navidrome/navidrome/consts" + "github.com/navidrome/navidrome/model/request" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("middlewares", func() { - var nextCalled bool - next := func(w http.ResponseWriter, r *http.Request) { - nextCalled = true - } + BeforeEach(func() { + DeferCleanup(configtest.SetupConfig()) + }) Describe("robotsTXT", func() { + var nextCalled bool + next := func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + } BeforeEach(func() { nextCalled = false }) @@ -144,4 +153,178 @@ var _ = Describe("middlewares", func() { }) }) }) + + Describe("clientUniqueIDMiddleware", func() { + var ( + nextHandler http.Handler + middleware http.Handler + req *http.Request + nextReq *http.Request + rec *httptest.ResponseRecorder + ) + + BeforeEach(func() { + nextHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + nextReq = r + }) + middleware = clientUniqueIDMiddleware(nextHandler) + req, _ = http.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + }) + + Context("when the request header has the unique client ID", func() { + BeforeEach(func() { + req.Header.Set(consts.UIClientUniqueIDHeader, "123456") + conf.Server.BasePath = "/music" + }) + + It("sets the unique client ID as a cookie and adds it to the request context", func() { + middleware.ServeHTTP(rec, req) + + Expect(rec.Result().Cookies()).To(HaveLen(1)) + Expect(rec.Result().Cookies()[0].Name).To(Equal(consts.UIClientUniqueIDHeader)) + Expect(rec.Result().Cookies()[0].Value).To(Equal("123456")) + Expect(rec.Result().Cookies()[0].MaxAge).To(Equal(consts.CookieExpiry)) + Expect(rec.Result().Cookies()[0].HttpOnly).To(BeTrue()) + Expect(rec.Result().Cookies()[0].Secure).To(BeTrue()) + Expect(rec.Result().Cookies()[0].SameSite).To(Equal(http.SameSiteStrictMode)) + Expect(rec.Result().Cookies()[0].Path).To(Equal("/music")) + clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context()) + Expect(clientUniqueId).To(Equal("123456")) + }) + }) + + Context("when the request header does not have the unique client ID", func() { + Context("when the request has the unique client ID in a cookie", func() { + BeforeEach(func() { + req.AddCookie(&http.Cookie{ + Name: consts.UIClientUniqueIDHeader, + Value: "123456", + }) + }) + + It("adds the unique client ID to the request context", func() { + middleware.ServeHTTP(rec, req) + + Expect(rec.Result().Cookies()).To(HaveLen(0)) + + clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context()) + Expect(clientUniqueId).To(Equal("123456")) + }) + }) + + Context("when the request does not have the unique client ID in a cookie", func() { + It("does not add the unique client ID to the request context", func() { + middleware.ServeHTTP(rec, req) + + Expect(rec.Result().Cookies()).To(HaveLen(0)) + + clientUniqueId, _ := request.ClientUniqueIdFrom(nextReq.Context()) + Expect(clientUniqueId).To(BeEmpty()) + }) + }) + }) + }) + + Describe("URLParamsMiddleware", func() { + var ( + router *chi.Mux + middleware http.Handler + recorder *httptest.ResponseRecorder + testHandler http.HandlerFunc + ) + + BeforeEach(func() { + router = chi.NewRouter() + recorder = httptest.NewRecorder() + testHandler = func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("OK")) + } + }) + + Context("when request has no query parameters", func() { + It("adds URL parameters to the request", func() { + middleware = URLParamsMiddleware(testHandler) + router.Mount("/", middleware) + + req, _ := http.NewRequest("GET", "/?user=1", nil) + router.ServeHTTP(recorder, req) + + Expect(recorder.Code).To(Equal(http.StatusOK)) + Expect(recorder.Body.String()).To(Equal("OK")) + Expect(req.URL.RawQuery).To(ContainSubstring("user=1")) + }) + }) + + Context("when request has query parameters", func() { + It("merges URL parameters and query parameters", func() { + router.Route("/{key}", func(r chi.Router) { + r.Use(URLParamsMiddleware) + r.Get("/", testHandler) + }) + + req, _ := http.NewRequest("GET", "/test?key=value", nil) + router.ServeHTTP(recorder, req) + Expect(recorder.Code).To(Equal(http.StatusOK)) + Expect(recorder.Body.String()).To(Equal("OK")) + Expect(req.URL.RawQuery).To(ContainSubstring("key=value")) + Expect(req.URL.RawQuery).To(ContainSubstring("%3Akey=test")) + }) + }) + + Context("when URL parameter has wildcard key", func() { + It("does not include wildcard key in query parameters", func() { + router.Route("/{t*}", func(r chi.Router) { + r.Use(URLParamsMiddleware) + r.Get("/", testHandler) + }) + + req, _ := http.NewRequest("GET", "/test?key=value", nil) + router.ServeHTTP(recorder, req) + + Expect(recorder.Code).To(Equal(http.StatusOK)) + Expect(recorder.Body.String()).To(Equal("OK")) + Expect(req.URL.RawQuery).To(ContainSubstring("key=value")) + }) + }) + + Context("when URL parameters require encoding", func() { + It("encodes URL parameters correctly", func() { + router.Route("/{key}", func(r chi.Router) { + r.Use(URLParamsMiddleware) + r.Get("/", testHandler) + }) + + req, _ := http.NewRequest("GET", "/test with space?key=another value", nil) + router.ServeHTTP(recorder, req) + + Expect(recorder.Code).To(Equal(http.StatusOK)) + Expect(recorder.Body.String()).To(Equal("OK")) + queryValues, _ := url.ParseQuery(req.URL.RawQuery) + Expect(queryValues.Get(":key")).To(Equal("test with space")) + Expect(queryValues.Get("key")).To(Equal("another value")) + }) + }) + + Context("when there are multiple URL parameters", func() { + It("includes all URL parameters in the query string", func() { + router.Route("/{key}/{value}", func(r chi.Router) { + r.Use(URLParamsMiddleware) + r.Get("/", testHandler) + }) + + req, _ := http.NewRequest("GET", "/test/value?key=other_value", nil) + router.ServeHTTP(recorder, req) + + Expect(recorder.Code).To(Equal(http.StatusOK)) + Expect(recorder.Body.String()).To(Equal("OK")) + + queryValues, _ := url.ParseQuery(req.URL.RawQuery) + Expect(queryValues.Get(":key")).To(Equal("test")) + Expect(queryValues.Get(":value")).To(Equal("value")) + Expect(queryValues.Get("key")).To(Equal("other_value")) + }) + }) + }) })