Skip to content

Commit

Permalink
feat: finish writing tests for new package
Browse files Browse the repository at this point in the history
  • Loading branch information
bassosimone committed Jun 10, 2021
1 parent b19ed2e commit e3f5c01
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 18 deletions.
4 changes: 2 additions & 2 deletions internal/cmd/ptxclient/ptxclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func main() {
os.Exit(1)
}
listener := &ptx.Listener{
ContextDialer: dialer,
Logger: log.Log,
PTDialer: dialer,
Logger: log.Log,
}
if err := listener.Start(); err != nil {
log.WithError(err).Fatal("listener.Start failed")
Expand Down
88 changes: 72 additions & 16 deletions internal/ptx/ptx.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ type PTDialer interface {
// you fill the mandatory fields before using. Do not modify public
// fields after you called Start, since this causes data races.
type Listener struct {
// ContextDialer is the MANDATORY pluggable transports dialer
// PTDialer is the MANDATORY pluggable transports dialer
// to use. Both SnowflakeDialer and OBFS4Dialer implement this
// interface and can be thus safely used here.
ContextDialer PTDialer
PTDialer PTDialer

// Logger is the optional logger. When not set, this library
// will not emit logs. (But the underlying pluggable transport
Expand All @@ -80,11 +80,17 @@ type Listener struct {
// mu provides mutual exclusion for accessing internals.
mu sync.Mutex

// cancel allows to stop the listener.
// cancel allows to stop the forwarders.
cancel context.CancelFunc

// laddr is the listen address.
laddr net.Addr

// listener allows us to stop the listener.
listener ptxSocksListener

// overrideListenSocks allows us to override pt.ListenSocks.
overrideListenSocks func(network string, laddr string) (ptxSocksListener, error)
}

// logger returns the Logger, if set, or the defaultLogger.
Expand Down Expand Up @@ -131,35 +137,55 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con
// handleSocksConn handles a new SocksConn connection by establishing
// the corresponding PT connection and forwarding traffic. This
// function TAKES OWNERSHIP of the socksConn argument.
func (lst *Listener) handleSocksConn(ctx context.Context, socksConn *pt.SocksConn) {
func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn) error {
err := socksConn.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
lst.logger().Warnf("ptx: socksConn.Grant error: %s", err)
return
return err // used for testing
}
ptConn, err := lst.ContextDialer.DialContext(ctx)
ptConn, err := lst.PTDialer.DialContext(ctx)
if err != nil {
socksConn.Close() // we own it
lst.logger().Warnf("ptx: ContextDialer.DialContext error: %s", err)
return
return err // used for testing
}
lst.forwardWithContext(ctx, socksConn, ptConn)
return nil // used for testing
}

// ptxSocksListener is a pt.SocksListener-like structure.
type ptxSocksListener interface {
// AcceptSocks accepts a socks conn
AcceptSocks() (ptxSocksConn, error)

// Addr returns the listening address.
Addr() net.Addr

// Close closes the listener
Close() error
}

// ptxSocksConn is a pt.SocksConn-like structure.
type ptxSocksConn interface {
// net.Conn is the embedded interface.
net.Conn

// Grants access to a specific IP address.
Grant(addr *net.TCPAddr) error
}

// acceptLoop accepts and handles local socks connection. This function
// TAKES OWNERSHIP of the socks listener.
func (lst *Listener) acceptLoop(ctx context.Context, ln *pt.SocksListener) {
defer ln.Close()
// DOES NOT take ownership of the socks listener.
func (lst *Listener) acceptLoop(ctx context.Context, ln ptxSocksListener) {
for {
conn, err := ln.AcceptSocks()
if err != nil {
if err, ok := err.(net.Error); ok && err.Temporary() {
continue
}
lst.logger().Warnf("ptx: socks accept error: %s", err)
return
break
}
lst.logger().Infof("ptx: SOCKS accepted: %v", conn.Req)
go lst.handleSocksConn(ctx, conn)
}
}
Expand All @@ -184,19 +210,46 @@ func (lst *Listener) Start() error {
return nil // already started
}
// TODO(bassosimone): be able to recover when SOCKS dies?
ln, err := pt.ListenSocks("tcp", "127.0.0.1:0")
ln, err := lst.listenSocks("tcp", "127.0.0.1:0")
if err != nil {
return err
}
lst.laddr = ln.Addr()
ctx, cancel := context.WithCancel(context.Background())
lst.cancel = cancel
lst.listener = ln
go lst.acceptLoop(ctx, ln)
lst.logger().Infof("ptx: started socks listener at %v", ln.Addr())
lst.logger().Debugf("ptx: test with `%s`", lst.torCmdLine())
return nil
}

// listenSocks calles either pt.ListenSocks or lst.overrideListenSocks.
func (lst *Listener) listenSocks(network string, laddr string) (ptxSocksListener, error) {
if lst.overrideListenSocks != nil {
return lst.overrideListenSocks(network, laddr)
}
return lst.castListener(pt.ListenSocks(network, laddr))
}

// castListener casts a pt.SocksListener to ptxSocksListener.
func (lst *Listener) castListener(in *pt.SocksListener, err error) (ptxSocksListener, error) {
if err != nil {
return nil, err
}
return &ptxSocksListenerAdapter{in}, nil
}

// ptxSocksListenerAdapter adapts pt.SocksListener to ptxSocksListener.
type ptxSocksListenerAdapter struct {
*pt.SocksListener
}

// AcceptSocks adapts pt.SocksListener.AcceptSocks to ptxSockListener.AcceptSocks.
func (la *ptxSocksListenerAdapter) AcceptSocks() (ptxSocksConn, error) {
return la.SocksListener.AcceptSocks()
}

// torCmdLine prints the command line for testing this listener.
func (lst *Listener) torCmdLine() string {
return strings.Join([]string{
Expand All @@ -208,19 +261,22 @@ func (lst *Listener) torCmdLine() string {
"ClientTransportPlugin",
"'" + lst.AsClientTransportPluginArgument() + "'",
"Bridge",
"'" + lst.ContextDialer.AsBridgeArgument() + "'",
"'" + lst.PTDialer.AsBridgeArgument() + "'",
}, " ")
}

// Stop stops the pluggable transport. This method is idempotent
// and asks the background goroutine to stop just once. Also, this
// and asks the background goroutine(s) to stop just once. Also, this
// method is safe to call from any goroutine.
func (lst *Listener) Stop() {
defer lst.mu.Unlock()
lst.mu.Lock()
if lst.cancel != nil {
lst.cancel() // cancel is idempotent
}
if lst.listener != nil {
lst.listener.Close() // should be idempotent
}
}

// AsClientTransportPluginArgument converts the current configuration
Expand All @@ -242,5 +298,5 @@ func (lst *Listener) Stop() {
// in the returned string. In fact, ClientTransportPlugin and its
// arguments need to be two consecutive argv strings.
func (lst *Listener) AsClientTransportPluginArgument() string {
return fmt.Sprintf("%s socks5 %s", lst.ContextDialer.Name(), lst.laddr.String())
return fmt.Sprintf("%s socks5 %s", lst.PTDialer.Name(), lst.laddr.String())
}
Loading

0 comments on commit e3f5c01

Please sign in to comment.