Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Require SSH prefix in router.DialHost connections #33729

Merged
merged 12 commits into from Oct 25, 2023
41 changes: 41 additions & 0 deletions lib/proxy/router.go
Expand Up @@ -15,6 +15,7 @@
package proxy

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -334,9 +335,49 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
return nil, trace.Wrap(err)
}

// SSH connection MUST start with "SSH-2.0" bytes according to https://datatracker.ietf.org/doc/html/rfc4253#section-4.2
conn = newCheckedPrefixWriter(conn, []byte("SSH-2.0"))
return NewProxiedMetricConn(conn), trace.Wrap(err)
}

// checkedPrefixWriter checks that first data written into it has the specified prefix.
type checkedPrefixWriter struct {
net.Conn

requiredPrefix []byte
requiredPointer int
}

func newCheckedPrefixWriter(conn net.Conn, requiredPrefix []byte) *checkedPrefixWriter {
return &checkedPrefixWriter{
Conn: conn,
requiredPrefix: requiredPrefix,
}
}

// Write writes data into connection, checking if it has required prefix. Not safe for concurrent calls.
func (c *checkedPrefixWriter) Write(p []byte) (int, error) {
// If pointer reached end of required prefix the check is done
if len(c.requiredPrefix) == c.requiredPointer {
n, err := c.Conn.Write(p)
return n, trace.Wrap(err)
}

// Decide which is smaller, provided data or remaining portion of the required prefix
small, big := c.requiredPrefix[c.requiredPointer:], p
if len(small) > len(big) {
big, small = small, big
}

if !bytes.HasPrefix(big, small) {
return 0, trace.AccessDenied("required prefix %q was not found", c.requiredPrefix)
}
n, err := c.Conn.Write(p)
// Advance pointer by confirmed portion of the prefix.
c.requiredPointer += min(n, len(small))
return n, trace.Wrap(err)
}

// getRemoteCluster looks up the provided clusterName to determine if a remote site exists with
// that name and determines if the user has access to it.
func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnelclient.RemoteSite, error) {
Expand Down
82 changes: 82 additions & 0 deletions lib/proxy/router_test.go
Expand Up @@ -15,6 +15,7 @@
package proxy

import (
"bytes"
"context"
"math/rand"
"net"
Expand Down Expand Up @@ -346,6 +347,87 @@ func serverResolver(srv types.Server, err error) serverResolverFn {
}
}

type mockConn struct {
net.Conn
buff bytes.Buffer
}

func (o *mockConn) Read(p []byte) (n int, err error) {
return o.buff.Read(p)
}

func (o *mockConn) Write(p []byte) (n int, err error) {
return o.buff.Write(p)
}

func (o *mockConn) Close() error {
return nil
}

func TestCheckedPrefixWriter(t *testing.T) {
t.Parallel()
testData := []byte("test data")
t.Run("missing prefix", func(t *testing.T) {
t.Run("single write", func(t *testing.T) {
cpw := newCheckedPrefixWriter(&mockConn{}, []byte("wrong"))

_, err := cpw.Write(testData)
require.True(t, trace.IsAccessDenied(err), "expected trace.AccessDenied error, got: %v", err)
})
t.Run("two writes", func(t *testing.T) {
cpw := newCheckedPrefixWriter(&mockConn{}, append(testData, []byte("wrong")...))

_, err := cpw.Write(testData)
require.NoError(t, err)

_, err = cpw.Write(testData)
require.True(t, trace.IsAccessDenied(err), "expected trace.AccessDenied error, got: %v", err)
})
})
t.Run("success", func(t *testing.T) {
t.Run("single write", func(t *testing.T) {
cpw := newCheckedPrefixWriter(&mockConn{}, []byte("test"))

// First write with correct prefix should be successful
_, err := cpw.Write(testData)
require.NoError(t, err)

// Write some additional data
secondData := []byte("second data")
_, err = cpw.Write(secondData)
require.NoError(t, err)

// Resulting read should contain data from both writes
buf := make([]byte, len(testData)+len(secondData))
_, err = cpw.Read(buf)
require.NoError(t, err)
require.Equal(t, append(testData, secondData...), buf)
})
t.Run("two writes", func(t *testing.T) {
cpw := newCheckedPrefixWriter(&mockConn{}, []byte("test"))

// First write gives part of correct prefix
_, err := cpw.Write(testData[:3])
require.NoError(t, err)

// Second write gives the rest of correct prefix
_, err = cpw.Write(testData[3:])
require.NoError(t, err)

// Write some additional data
secondData := []byte("second data")
_, err = cpw.Write(secondData)
require.NoError(t, err)

// Resulting read should contain all written data
buf := make([]byte, len(testData)+len(secondData))
_, err = cpw.Read(buf)
require.NoError(t, err)
require.Equal(t, append(testData, secondData...), buf)
})
})
}

type tunnel struct {
reversetunnelclient.Tunnel

Expand Down