-
Notifications
You must be signed in to change notification settings - Fork 287
/
tcp.go
105 lines (96 loc) · 3.56 KB
/
tcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package tcp
import (
"context"
"io"
"net"
"net/url"
"sync"
"github.com/hashicorp/boundary/globals"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/boundary/internal/servers/worker/proxy"
"github.com/hashicorp/boundary/internal/servers/worker/session"
"nhooyr.io/websocket"
)
func init() {
err := proxy.RegisterHandler(globals.TcpProxyV1, handleTcpProxyV1)
if err != nil {
panic(err)
}
}
// handleTcpProxyV1 creates a tcp proxy between the incoming websocket conn and the
// connection it creates with the remote endpoint. handleTcpProxyV1 sets the connectionId
// as connected in the repository.
//
// handleTcpProxyV1 blocks until an error (EOF on happy path) is received on either
// connection.
func handleTcpProxyV1(ctx context.Context, conf proxy.Config, _ ...proxy.Option) {
const op = "tcp.HandleTcpProxyV1"
si := conf.SessionInfo
si.RLock()
sessionId := si.LookupSessionResponse.GetAuthorization().GetSessionId()
si.RUnlock()
conn := conf.ClientConn
sessionUrl, err := url.Parse(conf.RemoteEndpoint)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error parsing endpoint information", "session_id", sessionId, "endpoint", conf.RemoteEndpoint))
if err = conn.Close(websocket.StatusInternalError, "cannot parse endpoint url"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
}
if sessionUrl.Scheme != "tcp" {
event.WriteError(ctx, op, err, event.WithInfo("session_id", sessionId, "endpoint", conf.RemoteEndpoint))
if err = conn.Close(websocket.StatusInternalError, "invalid scheme for type"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
}
remoteConn, err := net.Dial("tcp", sessionUrl.Host)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error dialing endpoint", "endpoint", conf.RemoteEndpoint))
if err = conn.Close(websocket.StatusInternalError, "endpoint dialing failed"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
}
// Assert this for better Go 1.11 splice support
tcpRemoteConn := remoteConn.(*net.TCPConn)
endpointAddr := tcpRemoteConn.RemoteAddr().(*net.TCPAddr)
connectionInfo := &pbs.ConnectConnectionRequest{
ConnectionId: conf.ConnectionId,
ClientTcpAddress: conf.ClientAddress.IP.String(),
ClientTcpPort: uint32(conf.ClientAddress.Port),
EndpointTcpAddress: endpointAddr.IP.String(),
EndpointTcpPort: uint32(endpointAddr.Port),
Type: "tcp",
}
connStatus, err := session.ConnectConnection(ctx, conf.SessionClient, connectionInfo)
if err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error marking connection as connected"))
if err = conn.Close(websocket.StatusInternalError, "failed to mark connection as connected"); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing client connection"))
}
return
}
si.Lock()
si.ConnInfoMap[conf.ConnectionId].Status = connStatus
si.Unlock()
// Get a wrapped net.Conn so we can use io.Copy
netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
connWg := new(sync.WaitGroup)
connWg.Add(2)
go func() {
defer connWg.Done()
_, _ = io.Copy(netConn, tcpRemoteConn)
_ = netConn.Close()
_ = tcpRemoteConn.Close()
}()
go func() {
defer connWg.Done()
_, _ = io.Copy(tcpRemoteConn, netConn)
_ = tcpRemoteConn.Close()
_ = netConn.Close()
}()
connWg.Wait()
}