-
Notifications
You must be signed in to change notification settings - Fork 124
/
middleware.go
126 lines (103 loc) · 3.49 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2021 Zenauth Ltd.
// SPDX-License-Identifier: Apache-2.0
package server
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/gorilla/handlers"
grpc_logging "github.com/grpc-ecosystem/go-grpc-middleware/logging"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
)
func XForwardedHostUnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return handler(ctx, req)
}
headers := make(map[string]interface{}, 2) //nolint:gomnd
xfh, ok := md["x-forwarded-host"]
if ok {
headers["x_forwarded_host"] = xfh
}
xff, ok := md["x-forwarded-for"]
if ok {
headers["x_forwarded_for"] = xff
}
if len(headers) > 0 {
tags := grpc_ctxtags.Extract(ctx).Set("http", headers)
return handler(grpc_ctxtags.SetInContext(ctx, tags), req)
}
return handler(ctx, req)
}
// accessLogExclude decides which methods to exclude from being logged to the access log.
func accessLogExclude(method string) bool {
return strings.HasPrefix(method, "/grpc.")
}
// loggingDecider prevents healthcheck requests from being logged.
func loggingDecider(fullMethodName string, _ error) bool {
return fullMethodName != "/grpc.health.v1.Health/Check"
}
// payloadLoggingDecider decides whether to log request payloads.
func payloadLoggingDecider(conf *Conf) grpc_logging.ServerPayloadLoggingDecider {
return func(ctx context.Context, fullMethodName string, servingObject interface{}) bool {
return conf.LogRequestPayloads && strings.HasPrefix(fullMethodName, "/cerbos.svc.v1")
}
}
// messageProducer handles gRPC log messages.
func messageProducer(ctx context.Context, _ string, level zapcore.Level, code codes.Code, err error, duration zapcore.Field) {
ctxzap.Extract(ctx).Check(level, "Handled request").Write(
zap.Error(err),
zap.String("grpc.code", code.String()),
duration,
)
}
// prettyJSON instructs grpc-gateway to output pretty JSON when the query parameter is present.
func prettyJSON(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, ok := r.URL.Query()["pretty"]; ok {
r.Header.Set("Accept", "application/json+pretty")
}
h.ServeHTTP(w, r)
})
}
func customHTTPResponseCode(ctx context.Context, w http.ResponseWriter, _ proto.Message) error {
md, ok := runtime.ServerMetadataFromContext(ctx)
if !ok {
return nil
}
if vals := md.HeaderMD.Get("x-http-code"); len(vals) > 0 {
code, err := strconv.Atoi(vals[0])
if err != nil {
return fmt.Errorf("invalid http code: %w", err)
}
delete(md.HeaderMD, "x-http-code")
delete(w.Header(), "Grpc-Metadata-X-Http-Code")
w.WriteHeader(code)
}
return nil
}
func withCORS(conf *Conf, handler http.Handler) http.Handler {
if conf.CORS.Disabled {
return handler
}
var opts []handlers.CORSOption
if len(conf.CORS.AllowedOrigins) > 0 {
opts = append(opts, handlers.AllowedOrigins(conf.CORS.AllowedOrigins))
} else {
opts = append(opts, handlers.AllowedOrigins([]string{"*"}))
}
if len(conf.CORS.AllowedHeaders) > 0 {
opts = append(opts, handlers.AllowedHeaders(conf.CORS.AllowedHeaders))
}
return handlers.CORS(opts...)(handler)
}