Skip to content

Commit

Permalink
Merge pull request #42 from getlantern/ox/wrapped
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Mar 2, 2023
2 parents 624088a + 1c13163 commit 6e479a5
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 39 deletions.
4 changes: 4 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ func newConn(c ptlshs.Conn, cfg *tls.Config, isClient bool, preshared ptlshs.Sec
return &conn{c, cfg, isClient, preshared, sync.Once{}, nil}
}

func (c *conn) Wrapped() net.Conn {
return c.Conn
}

func (c *conn) Read(b []byte) (n int, err error) {
if err := c.Handshake(); err != nil {
return 0, fmt.Errorf("handshake failed: %w", err)
Expand Down
11 changes: 10 additions & 1 deletion hijack.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func hijack(conn ptlshs.Conn, cfg *tls.Config, preshared ptlshs.Secret, client b
// successfully hijacked and further communication will be conducted with the appropriate
// version and suite, but newly-negotiated symmetric keys.
disguisedConn.shedDisguise()
return hijacked, nil
return &tlsConn{hijacked, conn}, nil
}

func ensureParameters(cfg *tls.Config, conn ptlshs.Conn) (*tls.Config, error) {
Expand Down Expand Up @@ -170,3 +170,12 @@ func (dc *disguisedConn) Write(b []byte) (n int, err error) {
}
return
}

type tlsConn struct {
*tls.Conn
wrapped net.Conn
}

func (conn *tlsConn) Wrapped() net.Conn {
return conn.wrapped
}
120 changes: 82 additions & 38 deletions ptlshs/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ func (s *connState) nextSeq() [8]byte {

type clientConn struct {
// The underlying connection to the server. This is likely just a TCP connection.
net.Conn
wrapped net.Conn
wrappedLock sync.RWMutex

cfg DialerConfig

Expand All @@ -91,21 +92,58 @@ type clientConn struct {
// Client initializes a client-side connection.
func Client(toServer net.Conn, cfg DialerConfig) Conn {
cfg = cfg.withDefaults()
return &clientConn{toServer, cfg, nil, nil, newOnce(), newOnce()}
return &clientConn{
wrapped: toServer,
cfg: cfg,
shakeOnce: newOnce(),
closeOnce: newOnce(),
}
}

func (c *clientConn) Wrapped() net.Conn {
c.wrappedLock.RLock()
defer c.wrappedLock.RUnlock()
return c.wrapped
}

func (c *clientConn) setWrapped(conn net.Conn) {
c.wrappedLock.Lock()
c.wrapped = conn
c.wrappedLock.Unlock()
}

func (c *clientConn) LocalAddr() net.Addr {
return c.Wrapped().LocalAddr()
}

func (c *clientConn) RemoteAddr() net.Addr {
return c.Wrapped().RemoteAddr()
}

func (c *clientConn) SetDeadline(t time.Time) error {
return c.Wrapped().SetDeadline(t)
}

func (c *clientConn) SetReadDeadline(t time.Time) error {
return c.Wrapped().SetReadDeadline(t)
}

func (c *clientConn) SetWriteDeadline(t time.Time) error {
return c.Wrapped().SetWriteDeadline(t)
}

func (c *clientConn) Read(b []byte) (n int, err error) {
if err := c.Handshake(); err != nil {
return 0, wrapError("handshake failed", err)
}
return c.Conn.Read(b)
return c.Wrapped().Read(b)
}

func (c *clientConn) Write(b []byte) (n int, err error) {
if err := c.Handshake(); err != nil {
return 0, wrapError("handshake failed", err)
}
return c.Conn.Write(b)
return c.Wrapped().Write(b)
}

// Handshake performs the ptlshs handshake protocol, if it has not yet been performed. Note that,
Expand Down Expand Up @@ -149,11 +187,11 @@ func (c *clientConn) handshake() error {
}
defer func() { transcriptDone = true }()

originalConn := c.Conn
c.Conn = mitm(c.Conn, onClientRead, nil)
defer func() { c.Conn = originalConn }()
originalConn := c.Wrapped()
c.setWrapped(mitm(originalConn, onClientRead, nil))
defer func() { c.setWrapped(originalConn) }()

hsResult, err := c.cfg.Handshaker.Handshake(c.Conn)
hsResult, err := c.cfg.Handshaker.Handshake(c.Wrapped())
if err != nil {
return err
}
Expand All @@ -173,7 +211,7 @@ func (c *clientConn) handshake() error {
if err != nil {
return fmt.Errorf("failed to create completion signal: %w", err)
}
if _, err = tlsutil.WriteRecord(c.Conn, *signal, tlsState); err != nil {
if _, err = tlsutil.WriteRecord(c.Wrapped(), *signal, tlsState); err != nil {
return wrapError("failed to signal completion", err)
}
// The watchForCompletionFunction needs direct control over what is written to the transcript.
Expand Down Expand Up @@ -203,7 +241,7 @@ func (c *clientConn) watchForCompletion(tlsState *tlsutil.ConnectionState, trans
readBuf.Write(b)
return nil
}
conn := mitm(c.Conn, onRead, nil)
conn := mitm(c.Wrapped(), onRead, nil)

// We attempt to decrypt every record we see from the server. We assume that any records we are
// unable to decrypt must have come from the origin. The first record we successfully decrypt
Expand Down Expand Up @@ -240,7 +278,7 @@ func (c *clientConn) watchForCompletion(tlsState *tlsutil.ConnectionState, trans
return fmt.Errorf("server signal contains bad transcript MAC")
}
// Put unprocessed post-signal data back on the connection.
c.Conn = preconn.WrapReader(c.Conn, unprocessedBuf)
c.setWrapped(preconn.WrapReader(c.Wrapped(), unprocessedBuf))
return nil
}
}
Expand Down Expand Up @@ -279,7 +317,7 @@ func (c *clientConn) IV() [16]byte {

func (c *clientConn) Close() error {
return c.closeOnce.do(func() error {
return c.Conn.Close()
return c.Wrapped().Close()
})
}

Expand Down Expand Up @@ -317,7 +355,7 @@ func serverConnWithCache(toClient net.Conn, cfg ListenerConfig, cache *nonceCach
newOnce(), newOnce(), newDeadline(), newDeadline(), sync.Mutex{}}
}

func (c *serverConn) getWrapped() net.Conn {
func (c *serverConn) Wrapped() net.Conn {
c.wrappedLock.RLock()
defer c.wrappedLock.RUnlock()
return c.wrapped
Expand All @@ -337,7 +375,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
if err != nil {
return 0, fmt.Errorf("handshake failed: %w", err)
}
return c.getWrapped().Read(b)
return c.Wrapped().Read(b)
case <-c.hsReadDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
Expand All @@ -351,7 +389,7 @@ func (c *serverConn) Write(b []byte) (n int, err error) {
if err != nil {
return 0, fmt.Errorf("handshake failed: %w", err)
}
return c.getWrapped().Write(b)
return c.Wrapped().Write(b)
case <-c.hsWriteDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
Expand All @@ -362,7 +400,7 @@ func (c *serverConn) Close() error {
if c.closeCache {
c.nonceCache.close()
}
return c.getWrapped().Close()
return c.Wrapped().Close()
})
}

Expand All @@ -376,7 +414,7 @@ func (c *serverConn) SetDeadline(t time.Time) error {
c.hsWriteDeadline.set(t)
return nil
}
return c.getWrapped().SetDeadline(t)
return c.Wrapped().SetDeadline(t)
}

func (c *serverConn) SetReadDeadline(t time.Time) error {
Expand All @@ -387,7 +425,7 @@ func (c *serverConn) SetReadDeadline(t time.Time) error {
c.hsReadDeadline.set(t)
return nil
}
return c.getWrapped().SetReadDeadline(t)
return c.Wrapped().SetReadDeadline(t)
}

func (c *serverConn) SetWriteDeadline(t time.Time) error {
Expand All @@ -398,15 +436,15 @@ func (c *serverConn) SetWriteDeadline(t time.Time) error {
c.hsWriteDeadline.set(t)
return nil
}
return c.getWrapped().SetWriteDeadline(t)
return c.Wrapped().SetWriteDeadline(t)
}

func (c *serverConn) LocalAddr() net.Addr {
return c.getWrapped().LocalAddr()
return c.Wrapped().LocalAddr()
}

func (c *serverConn) RemoteAddr() net.Addr {
return c.getWrapped().RemoteAddr()
return c.Wrapped().RemoteAddr()
}

// Handshake performs the ptlshs handshake protocol, if it has not yet been performed. Note that,
Expand Down Expand Up @@ -449,23 +487,23 @@ func (c *serverConn) handshake() error {

transcriptHMAC := signalHMAC(c.cfg.Secret)
transcriptLock := new(sync.Mutex)
originalWrapped := c.getWrapped()
originalWrapped := c.Wrapped()
onClientWrite := func(b []byte) error {
transcriptLock.Lock()
transcriptHMAC.Write(b)
transcriptLock.Unlock()
return nil
}
c.setWrapped(mitm(c.getWrapped(), nil, onClientWrite))
c.setWrapped(mitm(c.Wrapped(), nil, onClientWrite))
stopTranscript := func() { c.setWrapped(originalWrapped) }
defer stopTranscript() // may end up a no-op, but that's okay

// Read and copy ClientHello.
b, err := readClientHello(ctx, c.getWrapped(), listenerReadBufferSize)
b, err := readClientHello(ctx, c.Wrapped(), listenerReadBufferSize)
if err != nil && !errors.As(err, new(networkError)) {
// Client sent something other than ClientHello. Proxy everything to match origin behavior.
stopTranscript()
proxyUntilClose(ctx, preconn.Wrap(c.getWrapped(), b), origin)
proxyUntilClose(ctx, preconn.Wrap(c.Wrapped(), b), origin)
return fmt.Errorf("did not receive ClientHello: %w", err)
}
if err != nil {
Expand All @@ -481,7 +519,7 @@ func (c *serverConn) handshake() error {
if err != nil && !errors.As(err, new(networkError)) {
// Origin sent something other than ServerHello. Proxy everything to match origin behavior.
stopTranscript()
proxyUntilClose(ctx, c.getWrapped(), preconn.Wrap(origin, b))
proxyUntilClose(ctx, c.Wrapped(), preconn.Wrap(origin, b))
return fmt.Errorf("did not receive ServerHello: %w", err)
}
if err != nil {
Expand All @@ -492,7 +530,7 @@ func (c *serverConn) handshake() error {
if err != nil {
return fmt.Errorf("failed to init conn state based on hello info: %w", err)
}
_, err = c.getWrapped().Write(b)
_, err = c.Wrapped().Write(b)
if err != nil {
return fmt.Errorf("failed to write to client: %w", err)
}
Expand All @@ -510,15 +548,15 @@ func (c *serverConn) handshake() error {
return fmt.Errorf("failed to create completion signal: %w", err)
}

_, err = tlsutil.WriteRecord(c.getWrapped(), *signal, tlsState)
_, err = tlsutil.WriteRecord(c.Wrapped(), *signal, tlsState)
if err != nil {
return fmt.Errorf("failed to signal completion: %w", err)
}

// Transfer handshake deadlines to the underlying connection.
c.deadlinesLock.Lock()
c.getWrapped().SetReadDeadline(c.hsReadDeadline.get())
c.getWrapped().SetWriteDeadline(c.hsWriteDeadline.get())
c.Wrapped().SetReadDeadline(c.hsReadDeadline.get())
c.Wrapped().SetWriteDeadline(c.hsWriteDeadline.get())
c.hsReadDeadline.close()
c.hsWriteDeadline.close()
c.deadlinesLock.Unlock()
Expand All @@ -537,7 +575,7 @@ func (c *serverConn) watchForCompletion(ctx context.Context, bufferSize int,
// state. If we see the signal, we close the connection with the origin. Otherwise, we continue
// to proxy.

toClient, toOrigin := newCancelConn(c.getWrapped()), newCancelConn(originConn)
toClient, toOrigin := newCancelConn(c.Wrapped()), newCancelConn(originConn)

nonFatalError := func(err error) {
select {
Expand Down Expand Up @@ -577,7 +615,7 @@ func (c *serverConn) watchForCompletion(ctx context.Context, bufferSize int,
// We also need to ensure the unprocessed post-signal data is not lost. We prepend it to
// the client connection. Access to c.Conn is single-threaded until the handshake is
// complete, so this is safe to do without synchronization.
c.setWrapped(preconn.Wrap(c.getWrapped(), postSignal))
c.setWrapped(preconn.Wrap(c.Wrapped(), postSignal))
}

rr := new(recordReader)
Expand Down Expand Up @@ -666,9 +704,10 @@ func (c *serverConn) IV() [16]byte {
}

// Continuously reads off of conn until one of the following:
// - The bytes read constitute a valid TLS ClientHello.
// - The bytes read could not possibly constitute a valid TLS ClientHello.
// - A non-temporary network error is encountered.
// - The bytes read constitute a valid TLS ClientHello.
// - The bytes read could not possibly constitute a valid TLS ClientHello.
// - A non-temporary network error is encountered.
//
// Whatever was read is always returned.
func readClientHello(ctx context.Context, conn net.Conn, bufferSize int) ([]byte, error) {
var (
Expand Down Expand Up @@ -705,9 +744,10 @@ func readClientHello(ctx context.Context, conn net.Conn, bufferSize int) ([]byte
}

// Continuously reads off of conn until one of the following:
// - The bytes read constitute a valid TLS ServerHello.
// - The bytes read could not possibly constitute a valid TLS ServerHello.
// - A non-temporary network error is encountered.
// - The bytes read constitute a valid TLS ServerHello.
// - The bytes read could not possibly constitute a valid TLS ServerHello.
// - A non-temporary network error is encountered.
//
// Whatever was read is always returned. When a valid ServerHello is read, it is parsed and used to
// create a connection state.
func readServerHello(ctx context.Context, conn net.Conn, bufferSize int) ([]byte, *connState, error) {
Expand Down Expand Up @@ -788,6 +828,10 @@ func mitm(conn net.Conn, onRead, onWrite func([]byte) error) mitmConn {
return mitmConn{conn, onRead, onWrite}
}

func (c mitmConn) Wrapped() net.Conn {
return c.Conn
}

func (c mitmConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
if n > 0 {
Expand Down

0 comments on commit 6e479a5

Please sign in to comment.