diff --git a/pkg/exec/exec.go b/pkg/exec/exec.go index 8867caa9a86..cc8a8287e61 100644 --- a/pkg/exec/exec.go +++ b/pkg/exec/exec.go @@ -84,7 +84,9 @@ func (w *wrap) CmdAsyncWithContext(ctx context.Context, host string, commands .. } func (w *wrap) CmdAsync(host string, commands ...string) error { - return w.CmdAsyncWithContext(context.Background(), host, commands...) + ctx, cancel := ssh.GetTimeoutContext() + defer cancel() + return w.CmdAsyncWithContext(ctx, host, commands...) } func warnIfNotAbs(path string) { diff --git a/pkg/ssh/option.go b/pkg/ssh/option.go index afd66f2d6c1..b78facb32e2 100644 --- a/pkg/ssh/option.go +++ b/pkg/ssh/option.go @@ -46,7 +46,7 @@ func (o *Option) BindFlags(fs *pflag.FlagSet) { fs.StringVarP(&o.privateKey, "private-key", "i", o.privateKey, "selects a file from which the identity (private key) for public key authentication is read") fs.StringVar(&o.passphrase, "passphrase", o.passphrase, "passphrase for decrypting a PEM encoded private key") - fs.DurationVar(&o.timeout, "timeout", o.timeout, "ssh connection timeout") + fs.DurationVar(&o.timeout, "timeout", o.timeout, "ssh connection establish timeout") } const ( diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index e6803e89f22..9f72bdcf545 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -16,6 +16,7 @@ package ssh import ( "context" + "time" "github.com/spf13/pflag" "golang.org/x/crypto/ssh" @@ -26,10 +27,21 @@ import ( "github.com/labring/sealos/pkg/utils/logger" ) -var defaultMaxRetry = 5 +var ( + defaultMaxRetry = 5 + defaultExecutionTimeout = 300 * time.Second +) func RegisterFlags(fs *pflag.FlagSet) { fs.IntVar(&defaultMaxRetry, "max-retry", defaultMaxRetry, "define max num of ssh retry times") + fs.DurationVar(&defaultExecutionTimeout, "execution-timeout", defaultExecutionTimeout, "timeout setting of command execution") +} + +// GetTimeoutContext create a context.Context with default timeout +// default execution timeout in sealos is just fine, if you want to customize the timeout setting, +// you must invoke the `RegisterFlags` function above. +func GetTimeoutContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), defaultExecutionTimeout) } type Interface interface { diff --git a/pkg/ssh/sshcmd.go b/pkg/ssh/sshcmd.go index 2e645719041..ad28425db01 100644 --- a/pkg/ssh/sshcmd.go +++ b/pkg/ssh/sshcmd.go @@ -91,7 +91,7 @@ func (c *Client) CmdAsyncWithContext(ctx context.Context, host string, cmds ...s }() select { case <-ctx.Done(): - return nil + return ctx.Err() case err = <-errCh: return err } @@ -99,7 +99,9 @@ func (c *Client) CmdAsyncWithContext(ctx context.Context, host string, cmds ...s // CmdAsync not actually asynchronously, just print output asynchronously func (c *Client) CmdAsync(host string, cmds ...string) error { - return c.CmdAsyncWithContext(context.Background(), host, cmds...) + ctx, cancel := GetTimeoutContext() + defer cancel() + return c.CmdAsyncWithContext(ctx, host, cmds...) } func (c *Client) Cmd(host, cmd string) ([]byte, error) {