/
conn.go
287 lines (254 loc) · 8.76 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
package pkg
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"time"
"github.com/calamity-of-subterfuge/cos/pkg/utils"
"github.com/gorilla/websocket"
)
// PreparablePacket can be implemented by packets to get a chance to initialize
// things prior to being marshaled and sent across the connection. This is most
// typically used for setting the Type field on the packet.
type PreparablePacket interface {
// PrepareForMarshal is called prior to marshalling the packet which is to
// be sent across the connection.
PrepareForMarshal()
}
// ReceivedMessage describes a message along with the connection it was
// received on.
type ReceivedMessage struct {
// ConnectionUID is the UID of the WebsocketChannelConn that this message
// came from.
ConnectionUID string
// Message is the parsed message that was received. Note that this may have
// been only part of the actual logical message frame, since the calamity of
// subterfuge protocol allows multiple messages per message frame to reduce
// overhead on small messages.
Message map[string]interface{}
}
// Conn is a convenience wrapper around a basic websocket connection which uses
// channels for send/receive of packets in the format expected by the calamity
// of subterfuge lobby socket and game socket protocols. The connection itself
// manages the required goroutines that read from the send queue and write to
// the receive queue, which can be canceled using Close.
type Conn struct {
// UID is the identifier of this connection which is forwarded alongside all
// messages to the receiving channel. This allows multiple connections to
// use the same receive channel if it's desirable to do so. It may be left
// blank if the receive channel only has a single connection and hence the
// UID of the connection is superfluous.
UID string
// SendQueue is the channel which the Conn reads from in order to write to the
// actual websocket.
SendQueue chan interface{}
recvQueue chan ReceivedMessage
closedQueue chan string
cancelSignal chan struct{}
conn *websocket.Conn
}
// NewConn takes over management of the given websocket connection and returns
// the managed connection. It is not safe to use the websocket directly once
// this function is called.
//
// The uid is used to distinguish messages from this connection if there are
// multiple connections using the same receive queue and closed queue. It may
// be left as a blank string if the receive and closed queues only have one
// connection and hence the uid is not required to distinguish the source.
//
// A message is written to the receive queue whenever the server sends us a
// message. Our uid is written to the closedQueue exactly once when the
// underlying websocket connection is closed.
func NewConn(conn *websocket.Conn, uid string, recvQueue chan ReceivedMessage, closedQueue chan string) *Conn {
res := &Conn{
UID: uid,
SendQueue: make(chan interface{}, 128),
recvQueue: recvQueue,
closedQueue: closedQueue,
cancelSignal: make(chan struct{}, 1),
conn: conn,
}
go res.manageSend()
go res.manageRecv()
return res
}
// Close the connection if it's not already closed. If the socket is
// currently open, this will result in our uid being written to the
// closedQueue after a short delay.
func (c *Conn) Close() {
select {
case c.cancelSignal <- struct{}{}:
default:
}
}
func (c *Conn) manageSend() {
lowerTimeout := utils.CONN_READ_TIMEOUT
if utils.CONN_WRITE_TIMEOUT < lowerTimeout {
lowerTimeout = utils.CONN_WRITE_TIMEOUT
}
pingInterval := (lowerTimeout * 9) / 10
pingTicker := time.NewTicker(pingInterval)
packets := make([]interface{}, 0, 1)
outerLoop:
for {
select {
case packet := <-c.SendQueue:
// Batching sends can significantly improve performance
packets = append(packets, packet)
readPacketsLoop:
for len(packets) < 16 {
select {
case nextPacket := <-c.SendQueue:
packets = append(packets, nextPacket)
default:
break readPacketsLoop
}
}
err := c.sendPackets(packets)
packets = packets[:0]
if err != nil {
if !errors.Is(err, net.ErrClosed) {
log.Printf("error sending packets to conn %s: %v", c.UID, err)
}
c.Close()
break outerLoop
}
case <-pingTicker.C:
err := c.conn.SetWriteDeadline(time.Now().Add(utils.CONN_WRITE_TIMEOUT))
if err != nil {
log.Printf("Error setting write dealding for ping to %s: %v", c.UID, err)
c.Close()
break outerLoop
}
err = c.conn.WriteMessage(websocket.PingMessage, nil)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
log.Printf("Failed to write ping to connection %s: %v", c.UID, err)
}
c.Close()
break outerLoop
}
case <-c.cancelSignal:
// It's nice to try and send the last couple packets here so that
// when using this you don't have to do this awkward thing where
// when you want to close the connection cleanly you need to write
// the packets and wait "a bit". This isn't perfect since if there's
// too many packets in the queue it still won't get them all, but
// that should basically never happen
readPacketsLoop2:
for len(packets) < 16 {
select {
case nextPacket := <-c.SendQueue:
packets = append(packets, nextPacket)
default:
break readPacketsLoop2
}
}
if len(packets) > 0 {
err := c.sendPackets(packets)
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("error sending final packets to conn %s: %v", c.UID, err)
}
}
c.Close()
break outerLoop
}
}
cerr := c.conn.SetWriteDeadline(time.Now().Add(utils.CONN_WRITE_TIMEOUT))
if cerr != nil {
log.Printf("failed to set write deadline on %s for close code: %v", c.UID, cerr)
} else {
cerr = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "forcibly disconnecting"))
if cerr != nil && !errors.Is(cerr, net.ErrClosed) {
log.Printf("failed to send close code to conn %s: %v", c.UID, cerr)
}
}
cerr = c.conn.Close()
if cerr != nil && !errors.Is(cerr, net.ErrClosed) {
log.Printf("failed to close connection %s on send close: %v", c.UID, cerr)
}
c.closedQueue <- c.UID
}
func (c *Conn) sendPackets(packets []interface{}) error {
for _, pkt := range packets {
if preparablePacket, ok := pkt.(PreparablePacket); ok {
preparablePacket.PrepareForMarshal()
}
}
marshalledPackets, err := json.Marshal(packets)
if err != nil {
for _, pkt := range packets {
_, subErr := json.Marshal(pkt)
if subErr != nil {
return fmt.Errorf("failed to marshal packet %v: %w", pkt, err)
}
}
return fmt.Errorf("failed to marshal packets despite each individual packet marshalling fine! packets: %v, err: %w", packets, err)
}
err = c.conn.SetWriteDeadline(time.Now().Add(utils.CONN_WRITE_TIMEOUT))
if err != nil {
return fmt.Errorf("failed to set write deadline: %w", err)
}
err = c.conn.WriteMessage(websocket.TextMessage, marshalledPackets)
if err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
return nil
}
func (c *Conn) manageRecv() {
// Receive is naturally canceled promptly by manageSend
// closing the websocket
c.conn.SetPongHandler(func(string) error {
return c.conn.SetReadDeadline(time.Now().Add(utils.CONN_READ_TIMEOUT))
})
for {
err := c.conn.SetReadDeadline(time.Now().Add(utils.CONN_READ_TIMEOUT))
if err != nil {
log.Printf("Failed to set read deadline for %s: %v", c.UID, err)
c.Close()
break
}
var messageType int
var message []byte
messageType, message, err = c.conn.ReadMessage()
if err != nil {
log.Printf("Failed to read message from %s: %v", c.UID, err)
c.Close()
break
}
if messageType != websocket.TextMessage {
log.Printf("Invalid incoming message type from %s: %v", c.UID, messageType)
c.Close()
break
}
decoder := json.NewDecoder(bytes.NewBuffer(message))
decoder.UseNumber()
var decodedMessage interface{}
err = decoder.Decode(&decodedMessage)
if err != nil {
log.Printf("Failed to decode incoming message from %s: %v", c.UID, err)
c.Close()
break
}
if arr, ok := decodedMessage.([]interface{}); ok {
for _, packet := range arr {
c.recvQueue <- ReceivedMessage{
ConnectionUID: c.UID,
Message: packet.(map[string]interface{}),
}
}
} else if packet, ok := decodedMessage.(map[string]interface{}); ok {
c.recvQueue <- ReceivedMessage{
ConnectionUID: c.UID,
Message: packet,
}
} else {
log.Printf("Unknown format for incoming message from %s: %v", c.UID, decodedMessage)
c.Close()
break
}
}
}