-
Notifications
You must be signed in to change notification settings - Fork 1
/
ctx.go
198 lines (167 loc) · 5.53 KB
/
ctx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
// Package identity extracts the callers contextual identity information from the HTTP/TLS
// requests and exposes them for access via the generalized go context model.
package identity
import (
"context"
"net/http"
"sync"
"github.com/go-phorce/dolly/algorithms/guid"
"github.com/go-phorce/dolly/netutil"
"github.com/go-phorce/dolly/xhttp/header"
"github.com/go-phorce/dolly/xhttp/httperror"
"github.com/go-phorce/dolly/xhttp/marshal"
"github.com/go-phorce/dolly/xlog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var logger = xlog.NewPackageLogger("github.com/go-phorce/dolly", "xhttp/context")
type contextKey int
const (
keyContext contextKey = iota
keyIdentity
)
// NodeInfoFactory returns NodeInfo
type NodeInfoFactory func() netutil.NodeInfo
var (
nodeInfoFactory = newNodeInfoFactory()
)
// RequestContext represents user contextual information about a request being processed by the server,
// it includes identity, CorrelationID [for cross system request correlation].
type RequestContext struct {
identity Identity
correlationID string
clientIP string
}
// NewRequestContext creates a request context with a specific identity.
func NewRequestContext(id Identity) *RequestContext {
return &RequestContext{
identity: id,
}
}
// Context represents user contextual information about a request being processed by the server,
// it includes identity, CorrelationID [for cross system request correlation].
type Context interface {
Identity() Identity
CorrelationID() string
ClientIP() string
}
type defaultNodeInfoFactory struct {
lock sync.Mutex
nodeInfo netutil.NodeInfo
}
func (f *defaultNodeInfoFactory) getNodeInfo() netutil.NodeInfo {
f.lock.Lock()
defer f.lock.Unlock()
if f.nodeInfo == nil {
nodeInfo, err := netutil.NewNodeInfo(nil)
if err != nil {
logger.Panicf("err=[%v]", err.Error())
}
f.nodeInfo = nodeInfo
}
return f.nodeInfo
}
func newNodeInfoFactory() NodeInfoFactory {
factory := &defaultNodeInfoFactory{}
return factory.getNodeInfo
}
// SetGlobalNodeInfo applies NodeInfo for the application
func SetGlobalNodeInfo(n netutil.NodeInfo) {
if n == nil {
logger.Panic("NodeInfo must not be nil")
}
factory := &defaultNodeInfoFactory{nodeInfo: n}
nodeInfoFactory = factory.getNodeInfo
}
// FromContext extracts the RequestContext stored inside a go context. Returns null if no such value exists.
func FromContext(ctx context.Context) *RequestContext {
ret, _ := ctx.Value(keyContext).(*RequestContext)
if ret == nil {
ret = &RequestContext{
identity: guestIdentity,
}
}
return ret
}
//AddToContext returns a new golang context that adds `rq` as the dolly request context.
func AddToContext(ctx context.Context, rq *RequestContext) context.Context {
return context.WithValue(ctx, keyContext, rq)
}
// FromRequest returns the full context ascocicated with this http request.
func FromRequest(r *http.Request) *RequestContext {
return FromContext(r.Context())
}
// NewContextHandler returns a handler that will extact the role & contextID from the request
// and stash them away in the request context for later handlers to use.
// Also adds header to indicate which host is currently servicing the request
func NewContextHandler(delegate http.Handler, identityMapper ProviderFromRequest) http.Handler {
h := func(w http.ResponseWriter, r *http.Request) {
// Set XHostname on the response
w.Header().Set(header.XHostname, nodeInfoFactory().HostName())
var rctx *RequestContext
v := r.Context().Value(keyContext)
if v == nil {
clientIP := ClientIPFromRequest(r)
identity, err := identityMapper(r)
if err != nil {
logger.Errorf("reason=identityMapper, ip=%q, err=[%v]", clientIP, err.Error())
marshal.WriteJSON(w, r, httperror.WithUnauthorized(err.Error()))
return
}
rctx = &RequestContext{
identity: identity,
correlationID: extractCorrelationID(r),
clientIP: clientIP,
}
r = r.WithContext(context.WithValue(r.Context(), keyContext, rctx))
} else {
rctx = v.(*RequestContext)
}
w.Header().Set(header.XCorrelationID, rctx.correlationID)
delegate.ServeHTTP(w, r)
}
return http.HandlerFunc(h)
}
var guestIdentity = NewIdentity(GuestRoleName, "", "")
// NewAuthUnaryInterceptor returns grpc.UnaryServerInterceptor that
// identity to the context
func NewAuthUnaryInterceptor(identityMapper ProviderFromContext) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var id Identity
var err error
id, err = identityMapper(ctx)
if err != nil {
return nil, status.Errorf(codes.PermissionDenied, "unable to get identity: %v", err)
}
if id == nil {
id = guestIdentity
}
ctx = AddToContext(ctx, NewRequestContext(id))
return handler(ctx, req)
}
}
// Identity returns request's identity
func (c *RequestContext) Identity() Identity {
return c.identity
}
// CorrelationID returns request's CorrelationID, extracted from X-CorrelationID header.
// If it was not provided by the client, the a random will be generated.
func (c *RequestContext) CorrelationID() string {
return c.correlationID
}
// ClientIP returns request's IP
func (c *RequestContext) ClientIP() string {
return c.clientIP
}
// extractCorrelationID will find or create a requestID for this http request.
func extractCorrelationID(req *http.Request) string {
corID := req.Header.Get(header.XCorrelationID)
if corID == "" {
corID = req.Header.Get(header.XDeviceID)
}
if corID == "" {
corID = guid.MustCreate()
}
return corID
}