Skip to content

Commit

Permalink
Use io.ReadAtLeast to simplify handshake and getRequest in local.go
Browse files Browse the repository at this point in the history
  • Loading branch information
cyfdecyf committed Dec 15, 2012
1 parent 9c448db commit a18ed95
Showing 1 changed file with 66 additions and 72 deletions.
138 changes: 66 additions & 72 deletions cmd/shadowsocks-local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ import (
var debug ss.DebugLog

var (
errAddr = errors.New("socks addr type not supported")
errVer = errors.New("socks version not supported")
errMethod = errors.New("socks only support 1 method now")
errAuth = errors.New("socks authentication not required")
errCmd = errors.New("socks command not supported")
errAddrType = errors.New("socks addr type not supported")
errVer = errors.New("socks version not supported")
errMethod = errors.New("socks only support 1 method now")
errAuthExtraData = errors.New("socks authentication get extra data")
errReqExtraData = errors.New("socks request get extra data")
errCmd = errors.New("socks command not supported")
)

const (
socksVer5 = 5
socksCmdConnect = 1
)

func handShake(conn net.Conn) (err error) {
Expand All @@ -29,26 +35,36 @@ func handShake(conn net.Conn) (err error) {
)
// version identification and method selection message in theory can have
// at most 256 methods, plus version and nmethod field in total 258 bytes
// the current rfc defines only 3 authentication methods (plus 2 reserved)
// the current rfc defines only 3 authentication methods (plus 2 reserved),
// so it won't be such long in practice

buf := make([]byte, 258-2, 258-2) // reuse the buf to read nmethod field
buf := make([]byte, 258, 258)

if _, err = io.ReadFull(conn, buf[:2]); err != nil {
var n int
// make sure we get the nmethod field
if n, err = io.ReadAtLeast(conn, buf, idNmethod+1); err != nil {
return
}
if buf[idVer] != 5 {
if buf[idVer] != socksVer5 {
return errVer
}
nmethod := buf[idNmethod]
if _, err = io.ReadFull(conn, buf[:nmethod]); err != nil {
return
nmethod := int(buf[idNmethod])
msgLen := nmethod + 2
if n == msgLen { // handshake done, common case
// do nothing, jump directly to send confirmation
} else if n < msgLen { // has more methods to read, rare case
if _, err = io.ReadFull(conn, buf[n:msgLen]); err != nil {
return
}
} else { // error, should not get extra data
return errAuthExtraData
}
// version 5, no authentication required
_, err = conn.Write([]byte{5, 0})
// send confirmation: version 5, no authentication required
_, err = conn.Write([]byte{socksVer5, 0})
return
}

func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err error) {
func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) {
const (
idVer = 0
idCmd = 1
Expand All @@ -65,71 +81,56 @@ func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err e
)
// refer to getRequest in server.go for why set buffer size to 263
buf := make([]byte, 263, 263)
cur := 0 // current location in buf
reqLen := 0
var n int
// read till we get possible domain length field
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
return
}
// check version and cmd
if buf[idVer] != socksVer5 {
err = errVer
return
}
if buf[idCmd] != socksCmdConnect {
err = errCmd
return
}

for {
var n int
// usually need to read only once
if n, err = conn.Read(buf[cur:]); err != nil {
// debug.Println("read request error:", err)
return
}
cur += n
if cur < idType+1 { // read till we get addr type
continue
}
// check version and cmd
if buf[idVer] != 5 {
err = errVer
return
}
if buf[idCmd] != 1 {
err = errCmd
return
}
// TODO following code is copied from server.go, fix code duplication?
if buf[idType] == typeIP {
if cur >= lenIP {
// debug.Println("ip request complete, cur:", cur)
reqLen = lenIP
break
}
} else if buf[idType] == typeDm {
if cur < idDmLen+1 { // read until we get address length byte
continue
}
if cur >= lenDmBase+int(buf[idDmLen]) {
// debug.Println("domain request complete, cur:", cur)
reqLen = lenDmBase + int(buf[idDmLen])
break
}
} else {
err = errAddr
reqLen := lenIP
if buf[idType] == typeDm {
reqLen = int(buf[idDmLen]) + lenDmBase
} else if buf[idType] != typeIP {
err = errAddrType
return
}

if n == reqLen {
// common case, do nothing
} else if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
// debug.Println("request not complete, cur:", cur)
} else {
err = errReqExtraData
return
}

rawaddr = buf[idType:reqLen]
if cur > reqLen {
extra = buf[reqLen:cur]
// debug.Println("extra:", string(extra))
}

if debug {
if buf[idType] == typeIP {
if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
} else if buf[idType] == typeIP {
addrIp := make(net.IP, 4)
copy(addrIp, buf[idIP0:idIP0+4])
host = addrIp.String()
} else if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
var port int16
sb := bytes.NewBuffer(buf[reqLen-2 : reqLen])
binary.Read(sb, binary.BigEndian, &port)
host += ":" + strconv.Itoa(int(port))
}

return
}

Expand All @@ -144,7 +145,7 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) {
log.Println("socks handshack:", err)
return
}
rawaddr, extra, addr, err := getRequest(conn)
rawaddr, addr, err := getRequest(conn)
if err != nil {
log.Println("error getting request:", err)
return
Expand All @@ -163,13 +164,6 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) {
return
}
defer remote.Close()
if extra != nil {
debug.Println("writing extra content to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}

c := make(chan byte, 2)
go ss.Pipe(conn, remote, c)
Expand Down Expand Up @@ -199,7 +193,7 @@ func main() {
var configFile string
flag.StringVar(&configFile, "c", "config.json", "specify config file")
flag.Parse()

config := ss.ParseConfig(configFile)
debug = ss.Debug
run(strconv.Itoa(config.LocalPort), config.Password,
Expand Down

0 comments on commit a18ed95

Please sign in to comment.