diff --git a/middleware.go b/middleware.go index eb9e15e6..7e0e793c 100644 --- a/middleware.go +++ b/middleware.go @@ -125,7 +125,7 @@ func parseRequestBody(c *Client, r *Request) (err error) { CL: // by default resty won't set content length, you can if you want to :) - if c.setContentLength || r.setContentLength { + if (c.setContentLength || r.setContentLength) && r.bodyBuf != nil { r.Header.Set(hdrContentLengthKey, fmt.Sprintf("%d", r.bodyBuf.Len())) } diff --git a/request.go b/request.go index 6b4b394a..773a3499 100644 --- a/request.go +++ b/request.go @@ -11,11 +11,18 @@ import ( "encoding/xml" "fmt" "io" + "net" "net/url" "reflect" "strings" ) +// SRVRecord holds the data to query the SRV record for the following service +type SRVRecord struct { + Service string + Domain string +} + // SetHeader method is to set a single header field and its value in the current request. // Example: To set `Content-Type` and `Accept` as `application/json`. // resty.R(). @@ -345,6 +352,16 @@ func (r *Request) SetProxy(proxyURL string) *Request { return r } +// SetSRV method sets the details to query the service SRV record and execute the +// request. +// resty.R(). +// SetSRV(SRVRecord{"web", "testservice.com"}). +// Get("/get") +func (r *Request) SetSRV(srv *SRVRecord) *Request { + r.SRV = srv + return r +} + // // HTTP verb method starts here // @@ -389,22 +406,34 @@ func (r *Request) Patch(url string) (*Response, error) { // resp, err := resty.R().Execute(resty.GET, "http://httpbin.org/get") // func (r *Request) Execute(method, url string) (*Response, error) { + var addrs []*net.SRV + var err error + if r.isMultiPart && !(method == MethodPost || method == MethodPut) { return nil, fmt.Errorf("Multipart content is not allowed in HTTP verb [%v]", method) } + if r.SRV != nil { + _, addrs, err = net.LookupSRV(r.SRV.Service, "tcp", r.SRV.Domain) + if err != nil { + return nil, err + } + } + r.Method = method - r.URL = url + r.URL = r.selectAddr(addrs, url, 0) if r.client.RetryCount == 0 { return r.client.execute(r) } var resp *Response - var err error attempt := 0 _ = Backoff(func() (*Response, error) { attempt++ + + r.URL = r.selectAddr(addrs, url, attempt) + resp, err = r.client.execute(r) if err != nil { r.client.Log.Printf("ERROR [%v] Attempt [%v]", err, attempt) @@ -465,3 +494,15 @@ func (r *Request) fmtBodyString() (body string) { return } + +func (r *Request) selectAddr(addrs []*net.SRV, path string, attempt int) string { + if addrs == nil { + return path + } + + idx := attempt % len(addrs) + domain := strings.TrimRight(addrs[idx].Target, ".") + path = strings.TrimLeft(path, "/") + + return fmt.Sprintf("%s://%s:%d/%s", r.client.scheme, domain, addrs[idx].Port, path) +} diff --git a/request16.go b/request16.go index bfab4701..946a6f73 100644 --- a/request16.go +++ b/request16.go @@ -33,6 +33,7 @@ type Request struct { Error interface{} Time time.Time RawRequest *http.Request + SRV *SRVRecord client *Client bodyBuf *bytes.Buffer diff --git a/request17.go b/request17.go index 8fa7c947..c7729045 100644 --- a/request17.go +++ b/request17.go @@ -34,6 +34,7 @@ type Request struct { Error interface{} Time time.Time RawRequest *http.Request + SRV *SRVRecord client *Client bodyBuf *bytes.Buffer diff --git a/resty_test.go b/resty_test.go index 07eba115..b92064bf 100644 --- a/resty_test.go +++ b/resty_test.go @@ -1409,6 +1409,31 @@ func TestClientOptions(t *testing.T) { SetLogger(ioutil.Discard) } +func TestSRV(t *testing.T) { + c := dc(). + SetRedirectPolicy(FlexibleRedirectPolicy(20)). + SetScheme("http") + + r := c.R(). + SetSRV(&SRVRecord{"xmpp-server", "google.com"}) + + assertEqual(t, "xmpp-server", r.SRV.Service) + assertEqual(t, "google.com", r.SRV.Domain) + + resp, err := r.Get("/") + assertError(t, err) + assertEqual(t, http.StatusOK, resp.StatusCode()) +} + +func TestSRVInvalidService(t *testing.T) { + _, err := R(). + SetSRV(&SRVRecord{"nonexistantservice", "sampledomain"}). + Get("/") + + assertEqual(t, true, (err != nil)) + assertEqual(t, true, strings.Contains(err.Error(), "no such host")) +} + func getTestDataPath() string { pwd, _ := os.Getwd() return pwd + "/test-data"