/
options.go
238 lines (203 loc) · 6.85 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
package grpcserver
import (
"context"
"crypto/tls"
"crypto/x509"
"time"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator"
"github.com/sirupsen/logrus"
"go.opencensus.io/plugin/ocgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"github.com/heroku/x/go-kit/metrics"
"github.com/heroku/x/grpc/grpcmetrics"
"github.com/heroku/x/grpc/panichandler"
"github.com/heroku/x/tlsconfig"
)
const (
defaultReadHeaderTimeout = 60 * time.Second
)
var defaultLogOpts = []grpc_logrus.Option{
grpc_logrus.WithCodes(ErrorToCode),
}
type options struct {
logEntry *logrus.Entry
metricsProvider metrics.Provider
authUnaryInterceptor grpc.UnaryServerInterceptor
authStreamInterceptor grpc.StreamServerInterceptor
highCardUnaryInterceptor grpc.UnaryServerInterceptor
highCardStreamInterceptor grpc.StreamServerInterceptor
readHeaderTimeout time.Duration
useValidateInterceptor bool
grpcOptions []grpc.ServerOption
}
func defaultOptions() options {
return options{
readHeaderTimeout: defaultReadHeaderTimeout,
}
}
// ServerOption sets optional fields on the standard gRPC server
type ServerOption func(*options)
// GRPCOption adds a grpc ServerOption to the server.
func GRPCOption(opt grpc.ServerOption) ServerOption {
return func(o *options) {
o.grpcOptions = append(o.grpcOptions, opt)
}
}
// LogEntry provided will be added to the context
func LogEntry(entry *logrus.Entry) ServerOption {
return func(o *options) {
o.logEntry = entry
}
}
// MetricsProvider will have metrics reported to it
func MetricsProvider(provider metrics.Provider) ServerOption {
return func(o *options) {
o.metricsProvider = provider
}
}
// AuthInterceptors sets interceptors that are intended for
// authentication/authorization in the correct locations in the chain
func AuthInterceptors(unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) ServerOption {
return func(o *options) {
o.authUnaryInterceptor = unary
o.authStreamInterceptor = stream
}
}
// HighCardInterceptors sets interceptors that use
// Attributes/Labels on the instrumentation.
func HighCardInterceptors(unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) ServerOption {
return func(o *options) {
o.highCardUnaryInterceptor = unary
o.highCardStreamInterceptor = stream
}
}
// WithOCGRPCServerHandler sets the grpc server up with provided ServerHandler
// as its StatsHandler
func WithOCGRPCServerHandler(h *ocgrpc.ServerHandler) ServerOption {
return func(o *options) {
o.grpcOptions = append(o.grpcOptions, grpc.StatsHandler(h))
}
}
func WithReadHeaderTimeout(d time.Duration) ServerOption {
return func(o *options) {
o.readHeaderTimeout = d
}
}
// ValidateInterceptor sets interceptors that will validate every
// message that has a receiver of the form `Validate() error`
//
// See github.com/mwitkow/go-proto-validators for details.
func ValidateInterceptor() ServerOption {
return func(o *options) {
o.useValidateInterceptor = true
}
}
func (o *options) unaryInterceptors() []grpc.UnaryServerInterceptor {
l := o.logEntry
if l == nil {
l = logrus.NewEntry(logrus.New())
}
i := []grpc.UnaryServerInterceptor{
panichandler.LoggingUnaryPanicHandler(l),
grpc_ctxtags.UnaryServerInterceptor(),
UnaryPayloadLoggingTagger,
unaryRequestIDTagger,
unaryPeerNameTagger,
}
if o.highCardUnaryInterceptor != nil {
i = append(i, o.highCardUnaryInterceptor)
} else if o.metricsProvider != nil {
i = append(i, grpcmetrics.NewUnaryServerInterceptor(o.metricsProvider)) // report metrics on unwrapped errors
}
i = append(i,
unaryServerErrorUnwrapper, // unwrap after we've logged
grpc_logrus.UnaryServerInterceptor(l, defaultLogOpts...),
)
if o.authUnaryInterceptor != nil {
i = append(i, o.authUnaryInterceptor)
}
if o.useValidateInterceptor {
i = append(i, grpc_validator.UnaryServerInterceptor())
}
return i
}
func (o *options) streamInterceptors() []grpc.StreamServerInterceptor {
l := o.logEntry
if l == nil {
l = logrus.NewEntry(logrus.New())
}
i := []grpc.StreamServerInterceptor{
panichandler.LoggingStreamPanicHandler(l),
grpc_ctxtags.StreamServerInterceptor(),
streamRequestIDTagger,
streamPeerNameTagger,
}
if o.highCardStreamInterceptor != nil {
i = append(i, o.highCardStreamInterceptor)
} else if o.metricsProvider != nil {
i = append(i, grpcmetrics.NewStreamServerInterceptor(o.metricsProvider)) // report metrics on unwrapped errors
}
i = append(i,
streamServerErrorUnwrapper, // unwrap after we've logged
grpc_logrus.StreamServerInterceptor(l, defaultLogOpts...),
)
if o.authStreamInterceptor != nil {
i = append(i, o.authStreamInterceptor)
}
if o.useValidateInterceptor {
i = append(i, grpc_validator.StreamServerInterceptor())
}
return i
}
func (o *options) serverOptions() []grpc.ServerOption {
opts := []grpc.ServerOption{
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(o.unaryInterceptors()...)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(o.streamInterceptors()...)),
}
opts = append(opts, o.grpcOptions...)
return opts
}
// TLS returns a ServerOption which adds mutual-TLS to the gRPC server.
func TLS(caCerts [][]byte, serverCert tls.Certificate) (ServerOption, error) {
tlsConfig, err := tlsconfig.NewMutualTLS(caCerts, serverCert)
if err != nil {
return nil, err
}
return GRPCOption(grpc.Creds(credentials.NewTLS(tlsConfig))), nil
}
// WithPeerValidator configures the gRPC server to reject calls from peers
// which do not provide a certificate or for which the provided function
// returns false.
func WithPeerValidator(f func(*x509.Certificate) bool) ServerOption {
return func(o *options) {
o.authStreamInterceptor = func(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := validatePeer(ss.Context(), f); err != nil {
return err
}
return handler(req, ss)
}
o.authUnaryInterceptor = func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if err := validatePeer(ctx, f); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
}
func validatePeer(ctx context.Context, f func(*x509.Certificate) bool) error {
cert, ok := getPeerCertFromContext(ctx)
if !ok {
// TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck)
return grpc.Errorf(codes.Unauthenticated, "unauthenticated") //nolint:staticcheck
}
if !f(cert) {
// TODO: SA1019: grpc.Errorf is deprecated: use status.Errorf instead. (staticcheck)
return grpc.Errorf(codes.PermissionDenied, "forbidden") //nolint:staticcheck
}
return nil
}