Skip to content

Commit

Permalink
feat: add GPG forwarding to coder ssh (#5482)
Browse files Browse the repository at this point in the history
  • Loading branch information
deansheather committed Jan 6, 2023
1 parent 59e919a commit f1fe2b5
Show file tree
Hide file tree
Showing 12 changed files with 1,050 additions and 21 deletions.
14 changes: 10 additions & 4 deletions agent/agent.go
Expand Up @@ -480,12 +480,16 @@ func (a *agent) init(ctx context.Context) {
if err != nil {
panic(err)
}

sshLogger := a.logger.Named("ssh-server")
forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := &forwardedUnixHandler{log: a.logger}

a.sshServer = &ssh.Server{
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler,
"session": ssh.DefaultSessionHandler,
"direct-tcpip": ssh.DirectTCPIPHandler,
"direct-streamlocal@openssh.com": directStreamLocalHandler,
"session": ssh.DefaultSessionHandler,
},
ConnectionFailedCallback: func(conn net.Conn, err error) {
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
Expand Down Expand Up @@ -525,8 +529,10 @@ func (a *agent) init(ctx context.Context) {
return true
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
"tcpip-forward": forwardHandler.HandleSSHRequest,
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest,
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
Expand Down
255 changes: 248 additions & 7 deletions agent/agent_test.go
Expand Up @@ -273,7 +273,7 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) {
}

//nolint:paralleltest // This test reserves a port.
func TestAgent_LocalForwarding(t *testing.T) {
func TestAgent_TCPLocalForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
Expand All @@ -286,24 +286,239 @@ func TestAgent_LocalForwarding(t *testing.T) {
defer local.Close()
tcpAddr, valid = local.Addr().(*net.TCPAddr)
require.True(t, valid)
localPort := tcpAddr.Port
remotePort := tcpAddr.Port
done := make(chan struct{})
go func() {
defer close(done)
conn, err := local.Accept()
if !assert.NoError(t, err) {
return
}
_ = conn.Close()
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()

err = setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, localPort)}, []string{"echo", "test"}).Start()
cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)

require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}

return true
}, testutil.WaitLong, testutil.IntervalSlow)

<-done

_ = cmd.Process.Kill()
}

//nolint:paralleltest // This test reserves a port.
func TestAgent_TCPRemoteForwarding(t *testing.T) {
random, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_ = random.Close()
tcpAddr, valid := random.Addr().(*net.TCPAddr)
require.True(t, valid)
randomPort := tcpAddr.Port

conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(localPort))
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
conn.Close()
defer l.Close()
tcpAddr, valid = l.Addr().(*net.TCPAddr)
require.True(t, valid)
localPort := tcpAddr.Port

done := make(chan struct{})
go func() {
defer close(done)

conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()

cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)

require.Eventually(t, func() bool {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort))
if err != nil {
return false
}
defer conn.Close()
_, err = conn.Write([]byte("test"))
if !assert.NoError(t, err) {
return false
}
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return false
}
if !assert.Equal(t, "test", string(b)) {
return false
}

return true
}, testutil.WaitLong, testutil.IntervalSlow)

<-done

_ = cmd.Process.Kill()
}

func TestAgent_UnixLocalForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}

tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")

l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()

done := make(chan struct{})
go func() {
defer close(done)

conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()

cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)

require.Eventually(t, func() bool {
_, err := os.Stat(localSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)

conn, err := net.Dial("unix", localSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()
<-done

_ = cmd.Process.Kill()
}

func TestAgent_UnixRemoteForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}

tmpdir := tempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
localSocketPath := filepath.Join(tmpdir, "local-socket")

l, err := net.Listen("unix", localSocketPath)
require.NoError(t, err)
defer l.Close()

done := make(chan struct{})
go func() {
defer close(done)

conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
b := make([]byte, 4)
_, err = conn.Read(b)
if !assert.NoError(t, err) {
return
}
_, err = conn.Write(b)
if !assert.NoError(t, err) {
return
}
}()

cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "10"})
err = cmd.Start()
require.NoError(t, err)

require.Eventually(t, func() bool {
_, err := os.Stat(remoteSocketPath)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)

conn, err := net.Dial("unix", remoteSocketPath)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
b := make([]byte, 4)
_, err = conn.Read(b)
require.NoError(t, err)
require.Equal(t, "test", string(b))
_ = conn.Close()

<-done

_ = cmd.Process.Kill()
}

func TestAgent_SFTP(t *testing.T) {
Expand Down Expand Up @@ -733,7 +948,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
args := append(beforeArgs,
"-o", "HostName "+tcpAddr.IP.String(),
"-o", "Port "+strconv.Itoa(tcpAddr.Port),
"-o", "StrictHostKeyChecking=no", "host")
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"host",
)
args = append(args, afterArgs...)
return exec.Command("ssh", args...)
}
Expand Down Expand Up @@ -919,3 +1137,26 @@ func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWor
func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error {
return nil
}

// tempDirUnixSocket returns a temporary directory that can safely hold unix
// sockets (probably).
//
// During tests on darwin we hit the max path length limit for unix sockets
// pretty easily in the default location, so this function uses /tmp instead to
// get shorter paths.
func tempDirUnixSocket(t *testing.T) string {
t.Helper()
if runtime.GOOS == "darwin" {
testName := strings.ReplaceAll(t.Name(), "/", "_")
dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName))
require.NoError(t, err, "create temp dir for gpg test")

t.Cleanup(func() {
err := os.RemoveAll(dir)
assert.NoError(t, err, "remove temp dir", dir)
})
return dir
}

return t.TempDir()
}

0 comments on commit f1fe2b5

Please sign in to comment.