-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
ssh_command.go
348 lines (302 loc) · 9.59 KB
/
ssh_command.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
// SPDX-License-Identifier: Apache-2.0
// Copyright Authors of Cilium
package helpers
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"time"
"github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// SSHCommand stores the data associated with executing a command.
// TODO: this is poorly named in that it's not related to a command only
// ran over SSH - rename this.
type SSHCommand struct {
// TODO: path is not a clear name - rename to something more clear.
Path string
Env []string
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
}
// SSHClient stores the information needed to SSH into a remote location for
// running tests.
type SSHClient struct {
Config *ssh.ClientConfig // ssh client configuration information.
Host string // Ip/Host from the target virtualserver
Port int // Port to connect to the target server
client *ssh.Client // Client implements a traditional SSH client that supports shells,
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
}
// GetHostPort returns the host port representation of the ssh client
func (cli *SSHClient) GetHostPort() string {
return net.JoinHostPort(cli.Host, strconv.Itoa(cli.Port))
}
// SSHConfig contains metadata for an SSH session.
type SSHConfig struct {
target string
host string
user string
port int
identityFile string
}
// SSHConfigs maps the name of a host (VM) to its corresponding SSHConfiguration
type SSHConfigs map[string]*SSHConfig
// GetSSHClient initializes an SSHClient based on the provided SSHConfig
func (cfg *SSHConfig) GetSSHClient() *SSHClient {
var auths []ssh.AuthMethod
sshAgent := cfg.GetSSHAgent()
if sshAgent != nil {
auths = []ssh.AuthMethod{
sshAgent,
}
}
sshConfig := &ssh.ClientConfig{
User: cfg.user,
Auth: auths,
// ssh.InsecureIgnoreHostKey is OK in test code.
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // lgtm[go/insecure-hostkeycallback]
Timeout: 15 * time.Second,
}
return &SSHClient{
Config: sshConfig,
Host: cfg.host,
Port: cfg.port,
}
}
func (client *SSHClient) String() string {
return fmt.Sprintf("host: %s, port: %d, user: %s", client.Host, client.Port, client.Config.User)
}
func (cfg *SSHConfig) String() string {
return fmt.Sprintf("target: %s, host: %s, port %d, user, %s, identityFile: %s", cfg.target, cfg.host, cfg.port, cfg.user, cfg.identityFile)
}
// GetSSHAgent returns the ssh.AuthMethod corresponding to SSHConfig cfg.
func (cfg *SSHConfig) GetSSHAgent() ssh.AuthMethod {
if cfg.identityFile == "" {
return nil
}
key, err := os.ReadFile(cfg.identityFile)
if err != nil {
log.Fatalf("unable to retrieve ssh-key on target '%s': %s", cfg.target, err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Fatalf("unable to parse private key on target '%s': %s", cfg.target, err)
}
return ssh.PublicKeys(signer)
}
// ImportSSHconfig imports the SSH configuration stored at the provided path.
// Returns an error if the SSH configuration could not be instantiated.
func ImportSSHconfig(config []byte) (SSHConfigs, error) {
result := make(SSHConfigs)
cfg, err := ssh_config.Decode(bytes.NewBuffer(config))
if err != nil {
return nil, err
}
for _, host := range cfg.Hosts {
key := host.Patterns[0].String()
if key == "*" {
continue
}
port, _ := cfg.Get(key, "Port")
hostConfig := SSHConfig{target: key}
hostConfig.host, _ = cfg.Get(key, "Hostname")
hostConfig.identityFile, _ = cfg.Get(key, "identityFile")
hostConfig.user, _ = cfg.Get(key, "User")
hostConfig.port, _ = strconv.Atoi(port)
result[key] = &hostConfig
}
return result, nil
}
// copyWait runs an instance of io.Copy() in a goroutine, and returns a channel
// to receive the error result.
func copyWait(dst io.Writer, src io.Reader) chan error {
c := make(chan error, 1)
go func() {
_, err := io.Copy(dst, src)
c <- err
}()
return c
}
// runCommand runs the specified command on the provided SSH session, and
// gathers both of the sterr and stdout output into the writers provided by
// cmd. Returns whether the command was run and an optional error.
// Returns nil when the command completes successfully and all stderr,
// stdout output has been written. Returns an error otherwise.
func runCommand(session *ssh.Session, cmd *SSHCommand) (bool, error) {
stderr, err := session.StderrPipe()
if err != nil {
return false, fmt.Errorf("Unable to setup stderr for session: %w", err)
}
errChan := copyWait(cmd.Stderr, stderr)
stdout, err := session.StdoutPipe()
if err != nil {
return false, fmt.Errorf("Unable to setup stdout for session: %w", err)
}
outChan := copyWait(cmd.Stdout, stdout)
if err = session.Run(cmd.Path); err != nil {
return false, err
}
if err = <-errChan; err != nil {
return true, err
}
if err = <-outChan; err != nil {
return true, err
}
return true, nil
}
// RunCommand runs a SSHCommand using SSHClient client. The returned error is
// nil if the command runs, has no problems copying stdin, stdout, and stderr,
// and exits with a zero exit status.
func (client *SSHClient) RunCommand(cmd *SSHCommand) error {
session, err := client.newSession()
if err != nil {
return err
}
defer session.Close()
_, err = runCommand(session, cmd)
return err
}
// RunCommandInBackground runs an SSH command in a similar way to
// RunCommandContext, but with a context which allows the command to be
// cancelled at any time. When cancel is called the error of the command is
// returned instead the context error.
func (client *SSHClient) RunCommandInBackground(ctx context.Context, cmd *SSHCommand) error {
if ctx == nil {
panic("nil context provided to RunCommandInBackground()")
}
session, err := client.newSession()
if err != nil {
return err
}
defer session.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1, // enable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
session.RequestPty("xterm-256color", 80, 80, modes)
stdin, err := session.StdinPipe()
if err != nil {
log.Errorf("Could not get stdin: %s", err)
}
go func() {
<-ctx.Done()
_, err := stdin.Write([]byte{3})
if err != nil {
log.Errorf("write ^C error: %s", err)
}
err = session.Wait()
if err != nil {
log.Errorf("wait error: %s", err)
}
if err = session.Signal(ssh.SIGHUP); err != nil {
log.Errorf("failed to kill command: %s", err)
}
if err = session.Close(); err != nil {
log.Errorf("failed to close session: %s", err)
}
}()
_, err = runCommand(session, cmd)
return err
}
// RunCommandContext runs an SSH command in a similar way to RunCommand but with
// a context. If context is canceled it will return the error of that given
// context.
func (client *SSHClient) RunCommandContext(ctx context.Context, cmd *SSHCommand) error {
if ctx == nil {
panic("nil context provided to RunCommandContext()")
}
var (
session *ssh.Session
sessionErrChan = make(chan error, 1)
)
go func() {
var sessionErr error
// This may block depending on the state of the setup tests are being
// ran against. As a result, these goroutines may leak, but the logic
// below will fail and propagate to the rest of the CI framework, which
// will error out anyway. It's better to leak in really bad cases since
// the CI will fail anyway. Unfortunately, the golang SSH library does
// not provide a way to propagate context through to creating sessions.
// Note that this is a closure on the session variable!
session, sessionErr = client.newSession()
if sessionErr != nil {
log.Infof("error creating session: %s", sessionErr)
sessionErrChan <- sessionErr
return
}
_, runErr := runCommand(session, cmd)
sessionErrChan <- runErr
}()
select {
case asyncErr := <-sessionErrChan:
return asyncErr
case <-ctx.Done():
if session != nil {
log.Warning("sending SIGHUP to session due to canceled context")
if err := session.Signal(ssh.SIGHUP); err != nil {
log.Errorf("failed to kill command when context is canceled: %s", err)
}
if closeErr := session.Close(); closeErr != nil {
log.WithError(closeErr).Error("failed to close session")
}
} else {
log.Error("timeout reached; no session was able to be created")
}
return ctx.Err()
}
}
func (client *SSHClient) newSession() (*ssh.Session, error) {
var connection *ssh.Client
var err error
if client.client != nil {
connection = client.client
} else {
connection, err = ssh.Dial(
"tcp",
net.JoinHostPort(client.Host, fmt.Sprintf("%d", client.Port)),
client.Config)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
client.client = connection
}
session, err := connection.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return session, nil
}
// SSHAgent returns the ssh.Authmethod using the Public keys. Returns nil if
// a connection to SSH_AUTH_SHOCK does not succeed.
func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}
// GetSSHClient initializes an SSHClient for the specified host/port/user
// combination.
func GetSSHClient(host string, port int, user string) *SSHClient {
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
SSHAgent(),
},
// ssh.InsecureIgnoreHostKey is OK in test code.
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // lgtm[go/insecure-hostkeycallback]
Timeout: 15 * time.Second,
}
return &SSHClient{
Config: sshConfig,
Host: host,
Port: port,
}
}