From 70b402f1cdbdf1d124f717b154bb95b20c865891 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laurent=20Saint-F=C3=A9lix?= Date: Thu, 22 Jul 2021 15:38:32 +0200 Subject: [PATCH] Client: Rework product check according to latest spec Handling of Info special cases if not accessible --- elasticsearch.go | 42 ++++++++++++++++++++-------------- elasticsearch_internal_test.go | 22 ++++++++++++++++++ esapi/esapi_benchmark_test.go | 1 + 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/elasticsearch.go b/elasticsearch.go index a438d367d7..ff106477b4 100644 --- a/elasticsearch.go +++ b/elasticsearch.go @@ -23,6 +23,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "net/http" "net/url" "os" @@ -339,41 +341,47 @@ func (c *Client) doProductCheck(f func() error) error { // productCheck runs an esapi.Info query to retrieve informations of the current cluster // decodes the response and decides if the cluster is a genuine Elasticsearch product. func (c *Client) productCheck() error { - var info info - req := esapi.InfoRequest{} res, err := req.Do(context.Background(), c.Transport) if err != nil { return err } + defer res.Body.Close() if res.IsError() { - return fmt.Errorf("cannot retrieve info from Elasticsearch") + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + return err + } + switch res.StatusCode { + case http.StatusUnauthorized: + return nil + case http.StatusForbidden: + return nil + default: + return fmt.Errorf("cannot retrieve informations from Elasticsearch") + } } - contentType := res.Header.Get("Content-Type") - if res.Body != nil { - defer res.Body.Close() + err = genuineCheckHeader(res.Header) + if err != nil { + var info info + contentType := res.Header.Get("Content-Type") if strings.Contains(contentType, "json") { - decoder := json.NewDecoder(res.Body) - err = decoder.Decode(&info) + err = json.NewDecoder(res.Body).Decode(&info) if err != nil { return fmt.Errorf("error decoding Elasticsearch informations: %s", err) } } - err = genuineCheckHeader(res.Header) - - if err != nil { - if info.Version.Number != "" { - err = genuineCheckInfo(info) - } + if info.Version.Number != "" { + err = genuineCheckInfo(info) } + } - if err != nil { - return err - } + if err != nil { + return err } return nil diff --git a/elasticsearch_internal_test.go b/elasticsearch_internal_test.go index 08f4a5827d..f9133589b1 100644 --- a/elasticsearch_internal_test.go +++ b/elasticsearch_internal_test.go @@ -644,6 +644,7 @@ func TestResponseCheckOnly(t *testing.T) { useResponseCheckOnly: false, response: &http.Response{ StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(strings.NewReader("")), }, requestErr: nil, wantErr: true, @@ -653,10 +654,31 @@ func TestResponseCheckOnly(t *testing.T) { useResponseCheckOnly: false, response: &http.Response{ StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(strings.NewReader("")), }, requestErr: nil, wantErr: true, }, + { + name: "Valid request, 403 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: false, + }, + { + name: "Valid request, 401 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/esapi/esapi_benchmark_test.go b/esapi/esapi_benchmark_test.go index 78d7a4f832..88e0c30e34 100644 --- a/esapi/esapi_benchmark_test.go +++ b/esapi/esapi_benchmark_test.go @@ -43,6 +43,7 @@ var ( return &http.Response{ StatusCode: 200, Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}}, + Body: ioutil.NopCloser(strings.NewReader("{}")), }, nil } return &http.Response{