Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ func init() {
}

const (
defaultURL = "http://localhost:9200"
tagline = "You Know, for Search"
unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product"
defaultURL = "http://localhost:9200"
tagline = "You Know, for Search"
unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product"
unsupportedProduct = "the client noticed that the server is not a supported distribution of Elasticsearch"
)

// Version returns the package version as a string.
Expand Down Expand Up @@ -121,6 +122,7 @@ type info struct {
Tagline string `json:"tagline"`
}


// NewDefaultClient creates a new client with default options.
//
// It will use http://localhost:9200 as the default address.
Expand Down Expand Up @@ -232,7 +234,7 @@ func NewClient(cfg Config) (*Client, error) {
//
func genuineCheckHeader(header http.Header) error {
if header.Get("X-Elastic-Product") != "Elasticsearch" {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
return nil
}
Expand All @@ -246,18 +248,19 @@ func genuineCheckInfo(info info) error {
}

if major < 6 {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
if major < 7 {
if info.Tagline != tagline {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
}
if major >= 7 {
if minor < 14 {
if info.Tagline != tagline ||
info.Version.BuildFlavor != "default" {
return fmt.Errorf(unknownProduct)
if info.Tagline != tagline {
return errors.New(unknownProduct)
} else if info.Version.BuildFlavor != "default" {
return errors.New(unsupportedProduct)
}
}
}
Expand Down Expand Up @@ -312,18 +315,24 @@ func (c *Client) doProductCheck(f func() error) error {
c.productCheckMu.RLock()
productCheckSuccess := c.productCheckSuccess
c.productCheckMu.RUnlock()

if productCheckSuccess {
return nil
}

c.productCheckMu.Lock()
defer c.productCheckMu.Unlock()

if c.productCheckSuccess {
return nil
}

if err := f(); err != nil {
return err
}

c.productCheckSuccess = true

return nil
}

Expand All @@ -338,6 +347,10 @@ func (c *Client) productCheck() error {
return err
}

if res.IsError() {
return fmt.Errorf("cannot retrieve info from Elasticsearch")
}

contentType := res.Header.Get("Content-Type")
if res.Body != nil {
defer res.Body.Close()
Expand All @@ -350,19 +363,14 @@ func (c *Client) productCheck() error {
}
}

switch res.StatusCode {
case http.StatusForbidden:
case http.StatusUnauthorized:
break
default:
err = genuineCheckHeader(res.Header)
err = genuineCheckHeader(res.Header)

if err != nil {
if info.Version.Number != "" {
err = genuineCheckInfo(info)
}
if err != nil {
if info.Version.Number != "" {
err = genuineCheckInfo(info)
}
}

if err != nil {
return err
}
Expand Down
44 changes: 43 additions & 1 deletion elasticsearch_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ func TestGenuineCheckInfo(t *testing.T) {
name string
info info
wantErr bool
err error
}{
{
name: "Genuine Elasticsearch 7.14.0",
Expand All @@ -482,6 +483,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: false,
err: nil,
},
{
name: "Genuine Elasticsearch 6.15.1",
Expand All @@ -493,6 +495,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: false,
err: nil,
},
{
name: "Not so genuine Elasticsearch 7 major",
Expand All @@ -504,6 +507,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Not so genuine Elasticsearch 6 major",
Expand All @@ -515,6 +519,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Fun",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Way older Elasticsearch major",
Expand All @@ -526,11 +531,24 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Fun",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Elasticsearch oss",
info: info{
Version: esVersion{
Number: "7.10.0",
BuildFlavor: "oss",
},
Tagline: "You Know, for Search",
},
wantErr: true,
err: errors.New(unsupportedProduct),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr {
if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr && err != tt.err {
t.Errorf("genuineCheckInfo() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -621,6 +639,24 @@ func TestResponseCheckOnly(t *testing.T) {
requestErr: errors.New("request failed"),
wantErr: true,
},
{
name: "Valid request, 500 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusInternalServerError,
},
requestErr: nil,
wantErr: true,
},
{
name: "Valid request, 404 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusNotFound,
},
requestErr: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -657,6 +693,9 @@ func TestProductCheckError(t *testing.T) {
if _, err := c.Cat.Indices(); err == nil {
t.Fatal("expected error")
}
if c.productCheckSuccess {
t.Fatalf("product check should be invalid, got %v", c.productCheckSuccess)
}
if _, err := c.Cat.Indices(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
Expand All @@ -666,4 +705,7 @@ func TestProductCheckError(t *testing.T) {
if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) {
t.Fatalf("unexpected request paths: %s", requestPaths)
}
if !c.productCheckSuccess {
t.Fatalf("product check should be valid, got : %v", c.productCheckSuccess)
}
}
8 changes: 7 additions & 1 deletion esapi/esapi_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ var (
Body: ioutil.NopCloser(strings.NewReader("MOCK")),
}
defaultRoundTripFn = func(*http.Request) (*http.Response, error) { return defaultResponse, nil }
errorRoundTripFn = func(*http.Request) (*http.Response, error) {
errorRoundTripFn = func(request *http.Request) (*http.Response, error) {
if request.URL.Path == "/" {
return &http.Response{
StatusCode: 200,
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
}, nil
}
return &http.Response{
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
StatusCode: 400,
Expand Down