forked from jpillora/chisel
/
tunnel_out_ssh_udp.go
155 lines (143 loc) · 2.93 KB
/
tunnel_out_ssh_udp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package tunnel
import (
"encoding/gob"
"io"
"net"
"os"
"sync"
"time"
"github.com/defane/chisel/share/cio"
"github.com/defane/chisel/share/settings"
)
func (t *Tunnel) handleUDP(l *cio.Logger, rwc io.ReadWriteCloser, hostPort string) error {
conns := &udpConns{
Logger: l,
m: map[string]*udpConn{},
}
defer conns.closeAll()
h := &udpHandler{
Logger: l,
hostPort: hostPort,
udpChannel: &udpChannel{
r: gob.NewDecoder(rwc),
w: gob.NewEncoder(rwc),
c: rwc,
},
udpConns: conns,
maxMTU: settings.EnvInt("UDP_MAX_SIZE", 9012),
}
h.Debugf("UDP max size: %d bytes", h.maxMTU)
for {
p := udpPacket{}
if err := h.handleWrite(&p); err != nil {
return err
}
}
}
type udpHandler struct {
*cio.Logger
hostPort string
*udpChannel
*udpConns
maxMTU int
}
func (h *udpHandler) handleWrite(p *udpPacket) error {
if err := h.r.Decode(&p); err != nil {
return err
}
//dial now, we know we must write
conn, exists, err := h.udpConns.dial(p.Src, h.hostPort)
if err != nil {
return err
}
//however, we dont know if we must read...
//spawn up to <max-conns> go-routines to wait
//for a reply.
//TODO configurable
//TODO++ dont use go-routines, switch to pollable
// array of listeners where all listeners are
// sweeped periodically, removing the idle ones
const maxConns = 100
if !exists {
if h.udpConns.len() <= maxConns {
go h.handleRead(p, conn)
} else {
h.Debugf("exceeded max udp connections (%d)", maxConns)
}
}
_, err = conn.Write(p.Payload)
if err != nil {
return err
}
return nil
}
func (h *udpHandler) handleRead(p *udpPacket, conn *udpConn) {
//ensure connection is cleaned up
defer h.udpConns.remove(conn.id)
buff := make([]byte, h.maxMTU)
for {
//response must arrive within 15 seconds
deadline := settings.EnvDuration("UDP_DEADLINE", 15*time.Second)
conn.SetReadDeadline(time.Now().Add(deadline))
//read response
n, err := conn.Read(buff)
if err != nil {
if !os.IsTimeout(err) && err != io.EOF {
h.Debugf("read error: %s", err)
}
break
}
b := buff[:n]
//encode back over ssh connection
err = h.udpChannel.encode(p.Src, b)
if err != nil {
h.Debugf("encode error: %s", err)
return
}
}
}
type udpConns struct {
*cio.Logger
sync.Mutex
m map[string]*udpConn
}
func (cs *udpConns) dial(id, addr string) (*udpConn, bool, error) {
cs.Lock()
defer cs.Unlock()
conn, ok := cs.m[id]
if !ok {
c, err := net.Dial("udp", addr)
if err != nil {
return nil, false, err
}
conn = &udpConn{
id: id,
Conn: c, // cnet.MeterConn(cs.Logger.Fork(addr), c),
}
cs.m[id] = conn
}
return conn, ok, nil
}
func (cs *udpConns) len() int {
cs.Lock()
l := len(cs.m)
cs.Unlock()
return l
}
func (cs *udpConns) remove(id string) {
cs.Lock()
delete(cs.m, id)
cs.Unlock()
}
func (cs *udpConns) closeAll() {
cs.Lock()
for id, conn := range cs.m {
conn.Close()
delete(cs.m, id)
}
cs.Unlock()
}
type udpConn struct {
id string
net.Conn
}