From 322e87dd59bf305b461c1167b38527b7a8134452 Mon Sep 17 00:00:00 2001 From: Hector Sanjuan Date: Tue, 16 Oct 2018 15:23:06 +0200 Subject: [PATCH 1/2] Restapi: Add configurable response headers By default, CORS headers allowing GET requests from everywhere are set. This should facilitate the IPFS Web UI integration with the Cluster API. This commit refactors the sendResponse methods in the API, merging them into one as it was difficult to follow the flows that actually send something to the client. All tests now check the presence of the configured headers too, to make sure no route was missed. License: MIT Signed-off-by: Hector Sanjuan --- api/rest/config.go | 21 ++++- api/rest/restapi.go | 181 +++++++++++++++++++++------------------ api/rest/restapi_test.go | 16 ++++ config_test.go | 14 ++- 4 files changed, 145 insertions(+), 87 deletions(-) diff --git a/api/rest/config.go b/api/rest/config.go index 5902d1e6a..ed95146c9 100644 --- a/api/rest/config.go +++ b/api/rest/config.go @@ -27,6 +27,15 @@ const ( DefaultIdleTimeout = 120 * time.Second ) +// These are the default values for Config. +var ( + DefaultHeaders = map[string][]string{ + "Access-Control-Allow-Headers": []string{"X-Requested-With", "Range"}, + "Access-Control-Allow-Methods": []string{"GET"}, + "Access-Control-Allow-Origin": []string{"*"}, + } +) + // Config is used to intialize the API object and allows to // customize the behaviour of it. It implements the config.ComponentConfig // interface. @@ -71,6 +80,10 @@ type Config struct { // BasicAuthCreds is a map of username-password pairs // which are authorized to use Basic Authentication BasicAuthCreds map[string]string + + // Headers provides customization for the headers returned + // by the API. By default it sets a CORS policy. + Headers map[string][]string } type jsonConfig struct { @@ -87,7 +100,8 @@ type jsonConfig struct { ID string `json:"id,omitempty"` PrivateKey string `json:"private_key,omitempty"` - BasicAuthCreds map[string]string `json:"basic_auth_credentials"` + BasicAuthCreds map[string]string `json:"basic_auth_credentials"` + Headers map[string][]string `json:"headers"` } // ConfigKey returns a human-friendly identifier for this type of @@ -116,6 +130,9 @@ func (cfg *Config) Default() error { // Auth cfg.BasicAuthCreds = nil + // Headers + cfg.Headers = DefaultHeaders + return nil } @@ -177,6 +194,7 @@ func (cfg *Config) LoadJSON(raw []byte) error { // Other options cfg.BasicAuthCreds = jcfg.BasicAuthCreds + cfg.Headers = jcfg.Headers return cfg.Validate() } @@ -295,6 +313,7 @@ func (cfg *Config) ToJSON() (raw []byte, err error) { WriteTimeout: cfg.WriteTimeout.String(), IdleTimeout: cfg.IdleTimeout.String(), BasicAuthCreds: cfg.BasicAuthCreds, + Headers: cfg.Headers, } if cfg.ID != "" { diff --git a/api/rest/restapi.go b/api/rest/restapi.go index 25e50d0c2..acc303c94 100644 --- a/api/rest/restapi.go +++ b/api/rest/restapi.go @@ -55,6 +55,9 @@ var ( ErrHTTPEndpointNotEnabled = errors.New("the HTTP endpoint is not enabled") ) +// Used by sendResponse to set the right status +const autoStatus = -1 + // For making a random sharding ID var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -479,7 +482,7 @@ func (api *API) idHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &idSerial) - sendResponse(w, err, idSerial) + api.sendResponse(w, autoStatus, err, idSerial) } func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) { @@ -490,7 +493,7 @@ func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &v) - sendResponse(w, err, v) + api.sendResponse(w, autoStatus, err, v) } func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) { @@ -500,22 +503,24 @@ func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) { "ConnectGraph", struct{}{}, &graph) - sendResponse(w, err, graph) + api.sendResponse(w, autoStatus, err, graph) } func (api *API) addHandler(w http.ResponseWriter, r *http.Request) { reader, err := r.MultipartReader() if err != nil { - sendErrorResponse(w, http.StatusBadRequest, err.Error()) + api.sendResponse(w, http.StatusBadRequest, err, nil) return } params, err := types.AddParamsFromQuery(r.URL.Query()) if err != nil { - sendErrorResponse(w, http.StatusBadRequest, err.Error()) + api.sendResponse(w, http.StatusBadRequest, err, nil) return } + api.setHeaders(w) + // any errors sent as trailer adderutils.AddMultipartHTTPHandler( api.ctx, @@ -537,7 +542,7 @@ func (api *API) peerListHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &peersSerial) - sendResponse(w, err, peersSerial) + api.sendResponse(w, autoStatus, err, peersSerial) } func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { @@ -547,13 +552,13 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { var addInfo peerAddBody err := dec.Decode(&addInfo) if err != nil { - sendErrorResponse(w, 400, "error decoding request body") + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding request body"), nil) return } _, err = peer.IDB58Decode(addInfo.PeerID) if err != nil { - sendErrorResponse(w, 400, "error decoding peer_id") + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding peer_id"), nil) return } @@ -563,22 +568,22 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { "PeerAdd", addInfo.PeerID, &ids) - sendResponse(w, err, ids) + api.sendResponse(w, autoStatus, err, ids) } func (api *API) peerRemoveHandler(w http.ResponseWriter, r *http.Request) { - if p := parsePidOrError(w, r); p != "" { + if p := api.parsePidOrError(w, r); p != "" { err := api.rpcClient.Call("", "Cluster", "PeerRemove", p, &struct{}{}) - sendEmptyResponse(w, err) + api.sendResponse(w, autoStatus, err, nil) } } func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { logger.Debugf("rest api pinHandler: %s", ps.Cid) err := api.rpcClient.Call("", @@ -586,20 +591,20 @@ func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) { "Pin", ps, &struct{}{}) - sendAcceptedResponse(w, err) + api.sendResponse(w, http.StatusAccepted, err, nil) logger.Debug("rest api pinHandler done") } } func (api *API) unpinHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { logger.Debugf("rest api unpinHandler: %s", ps.Cid) err := api.rpcClient.Call("", "Cluster", "Unpin", ps, &struct{}{}) - sendAcceptedResponse(w, err) + api.sendResponse(w, http.StatusAccepted, err, nil) logger.Debug("rest api unpinHandler done") } } @@ -626,11 +631,11 @@ func (api *API) allocationsHandler(w http.ResponseWriter, r *http.Request) { outPins = append(outPins, pinS) } } - sendResponse(w, err, outPins) + api.sendResponse(w, autoStatus, err, outPins) } func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { var pin types.PinSerial err := api.rpcClient.Call("", "Cluster", @@ -638,10 +643,10 @@ func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) { ps, &pin) if err != nil { // errors here are 404s - sendErrorResponse(w, 404, err.Error()) + api.sendResponse(w, http.StatusNotFound, err, nil) return } - sendJSONResponse(w, 200, pin) + api.sendResponse(w, autoStatus, nil, pin) } } @@ -656,7 +661,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) { "StatusAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { var pinInfos []types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -664,7 +669,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) { "StatusAll", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfos) + api.sendResponse(w, autoStatus, err, pinInfos) } } @@ -672,7 +677,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -680,7 +685,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { "StatusLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -688,7 +693,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { "Status", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } @@ -704,7 +709,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) { "SyncAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { var pinInfos []types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -712,7 +717,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) { "SyncAll", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfos) + api.sendResponse(w, autoStatus, err, pinInfos) } } @@ -720,7 +725,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -728,7 +733,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { "SyncLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -736,7 +741,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { "Sync", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } @@ -751,9 +756,9 @@ func (api *API) recoverAllHandler(w http.ResponseWriter, r *http.Request) { "RecoverAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { - sendErrorResponse(w, 400, "only requests with parameter local=true are supported") + api.sendResponse(w, http.StatusBadRequest, errors.New("only requests with parameter local=true are supported"), nil) } } @@ -761,7 +766,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -769,7 +774,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { "RecoverLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -777,18 +782,18 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { "Recover", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } -func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { +func (api *API) parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { vars := mux.Vars(r) hash := vars["hash"] _, err := cid.Decode(hash) if err != nil { - sendErrorResponse(w, 400, "error decoding Cid: "+err.Error()) + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Cid: "+err.Error()), nil) return types.PinSerial{Cid: ""} } @@ -827,12 +832,12 @@ func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { return pin } -func parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID { +func (api *API) parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID { vars := mux.Vars(r) idStr := vars["peer"] pid, err := peer.IDB58Decode(idStr) if err != nil { - sendErrorResponse(w, 400, "error decoding Peer ID: "+err.Error()) + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Peer ID: "+err.Error()), nil) return "" } return pid @@ -855,64 +860,70 @@ func pinInfosToGlobal(pInfos []types.PinInfoSerial) []types.GlobalPinInfoSerial return gPInfos } -func sendResponse(w http.ResponseWriter, err error, resp interface{}) { - if checkErr(w, err) { - sendJSONResponse(w, 200, resp) - } -} +// sendResponse wraps all the logic for writing the response to a request: +// * Write configured headers +// * Write application/json content type +// * Write status: determined automatically if given 0 +// * Write an error if there is or write the response if there is +func (api *API) sendResponse( + w http.ResponseWriter, + status int, + err error, + resp interface{}, +) { -// checkErr takes care of returning standard error responses if we -// pass an error to it. It returns true when everythings OK (no error -// was handled), or false otherwise. -func checkErr(w http.ResponseWriter, err error) bool { + api.setHeaders(w) + enc := json.NewEncoder(w) + + // Send an error if err != nil { - sendErrorResponse(w, http.StatusInternalServerError, err.Error()) - return false - } - return true -} + if status == autoStatus || status < 400 { // set a default error status + status = http.StatusInternalServerError + } + w.WriteHeader(status) -func sendEmptyResponse(w http.ResponseWriter, err error) { - if checkErr(w, err) { - w.WriteHeader(http.StatusNoContent) - } -} + errorResp := types.Error{ + Code: status, + Message: err.Error(), + } + logger.Errorf("sending error response: %d: %s", status, err.Error()) -func sendAcceptedResponse(w http.ResponseWriter, err error) { - if checkErr(w, err) { - w.WriteHeader(http.StatusAccepted) + if err := enc.Encode(errorResp); err != nil { + logger.Error(err) + } + return } -} -func sendJSONResponse(w http.ResponseWriter, code int, resp interface{}) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Error(err) - } -} + // Send a body + if resp != nil { + if status == autoStatus { + status = http.StatusOK + } -func sendErrorResponse(w http.ResponseWriter, code int, msg string) { - errorResp := types.Error{ - Code: code, - Message: msg, - } - logger.Errorf("sending error response: %d: %s", code, msg) - sendJSONResponse(w, code, errorResp) -} + w.WriteHeader(status) -func sendStreamResponse(w http.ResponseWriter, err error, resp <-chan interface{}) { - if !checkErr(w, err) { + if err = enc.Encode(resp); err != nil { + logger.Error(err) + } return } - enc := json.NewEncoder(w) - w.Header().Add("Content-Type", "application/octet-stream") - w.WriteHeader(http.StatusOK) - for v := range resp { - err := enc.Encode(v) - if err != nil { - logger.Error(err) + // Empty response + if status == autoStatus { + status = http.StatusNoContent + } + + w.WriteHeader(status) +} + +// this sets all the headers that are common to all responses +// from this API. Called from sendResponse() and /add. +func (api *API) setHeaders(w http.ResponseWriter) { + for header, values := range api.config.Headers { + for _, val := range values { + w.Header().Add(header, val) } } + + w.Header().Add("Content-Type", "application/json") } diff --git a/api/rest/restapi_test.go b/api/rest/restapi_test.go index 849543a29..7de0a3fa1 100644 --- a/api/rest/restapi_test.go +++ b/api/rest/restapi_test.go @@ -124,6 +124,17 @@ func processStreamingResp(t *testing.T, httpResp *http.Response, err error, resp } } +func checkHeaders(t *testing.T, rest *API, url string, headers http.Header) { + for k, v := range rest.config.Headers { + if strings.Join(v, ",") != strings.Join(headers[k], ",") { + t.Errorf("%s does not show configured headers: %s", url, k) + } + } + if headers.Get("Content-Type") != "application/json" { + t.Errorf("%s is not application/json", url) + } +} + // makes a libp2p host that knows how to talk to the rest API host. func makeHost(t *testing.T, rest *API) host.Host { h, err := libp2p.New(context.Background()) @@ -185,6 +196,7 @@ func makeGet(t *testing.T, rest *API, url string, resp interface{}) { c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Get(url) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{}) { @@ -193,6 +205,7 @@ func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{} c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Post(url, "application/json", bytes.NewReader(body)) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makeDelete(t *testing.T, rest *API, url string, resp interface{}) { @@ -202,6 +215,7 @@ func makeDelete(t *testing.T, rest *API, url string, resp interface{}) { req, _ := http.NewRequest("DELETE", url, bytes.NewReader([]byte{})) httpResp, err := c.Do(req) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, contentType string, resp interface{}) { @@ -210,6 +224,7 @@ func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, cont c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Post(url, contentType, body) processStreamingResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } type testF func(t *testing.T, url urlF) @@ -251,6 +266,7 @@ func TestRestAPIIDEndpoint(t *testing.T) { rest := testAPI(t) httpsrest := testHTTPSAPI(t) defer rest.Shutdown() + defer httpsrest.Shutdown() tf := func(t *testing.T, url urlF) { id := api.IDSerial{} diff --git a/config_test.go b/config_test.go index 8be40c142..af5dc5b6e 100644 --- a/config_test.go +++ b/config_test.go @@ -50,7 +50,19 @@ var testingAPICfg = []byte(`{ "read_timeout": "0", "read_header_timeout": "5s", "write_timeout": "0", - "idle_timeout": "2m0s" + "idle_timeout": "2m0s", + "headers": { + "Access-Control-Allow-Headers": [ + "X-Requested-With", + "Range" + ], + "Access-Control-Allow-Methods": [ + "GET" + ], + "Access-Control-Allow-Origin": [ + "*" + ] + } }`) var testingIpfsCfg = []byte(`{ From 562ad713fc8ff739e8c89d74ec9c8484c9edd4eb Mon Sep 17 00:00:00 2001 From: Hector Sanjuan Date: Wed, 17 Oct 2018 13:43:57 +0200 Subject: [PATCH 2/2] Update docs for sendResponse License: MIT Signed-off-by: Hector Sanjuan --- api/rest/restapi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/rest/restapi.go b/api/rest/restapi.go index acc303c94..c2bee54ac 100644 --- a/api/rest/restapi.go +++ b/api/rest/restapi.go @@ -863,7 +863,7 @@ func pinInfosToGlobal(pInfos []types.PinInfoSerial) []types.GlobalPinInfoSerial // sendResponse wraps all the logic for writing the response to a request: // * Write configured headers // * Write application/json content type -// * Write status: determined automatically if given 0 +// * Write status: determined automatically if given "autoStatus" // * Write an error if there is or write the response if there is func (api *API) sendResponse( w http.ResponseWriter,