diff --git a/cors_test.go b/cors_test.go index 7878f4c..6efcee0 100644 --- a/cors_test.go +++ b/cors_test.go @@ -3,6 +3,7 @@ package cors import ( "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -94,6 +95,14 @@ func TestNormalize(t *testing.T) { assert.Equal(t, values, []string{}) } +func TestConvert(t *testing.T) { + methods := []string{"Get", "GET", "get"} + headers := []string{"X-CSRF-TOKEN", "X-CSRF-Token", "x-csrf-token"} + + assert.Equal(t, []string{"GET", "GET", "GET"}, convert(methods, strings.ToUpper)) + assert.Equal(t, []string{"X-Csrf-Token", "X-Csrf-Token", "X-Csrf-Token"}, convert(headers, http.CanonicalHeaderKey)) +} + func TestGenerateNormalHeaders_AllowAllOrigins(t *testing.T) { header := generateNormalHeaders(Config{ AllowAllOrigins: false, @@ -123,7 +132,7 @@ func TestGenerateNormalHeaders_ExposedHeaders(t *testing.T) { header := generateNormalHeaders(Config{ ExposeHeaders: []string{"X-user", "xPassword"}, }) - assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "x-user,xpassword") + assert.Equal(t, header.Get("Access-Control-Expose-Headers"), "X-User,Xpassword") assert.Equal(t, header.Get("Vary"), "Origin") assert.Len(t, header, 2) } @@ -157,7 +166,7 @@ func TestGeneratePreflightHeaders_AllowedMethods(t *testing.T) { header := generatePreflightHeaders(Config{ AllowMethods: []string{"GET ", "post", "PUT", " put "}, }) - assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "get,post,put") + assert.Equal(t, header.Get("Access-Control-Allow-Methods"), "GET,POST,PUT") assert.Equal(t, header.Get("Vary"), "Origin") assert.Len(t, header, 2) } @@ -166,7 +175,7 @@ func TestGeneratePreflightHeaders_AllowedHeaders(t *testing.T) { header := generatePreflightHeaders(Config{ AllowHeaders: []string{"X-user", "Content-Type"}, }) - assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "x-user,content-type") + assert.Equal(t, header.Get("Access-Control-Allow-Headers"), "X-User,Content-Type") assert.Equal(t, header.Get("Vary"), "Origin") assert.Len(t, header, 2) } @@ -227,7 +236,7 @@ func TestPassesAllowedOrigins(t *testing.T) { assert.Equal(t, w.Body.String(), "get") assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://google.com") assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") - assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data,x-user") + assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "Data,X-User") // deny CORS request w = performRequest(router, "GET", "https://google.com") @@ -241,8 +250,8 @@ func TestPassesAllowedOrigins(t *testing.T) { assert.Equal(t, w.Code, 200) assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "http://github.com") assert.Equal(t, w.Header().Get("Access-Control-Allow-Credentials"), "true") - assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "get,post,put,head") - assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,timestamp") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "GET,POST,PUT,HEAD") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "Content-Type,Timestamp") assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "43200") // deny CORS prefligh request @@ -276,15 +285,15 @@ func TestPassesAllowedAllOrigins(t *testing.T) { w = performRequest(router, "POST", "example.com") assert.Equal(t, w.Body.String(), "post") assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*") - assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "data2,x-user2") + assert.Equal(t, w.Header().Get("Access-Control-Expose-Headers"), "Data2,X-User2") assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) // allowed CORS prefligh request w = performRequest(router, "OPTIONS", "https://facebook.com") assert.Equal(t, w.Code, 200) assert.Equal(t, w.Header().Get("Access-Control-Allow-Origin"), "*") - assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "patch,get,post") - assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "content-type,testheader") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Methods"), "PATCH,GET,POST") + assert.Equal(t, w.Header().Get("Access-Control-Allow-Headers"), "Content-Type,Testheader") assert.Equal(t, w.Header().Get("Access-Control-Max-Age"), "36000") assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) } diff --git a/utils.go b/utils.go index 6cb7ea3..460ef17 100644 --- a/utils.go +++ b/utils.go @@ -7,13 +7,15 @@ import ( "time" ) +type converter func(string) string + func generateNormalHeaders(c Config) http.Header { headers := make(http.Header) if c.AllowCredentials { headers.Set("Access-Control-Allow-Credentials", "true") } if len(c.ExposeHeaders) > 0 { - exposeHeaders := normalize(c.ExposeHeaders) + exposeHeaders := convert(normalize(c.ExposeHeaders), http.CanonicalHeaderKey) headers.Set("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ",")) } if c.AllowAllOrigins { @@ -30,12 +32,12 @@ func generatePreflightHeaders(c Config) http.Header { headers.Set("Access-Control-Allow-Credentials", "true") } if len(c.AllowMethods) > 0 { - allowMethods := normalize(c.AllowMethods) + allowMethods := convert(normalize(c.AllowMethods), strings.ToUpper) value := strings.Join(allowMethods, ",") headers.Set("Access-Control-Allow-Methods", value) } if len(c.AllowHeaders) > 0 { - allowHeaders := normalize(c.AllowHeaders) + allowHeaders := convert(normalize(c.AllowHeaders), http.CanonicalHeaderKey) value := strings.Join(allowHeaders, ",") headers.Set("Access-Control-Allow-Headers", value) } @@ -46,7 +48,13 @@ func generatePreflightHeaders(c Config) http.Header { if c.AllowAllOrigins { headers.Set("Access-Control-Allow-Origin", "*") } else { - headers.Set("Vary", "Origin") + // Always set Vary headers + // see https://github.com/rs/cors/issues/10, + // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 + + headers.Add("Vary", "Origin") + headers.Add("Vary", "Access-Control-Request-Method") + headers.Add("Vary", "Access-Control-Request-Headers") } return headers } @@ -67,3 +75,11 @@ func normalize(values []string) []string { } return normalized } + +func convert(s []string, c converter) []string { + var out []string + for _, i := range s { + out = append(out, c(i)) + } + return out +}