Skip to content

Commit

Permalink
Require SSH prefix in proxySubsys connections
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonAM committed Oct 5, 2023
1 parent e0c9b35 commit 9be776d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
39 changes: 38 additions & 1 deletion lib/srv/regular/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ limitations under the License.
package regular

import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"strings"

Expand Down Expand Up @@ -260,12 +263,46 @@ func (t *proxySubsys) proxyToHost(ctx context.Context, ch ssh.Channel, clientSrc
}

go func() {
t.close(utils.ProxyConn(ctx, ch, conn))
t.close(utils.ProxyConn(ctx, &checkedPrefixReader{
ReadWriteCloser: ch,
// SSH connection MUST start with "SSH-2.0" bytes according to https://datatracker.ietf.org/doc/html/rfc4253#section-4.2
requiredPrefix: []byte("SSH-2.0"),
}, conn))
}()

return nil
}

// checkedPrefixReader checks that read data has the specified prefix.
type checkedPrefixReader struct {
io.ReadWriteCloser
requiredPrefix []byte

// upstreamReader reads from the underlying reader.
upstreamReader io.Reader
}

func (c *checkedPrefixReader) Read(b []byte) (int, error) {
// connection was already checked, forward upstream:
if c.upstreamReader != nil {
return c.upstreamReader.Read(b)
}

// make sure first bytes are equal to the required prefix
reader := bufio.NewReader(c.ReadWriteCloser)
data, err := reader.Peek(len(c.requiredPrefix))
if err != nil {
return 0, trace.Wrap(err)
}

if !bytes.Equal(data, c.requiredPrefix) {
return 0, trace.AccessDenied("required prefix %q was not found", string(c.requiredPrefix))
}

c.upstreamReader = reader
return c.upstreamReader.Read(b)
}

func (t *proxySubsys) close(err error) {
t.closeC <- err
}
Expand Down
44 changes: 44 additions & 0 deletions lib/srv/regular/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package regular

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -126,3 +128,45 @@ func TestParseBadRequests(t *testing.T) {
})
}
}

func TestCheckedPrefixReader(t *testing.T) {
getMockReadWriteCloser := func(data []byte) io.ReadWriteCloser {
buf := bytes.NewBuffer(data)
return struct {
io.ReadWriter
io.Closer
}{
ReadWriter: buf,
Closer: io.NopCloser(buf),
}
}

testData := []byte("test data")
t.Run("missing prefix", func(t *testing.T) {
cr := checkedPrefixReader{
ReadWriteCloser: getMockReadWriteCloser(testData),
requiredPrefix: []byte("wrong"),
}

_, err := io.ReadAll(&cr)
require.Error(t, err)
})
t.Run("success", func(t *testing.T) {
cr := checkedPrefixReader{
ReadWriteCloser: getMockReadWriteCloser(testData),
requiredPrefix: []byte("test"),
}

res, err := io.ReadAll(&cr)
require.NoError(t, err)
require.Equal(t, testData, res)

secondData := []byte("second data")
_, err = cr.Write(secondData)
require.NoError(t, err)

res, err = io.ReadAll(&cr)
require.NoError(t, err)
require.Equal(t, secondData, res)
})
}

0 comments on commit 9be776d

Please sign in to comment.