diff --git a/cors.go b/cors.go index 1f92d1a..1cf7581 100644 --- a/cors.go +++ b/cors.go @@ -110,7 +110,17 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set(corsVaryHeader, corsOriginHeader) } - w.Header().Set(corsAllowOriginHeader, origin) + returnOrigin := origin + for _, o := range ch.allowedOrigins { + // A configuration of * is different than explicitly setting an allowed + // origin. Returning arbitrary origin headers an an access control allow + // origin header is unsafe and is not required by any use case. + if o == corsOriginMatchAll { + returnOrigin = "*" + break + } + } + w.Header().Set(corsAllowOriginHeader, returnOrigin) if r.Method == corsOptionMethod { return diff --git a/cors_test.go b/cors_test.go index c63913e..61eb18f 100644 --- a/cors_test.go +++ b/cors_test.go @@ -327,10 +327,45 @@ func TestCORSHandlerWithCustomValidator(t *testing.T) { return false } - CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) + // Specially craft a CORS object. + handleFunc := func(h http.Handler) http.Handler { + c := &cors{ + allowedMethods: defaultCorsMethods, + allowedHeaders: defaultCorsHeaders, + allowedOrigins: []string{"http://a.example.com"}, + h: h, + } + AllowedOriginValidator(originValidator)(c) + return c + } + + handleFunc(testHandler).ServeHTTP(rr, r) header := rr.HeaderMap.Get(corsAllowOriginHeader) if header != r.URL.String() { t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) } } + +func TestCORSAllowStar(t *testing.T) { + r := newRequest("GET", "http://a.example.com") + r.Header.Set("Origin", r.URL.String()) + rr := httptest.NewRecorder() + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + originValidator := func(origin string) bool { + if strings.HasSuffix(origin, ".example.com") { + return true + } + return false + } + + CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) + header := rr.HeaderMap.Get(corsAllowOriginHeader) + // Because * is the default CORS policy (which is safe), we should be + // expect a * returned here as the Access Control Allow Origin header + if header != "*" { + t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) + } + +}