Skip to content

Commit

Permalink
Add CORSMethodMiddleware (#366)
Browse files Browse the repository at this point in the history
CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
on a request, by matching routes based only on paths. It also handles
OPTIONS requests, by settings Access-Control-Allow-Methods, and then
returning without calling the next HTTP handler.
  • Loading branch information
fharding1 authored and elithrar committed May 12, 2018
1 parent ded0c29 commit 5e55a4a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
44 changes: 43 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package mux

import "net/http"
import (
"net/http"
"strings"
)

// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
Expand Down Expand Up @@ -28,3 +31,42 @@ func (r *Router) Use(mwf ...MiddlewareFunc) {
func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw)
}

// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
// returning without calling the next http handler.
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string

err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}

allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})

if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))

if req.Method == "OPTIONS" {
return
}
}

next.ServeHTTP(w, req)
})
}
}
41 changes: 41 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mux
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)

Expand Down Expand Up @@ -334,3 +335,43 @@ func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
t.Fatal("Middleware was called for a method mismatch")
}
}

func TestCORSMethodMiddleware(t *testing.T) {
router := NewRouter()

cases := []struct {
path string
response string
method string
testURL string
expectedAllowedMethods string
}{
{"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
{"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
{"/g", "d", "POST", "/g", "POST,OPTIONS"},
}

for _, tt := range cases {
router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
}

router.Use(CORSMethodMiddleware(router))

for _, tt := range cases {
rr := httptest.NewRecorder()
req := newRequest(tt.method, tt.testURL)

router.ServeHTTP(rr, req)

if rr.Body.String() != tt.response {
t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
}

allowedMethods := rr.HeaderMap.Get("Access-Control-Allow-Methods")

if allowedMethods != tt.expectedAllowedMethods {
t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
}
}
}
8 changes: 8 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2315,6 +2315,14 @@ func stringMapEqual(m1, m2 map[string]string) bool {
return true
}

// stringHandler returns a handler func that writes a message 's' to the
// http.ResponseWriter.
func stringHandler(s string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(s))
}
}

// newRequest is a helper function to create a new request with a method and url.
// The request returned is a 'server' request as opposed to a 'client' one through
// simulated write onto the wire and read off of the wire.
Expand Down

0 comments on commit 5e55a4a

Please sign in to comment.