From 8c77e53c35802ccea1d7e07af3737f01e480844f Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Mon, 29 Jan 2024 19:20:51 -0800 Subject: [PATCH 1/3] Send activity signals during non-interactive codespace SSH command --- internal/codespaces/rpc/invoker.go | 33 ++++++++++++++++++++++-------- internal/codespaces/ssh.go | 14 +++++++++---- pkg/cmd/codespace/ssh.go | 11 +++++++++- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 7701d609285..a0830d2fce1 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -31,6 +31,7 @@ const ( codespacesInternalSessionName = "CodespacesInternal" clientName = "gh" connectedEventName = "connected" + keepAliveEventName = "keepAlive" ) type StartSSHServerOptions struct { @@ -43,16 +44,18 @@ type Invoker interface { RebuildContainer(ctx context.Context, full bool) error StartSSHServer(ctx context.Context) (int, string, error) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) + KeepAlive() } type invoker struct { - conn *grpc.ClientConn - fwd portforwarder.PortForwarder - listener net.Listener - jupyterClient jupyter.JupyterServerHostClient - codespaceClient codespace.CodespaceHostClient - sshClient ssh.SshServerHostClient - cancelPF context.CancelFunc + conn *grpc.ClientConn + fwd portforwarder.PortForwarder + listener net.Listener + jupyterClient jupyter.JupyterServerHostClient + codespaceClient codespace.CodespaceHostClient + sshClient ssh.SshServerHostClient + cancelPF context.CancelFunc + keepAliveOverride bool } // Connects to the internal RPC server and returns a new invoker for it @@ -256,6 +259,12 @@ func listenTCP() (*net.TCPListener, error) { return listener, nil } +// KeepAlive sets a flag to continuously send activity signals to +// the codespace even if there is no other activity (e.g. stdio) +func (i *invoker) KeepAlive() { + i.keepAliveOverride = true +} + // Periodically check whether there is a reason to keep the connection alive, and if so, notify the codespace to do so func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) @@ -266,7 +275,15 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { case <-ctx.Done(): return case <-ticker.C: - reason := i.fwd.GetKeepAliveReason() + reason := "" + + // If the keep alive override flag is set, we don't need to check for activity on the forwarder + // Otherwise, grab the reason from the forwarder + if i.keepAliveOverride { + reason = keepAliveEventName + } else { + reason = i.fwd.GetKeepAliveReason() + } _ = i.notifyCodespaceOfClientActivity(ctx, reason) } } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index a623d36f94a..d2bd0acf3cb 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -19,9 +19,9 @@ type printer interface { // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell( - ctx context.Context, p printer, sshArgs []string, port int, destination string, printConnDetails bool, + ctx context.Context, keepAliveOverride chan (bool), p printer, sshArgs []string, port int, destination string, printConnDetails bool, ) error { - cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + cmd, connArgs, err := newSSHCommand(ctx, keepAliveOverride, port, destination, sshArgs) if err != nil { return fmt.Errorf("failed to create ssh command: %w", err) } @@ -51,13 +51,13 @@ func Copy(ctx context.Context, scpArgs []string, port int, destination string) e // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { - cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + cmd, _, err := newSSHCommand(ctx, nil, tunnelPort, destination, sshArgs) return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { +func newSSHCommand(ctx context.Context, keepAliveOverride chan (bool), port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { connArgs := []string{ "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes", @@ -81,6 +81,12 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) if command != nil { cmdArgs = append(cmdArgs, command...) + + // If the user specified a command to run non-interactively, + // make sure we send activity signals to keep the connection alive + if keepAliveOverride != nil { + keepAliveOverride <- true + } } exe, err := safeexec.LookPath("ssh") diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 8302e24d85f..ea567a0e50e 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -281,8 +281,17 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e // args is the correct variable to use here, we just use scpArgs as the check for which command to run err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination) } else { + // Create a channel to send down to the shell to keep it alive + keepAliveOverride := make(chan bool, 1) + go func() { + // If we receive true on the channel, ignore the timeout + if <-keepAliveOverride { + invoker.KeepAlive() + } + }() + err = codespaces.Shell( - ctx, a.errLogger, args, localSSHServerPort, connectDestination, opts.printConnDetails, + ctx, keepAliveOverride, a.errLogger, args, localSSHServerPort, connectDestination, opts.printConnDetails, ) } shellClosed <- err From 2cb044caf58fbabb385697d5315f637a3116d4ff Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 30 Jan 2024 13:16:46 -0800 Subject: [PATCH 2/3] Parse SSH args before creating the shell --- internal/codespaces/ssh.go | 39 ++++++++++++++------------------- internal/codespaces/ssh_test.go | 2 +- pkg/cmd/codespace/ssh.go | 23 +++++++++++-------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index d2bd0acf3cb..abffdbf81fd 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -19,9 +19,9 @@ type printer interface { // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell( - ctx context.Context, keepAliveOverride chan (bool), p printer, sshArgs []string, port int, destination string, printConnDetails bool, + ctx context.Context, p printer, sshArgs []string, command []string, port int, destination string, printConnDetails bool, ) error { - cmd, connArgs, err := newSSHCommand(ctx, keepAliveOverride, port, destination, sshArgs) + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs, command) if err != nil { return fmt.Errorf("failed to create ssh command: %w", err) } @@ -51,42 +51,30 @@ func Copy(ctx context.Context, scpArgs []string, port int, destination string) e // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { - cmd, _, err := newSSHCommand(ctx, nil, tunnelPort, destination, sshArgs) + sshArgs, command, err := ParseSSHArgs(sshArgs) + if err != nil { + return nil, err + } + + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs, command) return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, keepAliveOverride chan (bool), port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string, command []string) (*exec.Cmd, []string, error) { connArgs := []string{ "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes", "-o", "PasswordAuthentication=no", } - // The ssh command syntax is: ssh [flags] user@host command [args...] - // There is no way to specify the user@host destination as a flag. - // Unfortunately, that means we need to know which user-provided words are - // SSH flags and which are command arguments so that we can place - // them before or after the destination, and that means we need to know all - // the flags and their arities. - cmdArgs, command, err := parseSSHArgs(cmdArgs) - if err != nil { - return nil, nil, err - } - cmdArgs = append(cmdArgs, connArgs...) cmdArgs = append(cmdArgs, "-C") // Compression cmdArgs = append(cmdArgs, dst) // user@host if command != nil { cmdArgs = append(cmdArgs, command...) - - // If the user specified a command to run non-interactively, - // make sure we send activity signals to keep the connection alive - if keepAliveOverride != nil { - keepAliveOverride <- true - } } exe, err := safeexec.LookPath("ssh") @@ -102,7 +90,14 @@ func newSSHCommand(ctx context.Context, keepAliveOverride chan (bool), port int, return cmd, connArgs, nil } -func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { +// ParseSSHArgs parses the given array of arguments into two distinct slices of flags and command. +// The ssh command syntax is: ssh [flags] user@host command [args...] +// There is no way to specify the user@host destination as a flag. +// Unfortunately, that means we need to know which user-provided words are +// SSH flags and which are command arguments so that we can place +// them before or after the destination, and that means we need to know all +// the flags and their arities. +func ParseSSHArgs(args []string) (cmdArgs, command []string, err error) { return parseArgs(args, "bcDeFIiLlmOopRSWw") } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index 7389bd2f925..faea74ae2ca 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -74,7 +74,7 @@ func TestParseSSHArgs(t *testing.T) { } for _, tcase := range testCases { - args, command, err := parseSSHArgs(tcase.Args) + args, command, err := ParseSSHArgs(tcase.Args) checkParseResult(t, tcase, args, command, err) } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index ea567a0e50e..43e081103c7 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -281,17 +281,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e // args is the correct variable to use here, we just use scpArgs as the check for which command to run err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination) } else { - // Create a channel to send down to the shell to keep it alive - keepAliveOverride := make(chan bool, 1) - go func() { - // If we receive true on the channel, ignore the timeout - if <-keepAliveOverride { - invoker.KeepAlive() - } - }() + // Parse the ssh args to determine if the user specified a command + args, command, err := codespaces.ParseSSHArgs(args) + if err != nil { + shellClosed <- err + return + } + + // If the user specified a command, we need to keep the shell alive + // since it will be non-interactive and the codespace might shut down + // before the command finishes + if command != nil { + invoker.KeepAlive() + } err = codespaces.Shell( - ctx, keepAliveOverride, a.errLogger, args, localSSHServerPort, connectDestination, opts.printConnDetails, + ctx, a.errLogger, args, command, localSSHServerPort, connectDestination, opts.printConnDetails, ) } shellClosed <- err From 400db0f41b275383c92f282e6480fc6fb92ac819 Mon Sep 17 00:00:00 2001 From: David Gardiner Date: Tue, 30 Jan 2024 13:24:23 -0800 Subject: [PATCH 3/3] Fix linting error --- pkg/cmd/codespace/ssh.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 43e081103c7..571f7426777 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -276,10 +276,9 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e shellClosed := make(chan error, 1) go func() { - var err error if opts.scpArgs != nil { // args is the correct variable to use here, we just use scpArgs as the check for which command to run - err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination) + shellClosed <- codespaces.Copy(ctx, args, localSSHServerPort, connectDestination) } else { // Parse the ssh args to determine if the user specified a command args, command, err := codespaces.ParseSSHArgs(args) @@ -295,11 +294,10 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e invoker.KeepAlive() } - err = codespaces.Shell( + shellClosed <- codespaces.Shell( ctx, a.errLogger, args, command, localSSHServerPort, connectDestination, opts.printConnDetails, ) } - shellClosed <- err }() select {