/
client_status_check.go
157 lines (134 loc) · 5.63 KB
/
client_status_check.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package chserver
import (
"context"
"time"
"github.com/openrport/openrport/server/clients"
"github.com/openrport/openrport/server/clients/clientdata"
"github.com/openrport/openrport/share/comm"
"github.com/openrport/openrport/share/logger"
)
const DefaultMaxWorkers = 100
type ClientsStatusCheckTask struct {
log *logger.Logger
clientsRepo *clients.ClientRepository
threshold time.Duration // Threshold after which a client to server ping is considered outdated.
pingTimeout time.Duration // Don't wait longer than pingTimeout for a response
}
// NewClientsStatusCheckTask pings all active clients and marks them disconnected on ping failure
func NewClientsStatusCheckTask(log *logger.Logger, cr *clients.ClientRepository, th time.Duration, pingTimeout time.Duration) *ClientsStatusCheckTask {
return &ClientsStatusCheckTask{
log: log.Fork("clients-status-check"),
clientsRepo: cr,
threshold: th,
pingTimeout: pingTimeout,
}
}
func (t *ClientsStatusCheckTask) Run(ctx context.Context) error {
t.log.Debugf("status check running")
timerStart := time.Now()
var confirmedClients = 0
dueClients, totalClientsCount := t.getDueClients()
if len(dueClients) == 0 {
// Nothing to do
t.log.Debugf("ended after %s, no clients to ping", time.Since(timerStart))
return nil
}
// make sure no more workers than clients and limit to max workers
maxWorkers := DefaultMaxWorkers
if maxWorkers > len(dueClients) {
maxWorkers = len(dueClients)
}
// make a channel that will receive all the clients to ping
clientsToPing := make(chan *clientdata.Client, len(dueClients))
// make another channel for ping results
results := make(chan bool, len(dueClients))
// create workers to ping clients
for w := 1; w <= maxWorkers; w++ {
go t.PingClients(ctx, w, clientsToPing, results)
}
// send the clients to ping to the workers
for _, dueClient := range dueClients {
clientsToPing <- dueClient
}
// we're done queuing clients for processing, so close the channel
close(clientsToPing)
// gather the results of pinged clients
var dead = 0
var alive = 0
// TODO: (rs): note this is fragile. any mismatch between actual and expected results will cause
// the task to block and essential hang. also there's no ctx checking.
for a := 0; a < len(dueClients); a++ {
if <-results {
alive++
} else {
dead++
}
}
t.log.Debugf("ended after %s, skipped: %d, pinged: %d, alive: %d, dead: %d, total: %d", time.Since(timerStart), confirmedClients, len(dueClients), alive, dead, totalClientsCount)
return nil
}
func (t *ClientsStatusCheckTask) getDueClients() (dueClients []*clientdata.Client, totalCount int) {
var confirmedClients = 0
var now = time.Now()
activeClients := t.clientsRepo.GetAllActiveClients()
for _, c := range activeClients {
// Shorten the threshold aka make heartbeat older than it is because the ping response is stored after this check.
// Clients would get checked only every second time otherwise.
if c.HasLastHeartbeatAt() {
lastHeartbeatAt := c.GetLastHeartbeatAtValue()
if now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second) {
// Skip all clients having sent a heartbeat from client to server recently
// t.log.Debugf("skipping client: %s, %s, %s", c.GetID(), lastHeartbeatAt, now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second))
confirmedClients++
continue
}
}
dueClients = append(dueClients, c)
}
return dueClients, len(activeClients)
}
func (t *ClientsStatusCheckTask) PingClients(ctx context.Context, workerNum int, clientsToPing <-chan *clientdata.Client, results chan<- bool) {
// while there are clients to ping
for cl := range clientsToPing {
clientName := cl.GetName()
clientID := cl.GetID()
ok, response, rtt, err := comm.PingConnectionWithTimeout(ctx, cl.GetConnection(), t.pingTimeout, cl.Log())
// t.log.Debugf("ok=%s, error=%s, response=%s", ok, err, response)
// Old clients cannot respond properly to a ping request yet
if !ok && err == nil && t.isLegacyClientResponse(response) {
t.log.Debugf("ping to %s [%s] succeeded in %s. client < 0.8.2", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// client versions from 0.9.2 to 0.9.6 can return "null" as a ping response. this is due to a bug
// in the client ping handling that cause 2 replies to be sent by the client. this breaks stuff.
// for the server, assume the null reply is a successful ping. unfortunately, the extra reply
// confuses the next send by the client which means it won't get a reply from the server and
// will ultimately disconnect and reconnect. the work around is to make sure that the client
// pings the server faster than the server pings the client. as the server has a recent heartbeat
// already (from the client) it won't ping the client again, meaning that the client won't get a
// chance to double reply to the server and cause ssh protocol confusion.
if ok && err == nil && string(response) == "null" {
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2 *", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// Only an empty response confirms the ping
if ok && err == nil && len(response) == 0 {
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// None of the above. Ping must have failed or timed out.
t.log.Infof("ping to %s [%s] failed: %s", clientName, clientID, err)
cl.SetDisconnectedNow()
cl.Close()
results <- false
}
}
func (t *ClientsStatusCheckTask) isLegacyClientResponse(response []byte) (isLegacy bool) {
return string(response) == "unknown request"
}