-
Notifications
You must be signed in to change notification settings - Fork 1
/
server.go
486 lines (409 loc) · 13.2 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
package restserver
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/effective-security/porto/restserver/authz"
"github.com/effective-security/porto/restserver/ready"
"github.com/effective-security/porto/restserver/telemetry"
"github.com/effective-security/porto/xhttp/correlation"
"github.com/effective-security/porto/xhttp/header"
"github.com/effective-security/porto/xhttp/httperror"
"github.com/effective-security/porto/xhttp/identity"
"github.com/effective-security/porto/xhttp/marshal"
"github.com/effective-security/x/netutil"
"github.com/effective-security/xlog"
"github.com/pkg/errors"
)
var logger = xlog.NewPackageLogger("github.com/effective-security/porto", "rest")
// MaxRequestSize specifies max size of regular HTTP Post requests in bytes, 64 Mb
const MaxRequestSize = 64 * 1024 * 1024
const (
// EvtSourceStatus specifies source for service Status
EvtSourceStatus = "status"
// EvtServiceStarted specifies Service Started event
EvtServiceStarted = "service started"
// EvtServiceStopped specifies Service Stopped event
EvtServiceStopped = "service stopped"
)
// ServerEvent specifies server event type
type ServerEvent int
const (
// ServerStartedEvent is fired on server start
ServerStartedEvent ServerEvent = iota
// ServerStoppedEvent is fired after server stopped
ServerStoppedEvent
// ServerStoppingEvent is fired before server stopped
ServerStoppingEvent
)
// ServerEventFunc is a callback to handle server events
type ServerEventFunc func(evt ServerEvent)
// Server is an interface to provide server status
type Server interface {
http.Handler
Name() string
Version() string
HostName() string
LocalIP() string
Port() string
Protocol() string
PublicURL() string
StartedAt() time.Time
Service(name string) Service
Config() Config
TLSConfig() *tls.Config
// IsReady indicates that all subservices are ready to serve
IsReady() bool
AddService(s Service)
StartHTTP() error
StopHTTP()
OnEvent(evt ServerEvent, handler ServerEventFunc)
}
// MuxFactory creates http handlers.
type MuxFactory interface {
NewMux() http.Handler
}
// HTTPServer is responsible for exposing the collection of the services
// as a single HTTP server
type HTTPServer struct {
Server
authz authz.HTTPAuthz
identityMapper identity.ProviderFromRequest
httpConfig Config
tlsConfig *tls.Config
httpServer *http.Server
cors *CORSOptions
muxFactory MuxFactory
hostname string
port string
ipaddr string
version string
serving bool
startedAt time.Time
clientAuth string
services map[string]Service
evtHandlers map[ServerEvent][]ServerEventFunc
lock sync.RWMutex
shutdownTimeout time.Duration
}
// New creates a new instance of the server
func New(
version string,
ipaddr string,
httpConfig Config,
tlsConfig *tls.Config,
) (*HTTPServer, error) {
var err error
// TODO: shall extract from bindAddr?
if ipaddr == "" {
ipaddr, err = netutil.GetLocalIP()
if err != nil {
ipaddr = "127.0.0.1"
logger.KV(xlog.ERROR, "reason", "unable_determine_ipaddr", "use", ipaddr, "err", err)
}
}
s := &HTTPServer{
services: map[string]Service{},
startedAt: time.Now().UTC(),
version: version,
ipaddr: ipaddr,
evtHandlers: make(map[ServerEvent][]ServerEventFunc),
clientAuth: tlsClientAuthToStrMap[tls.NoClientCert],
httpConfig: httpConfig,
// TODO: hostname shall be from os.Host
hostname: GetHostName(httpConfig.GetBindAddr()),
port: GetPort(httpConfig.GetBindAddr()),
tlsConfig: tlsConfig,
shutdownTimeout: time.Duration(5) * time.Second,
}
s.muxFactory = s
if tlsConfig != nil {
s.clientAuth = tlsClientAuthToStrMap[tlsConfig.ClientAuth]
}
return s, nil
}
// WithAuthz enables to use Authz
func (server *HTTPServer) WithAuthz(authz authz.HTTPAuthz) *HTTPServer {
server.authz = authz
return server
}
// WithIdentityProvider enables to set idenity on each request
func (server *HTTPServer) WithIdentityProvider(provider identity.ProviderFromRequest) *HTTPServer {
server.identityMapper = provider
return server
}
// WithCORS enables CORS options
func (server *HTTPServer) WithCORS(cors *CORSOptions) *HTTPServer {
server.cors = cors
return server
}
// WithShutdownTimeout sets the connection draining timeouts on server shutdown
func (server *HTTPServer) WithShutdownTimeout(timeout time.Duration) *HTTPServer {
server.shutdownTimeout = timeout
return server
}
var tlsClientAuthToStrMap = map[tls.ClientAuthType]string{
tls.NoClientCert: "NoClientCert",
tls.RequestClientCert: "RequestClientCert",
tls.RequireAnyClientCert: "RequireAnyClientCert",
tls.VerifyClientCertIfGiven: "VerifyClientCertIfGiven",
tls.RequireAndVerifyClientCert: "RequireAndVerifyClientCert",
}
// AddService provides a service registration for the server
func (server *HTTPServer) AddService(s Service) {
server.lock.Lock()
defer server.lock.Unlock()
if server.services[s.Name()] != nil {
logger.Panicf("service already registered: %s", s.Name())
}
server.services[s.Name()] = s
}
// OnEvent accepts a callback to handle server events
func (server *HTTPServer) OnEvent(evt ServerEvent, handler ServerEventFunc) {
server.lock.Lock()
defer server.lock.Unlock()
server.evtHandlers[evt] = append(server.evtHandlers[evt], handler)
}
// Service returns a registered server
func (server *HTTPServer) Service(name string) Service {
server.lock.Lock()
defer server.lock.Unlock()
return server.services[name]
}
// HostName returns the host name of the server
func (server *HTTPServer) HostName() string {
return server.hostname
}
// Port returns the port name of the server
func (server *HTTPServer) Port() string {
return server.port
}
// Protocol returns the protocol
func (server *HTTPServer) Protocol() string {
if server.tlsConfig != nil {
return "https"
}
return "http"
}
// LocalIP returns the IP address of the server
func (server *HTTPServer) LocalIP() string {
return server.ipaddr
}
// PublicURL returns the public URL of the server
func (server *HTTPServer) PublicURL() string {
return server.httpConfig.GetPublicURL()
}
// StartedAt returns the time when the server started
func (server *HTTPServer) StartedAt() time.Time {
return server.startedAt
}
// Uptime returns the duration the server was up
func (server *HTTPServer) Uptime() time.Duration {
return time.Now().UTC().Sub(server.startedAt)
}
// Version returns the version of the server
func (server *HTTPServer) Version() string {
return server.version
}
// Name returns the server name
func (server *HTTPServer) Name() string {
return server.httpConfig.GetServerName()
}
// HTTPConfig returns HTTPServerConfig
func (server *HTTPServer) HTTPConfig() Config {
return server.httpConfig
}
// TLSConfig returns TLSConfig
func (server *HTTPServer) TLSConfig() *tls.Config {
return server.tlsConfig
}
// IsReady returns true when the server is ready to serve
func (server *HTTPServer) IsReady() bool {
if !server.serving {
return false
}
for _, ss := range server.services {
if !ss.IsReady() {
return false
}
}
return true
}
// WithMuxFactory requires the server to use `muxFactory` to create server handler.
func (server *HTTPServer) WithMuxFactory(muxFactory MuxFactory) {
server.muxFactory = muxFactory
}
func (server *HTTPServer) broadcast(evt ServerEvent) {
for _, handler := range server.evtHandlers[evt] {
handler(evt)
}
}
// StartHTTP will verify all the TLS related files are present and start the actual HTTPS listener for the server
func (server *HTTPServer) StartHTTP() error {
bindAddr := server.httpConfig.GetBindAddr()
var err error
// Main server
if _, err = net.ResolveTCPAddr("tcp", bindAddr); err != nil {
return errors.WithMessagef(err, "unable to resolve address")
}
server.httpServer = &http.Server{
IdleTimeout: time.Hour, // TODO: via config
ErrorLog: xlog.Stderr,
}
var httpsListener net.Listener
if server.tlsConfig != nil {
// Start listening on main server over TLS
httpsListener, err = tls.Listen("tcp", bindAddr, server.tlsConfig)
if err != nil {
return errors.WithMessagef(err, "%s: unable to listen: %q",
server.Name(), bindAddr)
}
server.httpServer.TLSConfig = server.tlsConfig
} else {
server.httpServer.Addr = bindAddr
}
httpHandler := server.muxFactory.NewMux()
/*
if server.httpConfig.GetAllowProfiling() {
httpHandler, err = telemetry.NewRequestProfiler(httpHandler, server.httpConfig.GetProfilerDir(), nil, telemetry.LogProfile())
if err != nil {
return errors.WithStack(err)
}
}
*/
server.httpServer.Handler = httpHandler
serve := func() error {
server.serving = true
if httpsListener != nil {
return server.httpServer.Serve(httpsListener)
}
return server.httpServer.ListenAndServe()
}
go func() {
server.broadcast(ServerStartedEvent)
logger.KV(xlog.INFO, "server", server.Name(), "bind", bindAddr, "status", "starting", "protocol", server.Protocol())
// this is a blocking call to serve
if err := serve(); err != nil {
server.serving = false
// panic, only if not Serve error while stopping the server,
// which is a valid error
if netutil.IsAddrInUse(err) || err != http.ErrServerClosed {
logger.Panicf("server=%s, err=[%v]", server.Name(), errors.WithStack(err))
}
logger.KV(xlog.WARNING, "server", server.Name(), "status", "stopped", "reason", err.Error())
}
}()
return nil
}
// StopHTTP will perform a graceful shutdown of the serivce by
// 1. signally to the Load Balancer to remove this instance from the pool
// by changing to response to /availability
// 2. cause new responses to have their Connection closed when finished
// to force clients to re-connect [hopefully to a different instance]
// 3. wait the minShutdownTime to ensure the LB has noticed the status change
// 4. wait for existing requests to finish processing
// 5. step 4 is capped by a overrall timeout where we'll give up waiting
// for the requests to complete and will exit.
//
// it is expected that you don't try and use the server instance again
// after this. [i.e. if you want to start it again, create another server instance]
func (server *HTTPServer) StopHTTP() {
server.broadcast(ServerStoppingEvent)
// close services
for _, f := range server.services {
logger.KV(xlog.TRACE, "service", f.Name(), "status", "closing")
f.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), server.shutdownTimeout)
defer cancel()
err := server.httpServer.Shutdown(ctx)
if err != nil {
logger.KV(xlog.ERROR, "reason", "Shutdown", "err", err)
}
server.broadcast(ServerStoppedEvent)
}
// NewMux creates a new http handler for the http server, typically you only
// need to call this directly for tests.
func (server *HTTPServer) NewMux() http.Handler {
// NOTE: the handlers are executed in the reverse order
var router Router
if server.cors != nil {
router = NewRouterWithCORS(notFoundHandler, server.cors)
} else {
router = NewRouter(notFoundHandler)
}
for _, f := range server.services {
f.Register(router)
}
logger.KV(xlog.DEBUG, "server", server.Name(), "service_count", len(server.services))
var err error
httpHandler := router.Handler()
logger.KV(xlog.INFO, "server", server.Name(), "ClientAuth", server.clientAuth)
// service ready
httpHandler = ready.NewServiceStatusVerifier(server, httpHandler)
if server.authz != nil {
httpHandler, err = server.authz.NewHandler(httpHandler)
if err != nil {
logger.Panicf("failed to create authz handler: %+v", err)
}
}
// logging wrapper
httpHandler = telemetry.NewRequestLogger(
httpHandler,
time.Millisecond,
logger)
// metrics wrapper
httpHandler = telemetry.NewRequestMetrics(httpHandler)
// role/contextID wrapper
if server.identityMapper != nil {
httpHandler = identity.NewContextHandler(httpHandler, server.identityMapper)
} else {
httpHandler = identity.NewContextHandler(httpHandler, identity.GuestIdentityMapper)
}
// Add correlationID
httpHandler = correlation.NewHandler(httpHandler)
return httpHandler
}
// ServeHTTP should write reply headers and data to the ResponseWriter
// and then return. Returning signals that the request is finished; it
// is not valid to use the ResponseWriter or read from the
// Request.Body after or concurrently with the completion of the
// ServeHTTP call.
func (server *HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
server.httpServer.Handler.ServeHTTP(w, r)
}
func notFoundHandler(w http.ResponseWriter, r *http.Request) {
marshal.WriteJSON(w, r, httperror.NotFound(r.URL.Path))
}
// GetServerURL returns complete server URL for given relative end-point
func GetServerURL(s Server, r *http.Request, relativeEndpoint string) *url.URL {
proto := s.Protocol()
// Allow upstream proxies to specify the forwarded protocol. Allow this value
// to override our own guess.
if specifiedProto := r.Header.Get(header.XForwardedProto); specifiedProto != "" {
proto = specifiedProto
}
host := r.URL.Host
if host == "" {
host = r.Host
}
if host == "" {
host = s.HostName() + ":" + s.Port()
}
return &url.URL{
Scheme: proto,
Host: host,
Path: relativeEndpoint,
}
}
// GetServerBaseURL returns server base URL
func GetServerBaseURL(s Server) *url.URL {
return &url.URL{
Scheme: s.Protocol(),
Host: s.HostName() + ":" + s.Port(),
}
}