Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper ssh knownhosts hostkey checking #75

Merged
merged 2 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions pkg/ssh/hostkey/callbacks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Package hostkey implements a callback for the ssh.ClientConfig.HostKeyCallback
package hostkey

import (
"encoding/base64"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

var (
// InsecureIgnoreHostKeyCallback is an insecure HostKeyCallback that accepts any host key.
InsecureIgnoreHostKeyCallback = ssh.InsecureIgnoreHostKey() //nolint:gosec

// ErrHostKeyMismatch is returned when the host key does not match the host key or a key in known_hosts file
ErrHostKeyMismatch = errors.New("host key mismatch")

// ErrInvalidPath is returned for unusable file paths
ErrInvalidPath = errors.New("invalid path")

// DefaultKnownHostsPath is the default path to the known_hosts file - make sure to homedir-expand it
DefaultKnownHostsPath = "~/.ssh/known_hosts2"

mu sync.Mutex
)

// StaticKeyCallback returns a HostKeyCallback that checks the host key against a given host key
func StaticKeyCallback(trustedKey string) ssh.HostKeyCallback {
return func(_ string, _ net.Addr, k ssh.PublicKey) error {
ks := keyString(k)
if trustedKey != ks {
return ErrHostKeyMismatch
}

return nil
}
}

// KnownHostsPathFromEnv returns the path to a known_hosts file from the environment variable SSH_KNOWN_HOSTS
var KnownHostsPathFromEnv = func() (string, bool) {
return os.LookupEnv("SSH_KNOWN_HOSTS")
}

// KnownHostsFileCallback returns a HostKeyCallback that uses a known hosts file to verify host keys
func KnownHostsFileCallback(path string) (ssh.HostKeyCallback, error) {
if path == "/dev/null" {
return InsecureIgnoreHostKeyCallback, nil
}

mu.Lock()
defer mu.Unlock()

if err := ensureFile(path); err != nil {
return nil, err
}

hkc, err := knownhosts.New(path)
if err != nil {
return nil, fmt.Errorf("create knownhosts callback: %w", err)
}

return wrapCallback(hkc, path), nil
}

// extends a knownhosts callback to not return an error when the key
// is not found in the known_hosts file but instead adds it to the file as new
// entry
func wrapCallback(hkc ssh.HostKeyCallback, path string) ssh.HostKeyCallback {
return ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error {
mu.Lock()
defer mu.Unlock()
err := hkc(hostname, remote, key)
if err == nil {
return nil
}

var keyErr *knownhosts.KeyError
if !errors.As(err, &keyErr) || len(keyErr.Want) > 0 {
// keyErr.Want is empty if the host key is not in the known_hosts file
// non-empty is a mismatch
return fmt.Errorf("%w: %s", ErrHostKeyMismatch, err)
}

dbFile, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("failed to open ssh known_hosts file %s for writing: %w", path, err)
}

knownHostsEntry := knownhosts.Normalize(remote.String())
row := knownhosts.Line([]string{knownHostsEntry}, key)
row = fmt.Sprintf("%s\n", strings.TrimSpace(row))

if _, err := dbFile.WriteString(row); err != nil {
return fmt.Errorf("failed to write to known hosts file %s: %w", path, err)
}
if err := dbFile.Close(); err != nil {
return fmt.Errorf("failed to close known_hosts file after writing: %w", err)
}
return nil
})
}

func fileExists(path string) bool {
stat, err := os.Stat(path)
return err == nil && stat.Mode().IsRegular()
}

func ensureDir(path string) error {
stat, err := os.Stat(path)
if err == nil && !stat.Mode().IsDir() {
return fmt.Errorf("%w: path %s is not a directory", ErrInvalidPath, path)
}
if err := os.MkdirAll(path, 0o700); err != nil {
return fmt.Errorf("failed to create directory %s: %w", path, err)
}
return nil
}

func ensureFile(path string) error {
if fileExists(path) {
return nil
}
if err := ensureDir(filepath.Dir(path)); err != nil {
return err
}
f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND, 0o600)
if err != nil {
return fmt.Errorf("failed to create known_hosts file: %w", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("failed to close known_hosts file: %w", err)
}
return nil
}

// create human-readable SSH-key strings e.g. "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTY...."
func keyString(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
58 changes: 56 additions & 2 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
"github.com/acarl005/stripansi"
"github.com/alessio/shellescape"
"github.com/creasty/defaults"
"github.com/google/shlex"
"github.com/k0sproject/rig/exec"
"github.com/k0sproject/rig/log"
"github.com/k0sproject/rig/pkg/ssh/hostkey"
ps "github.com/k0sproject/rig/powershell"
"github.com/kevinburke/ssh_config"
ssh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -56,6 +58,7 @@ var (
defaultKeypaths = []string{"~/.ssh/id_rsa", "~/.ssh/identity", "~/.ssh/id_dsa"}
dummyhostKeyPaths []string
globalOnce sync.Once
knownHostsMU sync.Mutex

// ErrNoSignerFound is returned when no signer is found for a key
ErrNoSignerFound = errors.New("no signer found for key")
Expand Down Expand Up @@ -255,8 +258,59 @@ func (c *SSH) IsWindows() bool {
return c.isWindows
}

func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { //nolint:unparam
return ssh.InsecureIgnoreHostKey(), nil //nolint:gosec
func knownhostsCallback(path string) (ssh.HostKeyCallback, error) {
cb, err := hostkey.KnownHostsFileCallback(path)
if err != nil {
return nil, fmt.Errorf("create host key validator: %w", err)
}
return cb, nil
}

func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) {
if c.HostKey != "" {
log.Debugf("%s: using host key from config", c)
return hostkey.StaticKeyCallback(c.HostKey), nil
}

knownHostsMU.Lock()
defer knownHostsMU.Unlock()

if path, ok := hostkey.KnownHostsPathFromEnv(); ok {
if path == "" {
return hostkey.InsecureIgnoreHostKeyCallback, nil
}
log.Tracef("%s: using known_hosts file from SSH_KNOWN_HOSTS: %s", c, path)
return knownhostsCallback(path)
}

var khPath string

// Ask ssh_config for a known hosts file
kfs := SSHConfigGetAll(c.Address, "UserKnownHostsFile")
// splitting the result as for some reason ssh_config sometimes seems to
// return a single string containing space separated paths
if files, err := shlex.Split(strings.Join(kfs, " ")); err == nil {
for _, f := range files {
exp, err := expandAndValidatePath(f)
khPath = exp
if err == nil {
break
}
}
}

if khPath != "" {
log.Tracef("%s: using known_hosts file from ssh config %s", c, khPath)
return knownhostsCallback(khPath)
}

log.Tracef("%s: using default known_hosts file %s", c, hostkey.DefaultKnownHostsPath)
defaultPath, err := expandPath(hostkey.DefaultKnownHostsPath)
if err != nil {
return nil, err
}

return knownhostsCallback(defaultPath)
}

func (c *SSH) clientConfig() (*ssh.ClientConfig, error) {
Expand Down