-
Notifications
You must be signed in to change notification settings - Fork 155
/
wss-server.go
300 lines (269 loc) · 7.73 KB
/
wss-server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
// Copyright (c) 2021 Zededa, Inc.
// SPDX-License-Identifier: Apache-2.0
package main
import (
"flag"
"fmt"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
type endPoint struct {
hostname string
wsConn *websocket.Conn
}
const (
noDeviceMsg string = "no device online\n+++Done+++"
tokenReqMsg string = "token is required"
moretwoMsg string = "can't have more than 2 peers"
clientIPMsg string = "YourEndPointIPAddr:"
)
var upgrader = websocket.Upgrader{} // use default options
// reqAddrTokeConn indexed by 'token' then 'remoteAddr' strings
var reqAddrTokenEP map[string]map[string]endPoint
// mutex for access the maps
var connMutex sync.Mutex
// connection id, keep inc
var connID int
// debug set
var needDebug bool
// There are three entities in the edge-view data operation, the user,
// the dispatcher and the edge-node.
// From TCP/TLS/Websocket connection POV, user does not have any
// relation to the edge-node. The connections are between the
// user with the wss-server, and the edge-node with the wss-server.
// Think of this as the Hub-spoke model, with the wss-server as the hub,
// and user and edge-node are two spokes.
// The user and the edge-node only have a 'virtual' connection which
// contains the 'application' layer packets, and the wss-server is switching
// the packets for user and edge-node based on a 'token'. This is
// analogous to the hub-spoke in SD-WAN, where the hub installs the
// routing from each spoke node, and based on the packet destination
// and VPN-index to do a lookup to forward packets to the right
// destination spoke. Here the 'token' lookup is similar to lookup
// for a VPN-ID to find the VPN-table. Since we only allow one user
// to interact with one edge-node (only two spokes within the same VPN),
// the hub only needs to find the 'other' spoke for the packet switching.
// This may change for more complex topology.
func socketHandler(w http.ResponseWriter, r *http.Request) {
// Upgrade our raw HTTP connection to a websocket based one
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Printf("Error during connection upgradation: %v\n", err)
return
}
defer conn.Close()
if _, ok := r.Header["X-Session-Token"]; !ok {
err := conn.WriteMessage(websocket.TextMessage, []byte(tokenReqMsg))
if needDebug {
fmt.Printf("websocket write: %v\n", err)
}
return
}
if len(r.Header["X-Session-Token"]) == 0 {
err := conn.WriteMessage(websocket.TextMessage, []byte(tokenReqMsg))
if needDebug {
fmt.Printf("websocket write: %v\n", err)
}
return
}
token := r.Header["X-Session-Token"][0]
connID++
myConnID := connID
var hostname string
if _, ok := r.Header["X-Hostname"]; ok {
if len(r.Header["X-Hostname"]) > 0 {
hostname = r.Header["X-Hostname"][0]
}
}
remoteAddr := r.RemoteAddr
if addrStr, ok := r.Header["Cf-Connecting-Ip"]; ok {
if len(addrStr) > 0 {
remoteAddr = addrStr[0]
}
}
connMutex.Lock()
tmpMap := reqAddrTokenEP[token]
if tmpMap == nil {
tmpMap := make(map[string]endPoint)
reqAddrTokenEP[token] = tmpMap
}
if len(tmpMap) == 2 {
var addOK bool
// check to see if this one is from the same host
for addr, e := range tmpMap {
if e.hostname == hostname {
fmt.Printf("%v received connection with same hostname %s, close old w/Addr %s\n", time.Now(), hostname, addr)
e.wsConn.Close()
addOK = true
}
}
if !addOK {
err := conn.WriteMessage(websocket.TextMessage, []byte(moretwoMsg))
if needDebug {
fmt.Printf("websocket write: %v\n", err)
}
connMutex.Unlock()
return
}
}
ep := endPoint{
wsConn: conn,
hostname: hostname,
}
if _, ok := reqAddrTokenEP[token][remoteAddr]; !ok {
reqAddrTokenEP[token][remoteAddr] = ep
}
sizeMap := len(tmpMap)
connMutex.Unlock()
if sizeMap < 2 {
err := conn.WriteMessage(websocket.TextMessage, []byte(noDeviceMsg))
if needDebug {
fmt.Printf("websocket write: %v\n", err)
}
}
fmt.Printf("%v client %s from %s connected, ID: %d\n",
time.Now().Format("2006-01-02 15:04:05"), hostname, remoteAddr, myConnID)
// send peer's own endpoint IP over first
_ = conn.WriteMessage(websocket.TextMessage, []byte(clientIPMsg+remoteAddr))
cnt := 0
nopeerPkts := 0
for {
messageType, message, err := conn.ReadMessage()
now := time.Now()
nowStr := now.Format("2006-01-02 15:04:05")
if err != nil {
fmt.Printf("%s on reading host %s from %s, ID %d: %v\n", nowStr, hostname, remoteAddr, myConnID, err)
cleanConnMap(token, remoteAddr)
break
}
connMutex.Lock()
tmpMap = reqAddrTokenEP[token]
if tmpMap == nil {
connMutex.Unlock()
continue
}
myEP := endPoint{}
var peerAddr string
for addr, e := range tmpMap {
if remoteAddr == addr {
continue
}
dest := strings.Split(addr, ":")
if len(dest) == 2 {
addr = dest[1]
}
if needDebug {
fmt.Printf("%s (%d/%d): [%v], t-%d len %d, to %s\n",
nowStr, myConnID, cnt, hostname, messageType, len(message), addr)
}
peerAddr = addr
myEP = e
nopeerPkts = 0
break
}
connMutex.Unlock()
if myEP.wsConn == nil {
nopeerPkts++
fmt.Printf("%s can not find peer %d\n", nowStr, nopeerPkts)
if nopeerPkts < 50 { // need sometime for ep to reconnect
continue
}
err = conn.WriteMessage(websocket.TextMessage, []byte(noDeviceMsg))
if err != nil {
fmt.Printf("Error during message writing: %v\n", err)
cleanConnMap(token, remoteAddr)
break
}
continue
}
err = myEP.wsConn.WriteMessage(messageType, message)
if err != nil {
fmt.Printf("Error during message from %s writing to %s, ID %d: %v\n", hostname, peerAddr, myConnID, err)
cleanConnMap(token, remoteAddr)
break
}
cnt++
}
}
func cleanConnMap(token, remoteAddr string) {
connMutex.Lock()
tmpMap := reqAddrTokenEP[token]
if tmpMap != nil {
delete(tmpMap, remoteAddr)
if len(tmpMap) == 0 {
delete(reqAddrTokenEP, token)
}
}
connMutex.Unlock()
}
// Get preferred outbound ip of this machine
func getOutboundIP() string {
retryMax := 10
var conn net.Conn
var err error
var count int
for count < retryMax {
conn, err = net.Dial("udp", "8.8.8.8:80")
if err != nil {
fmt.Println(err)
} else {
defer conn.Close()
break
}
time.Sleep(2 * time.Second)
count++
}
if conn == nil {
return ""
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String()
}
func pingHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
fmt.Fprintf(w, "pong\n")
}
}
// the edge-view websocket dispatcher example
func main() {
reqAddrTokenEP = make(map[string]map[string]endPoint)
helpPtr := flag.Bool("h", false, "help string")
debugPtr := flag.Bool("debug", false, "more debug info")
portPtr := flag.String("port", "", "websocket listen port")
certFilePtr := flag.String("cert", "", "server certificate pem file")
keyFilePtr := flag.String("key", "", "server key pem file")
flag.Parse()
if *helpPtr {
fmt.Println(" -h this help")
fmt.Println(" -port <port number> mandatory, tcp port number")
fmt.Println(" -cert <path> mandatory, server certificate path in PEM format")
fmt.Println(" -key <path> mandatory, server key file path in PEM format")
fmt.Println(" -debug optional, turn on more debug")
return
}
if *debugPtr {
needDebug = true
}
if *portPtr == "" {
fmt.Println("port needs to be specified")
return
}
if *certFilePtr == "" || *keyFilePtr == "" {
fmt.Println("server cert and key files need to be specified")
return
}
localIP := getOutboundIP()
server := &http.Server{
Addr: localIP + ":" + *portPtr,
}
http.HandleFunc("/edge-view", socketHandler)
http.HandleFunc("/v1/ping", pingHandler)
fmt.Printf("Listen TLS on: %s:%s\n", localIP, *portPtr)
log.Fatal(server.ListenAndServeTLS(*certFilePtr, *keyFilePtr))
}