-
Notifications
You must be signed in to change notification settings - Fork 109
/
websocket.go
1153 lines (1058 loc) · 38.6 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// The package is a wrapper around gorilla websockets,
// aimed at simplifying the creation and usage of a websocket client/server.
//
// Check the Client and Server structure to get started.
package ws
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
"path"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/lorenzodonini/ocpp-go/logging"
)
const (
// Time allowed to write a message to the peer.
defaultWriteWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
defaultPongWait = 60 * time.Second
// Time allowed to wait for a ping on the server, before closing a connection due to inactivity.
defaultPingWait = defaultPongWait
// Send pings to peer with this period. Must be less than pongWait.
defaultPingPeriod = (defaultPongWait * 9) / 10
// Time allowed for the initial handshake to complete.
defaultHandshakeTimeout = 30 * time.Second
// When the Charging Station is reconnecting, after a connection loss, it will use this variable for the amount of time
// it will double the previous back-off time. When the maximum number of increments is reached, the Charging
// Station keeps connecting with the same back-off time.
defaultRetryBackOffRepeatTimes = 5
// When the Charging Station is reconnecting, after a connection loss, it will use this variable as the maximum value
// for the random part of the back-off time. It will add a new random value to every increasing back-off time,
// including the first connection attempt (with this maximum), for the amount of times it will double the previous
// back-off time. When the maximum number of increments is reached, the Charging Station will keep connecting
// with the same back-off time.
defaultRetryBackOffRandomRange = 15 // seconds
// When the Charging Station is reconnecting, after a connection loss, it will use this variable as the minimum backoff
// time, the first time it tries to reconnect.
defaultRetryBackOffWaitMinimum = 10 * time.Second
)
// The internal verbose logger
var log logging.Logger
// Sets a custom Logger implementation, allowing the package to log events.
// By default, a VoidLogger is used, so no logs will be sent to any output.
//
// The function panics, if a nil logger is passed.
func SetLogger(logger logging.Logger) {
if logger == nil {
panic("cannot set a nil logger")
}
log = logger
}
// Config contains optional configuration parameters for a websocket server.
// Setting the parameter allows to define custom timeout intervals for websocket network operations.
//
// To set a custom configuration, refer to the server's SetTimeoutConfig method.
// If no configuration is passed, a default configuration is generated via the NewServerTimeoutConfig function.
type ServerTimeoutConfig struct {
WriteWait time.Duration
PingWait time.Duration
}
// NewServerTimeoutConfig creates a default timeout configuration for a websocket endpoint.
//
// You may change fields arbitrarily and pass the struct to a SetTimeoutConfig method.
func NewServerTimeoutConfig() ServerTimeoutConfig {
return ServerTimeoutConfig{WriteWait: defaultWriteWait, PingWait: defaultPingWait}
}
// Config contains optional configuration parameters for a websocket client.
// Setting the parameter allows to define custom timeout intervals for websocket network operations.
//
// To set a custom configuration, refer to the client's SetTimeoutConfig method.
// If no configuration is passed, a default configuration is generated via the NewClientTimeoutConfig function.
type ClientTimeoutConfig struct {
WriteWait time.Duration
HandshakeTimeout time.Duration
PongWait time.Duration
PingPeriod time.Duration
RetryBackOffRepeatTimes int
RetryBackOffRandomRange int
RetryBackOffWaitMinimum time.Duration
}
// NewClientTimeoutConfig creates a default timeout configuration for a websocket endpoint.
//
// You may change fields arbitrarily and pass the struct to a SetTimeoutConfig method.
func NewClientTimeoutConfig() ClientTimeoutConfig {
return ClientTimeoutConfig{
WriteWait: defaultWriteWait,
HandshakeTimeout: defaultHandshakeTimeout,
PongWait: defaultPongWait,
PingPeriod: defaultPingPeriod,
RetryBackOffRepeatTimes: defaultRetryBackOffRepeatTimes,
RetryBackOffRandomRange: defaultRetryBackOffRandomRange,
RetryBackOffWaitMinimum: defaultRetryBackOffWaitMinimum,
}
}
// Channel represents a bi-directional communication channel, which provides at least a unique ID.
type Channel interface {
ID() string
RemoteAddr() net.Addr
TLSConnectionState() *tls.ConnectionState
}
// WebSocket is a wrapper for a single websocket channel.
// The connection itself is provided by the gorilla websocket package.
//
// Don't use a websocket directly, but refer to WsServer and WsClient.
type WebSocket struct {
connection *websocket.Conn
id string
outQueue chan []byte
closeC chan websocket.CloseError // used to gracefully close a websocket connection.
forceCloseC chan error // used by the readPump to notify a forcefully closed connection to the writePump.
pingMessage chan []byte
tlsConnectionState *tls.ConnectionState
}
// Retrieves the unique Identifier of the websocket (typically, the URL suffix).
func (websocket *WebSocket) ID() string {
return websocket.id
}
// Returns the address of the remote peer.
func (websocket *WebSocket) RemoteAddr() net.Addr {
return websocket.connection.RemoteAddr()
}
// Returns the TLS connection state of the connection, if any.
func (websocket *WebSocket) TLSConnectionState() *tls.ConnectionState {
return websocket.tlsConnectionState
}
// ConnectionError is a websocket
type HttpConnectionError struct {
Message string
HttpStatus string
HttpCode int
Details string
}
func (e HttpConnectionError) Error() string {
return fmt.Sprintf("%v, http status: %v", e.Message, e.HttpStatus)
}
// ---------------------- SERVER ----------------------
type CheckClientHandler func(id string, r *http.Request) bool
// WsServer defines a websocket server, which passively listens for incoming connections on ws or wss protocol.
// The offered API are of asynchronous nature, and each incoming connection/message is handled using callbacks.
//
// To create a new ws server, use:
//
// server := NewServer()
//
// If you need a TLS ws server instead, use:
//
// server := NewTLSServer("cert.pem", "privateKey.pem")
//
// To support client basic authentication, use:
//
// server.SetBasicAuthHandler(func (user, pass) bool {
// ok := authenticate(user, pass) // ... check for user and pass correctness
// return ok
// })
//
// To specify supported sub-protocols, use:
//
// server.AddSupportedSubprotocol("ocpp1.6")
//
// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method.
//
// Using Start and Stop you can respectively start and stop listening for incoming client websocket connections.
//
// To be notified of new and terminated connections,
// refer to SetNewClientHandler and SetDisconnectedClientHandler functions.
//
// To receive incoming messages, you will need to set your own handler using SetMessageHandler.
// To write data on the open socket, simply call the Write function.
type WsServer interface {
// Starts and runs the websocket server on a specific port and URL.
// After start, incoming connections and messages are handled automatically, so no explicit read operation is required.
//
// The functions blocks forever, hence it is suggested to invoke it in a goroutine, if the caller thread needs to perform other work, e.g.:
// go server.Start(8887, "/ws/{id}")
// doStuffOnMainThread()
// ...
//
// To stop a running server, call the Stop function.
Start(port int, listenPath string)
// Shuts down a running websocket server.
// All open channels will be forcefully closed, and the previously called Start function will return.
Stop()
// Closes a specific websocket connection.
StopConnection(id string, closeError websocket.CloseError) error
// Errors returns a channel for error messages. If it doesn't exist it es created.
// The channel is closed by the server when stopped.
Errors() <-chan error
// Sets a callback function for all incoming messages.
// The callbacks accept a Channel and the received data.
// It is up to the callback receiver, to check the identifier of the channel, to determine the source of the message.
SetMessageHandler(handler func(ws Channel, data []byte) error)
// Sets a callback function for all new incoming client connections.
// It is recommended to store a reference to the Channel in the received entity, so that the Channel may be recognized later on.
SetNewClientHandler(handler func(ws Channel))
// Sets a callback function for all client disconnection events.
// Once a client is disconnected, it is not possible to read/write on the respective Channel any longer.
SetDisconnectedClientHandler(handler func(ws Channel))
// Set custom timeout configuration parameters. If not passed, a default ServerTimeoutConfig struct will be used.
//
// This function must be called before starting the server, otherwise it may lead to unexpected behavior.
SetTimeoutConfig(config ServerTimeoutConfig)
// Sends a message on a specific Channel, identifier by the webSocketId parameter.
// If the passed ID is invalid, an error is returned.
//
// The data is queued and will be sent asynchronously in the background.
Write(webSocketId string, data []byte) error
// Adds support for a specified subprotocol.
// This is recommended in order to communicate the capabilities to the client during the handshake.
// If left empty, any subprotocol will be accepted.
//
// Duplicates will be removed automatically.
AddSupportedSubprotocol(subProto string)
// SetBasicAuthHandler enables HTTP Basic Authentication and requires clients to pass credentials.
// The handler function is called whenever a new client attempts to connect, to check for credentials correctness.
// The handler must return true if the credentials were correct, false otherwise.
SetBasicAuthHandler(handler func(username string, password string) bool)
// SetCheckOriginHandler sets a handler for incoming websocket connections, allowing to perform
// custom cross-origin checks.
//
// By default, if the Origin header is present in the request, and the Origin host is not equal
// to the Host request header, the websocket handshake fails.
SetCheckOriginHandler(handler func(r *http.Request) bool)
// SetCheckClientHandler sets a handler for validate incoming websocket connections, allowing to perform
// custom client connection checks.
SetCheckClientHandler(handler func(id string, r *http.Request) bool)
// Addr gives the address on which the server is listening, useful if, for
// example, the port is system-defined (set to 0).
Addr() *net.TCPAddr
}
// Default implementation of a Websocket server.
//
// Use the NewServer or NewTLSServer functions to create a new server.
type Server struct {
connections map[string]*WebSocket
httpServer *http.Server
messageHandler func(ws Channel, data []byte) error
checkClientHandler func(id string, r *http.Request) bool
newClientHandler func(ws Channel)
disconnectedHandler func(ws Channel)
basicAuthHandler func(username string, password string) bool
tlsCertificatePath string
tlsCertificateKey string
timeoutConfig ServerTimeoutConfig
upgrader websocket.Upgrader
errC chan error
connMutex sync.RWMutex
addr *net.TCPAddr
httpHandler *mux.Router
}
// Creates a new simple websocket server (the websockets are not secured).
func NewServer() *Server {
router := mux.NewRouter()
return &Server{
httpServer: &http.Server{},
timeoutConfig: NewServerTimeoutConfig(),
upgrader: websocket.Upgrader{Subprotocols: []string{}},
httpHandler: router,
}
}
// NewTLSServer creates a new secure websocket server. All created websocket channels will use TLS.
//
// You need to pass a filepath to the server TLS certificate and key.
//
// It is recommended to pass a valid TLSConfig for the server to use.
// For example to require client certificate verification:
//
// tlsConfig := &tls.Config{
// ClientAuth: tls.RequireAndVerifyClientCert,
// ClientCAs: clientCAs,
// }
//
// If no tlsConfig parameter is passed, the server will by default
// not perform any client certificate verification.
func NewTLSServer(certificatePath string, certificateKey string, tlsConfig *tls.Config) *Server {
router := mux.NewRouter()
return &Server{
tlsCertificatePath: certificatePath,
tlsCertificateKey: certificateKey,
httpServer: &http.Server{
TLSConfig: tlsConfig,
},
timeoutConfig: NewServerTimeoutConfig(),
upgrader: websocket.Upgrader{Subprotocols: []string{}},
httpHandler: router,
}
}
func (server *Server) SetMessageHandler(handler func(ws Channel, data []byte) error) {
server.messageHandler = handler
}
func (server *Server) SetCheckClientHandler(handler func(id string, r *http.Request) bool) {
server.checkClientHandler = handler
}
func (server *Server) SetNewClientHandler(handler func(ws Channel)) {
server.newClientHandler = handler
}
func (server *Server) SetDisconnectedClientHandler(handler func(ws Channel)) {
server.disconnectedHandler = handler
}
func (server *Server) SetTimeoutConfig(config ServerTimeoutConfig) {
server.timeoutConfig = config
}
func (server *Server) AddSupportedSubprotocol(subProto string) {
for _, sub := range server.upgrader.Subprotocols {
if sub == subProto {
// Don't add duplicates
return
}
}
server.upgrader.Subprotocols = append(server.upgrader.Subprotocols, subProto)
}
func (server *Server) SetBasicAuthHandler(handler func(username string, password string) bool) {
server.basicAuthHandler = handler
}
func (server *Server) SetCheckOriginHandler(handler func(r *http.Request) bool) {
server.upgrader.CheckOrigin = handler
}
func (server *Server) error(err error) {
log.Error(err)
if server.errC != nil {
server.errC <- err
}
}
func (server *Server) Errors() <-chan error {
if server.errC == nil {
server.errC = make(chan error, 1)
}
return server.errC
}
func (server *Server) Addr() *net.TCPAddr {
return server.addr
}
func (server *Server) AddHttpHandler(listenPath string, handler func(w http.ResponseWriter, r *http.Request)) {
server.httpHandler.HandleFunc(listenPath, handler)
}
func (server *Server) Start(port int, listenPath string) {
server.connMutex.Lock()
server.connections = make(map[string]*WebSocket)
server.connMutex.Unlock()
if server.httpServer == nil {
server.httpServer = &http.Server{}
}
addr := fmt.Sprintf(":%v", port)
server.httpServer.Addr = addr
server.AddHttpHandler(listenPath, func(w http.ResponseWriter, r *http.Request) {
server.wsHandler(w, r)
})
server.httpServer.Handler = server.httpHandler
ln, err := net.Listen("tcp", addr)
if err != nil {
server.error(fmt.Errorf("failed to listen: %w", err))
return
}
server.addr = ln.Addr().(*net.TCPAddr)
defer ln.Close()
log.Infof("listening on tcp network %v", addr)
server.httpServer.RegisterOnShutdown(server.stopConnections)
if server.tlsCertificatePath != "" && server.tlsCertificateKey != "" {
err = server.httpServer.ServeTLS(ln, server.tlsCertificatePath, server.tlsCertificateKey)
} else {
err = server.httpServer.Serve(ln)
}
if err != http.ErrServerClosed {
server.error(fmt.Errorf("failed to listen: %w", err))
}
}
func (server *Server) Stop() {
log.Info("stopping websocket server")
err := server.httpServer.Shutdown(context.TODO())
if err != nil {
server.error(fmt.Errorf("shutdown failed: %w", err))
}
if server.errC != nil {
close(server.errC)
server.errC = nil
}
}
func (server *Server) StopConnection(id string, closeError websocket.CloseError) error {
server.connMutex.RLock()
ws, ok := server.connections[id]
server.connMutex.RUnlock()
if !ok {
return fmt.Errorf("couldn't stop websocket connection. No connection with id %s is open", id)
}
log.Debugf("sending stop signal for websocket %s", ws.ID())
ws.closeC <- closeError
return nil
}
func (server *Server) stopConnections() {
server.connMutex.RLock()
defer server.connMutex.RUnlock()
for _, conn := range server.connections {
conn.closeC <- websocket.CloseError{Code: websocket.CloseNormalClosure, Text: ""}
}
}
func (server *Server) Write(webSocketId string, data []byte) error {
server.connMutex.RLock()
defer server.connMutex.RUnlock()
ws, ok := server.connections[webSocketId]
if !ok {
return fmt.Errorf("couldn't write to websocket. No socket with id %v is open", webSocketId)
}
log.Debugf("queuing data for websocket %s", webSocketId)
ws.outQueue <- data
return nil
}
func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
responseHeader := http.Header{}
url := r.URL
id := path.Base(url.Path)
log.Debugf("handling new connection for %s from %s", id, r.RemoteAddr)
// Negotiate sub-protocol
clientSubprotocols := websocket.Subprotocols(r)
negotiatedSuprotocol := ""
out:
for _, requestedProto := range clientSubprotocols {
if len(server.upgrader.Subprotocols) == 0 {
// All subProtocols are accepted, pick first
negotiatedSuprotocol = requestedProto
break
}
// Check if requested suprotocol is supported by server
for _, supportedProto := range server.upgrader.Subprotocols {
if requestedProto == supportedProto {
negotiatedSuprotocol = requestedProto
break out
}
}
}
if negotiatedSuprotocol != "" {
responseHeader.Add("Sec-WebSocket-Protocol", negotiatedSuprotocol)
}
// Handle client authentication
if server.basicAuthHandler != nil {
username, password, ok := r.BasicAuth()
if ok {
ok = server.basicAuthHandler(username, password)
}
if !ok {
server.error(fmt.Errorf("basic auth failed: credentials invalid"))
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
if server.checkClientHandler != nil {
ok := server.checkClientHandler(id, r)
if !ok {
server.error(fmt.Errorf("client validation: invalid client"))
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
}
// Upgrade websocket
conn, err := server.upgrader.Upgrade(w, r, responseHeader)
if err != nil {
server.error(fmt.Errorf("upgrade failed: %w", err))
return
}
// The id of the charge point is the final path element
ws := WebSocket{
connection: conn,
id: id,
outQueue: make(chan []byte, 1),
closeC: make(chan websocket.CloseError, 1),
forceCloseC: make(chan error, 1),
pingMessage: make(chan []byte, 1),
tlsConnectionState: r.TLS,
}
log.Debugf("upgraded websocket connection for %s from %s", id, conn.RemoteAddr().String())
// If unsupported subprotocol, terminate the connection immediately
if negotiatedSuprotocol == "" {
server.error(fmt.Errorf("unsupported subprotocols %v for new client %v (%v)", clientSubprotocols, id, r.RemoteAddr))
_ = conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseProtocolError, "invalid or unsupported subprotocol"),
time.Now().Add(server.timeoutConfig.WriteWait))
_ = conn.Close()
return
}
// Check whether client exists
server.connMutex.Lock()
// There is already a connection with the same ID. Close the new one immediately with a PolicyViolation.
if _, exists := server.connections[id]; exists {
server.connMutex.Unlock()
server.error(fmt.Errorf("client %s already exists, closing duplicate client", id))
_ = conn.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "a connection with this ID already exists"),
time.Now().Add(server.timeoutConfig.WriteWait))
_ = conn.Close()
return
}
// Add new client
server.connections[ws.id] = &ws
server.connMutex.Unlock()
// Read and write routines are started in separate goroutines and function will return immediately
go server.writePump(&ws)
go server.readPump(&ws)
if server.newClientHandler != nil {
var channel Channel = &ws
server.newClientHandler(channel)
}
}
func (server *Server) getReadTimeout() time.Time {
if server.timeoutConfig.PingWait == 0 {
return time.Time{}
}
return time.Now().Add(server.timeoutConfig.PingWait)
}
func (server *Server) readPump(ws *WebSocket) {
conn := ws.connection
conn.SetPingHandler(func(appData string) error {
log.Debugf("ping received from %s", ws.ID())
ws.pingMessage <- []byte(appData)
err := conn.SetReadDeadline(server.getReadTimeout())
return err
})
_ = conn.SetReadDeadline(server.getReadTimeout())
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
server.error(fmt.Errorf("read failed unexpectedly for %s: %w", ws.ID(), err))
}
log.Debugf("handling read error for %s: %v", ws.ID(), err.Error())
// Notify writePump of error. Force close will be handled there
ws.forceCloseC <- err
return
}
if server.messageHandler != nil {
var channel Channel = ws
err = server.messageHandler(channel, message)
if err != nil {
server.error(fmt.Errorf("handling failed for %s: %w", ws.ID(), err))
continue
}
}
_ = conn.SetReadDeadline(server.getReadTimeout())
}
}
func (server *Server) writePump(ws *WebSocket) {
conn := ws.connection
for {
select {
case data, ok := <-ws.outQueue:
_ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait))
if !ok {
// Unexpected closed queue, should never happen
server.error(fmt.Errorf("output queue for socket %v was closed, forcefully closing", ws.id))
// Don't invoke cleanup
return
}
// Send data
err := conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err))
// Invoking cleanup, as socket was forcefully closed
server.cleanupConnection(ws)
return
}
log.Debugf("written %d bytes to %s", len(data), ws.ID())
case ping := <-ws.pingMessage:
_ = conn.SetWriteDeadline(time.Now().Add(server.timeoutConfig.WriteWait))
err := conn.WriteMessage(websocket.PongMessage, ping)
if err != nil {
server.error(fmt.Errorf("write failed for %s: %w", ws.ID(), err))
// Invoking cleanup, as socket was forcefully closed
server.cleanupConnection(ws)
return
}
log.Debugf("pong sent to %s", ws.ID())
case closeErr := <-ws.closeC:
log.Debugf("closing connection to %s", ws.ID())
// Closing connection gracefully
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeErr.Code, closeErr.Text),
time.Now().Add(server.timeoutConfig.WriteWait),
); err != nil {
server.error(fmt.Errorf("failed to write close message for connection %s: %w", ws.id, err))
}
// Invoking cleanup
server.cleanupConnection(ws)
return
case closed, ok := <-ws.forceCloseC:
if !ok || closed != nil {
// Connection was forcefully closed, invoke cleanup
log.Debugf("handling forced close signal for %s", ws.ID())
server.cleanupConnection(ws)
}
return
}
}
}
// Frees internal resources after a websocket connection was signaled to be closed.
// From this moment onwards, no new messages may be sent.
func (server *Server) cleanupConnection(ws *WebSocket) {
_ = ws.connection.Close()
server.connMutex.Lock()
close(ws.outQueue)
close(ws.closeC)
delete(server.connections, ws.id)
server.connMutex.Unlock()
log.Infof("closed connection to %s", ws.ID())
if server.disconnectedHandler != nil {
server.disconnectedHandler(ws)
}
}
// ---------------------- CLIENT ----------------------
// WsClient defines a websocket client, needed to connect to a websocket server.
// The offered API are of asynchronous nature, and each incoming message is handled using callbacks.
//
// To create a new ws client, use:
//
// client := NewClient()
//
// If you need a TLS ws client instead, use:
//
// certPool, err := x509.SystemCertPool()
// if err != nil {
// log.Fatal(err)
// }
// // You may add more trusted certificates to the pool before creating the TLSClientConfig
// client := NewTLSClient(&tls.Config{
// RootCAs: certPool,
// })
//
// To add additional dial options, use:
//
// client.AddOption(func(*websocket.Dialer) {
// // Your option ...
// )}
//
// To add basic HTTP authentication, use:
//
// client.SetBasicAuth("username","password")
//
// If you need to set a specific timeout configuration, refer to the SetTimeoutConfig method.
//
// Using Start and Stop you can respectively open/close a websocket to a websocket server.
//
// To receive incoming messages, you will need to set your own handler using SetMessageHandler.
// To write data on the open socket, simply call the Write function.
type WsClient interface {
// Starts the client and attempts to connect to the server on a specified URL.
// If the connection fails, an error is returned.
//
// For example:
// err := client.Start("ws://localhost:8887/ws/1234")
//
// The function returns immediately, after the connection has been established.
// Incoming messages are passed automatically to the callback function, so no explicit read operation is required.
//
// To stop a running client, call the Stop function.
Start(url string) error
// Starts the client and attempts to connect to the server on a specified URL.
// If the connection fails, it keeps retrying with Backoff strategy from TimeoutConfig.
//
// For example:
// client.StartWithRetries("ws://localhost:8887/ws/1234")
//
// The function returns only when the connection has been established.
// Incoming messages are passed automatically to the callback function, so no explicit read operation is required.
//
// To stop a running client, call the Stop function.
StartWithRetries(url string)
// Closes the output of the websocket Channel, effectively closing the connection to the server with a normal closure.
Stop()
// Errors returns a channel for error messages. If it doesn't exist it es created.
// The channel is closed by the client when stopped.
Errors() <-chan error
// Sets a callback function for all incoming messages.
SetMessageHandler(handler func(data []byte) error)
// Set custom timeout configuration parameters. If not passed, a default ClientTimeoutConfig struct will be used.
//
// This function must be called before connecting to the server, otherwise it may lead to unexpected behavior.
SetTimeoutConfig(config ClientTimeoutConfig)
// Sets a callback function for receiving notifications about an unexpected disconnection from the server.
// The callback is invoked even if the automatic reconnection mechanism is active.
//
// If the client was stopped using the Stop function, the callback will NOT be invoked.
SetDisconnectedHandler(handler func(err error))
// Sets a callback function for receiving notifications whenever the connection to the server is re-established.
// Connections are re-established automatically thanks to the auto-reconnection mechanism.
//
// If set, the DisconnectedHandler will always be invoked before the Reconnected callback is invoked.
SetReconnectedHandler(handler func())
// IsConnected Returns information about the current connection status.
// If the client is currently attempting to auto-reconnect to the server, the function returns false.
IsConnected() bool
// Sends a message to the server over the websocket.
//
// The data is queued and will be sent asynchronously in the background.
Write(data []byte) error
// Adds a websocket option to the client.
AddOption(option interface{})
// SetRequestedSubProtocol will negotiate the specified sub-protocol during the websocket handshake.
// Internally this creates a dialer option and invokes the AddOption method on the client.
//
// Duplicates generated by invoking this method multiple times will be ignored.
SetRequestedSubProtocol(subProto string)
// SetBasicAuth adds basic authentication credentials, to use when connecting to the server.
// The credentials are automatically encoded in base64.
SetBasicAuth(username string, password string)
// SetHeaderValue sets a value on the HTTP header sent when opening a websocket connection to the server.
//
// The function overwrites previous header fields with the same key.
SetHeaderValue(key string, value string)
}
// Client is the default implementation of a Websocket client.
//
// Use the NewClient or NewTLSClient functions to create a new client.
type Client struct {
webSocket WebSocket
url url.URL
messageHandler func(data []byte) error
dialOptions []func(*websocket.Dialer)
header http.Header
timeoutConfig ClientTimeoutConfig
connected bool
onDisconnected func(err error)
onReconnected func()
mutex sync.Mutex
errC chan error
reconnectC chan struct{} // used for signaling, that a reconnection attempt should be interrupted
}
// Creates a new simple websocket client (the channel is not secured).
//
// Additional options may be added using the AddOption function.
//
// Basic authentication can be set using the SetBasicAuth function.
//
// By default, the client will not neogtiate any subprotocol. This value needs to be set via the
// respective SetRequestedSubProtocol method.
func NewClient() *Client {
return &Client{
dialOptions: []func(*websocket.Dialer){},
timeoutConfig: NewClientTimeoutConfig(),
header: http.Header{},
}
}
// NewTLSClient creates a new secure websocket client. If supported by the server, the websocket channel will use TLS.
//
// Additional options may be added using the AddOption function.
// Basic authentication can be set using the SetBasicAuth function.
//
// To set a client certificate, you may do:
//
// certificate, _ := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
// clientCertificates := []tls.Certificate{certificate}
// client := ws.NewTLSClient(&tls.Config{
// RootCAs: certPool,
// Certificates: clientCertificates,
// })
//
// You can set any other TLS option within the same constructor as well.
// For example, if you wish to test connecting to a server having a
// self-signed certificate (do not use in production!), pass:
//
// InsecureSkipVerify: true
func NewTLSClient(tlsConfig *tls.Config) *Client {
client := &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig(), header: http.Header{}}
client.dialOptions = append(client.dialOptions, func(dialer *websocket.Dialer) {
dialer.TLSClientConfig = tlsConfig
})
return client
}
func (client *Client) SetMessageHandler(handler func(data []byte) error) {
client.messageHandler = handler
}
func (client *Client) SetTimeoutConfig(config ClientTimeoutConfig) {
client.timeoutConfig = config
}
func (client *Client) SetDisconnectedHandler(handler func(err error)) {
client.onDisconnected = handler
}
func (client *Client) SetReconnectedHandler(handler func()) {
client.onReconnected = handler
}
func (client *Client) AddOption(option interface{}) {
dialOption, ok := option.(func(*websocket.Dialer))
if ok {
client.dialOptions = append(client.dialOptions, dialOption)
}
}
func (client *Client) SetRequestedSubProtocol(subProto string) {
opt := func(dialer *websocket.Dialer) {
alreadyExists := false
for _, proto := range dialer.Subprotocols {
if proto == subProto {
alreadyExists = true
break
}
}
if !alreadyExists {
dialer.Subprotocols = append(dialer.Subprotocols, subProto)
}
}
client.AddOption(opt)
}
func (client *Client) SetBasicAuth(username string, password string) {
client.header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(username+":"+password)))
}
func (client *Client) SetHeaderValue(key string, value string) {
client.header.Set(key, value)
}
func (client *Client) getReadTimeout() time.Time {
if client.timeoutConfig.PongWait == 0 {
return time.Time{}
}
return time.Now().Add(client.timeoutConfig.PongWait)
}
func (client *Client) writePump() {
ticker := time.NewTicker(client.timeoutConfig.PingPeriod)
conn := client.webSocket.connection
// Closure function correctly closes the current connection
closure := func(err error) {
ticker.Stop()
client.cleanup()
// Invoke callback
if client.onDisconnected != nil {
client.onDisconnected(err)
}
}
for {
select {
case data := <-client.webSocket.outQueue:
// Send data
log.Debugf("sending data")
_ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait))
err := conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
client.error(fmt.Errorf("write failed: %w", err))
closure(err)
client.handleReconnection()
return
}
log.Debugf("written %d bytes", len(data))
case <-ticker.C:
// Send periodic ping
_ = conn.SetWriteDeadline(time.Now().Add(client.timeoutConfig.WriteWait))
if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
client.error(fmt.Errorf("failed to send ping message: %w", err))
closure(err)
client.handleReconnection()
return
}
log.Debugf("ping sent")
case closeErr := <-client.webSocket.closeC:
log.Debugf("closing connection")
// Closing connection gracefully
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(closeErr.Code, closeErr.Text),
time.Now().Add(client.timeoutConfig.WriteWait),
); err != nil {
client.error(fmt.Errorf("failed to write close message: %w", err))
}
// Disconnected by user command. Not calling auto-reconnect.
// Passing nil will also not call onDisconnected.
closure(nil)
return
case closed, ok := <-client.webSocket.forceCloseC:
log.Debugf("handling forced close signal")
// Read pump sent a forceClose signal (reading failed -> aborting the connection)
if !ok || closed != nil {
closure(closed)
client.handleReconnection()
return
}
}
}
}
func (client *Client) readPump() {
conn := client.webSocket.connection
_ = conn.SetReadDeadline(client.getReadTimeout())
conn.SetPongHandler(func(string) error {
log.Debugf("pong received")
return conn.SetReadDeadline(client.getReadTimeout())
})
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
client.error(fmt.Errorf("read failed: %w", err))
}
// Notify writePump of error. Forced close will be handled there
client.webSocket.forceCloseC <- err
return
}
log.Debugf("received %v bytes", len(message))
if client.messageHandler != nil {
err = client.messageHandler(message)
if err != nil {
client.error(fmt.Errorf("handle failed: %w", err))
continue
}
}
}
}
// Frees internal resources after a websocket connection was signaled to be closed.
// From this moment onwards, no new messages may be sent.
func (client *Client) cleanup() {
client.setConnected(false)
ws := client.webSocket
_ = ws.connection.Close()
client.mutex.Lock()
defer client.mutex.Unlock()
close(ws.outQueue)
close(ws.closeC)
}
func (client *Client) handleReconnection() {
log.Info("started automatic reconnection handler")