/
ws.go
128 lines (111 loc) · 2.83 KB
/
ws.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package ws
import (
"crypto/sha256"
"fmt"
"strings"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
)
type MessageHandler func(WebSocketMessage, *websocket.Conn, echo.Context) error
// WS is the representation of the WebSocket handler
type WS struct {
// feed is the feed of websocket messages
feed chan WebSocketMessage
connections map[string]chan WebSocketMessage
handlers []MessageHandler
upgrader websocket.Upgrader
}
var singleton *WS
// GetWs retrieves the WebSocket singleton, making it if it hasn't been created yet.
func GetWS() *WS {
if singleton == nil {
singleton = &WS{
feed: make(chan WebSocketMessage, 5),
connections: map[string]chan WebSocketMessage{},
handlers: []MessageHandler{},
upgrader: websocket.Upgrader{},
}
go singleton.fanOut()
}
return singleton
}
func (w *WS) fanOut() {
for {
select {
case msg, ok := <-w.feed:
if !ok {
return
}
for _, conn := range(w.connections) {
conn <-msg
}
}
}
}
func (w *WS) RegisterHandle(handler MessageHandler) {
w.handlers = append(w.handlers, handler)
}
// Handle is the websocket handler for the echo server
func (w *WS) Handle(c echo.Context) error {
key := hashStr(c.Request().RemoteAddr)
if _, exists := w.connections[key]; exists {
return fmt.Errorf("A connection from this address already exists, only one connection per use allowed")
}
ws, err := w.upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
feed := make(chan WebSocketMessage, 5)
feed <- WebSocketMessage{Message: "Hello! How can I help you today?", Type: "message"}
defer close(feed)
w.connections[key] = feed
defer delete(w.connections, key)
go w.processMessages(feed, ws, c)
for {
if ok := w.readMessage(ws, c); !ok {
return nil
}
}
}
func (w *WS) processMessages(feed chan WebSocketMessage, ws *websocket.Conn, c echo.Context) {
for {
select {
case msg := <-feed:
if msg.Type == Close || msg.Type == "" || len(strings.TrimSpace(msg.Message)) < 1 {
return
}
for _, h := range(w.handlers) {
if err := h(msg, ws, c); err != nil {
c.Logger().Error(err)
}
}
}
}
}
func (w *WS) readMessage(ws *websocket.Conn, c echo.Context) bool {
var chatMessage WebSocketMessage
if err := ws.ReadJSON(&chatMessage); err != nil {
if strings.Contains(err.Error(), "websocket: close 1001") {
c.Logger().Warn("connection broken, closing reader...")
return false
}
c.Logger().Error(err)
return true
}
c.Logger().Info(chatMessage)
if chatMessage.Type != Close && len(strings.TrimSpace(chatMessage.Message)) < 1 {
// Ignore empty messages
return true
}
w.feed <- chatMessage
if chatMessage.Type == Close {
return false
}
return true
}
func hashStr(in string) string {
h := sha256.New()
h.Write([]byte(in))
return string(h.Sum(nil))
}