diff --git a/api/rest/config.go b/api/rest/config.go index 5902d1e6a..ed95146c9 100644 --- a/api/rest/config.go +++ b/api/rest/config.go @@ -27,6 +27,15 @@ const ( DefaultIdleTimeout = 120 * time.Second ) +// These are the default values for Config. +var ( + DefaultHeaders = map[string][]string{ + "Access-Control-Allow-Headers": []string{"X-Requested-With", "Range"}, + "Access-Control-Allow-Methods": []string{"GET"}, + "Access-Control-Allow-Origin": []string{"*"}, + } +) + // Config is used to intialize the API object and allows to // customize the behaviour of it. It implements the config.ComponentConfig // interface. @@ -71,6 +80,10 @@ type Config struct { // BasicAuthCreds is a map of username-password pairs // which are authorized to use Basic Authentication BasicAuthCreds map[string]string + + // Headers provides customization for the headers returned + // by the API. By default it sets a CORS policy. + Headers map[string][]string } type jsonConfig struct { @@ -87,7 +100,8 @@ type jsonConfig struct { ID string `json:"id,omitempty"` PrivateKey string `json:"private_key,omitempty"` - BasicAuthCreds map[string]string `json:"basic_auth_credentials"` + BasicAuthCreds map[string]string `json:"basic_auth_credentials"` + Headers map[string][]string `json:"headers"` } // ConfigKey returns a human-friendly identifier for this type of @@ -116,6 +130,9 @@ func (cfg *Config) Default() error { // Auth cfg.BasicAuthCreds = nil + // Headers + cfg.Headers = DefaultHeaders + return nil } @@ -177,6 +194,7 @@ func (cfg *Config) LoadJSON(raw []byte) error { // Other options cfg.BasicAuthCreds = jcfg.BasicAuthCreds + cfg.Headers = jcfg.Headers return cfg.Validate() } @@ -295,6 +313,7 @@ func (cfg *Config) ToJSON() (raw []byte, err error) { WriteTimeout: cfg.WriteTimeout.String(), IdleTimeout: cfg.IdleTimeout.String(), BasicAuthCreds: cfg.BasicAuthCreds, + Headers: cfg.Headers, } if cfg.ID != "" { diff --git a/api/rest/restapi.go b/api/rest/restapi.go index 25e50d0c2..c2bee54ac 100644 --- a/api/rest/restapi.go +++ b/api/rest/restapi.go @@ -55,6 +55,9 @@ var ( ErrHTTPEndpointNotEnabled = errors.New("the HTTP endpoint is not enabled") ) +// Used by sendResponse to set the right status +const autoStatus = -1 + // For making a random sharding ID var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") @@ -479,7 +482,7 @@ func (api *API) idHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &idSerial) - sendResponse(w, err, idSerial) + api.sendResponse(w, autoStatus, err, idSerial) } func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) { @@ -490,7 +493,7 @@ func (api *API) versionHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &v) - sendResponse(w, err, v) + api.sendResponse(w, autoStatus, err, v) } func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) { @@ -500,22 +503,24 @@ func (api *API) graphHandler(w http.ResponseWriter, r *http.Request) { "ConnectGraph", struct{}{}, &graph) - sendResponse(w, err, graph) + api.sendResponse(w, autoStatus, err, graph) } func (api *API) addHandler(w http.ResponseWriter, r *http.Request) { reader, err := r.MultipartReader() if err != nil { - sendErrorResponse(w, http.StatusBadRequest, err.Error()) + api.sendResponse(w, http.StatusBadRequest, err, nil) return } params, err := types.AddParamsFromQuery(r.URL.Query()) if err != nil { - sendErrorResponse(w, http.StatusBadRequest, err.Error()) + api.sendResponse(w, http.StatusBadRequest, err, nil) return } + api.setHeaders(w) + // any errors sent as trailer adderutils.AddMultipartHTTPHandler( api.ctx, @@ -537,7 +542,7 @@ func (api *API) peerListHandler(w http.ResponseWriter, r *http.Request) { struct{}{}, &peersSerial) - sendResponse(w, err, peersSerial) + api.sendResponse(w, autoStatus, err, peersSerial) } func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { @@ -547,13 +552,13 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { var addInfo peerAddBody err := dec.Decode(&addInfo) if err != nil { - sendErrorResponse(w, 400, "error decoding request body") + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding request body"), nil) return } _, err = peer.IDB58Decode(addInfo.PeerID) if err != nil { - sendErrorResponse(w, 400, "error decoding peer_id") + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding peer_id"), nil) return } @@ -563,22 +568,22 @@ func (api *API) peerAddHandler(w http.ResponseWriter, r *http.Request) { "PeerAdd", addInfo.PeerID, &ids) - sendResponse(w, err, ids) + api.sendResponse(w, autoStatus, err, ids) } func (api *API) peerRemoveHandler(w http.ResponseWriter, r *http.Request) { - if p := parsePidOrError(w, r); p != "" { + if p := api.parsePidOrError(w, r); p != "" { err := api.rpcClient.Call("", "Cluster", "PeerRemove", p, &struct{}{}) - sendEmptyResponse(w, err) + api.sendResponse(w, autoStatus, err, nil) } } func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { logger.Debugf("rest api pinHandler: %s", ps.Cid) err := api.rpcClient.Call("", @@ -586,20 +591,20 @@ func (api *API) pinHandler(w http.ResponseWriter, r *http.Request) { "Pin", ps, &struct{}{}) - sendAcceptedResponse(w, err) + api.sendResponse(w, http.StatusAccepted, err, nil) logger.Debug("rest api pinHandler done") } } func (api *API) unpinHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { logger.Debugf("rest api unpinHandler: %s", ps.Cid) err := api.rpcClient.Call("", "Cluster", "Unpin", ps, &struct{}{}) - sendAcceptedResponse(w, err) + api.sendResponse(w, http.StatusAccepted, err, nil) logger.Debug("rest api unpinHandler done") } } @@ -626,11 +631,11 @@ func (api *API) allocationsHandler(w http.ResponseWriter, r *http.Request) { outPins = append(outPins, pinS) } } - sendResponse(w, err, outPins) + api.sendResponse(w, autoStatus, err, outPins) } func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) { - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { var pin types.PinSerial err := api.rpcClient.Call("", "Cluster", @@ -638,10 +643,10 @@ func (api *API) allocationHandler(w http.ResponseWriter, r *http.Request) { ps, &pin) if err != nil { // errors here are 404s - sendErrorResponse(w, 404, err.Error()) + api.sendResponse(w, http.StatusNotFound, err, nil) return } - sendJSONResponse(w, 200, pin) + api.sendResponse(w, autoStatus, nil, pin) } } @@ -656,7 +661,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) { "StatusAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { var pinInfos []types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -664,7 +669,7 @@ func (api *API) statusAllHandler(w http.ResponseWriter, r *http.Request) { "StatusAll", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfos) + api.sendResponse(w, autoStatus, err, pinInfos) } } @@ -672,7 +677,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -680,7 +685,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { "StatusLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -688,7 +693,7 @@ func (api *API) statusHandler(w http.ResponseWriter, r *http.Request) { "Status", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } @@ -704,7 +709,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) { "SyncAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { var pinInfos []types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -712,7 +717,7 @@ func (api *API) syncAllHandler(w http.ResponseWriter, r *http.Request) { "SyncAll", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfos) + api.sendResponse(w, autoStatus, err, pinInfos) } } @@ -720,7 +725,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -728,7 +733,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { "SyncLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -736,7 +741,7 @@ func (api *API) syncHandler(w http.ResponseWriter, r *http.Request) { "Sync", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } @@ -751,9 +756,9 @@ func (api *API) recoverAllHandler(w http.ResponseWriter, r *http.Request) { "RecoverAllLocal", struct{}{}, &pinInfos) - sendResponse(w, err, pinInfosToGlobal(pinInfos)) + api.sendResponse(w, autoStatus, err, pinInfosToGlobal(pinInfos)) } else { - sendErrorResponse(w, 400, "only requests with parameter local=true are supported") + api.sendResponse(w, http.StatusBadRequest, errors.New("only requests with parameter local=true are supported"), nil) } } @@ -761,7 +766,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { queryValues := r.URL.Query() local := queryValues.Get("local") - if ps := parseCidOrError(w, r); ps.Cid != "" { + if ps := api.parseCidOrError(w, r); ps.Cid != "" { if local == "true" { var pinInfo types.PinInfoSerial err := api.rpcClient.Call("", @@ -769,7 +774,7 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { "RecoverLocal", ps, &pinInfo) - sendResponse(w, err, pinInfoToGlobal(pinInfo)) + api.sendResponse(w, autoStatus, err, pinInfoToGlobal(pinInfo)) } else { var pinInfo types.GlobalPinInfoSerial err := api.rpcClient.Call("", @@ -777,18 +782,18 @@ func (api *API) recoverHandler(w http.ResponseWriter, r *http.Request) { "Recover", ps, &pinInfo) - sendResponse(w, err, pinInfo) + api.sendResponse(w, autoStatus, err, pinInfo) } } } -func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { +func (api *API) parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { vars := mux.Vars(r) hash := vars["hash"] _, err := cid.Decode(hash) if err != nil { - sendErrorResponse(w, 400, "error decoding Cid: "+err.Error()) + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Cid: "+err.Error()), nil) return types.PinSerial{Cid: ""} } @@ -827,12 +832,12 @@ func parseCidOrError(w http.ResponseWriter, r *http.Request) types.PinSerial { return pin } -func parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID { +func (api *API) parsePidOrError(w http.ResponseWriter, r *http.Request) peer.ID { vars := mux.Vars(r) idStr := vars["peer"] pid, err := peer.IDB58Decode(idStr) if err != nil { - sendErrorResponse(w, 400, "error decoding Peer ID: "+err.Error()) + api.sendResponse(w, http.StatusBadRequest, errors.New("error decoding Peer ID: "+err.Error()), nil) return "" } return pid @@ -855,64 +860,70 @@ func pinInfosToGlobal(pInfos []types.PinInfoSerial) []types.GlobalPinInfoSerial return gPInfos } -func sendResponse(w http.ResponseWriter, err error, resp interface{}) { - if checkErr(w, err) { - sendJSONResponse(w, 200, resp) - } -} +// sendResponse wraps all the logic for writing the response to a request: +// * Write configured headers +// * Write application/json content type +// * Write status: determined automatically if given "autoStatus" +// * Write an error if there is or write the response if there is +func (api *API) sendResponse( + w http.ResponseWriter, + status int, + err error, + resp interface{}, +) { -// checkErr takes care of returning standard error responses if we -// pass an error to it. It returns true when everythings OK (no error -// was handled), or false otherwise. -func checkErr(w http.ResponseWriter, err error) bool { + api.setHeaders(w) + enc := json.NewEncoder(w) + + // Send an error if err != nil { - sendErrorResponse(w, http.StatusInternalServerError, err.Error()) - return false - } - return true -} + if status == autoStatus || status < 400 { // set a default error status + status = http.StatusInternalServerError + } + w.WriteHeader(status) -func sendEmptyResponse(w http.ResponseWriter, err error) { - if checkErr(w, err) { - w.WriteHeader(http.StatusNoContent) - } -} + errorResp := types.Error{ + Code: status, + Message: err.Error(), + } + logger.Errorf("sending error response: %d: %s", status, err.Error()) -func sendAcceptedResponse(w http.ResponseWriter, err error) { - if checkErr(w, err) { - w.WriteHeader(http.StatusAccepted) + if err := enc.Encode(errorResp); err != nil { + logger.Error(err) + } + return } -} -func sendJSONResponse(w http.ResponseWriter, code int, resp interface{}) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Error(err) - } -} + // Send a body + if resp != nil { + if status == autoStatus { + status = http.StatusOK + } -func sendErrorResponse(w http.ResponseWriter, code int, msg string) { - errorResp := types.Error{ - Code: code, - Message: msg, - } - logger.Errorf("sending error response: %d: %s", code, msg) - sendJSONResponse(w, code, errorResp) -} + w.WriteHeader(status) -func sendStreamResponse(w http.ResponseWriter, err error, resp <-chan interface{}) { - if !checkErr(w, err) { + if err = enc.Encode(resp); err != nil { + logger.Error(err) + } return } - enc := json.NewEncoder(w) - w.Header().Add("Content-Type", "application/octet-stream") - w.WriteHeader(http.StatusOK) - for v := range resp { - err := enc.Encode(v) - if err != nil { - logger.Error(err) + // Empty response + if status == autoStatus { + status = http.StatusNoContent + } + + w.WriteHeader(status) +} + +// this sets all the headers that are common to all responses +// from this API. Called from sendResponse() and /add. +func (api *API) setHeaders(w http.ResponseWriter) { + for header, values := range api.config.Headers { + for _, val := range values { + w.Header().Add(header, val) } } + + w.Header().Add("Content-Type", "application/json") } diff --git a/api/rest/restapi_test.go b/api/rest/restapi_test.go index 849543a29..7de0a3fa1 100644 --- a/api/rest/restapi_test.go +++ b/api/rest/restapi_test.go @@ -124,6 +124,17 @@ func processStreamingResp(t *testing.T, httpResp *http.Response, err error, resp } } +func checkHeaders(t *testing.T, rest *API, url string, headers http.Header) { + for k, v := range rest.config.Headers { + if strings.Join(v, ",") != strings.Join(headers[k], ",") { + t.Errorf("%s does not show configured headers: %s", url, k) + } + } + if headers.Get("Content-Type") != "application/json" { + t.Errorf("%s is not application/json", url) + } +} + // makes a libp2p host that knows how to talk to the rest API host. func makeHost(t *testing.T, rest *API) host.Host { h, err := libp2p.New(context.Background()) @@ -185,6 +196,7 @@ func makeGet(t *testing.T, rest *API, url string, resp interface{}) { c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Get(url) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{}) { @@ -193,6 +205,7 @@ func makePost(t *testing.T, rest *API, url string, body []byte, resp interface{} c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Post(url, "application/json", bytes.NewReader(body)) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makeDelete(t *testing.T, rest *API, url string, resp interface{}) { @@ -202,6 +215,7 @@ func makeDelete(t *testing.T, rest *API, url string, resp interface{}) { req, _ := http.NewRequest("DELETE", url, bytes.NewReader([]byte{})) httpResp, err := c.Do(req) processResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, contentType string, resp interface{}) { @@ -210,6 +224,7 @@ func makeStreamingPost(t *testing.T, rest *API, url string, body io.Reader, cont c := httpClient(t, h, isHTTPS(url)) httpResp, err := c.Post(url, contentType, body) processStreamingResp(t, httpResp, err, resp) + checkHeaders(t, rest, url, httpResp.Header) } type testF func(t *testing.T, url urlF) @@ -251,6 +266,7 @@ func TestRestAPIIDEndpoint(t *testing.T) { rest := testAPI(t) httpsrest := testHTTPSAPI(t) defer rest.Shutdown() + defer httpsrest.Shutdown() tf := func(t *testing.T, url urlF) { id := api.IDSerial{} diff --git a/config_test.go b/config_test.go index 8be40c142..af5dc5b6e 100644 --- a/config_test.go +++ b/config_test.go @@ -50,7 +50,19 @@ var testingAPICfg = []byte(`{ "read_timeout": "0", "read_header_timeout": "5s", "write_timeout": "0", - "idle_timeout": "2m0s" + "idle_timeout": "2m0s", + "headers": { + "Access-Control-Allow-Headers": [ + "X-Requested-With", + "Range" + ], + "Access-Control-Allow-Methods": [ + "GET" + ], + "Access-Control-Allow-Origin": [ + "*" + ] + } }`) var testingIpfsCfg = []byte(`{