Skip to content

Commit

Permalink
Use io.ReadAtLeast to simplify getRequest in server.go
Browse files Browse the repository at this point in the history
  • Loading branch information
cyfdecyf committed Dec 15, 2012
1 parent a18ed95 commit d9e55b3
Showing 1 changed file with 27 additions and 39 deletions.
66 changes: 27 additions & 39 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"flag"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
"io"
"log"
"net"
"strconv"
Expand All @@ -15,7 +16,7 @@ import (

var debug ss.DebugLog

var errAddr = errors.New("addr type not supported")
var errAddrType = errors.New("addr type not supported")

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
Expand All @@ -35,58 +36,45 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260, 260)
cur := 0 // current location in buf
var n int
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
return
}

// first read the complete request, may read extra bytes
for {
// hopefully, we should only need one read to get the complete request
// this read normally will read just the request, no extra data
reqLen := lenIP
if buf[idType] == typeDm {
reqLen = int(buf[idDmLen]) + lenDmBase
} else if buf[idType] != typeIP {
err = errAddrType
return
}

if n < reqLen { // rare case
ss.SetReadTimeout(conn)
var n int
if n, err = conn.Read(buf[cur:]); err != nil {
// debug.Println("read request error:", err)
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
cur += n
if buf[idType] == typeIP {
if cur >= lenIP {
// debug.Println("ip request complete, cur:", cur)
break
}
} else if buf[idType] == typeDm {
if cur < idDmLen+1 { // read until we get address length byte
continue
}
if cur >= lenDmBase+int(buf[idDmLen]) {
// debug.Println("domain request complete, cur:", cur)
break
}
} else {
err = errAddr
return
}
// debug.Println("request not complete, cur:", cur)
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
}

reqLen := lenIP // default to IP request length
if buf[idType] == typeIP {
// TODO add ipv6 support
if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
} else if buf[idType] == typeIP {
addrIp := make(net.IP, 4)
copy(addrIp, buf[idIP0:idIP0+4])
host = addrIp.String()
} else if buf[idType] == typeDm {
reqLen = lenDmBase + int(buf[idDmLen])
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
var port int16
sb := bytes.NewBuffer(buf[reqLen-2 : reqLen])
binary.Read(sb, binary.BigEndian, &port)

// debug.Println("requesting:", host, "header len", reqLen)
host += ":" + strconv.Itoa(int(port))
if cur > reqLen {
extra = buf[reqLen:cur]
// debug.Println("extra:", string(extra))
}
return
}

Expand All @@ -112,7 +100,7 @@ func handleConnection(conn *ss.Conn) {
defer remote.Close()
// write extra bytes read from
if extra != nil {
debug.Println("writing extra content to remote, len", len(extra))
debug.Println("getRequest read extra data, writing to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
Expand Down

0 comments on commit d9e55b3

Please sign in to comment.