Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow cors spec to be defined as regexps #723

Merged
merged 4 commits into from
Aug 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package cors

import (
"net/http"
"regexp"
"strings"

"golang.org/x/net/context"
Expand All @@ -21,10 +22,22 @@ const OriginKey key = "origin"

// MatchOrigin returns true if the given Origin header value matches the
// origin specification.
// Spec can be one of:
// - a plain string identifying an origin. eg http://swagger.goa.design
// - a plain string containing a wildcard. eg *.goa.design
// - the special string * that matches every host
func MatchOrigin(origin, spec string) bool {
if spec == "*" {
return true
}

// Check regular expression
if strings.HasPrefix(spec, "/") && strings.HasSuffix(spec, "/") {
stripped := strings.Trim(spec, "/")
r := regexp.MustCompile(stripped)
return r.Match([]byte(origin))
}

if !strings.Contains(spec, "*") {
return origin == spec
}
Expand All @@ -38,6 +51,13 @@ func MatchOrigin(origin, spec string) bool {
return true
}

// MatchOriginRegexp returns true if the given Origin header value matches the
// origin specification.
// Spec must be a valid regex
func MatchOriginRegexp(origin string, spec *regexp.Regexp) bool {
return spec.Match([]byte(origin))
}

// HandlePreflight returns a simple 200 response. The middleware takes care of handling CORS.
func HandlePreflight() goa.Handler {
return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
Expand Down
53 changes: 53 additions & 0 deletions cors/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package cors_test

import (
"regexp"
"testing"

"github.com/goadesign/goa/cors"
)

func TestMatchOrigin(t *testing.T) {
data := []struct {
Origin string
Spec string
Result bool
}{
{"http://example.com", "*", true},
{"http://example.com", "http://example.com", true},
{"http://example.com", "https://example.com", false},
{"http://test.example.com", "*.example.com", true},
{"http://test.example.com:80", "*.example.com", false},
{"http://test.example.com:80", "http://test.example.com*", true},
}

for _, test := range data {
result := cors.MatchOrigin(test.Origin, test.Spec)
if result != test.Result {
t.Errorf("cors.MatchOrigin(%s, %s) should return %t", test.Origin, test.Spec, test.Result)
}
}
}

func TestMatchOriginRegexp(t *testing.T) {
data := []struct {
Origin string
Spec string
Result bool
}{
{"http://test.example.com:80", "(.*).example.com(.*)", true},
{"http://test.example.com:80", ".*.example.com.*", true},
{"http://test.example.com:80", ".*.other.com.*", false},
{"http://test.example.com", "[test|swag].example.com", true},
{"http://swag.example.com", "[test|swag].example.com", true},
{"http://other.example.com", "[test|swag].example.com", false},
{"http://other.example.com", "[test|swag].other.com", false},
}

for _, test := range data {
result := cors.MatchOriginRegexp(test.Origin, regexp.MustCompile(test.Spec))
if result != test.Result {
t.Errorf("cors.MatchOrigin(%s, %s) should return %t", test.Origin, test.Spec, test.Result)
}
}
}
11 changes: 10 additions & 1 deletion design/apidsl/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ func BasePath(val string) {

// Origin defines the CORS policy for a given origin. The origin can use a wildcard prefix
// such as "https://*.mydomain.com". The special value "*" defines the policy for all origins
// (in which case there should be only one Origin DSL in the parent resource). Example:
// (in which case there should be only one Origin DSL in the parent resource).
// The origin can also be a regular expression wrapped into "/".
// Example:
//
// Origin("http://swagger.goa.design", func() { // Define CORS policy, may be prefixed with "*" wildcard
// Headers("X-Shared-Secret") // One or more authorized headers, use "*" to authorize all
Expand All @@ -159,8 +161,15 @@ func BasePath(val string) {
// Credentials() // Sets Access-Control-Allow-Credentials header
// })
//
// Origin("/[api|swagger].goa.design/", func() {}) // Define CORS policy with a regular expression
func Origin(origin string, dsl func()) {
cors := &design.CORSDefinition{Origin: origin}

if strings.HasPrefix(origin, "/") && strings.HasSuffix(origin, "/") {
cors.Regexp = true
cors.Origin = strings.Trim(origin, "/")
}

if !dslengine.Execute(dsl, cors) {
return
}
Expand Down
2 changes: 2 additions & 0 deletions design/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ type (
MaxAge uint
// Sets Access-Control-Allow-Credentials header
Credentials bool
// Sets Whether the Origin string is a regular expression
Regexp bool
}

// EncodingDefinition defines an encoder supported by the API.
Expand Down
9 changes: 8 additions & 1 deletion design/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
"strings"

Expand Down Expand Up @@ -250,9 +251,15 @@ func (r *ResourceDefinition) validateParent(verr *dslengine.ValidationErrors) {
// Validate makes sure the CORS definition origin is valid.
func (cors *CORSDefinition) Validate() *dslengine.ValidationErrors {
verr := new(dslengine.ValidationErrors)
if strings.Count(cors.Origin, "*") > 1 {
if !cors.Regexp && strings.Count(cors.Origin, "*") > 1 {
verr.Add(cors, "invalid origin, can only contain one wildcard character")
}
if cors.Regexp {
_, err := regexp.Compile(cors.Origin)
if err != nil {
verr.Add(cors, "invalid origin, should be a valid regular expression")
}
}
return verr
}

Expand Down
1 change: 1 addition & 0 deletions goagen/gen_app/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func (g *Generator) generateControllers() error {
codegen.SimpleImport("golang.org/x/net/context"),
codegen.SimpleImport("github.com/goadesign/goa"),
codegen.SimpleImport("github.com/goadesign/goa/cors"),
codegen.SimpleImport("regexp"),
}
encoders, err := BuildEncoders(g.API.Produces, true)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion goagen/gen_app/writers.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,13 +695,15 @@ func Mount{{ .Resource }}Controller(service *goa.Service, ctrl {{ .Resource }}Co
// template input: *ControllerTemplateData
handleCORST = `// handle{{ .Resource }}Origin applies the CORS response headers corresponding to the origin.
func handle{{ .Resource }}Origin(h goa.Handler) goa.Handler {
{{ range $i, $policy := .Origins }}{{ if $policy.Regexp }} spec{{$i}} := regexp.MustCompile({{ printf "%q" $policy.Origin }})
{{ end }}{{ end }}
return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
origin := req.Header.Get("Origin")
if origin == "" {
// Not a CORS request
return h(ctx, rw, req)
}
{{ range $policy := .Origins }} if cors.MatchOrigin(origin, {{ printf "%q" $policy.Origin }}) {
{{ range $i, $policy := .Origins }} {{ if $policy.Regexp }}if cors.MatchOriginRegexp(origin, spec{{$i}}){{else}}if cors.MatchOrigin(origin, {{ printf "%q" $policy.Origin }}){{end}} {
ctx = goa.WithLogContext(ctx, "origin", origin)
rw.Header().Set("Access-Control-Allow-Origin", origin)
{{ if not (eq $policy.Origin "*") }} rw.Header().Set("Vary", "Origin")
Expand Down
78 changes: 78 additions & 0 deletions goagen/gen_app/writers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,42 @@ var _ = Describe("ControllersWriter", func() {
})
})

Context("with regexp origins", func() {
BeforeEach(func() {
actions = []string{"List"}
verbs = []string{"GET"}
paths = []string{"/accounts"}
contexts = []string{"ListBottleContext"}
origins = []*design.CORSDefinition{
{
Origin: "[here|there].example.com",
Headers: []string{"X-One", "X-Two"},
Methods: []string{"GET", "POST"},
Exposed: []string{"X-Three"},
Credentials: true,
Regexp: true,
},
{
Origin: "there.example.com",
Headers: []string{"*"},
Methods: []string{"*"},
},
}

})

It("writes the controller code", func() {
err := writer.Execute(data)
Ω(err).ShouldNot(HaveOccurred())
b, err := ioutil.ReadFile(filename)
Ω(err).ShouldNot(HaveOccurred())
written := string(b)
Ω(written).ShouldNot(BeEmpty())
Ω(written).Should(ContainSubstring(originsIntegration))
Ω(written).Should(ContainSubstring(regexpOriginsHandler))
})
})

})
})
})
Expand Down Expand Up @@ -1232,6 +1268,7 @@ type BottlesController interface {

originsHandler = `// handleBottlesOrigin applies the CORS response headers corresponding to the origin.
func handleBottlesOrigin(h goa.Handler) goa.Handler {

return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
origin := req.Header.Get("Origin")
if origin == "" {
Expand Down Expand Up @@ -1267,6 +1304,47 @@ func handleBottlesOrigin(h goa.Handler) goa.Handler {
return h(ctx, rw, req)
}
}
`

regexpOriginsHandler = `// handleBottlesOrigin applies the CORS response headers corresponding to the origin.
func handleBottlesOrigin(h goa.Handler) goa.Handler {
spec0 := regexp.MustCompile("[here|there].example.com")

return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
origin := req.Header.Get("Origin")
if origin == "" {
// Not a CORS request
return h(ctx, rw, req)
}
if cors.MatchOriginRegexp(origin, spec0) {
ctx = goa.WithLogContext(ctx, "origin", origin)
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Vary", "Origin")
rw.Header().Set("Access-Control-Expose-Headers", "X-Three")
rw.Header().Set("Access-Control-Allow-Credentials", "true")
if acrm := req.Header.Get("Access-Control-Request-Method"); acrm != "" {
// We are handling a preflight request
rw.Header().Set("Access-Control-Allow-Methods", "GET, POST")
rw.Header().Set("Access-Control-Allow-Headers", "X-One, X-Two")
}
return h(ctx, rw, req)
}
if cors.MatchOrigin(origin, "there.example.com") {
ctx = goa.WithLogContext(ctx, "origin", origin)
rw.Header().Set("Access-Control-Allow-Origin", origin)
rw.Header().Set("Vary", "Origin")
rw.Header().Set("Access-Control-Allow-Credentials", "false")
if acrm := req.Header.Get("Access-Control-Request-Method"); acrm != "" {
// We are handling a preflight request
rw.Header().Set("Access-Control-Allow-Methods", "*")
rw.Header().Set("Access-Control-Allow-Headers", "*")
}
return h(ctx, rw, req)
}

return h(ctx, rw, req)
}
}
`

encoderController = `
Expand Down