From c4a521bdd675ba5beddb3e011977d1ee219c238d Mon Sep 17 00:00:00 2001 From: Thejas-bhat <35959007+Thejas-bhat@users.noreply.github.com> Date: Tue, 13 Jun 2023 22:19:18 +0530 Subject: [PATCH] recording the cost for different types of queries (#1829) * new callback for tracking bytes read for different query types * bug fix: invocations of the done logic, accounting regexp cost * removing duplicate code * accounting cost of other geo types and code cleanup * unit test fixes * rename bytesRead to search_cost * adding an abort invocation, to indicate that the context was cancelled * code cleanup with respect to messages for callback * comments around the searchResult struct * bug fix: handling a nil pointer check in statsMap() API --- index/scorch/scorch.go | 4 ++ index/scorch/snapshot_index_tfr.go | 8 ++-- index_impl.go | 16 ++++--- index_test.go | 26 +++++------ search.go | 30 +++++++++---- search/collector/topn.go | 4 ++ search/query/geo_boundingbox.go | 2 + search/query/geo_boundingpolygon.go | 2 + search/query/geo_distance.go | 2 + search/query/geo_shape.go | 2 + search/query/numeric_range.go | 1 + search/search.go | 4 -- search/searcher/search_fuzzy.go | 13 +++--- search/searcher/search_geoboundingbox.go | 11 +++-- search/searcher/search_geopointdistance.go | 9 +++- search/searcher/search_geopolygon.go | 9 +++- search/searcher/search_geoshape.go | 9 +++- search/searcher/search_numeric_range.go | 6 ++- search/searcher/search_regexp.go | 3 +- search/searcher/search_term.go | 14 ++++++ search/searcher/search_term_prefix.go | 3 +- search/searcher/search_term_range.go | 3 +- search/util.go | 51 ++++++++++++++++++++++ test/versus_test.go | 4 +- 24 files changed, 181 insertions(+), 55 deletions(-) diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index a4c88b765..f30d795e9 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -588,6 +588,10 @@ func (s *Scorch) StatsMap() map[string]interface{} { m := s.stats.ToMap() indexSnapshot := s.currentSnapshot() + if indexSnapshot == nil { + return nil + } + defer func() { _ = indexSnapshot.Close() }() diff --git a/index/scorch/snapshot_index_tfr.go b/index/scorch/snapshot_index_tfr.go index 349620c71..9f0315fa8 100644 --- a/index/scorch/snapshot_index_tfr.go +++ b/index/scorch/snapshot_index_tfr.go @@ -102,10 +102,10 @@ func (i *IndexSnapshotTermFieldReader) Next(preAlloced *index.TermFieldDoc) (*in // this is because there are chances of having a series of loadChunk calls, // and they have to be added together before sending the bytesRead at this point // upstream. - if delta := i.iterators[i.segmentOffset].BytesRead() - prevBytesRead; delta > 0 { - i.incrementBytesRead(delta) + bytesRead := i.iterators[i.segmentOffset].BytesRead() + if bytesRead > prevBytesRead { + i.incrementBytesRead(bytesRead - prevBytesRead) } - return rv, nil } i.segmentOffset++ @@ -204,6 +204,8 @@ func (i *IndexSnapshotTermFieldReader) Close() error { // reader's bytesRead value statsCallbackFn.(search.SearchIOStatsCallbackFunc)(i.bytesRead) } + + search.RecordSearchCost(i.ctx, search.AddM, i.bytesRead) } if i.snapshot != nil { diff --git a/index_impl.go b/index_impl.go index c5a0c46f4..b5f115411 100644 --- a/index_impl.go +++ b/index_impl.go @@ -474,9 +474,9 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr // accounted by invoking this callback when the TFR is closed. // 2. the docvalues portion (accounted in collector) and the retrieval // of stored fields bytes (by LoadAndHighlightFields) - var totalBytesRead uint64 + var totalSearchCost uint64 sendBytesRead := func(bytesRead uint64) { - totalBytesRead += bytesRead + totalSearchCost += bytesRead } ctx = context.WithValue(ctx, search.SearchIOStatsCallbackKey, @@ -495,11 +495,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr err = serr } if sr != nil { - sr.BytesRead = totalBytesRead + sr.Cost = totalSearchCost } if sr, ok := indexReader.(*scorch.IndexSnapshot); ok { - sr.UpdateIOStats(totalBytesRead) + sr.UpdateIOStats(totalSearchCost) } + + search.RecordSearchCost(ctx, search.DoneM, 0) }() if req.Facets != nil { @@ -574,6 +576,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } } + var storedFieldsCost uint64 for _, hit := range hits { if i.name != "" { hit.Index = i.name @@ -582,9 +585,12 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if err != nil { return nil, err } - totalBytesRead += storedFieldsBytes + storedFieldsCost += storedFieldsBytes } + totalSearchCost += storedFieldsCost + search.RecordSearchCost(ctx, search.AddM, storedFieldsCost) + atomic.AddUint64(&i.stats.searches, 1) searchDuration := time.Since(searchStart) atomic.AddUint64(&i.stats.searchTime, uint64(searchDuration)) diff --git a/index_test.go b/index_test.go index 0c89b4e6d..e9853101e 100644 --- a/index_test.go +++ b/index_test.go @@ -401,7 +401,7 @@ func TestBytesRead(t *testing.T) { } stats, _ := idx.StatsMap()["index"].(map[string]interface{}) prevBytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) - if prevBytesRead != 32349 && res.BytesRead == prevBytesRead { + if prevBytesRead != 32349 && res.Cost == prevBytesRead { t.Fatalf("expected bytes read for query string 32349, got %v", prevBytesRead) } @@ -415,7 +415,7 @@ func TestBytesRead(t *testing.T) { } stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 23 && res.BytesRead == bytesRead-prevBytesRead { + if bytesRead-prevBytesRead != 23 && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for query string 23, got %v", bytesRead-prevBytesRead) } @@ -431,7 +431,7 @@ func TestBytesRead(t *testing.T) { } stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 8468 && res.BytesRead == bytesRead-prevBytesRead { + if bytesRead-prevBytesRead != 8468 && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for fuzzy query is 8468, got %v", bytesRead-prevBytesRead) } @@ -448,7 +448,7 @@ func TestBytesRead(t *testing.T) { stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if !approxSame(bytesRead-prevBytesRead, 150) && res.BytesRead == bytesRead-prevBytesRead { + if !approxSame(bytesRead-prevBytesRead, 150) && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for faceted query is around 150, got %v", bytesRead-prevBytesRead) } @@ -466,7 +466,7 @@ func TestBytesRead(t *testing.T) { stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 924 && res.BytesRead == bytesRead-prevBytesRead { + if bytesRead-prevBytesRead != 924 && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for numeric range query is 924, got %v", bytesRead-prevBytesRead) } @@ -481,7 +481,7 @@ func TestBytesRead(t *testing.T) { stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 60 && res.BytesRead == bytesRead-prevBytesRead { + if bytesRead-prevBytesRead != 60 && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for query with highlighter is 60, got %v", bytesRead-prevBytesRead) } @@ -498,7 +498,7 @@ func TestBytesRead(t *testing.T) { // since it's created afresh and not reused stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 83 && res.BytesRead == bytesRead-prevBytesRead { + if bytesRead-prevBytesRead != 83 && res.Cost == bytesRead-prevBytesRead { t.Fatalf("expected bytes read for disjunction query is 83, got %v", bytesRead-prevBytesRead) } @@ -580,7 +580,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ := idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead != 25928 && bytesRead == res.BytesRead { + if bytesRead != 25928 && bytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 25928, got %v", bytesRead) } prevBytesRead := bytesRead @@ -592,7 +592,7 @@ func TestBytesReadStored(t *testing.T) { } stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 15 && bytesRead-prevBytesRead == res.BytesRead { + if bytesRead-prevBytesRead != 15 && bytesRead-prevBytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 15, got %v", bytesRead-prevBytesRead) } prevBytesRead = bytesRead @@ -607,7 +607,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 26478 && bytesRead-prevBytesRead == res.BytesRead { + if bytesRead-prevBytesRead != 26478 && bytesRead-prevBytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 26478, got %v", bytesRead-prevBytesRead) } @@ -651,7 +651,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx1.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead != 18114 && bytesRead == res.BytesRead { + if bytesRead != 18114 && bytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 18114, got %v", bytesRead) } prevBytesRead = bytesRead @@ -662,7 +662,7 @@ func TestBytesReadStored(t *testing.T) { } stats, _ = idx1.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 12 && bytesRead-prevBytesRead == res.BytesRead { + if bytesRead-prevBytesRead != 12 && bytesRead-prevBytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 12, got %v", bytesRead-prevBytesRead) } prevBytesRead = bytesRead @@ -675,7 +675,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx1.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead-prevBytesRead != 42 && bytesRead-prevBytesRead == res.BytesRead { + if bytesRead-prevBytesRead != 42 && bytesRead-prevBytesRead == res.Cost { t.Fatalf("expected the bytes read stat to be around 42, got %v", bytesRead-prevBytesRead) } } diff --git a/search.go b/search.go index acb812ada..fe426164a 100644 --- a/search.go +++ b/search.go @@ -485,15 +485,27 @@ func (ss *SearchStatus) Merge(other *SearchStatus) { // A SearchResult describes the results of executing // a SearchRequest. +// +// Status - Whether the search was executed on the underlying indexes successfully +// or failed, and the corresponding errors. +// Request - The SearchRequest that was executed. +// Hits - The list of documents that matched the query and their corresponding +// scores, score explanation, location info and so on. +// Total - The total number of documents that matched the query. +// Cost - indicates how expensive was the query with respect to bytes read +// from the mmaped index files. +// MaxScore - The maximum score seen across all document hits seen for this query. +// Took - The time taken to execute the search. +// Facets - The facet results for the search. type SearchResult struct { - Status *SearchStatus `json:"status"` - Request *SearchRequest `json:"request"` - Hits search.DocumentMatchCollection `json:"hits"` - Total uint64 `json:"total_hits"` - BytesRead uint64 `json:"bytesRead"` - MaxScore float64 `json:"max_score"` - Took time.Duration `json:"took"` - Facets search.FacetResults `json:"facets"` + Status *SearchStatus `json:"status"` + Request *SearchRequest `json:"request"` + Hits search.DocumentMatchCollection `json:"hits"` + Total uint64 `json:"total_hits"` + Cost uint64 `json:"cost"` + MaxScore float64 `json:"max_score"` + Took time.Duration `json:"took"` + Facets search.FacetResults `json:"facets"` } func (sr *SearchResult) Size() int { @@ -566,7 +578,7 @@ func (sr *SearchResult) Merge(other *SearchResult) { sr.Status.Merge(other.Status) sr.Hits = append(sr.Hits, other.Hits...) sr.Total += other.Total - sr.BytesRead += other.BytesRead + sr.Cost += other.Cost if other.MaxScore > sr.MaxScore { sr.MaxScore = other.MaxScore } diff --git a/search/collector/topn.go b/search/collector/topn.go index 4d19cd455..808f396d0 100644 --- a/search/collector/topn.go +++ b/search/collector/topn.go @@ -200,6 +200,7 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, hc.needDocIds = hc.needDocIds || loadID select { case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) return ctx.Err() default: next, err = searcher.Next(searchContext) @@ -208,6 +209,7 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, if hc.total%CheckDoneEvery == 0 { select { case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) return ctx.Err() default: } @@ -232,6 +234,8 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, // total bytes read as part of docValues being read every hit // which must be accounted by invoking the callback. statsCallbackFn.(search.SearchIOStatsCallbackFunc)(hc.bytesRead) + + search.RecordSearchCost(ctx, search.AddM, hc.bytesRead) } // help finalize/flush the results in case diff --git a/search/query/geo_boundingbox.go b/search/query/geo_boundingbox.go index ac9125393..1397c7799 100644 --- a/search/query/geo_boundingbox.go +++ b/search/query/geo_boundingbox.go @@ -63,6 +63,8 @@ func (q *GeoBoundingBoxQuery) Searcher(ctx context.Context, i index.IndexReader, field = m.DefaultSearchField() } + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Geo) + if q.BottomRight[0] < q.TopLeft[0] { // cross date line, rewrite as two parts diff --git a/search/query/geo_boundingpolygon.go b/search/query/geo_boundingpolygon.go index 467f39b28..baae514d9 100644 --- a/search/query/geo_boundingpolygon.go +++ b/search/query/geo_boundingpolygon.go @@ -61,6 +61,8 @@ func (q *GeoBoundingPolygonQuery) Searcher(ctx context.Context, i index.IndexRea field = m.DefaultSearchField() } + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Geo) + return searcher.NewGeoBoundedPolygonSearcher(ctx, i, q.Points, field, q.BoostVal.Value(), options) } diff --git a/search/query/geo_distance.go b/search/query/geo_distance.go index f05bf6723..7977d1538 100644 --- a/search/query/geo_distance.go +++ b/search/query/geo_distance.go @@ -64,6 +64,8 @@ func (q *GeoDistanceQuery) Searcher(ctx context.Context, i index.IndexReader, m field = m.DefaultSearchField() } + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Geo) + dist, err := geo.ParseDistance(q.Distance) if err != nil { return nil, err diff --git a/search/query/geo_shape.go b/search/query/geo_shape.go index a63ec80f7..2229dbe9c 100644 --- a/search/query/geo_shape.go +++ b/search/query/geo_shape.go @@ -107,6 +107,8 @@ func (q *GeoShapeQuery) Searcher(ctx context.Context, i index.IndexReader, field = m.DefaultSearchField() } + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Geo) + return searcher.NewGeoShapeSearcher(ctx, i, q.Geometry.Shape, q.Geometry.Relation, field, q.BoostVal.Value(), options) } diff --git a/search/query/numeric_range.go b/search/query/numeric_range.go index ad2474167..205ceecf6 100644 --- a/search/query/numeric_range.go +++ b/search/query/numeric_range.go @@ -77,6 +77,7 @@ func (q *NumericRangeQuery) Searcher(ctx context.Context, i index.IndexReader, m if q.FieldVal == "" { field = m.DefaultSearchField() } + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Numeric) return searcher.NewNumericRangeSearcher(ctx, i, q.Min, q.Max, q.InclusiveMin, q.InclusiveMax, field, q.BoostVal.Value(), options) } diff --git a/search/search.go b/search/search.go index 69d8945f9..d2dd33712 100644 --- a/search/search.go +++ b/search/search.go @@ -27,10 +27,6 @@ var reflectStaticSizeDocumentMatch int var reflectStaticSizeSearchContext int var reflectStaticSizeLocation int -const SearchIOStatsCallbackKey = "_search_io_stats_callback_key" - -type SearchIOStatsCallbackFunc func(uint64) - func init() { var dm DocumentMatch reflectStaticSizeDocumentMatch = int(reflect.TypeOf(dm).Size()) diff --git a/search/searcher/search_fuzzy.go b/search/searcher/search_fuzzy.go index 9423b611e..5345c272b 100644 --- a/search/searcher/search_fuzzy.go +++ b/search/searcher/search_fuzzy.go @@ -59,7 +59,8 @@ func NewFuzzySearcher(ctx context.Context, indexReader index.IndexReader, term s } if ctx != nil { - reportIOStats(dictBytesRead, ctx) + reportIOStats(ctx, dictBytesRead) + search.RecordSearchCost(ctx, search.AddM, dictBytesRead) } return NewMultiTermSearcher(ctx, indexReader, candidates, field, @@ -71,13 +72,15 @@ type fuzzyCandidates struct { bytesRead uint64 } -func reportIOStats(bytesRead uint64, ctx context.Context) { +func reportIOStats(ctx context.Context, bytesRead uint64) { // The fuzzy, regexp like queries essentially load a dictionary, // which potentially incurs a cost that must be accounted by // using the callback to report the value. - statsCallbackFn := ctx.Value(search.SearchIOStatsCallbackKey) - if statsCallbackFn != nil { - statsCallbackFn.(search.SearchIOStatsCallbackFunc)(bytesRead) + if ctx != nil { + statsCallbackFn := ctx.Value(search.SearchIOStatsCallbackKey) + if statsCallbackFn != nil { + statsCallbackFn.(search.SearchIOStatsCallbackFunc)(bytesRead) + } } } diff --git a/search/searcher/search_geoboundingbox.go b/search/searcher/search_geoboundingbox.go index 05ca1bf95..c889ddce0 100644 --- a/search/searcher/search_geoboundingbox.go +++ b/search/searcher/search_geoboundingbox.go @@ -49,7 +49,7 @@ func NewGeoBoundingBoxSearcher(ctx context.Context, indexReader index.IndexReade return nil, err } - return NewFilteringSearcher(ctx, boxSearcher, buildRectFilter(dvReader, + return NewFilteringSearcher(ctx, boxSearcher, buildRectFilter(ctx, dvReader, field, minLon, minLat, maxLon, maxLat)), nil } } @@ -85,7 +85,7 @@ func NewGeoBoundingBoxSearcher(ctx context.Context, indexReader index.IndexReade } // add filter to check points near the boundary onBoundarySearcher = NewFilteringSearcher(ctx, rawOnBoundarySearcher, - buildRectFilter(dvReader, field, minLon, minLat, maxLon, maxLat)) + buildRectFilter(ctx, dvReader, field, minLon, minLat, maxLon, maxLat)) openedSearchers = append(openedSearchers, onBoundarySearcher) } @@ -201,7 +201,7 @@ func buildIsIndexedFunc(ctx context.Context, indexReader index.IndexReader, fiel return isIndexed, closeF, err } -func buildRectFilter(dvReader index.DocValueReader, field string, +func buildRectFilter(ctx context.Context, dvReader index.DocValueReader, field string, minLon, minLat, maxLon, maxLat float64) FilterFunc { return func(d *search.DocumentMatch) bool { // check geo matches against all numeric type terms indexed @@ -222,6 +222,11 @@ func buildRectFilter(dvReader index.DocValueReader, field string, } }) if err == nil && found { + bytes := dvReader.BytesRead() + if bytes > 0 { + reportIOStats(ctx, bytes) + search.RecordSearchCost(ctx, search.AddM, bytes) + } for i := range lons { if geo.BoundingBoxContains(lons[i], lats[i], minLon, minLat, maxLon, maxLat) { diff --git a/search/searcher/search_geopointdistance.go b/search/searcher/search_geopointdistance.go index 01ed20929..fbe958953 100644 --- a/search/searcher/search_geopointdistance.go +++ b/search/searcher/search_geopointdistance.go @@ -66,7 +66,7 @@ func NewGeoPointDistanceSearcher(ctx context.Context, indexReader index.IndexRea // wrap it in a filtering searcher which checks the actual distance return NewFilteringSearcher(ctx, rectSearcher, - buildDistFilter(dvReader, field, centerLon, centerLat, dist)), nil + buildDistFilter(ctx, dvReader, field, centerLon, centerLat, dist)), nil } // boxSearcher builds a searcher for the described bounding box @@ -113,7 +113,7 @@ func boxSearcher(ctx context.Context, indexReader index.IndexReader, return boxSearcher, nil } -func buildDistFilter(dvReader index.DocValueReader, field string, +func buildDistFilter(ctx context.Context, dvReader index.DocValueReader, field string, centerLon, centerLat, maxDist float64) FilterFunc { return func(d *search.DocumentMatch) bool { // check geo matches against all numeric type terms indexed @@ -134,6 +134,11 @@ func buildDistFilter(dvReader index.DocValueReader, field string, } }) if err == nil && found { + bytes := dvReader.BytesRead() + if bytes > 0 { + reportIOStats(ctx, bytes) + search.RecordSearchCost(ctx, search.AddM, bytes) + } for i := range lons { dist := geo.Haversin(lons[i], lats[i], centerLon, centerLat) if dist <= maxDist/1000 { diff --git a/search/searcher/search_geopolygon.go b/search/searcher/search_geopolygon.go index 1d6538adf..a43edafbb 100644 --- a/search/searcher/search_geopolygon.go +++ b/search/searcher/search_geopolygon.go @@ -71,7 +71,7 @@ func NewGeoBoundedPolygonSearcher(ctx context.Context, indexReader index.IndexRe // wrap it in a filtering searcher that checks for the polygon inclusivity return NewFilteringSearcher(ctx, rectSearcher, - buildPolygonFilter(dvReader, field, coordinates)), nil + buildPolygonFilter(ctx, dvReader, field, coordinates)), nil } const float64EqualityThreshold = 1e-6 @@ -83,7 +83,7 @@ func almostEqual(a, b float64) bool { // buildPolygonFilter returns true if the point lies inside the // polygon. It is based on the ray-casting technique as referred // here: https://wrf.ecse.rpi.edu/nikola/pubdetails/pnpoly.html -func buildPolygonFilter(dvReader index.DocValueReader, field string, +func buildPolygonFilter(ctx context.Context, dvReader index.DocValueReader, field string, coordinates []geo.Point) FilterFunc { return func(d *search.DocumentMatch) bool { // check geo matches against all numeric type terms indexed @@ -107,6 +107,11 @@ func buildPolygonFilter(dvReader index.DocValueReader, field string, // Note: this approach works for points which are strictly inside // the polygon. ie it might fail for certain points on the polygon boundaries. if err == nil && found { + bytes := dvReader.BytesRead() + if bytes > 0 { + reportIOStats(ctx, bytes) + search.RecordSearchCost(ctx, search.AddM, bytes) + } nVertices := len(coordinates) if len(coordinates) < 3 { return false diff --git a/search/searcher/search_geoshape.go b/search/searcher/search_geoshape.go index d2c6b1c55..1107c9438 100644 --- a/search/searcher/search_geoshape.go +++ b/search/searcher/search_geoshape.go @@ -54,7 +54,7 @@ func NewGeoShapeSearcher(ctx context.Context, indexReader index.IndexReader, sha } return NewFilteringSearcher(ctx, mSearcher, - buildRelationFilterOnShapes(dvReader, field, relation, shape)), nil + buildRelationFilterOnShapes(ctx, dvReader, field, relation, shape)), nil } @@ -63,7 +63,7 @@ func NewGeoShapeSearcher(ctx context.Context, indexReader index.IndexReader, sha // implementation of doc values. var termSeparatorSplitSlice = []byte{0xff} -func buildRelationFilterOnShapes(dvReader index.DocValueReader, field string, +func buildRelationFilterOnShapes(ctx context.Context, dvReader index.DocValueReader, field string, relation string, shape index.GeoJSON) FilterFunc { // this is for accumulating the shape's actual complete value // spread across multiple docvalue visitor callbacks. @@ -116,6 +116,11 @@ func buildRelationFilterOnShapes(dvReader index.DocValueReader, field string, }) if err == nil && found { + bytes := dvReader.BytesRead() + if bytes > 0 { + reportIOStats(ctx, bytes) + search.RecordSearchCost(ctx, search.AddM, bytes) + } return found } diff --git a/search/searcher/search_numeric_range.go b/search/searcher/search_numeric_range.go index 68728c94c..f086051c1 100644 --- a/search/searcher/search_numeric_range.go +++ b/search/searcher/search_numeric_range.go @@ -88,7 +88,8 @@ func NewNumericRangeSearcher(ctx context.Context, indexReader index.IndexReader, // reporting back the IO stats with respect to the dictionary // loaded, using the context if ctx != nil { - reportIOStats(dictBytesRead, ctx) + reportIOStats(ctx, dictBytesRead) + search.RecordSearchCost(ctx, search.AddM, dictBytesRead) } // cannot return MatchNoneSearcher because of interaction with @@ -110,7 +111,8 @@ func NewNumericRangeSearcher(ctx context.Context, indexReader index.IndexReader, } if ctx != nil { - reportIOStats(dictBytesRead, ctx) + reportIOStats(ctx, dictBytesRead) + search.RecordSearchCost(ctx, search.AddM, dictBytesRead) } return NewMultiTermSearcherBytes(ctx, indexReader, terms, field, diff --git a/search/searcher/search_regexp.go b/search/searcher/search_regexp.go index b419d5470..b88133e31 100644 --- a/search/searcher/search_regexp.go +++ b/search/searcher/search_regexp.go @@ -102,7 +102,8 @@ func NewRegexpSearcher(ctx context.Context, indexReader index.IndexReader, patte } if ctx != nil { - reportIOStats(dictBytesRead, ctx) + reportIOStats(ctx, dictBytesRead) + search.RecordSearchCost(ctx, search.AddM, dictBytesRead) } return NewMultiTermSearcher(ctx, indexReader, candidateTerms, field, boost, diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index db18e5376..cd794ea32 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -39,6 +39,9 @@ type TermSearcher struct { } func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { + if isTermQuery(ctx) { + ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) + } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } @@ -140,3 +143,14 @@ func (s *TermSearcher) Optimize(kind string, octx index.OptimizableContext) ( return nil, nil } + +func isTermQuery(ctx context.Context) bool { + if ctx != nil { + // if the ctx already has a value set for query type + // it would've been done at a non term searcher level. + _, ok := ctx.Value(search.QueryTypeKey).(string) + return !ok + } + // if the context is nil, then don't set the query type + return false +} diff --git a/search/searcher/search_term_prefix.go b/search/searcher/search_term_prefix.go index 89f836a50..dc16e4864 100644 --- a/search/searcher/search_term_prefix.go +++ b/search/searcher/search_term_prefix.go @@ -49,7 +49,8 @@ func NewTermPrefixSearcher(ctx context.Context, indexReader index.IndexReader, p } if ctx != nil { - reportIOStats(fieldDict.BytesRead(), ctx) + reportIOStats(ctx, fieldDict.BytesRead()) + search.RecordSearchCost(ctx, search.AddM, fieldDict.BytesRead()) } return NewMultiTermSearcher(ctx, indexReader, terms, field, boost, options, true) diff --git a/search/searcher/search_term_range.go b/search/searcher/search_term_range.go index a2fb4e993..990c7386b 100644 --- a/search/searcher/search_term_range.go +++ b/search/searcher/search_term_range.go @@ -84,7 +84,8 @@ func NewTermRangeSearcher(ctx context.Context, indexReader index.IndexReader, } if ctx != nil { - reportIOStats(fieldDict.BytesRead(), ctx) + reportIOStats(ctx, fieldDict.BytesRead()) + search.RecordSearchCost(ctx, search.AddM, fieldDict.BytesRead()) } return NewMultiTermSearcher(ctx, indexReader, terms, field, boost, options, true) diff --git a/search/util.go b/search/util.go index 19dd5d68b..7a946868e 100644 --- a/search/util.go +++ b/search/util.go @@ -14,6 +14,8 @@ package search +import "context" + func MergeLocations(locations []FieldTermLocationMap) FieldTermLocationMap { rv := locations[0] @@ -67,3 +69,52 @@ func MergeFieldTermLocations(dest []FieldTermLocation, matches []*DocumentMatch) return dest } + +const SearchIOStatsCallbackKey = "_search_io_stats_callback_key" + +type SearchIOStatsCallbackFunc func(uint64) + +// Implementation of SearchIncrementalCostCallbackFn should handle the following messages +// - add: increment the cost of a search operation +// (which can be specific to a query type as well) +// - abort: query was aborted due to a cancel of search's context (for eg), +// which can be handled differently as well +// - done: indicates that a search was complete and the tracked cost can be +// handled safely by the implementation. +type SearchIncrementalCostCallbackFn func(SearchIncrementalCostCallbackMsg, + SearchQueryType, uint64) +type SearchIncrementalCostCallbackMsg uint +type SearchQueryType uint + +const ( + Term = SearchQueryType(1 << iota) + Geo + Numeric + GenericCost +) + +const ( + AddM = SearchIncrementalCostCallbackMsg(1 << iota) + AbortM + DoneM +) + +const SearchIncrementalCostKey = "_search_incremental_cost_key" +const QueryTypeKey = "_query_type_key" + +func RecordSearchCost(ctx context.Context, + msg SearchIncrementalCostCallbackMsg, bytes uint64) { + if ctx != nil { + queryType, ok := ctx.Value(QueryTypeKey).(SearchQueryType) + if !ok { + // for the cost of the non query type specific factors such as + // doc values and stored fields section. + queryType = GenericCost + } + + aggCallbackFn := ctx.Value(SearchIncrementalCostKey) + if aggCallbackFn != nil { + aggCallbackFn.(SearchIncrementalCostCallbackFn)(msg, queryType, bytes) + } + } +} diff --git a/test/versus_test.go b/test/versus_test.go index 795e37855..248a27f00 100644 --- a/test/versus_test.go +++ b/test/versus_test.go @@ -359,8 +359,8 @@ func testVersusSearches(vt *VersusTest, searchTemplates []string, idxA, idxB ble resA.Hits = nil resB.Hits = nil - resA.BytesRead = 0 - resB.BytesRead = 0 + resA.Cost = 0 + resB.Cost = 0 if !reflect.DeepEqual(resA, resB) { resAj, _ := json.Marshal(resA)