diff --git a/elasticsearch.go b/elasticsearch.go index fb000c67f9..ea5bd4d411 100644 --- a/elasticsearch.go +++ b/elasticsearch.go @@ -23,6 +23,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "io/ioutil" "net/http" "net/url" "os" @@ -47,9 +49,10 @@ func init() { } const ( - defaultURL = "http://localhost:9200" - tagline = "You Know, for Search" - unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product" + defaultURL = "http://localhost:9200" + tagline = "You Know, for Search" + unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product" + unsupportedProduct = "the client noticed that the server is not a supported distribution of Elasticsearch" ) // Version returns the package version as a string. @@ -119,6 +122,7 @@ type info struct { Tagline string `json:"tagline"` } + // NewDefaultClient creates a new client with default options. // // It will use http://localhost:9200 as the default address. @@ -228,7 +232,7 @@ func NewClient(cfg Config) (*Client, error) { // func genuineCheckHeader(header http.Header) error { if header.Get("X-Elastic-Product") != "Elasticsearch" { - return fmt.Errorf(unknownProduct) + return errors.New(unknownProduct) } return nil } @@ -242,18 +246,19 @@ func genuineCheckInfo(info info) error { } if major < 6 { - return fmt.Errorf(unknownProduct) + return errors.New(unknownProduct) } if major < 7 { if info.Tagline != tagline { - return fmt.Errorf(unknownProduct) + return errors.New(unknownProduct) } } if major >= 7 { if minor < 14 { - if info.Tagline != tagline || - info.Version.BuildFlavor != "default" { - return fmt.Errorf(unknownProduct) + if info.Tagline != tagline { + return errors.New(unknownProduct) + } else if info.Version.BuildFlavor != "default" { + return errors.New(unsupportedProduct) } } } @@ -308,62 +313,73 @@ func (c *Client) doProductCheck(f func() error) error { c.productCheckMu.RLock() productCheckSuccess := c.productCheckSuccess c.productCheckMu.RUnlock() + if productCheckSuccess { return nil } + c.productCheckMu.Lock() defer c.productCheckMu.Unlock() + if c.productCheckSuccess { return nil } + if err := f(); err != nil { return err } + c.productCheckSuccess = true + return nil } // productCheck runs an esapi.Info query to retrieve informations of the current cluster // decodes the response and decides if the cluster is a genuine Elasticsearch product. func (c *Client) productCheck() error { - var info info - req := esapi.InfoRequest{} res, err := req.Do(context.Background(), c.Transport) if err != nil { return err } + defer res.Body.Close() - contentType := res.Header.Get("Content-Type") - if res.Body != nil { - defer res.Body.Close() - - if strings.Contains(contentType, "json") { - decoder := json.NewDecoder(res.Body) - err = decoder.Decode(&info) - if err != nil { - return fmt.Errorf("error decoding Elasticsearch informations: %s", err) - } + if res.IsError() { + _, err = io.Copy(ioutil.Discard, res.Body) + if err != nil { + return err } - switch res.StatusCode { - case http.StatusForbidden: case http.StatusUnauthorized: - break + return nil + case http.StatusForbidden: + return nil default: - err = genuineCheckHeader(res.Header) + return fmt.Errorf("cannot retrieve informations from Elasticsearch") + } + } + err = genuineCheckHeader(res.Header) + + if err != nil { + var info info + contentType := res.Header.Get("Content-Type") + if strings.Contains(contentType, "json") { + err = json.NewDecoder(res.Body).Decode(&info) if err != nil { - if info.Version.Number != "" { - err = genuineCheckInfo(info) - } + return fmt.Errorf("error decoding Elasticsearch informations: %s", err) } } - if err != nil { - return err + + if info.Version.Number != "" { + err = genuineCheckInfo(info) } } + if err != nil { + return err + } + return nil } diff --git a/elasticsearch_internal_test.go b/elasticsearch_internal_test.go index 941c38dce0..f9133589b1 100644 --- a/elasticsearch_internal_test.go +++ b/elasticsearch_internal_test.go @@ -471,6 +471,7 @@ func TestGenuineCheckInfo(t *testing.T) { name string info info wantErr bool + err error }{ { name: "Genuine Elasticsearch 7.14.0", @@ -482,6 +483,7 @@ func TestGenuineCheckInfo(t *testing.T) { Tagline: "You Know, for Search", }, wantErr: false, + err: nil, }, { name: "Genuine Elasticsearch 6.15.1", @@ -493,6 +495,7 @@ func TestGenuineCheckInfo(t *testing.T) { Tagline: "You Know, for Search", }, wantErr: false, + err: nil, }, { name: "Not so genuine Elasticsearch 7 major", @@ -504,6 +507,7 @@ func TestGenuineCheckInfo(t *testing.T) { Tagline: "You Know, for Search", }, wantErr: true, + err: errors.New(unknownProduct), }, { name: "Not so genuine Elasticsearch 6 major", @@ -515,6 +519,7 @@ func TestGenuineCheckInfo(t *testing.T) { Tagline: "You Know, for Fun", }, wantErr: true, + err: errors.New(unknownProduct), }, { name: "Way older Elasticsearch major", @@ -526,11 +531,24 @@ func TestGenuineCheckInfo(t *testing.T) { Tagline: "You Know, for Fun", }, wantErr: true, + err: errors.New(unknownProduct), + }, + { + name: "Elasticsearch oss", + info: info{ + Version: esVersion{ + Number: "7.10.0", + BuildFlavor: "oss", + }, + Tagline: "You Know, for Search", + }, + wantErr: true, + err: errors.New(unsupportedProduct), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr { + if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr && err != tt.err { t.Errorf("genuineCheckInfo() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -621,6 +639,46 @@ func TestResponseCheckOnly(t *testing.T) { requestErr: errors.New("request failed"), wantErr: true, }, + { + name: "Valid request, 500 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: true, + }, + { + name: "Valid request, 404 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusNotFound, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: true, + }, + { + name: "Valid request, 403 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: false, + }, + { + name: "Valid request, 401 response", + useResponseCheckOnly: false, + response: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: ioutil.NopCloser(strings.NewReader("")), + }, + requestErr: nil, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -657,6 +715,9 @@ func TestProductCheckError(t *testing.T) { if _, err := c.Cat.Indices(); err == nil { t.Fatal("expected error") } + if c.productCheckSuccess { + t.Fatalf("product check should be invalid, got %v", c.productCheckSuccess) + } if _, err := c.Cat.Indices(); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -666,4 +727,7 @@ func TestProductCheckError(t *testing.T) { if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) { t.Fatalf("unexpected request paths: %s", requestPaths) } + if !c.productCheckSuccess { + t.Fatalf("product check should be valid, got : %v", c.productCheckSuccess) + } } diff --git a/esapi/esapi_benchmark_test.go b/esapi/esapi_benchmark_test.go index bb447aeac9..88e0c30e34 100644 --- a/esapi/esapi_benchmark_test.go +++ b/esapi/esapi_benchmark_test.go @@ -38,7 +38,14 @@ var ( Body: ioutil.NopCloser(strings.NewReader("MOCK")), } defaultRoundTripFn = func(*http.Request) (*http.Response, error) { return defaultResponse, nil } - errorRoundTripFn = func(*http.Request) (*http.Response, error) { + errorRoundTripFn = func(request *http.Request) (*http.Response, error) { + if request.URL.Path == "/" { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}}, + Body: ioutil.NopCloser(strings.NewReader("{}")), + }, nil + } return &http.Response{ Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}}, StatusCode: 400,