Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions elasticsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
Expand All @@ -47,9 +49,10 @@ func init() {
}

const (
defaultURL = "http://localhost:9200"
tagline = "You Know, for Search"
unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product"
defaultURL = "http://localhost:9200"
tagline = "You Know, for Search"
unknownProduct = "the client noticed that the server is not Elasticsearch and we do not support this unknown product"
unsupportedProduct = "the client noticed that the server is not a supported distribution of Elasticsearch"
)

// Version returns the package version as a string.
Expand Down Expand Up @@ -119,6 +122,7 @@ type info struct {
Tagline string `json:"tagline"`
}


// NewDefaultClient creates a new client with default options.
//
// It will use http://localhost:9200 as the default address.
Expand Down Expand Up @@ -228,7 +232,7 @@ func NewClient(cfg Config) (*Client, error) {
//
func genuineCheckHeader(header http.Header) error {
if header.Get("X-Elastic-Product") != "Elasticsearch" {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
return nil
}
Expand All @@ -242,18 +246,19 @@ func genuineCheckInfo(info info) error {
}

if major < 6 {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
if major < 7 {
if info.Tagline != tagline {
return fmt.Errorf(unknownProduct)
return errors.New(unknownProduct)
}
}
if major >= 7 {
if minor < 14 {
if info.Tagline != tagline ||
info.Version.BuildFlavor != "default" {
return fmt.Errorf(unknownProduct)
if info.Tagline != tagline {
return errors.New(unknownProduct)
} else if info.Version.BuildFlavor != "default" {
return errors.New(unsupportedProduct)
}
}
}
Expand Down Expand Up @@ -308,62 +313,73 @@ func (c *Client) doProductCheck(f func() error) error {
c.productCheckMu.RLock()
productCheckSuccess := c.productCheckSuccess
c.productCheckMu.RUnlock()

if productCheckSuccess {
return nil
}

c.productCheckMu.Lock()
defer c.productCheckMu.Unlock()

if c.productCheckSuccess {
return nil
}

if err := f(); err != nil {
return err
}

c.productCheckSuccess = true

return nil
}

// productCheck runs an esapi.Info query to retrieve informations of the current cluster
// decodes the response and decides if the cluster is a genuine Elasticsearch product.
func (c *Client) productCheck() error {
var info info

req := esapi.InfoRequest{}
res, err := req.Do(context.Background(), c.Transport)
if err != nil {
return err
}
defer res.Body.Close()

contentType := res.Header.Get("Content-Type")
if res.Body != nil {
defer res.Body.Close()

if strings.Contains(contentType, "json") {
decoder := json.NewDecoder(res.Body)
err = decoder.Decode(&info)
if err != nil {
return fmt.Errorf("error decoding Elasticsearch informations: %s", err)
}
if res.IsError() {
_, err = io.Copy(ioutil.Discard, res.Body)
if err != nil {
return err
}

switch res.StatusCode {
case http.StatusForbidden:
case http.StatusUnauthorized:
break
return nil
case http.StatusForbidden:
return nil
default:
err = genuineCheckHeader(res.Header)
return fmt.Errorf("cannot retrieve informations from Elasticsearch")
}
}

err = genuineCheckHeader(res.Header)

if err != nil {
var info info
contentType := res.Header.Get("Content-Type")
if strings.Contains(contentType, "json") {
err = json.NewDecoder(res.Body).Decode(&info)
if err != nil {
if info.Version.Number != "" {
err = genuineCheckInfo(info)
}
return fmt.Errorf("error decoding Elasticsearch informations: %s", err)
}
}
if err != nil {
return err

if info.Version.Number != "" {
err = genuineCheckInfo(info)
}
}

if err != nil {
return err
}

return nil
}

Expand Down
66 changes: 65 additions & 1 deletion elasticsearch_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ func TestGenuineCheckInfo(t *testing.T) {
name string
info info
wantErr bool
err error
}{
{
name: "Genuine Elasticsearch 7.14.0",
Expand All @@ -482,6 +483,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: false,
err: nil,
},
{
name: "Genuine Elasticsearch 6.15.1",
Expand All @@ -493,6 +495,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: false,
err: nil,
},
{
name: "Not so genuine Elasticsearch 7 major",
Expand All @@ -504,6 +507,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Search",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Not so genuine Elasticsearch 6 major",
Expand All @@ -515,6 +519,7 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Fun",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Way older Elasticsearch major",
Expand All @@ -526,11 +531,24 @@ func TestGenuineCheckInfo(t *testing.T) {
Tagline: "You Know, for Fun",
},
wantErr: true,
err: errors.New(unknownProduct),
},
{
name: "Elasticsearch oss",
info: info{
Version: esVersion{
Number: "7.10.0",
BuildFlavor: "oss",
},
Tagline: "You Know, for Search",
},
wantErr: true,
err: errors.New(unsupportedProduct),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr {
if err := genuineCheckInfo(tt.info); (err != nil) != tt.wantErr && err != tt.err {
t.Errorf("genuineCheckInfo() error = %v, wantErr %v", err, tt.wantErr)
}
})
Expand Down Expand Up @@ -621,6 +639,46 @@ func TestResponseCheckOnly(t *testing.T) {
requestErr: errors.New("request failed"),
wantErr: true,
},
{
name: "Valid request, 500 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusInternalServerError,
Body: ioutil.NopCloser(strings.NewReader("")),
},
requestErr: nil,
wantErr: true,
},
{
name: "Valid request, 404 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusNotFound,
Body: ioutil.NopCloser(strings.NewReader("")),
},
requestErr: nil,
wantErr: true,
},
{
name: "Valid request, 403 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusForbidden,
Body: ioutil.NopCloser(strings.NewReader("")),
},
requestErr: nil,
wantErr: false,
},
{
name: "Valid request, 401 response",
useResponseCheckOnly: false,
response: &http.Response{
StatusCode: http.StatusUnauthorized,
Body: ioutil.NopCloser(strings.NewReader("")),
},
requestErr: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -657,6 +715,9 @@ func TestProductCheckError(t *testing.T) {
if _, err := c.Cat.Indices(); err == nil {
t.Fatal("expected error")
}
if c.productCheckSuccess {
t.Fatalf("product check should be invalid, got %v", c.productCheckSuccess)
}
if _, err := c.Cat.Indices(); err != nil {
t.Fatalf("unexpected error: %s", err)
}
Expand All @@ -666,4 +727,7 @@ func TestProductCheckError(t *testing.T) {
if !reflect.DeepEqual(requestPaths, []string{"/", "/", "/_cat/indices"}) {
t.Fatalf("unexpected request paths: %s", requestPaths)
}
if !c.productCheckSuccess {
t.Fatalf("product check should be valid, got : %v", c.productCheckSuccess)
}
}
9 changes: 8 additions & 1 deletion esapi/esapi_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,14 @@ var (
Body: ioutil.NopCloser(strings.NewReader("MOCK")),
}
defaultRoundTripFn = func(*http.Request) (*http.Response, error) { return defaultResponse, nil }
errorRoundTripFn = func(*http.Request) (*http.Response, error) {
errorRoundTripFn = func(request *http.Request) (*http.Response, error) {
if request.URL.Path == "/" {
return &http.Response{
StatusCode: 200,
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
Body: ioutil.NopCloser(strings.NewReader("{}")),
}, nil
}
return &http.Response{
Header: http.Header{"X-Elastic-Product": []string{"Elasticsearch"}},
StatusCode: 400,
Expand Down