Skip to content

Commit

Permalink
optimize implementation && tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidCai1111 committed Jan 12, 2017
1 parent 7299421 commit dea9e42
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 63 deletions.
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
test:
go test -v
go test -v -race

cover:
rm -rf *.coverprofile
go test -coverprofile=fresh.coverprofile
gover
go tool cover -html=fresh.coverprofile
go tool cover -html=fresh.coverprofile
rm -rf *.coverprofile
119 changes: 74 additions & 45 deletions cors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cors

import (
"fmt"
"net/http"
"strconv"
"strings"
Expand All @@ -9,79 +10,107 @@ import (
)

// Version is this package's version
const Version = "0.1.0"
const Version = "1.0.0"

// Handler wraps the http.Handler with CORS support.
func Handler(h http.Handler, opts ...Option) http.Handler {
option := &options{
allowOrigin: true,
methods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPut,
http.MethodPost,
http.MethodDelete,
},
}

for _, opt := range opts {
opt(option)
}

return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
option := &options{
allowOrigin: true,
methods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPut,
http.MethodPost,
http.MethodDelete,
},
}
origin := req.Header.Get(headers.Origin)

for _, opt := range opts {
opt(option)
}
// Not a CORS request.
if origin == "" {
h.ServeHTTP(res, req)

origin := ""
return
}

if option.allowOrigin {
origin = req.Header.Get(headers.Origin)
allowOrigin := ""

if origin == "" {
origin = "*"
}
} else if option.allowOriginValidator != nil {
origin = option.allowOriginValidator(req)
if option.allowOriginValidator != nil {
allowOrigin = option.allowOriginValidator(req)
} else {
allowOrigin = req.Header.Get(headers.Origin)
}

if origin == "" {
if allowOrigin == "" {
res.WriteHeader(http.StatusForbidden)
res.Write([]byte(fmt.Sprintf("Invalid origin %v", origin)))
return
}

resHeader := res.Header()

resHeader.Set(headers.AccessControlAllowOrigin, origin)
if allowOrigin != "*" {
resHeader.Add(headers.Vary, headers.Origin)

if len(option.exposeHeaders) > 0 {
resHeader.Set(headers.AccessControlExposeHeaders,
strings.Join(option.exposeHeaders, ","))
if option.credentials {
// When responding to a credentialed request, server must specify a
// domain, and cannot use wild carding.
// See *important note* in https://developer.mozilla.org/en-US/docs/Web/HTTP/Access_control_CORS#Requests_with_credentials .
resHeader.Set(headers.AccessControlAllowCredentials, "true")
}
}

if option.maxAge > 0 {
resHeader.Set(headers.AccessControlMaxAge, strconv.Itoa(option.maxAge))
}
resHeader.Set(headers.AccessControlAllowOrigin, allowOrigin)

if option.credentials == true {
resHeader.Set(headers.AccessControlAllowCredentials, "true")
}
// Preflighted requests
if req.Method == http.MethodOptions {
requestMethod := req.Header.Get(headers.AccessControlRequestMethod)

if len(option.methods) > 0 {
resHeader.Set(headers.AccessControlAllowMethods,
strings.Join(option.methods, ","))
}
if requestMethod == "" {
resHeader.Del(headers.AccessControlAllowOrigin)
resHeader.Del(headers.AccessControlAllowCredentials)

var allowHeaders string
res.WriteHeader(http.StatusForbidden)
res.Write([]byte("Invalid preflighted request, missing Access-Control-Request-Method header"))

if len(option.allowHeaders) > 0 {
allowHeaders = strings.Join(option.allowHeaders, ",")
} else {
allowHeaders = req.Header.Get(headers.AccessControlRequestHeaders)
}
return
}

resHeader.Set(headers.AccessControlAllowHeaders, allowHeaders)
if len(option.methods) > 0 {
resHeader.Set(headers.AccessControlAllowMethods,
strings.Join(option.methods, ","))
}

var allowHeaders string

if len(option.allowHeaders) > 0 {
allowHeaders = strings.Join(option.allowHeaders, ",")
} else {
allowHeaders = req.Header.Get(headers.AccessControlRequestHeaders)
}

resHeader.Set(headers.AccessControlAllowHeaders, allowHeaders)

if option.maxAge > 0 {
resHeader.Set(headers.AccessControlMaxAge, strconv.Itoa(option.maxAge))
}

if req.Method == http.MethodOptions {
res.WriteHeader(http.StatusNoContent)

return
}

if len(option.exposeHeaders) > 0 {
resHeader.Set(headers.AccessControlExposeHeaders,
strings.Join(option.exposeHeaders, ","))
}

h.ServeHTTP(res, req)
})
}
104 changes: 88 additions & 16 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ func (s *CorsSuite) TestDefaultAllowOrigin() {

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)
req.Header.Set(headers.Origin, "test.org")

s.Nil(err)

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("*", res.Header.Get(headers.AccessControlAllowOrigin))
s.Equal("test.org", res.Header.Get(headers.AccessControlAllowOrigin))
}

