/
sshd.go
132 lines (113 loc) · 3.11 KB
/
sshd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package main
import (
"encoding/xml"
"flag"
"fmt"
gossh "golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"syscall"
"unsafe"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
)
var release = "dev" // Set by build process
// domain stores a libvirt domain
type domain struct {
XMLName xml.Name `xml:"domain"`
Name string `xml:"name"`
Password string `xml:"description"`
}
// Define flags
var (
bindHost = flag.String("l", ":2222", "Listen <host:port>")
hostKeyFile = flag.String("k", "~/.ssh/id_ed25519", "SSH host key file")
verbose = flag.Bool("v", false, "Enable verbose logging")
)
func handleAuth(ctx ssh.Context, providedPassword string) bool {
log.Printf("New connection from %s user %s password %s\n", ctx.RemoteAddr(), ctx.User(), providedPassword)
files, err := filepath.Glob("/etc/libvirt/qemu/*.xml")
if err != nil {
log.Fatalf("Unable to parse qemu config file glob: %v\n", err)
}
for _, f := range files {
// Read libvirt XML file
xmlFile, err := os.Open(f)
if err != nil {
log.Printf("XML open error: %v\n", err)
}
// Parse libvirt XML file
byteValue, _ := ioutil.ReadAll(xmlFile)
var currentDomain domain
err = xml.Unmarshal(byteValue, ¤tDomain)
if err != nil {
log.Println(err)
return false
}
_ = xmlFile.Close()
if *verbose {
fmt.Printf("Found VM %s password %s\n", currentDomain.Name, currentDomain.Password)
}
if currentDomain.Name == ctx.User() && currentDomain.Password == providedPassword {
return true // Allow access
}
}
return false // If there are no defined VMs, deny access
}
func handleSession(s ssh.Session) {
cmd := exec.Command("virsh", "console", "--safe", s.User())
ptyReq, winCh, isPty := s.Pty() // get SSH PTY information
if isPty {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, _ := pty.Start(cmd)
go func() {
for win := range winCh {
_, _, _ = syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0})))
}
}()
go func() { // goroutine to handle
_, err := io.Copy(f, s) // stdin
if err != nil {
log.Printf("virsh f->s copy error: %v\n", err)
}
}()
_, err := io.Copy(s, f) // stdout
if err != nil {
log.Printf("virsh s->f copy error: %v\n", err)
}
err = cmd.Wait()
if err != nil {
log.Printf("virsh wait error: %v\n", err)
}
} else {
_, _ = io.WriteString(s, "No PTY requested.\n")
_ = s.Exit(1)
}
}
func main() {
flag.Usage = func() {
fmt.Printf("Usage for libvirt-sshd (%s) https://github.com/natesales/libvirt-sshd:\n", release)
flag.PrintDefaults()
}
flag.Parse()
pemBytes, err := ioutil.ReadFile(*hostKeyFile)
if err != nil {
log.Fatal(err)
}
signer, err := gossh.ParsePrivateKey(pemBytes)
if err != nil {
log.Fatal(err)
}
sshServer := &ssh.Server{
Addr: *bindHost,
HostSigners: []ssh.Signer{signer},
Handler: handleSession,
PasswordHandler: handleAuth,
}
log.Printf("Starting libvirt-sshd on %s\n", *bindHost)
log.Fatal(sshServer.ListenAndServe())
}