Skip to content

Commit

Permalink
Merge pull request #14 from configcat/fix-cors-allow-headers
Browse files Browse the repository at this point in the history
Auto put auth headers into CORS allowed headers
  • Loading branch information
z4kn4fein committed Nov 10, 2023
2 parents 5699354 + 108a5ce commit 8b7d63b
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 63 deletions.
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -32,7 +32,7 @@ require (
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231009173412-8bfb1ae86b6c // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Expand Up @@ -72,8 +72,8 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
20 changes: 20 additions & 0 deletions internal/utils/utils.go
Expand Up @@ -77,3 +77,23 @@ func AddSdkIdContextParam(r *http.Request) {
ctx := context.WithValue(context.Background(), httprouter.ParamsKey, params)
*r = *r.WithContext(ctx)
}

func Keys[M ~map[K]V, K comparable, V any](m M) []K {
r := make([]K, 0, len(m))
for k := range m {
r = append(r, k)
}
return r
}

func DedupStringSlice[T string](strings []T) []T {
keys := make(map[T]bool)
var list []T
for _, item := range strings {
if _, value := keys[item]; !value {
keys[item] = true
list = append(list, item)
}
}
return list
}
2 changes: 1 addition & 1 deletion sdk/user_agent.go
Expand Up @@ -4,7 +4,7 @@ import (
"net/http"
)

const proxyVersion = "0.2.2"
const proxyVersion = "0.2.3"

type userAgentInterceptor struct {
http.RoundTripper
Expand Down
50 changes: 38 additions & 12 deletions web/mware/cors.go
Expand Up @@ -2,50 +2,76 @@ package mware

import (
"github.com/configcat/configcat-proxy/config"
"github.com/configcat/configcat-proxy/internal/utils"
"net/http"
"slices"
"strings"
)

var defaultAllowedHeaders = strings.Join([]string{
var defaultAllowedHeaders = []string{
"Cache-Control",
"Content-Type",
"Content-Length",
"Accept-Encoding",
"If-None-Match",
}, ",")
}

var defaultExposedHeaders = strings.Join([]string{
var defaultExposedHeaders = []string{
"Content-Length",
"ETag",
"Date",
"Content-Encoding",
}, ",")
}

var defaultAllowedOrigin = "*"

func CORS(allowedMethods []string, allowedOrigins []string, originRegexConfig *config.OriginRegexConfig, next http.HandlerFunc) http.HandlerFunc {
func CORS(allowedMethods []string, allowedOrigins []string, headers []string, authHeaders []string, originRegexConfig *config.OriginRegexConfig, next http.HandlerFunc) http.HandlerFunc {
var exposedHeaders = defaultExposedHeaders
if len(headers) > 0 {
exposedHeaders = append(exposedHeaders, headers...)
exposedHeaders = utils.DedupStringSlice(exposedHeaders)
}

var allowedHeaders = defaultAllowedHeaders
if len(authHeaders) > 0 {
allowedHeaders = append(allowedHeaders, authHeaders...)
allowedHeaders = utils.DedupStringSlice(allowedHeaders)
}

exposedHeadersString := strings.Join(exposedHeaders, ",")
allowedHeadersString := strings.Join(allowedHeaders, ",")

return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if r.Method == http.MethodOptions {
setOptionsCORSHeaders(w, origin, allowedOrigins, originRegexConfig, allowedMethods)
setOptionsCORSHeaders(w, origin, allowedOrigins, originRegexConfig, allowedMethods, exposedHeadersString, allowedHeadersString)
} else {
setDefaultCORSHeaders(w, origin, allowedOrigins, originRegexConfig)
setDefaultCORSHeaders(w, origin, allowedOrigins, exposedHeadersString, originRegexConfig)
}
next(w, r)
}
}

func setDefaultCORSHeaders(w http.ResponseWriter, requestOrigin string, allowedOrigins []string, originRegexConfig *config.OriginRegexConfig) {
func setDefaultCORSHeaders(w http.ResponseWriter,
requestOrigin string,
allowedOrigins []string,
exposedHeaders string,
originRegexConfig *config.OriginRegexConfig) {
w.Header().Set("Access-Control-Allow-Origin", determineOrigin(requestOrigin, allowedOrigins, originRegexConfig))
w.Header().Set("Access-Control-Expose-Headers", defaultExposedHeaders)
w.Header().Set("Access-Control-Expose-Headers", exposedHeaders)
}

func setOptionsCORSHeaders(w http.ResponseWriter, requestOrigin string, allowedOrigins []string, originRegexConfig *config.OriginRegexConfig, allowedMethods []string) {
setDefaultCORSHeaders(w, requestOrigin, allowedOrigins, originRegexConfig)
func setOptionsCORSHeaders(w http.ResponseWriter,
requestOrigin string,
allowedOrigins []string,
originRegexConfig *config.OriginRegexConfig,
allowedMethods []string,
exposeHeaders string,
allowedHeaders string) {
setDefaultCORSHeaders(w, requestOrigin, allowedOrigins, exposeHeaders, originRegexConfig)
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Max-Age", "600")
w.Header().Set("Access-Control-Allow-Headers", defaultAllowedHeaders)
w.Header().Set("Access-Control-Allow-Headers", allowedHeaders)
if allowedMethods != nil && len(allowedMethods) > 0 {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ","))
}
Expand Down
34 changes: 17 additions & 17 deletions web/mware/cors_test.go
Expand Up @@ -13,7 +13,7 @@ import (

func TestCORS(t *testing.T) {
t.Run("* origin, options", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, nil, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand All @@ -30,7 +30,7 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
})
t.Run("custom origin, options", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"http://localhost"}, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"http://localhost"}, nil, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand All @@ -48,7 +48,7 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
})
t.Run("* origin, get", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, nil, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand All @@ -61,7 +61,7 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
})
t.Run("custom origin, get", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"http://localhost"}, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"http://localhost"}, nil, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand All @@ -75,7 +75,7 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
})
t.Run("custom origin, options, multiple origins", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test1.com", "https://test2.com"}, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test1.com", "https://test2.com"}, []string{"h1", "ETag"}, []string{"X-AUTH"}, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand All @@ -87,45 +87,45 @@ func TestCORS(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, "false", resp.Header.Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match,X-AUTH", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "600", resp.Header.Get("Access-Control-Max-Age"))
assert.Equal(t, "https://test1.com", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding,h1", resp.Header.Get("Access-Control-Expose-Headers"))

req, _ = http.NewRequest(http.MethodOptions, srv.URL, http.NoBody)
req.Header.Set("Origin", "https://test2.com")
resp, _ = client.Do(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, "false", resp.Header.Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match,X-AUTH", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "600", resp.Header.Get("Access-Control-Max-Age"))
assert.Equal(t, "https://test2.com", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding,h1", resp.Header.Get("Access-Control-Expose-Headers"))

req, _ = http.NewRequest(http.MethodOptions, srv.URL, http.NoBody)
req.Header.Set("Origin", "something-else")
resp, _ = client.Do(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, "false", resp.Header.Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match,X-AUTH", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "600", resp.Header.Get("Access-Control-Max-Age"))
assert.Equal(t, "https://test1.com", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding,h1", resp.Header.Get("Access-Control-Expose-Headers"))

req, _ = http.NewRequest(http.MethodOptions, srv.URL, http.NoBody)
resp, _ = client.Do(req)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET,OPTIONS", resp.Header.Get("Access-Control-Allow-Methods"))
assert.Equal(t, "false", resp.Header.Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "Cache-Control,Content-Type,Content-Length,Accept-Encoding,If-None-Match,X-AUTH", resp.Header.Get("Access-Control-Allow-Headers"))
assert.Equal(t, "600", resp.Header.Get("Access-Control-Max-Age"))
assert.Equal(t, "https://test1.com", resp.Header.Get("Access-Control-Allow-Origin"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding", resp.Header.Get("Access-Control-Expose-Headers"))
assert.Equal(t, "Content-Length,ETag,Date,Content-Encoding,h1", resp.Header.Get("Access-Control-Expose-Headers"))
})
t.Run("custom origin, get, multiple origins", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test1.com", "https://test2.com"}, nil, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test1.com", "https://test2.com"}, nil, nil, nil, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestCORS(t *testing.T) {
regex1, _ := regexp.Compile(".*test1\\.com")
regex2, _ := regexp.Compile(".*test2\\.com")
t.Run("only regex", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, &config.OriginRegexConfig{
handler := CORS([]string{http.MethodGet, http.MethodOptions}, nil, nil, nil, &config.OriginRegexConfig{
Regexes: []*regexp.Regexp{
regex1,
regex2,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "https://test3.com", resp.Header.Get("Access-Control-Allow-Origin"))
})
t.Run("both", func(t *testing.T) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test3.com", "https://test4.com"}, &config.OriginRegexConfig{
handler := CORS([]string{http.MethodGet, http.MethodOptions}, []string{"https://test3.com", "https://test4.com"}, nil, nil, &config.OriginRegexConfig{
Regexes: []*regexp.Regexp{
regex1,
regex2,
Expand Down Expand Up @@ -268,7 +268,7 @@ http:
conf, err := config.LoadConfigFromFileAndEnvironment(file)
require.NoError(t, err)

handler := CORS([]string{http.MethodGet, http.MethodOptions}, conf.Http.Api.CORS.AllowedOrigins, &conf.Http.Api.CORS.AllowedOriginsRegex, func(writer http.ResponseWriter, request *http.Request) {
handler := CORS([]string{http.MethodGet, http.MethodOptions}, conf.Http.Api.CORS.AllowedOrigins, nil, nil, &conf.Http.Api.CORS.AllowedOriginsRegex, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
})
srv := httptest.NewServer(handler)
Expand Down
7 changes: 4 additions & 3 deletions web/router.go
Expand Up @@ -2,6 +2,7 @@ package web

import (
"github.com/configcat/configcat-proxy/config"
"github.com/configcat/configcat-proxy/internal/utils"
"github.com/configcat/configcat-proxy/log"
"github.com/configcat/configcat-proxy/metrics"
"github.com/configcat/configcat-proxy/sdk"
Expand Down Expand Up @@ -74,7 +75,7 @@ func (s *HttpRouter) setupSSERoutes(conf *config.SseConfig, sdkClients map[strin
endpoint.handler = mware.ExtraHeaders(conf.Headers, endpoint.handler)
}
if conf.CORS.Enabled {
endpoint.handler = mware.CORS([]string{endpoint.method, http.MethodOptions}, conf.CORS.AllowedOrigins, &conf.CORS.AllowedOriginsRegex, endpoint.handler)
endpoint.handler = mware.CORS([]string{endpoint.method, http.MethodOptions}, conf.CORS.AllowedOrigins, utils.Keys(conf.Headers), nil, &conf.CORS.AllowedOriginsRegex, endpoint.handler)
}
if l.Level() == log.Debug {
endpoint.handler = mware.DebugLog(l, endpoint.handler)
Expand Down Expand Up @@ -114,7 +115,7 @@ func (s *HttpRouter) setupCDNProxyRoutes(conf *config.CdnProxyConfig, sdkClients
handler = mware.ExtraHeaders(conf.Headers, handler)
}
if conf.CORS.Enabled {
handler = mware.CORS([]string{http.MethodGet, http.MethodOptions}, conf.CORS.AllowedOrigins, &conf.CORS.AllowedOriginsRegex, handler)
handler = mware.CORS([]string{http.MethodGet, http.MethodOptions}, conf.CORS.AllowedOrigins, utils.Keys(conf.Headers), nil, &conf.CORS.AllowedOriginsRegex, handler)
}
if s.metrics != nil {
handler = metrics.Measure(s.metrics, handler)
Expand Down Expand Up @@ -161,7 +162,7 @@ func (s *HttpRouter) setupAPIRoutes(conf *config.ApiConfig, sdkClients map[strin
endpoint.handler = mware.ExtraHeaders(conf.Headers, endpoint.handler)
}
if conf.CORS.Enabled {
endpoint.handler = mware.CORS([]string{endpoint.method, http.MethodOptions}, conf.CORS.AllowedOrigins, &conf.CORS.AllowedOriginsRegex, endpoint.handler)
endpoint.handler = mware.CORS([]string{endpoint.method, http.MethodOptions}, conf.CORS.AllowedOrigins, utils.Keys(conf.Headers), utils.Keys(conf.AuthHeaders), &conf.CORS.AllowedOriginsRegex, endpoint.handler)
}
if s.metrics != nil {
endpoint.handler = metrics.Measure(s.metrics, endpoint.handler)
Expand Down

0 comments on commit 8b7d63b

Please sign in to comment.