-
Notifications
You must be signed in to change notification settings - Fork 124
/
context.go
115 lines (90 loc) · 2.44 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// Copyright 2021-2024 Zenauth Ltd.
// SPDX-License-Identifier: Apache-2.0
package audit
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"strings"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
auditv1 "github.com/cerbos/cerbos/api/genpb/cerbos/audit/v1"
"github.com/cerbos/cerbos/internal/util"
)
const (
grpcGWUserAgentKey = "grpcgateway-user-agent"
userAgentKey = "user-agent"
xffKey = "x-forwarded-for"
callIDTagKey = "call_id"
SetByGRPCGatewayKey = "x-cerbos-set-by-grpc-gateway"
HTTPRemoteAddrKey = "x-cerbos-http-remote-addr"
)
var SetByGRPCGatewayVal string
func init() {
SetByGRPCGatewayVal = generateSetByGRPCGatewayVal()
}
type callIDCtxKeyType struct{}
var callIDCtxKey = callIDCtxKeyType{}
func NewContextWithCallID(ctx context.Context, id ID) context.Context {
tagCtx := logging.InjectLogField(ctx, util.AppName, map[string]any{callIDTagKey: id})
return context.WithValue(tagCtx, callIDCtxKey, id)
}
func CallIDFromContext(ctx context.Context) (ID, bool) {
idVal := ctx.Value(callIDCtxKey)
if idVal == nil {
return "", false
}
id, ok := idVal.(ID)
if !ok {
return "", false
}
return id, true
}
func PeerFromContext(ctx context.Context) *auditv1.Peer {
p := peerFromContext(ctx)
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return p
}
setByGateway := checkSetByGRPCGateway(md)
var ua []string
xff := md[xffKey]
if setByGateway {
if addr := md[HTTPRemoteAddrKey]; len(addr) > 0 {
p.Address = addr[len(addr)-1]
}
ua = md[grpcGWUserAgentKey]
} else {
ua = md[userAgentKey]
}
p.UserAgent = strings.Join(ua, "|")
p.ForwardedFor = strings.Join(xff, ", ")
return p
}
func peerFromContext(ctx context.Context) *auditv1.Peer {
p, ok := peer.FromContext(ctx)
if !ok {
return &auditv1.Peer{}
}
pp := &auditv1.Peer{Address: p.Addr.String()}
if p.AuthInfo != nil {
pp.AuthInfo = p.AuthInfo.AuthType()
}
return pp
}
func generateSetByGRPCGatewayVal() string {
const n = 32
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
panic(fmt.Errorf("failed to generate %s header value: %w", SetByGRPCGatewayKey, err))
}
return base64.StdEncoding.EncodeToString(b)
}
func checkSetByGRPCGateway(md metadata.MD) bool {
v := md[SetByGRPCGatewayKey]
return len(v) > 0 && subtle.ConstantTimeCompare([]byte(v[len(v)-1]), []byte(SetByGRPCGatewayVal)) == 1
}