diff --git a/elasticsearch.go b/elasticsearch.go index c68d5c33cf..fb000c67f9 100644 --- a/elasticsearch.go +++ b/elasticsearch.go @@ -101,13 +101,12 @@ type Config struct { // Client represents the Elasticsearch client. // type Client struct { - *esapi.API // Embeds the API methods - Transport estransport.Interface - - productCheckOnce sync.Once - responseCheckOnce sync.Once - productCheckError error + *esapi.API // Embeds the API methods + Transport estransport.Interface useResponseCheckOnly bool + + productCheckMu sync.RWMutex + productCheckSuccess bool } type esVersion struct { @@ -280,35 +279,50 @@ func ParseElasticsearchVersion(version string) (int64, int64, int64, error) { // Perform delegates to Transport to execute a request and return a response. // func (c *Client) Perform(req *http.Request) (*http.Response, error) { - // ProductCheck validation - c.productCheckOnce.Do(func() { - // We skip this validation of we only want the header validation. - // ResponseCheck path continues after original request. - if c.useResponseCheckOnly { - return - } - + // ProductCheck validation. We skip this validation of we only want the + // header validation. ResponseCheck path continues after original request. + if !c.useResponseCheckOnly { // Launch product check for 7.x, request info, check header then payload. - c.productCheckError = c.productCheck() - return - }) + if err := c.doProductCheck(c.productCheck); err != nil { + return nil, err + } + } // Retrieve the original request. res, err := c.Transport.Perform(req) - c.responseCheckOnce.Do(func() { - // ResponseCheck path continues, we run the header check on the first answer from ES. - if c.useResponseCheckOnly { - c.productCheckError = genuineCheckHeader(res.Header) + // ResponseCheck path continues, we run the header check on the first answer from ES. + if err == nil { + checkHeader := func() error { return genuineCheckHeader(res.Header) } + if err := c.doProductCheck(checkHeader); err != nil { + res.Body.Close() + return nil, err } - }) - - if c.productCheckError != nil { - return nil, c.productCheckError } return res, err } +// doProductCheck calls f if there as not been a prior successful call to doProductCheck, +// returning nil otherwise. +func (c *Client) doProductCheck(f func() error) error { + c.productCheckMu.RLock() + productCheckSuccess := c.productCheckSuccess + c.productCheckMu.RUnlock() + if productCheckSuccess { + return nil + } + c.productCheckMu.Lock() + defer c.productCheckMu.Unlock() + if c.productCheckSuccess { + return nil + } + if err := f(); err != nil { + return err + } + c.productCheckSuccess = true + return nil +} + // productCheck runs an esapi.Info query to retrieve informations of the current cluster // decodes the response and decides if the cluster is a genuine Elasticsearch product. func (c *Client) productCheck() error { diff --git a/elasticsearch_internal_test.go b/elasticsearch_internal_test.go index 037f145d92..941c38dce0 100644 --- a/elasticsearch_internal_test.go +++ b/elasticsearch_internal_test.go @@ -22,19 +22,22 @@ package elasticsearch import ( "encoding/base64" "errors" - "github.com/elastic/go-elasticsearch/v7/estransport" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "os" + "reflect" "regexp" "strings" "testing" + + "github.com/elastic/go-elasticsearch/v7/estransport" ) var called bool -type mockTransp struct{ +type mockTransp struct { RoundTripFunc func(*http.Request) (*http.Response, error) } @@ -64,7 +67,6 @@ func (t *mockTransp) RoundTrip(req *http.Request) (*http.Response, error) { return t.RoundTripFunc(req) } - func TestClientConfiguration(t *testing.T) { t.Parallel() @@ -219,7 +221,7 @@ func TestClientConfiguration(t *testing.T) { }, nil }, }, - }) + }) if err != nil { t.Errorf("Unexpected error, got: %+v", err) } @@ -444,7 +446,7 @@ func TestParseElasticsearchVersion(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, got2, err := ParseElasticsearchVersion(tt.version) if (err != nil) != tt.wantErr { @@ -556,7 +558,7 @@ func TestGenuineCheckHeader(t *testing.T) { wantErr: true, }, { - name: "Unavailable product header", + name: "Unavailable product header", headers: http.Header{}, wantErr: true, }, @@ -575,49 +577,56 @@ func TestResponseCheckOnly(t *testing.T) { name string useResponseCheckOnly bool response *http.Response + requestErr error wantErr bool - } { + }{ { - name: "Valid answer with header", + name: "Valid answer with header", useResponseCheckOnly: false, response: &http.Response{ Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}}, - Body: ioutil.NopCloser(strings.NewReader("{}")), + Body: ioutil.NopCloser(strings.NewReader("{}")), }, - wantErr: false, + wantErr: false, }, { - name: "Valid answer without header", + name: "Valid answer without header", useResponseCheckOnly: false, response: &http.Response{ Body: ioutil.NopCloser(strings.NewReader("{}")), }, - wantErr: true, + wantErr: true, }, { - name: "Valid answer with header and response check", + name: "Valid answer with header and response check", useResponseCheckOnly: true, response: &http.Response{ Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}}, - Body: ioutil.NopCloser(strings.NewReader("{}")), + Body: ioutil.NopCloser(strings.NewReader("{}")), }, - wantErr: false, + wantErr: false, }, { - name: "Valid answer withouth header and response check", + name: "Valid answer without header and response check", useResponseCheckOnly: true, response: &http.Response{ Body: ioutil.NopCloser(strings.NewReader("{}")), }, - wantErr: true, + wantErr: true, + }, + { + name: "Request failed", + useResponseCheckOnly: true, + response: nil, + requestErr: errors.New("request failed"), + wantErr: true, }, - } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, _ := NewClient(Config{ Transport: &mockTransp{RoundTripFunc: func(request *http.Request) (*http.Response, error) { - return tt.response, nil + return tt.response, tt.requestErr }}, UseResponseCheckOnly: tt.useResponseCheckOnly, }) @@ -628,3 +637,33 @@ func TestResponseCheckOnly(t *testing.T) { }) } } + +func TestProductCheckError(t *testing.T) { + var requestPaths []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestPaths = append(requestPaths, r.URL.Path) + if len(requestPaths) == 1 { + // Simulate transient error from a proxy on the first request. + // This must not be cached by the client. + w.WriteHeader(http.StatusBadGateway) + return + } + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.Write([]byte("{}")) + })) + defer server.Close() + + c, _ := NewClient(Config{Addresses: []string{server.URL}, DisableRetry: true}) + if _, err := c.Cat.Indices(); err == nil { + t.Fatal("expected error") + } + if _, err := c.Cat.Indices(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if n := len(requestPaths); n != 3 { + t.Fatalf("expected 3 requests, got %d", n) + } + if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) { + t.Fatalf("unexpected request paths: %s", requestPaths) + } +}