Skip to content

Commit

Permalink
Add MaxResponseSize to guard against OOMs
Browse files Browse the repository at this point in the history
This commit adds an optional `PerformRequestOptions.MaxResponseSize` to guard against OOMs. It is added to the `SearchService` as well.

Close #929
  • Loading branch information
Erik Grinaker authored and olivere committed Oct 11, 2018
1 parent 12de707 commit d4bc527
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 17 deletions.
21 changes: 11 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1230,14 +1230,15 @@ func (c *Client) mustActiveConn() error {

// PerformRequestOptions must be passed into PerformRequest.
type PerformRequestOptions struct {
Method string
Path string
Params url.Values
Body interface{}
ContentType string
IgnoreErrors []int
Retrier Retrier
Headers http.Header
Method string
Path string
Params url.Values
Body interface{}
ContentType string
IgnoreErrors []int
Retrier Retrier
Headers http.Header
MaxResponseSize int64
}

// PerformRequest does a HTTP request to Elasticsearch.
Expand Down Expand Up @@ -1376,14 +1377,14 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
if err := checkResponse((*http.Request)(req), res, opt.IgnoreErrors...); err != nil {
// No retry if request succeeded
// We still try to return a response.
resp, _ = c.newResponse(res)
resp, _ = c.newResponse(res, opt.MaxResponseSize)
return resp, err
}

// We successfully made a request with this connection
conn.MarkAsHealthy()

resp, err = c.newResponse(res)
resp, err = c.newResponse(res, opt.MaxResponseSize)
if err != nil {
return nil, err
}
Expand Down
27 changes: 27 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,33 @@ func TestPerformRequestWithCustomLogger(t *testing.T) {
}
}

func TestPerformRequestWithMaxResponseSize(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
Method: "GET",
Path: "/",
MaxResponseSize: 1000,
})
if err != nil {
t.Fatal(err)
}
if res == nil {
t.Fatal("expected response to be != nil")
}

res, err = client.PerformRequest(context.TODO(), PerformRequestOptions{
Method: "GET",
Path: "/",
MaxResponseSize: 100,
})
if err != ErrResponseSize {
t.Fatal("expected response size error")
}
}

// failingTransport will run a fail callback if it sees a given URL path prefix.
type failingTransport struct {
path string // path prefix to look for
Expand Down
21 changes: 19 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@ package elastic

import (
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
)

var (
// ErrResponseSize is raised if a response body exceeds the given max body size.
ErrResponseSize = errors.New("response size too large")
)

// Response represents a response from Elasticsearch.
type Response struct {
// StatusCode is the HTTP status code, e.g. 200.
Expand All @@ -22,16 +29,26 @@ type Response struct {
}

// newResponse creates a new response from the HTTP response.
func (c *Client) newResponse(res *http.Response) (*Response, error) {
func (c *Client) newResponse(res *http.Response, maxBodySize int64) (*Response, error) {
r := &Response{
StatusCode: res.StatusCode,
Header: res.Header,
}
if res.Body != nil {
slurp, err := ioutil.ReadAll(res.Body)
body := io.Reader(res.Body)
if maxBodySize > 0 {
if res.ContentLength > maxBodySize {
return nil, ErrResponseSize
}
body = io.LimitReader(body, maxBodySize+1)
}
slurp, err := ioutil.ReadAll(body)
if err != nil {
return nil, err
}
if maxBodySize > 0 && int64(len(slurp)) > maxBodySize {
return nil, ErrResponseSize
}
// HEAD requests return a body but no content
if len(slurp) > 0 {
r.Body = json.RawMessage(slurp)
Expand Down
2 changes: 1 addition & 1 deletion response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func BenchmarkResponse(b *testing.B) {
StatusCode: http.StatusOK,
}
var err error
resp, err = c.newResponse(res)
resp, err = c.newResponse(res, 0)
if err != nil {
b.Fatal(err)
}
Expand Down
17 changes: 13 additions & 4 deletions search.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type SearchService struct {
ignoreUnavailable *bool
allowNoIndices *bool
expandWildcards string
maxResponseSize int64
}

// NewSearchService creates a new service for searching in Elasticsearch.
Expand Down Expand Up @@ -312,6 +313,13 @@ func (s *SearchService) ExpandWildcards(expandWildcards string) *SearchService {
return s
}

// MaxResponseSize sets an upper limit on the response body size that we accept,
// to guard against OOM situations.
func (s *SearchService) MaxResponseSize(maxResponseSize int64) *SearchService {
s.maxResponseSize = maxResponseSize
return s
}

// buildURL builds the URL for the operation.
func (s *SearchService) buildURL() (string, url.Values, error) {
var err error
Expand Down Expand Up @@ -399,10 +407,11 @@ func (s *SearchService) Do(ctx context.Context) (*SearchResult, error) {
body = src
}
res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
Method: "POST",
Path: path,
Params: params,
Body: body,
Method: "POST",
Path: path,
Params: params,
Body: body,
MaxResponseSize: s.maxResponseSize,
})
if err != nil {
return nil, err
Expand Down

0 comments on commit d4bc527

Please sign in to comment.