-
Notifications
You must be signed in to change notification settings - Fork 721
/
reconnect.go
138 lines (124 loc) · 3.68 KB
/
reconnect.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package supervisor
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/retry"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
errJWTUnset = errors.New("JWT unset")
)
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
// get/set parameters for ReconnectTunnel RPC calls.
type reconnectCredentialManager struct {
mu sync.RWMutex
jwt []byte
eventDigest map[uint8][]byte
connDigest map[uint8][]byte
authSuccess prometheus.Counter
authFail *prometheus.CounterVec
}
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
authSuccess := prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_success",
Help: "Count of successful tunnel authenticate",
},
)
authFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_fail",
Help: "Count of tunnel authenticate errors by type",
},
[]string{"error"},
)
prometheus.MustRegister(authSuccess, authFail)
return &reconnectCredentialManager{
eventDigest: make(map[uint8][]byte, haConnections),
connDigest: make(map[uint8][]byte, haConnections),
authSuccess: authSuccess,
authFail: authFail,
}
}
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
if cm.jwt == nil {
return nil, errJWTUnset
}
return cm.jwt, nil
}
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.jwt = jwt
}
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.eventDigest[connID]
if !ok {
return nil, fmt.Errorf("no event digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.eventDigest[connID] = digest
}
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no connection digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.connDigest[connID] = digest
}
func (cm *reconnectCredentialManager) RefreshAuth(
ctx context.Context,
backoff *retry.BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
cm.authFail.WithLabelValues(err.Error()).Inc()
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
cm.SetReconnectToken(outcome.JWT())
cm.authSuccess.Inc()
return retry.Clock.After(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return retry.Clock.After(duration), nil
case tunnelpogs.AuthFail:
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome
default:
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
cm.authFail.WithLabelValues(err.Error()).Inc()
return nil, err
}
}