-
Notifications
You must be signed in to change notification settings - Fork 124
/
interceptor.go
89 lines (73 loc) · 2.21 KB
/
interceptor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright 2021-2024 Zenauth Ltd.
// SPDX-License-Identifier: Apache-2.0
package audit
import (
"context"
"time"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/timestamppb"
auditv1 "github.com/cerbos/cerbos/api/genpb/cerbos/audit/v1"
"github.com/cerbos/cerbos/internal/observability/logging"
"github.com/cerbos/cerbos/internal/observability/tracing"
)
type (
ExcludeMethod func(string) bool
IncludeKeysMethod func(string) bool
)
func NewUnaryInterceptor(log Log, exclude ExcludeMethod) (grpc.UnaryServerInterceptor, error) {
mdExtractor, err := NewMetadataExtractor()
if err != nil {
return nil, err
}
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if exclude(info.FullMethod) {
return handler(ctx, req)
}
ts := time.Now()
callID, err := NewID()
if err != nil {
logging.FromContext(ctx).Warn("Failed to generate call ID", zap.Error(err))
return handler(ctx, req)
}
resp, err := handler(NewContextWithCallID(ctx, callID), req)
if logErr := log.WriteAccessLogEntry(ctx, func() (*auditv1.AccessLogEntry, error) {
ctx, span := tracing.StartSpan(ctx, "audit.WriteAccessLog")
defer span.End()
return &auditv1.AccessLogEntry{
CallId: string(callID),
Timestamp: timestamppb.New(ts),
Peer: PeerFromContext(ctx),
Method: info.FullMethod,
StatusCode: uint32(status.Code(err)),
Metadata: mdExtractor(ctx),
}, nil
}); logErr != nil {
logging.FromContext(ctx).Warn("Failed to write access log entry", zap.Error(logErr))
}
setCerbosCallID(string(callID), resp)
return resp, err
}, nil
}
type cerbosAPIResponse interface {
proto.Message
GetCerbosCallId() string
}
func setCerbosCallID(callID string, resp any) {
if resp == nil {
return
}
// don't panic in case there's nil pointer error
defer func() {
_ = recover()
}()
if r, ok := resp.(cerbosAPIResponse); ok {
fd := r.ProtoReflect().Descriptor().Fields().ByTextName("cerbos_call_id")
if fd != nil {
r.ProtoReflect().Set(fd, protoreflect.ValueOf(callID))
}
}
}