Skip to content

Commit

Permalink
Merge 060d2fa into 97b4a9d
Browse files Browse the repository at this point in the history
  • Loading branch information
cadigun committed Jul 9, 2020
2 parents 97b4a9d + 060d2fa commit de55803
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
58 changes: 54 additions & 4 deletions cors.go
Expand Up @@ -85,9 +85,9 @@ type Config struct {
// One time, do the conversion from our the public facing Configuration,
// to all the formats we use internally strings for headers.. slices for looping
func (config *Config) prepare() {
config.origins = strings.Split(config.Origins, ", ")
config.methods = strings.Split(config.Methods, ", ")
config.requestHeaders = strings.Split(config.RequestHeaders, ", ")
config.origins = strings.Split(strings.ReplaceAll(config.Origins, " ", ""), ",")
config.methods = strings.Split(strings.ReplaceAll(config.Methods, " ", ""), ",")
config.requestHeaders = strings.Split(strings.ReplaceAll(config.RequestHeaders, " ", ""), ",")
config.maxAge = fmt.Sprintf("%.f", config.MaxAge.Seconds())

// Generates a boolean of value "true".
Expand Down Expand Up @@ -205,13 +205,63 @@ func handleRequest(context *gin.Context, config Config) bool {
// Case-sensitive match of origin header
func matchOrigin(origin string, config Config) bool {
for _, value := range config.origins {
if value == origin {
if matchString(value, origin) {
return true
}
}
return false
}

func matchString(pattern string, str string) bool {
if pattern == str {
return true
}
EOF := len(str)
final := len(pattern)
skip := -1
for i := 0; i < final; i++ {
if pattern[i] == '*' {
skip = i
}
}
if skip == -1 && pattern != str {
return false
}
nstr := EOF - (final - skip) + 1
if nstr < 0 {
// input shorter than pattern
return false
}
if skip != -1 && pattern[skip+1:final] != str[nstr:EOF] {
return false
}

loopback := -1
current := 0
for cursor := 0; cursor < EOF && current != final; cursor++ {
if str[cursor] == pattern[current] {
current++
} else if pattern[current] == '*' {
loopback = current
if current+1 < final && pattern[current+1] == str[cursor] ||
current+1 < final && pattern[current+1] == '*' {
current++
cursor--
}
} else if loopback == -1 {
break
} else {
current = loopback
cursor--
}
}

if current == skip && skip+1 == final {
return pattern[current] == '*'
}
return current == final
}

// Case-sensitive match of request method
func validateRequestMethod(requestMethod string, config Config) bool {
if !config.ValidateHeaders {
Expand Down
42 changes: 39 additions & 3 deletions cors_test.go
Expand Up @@ -102,6 +102,25 @@ func TestMismatchOrigin(t *testing.T) {
}
}

func TestWildMismatchOrigin(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()

req.Header.Set("Origin", "http://files.testing.com")

router := gin.New()

router.Use(Middleware(Config{
Origins: "http://*testing.io/*, http://sample.testing.com/*, http://this-is-not-a-typical-short-url.*.testing.com/expected-prefix",
}))

router.ServeHTTP(w, req)

if w.Header().Get(AllowOriginKey) != "" {
t.Fatal("This should not match.")
}
}

func TestPreflightRequest(t *testing.T) {
req, _ := http.NewRequest("OPTIONS", "/", nil)
w := httptest.NewRecorder()
Expand Down Expand Up @@ -192,6 +211,23 @@ func TestMatchOrigin(t *testing.T) {
}
}

func TestWildMatchOrigin(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()

req.Header.Set("Origin", "http://files.testing.com")

router := gin.New()
router.Use(Middleware(Config{
Origins: "http://files.*testing, *://files.testing*",
}))
router.ServeHTTP(w, req)

if w.Header().Get(AllowOriginKey) == "" {
t.Fatal("Origin matches, this header should be set.")
}
}

func TestForceOrigin(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
Expand All @@ -210,9 +246,9 @@ func TestForceOrigin(t *testing.T) {
func TestForceOriginCredentails(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()

req.Header.Set("Origin", "http://localhost")

router := gin.New()
router.Use(Middleware(Config{
Origins: "http://localhost",
Expand All @@ -224,7 +260,7 @@ func TestForceOriginCredentails(t *testing.T) {
MaxAge: 1 * time.Minute,
}))
router.ServeHTTP(w, req)

if w.Header().Get(AllowOriginKey) != "http://localhost" {
t.Fatal("Improper Origin is set.")
}
Expand Down

0 comments on commit de55803

Please sign in to comment.