Skip to content

Commit

Permalink
fix: 解决内存泄露,进一步提升性能
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhaolin committed Oct 14, 2017
1 parent 25a2afe commit dfe7c2d
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 76 deletions.
22 changes: 0 additions & 22 deletions cmd/config_test.go

This file was deleted.

22 changes: 14 additions & 8 deletions cmd/e2e_test.go
Expand Up @@ -45,14 +45,12 @@ func runEchoServer() {
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalln("listener.Accept", err)
log.Fatalln(err)
continue
}
log.Println("EchoServer", "listener.Accept")
go func() {
defer conn.Close()
io.Copy(conn, conn)
log.Println("EchoServer", "conn.Close")
}()
}
}
Expand All @@ -63,8 +61,8 @@ func runLightsocksProxyServer() {
serverAddr, _ := net.ResolveTCPAddr("tcp", LightSocksProxyServerAddr)
serverS := local.New(password, localAddr, serverAddr)
localS := server.New(password, serverAddr)
go serverS.Listen()
localS.Listen()
go serverS.Listen(nil)
localS.Listen(nil)
}

// 发生一次连接测试经过代理后的数据传输的正确性
Expand All @@ -73,21 +71,29 @@ func testConnect(packSize int) {
// 随机生产 MaxPackSize byte的[]byte
data := make([]byte, packSize)
_, err := rand.Read(data)
buf := make([]byte, len(data))

// 连接
conn, err := lightsocksDialer.Dial("tcp", EchoServerAddr)
if err != nil {
log.Fatalln(err)
}
defer conn.Close()

// 写
go func() {
conn.Write(data)
}()

// 读
buf := make([]byte, len(data))
_, err = io.ReadFull(conn, buf)
conn.Close()
if err != nil {
log.Fatalln("io.ReadFull", err)
log.Fatalln(err)
}
if !reflect.DeepEqual(data, buf) {
log.Fatalln("通过 Lightsocks 代理传输得到的数据前后不一致")
} else {
log.Println("数据一致性验证通过")
}
}

