Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ type Config struct {
}

type ServerConfig struct {
Port int `mapstructure:"port"`
SSL SSLConfig `mapstructure:"ssl"`
Port int `mapstructure:"port"`
SSL SSLConfig `mapstructure:"ssl"`
AccessControlAllowOrigin string `mapstructure:"accessControlAllowOrigin"`
}

type ProxyConfig struct {
Expand Down
1 change: 1 addition & 0 deletions config/config_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ server:
enable: false
certFilePath: "/etc/ssl/certs/cert.crt"
keyFilePath: "/etc/ssl/private/private.key"
accessControlAllowOrigin: "*"

proxy:
upstreamTarget: "https://api.form3.tech/v1"
Expand Down
1 change: 1 addition & 0 deletions config/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func TestLoadConfig(t *testing.T) {
CertFilePath: "/etc/ssl/certs/cert.crt",
KeyFilePath: "/etc/ssl/private/private.key",
},
AccessControlAllowOrigin: "*",
},
Log: LogConfig{
Level: "debug",
Expand Down
2 changes: 2 additions & 0 deletions example/config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ server:
certFilePath: "/etc/ssl/certs/cert.crt"
# Location of the proxy's private key, if SSL is enabled
keyFilePath: "/etc/ssl/private/private.key"
# Value to be used in the Access-Control-Allow-Origin response header
accessControlAllowOrigin: "*"

# Request forward proxy config
proxy:
Expand Down
63 changes: 63 additions & 0 deletions proxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,69 @@ func TestHandler(t *testing.T) {
}
}

func TestHandlerCORS(t *testing.T) {
tests := []struct {
name string
accessControlAllowOrigin string
}{
{
"no value",
"",
},
{
"*",
"*",
},
{
"single domain",
"https://test",
},
}

expectedRespBody := "OK"
mockURL := "mock"
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

// Mock dependencies
mockReqSigner := mockReqSigner(mockCtrl)
mockMetricPublisher := mockMetricPublisher(mockCtrl, mockURL)

// Test upstream target that returns 200 OK
targetSrv := testTargetServer(expectedRespBody)

// Reverse proxy pointing to test target
rs, err := NewReverseProxy(targetSrv.URL)
require.NoError(t, err)

// Test handler
var w *test.TestResponseRecorder

for _, tt := range tests {
h := NewHandler(rs, mockReqSigner, mockMetricPublisher)
_, e := gin.CreateTestContext(w)
e.NoRoute(
RecoverMiddleware(mockMetricPublisher),
LogAndMetricsMiddleware(mockMetricPublisher),
CORSMiddleware(tt.accessControlAllowOrigin),
h.ForwardRequest,
)

t.Run(tt.name, func(t *testing.T) {
w = test.NewTestResponseRecorder()

// Test request
req, err := http.NewRequest(http.MethodGet, mockURL, nil)
require.NoError(t, err)

e.ServeHTTP(w, req)

require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, w.Header().Get(AccessControlAllowOriginHeader), tt.accessControlAllowOrigin)
})
}
}

func mockReqSigner(mockCtrl *gomock.Controller) *MockRequestSigner {
mockReqSigner := NewMockRequestSigner(mockCtrl)
mockReqSigner.EXPECT().SignRequest(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Request, error) {
Expand Down
13 changes: 13 additions & 0 deletions proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (
log "github.com/sirupsen/logrus"
)

const (
AccessControlAllowOriginHeader string = "Access-Control-Allow-Origin"
)

func RecoverMiddleware(metricPublisher MetricPublisher) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
Expand Down Expand Up @@ -71,3 +75,12 @@ func LogAndMetricsMiddleware(metricPublisher MetricPublisher) gin.HandlerFunc {
}).Info("request summary")
}
}

func CORSMiddleware(accessControlAllowOrigin string) gin.HandlerFunc {
if accessControlAllowOrigin != "" {
return func(c *gin.Context) {
c.Writer.Header().Set(AccessControlAllowOriginHeader, accessControlAllowOrigin)
}
}
return func(_ *gin.Context) {}
}
1 change: 1 addition & 0 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func NewServer(cfg config.ServerConfig, handler Handler, metric MetricPublisher)
router.NoRoute(
RecoverMiddleware(metric),
LogAndMetricsMiddleware(metric),
CORSMiddleware(cfg.AccessControlAllowOrigin),
handler.ForwardRequest,
)

Expand Down
15 changes: 9 additions & 6 deletions test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"testing"
"time"

"github.com/form3tech-oss/go-http-message-signatures"
httpsignatures "github.com/form3tech-oss/go-http-message-signatures"
"github.com/form3tech-oss/http-message-signing-proxy/cmd"
"github.com/stretchr/testify/suite"
)
Expand All @@ -37,6 +37,7 @@ const (

type e2eTestSuite struct {
suite.Suite
accessControlAllowOrigin string
}

func (s *e2eTestSuite) msgVerifier() *httpsignatures.MessageVerifier {
Expand Down Expand Up @@ -86,10 +87,11 @@ func (s *e2eTestSuite) runProxy(upstreamTarget string) {
rootCmd.SetArgs(append(
[]string{"--config", cfgFile},
genSetFlags(map[string]string{
"server.ssl.certFilePath": sslCertFile,
"server.ssl.keyFilePath": sslKeyFile,
"proxy.signer.keyFilePath": privateKeyFile,
"proxy.upstreamTarget": upstreamTarget,
"server.ssl.certFilePath": sslCertFile,
"server.ssl.keyFilePath": sslKeyFile,
"proxy.signer.keyFilePath": privateKeyFile,
"proxy.upstreamTarget": upstreamTarget,
"server.accessControlAllowOrigin": s.accessControlAllowOrigin,
})...,
))
go func() {
Expand Down Expand Up @@ -167,6 +169,7 @@ func (s *e2eTestSuite) TestProxy() {
r, err := http.DefaultClient.Do(req)
s.NoError(err)
s.Equal(test.expectedStatus, r.StatusCode)
s.Equal(s.accessControlAllowOrigin, r.Header.Get("Access-Control-Allow-Origin"))

if test.expectedStatus == http.StatusOK {
resp, err := readHttpResp[successResp](r)
Expand All @@ -186,7 +189,7 @@ func (s *e2eTestSuite) TestProxy() {
}

func TestE2ETestSuite(t *testing.T) {
suite.Run(t, new(e2eTestSuite))
suite.Run(t, &e2eTestSuite{accessControlAllowOrigin: "*"})
}

func genSetFlags(m map[string]string) []string {
Expand Down