Skip to content

Commit

Permalink
0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidCai1111 committed Nov 25, 2016
1 parent 014d716 commit cb0761e
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 53 deletions.
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
# cors
# cors
[![Build Status](https://travis-ci.org/go-http-utils/cors.svg?branch=master)](https://travis-ci.org/go-http-utils/cors)
[![Coverage Status](https://coveralls.io/repos/github/go-http-utils/cors/badge.svg?branch=master)](https://coveralls.io/github/go-http-utils/cors?branch=master)

CORS middleware for Go.

## Installation

```
go get -u github.com/go-http-utils/cors
```

## Documentation

API documentation can be found here: https://godoc.org/github.com/go-http-utils/cors

## Usage

```go
import (
"github.com/go-http-utils/cors"
)
```

```go
mux := http.NewServeMux()
mux.HandleFunc("/", func(res http.ResponseWriter, req *http.Request) {
res.Write([]byte("Hello World"))
})

http.ListenAndServe(":8080", cors.Handler(mux))
```
58 changes: 33 additions & 25 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,46 @@ import (
)

// Version is this package's version
const Version = "0.0.1"
const Version = "0.1.0"

// Handle wraps the http.Handler with CORS support.
func Handle(h http.Handler, opts ...Option) http.Handler {
// Handler wraps the http.Handler with CORS support.
func Handler(h http.Handler, opts ...Option) http.Handler {
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
option := &options{
exposeHeaders: defualtExposeHeaders,
methods: defualtMethods,
allowOrigins: []string{"*"},
allowOrigin: true,
methods: []string{
http.MethodGet,
http.MethodHead,
http.MethodPut,
http.MethodPost,
http.MethodDelete,
},
}

for _, opt := range opts {
opt(option)
}

passed := false
origin := req.Header.Get(headers.Origin)
origin := ""

if option.allowOriginValidator != nil {
passed = option.allowOriginValidator(req)
}
if option.allowOrigin {
origin = req.Header.Get(headers.Origin)

if !passed && option.allowOrigins != nil {
if allowAllOrigins(option.allowOrigins) {
passed = true
} else {
passed = has(option.allowOrigins, origin)
if origin == "" {
origin = "*"
}
} else if option.allowOriginValidator != nil {
origin = option.allowOriginValidator(req)
}

if !passed {
h.ServeHTTP(res, req)
if origin == "" {
return
}

resHeader := res.Header()

resHeader.Set(headers.AccessControlAllowOrigin, origin)

if len(option.exposeHeaders) > 0 {
resHeader.Set(headers.AccessControlExposeHeaders,
strings.Join(option.exposeHeaders, ","))
Expand All @@ -64,16 +67,21 @@ func Handle(h http.Handler, opts ...Option) http.Handler {
strings.Join(option.methods, ","))
}

// TODO: Allow headers
var allowHeaders string

if len(option.allowHeaders) > 0 {
allowHeaders = strings.Join(option.allowHeaders, ",")
} else {
allowHeaders = req.Header.Get(headers.AccessControlRequestHeaders)
}

resHeader.Set(headers.AccessControlAllowHeaders, allowHeaders)

if req.Method == http.MethodOptions {
res.WriteHeader(http.StatusNoContent)
} else {
h.ServeHTTP(res, req)
return
}
})
}

func allowAllOrigins(allowOrigins []string) bool {
return len(allowOrigins) == 1 && allowOrigins[0] == "*"
h.ServeHTTP(res, req)
})
}
239 changes: 239 additions & 0 deletions cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package cors

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/go-http-utils/headers"
"github.com/stretchr/testify/suite"
)

type CorsSuite struct {
suite.Suite
}

func (s *CorsSuite) TestDefaultAllowOrigin() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("*", res.Header.Get(headers.AccessControlAllowOrigin))
}

func (s *CorsSuite) TestReflectAllowOrigin() {
origin := "helloworld.org"

mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, origin)

res, err := sendRequest(req)

s.Nil(err)
s.Equal(origin, res.Header.Get(headers.AccessControlAllowOrigin))
}

func (s *CorsSuite) TestValidatorAllowOrigin() {
origin := "helloworld.org"
validator := func(req *http.Request) string {
return origin
}

mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetAllowOriginValidator(validator), SetAllowOrigin(false)))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.Origin, "test.org")

res, err := sendRequest(req)

s.Nil(err)
s.Equal(origin, res.Header.Get(headers.AccessControlAllowOrigin))
}

func (s *CorsSuite) TestNotAllowOrigin() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetAllowOrigin(false)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("", res.Header.Get(headers.AccessControlAllowOrigin))
s.Equal("", res.Header.Get(headers.AccessControlAllowMethods))
}

func (s *CorsSuite) TestEmptyExpose() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("", res.Header.Get(headers.AccessControlExposeHeaders))
}

func (s *CorsSuite) TestExpose() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetExposeHeaders([]string{headers.AcceptRanges, headers.AcceptDatetime})))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal(headers.AcceptRanges+","+headers.AcceptDatetime,
res.Header.Get(headers.AccessControlExposeHeaders))
}

func (s *CorsSuite) TestMaxAge() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetMaxAge(600)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("600", res.Header.Get(headers.AccessControlMaxAge))
}

func (s *CorsSuite) TestDefualtMethods() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("GET,HEAD,PUT,POST,DELETE",
res.Header.Get(headers.AccessControlAllowMethods))
}

func (s *CorsSuite) TestMethods() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetMethods([]string{http.MethodHead, http.MethodTrace})))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("HEAD,TRACE", res.Header.Get(headers.AccessControlAllowMethods))
}

func (s *CorsSuite) TestCredentials() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetCredentials(true)))

server := httptest.NewServer(mux)

res, err := http.Get(server.URL + "/")

s.Nil(err)
s.Equal(http.StatusOK, res.StatusCode)
s.Equal("true", res.Header.Get(headers.AccessControlAllowCredentials))
}

func (s *CorsSuite) TestDefualtAllowHeader() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)

s.Nil(err)
s.Equal("FOO-BAR", res.Header.Get(headers.AccessControlAllowHeaders))
}

func (s *CorsSuite) TestAllowHeader() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc),
SetAllowHeaders([]string{"FOO", "BAR"})))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodGet, server.URL+"/", nil)

s.Nil(err)

req.Header.Set(headers.AccessControlRequestHeaders, "FOO-BAR")

res, err := sendRequest(req)

s.Nil(err)
s.Equal("Foo,Bar", res.Header.Get(headers.AccessControlAllowHeaders))
}

func (s *CorsSuite) TestOptionRequest() {
mux := http.NewServeMux()
mux.Handle("/", Handler(http.HandlerFunc(helloHandlerFunc)))

server := httptest.NewServer(mux)

req, err := http.NewRequest(http.MethodOptions, server.URL+"/", nil)

s.Nil(err)

res, err := sendRequest(req)

s.Nil(err)
s.Equal(http.StatusNoContent, res.StatusCode)
}

func TestCors(t *testing.T) {
suite.Run(t, new(CorsSuite))
}

func helloHandlerFunc(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(http.StatusOK)

res.Write([]byte("Hello World"))
}

func sendRequest(req *http.Request) (*http.Response, error) {
cli := &http.Client{}

return cli.Do(req)
}
Loading

0 comments on commit cb0761e

Please sign in to comment.