/
forwarded-tcpip.go
151 lines (131 loc) · 3.89 KB
/
forwarded-tcpip.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package serv
import (
"github.com/nishoushun/gosshd"
"golang.org/x/crypto/ssh"
"net"
"strconv"
"sync"
)
// ForwardedTcpIpRequestHandler 用于处理 tcpip-forward 全局请求
type ForwardedTcpIpRequestHandler struct {
bufSize int
forwards map[string]net.Listener
sync.Mutex
}
func NewForwardedTcpIpHandler(bufSize int) *ForwardedTcpIpRequestHandler {
return &ForwardedTcpIpRequestHandler{
bufSize: bufSize,
forwards: map[string]net.Listener{},
Mutex: sync.Mutex{},
}
}
// HandleRequest 可用于注册 tcpip-forward 与 cancel-tcpip-forward 类型的全局请求的处理函数
func (h *ForwardedTcpIpRequestHandler) HandleRequest(ctx gosshd.Context, request gosshd.Request) {
switch request.Type {
case gosshd.GlobalReqTcpIpForward:
h.ServeForward(ctx, request)
case gosshd.GlobalReqCancelTcpIpForward:
h.CancelForward(ctx, request)
default:
request.Reply(false, nil)
}
}
// ServeForward 处理 tcpip-forward 全局请求,监听请求消息中的地址与端口;
// 每当监听到一个新的网络连接,就向客户端发送一个 forwarded-tcpip 通道建立请求,转发连接内容
func (h *ForwardedTcpIpRequestHandler) ServeForward(ctx gosshd.Context, request gosshd.Request) {
forwardReq := &gosshd.RemoteForwardRequestMsg{}
if err := ssh.Unmarshal(request.Payload, forwardReq); err != nil {
request.Reply(false, invalidPayload)
return
}
addr := net.JoinHostPort(forwardReq.BindAddr, strconv.Itoa(int(forwardReq.BindPort)))
ln, err := net.Listen("tcp", addr)
if err != nil {
request.Reply(false, []byte(err.Error()))
return
}
_, destPortStr, err := net.SplitHostPort(ln.Addr().String())
destPort, err := strconv.Atoi(destPortStr)
if err != nil {
request.Reply(false, nil)
return
}
request.Reply(true, nil)
h.Lock()
h.forwards[addr] = ln
h.Unlock()
go func() {
select {
case <-ctx.Done():
h.CloseAndDel(addr)
}
}()
for {
remoteConn, err := ln.Accept()
if err != nil {
break
}
originAddr, orignPortStr, _ := net.SplitHostPort(ctx.RemoteAddr().String())
originPort, _ := strconv.Atoi(orignPortStr)
remoteForwardChannelDataMsg := ssh.Marshal(&gosshd.RemoteForwardChannelDataMsg{
DestAddr: forwardReq.BindAddr,
DestPort: uint32(destPort),
OriginAddr: originAddr,
OriginPort: uint32(originPort),
})
// 每监听到一个网络连接,就向客户端打开一个通道,然后转发数据
go func() {
channel, requests, err := ctx.Conn().OpenChannel(gosshd.ForwardedTcpIpChannelType, remoteForwardChannelDataMsg)
if err != nil {
request.Reply(false, []byte(err.Error()))
remoteConn.Close()
return
}
go ssh.DiscardRequests(requests)
var wbuf []byte = nil
var rbuf []byte = nil
if h.bufSize > 0 {
wbuf = make([]byte, h.bufSize)
rbuf = make([]byte, h.bufSize)
}
go func() {
defer channel.Close()
defer remoteConn.Close()
CopyBufferWithContext(channel, remoteConn, rbuf, ctx)
}()
go func() {
defer channel.Close()
defer remoteConn.Close()
CopyBufferWithContext(remoteConn, channel, wbuf, ctx)
}()
}()
}
h.CloseAndDel(addr)
}
func (h *ForwardedTcpIpRequestHandler) CancelForward(ctx gosshd.Context, request gosshd.Request) {
cancelReq := &gosshd.RemoteForwardCancelRequestMsg{}
if err := ssh.Unmarshal(request.Payload, cancelReq); err != nil {
request.Reply(false, invalidPayload)
return
}
addr := net.JoinHostPort(cancelReq.BindAddr, strconv.Itoa(int(cancelReq.BindPort)))
h.CloseAndDel(addr)
request.Reply(true, nil)
}
// CloseAndDel 删除并关闭对应地址的 listener
func (h *ForwardedTcpIpRequestHandler) CloseAndDel(addr string) {
h.Lock()
defer h.Unlock()
ln, ok := h.forwards[addr]
if ok {
ln.Close()
delete(h.forwards, addr)
}
}
// Del 删除对应地址的 listener
func (h *ForwardedTcpIpRequestHandler) Del(addr string) {
h.Lock()
defer h.Unlock()
delete(h.forwards, addr)
}
var invalidPayload = []byte("invalid payload")