Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Add fixed header and write skipping to multiplexer #35859

Merged
merged 4 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func (t *TLSServer) Serve(listener net.Listener, options ...ServeOption) error {
// It's required to accommodate setups with high latency and where the time
// between the TCP being accepted and the time for the first byte is longer
// than the default value - 1s.
ReadDeadline: 10 * time.Second,
DetectTimeout: 10 * time.Second,
}
for _, opt := range options {
opt(&muxConfig)
Expand Down
42 changes: 29 additions & 13 deletions lib/multiplexer/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ type Config struct {
Listener net.Listener
// Context is a context to signal stops, cancellations
Context context.Context
// ReadDeadline is a connection read deadline,
// set to defaults.ReadHeadersTimeout if unspecified
ReadDeadline time.Duration
// DetectTimeout is a timeout applied to the whole detection phase of the
// connection, set to defaults.ReadHeadersTimeout if unspecified
DetectTimeout time.Duration
// Clock is a clock to override in tests, set to real time clock
// by default
Clock clockwork.Clock
Expand All @@ -98,6 +98,12 @@ type Config struct {
// connection (coming from same IP as the listening address) when deciding if it should drop connection with
// missing required PROXY header. This is needed since all connections in tests are self connections.
IgnoreSelfConnections bool

// FixedHeader contains data that's sent to the client at the beginning of
// every connection, before protocol detection happens. An equal amount of
// data is then skipped from the connection when the application writes into
// it. Mostly useful for SSH servers.
FixedHeader string
}

// CheckAndSetDefaults verifies configuration and sets defaults
Expand All @@ -108,8 +114,8 @@ func (c *Config) CheckAndSetDefaults() error {
if c.Context == nil {
c.Context = context.TODO()
}
if c.ReadDeadline == 0 {
c.ReadDeadline = defaults.ReadHeadersTimeout
if c.DetectTimeout == 0 {
c.DetectTimeout = defaults.ReadHeadersTimeout
}
if c.Clock == nil {
c.Clock = clockwork.NewRealClock()
Expand Down Expand Up @@ -277,13 +283,22 @@ func (m *Mux) protocolListener(proto Protocol) *Listener {
// protocol without a registered protocol listener are closed. This
// method is called as a goroutine by Serve for each connection.
func (m *Mux) detectAndForward(conn net.Conn) {
err := conn.SetReadDeadline(m.Clock.Now().Add(m.ReadDeadline))
if err != nil {
if err := conn.SetDeadline(m.Clock.Now().Add(m.DetectTimeout)); err != nil {
m.Warning(err.Error())
conn.Close()
return
}

if m.FixedHeader != "" {
if _, err := conn.Write([]byte(m.FixedHeader)); err != nil {
if !utils.IsOKNetworkError(err) {
m.WithError(err).Warn("Failed to send connection header.")
}
conn.Close()
return
}
}

connWrapper, err := m.detect(conn)
if err != nil {
if trace.Unwrap(err) != io.EOF {
Expand All @@ -295,8 +310,8 @@ func (m *Mux) detectAndForward(conn net.Conn) {
conn.Close()
return
}
err = conn.SetReadDeadline(time.Time{})
if err != nil {

if err := connWrapper.SetDeadline(time.Time{}); err != nil {
m.Warning(trace.DebugReport(err))
connWrapper.Close()
return
Expand Down Expand Up @@ -568,10 +583,11 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) {
}

return &Conn{
protocol: proto,
Conn: conn,
reader: reader,
proxyLine: proxyLine,
protocol: proto,
Conn: conn,
reader: reader,
proxyLine: proxyLine,
alreadyWritten: []byte(m.FixedHeader),
}, nil
}
}
Expand Down
52 changes: 51 additions & 1 deletion lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func TestMux(t *testing.T) {
Listener: listener,
// Set read deadline in the past to remove reliance on real time
// and simulate scenario when read deadline has elapsed.
ReadDeadline: -time.Millisecond,
DetectTimeout: -time.Millisecond,
}
mux, err := New(config)
require.NoError(t, err)
Expand Down Expand Up @@ -1393,3 +1393,53 @@ func TestIsDifferentTCPVersion(t *testing.T) {
fmt.Sprintf("Unexpected result for %q, %q", tt.addr1, tt.addr2))
}
}

func TestFixedHeader(t *testing.T) {
t.Parallel()
require := require.New(t)

listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(err)
t.Cleanup(func() { listener.Close() })

const defaultSSHVersionIdentifier = "SSH-2.0-Go\r\n"
mux, err := New(Config{
Listener: listener,
FixedHeader: defaultSSHVersionIdentifier,
})
require.NoError(err)
t.Cleanup(func() { mux.Close() })
go mux.Serve()

go startSSHServer(t, mux.SSH())

netConn, err := net.DialTimeout(listener.Addr().Network(), listener.Addr().String(), 5*time.Second)
require.NoError(err)
t.Cleanup(func() { netConn.Close() })

// the SSH transport layer protocol rfc (5423) states that SSH servers must
// send a version string immediately after the connection is established, so
// we expect (a specific) version string without sending anything
buf := make([]byte, len(defaultSSHVersionIdentifier))
_, err = io.ReadFull(netConn, buf)
require.NoError(err)
require.Equal(defaultSSHVersionIdentifier, string(buf))

// the SSH server hasn't even been touched yet, so we can connect to it from
// a separate connection (we have to, in fact, or startSSHServer will fail
// the test)

sshClient, err := ssh.Dial(listener.Addr().Network(), listener.Addr().String(), &ssh.ClientConfig{
User: "bob",
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
})
require.NoError(err)
t.Cleanup(func() { sshClient.Close() })

const payload = "this is a bit useless since we already went through a full handshake"
ok, echoReply, err := sshClient.Conn.SendRequest("echo", true, []byte(payload))
require.NoError(err)
require.True(ok)
require.Equal(payload, string(echoReply))
}
31 changes: 31 additions & 0 deletions lib/multiplexer/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -124,3 +125,33 @@ func TestPROXYEnabledListener_Accept(t *testing.T) {
})
}
}

func TestAlreadyWritten(t *testing.T) {
require := require.New(t)

c := &Conn{
Conn: zeroConn{},
alreadyWritten: []byte("aa"),
}

n, err := c.Write([]byte("a"))
require.NoError(err)
require.Equal(1, n)
require.Equal([]byte("a"), c.alreadyWritten)

n, err = c.Write([]byte("b"))
require.Error(err)
require.ErrorAs(err, new(*trace.BadParameterError))
require.Equal(0, n)

n, err = c.Write([]byte("ab"))
require.NoError(err)
require.Equal(2, n)
require.Empty(c.alreadyWritten)
}

type zeroConn struct{ net.Conn }

func (zeroConn) Write(p []byte) (int, error) {
return len(p), nil
}
31 changes: 31 additions & 0 deletions lib/multiplexer/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package multiplexer

import (
"bufio"
"bytes"
"context"
"net"

Expand All @@ -35,6 +36,11 @@ type Conn struct {
protocol Protocol
proxyLine *ProxyLine
reader *bufio.Reader

// alreadyWritten is a slice of data that we expect the application to
// Write() on the connection (because it was already sent on the wire). As
// the application writes, the slice gets smaller.
alreadyWritten []byte
}

// NewConn returns a net.Conn wrapper that supports peeking into the connection.
Expand All @@ -55,6 +61,31 @@ func (c *Conn) Read(p []byte) (int, error) {
return c.reader.Read(p)
}

// Write implements [io.Writer] and [net.Conn].
func (c *Conn) Write(p []byte) (int, error) {
if len(c.alreadyWritten) < 1 {
return c.Conn.Write(p)
}

s := min(len(p), len(c.alreadyWritten))
if !bytes.Equal(p[:s], c.alreadyWritten[:s]) {
return 0, trace.BadParameter("new application data doesn't match already written data (this is a bug)")
}

// we should do the write even if it's zero-length to check that the
// connection is still open and that we're not past the write deadline
n, err := c.Conn.Write(p[s:])
if n > 0 || err == nil {
n += s
c.alreadyWritten = c.alreadyWritten[s:]
if len(c.alreadyWritten) < 1 {
c.alreadyWritten = nil
}
}

return n, trace.Wrap(err)
}

// LocalAddr returns local address of the connection
func (c *Conn) LocalAddr() net.Addr {
if c.proxyLine != nil {
Expand Down
2 changes: 2 additions & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ import (
"github.com/gravitational/teleport/lib/srv/ingress"
"github.com/gravitational/teleport/lib/srv/regular"
"github.com/gravitational/teleport/lib/srv/transport/transportv1"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/system"
usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -2662,6 +2663,7 @@ func (process *TeleportProcess) initSSH() error {
ID: teleport.Component(teleport.ComponentNode, process.id),
CertAuthorityGetter: authClient.GetCertAuthority,
LocalClusterName: conn.ServerIdentity.ClusterName,
FixedHeader: sshutils.SSHVersionPrefix + "\r\n",
})
if err != nil {
return trace.Wrap(err)
Expand Down