-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
75 lines (62 loc) · 1.84 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package remotedialer
import (
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
var (
errFailedAuth = errors.New("failed authentication")
errWrongMessageType = errors.New("wrong websocket message type")
)
type Authorizer func(req *http.Request) (clientKey string, authed bool, err error)
type ErrorWriter func(rw http.ResponseWriter, req *http.Request, code int, err error)
type Server struct {
ready func() bool
authorizer Authorizer
errorWriter ErrorWriter
sessions *sessionManager
}
func New(auth Authorizer, errorWriter ErrorWriter, ready func() bool) *Server {
return &Server{
ready: ready,
authorizer: auth,
errorWriter: errorWriter,
sessions: newSessionManager(),
}
}
func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if !s.ready() {
s.errorWriter(rw, req, 503, errors.New("tunnel server not active"))
return
}
clientKey, authed, err := s.authorizer(req)
if err != nil {
s.errorWriter(rw, req, 400, err)
return
}
if !authed {
s.errorWriter(rw, req, 401, errFailedAuth)
return
}
logrus.Infof("Handling backend connection request [%s]", clientKey)
upgrader := websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
CheckOrigin: func(r *http.Request) bool { return true },
Error: s.errorWriter,
}
wsConn, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
s.errorWriter(rw, req, 400, errors.Wrapf(err, "Error during upgrade for host [%v]", clientKey))
return
}
session := s.sessions.add(clientKey, wsConn)
defer s.sessions.remove(session)
// Don't need to associate req.Context() to the session, it will cancel otherwise
code, err := session.serve()
if err != nil {
// Hijacked so we can't write to the client
logrus.Debugf("error in remotedialer server [%d]: %v", code, err)
}
}