Skip to content

Commit

Permalink
accounting cost of other geo types and code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas-bhat committed Jun 2, 2023
1 parent d062372 commit c133e60
Show file tree
Hide file tree
Showing 11 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion search/searcher/search_fuzzy.go
Expand Up @@ -60,7 +60,7 @@ func NewFuzzySearcher(ctx context.Context, indexReader index.IndexReader, term s

if ctx != nil {
reportIOStats(dictBytesRead, ctx)
aggregateBytesRead(ctx, dictBytesRead)
search.RecordSearchCost(ctx, "add", dictBytesRead)
}

return NewMultiTermSearcher(ctx, indexReader, candidates, field,
Expand Down
4 changes: 4 additions & 0 deletions search/searcher/search_geoboundingbox.go
Expand Up @@ -222,6 +222,10 @@ func buildRectFilter(ctx context.Context, dvReader index.DocValueReader, field s
}
})
if err == nil && found {
bytes := dvReader.BytesRead()
if bytes > 0 {
search.RecordSearchCost(ctx, "add", bytes)
}
for i := range lons {
if geo.BoundingBoxContains(lons[i], lats[i],
minLon, minLat, maxLon, maxLat) {
Expand Down
4 changes: 4 additions & 0 deletions search/searcher/search_geopointdistance.go
Expand Up @@ -134,6 +134,10 @@ func buildDistFilter(ctx context.Context, dvReader index.DocValueReader, field s
}
})
if err == nil && found {
bytes := dvReader.BytesRead()
if bytes > 0 {
search.RecordSearchCost(ctx, "add", bytes)
}
for i := range lons {
dist := geo.Haversin(lons[i], lats[i], centerLon, centerLat)
if dist <= maxDist/1000 {
Expand Down
4 changes: 4 additions & 0 deletions search/searcher/search_geopolygon.go
Expand Up @@ -107,6 +107,10 @@ func buildPolygonFilter(ctx context.Context, dvReader index.DocValueReader, fiel
// Note: this approach works for points which are strictly inside
// the polygon. ie it might fail for certain points on the polygon boundaries.
if err == nil && found {
bytes := dvReader.BytesRead()
if bytes > 0 {
search.RecordSearchCost(ctx, "add", bytes)
}
nVertices := len(coordinates)
if len(coordinates) < 3 {
return false
Expand Down
2 changes: 1 addition & 1 deletion search/searcher/search_geoshape.go
Expand Up @@ -119,7 +119,7 @@ func buildRelationFilterOnShapes(ctx context.Context, dvReader index.DocValueRea
if err == nil && found {
bytes := dvReader.BytesRead()
if bytes > 0 {
aggregateBytesRead(ctx, bytes)
search.RecordSearchCost(ctx, "add", bytes)
}
return found
}
Expand Down
4 changes: 2 additions & 2 deletions search/searcher/search_numeric_range.go
Expand Up @@ -89,7 +89,7 @@ func NewNumericRangeSearcher(ctx context.Context, indexReader index.IndexReader,
// loaded, using the context
if ctx != nil {
reportIOStats(dictBytesRead, ctx)
aggregateBytesRead(ctx, dictBytesRead)
search.RecordSearchCost(ctx, "add", dictBytesRead)
}

// cannot return MatchNoneSearcher because of interaction with
Expand All @@ -112,7 +112,7 @@ func NewNumericRangeSearcher(ctx context.Context, indexReader index.IndexReader,

if ctx != nil {
reportIOStats(dictBytesRead, ctx)
aggregateBytesRead(ctx, dictBytesRead)
search.RecordSearchCost(ctx, "add", dictBytesRead)
}

return NewMultiTermSearcherBytes(ctx, indexReader, terms, field,
Expand Down
2 changes: 1 addition & 1 deletion search/searcher/search_regexp.go
Expand Up @@ -103,7 +103,7 @@ func NewRegexpSearcher(ctx context.Context, indexReader index.IndexReader, patte

if ctx != nil {
reportIOStats(dictBytesRead, ctx)
aggregateBytesRead(ctx, dictBytesRead)
search.RecordSearchCost(ctx, "add", dictBytesRead)
}

return NewMultiTermSearcher(ctx, indexReader, candidateTerms, field, boost,
Expand Down
12 changes: 0 additions & 12 deletions search/searcher/search_term.go
Expand Up @@ -151,15 +151,3 @@ func getQueryType(ctx context.Context) bool {
}
return true
}

func aggregateBytesRead(ctx context.Context, bytes uint64) {
if ctx != nil {
queryType, ok := ctx.Value(search.QueryTypeKey).(string)
if ok {
aggCallbackFn := ctx.Value(search.SearchCostAggregatorKey)
if aggCallbackFn != nil {
aggCallbackFn.(search.SearchCostAggregatorCallbackFn)("add", queryType, bytes)
}
}
}
}
2 changes: 1 addition & 1 deletion search/searcher/search_term_prefix.go
Expand Up @@ -50,7 +50,7 @@ func NewTermPrefixSearcher(ctx context.Context, indexReader index.IndexReader, p

if ctx != nil {
reportIOStats(fieldDict.BytesRead(), ctx)
aggregateBytesRead(ctx, fieldDict.BytesRead())
search.RecordSearchCost(ctx, "add", fieldDict.BytesRead())
}

return NewMultiTermSearcher(ctx, indexReader, terms, field, boost, options, true)
Expand Down
2 changes: 1 addition & 1 deletion search/searcher/search_term_range.go
Expand Up @@ -85,7 +85,7 @@ func NewTermRangeSearcher(ctx context.Context, indexReader index.IndexReader,

if ctx != nil {
reportIOStats(fieldDict.BytesRead(), ctx)
aggregateBytesRead(ctx, fieldDict.BytesRead())
search.RecordSearchCost(ctx, "add", fieldDict.BytesRead())
}

return NewMultiTermSearcher(ctx, indexReader, terms, field, boost, options, true)
Expand Down
13 changes: 9 additions & 4 deletions search/util.go
Expand Up @@ -73,9 +73,14 @@ func MergeFieldTermLocations(dest []FieldTermLocation, matches []*DocumentMatch)
const SearchIOStatsCallbackKey = "_search_io_stats_callback_key"

type SearchIOStatsCallbackFunc func(uint64)
type SearchCostAggregatorCallbackFn func(string, string, uint64)

const SearchCostAggregatorKey = "_search_cost_aggregator_key"
// The callback signature is (message, queryType, cost) which allows
// the caller to act on a particular query type and what its the associated
// cost of an operation. "add" indicates to increment the cost for the query
// "done" indicates a finish of the accounting of the costs.
type SearchIncrementalCostCallbackFn func(string, string, uint64)

const SearchIncrementalCostKey = "_search_incremental_cost_key"
const QueryTypeKey = "_query_type_key"

func RecordSearchCost(ctx context.Context, msg string, bytes uint64) {
Expand All @@ -86,8 +91,8 @@ func RecordSearchCost(ctx context.Context, msg string, bytes uint64) {
queryType = ""
}

aggCallbackFn := ctx.Value(SearchCostAggregatorKey)
aggCallbackFn := ctx.Value(SearchIncrementalCostKey)
if aggCallbackFn != nil {
aggCallbackFn.(SearchCostAggregatorCallbackFn)(msg, queryType, bytes)
aggCallbackFn.(SearchIncrementalCostCallbackFn)(msg, queryType, bytes)
}
}

0 comments on commit c133e60

Please sign in to comment.