/
sshstub.go
118 lines (100 loc) · 3.24 KB
/
sshstub.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main
import (
"bytes"
"io"
"net"
"net/http"
"os"
"time"
"github.com/edgenesis/shifu/pkg/logger"
"golang.org/x/crypto/ssh"
)
// Get the required configuration from the environment variables
var (
privateSSHKeyFile = os.Getenv("EDGEDEVICE_DRIVER_SSH_KEY_PATH")
driverHTTPPort = os.Getenv("EDGEDEVICE_DRIVER_HTTP_PORT")
sshExecTimeoutSecond = os.Getenv("EDGEDEVICE_DRIVER_EXEC_TIMEOUT_SECOND")
sshUser = os.Getenv("EDGEDEVICE_DRIVER_SSH_USER")
)
func init() {
if privateSSHKeyFile == "" {
logger.Fatalf("SSH Keyfile needs to be specified")
}
if driverHTTPPort == "" {
driverHTTPPort = "11112"
logger.Infof("No HTTP Port specified for driver, default to %v", driverHTTPPort)
}
if sshExecTimeoutSecond == "" {
sshExecTimeoutSecond = "5"
logger.Infof("No SSH exec timeout specified for driver, default to %v seconds", sshExecTimeoutSecond)
}
if sshUser == "" {
sshUser = "root"
logger.Infof("No SSH user specified for driver, default to %v", sshUser)
}
}
func main() {
key, err := os.ReadFile(privateSSHKeyFile)
if err != nil {
logger.Fatalf("unable to read private key: %v", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
logger.Fatalf("unable to parse private key: %v", err)
}
config := &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
Timeout: time.Minute,
}
sshClient, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
logger.Fatalf("unable to connect: %v", err)
}
defer sshClient.Close()
logger.Infof("Driver SSH established")
sshListener, err := sshClient.Listen("tcp", "localhost:"+driverHTTPPort)
if err != nil {
logger.Fatalf("unable to register tcp forward: %v", err)
}
defer sshListener.Close()
logger.Infof("Driver HTTP listener established")
err = http.Serve(sshListener, httpCmdlinePostHandler(sshClient))
if err != nil {
logger.Errorf("cannot start server, error: %v", err)
}
}
// Create a session reply for the incoming connection, obtain the connection body information,
// process it, hand it over to the shell for processing,return both result and status code based
// on shell execution result
func httpCmdlinePostHandler(sshConnection *ssh.Client) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
session, err := sshConnection.NewSession()
if err != nil {
logger.Fatalf("Failed to create session: %v", err)
}
defer session.Close()
httpCommand, err := io.ReadAll(req.Body)
if err != nil {
panic(err)
}
cmdString := "timeout " + sshExecTimeoutSecond + " " + string(httpCommand)
logger.Infof("running command: %v", cmdString)
var stdout bytes.Buffer
var stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
if err := session.Run(cmdString); err != nil {
logger.Errorf("Failed to run cmd: %v\n stderr: %v \n stdout: %v", cmdString, stderr.String(), stdout.String())
resp.WriteHeader(http.StatusBadRequest)
_, _ = resp.Write(append(stderr.Bytes(), stdout.Bytes()...))
return
}
logger.Infof("cmd: %v success", cmdString)
resp.WriteHeader(http.StatusOK)
_, _ = resp.Write(stdout.Bytes())
}
}