From f1fe2b5c06c662079a7c6651a8ad520db7f401a8 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jan 2023 01:52:19 -0600 Subject: [PATCH] feat: add GPG forwarding to coder ssh (#5482) --- agent/agent.go | 14 +- agent/agent_test.go | 255 ++++++++++++++++++++++++++- agent/ssh.go | 203 +++++++++++++++++++++ cli/server_test.go | 9 +- cli/ssh.go | 211 +++++++++++++++++++++- cli/ssh_other.go | 26 +++ cli/ssh_test.go | 249 +++++++++++++++++++++++++- cli/ssh_windows.go | 78 ++++++++ cli/testdata/coder_ssh_--help.golden | 8 + pty/pty.go | 13 +- pty/start_other.go | 3 + scripts/develop.sh | 2 +- 12 files changed, 1050 insertions(+), 21 deletions(-) create mode 100644 agent/ssh.go diff --git a/agent/agent.go b/agent/agent.go index 53def9472ae8e..dd900700c8913 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -480,12 +480,16 @@ func (a *agent) init(ctx context.Context) { if err != nil { panic(err) } + sshLogger := a.logger.Named("ssh-server") forwardHandler := &ssh.ForwardedTCPHandler{} + unixForwardHandler := &forwardedUnixHandler{log: a.logger} + a.sshServer = &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, - "session": ssh.DefaultSessionHandler, + "direct-tcpip": ssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": directStreamLocalHandler, + "session": ssh.DefaultSessionHandler, }, ConnectionFailedCallback: func(conn net.Conn, err error) { sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) @@ -525,8 +529,10 @@ func (a *agent) init(ctx context.Context) { return true }, RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": forwardHandler.HandleSSHRequest, - "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, }, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ diff --git a/agent/agent_test.go b/agent/agent_test.go index 25f1f86e95062..44b1ff7bf93ae 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -273,7 +273,7 @@ func TestAgent_Session_TTY_Hushlogin(t *testing.T) { } //nolint:paralleltest // This test reserves a port. -func TestAgent_LocalForwarding(t *testing.T) { +func TestAgent_TCPLocalForwarding(t *testing.T) { random, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) _ = random.Close() @@ -286,7 +286,7 @@ func TestAgent_LocalForwarding(t *testing.T) { defer local.Close() tcpAddr, valid = local.Addr().(*net.TCPAddr) require.True(t, valid) - localPort := tcpAddr.Port + remotePort := tcpAddr.Port done := make(chan struct{}) go func() { defer close(done) @@ -294,16 +294,231 @@ func TestAgent_LocalForwarding(t *testing.T) { if !assert.NoError(t, err) { return } - _ = conn.Close() + defer conn.Close() + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = conn.Write(b) + if !assert.NoError(t, err) { + return + } }() - err = setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, localPort)}, []string{"echo", "test"}).Start() + cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%d:127.0.0.1:%d", randomPort, remotePort)}, []string{"sleep", "10"}) + err = cmd.Start() + require.NoError(t, err) + + require.Eventually(t, func() bool { + conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(randomPort)) + if err != nil { + return false + } + defer conn.Close() + _, err = conn.Write([]byte("test")) + if !assert.NoError(t, err) { + return false + } + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return false + } + if !assert.Equal(t, "test", string(b)) { + return false + } + + return true + }, testutil.WaitLong, testutil.IntervalSlow) + + <-done + + _ = cmd.Process.Kill() +} + +//nolint:paralleltest // This test reserves a port. +func TestAgent_TCPRemoteForwarding(t *testing.T) { + random, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + _ = random.Close() + tcpAddr, valid := random.Addr().(*net.TCPAddr) + require.True(t, valid) + randomPort := tcpAddr.Port - conn, err := net.Dial("tcp", "127.0.0.1:"+strconv.Itoa(localPort)) + l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - conn.Close() + defer l.Close() + tcpAddr, valid = l.Addr().(*net.TCPAddr) + require.True(t, valid) + localPort := tcpAddr.Port + + done := make(chan struct{}) + go func() { + defer close(done) + + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = conn.Write(b) + if !assert.NoError(t, err) { + return + } + }() + + cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("127.0.0.1:%d:127.0.0.1:%d", randomPort, localPort)}, []string{"sleep", "10"}) + err = cmd.Start() + require.NoError(t, err) + + require.Eventually(t, func() bool { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", randomPort)) + if err != nil { + return false + } + defer conn.Close() + _, err = conn.Write([]byte("test")) + if !assert.NoError(t, err) { + return false + } + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return false + } + if !assert.Equal(t, "test", string(b)) { + return false + } + + return true + }, testutil.WaitLong, testutil.IntervalSlow) + <-done + + _ = cmd.Process.Kill() +} + +func TestAgent_UnixLocalForwarding(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("unix domain sockets are not fully supported on Windows") + } + + tmpdir := tempDirUnixSocket(t) + remoteSocketPath := filepath.Join(tmpdir, "remote-socket") + localSocketPath := filepath.Join(tmpdir, "local-socket") + + l, err := net.Listen("unix", remoteSocketPath) + require.NoError(t, err) + defer l.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = conn.Write(b) + if !assert.NoError(t, err) { + return + } + }() + + cmd := setupSSHCommand(t, []string{"-L", fmt.Sprintf("%s:%s", localSocketPath, remoteSocketPath)}, []string{"sleep", "10"}) + err = cmd.Start() + require.NoError(t, err) + + require.Eventually(t, func() bool { + _, err := os.Stat(localSocketPath) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + + conn, err := net.Dial("unix", localSocketPath) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("test")) + require.NoError(t, err) + b := make([]byte, 4) + _, err = conn.Read(b) + require.NoError(t, err) + require.Equal(t, "test", string(b)) + _ = conn.Close() + <-done + + _ = cmd.Process.Kill() +} + +func TestAgent_UnixRemoteForwarding(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("unix domain sockets are not fully supported on Windows") + } + + tmpdir := tempDirUnixSocket(t) + remoteSocketPath := filepath.Join(tmpdir, "remote-socket") + localSocketPath := filepath.Join(tmpdir, "local-socket") + + l, err := net.Listen("unix", localSocketPath) + require.NoError(t, err) + defer l.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + b := make([]byte, 4) + _, err = conn.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = conn.Write(b) + if !assert.NoError(t, err) { + return + } + }() + + cmd := setupSSHCommand(t, []string{"-R", fmt.Sprintf("%s:%s", remoteSocketPath, localSocketPath)}, []string{"sleep", "10"}) + err = cmd.Start() + require.NoError(t, err) + + require.Eventually(t, func() bool { + _, err := os.Stat(remoteSocketPath) + return err == nil + }, testutil.WaitLong, testutil.IntervalFast) + + conn, err := net.Dial("unix", remoteSocketPath) + require.NoError(t, err) + defer conn.Close() + _, err = conn.Write([]byte("test")) + require.NoError(t, err) + b := make([]byte, 4) + _, err = conn.Read(b) + require.NoError(t, err) + require.Equal(t, "test", string(b)) + _ = conn.Close() + + <-done + + _ = cmd.Process.Kill() } func TestAgent_SFTP(t *testing.T) { @@ -733,7 +948,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe args := append(beforeArgs, "-o", "HostName "+tcpAddr.IP.String(), "-o", "Port "+strconv.Itoa(tcpAddr.Port), - "-o", "StrictHostKeyChecking=no", "host") + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "host", + ) args = append(args, afterArgs...) return exec.Command("ssh", args...) } @@ -919,3 +1137,26 @@ func (*client) PostWorkspaceAgentAppHealth(_ context.Context, _ codersdk.PostWor func (*client) PostWorkspaceAgentVersion(_ context.Context, _ string) error { return nil } + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets (probably). +// +// During tests on darwin we hit the max path length limit for unix sockets +// pretty easily in the default location, so this function uses /tmp instead to +// get shorter paths. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName)) + require.NoError(t, err, "create temp dir for gpg test") + + t.Cleanup(func() { + err := os.RemoveAll(dir) + assert.NoError(t, err, "remove temp dir", dir) + }) + return dir + } + + return t.TempDir() +} diff --git a/agent/ssh.go b/agent/ssh.go new file mode 100644 index 0000000000000..edebbeffd1c4c --- /dev/null +++ b/agent/ssh.go @@ -0,0 +1,203 @@ +package agent + +import ( + "context" + "fmt" + "net" + "os" + "path/filepath" + "sync" + + "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + + "cdr.dev/slog" +) + +// streamLocalForwardPayload describes the extra data sent in a +// streamlocal-forward@openssh.com containing the socket path to bind to. +type streamLocalForwardPayload struct { + SocketPath string +} + +// forwardedStreamLocalPayload describes the data sent as the payload in the new +// channel request when a Unix connection is accepted by the listener. +type forwardedStreamLocalPayload struct { + SocketPath string + Reserved uint32 +} + +// forwardedUnixHandler is a clone of ssh.ForwardedTCPHandler that does +// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding. +type forwardedUnixHandler struct { + sync.Mutex + log slog.Logger + forwards map[string]net.Listener +} + +func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) + if !ok { + h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection") + return false, nil + } + + switch req.Type { + case "streamlocal-forward@openssh.com": + var reqPayload streamLocalForwardPayload + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + h.log.Warn(ctx, "parse streamlocal-forward@openssh.com request payload from client", slog.Error(err)) + return false, nil + } + + addr := reqPayload.SocketPath + h.Lock() + _, ok := h.forwards[addr] + h.Unlock() + if ok { + h.log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)", + slog.F("socket_path", addr), + ) + return false, nil + } + + // Create socket parent dir if not exists. + parentDir := filepath.Dir(addr) + err = os.MkdirAll(parentDir, 0700) + if err != nil { + h.log.Warn(ctx, "create parent dir for SSH unix forward request", + slog.F("parent_dir", parentDir), + slog.F("socket_path", addr), + slog.Error(err), + ) + return false, nil + } + + ln, err := net.Listen("unix", addr) + if err != nil { + h.log.Warn(ctx, "listen on Unix socket for SSH unix forward request", + slog.F("socket_path", addr), + slog.Error(err), + ) + return false, nil + } + + // The listener needs to successfully start before it can be added to + // the map, so we don't have to worry about checking for an existing + // listener. + // + // This is also what the upstream TCP version of this code does. + h.Lock() + h.forwards[addr] = ln + h.Unlock() + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + _ = ln.Close() + }() + go func() { + defer cancel() + + for { + c, err := ln.Accept() + if err != nil { + if !xerrors.Is(err, net.ErrClosed) { + h.log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", + slog.F("socket_path", addr), + slog.Error(err), + ) + } + // closed below + break + } + payload := gossh.Marshal(&forwardedStreamLocalPayload{ + SocketPath: addr, + }) + + go func() { + ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload) + if err != nil { + h.log.Warn(ctx, "open SSH channel to forward Unix connection to client", + slog.F("socket_path", addr), + slog.Error(err), + ) + _ = c.Close() + return + } + go gossh.DiscardRequests(reqs) + Bicopy(ctx, ch, c) + }() + } + + h.Lock() + ln2, ok := h.forwards[addr] + if ok && ln2 == ln { + delete(h.forwards, addr) + } + h.Unlock() + _ = ln.Close() + }() + + return true, nil + + case "cancel-streamlocal-forward@openssh.com": + var reqPayload streamLocalForwardPayload + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com request payload from client", slog.Error(err)) + return false, nil + } + h.Lock() + ln, ok := h.forwards[reqPayload.SocketPath] + h.Unlock() + if ok { + _ = ln.Close() + } + return true, nil + + default: + return false, nil + } +} + +// directStreamLocalPayload describes the extra data sent in a +// direct-streamlocal@openssh.com channel request containing the socket path. +type directStreamLocalPayload struct { + SocketPath string + + Reserved1 string + Reserved2 uint32 +} + +func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + var reqPayload directStreamLocalPayload + err := gossh.Unmarshal(newChan.ExtraData(), &reqPayload) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, "could not parse direct-streamlocal@openssh.com channel payload") + return + } + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "unix", reqPayload.SocketPath) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", reqPayload.SocketPath, err.Error())) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + Bicopy(ctx, ch, dconn) +} diff --git a/cli/server_test.go b/cli/server_test.go index 5a1f6ed3fa9e0..de68027260622 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -575,14 +575,17 @@ func TestServer(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + httpListenAddr := "" + if c.httpListener { + httpListenAddr = ":0" + } + certPath, keyPath := generateTLSCertificate(t) flags := []string{ "server", "--in-memory", "--cache-dir", t.TempDir(), - } - if c.httpListener { - flags = append(flags, "--http-address", ":0") + "--http-address", httpListenAddr, } if c.tlsListener { flags = append(flags, diff --git a/cli/ssh.go b/cli/ssh.go index 57a8c4aab4ac4..3c1671c7849b1 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -1,12 +1,15 @@ package cli import ( + "bytes" "context" "errors" "fmt" "io" + "net" "net/url" "os" + "os/exec" "path/filepath" "strings" "time" @@ -21,6 +24,7 @@ import ( "golang.org/x/term" "golang.org/x/xerrors" + "github.com/coder/coder/agent" "github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" @@ -39,6 +43,7 @@ func ssh() *cobra.Command { stdio bool shuffle bool forwardAgent bool + forwardGPG bool identityAgent string wsPollInterval time.Duration ) @@ -138,7 +143,7 @@ func ssh() *cobra.Command { if forwardAgent && identityAgent != "" { err = gosshagent.ForwardToRemote(sshClient, identityAgent) if err != nil { - return xerrors.Errorf("forward agent failed: %w", err) + return xerrors.Errorf("forward agent: %w", err) } err = gosshagent.RequestAgentForwarding(sshSession) if err != nil { @@ -146,6 +151,22 @@ func ssh() *cobra.Command { } } + if forwardGPG { + if workspaceAgent.OperatingSystem == "windows" { + return xerrors.New("GPG forwarding is not supported for Windows workspaces") + } + + err = uploadGPGKeys(ctx, sshClient) + if err != nil { + return xerrors.Errorf("upload GPG public keys and ownertrust to workspace: %w", err) + } + closer, err := forwardGPGAgent(ctx, cmd.ErrOrStderr(), sshClient) + if err != nil { + return xerrors.Errorf("forward GPG socket: %w", err) + } + defer closer.Close() + } + stdoutFile, validOut := cmd.OutOrStdout().(*os.File) stdinFile, validIn := cmd.InOrStdin().(*os.File) if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) { @@ -199,10 +220,12 @@ func ssh() *cobra.Command { _ = sshSession.WindowChange(height, width) } } + err = sshSession.Wait() if err != nil { - // If the connection drops unexpectedly, we get an ExitMissingError but no other - // error details, so try to at least give the user a better message + // If the connection drops unexpectedly, we get an + // ExitMissingError but no other error details, so try to at + // least give the user a better message if errors.Is(err, &gossh.ExitMissingError{}) { return xerrors.New("SSH connection ended unexpectedly") } @@ -216,6 +239,7 @@ func ssh() *cobra.Command { cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace") _ = cmd.Flags().MarkHidden("shuffle") cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") + cliflag.BoolVarP(cmd.Flags(), &forwardGPG, "forward-gpg", "G", "CODER_SSH_FORWARD_GPG", false, "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.") cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") return cmd @@ -364,3 +388,184 @@ func verifyWorkspaceOutdated(client *codersdk.Client, workspace codersdk.Workspa func buildWorkspaceLink(serverURL *url.URL, workspace codersdk.Workspace) *url.URL { return serverURL.ResolveReference(&url.URL{Path: fmt.Sprintf("@%s/%s", workspace.OwnerName, workspace.Name)}) } + +// runLocal runs a command on the local machine. +func runLocal(ctx context.Context, stdin io.Reader, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + cmd.Stdin = stdin + + out, err := cmd.Output() + if err != nil { + var stderr []byte + if exitErr := new(exec.ExitError); errors.As(err, &exitErr) { + stderr = exitErr.Stderr + } + + return out, xerrors.Errorf( + "`%s %s` failed: stderr: %s\n\nstdout: %s\n\n%w", + name, + strings.Join(args, " "), + bytes.TrimSpace(stderr), + bytes.TrimSpace(out), + err, + ) + } + + return out, nil +} + +// runRemoteSSH runs a command on a remote machine/workspace via SSH. +func runRemoteSSH(sshClient *gossh.Client, stdin io.Reader, cmd string) ([]byte, error) { + sess, err := sshClient.NewSession() + if err != nil { + return nil, xerrors.Errorf("create SSH session") + } + defer sess.Close() + + stderr := bytes.NewBuffer(nil) + sess.Stdin = stdin + sess.Stderr = stderr + + out, err := sess.Output(cmd) + if err != nil { + return out, xerrors.Errorf( + "`%s` failed: stderr: %s\n\nstdout: %s:\n\n%w", + cmd, + bytes.TrimSpace(stderr.Bytes()), + bytes.TrimSpace(out), + err, + ) + } + + return out, nil +} + +func uploadGPGKeys(ctx context.Context, sshClient *gossh.Client) error { + // Check if the agent is running in the workspace already. + // + // Note: we don't support windows in the workspace for GPG forwarding so + // using shell commands is fine. + // + // Note: we sleep after killing the agent because it doesn't always die + // immediately. + agentSocketBytes, err := runRemoteSSH(sshClient, nil, ` +set -eux +agent_socket=$(gpgconf --list-dir agent-socket) +echo "$agent_socket" +if [ -S "$agent_socket" ]; then + echo "agent socket exists, attempting to kill it" >&2 + gpgconf --kill gpg-agent + rm -f "$agent_socket" + sleep 1 +fi + +test ! -S "$agent_socket" +`) + agentSocket := strings.TrimSpace(string(agentSocketBytes)) + if err != nil { + return xerrors.Errorf("check if agent socket is running (check if %q exists): %w", agentSocket, err) + } + if agentSocket == "" { + return xerrors.Errorf("agent socket path is empty, check the output of `gpgconf --list-dir agent-socket`") + } + + // Read the user's public keys and ownertrust from GPG. + pubKeyExport, err := runLocal(ctx, nil, "gpg", "--armor", "--export") + if err != nil { + return xerrors.Errorf("export local public keys from GPG: %w", err) + } + ownerTrustExport, err := runLocal(ctx, nil, "gpg", "--export-ownertrust") + if err != nil { + return xerrors.Errorf("export local ownertrust from GPG: %w", err) + } + + // Import the public keys and ownertrust into the workspace. + _, err = runRemoteSSH(sshClient, bytes.NewReader(pubKeyExport), "gpg --import") + if err != nil { + return xerrors.Errorf("import public keys into workspace: %w", err) + } + _, err = runRemoteSSH(sshClient, bytes.NewReader(ownerTrustExport), "gpg --import-ownertrust") + if err != nil { + return xerrors.Errorf("import ownertrust into workspace: %w", err) + } + + // Kill the agent in the workspace if it was started by one of the above + // commands. + _, err = runRemoteSSH(sshClient, nil, fmt.Sprintf("gpgconf --kill gpg-agent && rm -f %q", agentSocket)) + if err != nil { + return xerrors.Errorf("kill existing agent in workspace: %w", err) + } + + return nil +} + +func localGPGExtraSocket(ctx context.Context) (string, error) { + localSocket, err := runLocal(ctx, nil, "gpgconf", "--list-dir", "agent-extra-socket") + if err != nil { + return "", xerrors.Errorf("get local GPG agent socket: %w", err) + } + + return string(bytes.TrimSpace(localSocket)), nil +} + +func remoteGPGAgentSocket(sshClient *gossh.Client) (string, error) { + remoteSocket, err := runRemoteSSH(sshClient, nil, "gpgconf --list-dir agent-socket") + if err != nil { + return "", xerrors.Errorf("get remote GPG agent socket: %w", err) + } + + return string(bytes.TrimSpace(remoteSocket)), nil +} + +// cookieAddr is a special net.Addr accepted by sshForward() which includes a +// cookie which is written to the connection before forwarding. +type cookieAddr struct { + net.Addr + cookie []byte +} + +// sshForwardRemote starts forwarding connections from a remote listener to a +// local address via SSH in a goroutine. +// +// Accepts a `cookieAddr` as the local address. +func sshForwardRemote(ctx context.Context, stderr io.Writer, sshClient *gossh.Client, localAddr, remoteAddr net.Addr) (io.Closer, error) { + listener, err := sshClient.Listen(remoteAddr.Network(), remoteAddr.String()) + if err != nil { + return nil, xerrors.Errorf("listen on remote SSH address %s: %w", remoteAddr.String(), err) + } + + go func() { + for { + remoteConn, err := listener.Accept() + if err != nil { + if ctx.Err() == nil { + _, _ = fmt.Fprintf(stderr, "Accept SSH listener connection: %+v\n", err) + } + return + } + + go func() { + defer remoteConn.Close() + + localConn, err := net.Dial(localAddr.Network(), localAddr.String()) + if err != nil { + _, _ = fmt.Fprintf(stderr, "Dial local address %s: %+v\n", localAddr.String(), err) + return + } + defer localConn.Close() + + if c, ok := localAddr.(cookieAddr); ok { + _, err = localConn.Write(c.cookie) + if err != nil { + _, _ = fmt.Fprintf(stderr, "Write cookie to local connection: %+v\n", err) + return + } + } + + agent.Bicopy(ctx, localConn, remoteConn) + }() + } + }() + + return listener, nil +} diff --git a/cli/ssh_other.go b/cli/ssh_other.go index 8799030949283..064436da31406 100644 --- a/cli/ssh_other.go +++ b/cli/ssh_other.go @@ -5,9 +5,12 @@ package cli import ( "context" + "io" + "net" "os" "os/signal" + gossh "golang.org/x/crypto/ssh" "golang.org/x/sys/unix" ) @@ -20,3 +23,26 @@ func listenWindowSize(ctx context.Context) <-chan os.Signal { }() return windowSize } + +func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Client) (io.Closer, error) { + localSocket, err := localGPGExtraSocket(ctx) + if err != nil { + return nil, err + } + + remoteSocket, err := remoteGPGAgentSocket(sshClient) + if err != nil { + return nil, err + } + + localAddr := &net.UnixAddr{ + Name: localSocket, + Net: "unix", + } + remoteAddr := &net.UnixAddr{ + Name: remoteSocket, + Net: "unix", + } + + return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr) +} diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 8e828b4e047d5..ceb34cc7f6b80 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -1,15 +1,20 @@ package cli_test import ( + "bytes" "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "errors" + "fmt" "io" "net" + "os" + "os/exec" "path/filepath" "runtime" + "strings" "testing" "time" @@ -27,6 +32,7 @@ import ( "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" "github.com/coder/coder/testutil" ) @@ -226,7 +232,7 @@ func TestSSH(t *testing.T) { }) // Start up ssh agent listening on unix socket. - tmpdir := t.TempDir() + tmpdir := tempDirUnixSocket(t) agentSock := filepath.Join(tmpdir, "agent.sock") l, err := net.Listen("unix", agentSock) require.NoError(t, err) @@ -283,6 +289,224 @@ func TestSSH(t *testing.T) { pty.WriteLine("exit") <-cmdDone }) + + //nolint:paralleltest // This test uses t.Setenv. + t.Run("ForwardGPG", func(t *testing.T) { + if runtime.GOOS == "windows" { + // While GPG forwarding from a Windows client works, we currently do + // not support forwarding to a Windows workspace. Our tests use the + // same platform for the "client" and "workspace" as they run in the + // same process. + t.Skip("Test not supported on windows") + } + + // This key is for dean@coder.com. + const randPublicKeyFingerprint = "7BDFBA0CC7F5A96537C806C427BC6335EB5117F1" + const randPublicKey = `-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBF6SWkEBEADB8sAhBaT36VQ6HEhAmtKexLldu1HUdXNw16rdF+1wiBzSFfJN +aPeX4Y9iFIZgC2wU0wOjJ04BpioyOLtJngbThI5WpeoQ/1yQZOpnDaCMPPLp+uJ+ +Gy4tMZYWQq21PukrFm3XDRGKjVN58QN6uCPb1S/YzteP8Epmq590GYIYLiAHnMt6 +5iyxIFhXj/fq5Fddp2+efI7QWvNl2wTNnCaTziOSKYcbNmQpn9gy0WvKktWYtB8E +JJtWES0DzgCnDpm/hYx79Wkb+F7qY54y2uauDx+z97QXrON47lsIyGm8/T59ZfSd +/yrBqDLHYrHlt9RkFpAnBzO402y2eHsKTB6/EAHv9H2apxahyJlcxGbE5QE+fOJk +LdPlako0cSljz0g9Icesr2nZL0MhWwLnwk7DHkg/PUUijkbuR/TD9dti2/yOTFrf +Y7DdZpoZ0ZkcGu9lMh2vOTWc96RNCyIZfE5WNDKKo+u5Txzndsc/qIgKohwDSxTC +3hAulG5Wt05UeyHBEAAvGV2szG88VsGwd1juqXAbEzk+kLQzNyoQX188/4V4X+MV +pY9Wz7JudmQpB/3+YTcA/ziK/+wu3c2wNlr7gMZYMOwDWTLfW64nux7zHWDytrP0 +HfgJIgqP7F7SnChpTFdb1hr1WDox99ZG+/eDkwxnuXYWm9xx5/crqQ0POQARAQAB +tClEZWFuIFNoZWF0aGVyICh3b3JrIGtleSkgPGRlYW5AY29kZXIuY29tPokCVAQT +AQgAPhYhBHvfugzH9allN8gGxCe8YzXrURfxBQJeklpBAhsDBQkJZgGABQsJCAcC +BhUKCQgLAgQWAgMBAh4BAheAAAoJECe8YzXrURfxIVkP/3UJMzvIjTNF63WiK4xk +TXlBbPKodnzUmAJ+8DVXmJMJpNsSI2czw6eFUXMcrT3JMlviOXhRWMLHr2FsQhyS +AJOQo0x9z7nntPIkvj96ihCdgRn7VN1WzaMwOOesGPr57StWLE84bg9/R0aSsxtX +LgfBCyNkv6FFlruhnw8+JdZJEjvIXQ9swvwD6L68ZLWIWcdnj/CjQmnmgFA+O4UO +SFXMUjklbrq8mJ0sAPUUATJK0SOTyqkZPkhqjlTZa8p0XoJF25trhwLhzDi4GPR6 +SK/9SkqB/go9ZwkNZOjs2tP7eMExy4zQ21MFH09JMKQB7H5CG8GwdMwz4+VKc9aP +y9Ncova/p7Y8kJ7oQPWhACJT1jMP6620oC2N/7wwS0Vtc6E9LoPrfXC2TtvOA9qx +aOf6riWSjo8BEcXDuMtlW4g6IQFNd0+wcgcKrAd+vPLZnG4rtYL0Etdd1ymBT4pi +5E5uT8oUT9rLHX+2tD/E8SE5PzsaKEOJKzcOB8ESb3YBGic7+VvX/AuJuSFsuWnZ +FqAUENqfdz6+0dEJe1pfWyje+Q+o7B7u+ffMT4dOQOC8NfHFnz1kU+DA3VDE6xsu +3YN1L8KlYON92s9VWDA8VuvmU2d9pq5ysUeg133ftDSwj3X+5GYcBv4VFcSRCBW5 +w0hDpMDun1t8xcXdo1LQ4R4NuQINBF6SWkEBEADF4Nrhlqc5M3Sz9sNHDJZR68zb +4CjkoOpYwsKj/ZCukzRCGKpT5Agn0zOycUjbAyCZVjREeIRRURyAhfpOmZY5yF6b +PD93+04OzWk1AaDRmMfvi1Crn/WUEVHIbDaisxDzNuAJgLrt93I/lOz06GczhCb6 +sPBeKuaXCLl/5LSwTahGWsweeSCmfyrYsOc11T+SjdyWXWXEpzFNNIhvqiEoJCw3 +IcdktTBJYuHsN4jh5kVemi/ttqRN3z7rBMKR1sPG3ux1MfCfSTSCeZLTN9eVvqm9 +ne8brk8ZC6sdwlZ9IofPbmSaAh+F5Kfcnd3KjmyQ63t+8plpJ2YH3Fx6IwTwVEQ8 +Ii3WQInTpBSPqf0EwnzRBvhYeKusRpcmX3JSmosLbd5uhvJdgotzuwZYzgay/6DL +OlwElZ//ecXNhU8iYmx1BwNuquvGcGVpkP5eaaT6O9qDznB7TT0xztfAK0LaAuRJ +HOFCc8iiHtQ4o0OkRhg/0KkUGBU5Iw5SIDimkgwJMtD3ZiYOqLaXS6kmmVw2u6YD +LB8rTpegz/tcX+4uyfnIZ28JCOYFTeaDT4FixFW2hrfo/VJzMI5IIv9XAAmtAiEU +f+CY2BT6kg9NkQuke0p4/W8yTaScapYZa5I2bzFpJJyzh1TKE6x3qcbBs9vVX+6E +vK4FflNwu9WSWojO2wARAQABiQI8BBgBCAAmFiEEe9+6DMf1qWU3yAbEJ7xjNetR +F/EFAl6SWkECGwwFCQlmAYAACgkQJ7xjNetRF/FpnQ//SIYePQzhvWj9drnT2krG +dUGSxCN0pA2UQZNkreAaKmyxn2/6xEdxYSz0iUEk+I0HKay+NLCxJ5PDoDBypFtM +f0yOnbWRObhim8HmED4JRw678G4hRU7KEN0L/9SUYlsBNbgr1xYM/CUX/Ih9NT+P +eApxs2VgjKii6m81nfBCFpWSxAs+TOnbshp8dlDZk9kxjFH9+h1ffgZjntqeyiWe +F1UE1Wh32MbJdtc2Y3mrA6i+7+3OXmqMHoiG1obhISgdpaCJ/ub3ywnAmeXSiAKE +IuS6CriR71Wqv8LMQ8kPM8On9Q26d1dsKKBnlFop9oexxf1AFsbbf9gkcgb+uNno +1Qr/R6l2H1TcV1gmiyQLzVnkgLRORosLvSlFrisrsLv9uTYYgcGvwKiU/o3PTdQg +fv0D7LB+a3C9KsCBFjihW3bTOcHKX2sAWEQXZMtKGf5aNTBmWQ+eKWUGpudXIvLE +od5lgfk9p8T1R50KDieG/+2X95zxFSYBoPRAfp7JNT7h+TZ55qUmQXZGI1VqhWiq +b6y/yqfI17JCm4oWpXYbgeruLuye2c/ptDc3S3d26hbWYiWKVT4bLtUGR0wuE6lS +DK0u4LK+mnrYfIvRDYJGx18/nbLpR+ivWLIssJT2Jyyj8w9+hk10XkODySNjHCxj +p7KeSZdlk47pMBGOfnvEmoQ= +=OxHv +-----END PGP PUBLIC KEY BLOCK-----` + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + gpgPath, err := exec.LookPath("gpg") + if err != nil { + t.Skip("gpg not found") + } + gpgConfPath, err := exec.LookPath("gpgconf") + if err != nil { + t.Skip("gpgconf not found") + } + gpgAgentPath, err := exec.LookPath("gpg-agent") + if err != nil { + t.Skip("gpg-agent not found") + } + + // Setup GPG home directory on the "client". + gnupgHomeClient := tempDirUnixSocket(t) + t.Setenv("GNUPGHOME", gnupgHomeClient) + + // Get the agent extra socket path. + var ( + stdout = bytes.NewBuffer(nil) + stderr = bytes.NewBuffer(nil) + ) + c := exec.CommandContext(ctx, gpgConfPath, "--list-dir", "agent-extra-socket") + c.Stdout = stdout + c.Stderr = stderr + err = c.Run() + require.NoError(t, err, "get extra socket path failed: %s", stderr.String()) + extraSocketPath := strings.TrimSpace(stdout.String()) + + // Generate private key non-interactively. + genKeyScript := ` +Key-Type: 1 +Key-Length: 2048 +Subkey-Type: 1 +Subkey-Length: 2048 +Name-Real: Coder Test +Name-Email: test@coder.com +Expire-Date: 0 +%no-protection +` + c = exec.CommandContext(ctx, gpgPath, "--batch", "--gen-key") + c.Stdin = strings.NewReader(genKeyScript) + out, err := c.CombinedOutput() + require.NoError(t, err, "generate key failed: %s", out) + + // Import a random public key. + stdin := strings.NewReader(randPublicKey + "\n") + c = exec.CommandContext(ctx, gpgPath, "--import", "-") + c.Stdin = stdin + out, err = c.CombinedOutput() + require.NoError(t, err, "import key failed: %s", out) + + // Set ultimate trust on imported key. + stdin = strings.NewReader(randPublicKeyFingerprint + ":6:\n") + c = exec.CommandContext(ctx, gpgPath, "--import-ownertrust") + c.Stdin = stdin + out, err = c.CombinedOutput() + require.NoError(t, err, "import ownertrust failed: %s", out) + + // Start the GPG agent. + agentCmd := exec.CommandContext(ctx, gpgAgentPath, "--no-detach", "--extra-socket", extraSocketPath) + agentCmd.Env = append(agentCmd.Env, "GNUPGHOME="+gnupgHomeClient) + agentPTY, agentProc, err := pty.Start(agentCmd, pty.WithPTYOption(pty.WithGPGTTY())) + require.NoError(t, err, "launch agent failed") + defer func() { + _ = agentProc.Kill() + _ = agentPTY.Close() + }() + + // Get the agent socket path in the "workspace". + gnupgHomeWorkspace := tempDirUnixSocket(t) + + stdout = bytes.NewBuffer(nil) + stderr = bytes.NewBuffer(nil) + c = exec.CommandContext(ctx, gpgConfPath, "--list-dir", "agent-socket") + c.Env = append(c.Env, "GNUPGHOME="+gnupgHomeWorkspace) + c.Stdout = stdout + c.Stderr = stderr + err = c.Run() + require.NoError(t, err, "get agent socket path in workspace failed: %s", stderr.String()) + workspaceAgentSocketPath := strings.TrimSpace(stdout.String()) + require.NotEqual(t, extraSocketPath, workspaceAgentSocketPath, "socket path should be different") + + client, workspace, agentToken := setupWorkspaceForAgent(t, nil) + + agentClient := codersdk.New(client.URL) + agentClient.SetSessionToken(agentToken) + agentCloser := agent.New(agent.Options{ + Client: agentClient, + EnvironmentVariables: map[string]string{ + "GNUPGHOME": gnupgHomeWorkspace, + }, + Logger: slogtest.Make(t, nil).Named("agent"), + }) + defer agentCloser.Close() + + cmd, root := clitest.New(t, + "ssh", + workspace.Name, + "--forward-gpg", + ) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) + cmd.SetErr(pty.Output()) + cmdDone := tGo(t, func() { + err := cmd.ExecuteContext(ctx) + assert.NoError(t, err, "ssh command failed") + }) + // Prevent the test from hanging if the asserts below kill the test + // early. This will cause the command to exit with an error, which will + // let the t.Cleanup'd `<-done` inside of `tGo` exit and not hang. + // Without this, the test will hang forever on failure, preventing the + // real error from being printed. + t.Cleanup(cancel) + + pty.WriteLine("echo hello 'world'") + pty.ExpectMatch("hello world") + + // Check the GNUPGHOME was correctly inherited via shell. + pty.WriteLine("env && echo env-''-command-done") + match := pty.ExpectMatch("env--command-done") + require.Contains(t, match, "GNUPGHOME="+gnupgHomeWorkspace, match) + + // Get the agent extra socket path in the "workspace" via shell. + pty.WriteLine("gpgconf --list-dir agent-socket && echo gpgconf-''-agentsocket-command-done") + pty.ExpectMatch(workspaceAgentSocketPath) + pty.ExpectMatch("gpgconf--agentsocket-command-done") + + // List the keys in the "workspace". + pty.WriteLine("gpg --list-keys && echo gpg-''-listkeys-command-done") + listKeysOutput := pty.ExpectMatch("gpg--listkeys-command-done") + require.Contains(t, listKeysOutput, "[ultimate] Coder Test ") + require.Contains(t, listKeysOutput, "[ultimate] Dean Sheather (work key) ") + + // Try to sign something. This demonstrates that the forwarding is + // working as expected, since the workspace doesn't have access to the + // private key directly and must use the forwarded agent. + pty.WriteLine("echo 'hello world' | gpg --clearsign && echo gpg-''-sign-command-done") + pty.ExpectMatch("BEGIN PGP SIGNED MESSAGE") + pty.ExpectMatch("Hash:") + pty.ExpectMatch("hello world") + pty.ExpectMatch("gpg--sign-command-done") + + // And we're done. + pty.WriteLine("exit") + <-cmdDone + }) } // tGoContext runs fn in a goroutine passing a context that will be @@ -356,3 +580,26 @@ func (*stdioConn) SetReadDeadline(_ time.Time) error { func (*stdioConn) SetWriteDeadline(_ time.Time) error { return nil } + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets (probably). +// +// During tests on darwin we hit the max path length limit for unix sockets +// pretty easily in the default location, so this function uses /tmp instead to +// get shorter paths. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("coder-test-%s-", testName)) + require.NoError(t, err, "create temp dir for gpg test") + + t.Cleanup(func() { + err := os.RemoveAll(dir) + assert.NoError(t, err, "remove temp dir", dir) + }) + return dir + } + + return t.TempDir() +} diff --git a/cli/ssh_windows.go b/cli/ssh_windows.go index 2b1cbb4dd6a8f..bf579c9df56b4 100644 --- a/cli/ssh_windows.go +++ b/cli/ssh_windows.go @@ -4,9 +4,16 @@ package cli import ( + "bufio" "context" + "io" + "net" "os" + "strconv" "time" + + gossh "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" ) func listenWindowSize(ctx context.Context) <-chan os.Signal { @@ -25,3 +32,74 @@ func listenWindowSize(ctx context.Context) <-chan os.Signal { }() return windowSize } + +func forwardGPGAgent(ctx context.Context, stderr io.Writer, sshClient *gossh.Client) (io.Closer, error) { + // Read TCP port and cookie from extra socket file. A gpg-agent socket + // file looks like the following: + // + // 49955 + // abcdefghijklmnop + // + // The first line is the TCP port that gpg-agent is listening on, and + // the second line is a 16 byte cookie that MUST be sent as the first + // bytes of any connection to this port (otherwise the connection is + // closed by gpg-agent). + localSocket, err := localGPGExtraSocket(ctx) + if err != nil { + return nil, err + } + f, err := os.Open(localSocket) + if err != nil { + return nil, xerrors.Errorf("open gpg-agent-extra socket file %q: %w", localSocket, err) + } + + // Scan lines from file to get port and cookie. + var ( + port uint16 + cookie []byte + scanner = bufio.NewScanner(f) + ) + for i := 0; scanner.Scan(); i++ { + switch i { + case 0: + port64, err := strconv.ParseUint(scanner.Text(), 10, 16) + if err != nil { + return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: line 1: convert string to integer: %w", localSocket, err) + } + port = uint16(port64) + + case 1: + cookie = scanner.Bytes() + if len(cookie) != 16 { + return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: line 2: expected 16 bytes, got %v bytes", localSocket, len(cookie)) + } + + default: + return nil, xerrors.Errorf("parse gpg-agent-extra socket file %q: file contains more than 2 lines", localSocket) + } + } + + err = scanner.Err() + if err != nil { + return nil, xerrors.Errorf("parse gpg-agent-extra socket file: %q: %w", localSocket, err) + } + + remoteSocket, err := remoteGPGAgentSocket(sshClient) + if err != nil { + return nil, err + } + + localAddr := cookieAddr{ + Addr: &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: int(port), + }, + cookie: cookie, + } + remoteAddr := &net.UnixAddr{ + Name: remoteSocket, + Net: "unix", + } + + return sshForwardRemote(ctx, stderr, sshClient, localAddr, remoteAddr) +} diff --git a/cli/testdata/coder_ssh_--help.golden b/cli/testdata/coder_ssh_--help.golden index 0b543b251b0c9..86010356f49a6 100644 --- a/cli/testdata/coder_ssh_--help.golden +++ b/cli/testdata/coder_ssh_--help.golden @@ -7,6 +7,14 @@ Flags: -A, --forward-agent Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK. Consumes $CODER_SSH_FORWARD_AGENT + -G, --forward-gpg Specifies whether to forward the GPG agent. + Unsupported on Windows workspaces, but supports all + clients. Requires gnupg (gpg, gpgconf) on both the + client and workspace. The GPG agent must already be + running locally and will not be started for you. If + a GPG agent is already running in the workspace, it + will be attempted to be killed. + Consumes $CODER_SSH_FORWARD_GPG -h, --help help for ssh --identity-agent string Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be diff --git a/pty/pty.go b/pty/pty.go index b37369c9c1ec2..5910509ba45c8 100644 --- a/pty/pty.go +++ b/pty/pty.go @@ -61,8 +61,9 @@ type WithFlags interface { type Option func(*ptyOptions) type ptyOptions struct { - logger *log.Logger - sshReq *ssh.Pty + logger *log.Logger + sshReq *ssh.Pty + setGPGTTY bool } // WithSSHRequest applies the ssh.Pty request to the PTY. @@ -81,6 +82,14 @@ func WithLogger(logger *log.Logger) Option { } } +// WithGPGTTY sets the GPG_TTY environment variable to the PTY name. This only +// applies to non-Windows platforms. +func WithGPGTTY() Option { + return func(opts *ptyOptions) { + opts.setGPGTTY = true + } +} + // New constructs a new Pty. func New(opts ...Option) (PTY, error) { return newPty(opts...) diff --git a/pty/start_other.go b/pty/start_other.go index 2bf3bdfebc0d9..c38b6dcf8faee 100644 --- a/pty/start_other.go +++ b/pty/start_other.go @@ -27,6 +27,9 @@ func startPty(cmd *exec.Cmd, opt ...StartOption) (retPTY *otherPty, proc Process if opty.opts.sshReq != nil { cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", opty.Name())) } + if opty.opts.setGPGTTY { + cmd.Env = append(cmd.Env, fmt.Sprintf("GPG_TTY=%s", opty.Name())) + } cmd.SysProcAttr = &syscall.SysProcAttr{ Setsid: true, diff --git a/scripts/develop.sh b/scripts/develop.sh index 901c69d498786..533e984666ec6 100755 --- a/scripts/develop.sh +++ b/scripts/develop.sh @@ -121,7 +121,7 @@ fatal() { trap 'fatal "Script encountered an error"' ERR cdroot - start_cmd API "" "${CODER_DEV_SHIM}" server --http-address 0.0.0.0:3000 --swagger-enable + start_cmd API "" "${CODER_DEV_SHIM}" server --http-address 0.0.0.0:3000 --swagger-enable --access-url "http://127.0.0.1:3000" echo '== Waiting for Coder to become ready' # Start the timeout in the background so interrupting this script