forked from hashicorp/vault
/
request_forwarding_rpc.go
133 lines (115 loc) · 3.47 KB
/
request_forwarding_rpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package vault
import (
"context"
"net/http"
"runtime"
"sync/atomic"
"time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/forwarding"
cache "github.com/patrickmn/go-cache"
)
type forwardedRequestRPCServer struct {
core *Core
handler http.Handler
perfStandbySlots chan struct{}
perfStandbyRepCluster *ReplicatedCluster
perfStandbyCache *cache.Cache
}
func (s *forwardedRequestRPCServer) ForwardRequest(ctx context.Context, freq *forwarding.Request) (*forwarding.Response, error) {
// Parse an http.Request out of it
req, err := forwarding.ParseForwardedRequest(freq)
if err != nil {
return nil, err
}
// A very dummy response writer that doesn't follow normal semantics, just
// lets you write a status code (last written wins) and a body. But it
// meets the interface requirements.
w := forwarding.NewRPCResponseWriter()
resp := &forwarding.Response{}
runRequest := func() {
defer func() {
// Logic here comes mostly from the Go source code
if err := recover(); err != nil {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.core.logger.Error("panic serving forwarded request", "path", req.URL.Path, "error", err, "stacktrace", string(buf))
}
}()
s.handler.ServeHTTP(w, req)
}
runRequest()
resp.StatusCode = uint32(w.StatusCode())
resp.Body = w.Body().Bytes()
header := w.Header()
if header != nil {
resp.HeaderEntries = make(map[string]*forwarding.HeaderEntry, len(header))
for k, v := range header {
resp.HeaderEntries[k] = &forwarding.HeaderEntry{
Values: v,
}
}
}
resp.LastRemoteWal = LastRemoteWAL(s.core)
return resp, nil
}
func (s *forwardedRequestRPCServer) Echo(ctx context.Context, in *EchoRequest) (*EchoReply, error) {
if in.ClusterAddr != "" {
s.core.clusterPeerClusterAddrsCache.Set(in.ClusterAddr, nil, 0)
}
return &EchoReply{
Message: "pong",
ReplicationState: uint32(s.core.ReplicationState()),
}, nil
}
type forwardingClient struct {
RequestForwardingClient
core *Core
echoTicker *time.Ticker
echoContext context.Context
}
// NOTE: we also take advantage of gRPC's keepalive bits, but as we send data
// with these requests it's useful to keep this as well
func (c *forwardingClient) startHeartbeat() {
go func() {
tick := func() {
c.core.stateLock.RLock()
clusterAddr := c.core.clusterAddr
c.core.stateLock.RUnlock()
ctx, cancel := context.WithTimeout(c.echoContext, 2*time.Second)
resp, err := c.RequestForwardingClient.Echo(ctx, &EchoRequest{
Message: "ping",
ClusterAddr: clusterAddr,
})
cancel()
if err != nil {
c.core.logger.Debug("forwarding: error sending echo request to active node", "error", err)
return
}
if resp == nil {
c.core.logger.Debug("forwarding: empty echo response from active node")
return
}
if resp.Message != "pong" {
c.core.logger.Debug("forwarding: unexpected echo response from active node", "message", resp.Message)
return
}
// Store the active node's replication state to display in
// sys/health calls
atomic.StoreUint32(c.core.activeNodeReplicationState, resp.ReplicationState)
}
tick()
for {
select {
case <-c.echoContext.Done():
c.echoTicker.Stop()
c.core.logger.Debug("forwarding: stopping heartbeating")
atomic.StoreUint32(c.core.activeNodeReplicationState, uint32(consts.ReplicationUnknown))
return
case <-c.echoTicker.C:
tick()
}
}
}()
}