func (s *CorsSuite) TestReflectAllowOrigin() {
Expand Down Expand Up @@ -105,7 +110,13 @@ func (s *CorsSuite) TestExpose() {

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.rog")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
Expand All @@ -120,10 +131,17 @@ func (s *CorsSuite) TestMaxAge() {

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)

req.Header.Set(headers.Origin, "test.rog")
req.Header.Set(headers.AccessControlRequestMethod, http.MethodPatch)

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusNoContent, res.StatusCode)
s.Equal("600", res.Header.Get(headers.AccessControlMaxAge))
}

Expand All @@ -133,10 +151,18 @@ func (s *CorsSuite) TestDefualtMethods() {

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)

req.Header.Set(headers.Origin, "test.rog")
req.Header.Set(headers.AccessControlRequestMethod, http.MethodPost)
req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusNoContent, res.StatusCode)
s.Equal("GET,HEAD,PUT,POST,DELETE",
res.Header.Get(headers.AccessControlAllowMethods))
}
Expand All @@ -148,21 +174,37 @@ func (s *CorsSuite) TestMethods() {

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)

req.Header.Set(headers.Origin, "test.rog")
req.Header.Set(headers.AccessControlRequestMethod, http.MethodPost)

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusNoContent, res.StatusCode)
s.Equal("HEAD,TRACE", res.Header.Get(headers.AccessControlAllowMethods))
}

func (s *CorsSuite) TestCredentials() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetCredentials(true)))
SetCredentials(true), SetAllowOriginValidator(func(req *http.Request) string {
return "test.org"
})))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")
req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)
req.Header.Set(headers.Origin, "test.rog")

s.Nil(err)

req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
Expand All @@ -175,10 +217,12 @@ func (s *CorsSuite) TestDefualtAllowHeader() {

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)
req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.org")
req.Header.Set(headers.AccessControlRequestMethod, http.MethodPut)
req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)
Expand All @@ -194,10 +238,12 @@ func (s *CorsSuite) TestAllowHeader() {

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)
req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.org")
req.Header.Set(headers.AccessControlRequestMethod, http.MethodPut)
req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)
Expand All @@ -206,20 +252,46 @@ func (s *CorsSuite) TestAllowHeader() {
s.Equal("Foo,Bar", res.Header.Get(headers.AccessControlAllowHeaders))
}

func (s *CorsSuite) TestOptionRequest() {
func (s *CorsSuite) TestLackAllowMethod() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetAllowHeaders([]string{"FOO", "BAR"})))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.org")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusForbidden, res.StatusCode)
}

func (s *CorsSuite) TestOriginNotAllow() {
validator := func(req *http.Request) string {
return ""
}

mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetAllowOriginValidator(validator)))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.org")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusNoContent, res.StatusCode)
s.Equal(http.StatusForbidden, res.StatusCode)
}

func TestCors(t *testing.T) {
Expand Down

0 comments on commit dea9e42

Please sign in to comment.