/
server.go
191 lines (156 loc) · 6 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package grpcserver
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"strconv"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/lstoll/grpce/h2c"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
healthgrpc "google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/peer"
"github.com/heroku/x/cmdutil"
"github.com/heroku/x/grpc/requestid"
)
// New configures a gRPC Server with default options and a health server.
func New(opts ...ServerOption) *grpc.Server {
var o options
for _, so := range opts {
so(&o)
}
srv := grpc.NewServer(o.serverOptions()...)
healthpb.RegisterHealthServer(srv, healthgrpc.NewServer())
return srv
}
// A Starter registers and starts itself on the provided grpc.Server.
//
// It's expected Start will call the relevant RegisterXXXServer method
// using srv.
type Starter interface {
Start(srv *grpc.Server) error
}
// RunStandardServer runs a GRPC server with a standard setup including metrics
// (if provider passed), panic handling, a health check service, TLS termination
// with client authentication, and proxy-protocol wrapping.
//
// Deprecated: Use NewStandardServer instead.
func RunStandardServer(logger log.FieldLogger, port int, serverCACerts [][]byte, serverCert, serverKey []byte, server Starter, opts ...ServerOption) error {
cert, err := tls.X509KeyPair(serverCert, serverKey)
if err != nil {
return errors.Wrap(err, "creating X509 key pair")
}
return NewStandardServer(logger, port, serverCACerts, cert, server, opts...).Run()
}
// NewStandardServer configures a GRPC server with a standard setup including metrics
// (if provider passed), panic handling, a health check service, TLS termination
// with client authentication, and proxy-protocol wrapping.
func NewStandardServer(logger log.FieldLogger, port int, serverCACerts [][]byte, serverCert tls.Certificate, server Starter, opts ...ServerOption) cmdutil.Server {
tls, err := TLS(serverCACerts, serverCert)
if err != nil {
logger.Fatal(err)
}
opts = append(opts, tls, LogEntry(logger.WithField("component", "grpc")))
grpcsrv := New(opts...)
if err := server.Start(grpcsrv); err != nil {
logger.Fatal(err)
}
return TCP(logger, grpcsrv, net.JoinHostPort("", strconv.Itoa(port)))
}
// NewStandardH2C create a set of servers suitable for serving gRPC services
// using H2C (aka client upgrades). This is suitable for serving gRPC services
// via both hermes and dogwood-router. HTTP 1.x traffic will be passed to the
// provided handler. This will return a *grpc.Server configured with our
// standard set of services, and a HTTP server that should be what is served on
// a listener.
func NewStandardH2C(http11 http.Handler, opts ...ServerOption) (*grpc.Server, *http.Server) {
o := defaultOptions()
for _, so := range opts {
so(&o)
}
gSrv := grpc.NewServer(o.serverOptions()...)
healthpb.RegisterHealthServer(gSrv, healthgrpc.NewServer())
h2cSrv := &h2c.Server{
HTTP2Handler: gSrv,
NonUpgradeHandler: http11,
}
hSrv := &http.Server{
Handler: h2cSrv,
ReadHeaderTimeout: o.readHeaderTimeout,
}
return gSrv, hSrv
}
// unaryServerErrorUnwrapper removes errors.Wrap annotations from errors so
// gRPC status codes are correctly returned to interceptors and clients later
// in the chain.
func unaryServerErrorUnwrapper(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
res, err := handler(ctx, req)
return res, errors.Cause(err)
}
// streamServerErrorUnwrapper removes errors.Wrap annotations from errors so
// gRPC status codes are correctly returned to interceptors and clients later
// in the chain.
func streamServerErrorUnwrapper(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, ss)
return errors.Cause(err)
}
// unaryRequestIDTagger sets a grpc_ctxtags request_id tag for logging if the
// context includes a request ID.
func unaryRequestIDTagger(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if id, ok := requestid.FromContext(ctx); ok {
grpc_ctxtags.Extract(ctx).Set("request_id", id)
}
return handler(ctx, req)
}
// streamRequestIDTagger sets a grpc_ctxtags request_id tag for logging if the
// context includes a request ID.
func streamRequestIDTagger(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if id, ok := requestid.FromContext(ss.Context()); ok {
grpc_ctxtags.Extract(ss.Context()).Set("request_id", id)
}
return handler(req, ss)
}
// unaryPeerNameTagger sets a grpc_ctxtags peer name tag for logging if the
// caller provider provides a mutual TLS certificate.
func unaryPeerNameTagger(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
peerName := getPeerNameFromContext(ctx)
if peerName != "" {
grpc_ctxtags.Extract(ctx).Set("peer.name", peerName)
}
return handler(ctx, req)
}
// streamPeerNameTagger sets a grpc_ctxtags peer name tag for logging if the
// caller provider provides a mutual TLS certificate.
func streamPeerNameTagger(req interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
peerName := getPeerNameFromContext(ss.Context())
if peerName != "" {
grpc_ctxtags.Extract(ss.Context()).Set("peer.name", peerName)
}
return handler(req, ss)
}
func getPeerNameFromContext(ctx context.Context) string {
cert, ok := getPeerCertFromContext(ctx)
if !ok {
return ""
}
return cert.Subject.CommonName
}
func getPeerCertFromContext(ctx context.Context) (*x509.Certificate, bool) {
p, ok := peer.FromContext(ctx)
if !ok {
return nil, false
}
tlsAuth, ok := p.AuthInfo.(credentials.TLSInfo)
if !ok {
return nil, false
}
if len(tlsAuth.State.PeerCertificates) == 0 {
return nil, false
}
return tlsAuth.State.PeerCertificates[0], true
}