From 6845c14aa31235927278801d432273a12b2cc352 Mon Sep 17 00:00:00 2001 From: Sam Phillips Date: Tue, 4 Oct 2022 16:05:17 +0100 Subject: [PATCH] feat: add config option to set Access-Control-Allow-Origin header on responses --- config/config.go | 5 +-- config/config_test.yaml | 1 + config/loader_test.go | 1 + example/config_example.yaml | 2 ++ proxy/handler_test.go | 63 +++++++++++++++++++++++++++++++++++++ proxy/middleware.go | 13 ++++++++ proxy/server.go | 1 + test/e2e_test.go | 15 +++++---- 8 files changed, 93 insertions(+), 8 deletions(-) diff --git a/config/config.go b/config/config.go index fa62951..7d1e696 100644 --- a/config/config.go +++ b/config/config.go @@ -7,8 +7,9 @@ type Config struct { } type ServerConfig struct { - Port int `mapstructure:"port"` - SSL SSLConfig `mapstructure:"ssl"` + Port int `mapstructure:"port"` + SSL SSLConfig `mapstructure:"ssl"` + AccessControlAllowOrigin string `mapstructure:"accessControlAllowOrigin"` } type ProxyConfig struct { diff --git a/config/config_test.yaml b/config/config_test.yaml index 3934218..b32d071 100644 --- a/config/config_test.yaml +++ b/config/config_test.yaml @@ -4,6 +4,7 @@ server: enable: false certFilePath: "/etc/ssl/certs/cert.crt" keyFilePath: "/etc/ssl/private/private.key" + accessControlAllowOrigin: "*" proxy: upstreamTarget: "https://api.form3.tech/v1" diff --git a/config/loader_test.go b/config/loader_test.go index 0b2003a..d39164e 100644 --- a/config/loader_test.go +++ b/config/loader_test.go @@ -129,6 +129,7 @@ func TestLoadConfig(t *testing.T) { CertFilePath: "/etc/ssl/certs/cert.crt", KeyFilePath: "/etc/ssl/private/private.key", }, + AccessControlAllowOrigin: "*", }, Log: LogConfig{ Level: "debug", diff --git a/example/config_example.yaml b/example/config_example.yaml index e7c477a..9d756b3 100644 --- a/example/config_example.yaml +++ b/example/config_example.yaml @@ -11,6 +11,8 @@ server: certFilePath: "/etc/ssl/certs/cert.crt" # Location of the proxy's private key, if SSL is enabled keyFilePath: "/etc/ssl/private/private.key" + # Value to be used in the Access-Control-Allow-Origin response header + accessControlAllowOrigin: "*" # Request forward proxy config proxy: diff --git a/proxy/handler_test.go b/proxy/handler_test.go index c8b8f3a..fc6b2f9 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -84,6 +84,69 @@ func TestHandler(t *testing.T) { } } +func TestHandlerCORS(t *testing.T) { + tests := []struct { + name string + accessControlAllowOrigin string + }{ + { + "no value", + "", + }, + { + "*", + "*", + }, + { + "single domain", + "https://test", + }, + } + + expectedRespBody := "OK" + mockURL := "mock" + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + // Mock dependencies + mockReqSigner := mockReqSigner(mockCtrl) + mockMetricPublisher := mockMetricPublisher(mockCtrl, mockURL) + + // Test upstream target that returns 200 OK + targetSrv := testTargetServer(expectedRespBody) + + // Reverse proxy pointing to test target + rs, err := NewReverseProxy(targetSrv.URL) + require.NoError(t, err) + + // Test handler + var w *test.TestResponseRecorder + + for _, tt := range tests { + h := NewHandler(rs, mockReqSigner, mockMetricPublisher) + _, e := gin.CreateTestContext(w) + e.NoRoute( + RecoverMiddleware(mockMetricPublisher), + LogAndMetricsMiddleware(mockMetricPublisher), + CORSMiddleware(tt.accessControlAllowOrigin), + h.ForwardRequest, + ) + + t.Run(tt.name, func(t *testing.T) { + w = test.NewTestResponseRecorder() + + // Test request + req, err := http.NewRequest(http.MethodGet, mockURL, nil) + require.NoError(t, err) + + e.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, w.Header().Get(AccessControlAllowOriginHeader), tt.accessControlAllowOrigin) + }) + } +} + func mockReqSigner(mockCtrl *gomock.Controller) *MockRequestSigner { mockReqSigner := NewMockRequestSigner(mockCtrl) mockReqSigner.EXPECT().SignRequest(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Request, error) { diff --git a/proxy/middleware.go b/proxy/middleware.go index 2de5c45..d59ff86 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -13,6 +13,10 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + AccessControlAllowOriginHeader string = "Access-Control-Allow-Origin" +) + func RecoverMiddleware(metricPublisher MetricPublisher) gin.HandlerFunc { return func(c *gin.Context) { defer func() { @@ -71,3 +75,12 @@ func LogAndMetricsMiddleware(metricPublisher MetricPublisher) gin.HandlerFunc { }).Info("request summary") } } + +func CORSMiddleware(accessControlAllowOrigin string) gin.HandlerFunc { + if accessControlAllowOrigin != "" { + return func(c *gin.Context) { + c.Writer.Header().Set(AccessControlAllowOriginHeader, accessControlAllowOrigin) + } + } + return func(_ *gin.Context) {} +} diff --git a/proxy/server.go b/proxy/server.go index f43423a..edcfbab 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -33,6 +33,7 @@ func NewServer(cfg config.ServerConfig, handler Handler, metric MetricPublisher) router.NoRoute( RecoverMiddleware(metric), LogAndMetricsMiddleware(metric), + CORSMiddleware(cfg.AccessControlAllowOrigin), handler.ForwardRequest, ) diff --git a/test/e2e_test.go b/test/e2e_test.go index 3d4b2c8..65887bc 100644 --- a/test/e2e_test.go +++ b/test/e2e_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - "github.com/form3tech-oss/go-http-message-signatures" + httpsignatures "github.com/form3tech-oss/go-http-message-signatures" "github.com/form3tech-oss/http-message-signing-proxy/cmd" "github.com/stretchr/testify/suite" ) @@ -37,6 +37,7 @@ const ( type e2eTestSuite struct { suite.Suite + accessControlAllowOrigin string } func (s *e2eTestSuite) msgVerifier() *httpsignatures.MessageVerifier { @@ -86,10 +87,11 @@ func (s *e2eTestSuite) runProxy(upstreamTarget string) { rootCmd.SetArgs(append( []string{"--config", cfgFile}, genSetFlags(map[string]string{ - "server.ssl.certFilePath": sslCertFile, - "server.ssl.keyFilePath": sslKeyFile, - "proxy.signer.keyFilePath": privateKeyFile, - "proxy.upstreamTarget": upstreamTarget, + "server.ssl.certFilePath": sslCertFile, + "server.ssl.keyFilePath": sslKeyFile, + "proxy.signer.keyFilePath": privateKeyFile, + "proxy.upstreamTarget": upstreamTarget, + "server.accessControlAllowOrigin": s.accessControlAllowOrigin, })..., )) go func() { @@ -167,6 +169,7 @@ func (s *e2eTestSuite) TestProxy() { r, err := http.DefaultClient.Do(req) s.NoError(err) s.Equal(test.expectedStatus, r.StatusCode) + s.Equal(s.accessControlAllowOrigin, r.Header.Get("Access-Control-Allow-Origin")) if test.expectedStatus == http.StatusOK { resp, err := readHttpResp[successResp](r) @@ -186,7 +189,7 @@ func (s *e2eTestSuite) TestProxy() { } func TestE2ETestSuite(t *testing.T) { - suite.Run(t, new(e2eTestSuite)) + suite.Run(t, &e2eTestSuite{accessControlAllowOrigin: "*"}) } func genSetFlags(m map[string]string) []string {