-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
ssh_command.go
347 lines (305 loc) · 9.18 KB
/
ssh_command.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
// Copyright 2017-2019 Authors of Cilium
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package helpers
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"strconv"
"sync"
"time"
"github.com/kevinburke/ssh_config"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
// SSHCommand stores the data associated with executing a command.
// TODO: this is poorly named in that it's not related to a command only
// ran over SSH - rename this.
type SSHCommand struct {
// TODO: path is not a clear name - rename to something more clear.
Path string
Env []string
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
}
// SSHClient stores the information needed to SSH into a remote location for
// running tests.
type SSHClient struct {
Config *ssh.ClientConfig // ssh client configuration information.
Host string // Ip/Host from the target virtualserver
Port int // Port to connect to the target server
client *ssh.Client // Client implements a traditional SSH client that supports shells,
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
}
// GetHostPort returns the host port representation of the ssh client
func (cli *SSHClient) GetHostPort() string {
return net.JoinHostPort(cli.Host, strconv.Itoa(cli.Port))
}
// SSHConfig contains metadata for an SSH session.
type SSHConfig struct {
target string
host string
user string
port int
identityFile string
}
// SSHConfigs maps the name of a host (VM) to its corresponding SSHConfiguration
type SSHConfigs map[string]*SSHConfig
// GetSSHClient initializes an SSHClient based on the provided SSHConfig
func (cfg *SSHConfig) GetSSHClient() *SSHClient {
sshConfig := &ssh.ClientConfig{
User: cfg.user,
Auth: []ssh.AuthMethod{
cfg.GetSSHAgent(),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 15 * time.Second,
}
return &SSHClient{
Config: sshConfig,
Host: cfg.host,
Port: cfg.port,
}
}
func (client *SSHClient) String() string {
return fmt.Sprintf("host: %s, port: %d, user: %s", client.Host, client.Port, client.Config.User)
}
func (cfg *SSHConfig) String() string {
return fmt.Sprintf("target: %s, host: %s, port %d, user, %s, identityFile: %s", cfg.target, cfg.host, cfg.port, cfg.user, cfg.identityFile)
}
// GetSSHAgent returns the ssh.AuthMethod corresponding to SSHConfig cfg.
func (cfg *SSHConfig) GetSSHAgent() ssh.AuthMethod {
key, err := ioutil.ReadFile(cfg.identityFile)
if err != nil {
log.Fatalf("unable to retrieve ssh-key on target '%s': %s", cfg.target, err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Fatalf("unable to parse private key on target '%s': %s", cfg.target, err)
}
return ssh.PublicKeys(signer)
}
// ImportSSHconfig imports the SSH configuration stored at the provided path.
// Returns an error if the SSH configuration could not be instantiated.
func ImportSSHconfig(config []byte) (SSHConfigs, error) {
result := make(SSHConfigs)
cfg, err := ssh_config.Decode(bytes.NewBuffer(config))
if err != nil {
return nil, err
}
for _, host := range cfg.Hosts {
key := host.Patterns[0].String()
if key == "*" {
continue
}
port, _ := cfg.Get(key, "Port")
hostConfig := SSHConfig{target: key}
hostConfig.host, _ = cfg.Get(key, "Hostname")
hostConfig.identityFile, _ = cfg.Get(key, "identityFile")
hostConfig.user, _ = cfg.Get(key, "User")
hostConfig.port, _ = strconv.Atoi(port)
result[key] = &hostConfig
}
return result, nil
}
// copyWait runs an instance of io.Copy() in a goroutine, and returns a channel
// to receive the error result.
func copyWait(dst io.Writer, src io.Reader) chan error {
c := make(chan error)
go func() {
_, err := io.Copy(dst, src)
c <- err
}()
return c
}
// runCommand runs the specified command on the provided SSH session, and
// gathers both of the sterr and stdout output into the writers provided by
// cmd. Returns nil when the command completes successfully and all stderr,
// stdout output has been written. Returns an error otherwise.
func runCommand(session *ssh.Session, cmd *SSHCommand) error {
stderr, err := session.StderrPipe()
if err != nil {
return fmt.Errorf("Unable to setup stderr for session: %v", err)
}
errChan := copyWait(cmd.Stderr, stderr)
stdout, err := session.StdoutPipe()
if err != nil {
return fmt.Errorf("Unable to setup stdout for session: %v", err)
}
outChan := copyWait(cmd.Stdout, stdout)
if err = session.Run(cmd.Path); err != nil {
return err
}
if err = <-errChan; err != nil {
return err
}
if err = <-outChan; err != nil {
return err
}
return nil
}
// RunCommand runs a SSHCommand using SSHClient client. The returned error is
// nil if the command runs, has no problems copying stdin, stdout, and stderr,
// and exits with a zero exit status.
func (client *SSHClient) RunCommand(cmd *SSHCommand) error {
session, err := client.newSession()
if err != nil {
return err
}
defer session.Close()
return runCommand(session, cmd)
}
// RunCommandInBackground runs an SSH command in a similar way to
// RunCommandContext, but with a context which allows the command to be
// cancelled at any time. When cancel is called the error of the command is
// returned instead the context error.
func (client *SSHClient) RunCommandInBackground(ctx context.Context, cmd *SSHCommand) error {
if ctx == nil {
panic("nil context provided to RunCommandInBackground()")
}
session, err := client.newSession()
if err != nil {
return err
}
defer session.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1, // enable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
session.RequestPty("xterm-256color", 80, 80, modes)
stdin, err := session.StdinPipe()
if err != nil {
log.Errorf("Could not get stdin: %s", err)
}
go func() {
select {
case <-ctx.Done():
_, err := stdin.Write([]byte{3})
if err != nil {
log.Errorf("write ^C error: %s", err)
}
err = session.Wait()
if err != nil {
log.Errorf("wait error: %s", err)
}
if err = session.Signal(ssh.SIGHUP); err != nil {
log.Errorf("failed to kill command: %s", err)
}
if err = session.Close(); err != nil {
log.Errorf("failed to close session: %s", err)
}
}
}()
return runCommand(session, cmd)
}
// RunCommandContext runs an SSH command in a similar way to RunCommand but with
// a context. If context is canceled it will return the error of that given
// context.
func (client *SSHClient) RunCommandContext(ctx context.Context, cmd *SSHCommand) error {
if ctx == nil {
panic("nil context provided to RunCommandContext()")
}
session, err := client.newSession()
if err != nil {
return err
}
defer session.Close()
stdin, err := session.StdinPipe()
if err != nil {
log.Errorf("Could not get stdin %s", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
select {
case <-ctx.Done():
_, err := stdin.Write([]byte{3})
if err != nil {
log.Errorf("write ^C error: %s", err)
}
err = session.Wait()
if err != nil {
log.Errorf("wait error: %s", err)
}
if err = session.Signal(ssh.SIGHUP); err != nil {
log.Errorf("failed to kill command: %s", err)
}
if err = session.Close(); err != nil {
log.Errorf("failed to close session: %s", err)
}
}
wg.Done()
}()
err = runCommand(session, cmd)
select {
case <-ctx.Done():
// Wait until the ssh session is stopped
wg.Wait()
return ctx.Err()
default:
return err
}
}
func (client *SSHClient) newSession() (*ssh.Session, error) {
var connection *ssh.Client
var err error
if client.client != nil {
connection = client.client
} else {
connection, err = ssh.Dial(
"tcp",
fmt.Sprintf("%s:%d", client.Host, client.Port),
client.Config)
if err != nil {
return nil, fmt.Errorf("failed to dial: %s", err)
}
client.client = connection
}
session, err := connection.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create session: %s", err)
}
return session, nil
}
// SSHAgent returns the ssh.Authmethod using the Public keys. Returns nil if
// a connection to SSH_AUTH_SHOCK does not succeed.
func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}
// GetSSHClient initializes an SSHClient for the specified host/port/user
// combination.
func GetSSHClient(host string, port int, user string) *SSHClient {
sshConfig := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
SSHAgent(),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 15 * time.Second,
}
return &SSHClient{
Config: sshConfig,
Host: host,
Port: port,
}
}