Expand Down
8 changes: 4 additions & 4 deletions cmd/lightsocks-local/main.go
Expand Up @@ -18,13 +18,14 @@ var version = "master"
func main() {
log.SetFlags(log.Lshortfile)

var err error
// 默认配置
config := &cmd.Config{
ListenAddr: DefaultListenAddr,
}
config.ReadConfig()
config.SaveConfig()

// 解析配置
password, err := core.ParsePassword(config.Password)
if err != nil {
log.Fatalln(err)
Expand All @@ -40,7 +41,7 @@ func main() {

// 启动 local 端并监听
lsLocal := local.New(password, listenAddr, remoteAddr)
lsLocal.AfterListen = func(listenAddr net.Addr) {
log.Fatalln(lsLocal.Listen(func(listenAddr net.Addr) {
log.Printf("lightsocks-local:%s 启动成功 监听在 %s\n", version, listenAddr.String())
log.Println("使用配置:", fmt.Sprintf(`
本地监听地址 listen:
Expand All @@ -50,6 +51,5 @@ func main() {
密码 password:
%s
`, listenAddr, remoteAddr, password))
}
log.Fatalln(lsLocal.Listen())
}))
}
9 changes: 4 additions & 5 deletions cmd/lightsocks-server/main.go
Expand Up @@ -15,14 +15,13 @@ var version = "master"
func main() {
log.SetFlags(log.Lshortfile)

var err error

// 服务端监听端口随机生成
port, err := freeport.GetFreePort()
if err != nil {
// 随机端口失败就采用 7448
port = 7448
}
// 默认配置
config := &cmd.Config{
ListenAddr: fmt.Sprintf(":%d", port),
// 密码随机生成
Expand All @@ -31,6 +30,7 @@ func main() {
config.ReadConfig()
config.SaveConfig()

// 解析配置
password, err := core.ParsePassword(config.Password)
if err != nil {
log.Fatalln(err)
Expand All @@ -42,14 +42,13 @@ func main() {

// 启动 server 端并监听
lsServer := server.New(password, listenAddr)
lsServer.AfterListen = func(listenAddr net.Addr) {
log.Fatalln(lsServer.Listen(func(listenAddr net.Addr) {
log.Printf("lightsocks-server:%s 启动成功 监听在 %s\n", version, listenAddr.String())
log.Println("使用配置:", fmt.Sprintf(`
本地监听地址 listen:
%s
密码 password:
%s
`, listenAddr, password))
}
log.Fatalln(lsServer.Listen())
}))
}
8 changes: 6 additions & 2 deletions core/securesocket.go
Expand Up @@ -22,6 +22,8 @@ type SecureSocket struct {

// 从输入流里读取加密过的数据,解密后把原数据放到bs里
func (secureSocket *SecureSocket) DecodeRead(conn *net.TCPConn, bs []byte) (n int, err error) {
// 设置读超时
conn.SetReadDeadline(time.Now().Add(TIMEOUT))
n, err = conn.Read(bs)
if err != nil {
return
Expand All @@ -33,6 +35,8 @@ func (secureSocket *SecureSocket) DecodeRead(conn *net.TCPConn, bs []byte) (n in
// 把放在bs里的数据加密后立即全部写入输出流
func (secureSocket *SecureSocket) EncodeWrite(conn *net.TCPConn, bs []byte) (int, error) {
secureSocket.Cipher.encode(bs)
// 设置写超时
conn.SetWriteDeadline(time.Now().Add(TIMEOUT))
return conn.Write(bs)
}

Expand Down Expand Up @@ -84,8 +88,8 @@ func (secureSocket *SecureSocket) DecodeCopy(dst *net.TCPConn, src *net.TCPConn)
}
}

// 和远程的socket建立连接,他们直接的数据传输会加密
func (secureSocket *SecureSocket) DialServer() (*net.TCPConn, error) {
// 和远程的socket建立连接,他们之间的数据传输会加密
func (secureSocket *SecureSocket) DialRemote() (*net.TCPConn, error) {
remoteConn, err := net.DialTCP("tcp", nil, secureSocket.RemoteAddr)
if err != nil {
return nil, errors.New(fmt.Sprintf("连接到远程服务器 %s 失败:%s", secureSocket.RemoteAddr, err))
Expand Down
23 changes: 12 additions & 11 deletions local/local.go
Expand Up @@ -3,13 +3,11 @@ package local
import (
"net"
"log"
"time"
"github.com/gwuhaolin/lightsocks/core"
)

type LsLocal struct {
*core.SecureSocket
AfterListen func(listenAddr net.Addr)
}

// 新建一个本地端
Expand All @@ -29,21 +27,22 @@ func New(password *core.Password, listenAddr, remoteAddr *net.TCPAddr) *LsLocal
}

// 本地端启动监听给用户的浏览器调用
func (local *LsLocal) Listen() error {
func (local *LsLocal) Listen(didListen func(listenAddr net.Addr)) error {
listener, err := net.ListenTCP("tcp", local.ListenAddr)
if err != nil {
return err
}

defer listener.Close()

if local.AfterListen != nil {
local.AfterListen(listener.Addr())
if didListen != nil {
didListen(listener.Addr())
}

for {
userConn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
// userConn被关闭时直接清除所有数据 不管没有发送的数据
Expand All @@ -55,15 +54,17 @@ func (local *LsLocal) Listen() error {

func (local *LsLocal) handleConn(userConn *net.TCPConn) {
defer userConn.Close()
server, err := local.DialServer()

proxyServer, err := local.DialRemote()
if err != nil {
log.Println(err)
return
}
defer server.Close()
server.SetLinger(0)
server.SetDeadline(time.Now().Add(core.TIMEOUT))
defer proxyServer.Close()
// Conn被关闭时直接清除所有数据 不管没有发送的数据
proxyServer.SetLinger(0)

// 进行转发
go local.EncodeCopy(server, userConn)
local.DecodeCopy(userConn, server)
go local.EncodeCopy(proxyServer, userConn)
local.DecodeCopy(userConn, proxyServer)
}
53 changes: 29 additions & 24 deletions server/server.go
Expand Up @@ -3,13 +3,12 @@ package server
import (
"net"
"encoding/binary"
"time"
"log"
"github.com/gwuhaolin/lightsocks/core"
)

type LsServer struct {
*core.SecureSocket
AfterListen func(listenAddr net.Addr)
}

// 新建一个服务端
Expand All @@ -27,34 +26,35 @@ func New(password *core.Password, listenAddr *net.TCPAddr) *LsServer {
}

// 运行服务端并且监听来自本地代理客户端的请求
func (server *LsServer) Listen() error {
listener, err := net.ListenTCP("tcp", server.ListenAddr)
func (lsServer *LsServer) Listen(didListen func(listenAddr net.Addr)) error {
listener, err := net.ListenTCP("tcp", lsServer.ListenAddr)
if err != nil {
return err
}

defer listener.Close()

if server.AfterListen != nil {
server.AfterListen(listener.Addr())
if didListen != nil {
didListen(listener.Addr())
}

for {
localConn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
// localConn被关闭时直接清除所有数据 不管没有发送的数据
localConn.SetLinger(0)
go server.handleConn(localConn)
go lsServer.handleConn(localConn)
}
return nil
}

// socks5实现
// https://www.ietf.org/rfc/rfc1928.txt
// http://www.jianshu.com/p/172810a70fad
func (server *LsServer) handleConn(localConn *net.TCPConn) {
func (lsServer *LsServer) handleConn(localConn *net.TCPConn) {
defer localConn.Close()
buf := make([]byte, 256)

Expand All @@ -71,7 +71,7 @@ func (server *LsServer) handleConn(localConn *net.TCPConn) {
appear in the METHODS field.
*/
// 第一个字段VER代表Socks的版本,Socks5默认为0x05,其固定长度为1个字节
_, err := server.DecodeRead(localConn, buf)
_, err := lsServer.DecodeRead(localConn, buf)
// 只支持版本5
if err != nil || buf[0] != 0x05 {
return
Expand All @@ -88,7 +88,7 @@ func (server *LsServer) handleConn(localConn *net.TCPConn) {
+----+--------+
*/
// 不需要验证,直接验证通过
server.EncodeWrite(localConn, []byte{0x05, 0x00})
lsServer.EncodeWrite(localConn, []byte{0x05, 0x00})

/**
+----+-----+-------+------+----------+----------+
Expand All @@ -105,7 +105,8 @@ func (server *LsServer) handleConn(localConn *net.TCPConn) {
return
}

n, err := server.DecodeRead(localConn, buf)
// 获取真正的远程服务的地址
n, err := lsServer.DecodeRead(localConn, buf)
// n 最短的长度为7 情况为 ATYP=3 DST.ADDR占用1字节 值为0x0
if err != nil || n < 7 {
return
Expand Down Expand Up @@ -134,25 +135,29 @@ func (server *LsServer) handleConn(localConn *net.TCPConn) {
IP: dIP,
Port: int(binary.BigEndian.Uint16(dPort)),
}
dstServer, err := net.DialTCP("tcp", nil, dstAddr)

/**
+----+-----+-------+------+----------+----------+
|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
+----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+----+-----+-------+------+----------+----------+
*/
// 连接真正的远程服务
dstServer, err := net.DialTCP("tcp", nil, dstAddr)
if err != nil {
return
} else {
defer dstServer.Close()
// 响应客户端连接成功
server.EncodeWrite(localConn, []byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
// Conn被关闭时直接清除所有数据 不管没有发送的数据
dstServer.SetLinger(0)
dstServer.SetDeadline(time.Now().Add(core.TIMEOUT))

// 响应客户端连接成功
/**
+----+-----+-------+------+----------+----------+
|VER | REP | RSV | ATYP | BND.ADDR | BND.PORT |
+----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+----+-----+-------+------+----------+----------+
*/
// 响应客户端连接成功
go lsServer.EncodeWrite(localConn, []byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
}

// 进行转发
go server.DecodeCopy(dstServer, localConn)
server.EncodeCopy(localConn, dstServer)
go lsServer.DecodeCopy(dstServer, localConn)
lsServer.EncodeCopy(localConn, dstServer)
}

0 comments on commit dfe7c2d

Please sign in to comment.