/
shell.go
158 lines (132 loc) · 3.77 KB
/
shell.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package shell
import (
"bufio"
"bytes"
"encoding/base64"
"fmt"
"github.com/kevinoula/beach/log"
"golang.org/x/crypto/ssh"
"io"
"os"
"time"
)
// SSH is an object which runs a single SSH session for a given user.
type SSH struct {
// Username is the username provided to sign onto an SSH session.
Username string
// Password is the username provided to sign onto an SSH session.
Password string
// Host is the username provided to sign onto an SSH session.
Hostname string
// client is the generated SSH client that handles the TLS handshake between the user and remote server.
client *ssh.Client
// session is the generated SSH session that delivers user input and remote server output.
session *ssh.Session
// stdin is the IO writer for the SSH session where the user sends inputs to.
stdin io.WriteCloser
// stdout is the IO reader for the SSH session which reads the remote server output.
stdout io.Reader
// stderr is the IO reader for the SSH session which reads errors from the remote server output
stderr io.Reader
}
// CreateSession creates all the necessary components to begin an SSH session with a remote server.
func (s *SSH) CreateSession() error {
decodedPass, _ := base64.StdEncoding.DecodeString(s.Password)
trimmedPass := bytes.TrimSpace(decodedPass)
sshConfig := &ssh.ClientConfig{
User: s.Username,
Auth: []ssh.AuthMethod{ssh.Password(string(trimmedPass))},
Timeout: time.Second * 10,
}
sshConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey() // TODO validate host key
client, err := ssh.Dial("tcp", s.Hostname+":22", sshConfig)
if err != nil {
return err
}
session, err := client.NewSession()
if err != nil {
_ = client.Close()
return err
}
s.client = client
s.session = session
return nil
}
// StartSession uses the SSH client and session to begin serving input and outputs from the user to the remote server and back.
func (s *SSH) StartSession() error {
log.Info.Printf("Attempting to SSH into %s@%s...\n", s.Username, s.Hostname)
err := s.CreateSession()
if err != nil {
return fmt.Errorf("creating session resulted in %v\n", err)
}
// Defer closing client and session
defer func(client *ssh.Client) {
_ = client.Close()
}(s.client)
defer func(session *ssh.Session) {
_ = session.Close()
}(s.session)
s.stdin, err = s.session.StdinPipe()
if err != nil {
return fmt.Errorf("connecting stdin to pipe resulted in %v\n", err)
}
s.stdout, err = s.session.StdoutPipe()
if err != nil {
return fmt.Errorf("connecting stdout to pipe resulted in %v\n", err)
}
s.stderr, err = s.session.StderrPipe()
if err != nil {
return fmt.Errorf("connecting stdout to pipe resulted in %v\n", err)
}
// go routine to pass stdin to shell stdin
wr := make(chan []byte, 10)
go func() {
for {
select {
case d := <-wr:
_, err := s.stdin.Write(d)
if err != nil {
log.Err.Printf("writing to stdin resulted in %v\n", err)
break
}
}
}
}()
// go routine to scan shell stdout
go func() {
scanner := bufio.NewScanner(s.stdout)
for {
if tkn := scanner.Scan(); tkn {
rcv := scanner.Bytes()
raw := make([]byte, len(rcv))
copy(raw, rcv)
fmt.Println(string(raw))
} else if scanner.Err() != nil {
log.Err.Printf("error scanning: %v\n", scanner.Err())
} else {
log.Err.Println("io.EOF")
break
}
}
}()
// go routine to scan stderr
go func() {
scanner := bufio.NewScanner(s.stderr)
for scanner.Scan() {
fmt.Println(scanner.Text())
}
}()
// Open SSH session
_ = s.session.Shell()
for {
fmt.Printf("%s@%s $ ", s.Username, s.Hostname)
scanner := bufio.NewScanner(os.Stdin)
scanner.Scan()
text := scanner.Text()
if text == "exit" {
return nil
}
wr <- []byte(text + "\n")
time.Sleep(time.Second * 1) // short input delay to allow output to populate
}
}