/
sessionwork.go
115 lines (103 loc) · 2.5 KB
/
sessionwork.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package pssh
import (
"context"
"errors"
"fmt"
"io"
"strings"
"golang.org/x/crypto/ssh"
)
type sess interface {
StderrPipe() (io.Reader, error)
StdoutPipe() (io.Reader, error)
Run(cmd string) error
Close() error
}
type sessionWork struct {
id int
*input
con *conWork
runner func(ctx context.Context, res *result, session sess)
}
func (s *sessionWork) newResult() *result {
return s.con.newResult(s.con.id, s.id)
}
func (s *sessionWork) getPipe(ctx context.Context, pipe func() (io.Reader, error), res *result, name string) (io.Reader, error) {
out, err := pipe()
if err != nil {
s.result(ctx, fmt.Errorf("cannot open %sPipe: %v", name, err), res)
}
return out, err
}
func (s *sessionWork) result(ctx context.Context, err error, res *result) {
res.err = err
s.errResult(ctx, res)
}
type sessErr struct {
name string
err error
}
func (s *sessionWork) run(ctx context.Context, res *result, session sess) {
// nolint: errcheck,gosec
defer session.Close()
stdout, err := s.getPipe(ctx, session.StdoutPipe, res, "stdout")
if err != nil {
return
}
stderr, err := s.getPipe(ctx, session.StderrPipe, res, "stderr")
if err != nil {
return
}
errChs := []chan error{make(chan error, 1), make(chan error, 1)}
errs := []sessErr{
{name: "stdoutStream err:", err: nil}, // 0
{name: "stderrStream err:", err: nil}, // 1
{name: "", err: nil}, // 2
{name: "I/O err:", err: nil}, // 3
}
go readStream(ctx, res.stdout, stdout, errChs[0])
go readStream(ctx, res.stderr, stderr, errChs[1])
err = session.Run(s.command)
if err != nil {
if ee, ok := err.(*ssh.ExitError); ok {
errs[2].err = ee
res.code = ee.ExitStatus()
} else {
errs[3].err = err
}
}
for i := 0; i < len(errChs); i++ {
errs[i].err = getErr(ctx, errChs[i])
}
res.err = getAllError(errs)
s.errResult(ctx, res)
}
func getAllError(errs []sessErr) error {
s := make([]string, 0, len(errs))
for _, e := range errs {
if e.err != nil {
s = append(s, e.name+e.err.Error())
}
}
if len(s) > 0 {
return errors.New(strings.Join(s, "\n"))
}
return nil
}
func (s *sessionWork) worker(ctx context.Context, conn client) {
res := s.newResult()
session, err := conn.NewSession()
if err != nil {
s.result(ctx, fmt.Errorf("cannot open new session: %v", err), res)
return
}
// nolint: errcheck
session.Stdin = strings.NewReader(s.stdin)
s.runner(ctx, res, session)
}
func (s *sessionWork) errResult(ctx context.Context, res *result) {
select {
case <-ctx.Done():
case s.results <- res:
}
}