Skip to content

Commit

Permalink
Merge pull request #190 from ipfs/fix/api-post
Browse files Browse the repository at this point in the history
http: configurable allowed request methods for the API.
  • Loading branch information
Stebalien committed Apr 3, 2020
2 parents 59c18d0 + 2fbebbe commit 90182ce
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 17 deletions.
20 changes: 20 additions & 0 deletions http/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ type ServerConfig struct {
// Headers is an optional map of headers that is written out.
Headers map[string][]string

// HandledMethods set which methods will be handled for the HTTP
// requests. Other methods will return 405. This is different from CORS
// AllowedMethods (the API may handle GET and POST, but only allow GETs
// for CORS-enabled requests via AllowedMethods).
HandledMethods []string

// corsOpts is a set of options for CORS headers.
corsOpts *cors.Options

Expand All @@ -32,6 +38,7 @@ type ServerConfig struct {
func NewServerConfig() *ServerConfig {
cfg := new(ServerConfig)
cfg.corsOpts = new(cors.Options)
cfg.HandledMethods = []string{http.MethodPost}
return cfg
}

Expand Down Expand Up @@ -142,3 +149,16 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {

return false
}

// handleRequestMethod returns true if the request method is among
// HandledMethods.
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
// For very small slices as these, this should be faster than
// a map lookup.
for _, m := range cfg.HandledMethods {
if r.Method == m {
return true
}
}
return false
}
16 changes: 14 additions & 2 deletions http/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"strings"
"testing"

"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"
)

func TestErrors(t *testing.T) {
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {

mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
_, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
if err != nil {
Expand Down Expand Up @@ -158,3 +158,15 @@ func TestErrors(t *testing.T) {
t.Run(fmt.Sprintf("%d-%s", i, strings.Join(tc.path, "/")), mkTest(tc))
}
}

func TestUnhandledMethod(t *testing.T) {
tc := httpTestCase{
Method: "GET",
HandledMethods: []string{"POST"},
Code: http.StatusMethodNotAllowed,
ResHeaders: map[string]string{
"Allow": "POST",
},
}
tc.test(t)
}
15 changes: 15 additions & 0 deletions http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}()

// First of all, check if we are allowed to handle the request method
// or we are configured not to.
if !handleRequestMethod(r, h.cfg) {
setAllowedHeaders(w, h.cfg.HandledMethods)
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
log.Warningf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods)
return
}

if !allowOrigin(r, h.cfg) || !allowReferer(r, h.cfg) {
http.Error(w, "403 - Forbidden", http.StatusForbidden)
log.Warningf("API blocked request to %s. (possible CSRF)", r.URL)
Expand Down Expand Up @@ -170,3 +179,9 @@ func sanitizedErrStr(err error) string {
s = strings.Split(s, "\r")[0]
return s
}

func setAllowedHeaders(w http.ResponseWriter, methods []string) {
for _, m := range methods {
w.Header().Add("Allow", m)
}
}
12 changes: 10 additions & 2 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ var (
}
)

func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest.Server) {
func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) {
if len(origins) == 0 {
origins = defaultOrigins
}
Expand All @@ -305,7 +305,15 @@ func getTestServer(t *testing.T, origins []string) (cmds.Environment, *httptest.
wait: make(chan struct{}),
}

return env, httptest.NewServer(NewHandler(env, cmdRoot, originCfg(origins)))
srvCfg := originCfg(origins)

if len(handledMethods) == 0 {
srvCfg.HandledMethods = []string{"GET", "POST"}
} else {
srvCfg.HandledMethods = handledMethods
}

return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
}

func errEq(err1, err2 error) bool {
Expand Down
6 changes: 3 additions & 3 deletions http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
"strings"
"testing"

"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"

"github.com/ipfs/go-ipfs-files"
files "github.com/ipfs/go-ipfs-files"
)

func newReaderPathFile(t *testing.T, path string, reader io.ReadCloser, stat os.FileInfo) files.File {
Expand Down Expand Up @@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {

mkTest := func(tc testcase) func(*testing.T) {
return func(t *testing.T) {
env, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
c := NewClient(srv.URL)
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
if err != nil {
Expand Down
22 changes: 12 additions & 10 deletions http/reforigin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net/url"
"testing"

"github.com/ipfs/go-ipfs-cmds"
cmds "github.com/ipfs/go-ipfs-cmds"
)

func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
Expand All @@ -27,6 +27,7 @@ func originCfg(origins []string) *ServerConfig {
cfg := NewServerConfig()
cfg.SetAllowedOrigins(origins...)
cfg.SetAllowedMethods("GET", "PUT", "POST")
cfg.HandledMethods = []string{"GET", "POST"}
return cfg
}

Expand All @@ -38,14 +39,15 @@ var defaultOrigins = []string{
}

type httpTestCase struct {
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
ReqHeaders map[string]string
ResHeaders map[string]string
Method string
Path string
Code int
Origin string
Referer string
AllowOrigins []string
HandledMethods []string
ReqHeaders map[string]string
ResHeaders map[string]string
}

func (tc *httpTestCase) test(t *testing.T) {
Expand Down Expand Up @@ -83,7 +85,7 @@ func (tc *httpTestCase) test(t *testing.T) {
}

// server
_, server := getTestServer(t, tc.AllowOrigins)
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods)
if server == nil {
return
}
Expand Down

0 comments on commit 90182ce

Please sign in to comment.