diff --git a/middleware/cors.go b/middleware/cors.go index 96ed16985..48964e23f 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -193,8 +193,6 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { res := c.Response() origin := req.Header.Get(echo.HeaderOrigin) - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) - // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request // For simplicity we just consider method type and later `Origin` header. @@ -217,8 +215,12 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { if preflight { // req.Method=OPTIONS + addVaryHeader(res.Header(), echo.HeaderOrigin) return c.NoContent(http.StatusNoContent) } + res.Before(func() { + addVaryHeader(res.Header(), echo.HeaderOrigin) + }) return next(c) // let non-browser calls through } @@ -239,21 +241,28 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // no CORS middleware should block non-preflight requests; // such requests should be let through. One reason is that not all requests that // carry an Origin header participate in the CORS protocol. + res.Before(func() { + addVaryHeader(res.Header(), echo.HeaderOrigin) + }) return next(c) } // Origin existed and was allowed - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") - } - // Simple request will be let though if !preflight { - if exposeHeaders != "" { - res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) - } + res.Before(func() { + addVaryHeader(res.Header(), echo.HeaderOrigin) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + if config.AllowCredentials { + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + } else { + res.Header().Del(echo.HeaderAccessControlAllowCredentials) + } + if exposeHeaders != "" { + res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) + } + }) return next(c) } // Below code is for Preflight (OPTIONS) request @@ -261,8 +270,15 @@ func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if // at the end of handler chain is actual OPTIONS route or 404/405 route which // response code will confuse browsers - res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) - res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) + addVaryHeader(res.Header(), echo.HeaderOrigin) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + if config.AllowCredentials { + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + } else { + res.Header().Del(echo.HeaderAccessControlAllowCredentials) + } + addVaryHeader(res.Header(), echo.HeaderAccessControlRequestMethod) + addVaryHeader(res.Header(), echo.HeaderAccessControlRequestHeaders) if !hasCustomAllowMethods && routerAllowMethods != "" { res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods) @@ -298,3 +314,18 @@ func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) } return "", false, nil } + +func addVaryHeader(h http.Header, value string) { + if h.Get(echo.HeaderVary) == "" { + h.Set(echo.HeaderVary, value) + return + } + for _, v := range h.Values(echo.HeaderVary) { + for _, part := range strings.Split(v, ",") { + if strings.EqualFold(strings.TrimSpace(part), value) { + return + } + } + } + h.Add(echo.HeaderVary, value) +} diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 5de4ca063..b6a5d3e7f 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -626,3 +626,49 @@ func Test_allowOriginFunc(t *testing.T) { } } } + +func TestCORSProxyChainedHeaders(t *testing.T) { + e := echo.New() + + // CORS middleware on the proxy + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://example.com"}, + }) + + // Proxy handler simulating upstream call that also returns CORS headers + proxyHandler := func(c *echo.Context) error { + // Mock upstream copying headers to response + // This simulates the behavior of httputil.ReverseProxy which copies headers from upstream + c.Response().Header().Add(echo.HeaderAccessControlAllowOrigin, "http://example.com") + c.Response().Header().Add(echo.HeaderVary, echo.HeaderOrigin) + c.Response().WriteHeader(http.StatusOK) + return nil + } + + h := cors(proxyHandler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderOrigin, "http://example.com") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + // Verify that Access-Control-Allow-Origin is not duplicated + acaoHeaders := rec.Header()[echo.HeaderAccessControlAllowOrigin] + assert.Len(t, acaoHeaders, 1, "Access-Control-Allow-Origin should not be duplicated") + assert.Equal(t, "http://example.com", acaoHeaders[0]) + + // Verify that Vary: Origin is not duplicated + varyHeaders := rec.Header()[echo.HeaderVary] + originCount := 0 + for _, v := range varyHeaders { + for _, part := range strings.Split(v, ",") { + if strings.EqualFold(strings.TrimSpace(part), echo.HeaderOrigin) { + originCount++ + } + } + } + assert.Equal(t, 1, originCount, "Vary Origin should not be duplicated") +}