-
Notifications
You must be signed in to change notification settings - Fork 33
/
ensure_polling.go
205 lines (190 loc) · 6.89 KB
/
ensure_polling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package handler
import (
"context"
"sync"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/sliding-sync/pubsub"
)
// pendingInfo tracks the status of a poller that we are (or previously were) waiting
// to start.
type pendingInfo struct {
// done is set to true when the EnsurePolling request received a response.
done bool
// expired is true when the token is expired. Any 'done'ness should be ignored for expired tokens.
expired bool
// ch is a dummy channel which never receives any data. A call to
// EnsurePoller.OnInitialSyncComplete will close the channel (unblocking any
// EnsurePoller.EnsurePolling calls which are waiting on it) and then set the ch
// field to nil.
ch chan struct{}
}
// EnsurePoller is a gadget used by the sliding sync request handler to ensure that
// we are running a v2 poller for a given device.
type EnsurePoller struct {
chanName string
// mu guards reads and writes to pendingPolls.
mu *sync.Mutex
// pendingPolls tracks the status of pollers that we are waiting to start.
pendingPolls map[sync2.PollerID]pendingInfo
notifier pubsub.Notifier
// the total number of outstanding ensurepolling requests.
numPendingEnsurePolling prometheus.Gauge
}
func NewEnsurePoller(notifier pubsub.Notifier, enablePrometheus bool) *EnsurePoller {
p := &EnsurePoller{
chanName: pubsub.ChanV3,
mu: &sync.Mutex{},
pendingPolls: make(map[sync2.PollerID]pendingInfo),
notifier: notifier,
}
if enablePrometheus {
p.numPendingEnsurePolling = prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "sliding_sync",
Subsystem: "api",
Name: "num_devices_pending_ensure_polling",
Help: "Number of devices blocked on EnsurePolling returning.",
})
prometheus.MustRegister(p.numPendingEnsurePolling)
}
return p
}
// EnsurePolling blocks until the V2InitialSyncComplete response is received for this device. It is
// the caller's responsibility to call OnInitialSyncComplete when new events arrive. Returns whether
// or not the token is expired
func (p *EnsurePoller) EnsurePolling(ctx context.Context, pid sync2.PollerID, tokenHash string) bool {
ctx, region := internal.StartSpan(ctx, "EnsurePolling")
defer region.End()
p.mu.Lock()
// do we need to wait? Expired devices ALWAYS need a fresh poll
expired := p.pendingPolls[pid].expired
if !expired && p.pendingPolls[pid].done {
internal.Logf(ctx, "EnsurePolling", "user %s device %s already done", pid.UserID, pid.DeviceID)
p.mu.Unlock()
return expired // always false
}
// either we have expired or we haven't expired and are still waiting on the initial sync.
// If we have expired, nuke the pid from the map now so we will use the same code path as a fresh device sync
if expired {
// close any existing channel
if p.pendingPolls[pid].ch != nil {
close(p.pendingPolls[pid].ch)
}
delete(p.pendingPolls, pid)
// at this point, ch == nil so we will do an initial sync
}
// have we called EnsurePolling for this user/device before?
ch := p.pendingPolls[pid].ch
if ch != nil {
p.mu.Unlock()
// we already called EnsurePolling on this device, so just listen for the close
// TODO: several times there have been problems getting the response back from the poller
// we should time out here after 100s and return an error or something to kick conns into
// trying again
internal.Logf(ctx, "EnsurePolling", "user %s device %s channel exits, listening for channel close", pid.UserID, pid.DeviceID)
_, r2 := internal.StartSpan(ctx, "waitForExistingChannelClose")
<-ch
r2.End()
p.mu.Lock()
expired := p.pendingPolls[pid].expired
p.mu.Unlock()
return expired
}
// Make a channel to wait until we have done an initial sync
ch = make(chan struct{})
p.pendingPolls[pid] = pendingInfo{
done: false,
ch: ch,
}
p.calculateNumOutstanding() // increment total
p.mu.Unlock()
// ask the pollers to poll for this device
p.notifier.Notify(p.chanName, &pubsub.V3EnsurePolling{
UserID: pid.UserID,
DeviceID: pid.DeviceID,
AccessTokenHash: tokenHash,
})
// if by some miracle the notify AND sync completes before we receive on ch then this is
// still fine as recv on a closed channel will return immediately.
internal.Logf(ctx, "EnsurePolling", "user %s device %s just made channel, listening for channel close", pid.UserID, pid.DeviceID)
_, r2 := internal.StartSpan(ctx, "waitForNewChannelClose")
<-ch
r2.End()
p.mu.Lock()
expired = p.pendingPolls[pid].expired
p.mu.Unlock()
return expired
}
func (p *EnsurePoller) OnInitialSyncComplete(payload *pubsub.V2InitialSyncComplete) {
log := logger.With().Str("user", payload.UserID).Str("device", payload.DeviceID).Logger()
log.Trace().Msg("OnInitialSyncComplete: got payload")
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
p.mu.Lock()
defer p.mu.Unlock()
pending, ok := p.pendingPolls[pid]
// were we waiting for this initial sync to complete?
if !ok {
// This can happen when the v2 poller spontaneously starts polling even without us asking it to
// e.g from the database
log.Trace().Msg("OnInitialSyncComplete: we weren't waiting for this")
p.pendingPolls[pid] = pendingInfo{
done: true,
expired: !payload.Success,
}
return
}
if pending.done {
log.Trace().Msg("OnInitialSyncComplete: already done")
// nothing to do, we just got OnInitialSyncComplete called twice
return
}
// we get here if we asked the poller to start via EnsurePolling, so let's make that goroutine
// wake up now
ch := pending.ch
pending.done = true
pending.ch = nil
// If for whatever reason we get OnExpiredToken prior to OnInitialSyncComplete, don't forget that
// we expired the token i.e expiry latches true.
if !pending.expired {
pending.expired = !payload.Success
}
p.pendingPolls[pid] = pending
p.calculateNumOutstanding() // decrement total
log.Trace().Msg("OnInitialSyncComplete: closing channel")
close(ch)
}
func (p *EnsurePoller) OnExpiredToken(payload *pubsub.V2ExpiredToken) {
pid := sync2.PollerID{UserID: payload.UserID, DeviceID: payload.DeviceID}
p.mu.Lock()
defer p.mu.Unlock()
pending, exists := p.pendingPolls[pid]
if !exists {
// We weren't tracking the state of this poller, so we have nothing to clean up.
return
}
pending.expired = true
p.pendingPolls[pid] = pending
// We used to delete the entry from the map at this point to force the next
// EnsurePolling call to do a fresh EnsurePolling request, but now we do that
// by signalling via the expired flag.
}
func (p *EnsurePoller) Teardown() {
p.notifier.Close()
if p.numPendingEnsurePolling != nil {
prometheus.Unregister(p.numPendingEnsurePolling)
}
}
// must hold p.mu
func (p *EnsurePoller) calculateNumOutstanding() {
if p.numPendingEnsurePolling == nil {
return
}
var total int
for _, pi := range p.pendingPolls {
if !pi.done {
total++
}
}
p.numPendingEnsurePolling.Set(float64(total))
}