Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in socks5 when rendering remote address #3110

Merged
merged 2 commits into from Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 26 additions & 18 deletions lib/utils/socks/socks.go
Expand Up @@ -19,6 +19,7 @@ package socks

import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
Expand Down Expand Up @@ -155,14 +156,7 @@ func readRequest(conn net.Conn) (string, error) {

// Read in the address type and determine how many more bytes need to be
// read in to read in the remote host address.
addrLen, err := readAddrType(conn)
if err != nil {
return "", trace.Wrap(err)
}

// Read in the destination address.
destAddr := make([]byte, addrLen)
_, err = io.ReadFull(conn, destAddr)
destAddr, err := readDestAddr(conn)
if err != nil {
return "", trace.Wrap(err)
}
Expand All @@ -174,34 +168,48 @@ func readRequest(conn net.Conn) (string, error) {
return "", trace.Wrap(err)
}

return net.JoinHostPort(string(destAddr), strconv.Itoa(int(destPort))), nil
return net.JoinHostPort(destAddr, strconv.Itoa(int(destPort))), nil
}

// readAddrType reads in the address type and returns the length of the dest
// addr field.
func readAddrType(conn net.Conn) (int, error) {
// readDestAddr reads in the destination address.
func readDestAddr(conn net.Conn) (string, error) {
// Read in the type of the remote host.
addrType, err := readByte(conn)
if err != nil {
return 0, trace.Wrap(err)
return "", trace.Wrap(err)
}

// Based off the type, determine how many more bytes to read in for the
// remote address. For IPv4 it's 4 bytes, for IPv6 it's 16, and for domain
// names read in another byte to determine the length of the field.
switch addrType {
case socks5AddressTypeIPv4:
return net.IPv4len, nil
destAddr := make([]byte, net.IPv4len)
_, err = io.ReadFull(conn, destAddr)
if err != nil {
return "", trace.Wrap(err)
}
return fmt.Sprintf("%s", net.IP(destAddr)), nil
case socks5AddressTypeIPv6:
return net.IPv6len, nil
destAddr := make([]byte, net.IPv6len)
_, err = io.ReadFull(conn, destAddr)
if err != nil {
return "", trace.Wrap(err)
}
return fmt.Sprintf("%s", net.IP(destAddr)), nil
case socks5AddressTypeDomainName:
len, err := readByte(conn)
if err != nil {
return 0, trace.Wrap(err)
return "", trace.Wrap(err)
}
destAddr := make([]byte, len)
_, err = io.ReadFull(conn, destAddr)
if err != nil {
return "", trace.Wrap(err)
}
return int(len), nil
return string(destAddr), nil
default:
return 0, trace.BadParameter("unsupported address type: %v", addrType)
return "", trace.BadParameter("unsupported address type: %v", addrType)
}
}

Expand Down
35 changes: 20 additions & 15 deletions lib/utils/socks/socks_test.go
Expand Up @@ -46,7 +46,10 @@ func (s *SOCKSSuite) SetUpTest(c *check.C) {}
func (s *SOCKSSuite) TearDownTest(c *check.C) {}

func (s *SOCKSSuite) TestHandshake(c *check.C) {
remoteAddr := "example.com:443"
remoteAddrs := []string{
"example.com:443",
"9.8.7.6:443",
}

// Create and start a debug SOCKS5 server that calls socks.Handshake().
socksServer, err := newDebugServer()
Expand All @@ -57,20 +60,22 @@ func (s *SOCKSSuite) TestHandshake(c *check.C) {
proxy, err := proxy.SOCKS5("tcp", socksServer.Addr().String(), nil, nil)
c.Assert(err, check.IsNil)

// Connect to the SOCKS5 server, this is where the handshake function is called.
conn, err := proxy.Dial("tcp", remoteAddr)
c.Assert(err, check.IsNil)

// Read in what was written on the connection. With the debug server it's
// always the address requested.
buf := make([]byte, len(remoteAddr))
_, err = io.ReadFull(conn, buf)
c.Assert(err, check.IsNil)
c.Assert(string(buf), check.Equals, remoteAddr)

// Close and cleanup.
err = conn.Close()
c.Assert(err, check.IsNil)
for _, remoteAddr := range remoteAddrs {
// Connect to the SOCKS5 server, this is where the handshake function is called.
conn, err := proxy.Dial("tcp", remoteAddr)
c.Assert(err, check.IsNil)

// Read in what was written on the connection. With the debug server it's
// always the address requested.
buf := make([]byte, len(remoteAddr))
_, err = io.ReadFull(conn, buf)
c.Assert(err, check.IsNil)
c.Assert(string(buf), check.Equals, remoteAddr)

// Close and cleanup.
err = conn.Close()
c.Assert(err, check.IsNil)
}
}

// debugServer is a debug SOCKS5 server that performs a SOCKS5 handshake
Expand Down