Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,12 @@ type Config struct {
// Client represents the Elasticsearch client.
//
type Client struct {
*esapi.API // Embeds the API methods
Transport estransport.Interface

productCheckOnce sync.Once
responseCheckOnce sync.Once
productCheckError error
*esapi.API // Embeds the API methods
Transport estransport.Interface
useResponseCheckOnly bool

productCheckMu sync.RWMutex
productCheckSuccess bool
}

type esVersion struct {
Expand Down Expand Up @@ -280,35 +279,50 @@ func ParseElasticsearchVersion(version string) (int64, int64, int64, error) {
// Perform delegates to Transport to execute a request and return a response.
//
func (c *Client) Perform(req *http.Request) (*http.Response, error) {
// ProductCheck validation
c.productCheckOnce.Do(func() {
// We skip this validation of we only want the header validation.
// ResponseCheck path continues after original request.
if c.useResponseCheckOnly {
return
}

// ProductCheck validation. We skip this validation of we only want the
// header validation. ResponseCheck path continues after original request.
if !c.useResponseCheckOnly {
// Launch product check for 7.x, request info, check header then payload.
c.productCheckError = c.productCheck()
return
})
if err := c.doProductCheck(c.productCheck); err != nil {
return nil, err
}
}

// Retrieve the original request.
res, err := c.Transport.Perform(req)

c.responseCheckOnce.Do(func() {
// ResponseCheck path continues, we run the header check on the first answer from ES.
if c.useResponseCheckOnly {
c.productCheckError = genuineCheckHeader(res.Header)
// ResponseCheck path continues, we run the header check on the first answer from ES.
if err == nil {
checkHeader := func() error { return genuineCheckHeader(res.Header) }
if err := c.doProductCheck(checkHeader); err != nil {
res.Body.Close()
return nil, err
}
})

if c.productCheckError != nil {
return nil, c.productCheckError
}
return res, err
}

// doProductCheck calls f if there as not been a prior successful call to doProductCheck,
// returning nil otherwise.
func (c *Client) doProductCheck(f func() error) error {
c.productCheckMu.RLock()
productCheckSuccess := c.productCheckSuccess
c.productCheckMu.RUnlock()
if productCheckSuccess {
return nil
}
c.productCheckMu.Lock()
defer c.productCheckMu.Unlock()
if c.productCheckSuccess {
return nil
}
if err := f(); err != nil {
return err
}
c.productCheckSuccess = true
return nil
}

// productCheck runs an esapi.Info query to retrieve informations of the current cluster
// decodes the response and decides if the cluster is a genuine Elasticsearch product.
func (c *Client) productCheck() error {
Expand Down
77 changes: 58 additions & 19 deletions elasticsearch_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ package elasticsearch
import (
"encoding/base64"
"errors"
"github.com/elastic/go-elasticsearch/v7/estransport"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"reflect"
"regexp"
"strings"
"testing"

"github.com/elastic/go-elasticsearch/v7/estransport"
)

var called bool

type mockTransp struct{
type mockTransp struct {
RoundTripFunc func(*http.Request) (*http.Response, error)
}

Expand Down Expand Up @@ -64,7 +67,6 @@ func (t *mockTransp) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripFunc(req)
}


func TestClientConfiguration(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -219,7 +221,7 @@ func TestClientConfiguration(t *testing.T) {
}, nil
},
},
})
})
if err != nil {
t.Errorf("Unexpected error, got: %+v", err)
}
Expand Down Expand Up @@ -444,7 +446,7 @@ func TestParseElasticsearchVersion(t *testing.T) {
wantErr: true,
},
}
for _, tt := range tests {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, got2, err := ParseElasticsearchVersion(tt.version)
if (err != nil) != tt.wantErr {
Expand Down Expand Up @@ -556,7 +558,7 @@ func TestGenuineCheckHeader(t *testing.T) {
wantErr: true,
},
{
name: "Unavailable product header",
name: "Unavailable product header",
headers: http.Header{},
wantErr: true,
},
Expand All @@ -575,49 +577,56 @@ func TestResponseCheckOnly(t *testing.T) {
name string
useResponseCheckOnly bool
response *http.Response
requestErr error
wantErr bool
} {
}{
{
name: "Valid answer with header",
name: "Valid answer with header",
useResponseCheckOnly: false,
response: &http.Response{
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
Body: ioutil.NopCloser(strings.NewReader("{}")),
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
wantErr: false,
wantErr: false,
},
{
name: "Valid answer without header",
name: "Valid answer without header",
useResponseCheckOnly: false,
response: &http.Response{
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
wantErr: true,
wantErr: true,
},
{
name: "Valid answer with header and response check",
name: "Valid answer with header and response check",
useResponseCheckOnly: true,
response: &http.Response{
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
Body: ioutil.NopCloser(strings.NewReader("{}")),
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
wantErr: false,
wantErr: false,
},
{
name: "Valid answer withouth header and response check",
name: "Valid answer without header and response check",
useResponseCheckOnly: true,
response: &http.Response{
Body: ioutil.NopCloser(strings.NewReader("{}")),
},
wantErr: true,
wantErr: true,
},
{
name: "Request failed",
useResponseCheckOnly: true,
response: nil,
requestErr: errors.New("request failed"),
wantErr: true,
},

}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewClient(Config{
Transport: &mockTransp{RoundTripFunc: func(request *http.Request) (*http.Response, error) {
return tt.response, nil
return tt.response, tt.requestErr
}},
UseResponseCheckOnly: tt.useResponseCheckOnly,
})
Expand All @@ -628,3 +637,33 @@ func TestResponseCheckOnly(t *testing.T) {
})
}
}

func TestProductCheckError(t *testing.T) {
var requestPaths []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestPaths = append(requestPaths, r.URL.Path)
if len(requestPaths) == 1 {
// Simulate transient error from a proxy on the first request.
// This must not be cached by the client.
w.WriteHeader(http.StatusBadGateway)
return
}
w.Header().Set("X-Elastic-Product", "Elasticsearch")
w.Write([]byte("{}"))
}))
defer server.Close()

c, _ := NewClient(Config{Addresses: []string{server.URL}, DisableRetry: true})
if _, err := c.Cat.Indices(); err == nil {
t.Fatal("expected error")
}
if _, err := c.Cat.Indices(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
if n := len(requestPaths); n != 3 {
t.Fatalf("expected 3 requests, got %d", n)
}
if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) {
t.Fatalf("unexpected request paths: %s", requestPaths)
}
}