/
shard_by.go
104 lines (85 loc) · 2.74 KB
/
shard_by.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package tripperware
import (
"context"
"net/http"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/thanos-io/thanos/pkg/querysharding"
"github.com/thanos-io/thanos/pkg/store/storepb"
"github.com/weaveworks/common/httpgrpc"
querier_stats "github.com/cortexproject/cortex/pkg/querier/stats"
cquerysharding "github.com/cortexproject/cortex/pkg/querysharding"
"github.com/cortexproject/cortex/pkg/tenant"
util_log "github.com/cortexproject/cortex/pkg/util/log"
"github.com/cortexproject/cortex/pkg/util/validation"
)
func ShardByMiddleware(logger log.Logger, limits Limits, merger Merger, queryAnalyzer querysharding.Analyzer) Middleware {
return MiddlewareFunc(func(next Handler) Handler {
return shardBy{
next: next,
limits: limits,
merger: merger,
logger: logger,
analyzer: queryAnalyzer,
}
})
}
type shardBy struct {
next Handler
limits Limits
logger log.Logger
merger Merger
analyzer querysharding.Analyzer
}
func (s shardBy) Do(ctx context.Context, r Request) (Response, error) {
tenantIDs, err := tenant.TenantIDs(ctx)
stats := querier_stats.FromContext(ctx)
if err != nil {
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}
numShards := validation.SmallestPositiveIntPerTenant(tenantIDs, s.limits.QueryVerticalShardSize)
if numShards <= 1 {
return s.next.Do(ctx, r)
}
logger := util_log.WithContext(ctx, s.logger)
analysis, err := s.analyzer.Analyze(r.GetQuery())
if err != nil {
level.Warn(logger).Log("msg", "error analyzing query", "q", r.GetQuery(), "err", err)
return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error())
}
stats.AddExtraFields(
"shard_by.is_shardable", analysis.IsShardable(),
"shard_by.num_shards", numShards,
"shard_by.sharding_labels", analysis.ShardingLabels(),
)
if !analysis.IsShardable() {
return s.next.Do(ctx, r)
}
reqs := s.shardQuery(logger, numShards, r, analysis)
reqResps, err := DoRequests(ctx, s.next, reqs, s.limits)
if err != nil {
return nil, err
}
resps := make([]Response, 0, len(reqResps))
for _, reqResp := range reqResps {
resps = append(resps, reqResp.Response)
}
return s.merger.MergeResponse(ctx, r, resps...)
}
func (s shardBy) shardQuery(l log.Logger, numShards int, r Request, analysis querysharding.QueryAnalysis) []Request {
reqs := make([]Request, numShards)
for i := 0; i < numShards; i++ {
q, err := cquerysharding.InjectShardingInfo(r.GetQuery(), &storepb.ShardInfo{
TotalShards: int64(numShards),
ShardIndex: int64(i),
By: analysis.ShardBy(),
Labels: analysis.ShardingLabels(),
})
reqs[i] = r.WithQuery(q)
if err != nil {
level.Warn(l).Log("msg", "error sharding query", "q", r.GetQuery(), "err", err)
return []Request{r}
}
}
return reqs
}