forked from gwuhaolin/lightsocks
/
securetcp.go
144 lines (130 loc) · 3.15 KB
/
securetcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package lightsocks
import (
"io"
"log"
"net"
"sync"
)
const (
bufSize = 1024
)
var bpool sync.Pool
func init() {
bpool.New = func() interface{} {
return make([]byte, bufSize)
}
}
func bufferPoolGet() []byte {
return bpool.Get().([]byte)
}
func bufferPoolPut(b []byte) {
bpool.Put(b)
}
// 加密传输的 TCP Socket
type SecureTCPConn struct {
io.ReadWriteCloser
Cipher *Cipher
}
// 从输入流里读取加密过的数据,解密后把原数据放到bs里
func (secureSocket *SecureTCPConn) DecodeRead(bs []byte) (n int, err error) {
n, err = secureSocket.Read(bs)
if err != nil {
return
}
secureSocket.Cipher.Decode(bs[:n])
return
}
// 把放在bs里的数据加密后立即全部写入输出流
func (secureSocket *SecureTCPConn) EncodeWrite(bs []byte) (int, error) {
secureSocket.Cipher.Encode(bs)
return secureSocket.Write(bs)
}
// 从src中源源不断的读取原数据加密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) EncodeCopy(dst io.ReadWriteCloser) error {
buf := bufferPoolGet()
defer bufferPoolPut(buf)
for {
readCount, errRead := secureSocket.Read(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := (&SecureTCPConn{
ReadWriteCloser: dst,
Cipher: secureSocket.Cipher,
}).EncodeWrite(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// 从src中源源不断的读取加密后的数据解密后写入到dst,直到src中没有数据可以再读取
func (secureSocket *SecureTCPConn) DecodeCopy(dst io.Writer) error {
buf := bufferPoolGet()
defer bufferPoolPut(buf)
for {
readCount, errRead := secureSocket.DecodeRead(buf)
if errRead != nil {
if errRead != io.EOF {
return errRead
} else {
return nil
}
}
if readCount > 0 {
writeCount, errWrite := dst.Write(buf[0:readCount])
if errWrite != nil {
return errWrite
}
if readCount != writeCount {
return io.ErrShortWrite
}
}
}
}
// see net.DialTCP
func DialEncryptedTCP(raddr *net.TCPAddr, cipher *Cipher) (*SecureTCPConn, error) {
remoteConn, err := net.DialTCP("tcp", nil, raddr)
if err != nil {
return nil, err
}
// Conn被关闭时直接清除所有数据 不管没有发送的数据
remoteConn.SetLinger(0)
return &SecureTCPConn{
ReadWriteCloser: remoteConn,
Cipher: cipher,
}, nil
}
// see net.ListenTCP
func ListenEncryptedTCP(laddr *net.TCPAddr, cipher *Cipher, handleConn func(localConn *SecureTCPConn), didListen func(listenAddr *net.TCPAddr)) error {
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return err
}
defer listener.Close()
if didListen != nil {
// didListen 可能有阻塞操作
go didListen(listener.Addr().(*net.TCPAddr))
}
for {
localConn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
// localConn被关闭时直接清除所有数据 不管没有发送的数据
localConn.SetLinger(0)
go handleConn(&SecureTCPConn{
ReadWriteCloser: localConn,
Cipher: cipher,
})
}
}