-
Notifications
You must be signed in to change notification settings - Fork 491
/
ssh_keycheck.go
108 lines (96 loc) · 3.14 KB
/
ssh_keycheck.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// Copyright 2017 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.
package main
import (
"fmt"
"io"
"net"
"os"
"os/user"
"path"
"strings"
"time"
"github.com/juju/gnuflag"
"github.com/juju/loggo"
"golang.org/x/crypto/ssh"
"github.com/juju/juju/core/network"
jujussh "github.com/juju/juju/network/ssh"
)
func knownHostFilename() string {
usr, err := user.Current()
if err != nil {
panic(fmt.Sprintf("unable to find current user: %v", err))
}
return path.Join(usr.HomeDir, ".ssh", "known_hosts")
}
// Juju reports the files in /etc/ssh/ssh_host_key_*_key.pub, so they are all
// in AuthorizedKey format.
func getKnownHostKeys(fname string) []string {
f, err := os.Open(fname)
if err != nil {
panic(fmt.Sprintf("unable to read known-hosts file: %q %v", fname, err))
}
defer func() { _ = f.Close() }()
content, err := io.ReadAll(f)
if err != nil {
panic(fmt.Sprintf("failed while reading known-hosts file: %q %v", fname, err))
}
pubKeys := make([]string, 0)
for len(content) > 0 {
// marker, hosts, pubkey, comment, rest, err
_, _, pubkey, _, remaining, err := ssh.ParseKnownHosts(content)
if err != nil {
panic(fmt.Sprintf("failed while parsing known hosts: %q %v", fname, err))
}
content = remaining
// We convert the "known_hosts" format into AuthorizedKeys format to
// match what Juju records.
pubKeys = append(pubKeys, string(ssh.MarshalAuthorizedKey(pubkey)))
}
return pubKeys
}
var logger = loggo.GetLogger("juju.ssh_keyscan")
func main() {
var verbose bool
var dialTimeout int = 500
var waitTimeout int = 5000
var hostFile string
gnuflag.BoolVar(&verbose, "v", false, "dump debugging information")
gnuflag.IntVar(&dialTimeout, "dial-timeout", 500, "time to try a single connection (in milliseconds)")
gnuflag.IntVar(&waitTimeout, "wait-timeout", 5000, "overall time to wait for answers (in milliseconds)")
gnuflag.StringVar(&hostFile, "known-hosts", knownHostFilename(), "point to an alternate known-hosts file")
gnuflag.Parse(true)
if verbose {
_ = loggo.ConfigureLoggers(`<root>=DEBUG`)
}
args := gnuflag.Args()
pubKeys := getKnownHostKeys(hostFile)
dialAddresses := make(network.HostPorts, 0, len(args))
for _, arg := range args {
if strings.Index(arg, ":") < 0 {
// Not valid for IPv6, but good enough for testing
arg = arg + ":22"
}
hp, err := network.ParseMachineHostPort(arg)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "invalid host:port value: %v\n%v\n", arg, err)
return
}
dialAddresses = append(dialAddresses, *hp)
}
addrs := make([]string, len(dialAddresses))
for i, addr := range dialAddresses {
addrs[i] = network.DialAddress(addr)
}
logger.Infof("host ports: %v\n", addrs)
logger.Infof("found %d known hosts\n", len(pubKeys))
logger.Debugf("known hosts: %v\n", pubKeys)
dialer := &net.Dialer{Timeout: time.Duration(dialTimeout) * time.Millisecond}
checker := jujussh.NewReachableChecker(dialer, time.Duration(waitTimeout)*time.Millisecond)
found, err := checker.FindHost(dialAddresses, pubKeys)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "could not find valid host: %v\n", err)
return
}
fmt.Printf("%s\n", network.DialAddress(found))
}